Compare commits

...

1 Commits

Author SHA1 Message Date
snaily
b14bb93d8f refactor: 项目结构优化与FastAPI生命周期更新
重构项目目录结构,提高代码组织性和可维护性

将schemas目录重命名为domain,更好地表达领域模型概念
将services目录细分为service/chat、service/image等子目录
将api目录重命名为router,更符合FastAPI惯例
创建utils目录存放通用工具函数
更新FastAPI应用程序生命周期管理

替换已弃用的on_event方法为推荐的lifespan事件处理器
添加应用程序关闭时的日志记录
代码质量改进

抽取常量到constants.py,减少硬编码值
添加helpers.py提供通用工具函数
优化配置管理,使用环境变量和默认值
完善文档字符串,提高代码可读性
2025-03-20 17:13:03 +08:00
31 changed files with 754 additions and 248 deletions

View File

@@ -263,7 +263,7 @@ uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
"content": "你好"
}
],
"model": "gemini-1.5-flash-002",
"model": "gemini-1.5-flash",
"temperature": 0.7,
"stream": false,
"tools": [],
@@ -276,7 +276,7 @@ uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
- `messages`: 消息列表,格式与 OpenAI API 相同
- `model`: 模型名称支持所有Gemini模型包括:
- `gemini-1.5-flash-002`: 快速响应模型
- `gemini-1.5-flash`: 快速响应模型
- `gemini-2.0-flash-exp`: 实验性快速响应模型
- `gemini-2.0-flash-exp-search`: 支持搜索功能的实验性模型
- `stream`: 是否开启流式响应,`true` 或 `false`

View File

@@ -1,19 +1,37 @@
from pydantic_settings import BaseSettings
"""
应用程序配置模块
"""
from typing import List
from pydantic_settings import BaseSettings
from app.core.constants import API_VERSION, DEFAULT_MODEL
class Settings(BaseSettings):
"""应用程序配置"""
# API相关配置
API_KEYS: List[str]
ALLOWED_TOKENS: List[str]
BASE_URL: str = "https://generativelanguage.googleapis.com/v1beta"
BASE_URL: str = f"https://generativelanguage.googleapis.com/{API_VERSION}"
AUTH_TOKEN: str = ""
MAX_FAILURES: int = 3
TEST_MODEL: str = DEFAULT_MODEL
# 模型相关配置
SEARCH_MODELS: List[str] = ["gemini-2.0-flash-exp"]
IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp"]
FILTERED_MODELS: List[str] = ["gemini-1.0-pro-vision-latest", "gemini-pro-vision", "chat-bison-001", "text-bison-001", "embedding-gecko-001"]
FILTERED_MODELS: List[str] = [
"gemini-1.0-pro-vision-latest",
"gemini-pro-vision",
"chat-bison-001",
"text-bison-001",
"embedding-gecko-001"
]
TOOLS_CODE_EXECUTION_ENABLED: bool = False
SHOW_SEARCH_LINK: bool = True
SHOW_THINKING_PROCESS: bool = True
AUTH_TOKEN: str = ""
MAX_FAILURES: int = 3
# 图像生成相关配置
PAID_KEY: str = ""
CREATE_IMAGE_MODEL: str = "imagen-3.0-generate-002"
UPLOAD_PROVIDER: str = "smms"
@@ -21,7 +39,6 @@ class Settings(BaseSettings):
PICGO_API_KEY: str = ""
CLOUDFLARE_IMGBED_URL: str = ""
CLOUDFLARE_IMGBED_AUTH_CODE: str = ""
TEST_MODEL: str = "gemini-1.5-flash"
# 流式输出优化器配置
STREAM_MIN_DELAY: float = 0.016
@@ -29,14 +46,16 @@ class Settings(BaseSettings):
STREAM_SHORT_TEXT_THRESHOLD: int = 10
STREAM_LONG_TEXT_THRESHOLD: int = 50
STREAM_CHUNK_SIZE: int = 5
def __init__(self):
super().__init__()
if not self.AUTH_TOKEN:
self.AUTH_TOKEN = self.ALLOWED_TOKENS[0] if self.ALLOWED_TOKENS else ""
def __init__(self, **kwargs):
super().__init__(**kwargs)
# 设置默认AUTH_TOKEN(如果未提供)
if not self.AUTH_TOKEN and self.ALLOWED_TOKENS:
self.AUTH_TOKEN = self.ALLOWED_TOKENS[0]
class Config:
env_file = ".env"
# 创建全局配置实例
settings = Settings()

71
app/core/application.py Normal file
View File

@@ -0,0 +1,71 @@
"""
应用程序工厂模块负责创建和配置FastAPI应用程序实例
"""
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from app.config.config import settings
from app.logger.logger import get_main_logger
from app.middleware.middleware import setup_middlewares
from app.exception.exceptions import setup_exception_handlers
from app.router.routers import setup_routers
from app.service.key.key_manager import get_key_manager_instance
from app.core.initialization import initialize_app
logger = get_main_logger()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
应用程序生命周期管理器
Args:
app: FastAPI应用实例
"""
# 启动事件
logger.info("Application starting up...")
try:
# 初始化KeyManager
await get_key_manager_instance(settings.API_KEYS)
logger.info("KeyManager initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize KeyManager: {str(e)}")
raise
yield # 应用程序运行期间
# 关闭事件
logger.info("Application shutting down...")
def create_app() -> FastAPI:
"""
创建并配置FastAPI应用程序实例
Returns:
FastAPI: 配置好的FastAPI应用程序实例
"""
# 初始化应用程序
initialize_app()
# 创建FastAPI应用
app = FastAPI(
title="Gemini Balance API",
description="Gemini API代理服务支持负载均衡和密钥管理",
version="1.0.0",
lifespan=lifespan
)
# 配置静态文件
app.mount("/static", StaticFiles(directory="app/static"), name="static")
# 配置中间件
setup_middlewares(app)
# 配置异常处理器
setup_exception_handlers(app)
# 配置路由
setup_routers(app)
return app

32
app/core/constants.py Normal file
View File

@@ -0,0 +1,32 @@
"""
常量定义模块
"""
# API相关常量
API_VERSION = "v1beta"
DEFAULT_TIMEOUT = 300 # 秒
# 模型相关常量
SUPPORTED_ROLES = ["user", "model", "system"]
DEFAULT_MODEL = "gemini-1.5-flash"
DEFAULT_TEMPERATURE = 0.7
DEFAULT_MAX_TOKENS = 8192
DEFAULT_TOP_P = 0.9
DEFAULT_TOP_K = 40
# 图像生成相关常量
VALID_IMAGE_RATIOS = ["1:1", "3:4", "4:3", "9:16", "16:9"]
# 上传提供商
UPLOAD_PROVIDERS = ["smms", "picgo", "cloudflare_imgbed"]
# 流式输出相关常量
DEFAULT_STREAM_MIN_DELAY = 0.016
DEFAULT_STREAM_MAX_DELAY = 0.024
DEFAULT_STREAM_SHORT_TEXT_THRESHOLD = 10
DEFAULT_STREAM_LONG_TEXT_THRESHOLD = 50
DEFAULT_STREAM_CHUNK_SIZE = 5
# 正则表达式模式
IMAGE_URL_PATTERN = r'!\[(.*?)\]\((.*?)\)'
DATA_URL_PATTERN = r'data:([^;]+);base64,(.+)'

View File

@@ -0,0 +1,39 @@
"""
应用程序初始化模块
"""
import logging
from pathlib import Path
from typing import List
logger = logging.getLogger("initialization")
def ensure_directories_exist(directories: List[str]) -> None:
"""
确保指定的目录存在,如果不存在则创建
Args:
directories: 要确保存在的目录列表
"""
for directory in directories:
try:
Path(directory).mkdir(parents=True, exist_ok=True)
logger.info(f"Ensured directory exists: {directory}")
except Exception as e:
logger.error(f"Failed to create directory {directory}: {str(e)}")
def initialize_app() -> None:
"""
初始化应用程序,确保所需的目录和文件都存在
"""
# 确保必要的目录存在
required_directories = [
"app/static/css",
"app/static/js",
"app/static/icons",
"app/templates",
]
ensure_directories_exist(required_directories)
logger.info("Application initialization completed")

View File

@@ -1,7 +1,7 @@
from fastapi import HTTPException, Header
from typing import Optional
from app.core.logger import get_security_logger
from app.core.config import settings
from app.logger.logger import get_security_logger
from app.config.config import settings
logger = get_security_logger()

View File

@@ -36,5 +36,5 @@ class GeminiRequest(BaseModel):
contents: List[GeminiContent] = []
tools: Optional[List[Dict[str, Any]]] = []
safetySettings: Optional[List[SafetySetting]] = None
generationConfig: Optional[GenerationConfig] = {}
generationConfig: Optional[GenerationConfig] = None
systemInstruction: Optional[SystemInstruction] = None

View File

@@ -1,17 +1,19 @@
from pydantic import BaseModel
from typing import List, Optional, Union
from app.core.constants import DEFAULT_MAX_TOKENS, DEFAULT_MODEL, DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P
class ChatRequest(BaseModel):
messages: List[dict]
model: str = "gemini-1.5-flash-002"
temperature: Optional[float] = 0.7
model: str = DEFAULT_MODEL
temperature: Optional[float] = DEFAULT_TEMPERATURE
stream: Optional[bool] = False
tools: Optional[List[dict]] = []
max_tokens: Optional[int] = 8192
max_tokens: Optional[int] = DEFAULT_MAX_TOKENS
top_p: Optional[float] = DEFAULT_TOP_P
top_k: Optional[int] = DEFAULT_TOP_K
stop: Optional[List[str]] = []
top_p: Optional[float] = 0.9
top_k: Optional[int] = 40
class EmbeddingRequest(BaseModel):

133
app/exception/exceptions.py Normal file
View File

@@ -0,0 +1,133 @@
"""
异常处理模块,定义应用程序中使用的自定义异常和异常处理器
"""
from fastapi import Request, FastAPI
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from starlette.exceptions import HTTPException as StarletteHTTPException
from app.logger.logger import get_main_logger
logger = get_main_logger()
class APIError(Exception):
"""API错误基类"""
def __init__(self, status_code: int, detail: str, error_code: str = None):
self.status_code = status_code
self.detail = detail
self.error_code = error_code or "api_error"
super().__init__(self.detail)
class AuthenticationError(APIError):
"""认证错误"""
def __init__(self, detail: str = "Authentication failed"):
super().__init__(status_code=401, detail=detail, error_code="authentication_error")
class AuthorizationError(APIError):
"""授权错误"""
def __init__(self, detail: str = "Not authorized to access this resource"):
super().__init__(status_code=403, detail=detail, error_code="authorization_error")
class ResourceNotFoundError(APIError):
"""资源未找到错误"""
def __init__(self, detail: str = "Resource not found"):
super().__init__(status_code=404, detail=detail, error_code="resource_not_found")
class ModelNotSupportedError(APIError):
"""模型不支持错误"""
def __init__(self, model: str):
super().__init__(
status_code=400,
detail=f"Model {model} is not supported",
error_code="model_not_supported"
)
class APIKeyError(APIError):
"""API密钥错误"""
def __init__(self, detail: str = "Invalid or expired API key"):
super().__init__(status_code=401, detail=detail, error_code="api_key_error")
class ServiceUnavailableError(APIError):
"""服务不可用错误"""
def __init__(self, detail: str = "Service temporarily unavailable"):
super().__init__(status_code=503, detail=detail, error_code="service_unavailable")
def setup_exception_handlers(app: FastAPI) -> None:
"""
设置应用程序的异常处理器
Args:
app: FastAPI应用程序实例
"""
@app.exception_handler(APIError)
async def api_error_handler(request: Request, exc: APIError):
"""处理API错误"""
logger.error(f"API Error: {exc.detail} (Code: {exc.error_code})")
return JSONResponse(
status_code=exc.status_code,
content={
"error": {
"code": exc.error_code,
"message": exc.detail
}
}
)
@app.exception_handler(StarletteHTTPException)
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
"""处理HTTP异常"""
logger.error(f"HTTP Exception: {exc.detail} (Status: {exc.status_code})")
return JSONResponse(
status_code=exc.status_code,
content={
"error": {
"code": "http_error",
"message": exc.detail
}
}
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""处理请求验证错误"""
error_details = []
for error in exc.errors():
error_details.append({
"loc": error["loc"],
"msg": error["msg"],
"type": error["type"]
})
logger.error(f"Validation Error: {error_details}")
return JSONResponse(
status_code=422,
content={
"error": {
"code": "validation_error",
"message": "Request validation failed",
"details": error_details
}
}
)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
"""处理通用异常"""
logger.exception(f"Unhandled Exception: {str(exc)}")
return JSONResponse(
status_code=500,
content={
"error": {
"code": "internal_server_error",
"message": "An unexpected error occurred"
}
}
)

View File

@@ -6,8 +6,7 @@ from typing import Any, Dict, List, Optional
import requests
import base64
SUPPORTED_ROLES = ["user", "model", "system"]
IMAGE_URL_PATTERN = r'\[(.*?)\]\((.*?)\)'
from app.core.constants import DATA_URL_PATTERN, IMAGE_URL_PATTERN, SUPPORTED_ROLES
class MessageConverter(ABC):
@@ -30,7 +29,7 @@ def _get_mime_type_and_data(base64_string):
# 检查字符串是否以 "data:" 格式开始
if base64_string.startswith('data:'):
# 提取 MIME 类型和数据
pattern = r'data:([^;]+);base64,(.+)'
pattern = DATA_URL_PATTERN
match = re.match(pattern, base64_string)
if match:
mime_type = "image/jpeg" if match.group(1) == "image/jpg" else match.group(1)

View File

@@ -8,8 +8,8 @@ from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional
import time
import uuid
from app.core.config import settings
from app.core.uploader import ImageUploaderFactory
from app.config.config import settings
from app.utils.uploader import ImageUploaderFactory
class ResponseHandler(ABC):

View File

@@ -2,7 +2,7 @@
from typing import TypeVar, Callable
from functools import wraps
from app.core.logger import get_retry_logger
from app.logger.logger import get_retry_logger
T = TypeVar('T')
logger = get_retry_logger()

View File

@@ -3,8 +3,9 @@
import asyncio
import math
from typing import Any, List, AsyncGenerator, Callable
from app.core.logger import get_openai_logger, get_gemini_logger
from app.core.config import settings
from app.logger.logger import get_openai_logger, get_gemini_logger
from app.config.config import settings
from app.core.constants import DEFAULT_STREAM_CHUNK_SIZE, DEFAULT_STREAM_LONG_TEXT_THRESHOLD, DEFAULT_STREAM_MAX_DELAY, DEFAULT_STREAM_MIN_DELAY, DEFAULT_STREAM_SHORT_TEXT_THRESHOLD
logger_openai = get_openai_logger()
logger_gemini = get_gemini_logger()
@@ -18,11 +19,11 @@ class StreamOptimizer:
def __init__(self,
logger=None,
min_delay: float = 0.016,
max_delay: float = 0.024,
short_text_threshold: int = 10,
long_text_threshold: int = 50,
chunk_size: int = 5):
min_delay: float = DEFAULT_STREAM_MIN_DELAY,
max_delay: float = DEFAULT_STREAM_MAX_DELAY,
short_text_threshold: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD,
long_text_threshold: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD,
chunk_size: int = DEFAULT_STREAM_CHUNK_SIZE):
"""初始化流式输出优化器
参数:

View File

@@ -1,134 +1,16 @@
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from app.core.logger import get_main_logger
from app.core.security import verify_auth_token
from app.services.key_manager import get_key_manager_instance
from app.core.config import settings
from app.api import gemini_routes, openai_routes
"""
应用程序入口模块
"""
import uvicorn
from app.core.application import create_app
from app.logger.logger import get_main_logger
# 创建应用程序实例
app = create_app()
# 配置日志
logger = get_main_logger()
app = FastAPI()
# 配置Jinja2模板
templates = Jinja2Templates(directory="app/templates")
# 配置静态文件
app.mount("/static", StaticFiles(directory="app/static"), name="static")
# 创建 KeyManager 实例
key_manager = None
@app.on_event("startup")
async def startup_event():
global key_manager
logger.info("Application starting up...")
try:
key_manager = await get_key_manager_instance(settings.API_KEYS)
logger.info("KeyManager initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize KeyManager: {str(e)}")
raise
# 添加中间件来处理未经身份验证的请求
@app.middleware("http")
async def auth_middleware(request: Request, call_next):
# 允许 gemini_routes 和 openai_routes 中的端点绕过身份验证
if (request.url.path not in ["/", "/auth"] and
not request.url.path.startswith("/static") and
not request.url.path.startswith("/gemini") and
not request.url.path.startswith("/v1") and
not request.url.path.startswith("/v1beta") and
not request.url.path.startswith("/health") and
not request.url.path.startswith("/hf")):
auth_token = request.cookies.get("auth_token")
if not auth_token or not verify_auth_token(auth_token):
logger.warning(f"Unauthorized access attempt to {request.url.path}")
return RedirectResponse(url="/")
logger.debug("Request authenticated successfully")
response = await call_next(request)
return response
# 添加请求日志中间件
# app.add_middleware(RequestLoggingMiddleware)
# 配置CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 生产环境建议配置具体的域名
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], # 明确指定允许的HTTP方法
allow_headers=["*"], # 生产环境建议配置具体的请求头
expose_headers=["*"], # 允许前端访问的响应头
max_age=600, # 预检请求缓存时间(秒)
)
# 包含所有路由
app.include_router(openai_routes.router)
app.include_router(gemini_routes.router)
app.include_router(gemini_routes.router_v1beta)
@app.get("/", response_class=HTMLResponse)
async def auth_page(request: Request):
return templates.TemplateResponse("auth.html", {"request": request})
@app.post("/auth")
async def authenticate(request: Request):
try:
form = await request.form()
auth_token = form.get("auth_token")
if not auth_token:
logger.warning("Authentication attempt with empty token")
return RedirectResponse(url="/", status_code=302)
if verify_auth_token(auth_token):
logger.info("Successful authentication")
response = RedirectResponse(url="/keys", status_code=302)
response.set_cookie(key="auth_token", value=auth_token, httponly=True, max_age=3600)
return response
logger.warning("Failed authentication attempt with invalid token")
return RedirectResponse(url="/", status_code=302)
except Exception as e:
logger.error(f"Authentication error: {str(e)}")
return RedirectResponse(url="/", status_code=302)
@app.get("/keys", response_class=HTMLResponse)
async def keys_page(request: Request):
try:
auth_token = request.cookies.get("auth_token")
if not auth_token or not verify_auth_token(auth_token):
logger.warning("Unauthorized access attempt to keys page")
return RedirectResponse(url="/", status_code=302)
keys_status = await key_manager.get_keys_by_status()
total = len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"])
logger.info(f"Keys status retrieved successfully. Total keys: {total}")
return templates.TemplateResponse("keys_status.html", {
"request": request,
"valid_keys": keys_status["valid_keys"],
"invalid_keys": keys_status["invalid_keys"],
"total": total
})
except Exception as e:
logger.error(f"Error retrieving keys status: {str(e)}")
raise
@app.get("/health")
async def health_check(request: Request):
logger.info("Health check endpoint called")
return {"status": "healthy"}
if __name__ == "__main__":
logger.info("Starting application server...")
uvicorn.run(app, host="0.0.0.0", port=8001)

View File

@@ -0,0 +1,61 @@
"""
中间件配置模块,负责设置和配置应用程序的中间件
"""
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from starlette.middleware.base import BaseHTTPMiddleware
from app.logger.logger import get_main_logger
from app.core.security import verify_auth_token
# from app.middleware.request_logging_middleware import RequestLoggingMiddleware
from app.core.constants import API_VERSION
logger = get_main_logger()
class AuthMiddleware(BaseHTTPMiddleware):
"""
认证中间件,处理未经身份验证的请求
"""
async def dispatch(self, request: Request, call_next):
# 允许特定路径绕过身份验证
if (request.url.path not in ["/", "/auth"] and
not request.url.path.startswith("/static") and
not request.url.path.startswith("/gemini") and
not request.url.path.startswith("/v1") and
not request.url.path.startswith(f"/{API_VERSION}") and
not request.url.path.startswith("/health") and
not request.url.path.startswith("/hf")):
auth_token = request.cookies.get("auth_token")
if not auth_token or not verify_auth_token(auth_token):
logger.warning(f"Unauthorized access attempt to {request.url.path}")
return RedirectResponse(url="/")
logger.debug("Request authenticated successfully")
response = await call_next(request)
return response
def setup_middlewares(app: FastAPI) -> None:
"""
设置应用程序的中间件
Args:
app: FastAPI应用程序实例
"""
# 添加认证中间件
app.add_middleware(AuthMiddleware)
# 添加请求日志中间件(可选,默认注释掉)
# app.add_middleware(RequestLoggingMiddleware)
# 配置CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 生产环境建议配置具体的域名
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], # 明确指定允许的HTTP方法
allow_headers=["*"], # 生产环境建议配置具体的请求头
expose_headers=["*"], # 允许前端访问的响应头
max_age=600, # 预检请求缓存时间(秒)
)

View File

@@ -1,7 +1,7 @@
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
import json
from app.core.logger import get_request_logger
from app.logger.logger import get_request_logger
logger = get_request_logger()

View File

@@ -1,73 +1,80 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from copy import deepcopy
from app.core.config import settings
from app.core.logger import get_gemini_logger
from app.config.config import settings
from app.logger.logger import get_gemini_logger
from app.core.security import SecurityService
from app.schemas.gemini_models import GeminiContent, GeminiRequest
from app.services.gemini_chat_service import GeminiChatService
from app.services.key_manager import KeyManager, get_key_manager_instance
from app.services.model_service import ModelService
from app.services.chat.retry_handler import RetryHandler
from app.domain.gemini_models import GeminiContent, GeminiRequest
from app.service.chat.gemini_chat_service import GeminiChatService
from app.service.key.key_manager import KeyManager, get_key_manager_instance
from app.service.model.model_service import ModelService
from app.handler.retry_handler import RetryHandler
from app.core.constants import API_VERSION
router = APIRouter(prefix="/gemini/v1beta")
router_v1beta = APIRouter(prefix="/v1beta")
# 路由设置
router = APIRouter(prefix=f"/gemini/{API_VERSION}")
router_v1beta = APIRouter(prefix=f"/{API_VERSION}")
logger = get_gemini_logger()
# 初始化服务
security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN)
model_service = ModelService(settings.SEARCH_MODELS, settings.IMAGE_MODELS)
async def get_key_manager():
"""获取密钥管理器实例"""
return await get_key_manager_instance()
async def get_next_working_key_wrapper(key_manager: KeyManager = Depends(get_key_manager)):
return await key_manager.get_next_working_key()
model_service = ModelService(settings.SEARCH_MODELS,settings.IMAGE_MODELS)
async def get_next_working_key(key_manager: KeyManager = Depends(get_key_manager)):
"""获取下一个可用的API密钥"""
return await key_manager.get_next_working_key()
@router.get("/models")
@router_v1beta.get("/models")
async def list_models(_=Depends(security_service.verify_key),
key_manager: KeyManager = Depends(get_key_manager)):
async def list_models(
_=Depends(security_service.verify_key),
key_manager: KeyManager = Depends(get_key_manager)
):
"""获取可用的Gemini模型列表"""
logger.info("-" * 50 + "list_gemini_models" + "-" * 50)
logger.info("Handling Gemini models list request")
api_key = await key_manager.get_next_working_key()
logger.info(f"Using API key: {api_key}")
models_json = model_service.get_gemini_models(api_key)
# 模型名称以及对应的详细信息
model_mapping = {x.get("name", "").split("/", maxsplit=1)[1]: x for x in models_json["models"]}
# 添加搜索模型
if model_service.search_models:
for name in model_service.search_models:
model = model_mapping.get(name, None)
model = model_mapping.get(name)
if not model:
continue
item = deepcopy(model)
item["name"] = f"models/{name}-search"
display_name = f'{item.get("displayName")} For Search'
item["displayName"] = display_name
item["description"] = display_name
models_json["models"].append(item)
# 添加图像生成模型
if model_service.image_models:
for name in model_service.image_models:
model = model_mapping.get(name, None)
model = model_mapping.get(name)
if not model:
continue
item = deepcopy(model)
item["name"] = f"models/{name}-image"
display_name = f'{item.get("displayName")} For Image'
item["displayName"] = display_name
item["description"] = display_name
models_json["models"].append(item)
return models_json
@@ -77,30 +84,29 @@ async def list_models(_=Depends(security_service.verify_key),
@router_v1beta.post("/models/{model_name}:generateContent")
@RetryHandler(max_retries=3, key_arg="api_key")
async def generate_content(
model_name: str,
request: GeminiRequest,
_=Depends(security_service.verify_goog_api_key),
api_key: str = Depends(get_next_working_key_wrapper),
key_manager: KeyManager = Depends(get_key_manager)
model_name: str,
request: GeminiRequest,
_=Depends(security_service.verify_goog_api_key),
api_key: str = Depends(get_next_working_key),
key_manager: KeyManager = Depends(get_key_manager)
):
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
"""非流式生成内容"""
logger.info("-" * 50 + "gemini_generate_content" + "-" * 50)
logger.info(f"Handling Gemini content generation request for model: {model_name}")
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}")
if not model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
try:
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
response = await chat_service.generate_content(
model=model_name,
request=request,
api_key=api_key
)
return response
except Exception as e:
logger.error(f"Chat completion failed after retries: {str(e)}")
raise HTTPException(status_code=500, detail="Chat completion failed") from e
@@ -110,45 +116,46 @@ async def generate_content(
@router_v1beta.post("/models/{model_name}:streamGenerateContent")
@RetryHandler(max_retries=3, key_arg="api_key")
async def stream_generate_content(
model_name: str,
request: GeminiRequest,
_=Depends(security_service.verify_goog_api_key),
api_key: str = Depends(get_next_working_key_wrapper),
key_manager: KeyManager = Depends(get_key_manager)
model_name: str,
request: GeminiRequest,
_=Depends(security_service.verify_goog_api_key),
api_key: str = Depends(get_next_working_key),
key_manager: KeyManager = Depends(get_key_manager)
):
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
"""流式生成内容"""
logger.info("-" * 50 + "gemini_stream_generate_content" + "-" * 50)
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}")
if not model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
try:
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
response_stream = chat_service.stream_generate_content(
model=model_name,
request=request,
api_key=api_key
)
return StreamingResponse(response_stream, media_type="text/event-stream")
except Exception as e:
logger.error(f"Streaming request failed: {str(e)}")
raise HTTPException(status_code=500, detail="Streaming request failed") from e
@router.post("/verify-key/{api_key}")
async def verify_key(api_key: str):
key_manager = await get_key_manager()
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
"""验证Gemini API密钥的有效性"""
logger.info("-" * 50 + "verify_gemini_key" + "-" * 50)
logger.info("Verifying API key validity")
try:
key_manager = await get_key_manager()
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
# 使用generate_content接口测试key的有效性
gemini_requset = GeminiRequest(
gemini_request = GeminiRequest(
contents=[
GeminiContent(
role="user",
@@ -156,10 +163,16 @@ async def verify_key(api_key: str):
)
]
)
response = await chat_service.generate_content(settings.TEST_MODEL,gemini_requset, api_key)
response = await chat_service.generate_content(
settings.TEST_MODEL,
gemini_request,
api_key
)
if response:
return JSONResponse({"status": "valid"})
return JSONResponse({"status": "invalid"})
except Exception as e:
logger.error(f"Key verification failed: {str(e)}")
return JSONResponse({"status": "invalid", "error": str(e)})
return JSONResponse({"status": "invalid", "error": str(e)})

View File

@@ -1,16 +1,16 @@
from fastapi import HTTPException, APIRouter, Depends
from fastapi.responses import StreamingResponse
from app.core.config import settings
from app.core.logger import get_openai_logger
from app.config.config import settings
from app.logger.logger import get_openai_logger
from app.core.security import SecurityService
from app.schemas.openai_models import ChatRequest, EmbeddingRequest, ImageGenerationRequest
from app.services.chat.retry_handler import RetryHandler
from app.services.embedding_service import EmbeddingService
from app.services.image_create_service import ImageCreateService
from app.services.key_manager import KeyManager, get_key_manager_instance
from app.services.model_service import ModelService
from app.services.openai_chat_service import OpenAIChatService
from app.domain.openai_models import ChatRequest, EmbeddingRequest, ImageGenerationRequest
from app.handler.retry_handler import RetryHandler
from app.service.embedding.embedding_service import EmbeddingService
from app.service.image.image_create_service import ImageCreateService
from app.service.key.key_manager import KeyManager, get_key_manager_instance
from app.service.model.model_service import ModelService
from app.service.chat.openai_chat_service import OpenAIChatService
router = APIRouter()
logger = get_openai_logger()

103
app/router/routers.py Normal file
View File

@@ -0,0 +1,103 @@
"""
路由配置模块,负责设置和配置应用程序的路由
"""
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from app.logger.logger import get_main_logger
from app.core.security import verify_auth_token
from app.router import gemini_routes, openai_routes
from app.service.key.key_manager import get_key_manager_instance
logger = get_main_logger()
# 配置Jinja2模板
templates = Jinja2Templates(directory="app/templates")
def setup_routers(app: FastAPI) -> None:
"""
设置应用程序的路由
Args:
app: FastAPI应用程序实例
"""
# 包含API路由
app.include_router(openai_routes.router)
app.include_router(gemini_routes.router)
app.include_router(gemini_routes.router_v1beta)
# 添加页面路由
setup_page_routes(app)
# 添加健康检查路由
setup_health_routes(app)
def setup_page_routes(app: FastAPI) -> None:
"""
设置页面相关的路由
Args:
app: FastAPI应用程序实例
"""
@app.get("/", response_class=HTMLResponse)
async def auth_page(request: Request):
"""认证页面"""
return templates.TemplateResponse("auth.html", {"request": request})
@app.post("/auth")
async def authenticate(request: Request):
"""处理认证请求"""
try:
form = await request.form()
auth_token = form.get("auth_token")
if not auth_token:
logger.warning("Authentication attempt with empty token")
return RedirectResponse(url="/", status_code=302)
if verify_auth_token(auth_token):
logger.info("Successful authentication")
response = RedirectResponse(url="/keys", status_code=302)
response.set_cookie(key="auth_token", value=auth_token, httponly=True, max_age=3600)
return response
logger.warning("Failed authentication attempt with invalid token")
return RedirectResponse(url="/", status_code=302)
except Exception as e:
logger.error(f"Authentication error: {str(e)}")
return RedirectResponse(url="/", status_code=302)
@app.get("/keys", response_class=HTMLResponse)
async def keys_page(request: Request):
"""密钥管理页面"""
try:
auth_token = request.cookies.get("auth_token")
if not auth_token or not verify_auth_token(auth_token):
logger.warning("Unauthorized access attempt to keys page")
return RedirectResponse(url="/", status_code=302)
key_manager = await get_key_manager_instance()
keys_status = await key_manager.get_keys_by_status()
total = len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"])
logger.info(f"Keys status retrieved successfully. Total keys: {total}")
return templates.TemplateResponse("keys_status.html", {
"request": request,
"valid_keys": keys_status["valid_keys"],
"invalid_keys": keys_status["invalid_keys"],
"total": total
})
except Exception as e:
logger.error(f"Error retrieving keys status: {str(e)}")
raise
def setup_health_routes(app: FastAPI) -> None:
"""
设置健康检查相关的路由
Args:
app: FastAPI应用程序实例
"""
@app.get("/health")
async def health_check(request: Request):
"""健康检查端点"""
logger.info("Health check endpoint called")
return {"status": "healthy"}

View File

@@ -2,13 +2,13 @@
import json
from typing import Dict, Any, AsyncGenerator, List
from app.core.logger import get_gemini_logger
from app.services.chat.api_client import GeminiApiClient
from app.services.chat.stream_optimizer import gemini_optimizer
from app.schemas.gemini_models import GeminiRequest
from app.core.config import settings
from app.services.chat.response_handler import GeminiResponseHandler
from app.services.key_manager import KeyManager
from app.logger.logger import get_gemini_logger
from app.service.client.api_client import GeminiApiClient
from app.handler.stream_optimizer import gemini_optimizer
from app.domain.gemini_models import GeminiRequest
from app.config.config import settings
from app.handler.response_handler import GeminiResponseHandler
from app.service.key.key_manager import KeyManager
logger = get_gemini_logger()

View File

@@ -3,15 +3,15 @@
from copy import deepcopy
import json
from typing import Dict, Any, AsyncGenerator, List, Optional, Union
from app.core.logger import get_openai_logger
from app.services.chat.message_converter import OpenAIMessageConverter
from app.services.chat.response_handler import OpenAIResponseHandler
from app.services.chat.api_client import GeminiApiClient
from app.services.chat.stream_optimizer import openai_optimizer
from app.schemas.openai_models import ChatRequest, ImageGenerationRequest
from app.core.config import settings
from app.services.image_create_service import ImageCreateService
from app.services.key_manager import KeyManager
from app.logger.logger import get_openai_logger
from app.handler.message_converter import OpenAIMessageConverter
from app.handler.response_handler import OpenAIResponseHandler
from app.service.client.api_client import GeminiApiClient
from app.handler.stream_optimizer import openai_optimizer
from app.domain.openai_models import ChatRequest, ImageGenerationRequest
from app.config.config import settings
from app.service.image.image_create_service import ImageCreateService
from app.service.key.key_manager import KeyManager
logger = get_openai_logger()

View File

@@ -4,6 +4,8 @@ from typing import Dict, Any, AsyncGenerator
import httpx
from abc import ABC, abstractmethod
from app.core.constants import DEFAULT_TIMEOUT
class ApiClient(ABC):
"""API客户端基类"""
@@ -20,7 +22,7 @@ class ApiClient(ABC):
class GeminiApiClient(ApiClient):
"""Gemini API客户端"""
def __init__(self, base_url: str, timeout: int = 300):
def __init__(self, base_url: str, timeout: int = DEFAULT_TIMEOUT):
self.base_url = base_url
self.timeout = timeout

View File

@@ -3,7 +3,7 @@ from typing import Union, List
import openai
from openai.types import CreateEmbeddingResponse
from app.core.logger import get_embeddings_logger
from app.logger.logger import get_embeddings_logger
logger = get_embeddings_logger()

View File

@@ -5,10 +5,11 @@ from google import genai
from google.genai import types
import base64
from app.core.config import settings
from app.core.logger import get_image_create_logger
from app.core.uploader import ImageUploaderFactory
from app.schemas.openai_models import ImageGenerationRequest
from app.config.config import settings
from app.logger.logger import get_image_create_logger
from app.utils.uploader import ImageUploaderFactory
from app.domain.openai_models import ImageGenerationRequest
from app.core.constants import VALID_IMAGE_RATIOS
logger = get_image_create_logger()
@@ -43,10 +44,9 @@ class ImageCreateService:
ratio_match = re.search(r'{ratio:(\d+:\d+)}', prompt)
if ratio_match:
aspect_ratio = ratio_match.group(1)
valid_ratios = ["1:1", "3:4", "4:3", "9:16", "16:9"]
if aspect_ratio not in valid_ratios:
if aspect_ratio not in VALID_IMAGE_RATIOS:
raise ValueError(
f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(valid_ratios)}"
f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}"
)
prompt = prompt.replace(ratio_match.group(0), '').strip()

View File

@@ -1,8 +1,8 @@
import asyncio
from itertools import cycle
from typing import Dict
from app.core.logger import get_key_manager_logger
from app.core.config import settings
from app.logger.logger import get_key_manager_logger
from app.config.config import settings
logger = get_key_manager_logger()

View File

@@ -1,8 +1,8 @@
import requests
from datetime import datetime, timezone
from typing import Optional, Dict, Any
from app.core.logger import get_model_logger
from app.core.config import settings
from app.logger.logger import get_model_logger
from app.config.config import settings
logger = get_model_logger()
@@ -10,7 +10,7 @@ class ModelService:
def __init__(self, search_models: list, image_models: list):
self.search_models = search_models
self.image_models = image_models
self.base_url = "https://generativelanguage.googleapis.com/v1beta"
self.base_url = settings.BASE_URL
self.filtered_models = settings.FILTERED_MODELS
def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]:

3
app/utils/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
"""
工具包初始化模块
"""

146
app/utils/helpers.py Normal file
View File

@@ -0,0 +1,146 @@
"""
通用工具函数模块
"""
import json
import re
import base64
import requests
from typing import Dict, Any, List, Optional, Tuple
from app.core.constants import DATA_URL_PATTERN, IMAGE_URL_PATTERN, VALID_IMAGE_RATIOS
def extract_mime_type_and_data(base64_string: str) -> Tuple[Optional[str], str]:
"""
从 base64 字符串中提取 MIME 类型和数据
Args:
base64_string: 可能包含 MIME 类型信息的 base64 字符串
Returns:
tuple: (mime_type, encoded_data)
"""
# 检查字符串是否以 "data:" 格式开始
if base64_string.startswith('data:'):
# 提取 MIME 类型和数据
pattern = DATA_URL_PATTERN
match = re.match(pattern, base64_string)
if match:
mime_type = "image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
encoded_data = match.group(2)
return mime_type, encoded_data
# 如果不是预期格式,假定它只是数据部分
return None, base64_string
def convert_image_to_base64(url: str) -> str:
"""
将图片URL转换为base64编码
Args:
url: 图片URL
Returns:
str: base64编码的图片数据
Raises:
Exception: 如果获取图片失败
"""
response = requests.get(url)
if response.status_code == 200:
# 将图片内容转换为base64
img_data = base64.b64encode(response.content).decode('utf-8')
return img_data
else:
raise Exception(f"Failed to fetch image: {response.status_code}")
def format_json_response(data: Dict[str, Any], indent: int = 2) -> str:
"""
格式化JSON响应
Args:
data: 要格式化的数据
indent: 缩进空格数
Returns:
str: 格式化后的JSON字符串
"""
return json.dumps(data, indent=indent, ensure_ascii=False)
def parse_prompt_parameters(prompt: str, default_ratio: str = "1:1") -> Tuple[str, int, str]:
"""
从prompt中解析参数
支持的格式:
- {n:数量} 例如: {n:2} 生成2张图片
- {ratio:比例} 例如: {ratio:16:9} 使用16:9比例
Args:
prompt: 提示文本
default_ratio: 默认比例
Returns:
tuple: (清理后的提示文本, 图片数量, 比例)
"""
# 默认值
n = 1
aspect_ratio = default_ratio
# 解析n参数
n_match = re.search(r'{n:(\d+)}', prompt)
if n_match:
n = int(n_match.group(1))
if n < 1 or n > 4:
raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.")
prompt = prompt.replace(n_match.group(0), '').strip()
# 解析ratio参数
ratio_match = re.search(r'{ratio:(\d+:\d+)}', prompt)
if ratio_match:
aspect_ratio = ratio_match.group(1)
if aspect_ratio not in VALID_IMAGE_RATIOS:
raise ValueError(
f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}"
)
prompt = prompt.replace(ratio_match.group(0), '').strip()
return prompt, n, aspect_ratio
def extract_image_urls_from_markdown(text: str) -> List[str]:
"""
从Markdown文本中提取图片URL
Args:
text: Markdown文本
Returns:
List[str]: 图片URL列表
"""
pattern = IMAGE_URL_PATTERN
matches = re.findall(pattern, text)
return [match[1] for match in matches]
def is_valid_api_key(key: str) -> bool:
"""
检查API密钥格式是否有效
Args:
key: API密钥
Returns:
bool: 如果密钥格式有效则返回True
"""
# 检查Gemini API密钥格式
if key.startswith('AIza'):
return len(key) >= 30
# 检查OpenAI API密钥格式
if key.startswith('sk-'):
return len(key) >= 30
return False

View File

@@ -1,5 +1,5 @@
import requests
from app.schemas.image_models import ImageMetadata, ImageUploader, UploadResponse
from app.domain.image_models import ImageMetadata, ImageUploader, UploadResponse
from enum import Enum
from typing import Optional, Any