From b14bb93d8f4aa6e883b97e5d3b547855b38eb18c Mon Sep 17 00:00:00 2001 From: snaily Date: Thu, 20 Mar 2025 17:13:03 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=A1=B9=E7=9B=AE=E7=BB=93?= =?UTF-8?q?=E6=9E=84=E4=BC=98=E5=8C=96=E4=B8=8EFastAPI=E7=94=9F=E5=91=BD?= =?UTF-8?q?=E5=91=A8=E6=9C=9F=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 重构项目目录结构,提高代码组织性和可维护性 将schemas目录重命名为domain,更好地表达领域模型概念 将services目录细分为service/chat、service/image等子目录 将api目录重命名为router,更符合FastAPI惯例 创建utils目录存放通用工具函数 更新FastAPI应用程序生命周期管理 替换已弃用的on_event方法为推荐的lifespan事件处理器 添加应用程序关闭时的日志记录 代码质量改进 抽取常量到constants.py,减少硬编码值 添加helpers.py提供通用工具函数 优化配置管理,使用环境变量和默认值 完善文档字符串,提高代码可读性 --- README.md | 4 +- app/{core => config}/config.py | 43 ++++-- app/core/application.py | 71 +++++++++ app/core/constants.py | 32 ++++ app/core/initialization.py | 39 +++++ app/core/security.py | 4 +- app/{schemas => domain}/gemini_models.py | 2 +- app/{schemas => domain}/image_models.py | 0 app/{schemas => domain}/openai_models.py | 12 +- app/exception/exceptions.py | 133 ++++++++++++++++ .../chat => handler}/message_converter.py | 5 +- .../chat => handler}/response_handler.py | 4 +- .../chat => handler}/retry_handler.py | 2 +- .../chat => handler}/stream_optimizer.py | 15 +- app/{core => logger}/logger.py | 0 app/main.py | 132 +--------------- app/middleware/middleware.py | 61 ++++++++ app/middleware/request_logging_middleware.py | 2 +- app/{api => router}/gemini_routes.py | 107 +++++++------ app/{api => router}/openai_routes.py | 18 +-- app/router/routers.py | 103 ++++++++++++ .../chat}/gemini_chat_service.py | 14 +- .../chat}/openai_chat_service.py | 18 +-- .../chat => service/client}/api_client.py | 4 +- .../embedding}/embedding_service.py | 2 +- .../image}/image_create_service.py | 14 +- app/{services => service/key}/key_manager.py | 4 +- .../model}/model_service.py | 6 +- app/utils/__init__.py | 3 + app/utils/helpers.py | 146 ++++++++++++++++++ app/{core => utils}/uploader.py | 2 +- 31 files changed, 754 insertions(+), 248 deletions(-) rename app/{core => config}/config.py (55%) create mode 100644 app/core/application.py create mode 100644 app/core/constants.py create mode 100644 app/core/initialization.py rename app/{schemas => domain}/gemini_models.py (96%) rename app/{schemas => domain}/image_models.py (100%) rename app/{schemas => domain}/openai_models.py (65%) create mode 100644 app/exception/exceptions.py rename app/{services/chat => handler}/message_converter.py (97%) rename app/{services/chat => handler}/response_handler.py (99%) rename app/{services/chat => handler}/retry_handler.py (96%) rename app/{services/chat => handler}/stream_optimizer.py (87%) rename app/{core => logger}/logger.py (100%) create mode 100644 app/middleware/middleware.py rename app/{api => router}/gemini_routes.py (68%) rename app/{api => router}/openai_routes.py (90%) create mode 100644 app/router/routers.py rename app/{services => service/chat}/gemini_chat_service.py (94%) rename app/{services => service/chat}/openai_chat_service.py (95%) rename app/{services/chat => service/client}/api_client.py (95%) rename app/{services => service/embedding}/embedding_service.py (93%) rename app/{services => service/image}/image_create_service.py (94%) rename app/{services => service/key}/key_manager.py (97%) rename app/{services => service/model}/model_service.py (95%) create mode 100644 app/utils/__init__.py create mode 100644 app/utils/helpers.py rename app/{core => utils}/uploader.py (99%) diff --git a/README.md b/README.md index ecd2095..da1b45a 100644 --- a/README.md +++ b/README.md @@ -263,7 +263,7 @@ uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload "content": "你好" } ], - "model": "gemini-1.5-flash-002", + "model": "gemini-1.5-flash", "temperature": 0.7, "stream": false, "tools": [], @@ -276,7 +276,7 @@ uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload - `messages`: 消息列表,格式与 OpenAI API 相同 - `model`: 模型名称,支持所有Gemini模型,包括: - - `gemini-1.5-flash-002`: 快速响应模型 + - `gemini-1.5-flash`: 快速响应模型 - `gemini-2.0-flash-exp`: 实验性快速响应模型 - `gemini-2.0-flash-exp-search`: 支持搜索功能的实验性模型 - `stream`: 是否开启流式响应,`true` 或 `false` diff --git a/app/core/config.py b/app/config/config.py similarity index 55% rename from app/core/config.py rename to app/config/config.py index bfa1835..d12f6b8 100644 --- a/app/core/config.py +++ b/app/config/config.py @@ -1,19 +1,37 @@ -from pydantic_settings import BaseSettings +""" +应用程序配置模块 +""" from typing import List +from pydantic_settings import BaseSettings + +from app.core.constants import API_VERSION, DEFAULT_MODEL class Settings(BaseSettings): + """应用程序配置""" + # API相关配置 API_KEYS: List[str] ALLOWED_TOKENS: List[str] - BASE_URL: str = "https://generativelanguage.googleapis.com/v1beta" + BASE_URL: str = f"https://generativelanguage.googleapis.com/{API_VERSION}" + AUTH_TOKEN: str = "" + MAX_FAILURES: int = 3 + TEST_MODEL: str = DEFAULT_MODEL + + # 模型相关配置 SEARCH_MODELS: List[str] = ["gemini-2.0-flash-exp"] IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp"] - FILTERED_MODELS: List[str] = ["gemini-1.0-pro-vision-latest", "gemini-pro-vision", "chat-bison-001", "text-bison-001", "embedding-gecko-001"] + FILTERED_MODELS: List[str] = [ + "gemini-1.0-pro-vision-latest", + "gemini-pro-vision", + "chat-bison-001", + "text-bison-001", + "embedding-gecko-001" + ] TOOLS_CODE_EXECUTION_ENABLED: bool = False SHOW_SEARCH_LINK: bool = True SHOW_THINKING_PROCESS: bool = True - AUTH_TOKEN: str = "" - MAX_FAILURES: int = 3 + + # 图像生成相关配置 PAID_KEY: str = "" CREATE_IMAGE_MODEL: str = "imagen-3.0-generate-002" UPLOAD_PROVIDER: str = "smms" @@ -21,7 +39,6 @@ class Settings(BaseSettings): PICGO_API_KEY: str = "" CLOUDFLARE_IMGBED_URL: str = "" CLOUDFLARE_IMGBED_AUTH_CODE: str = "" - TEST_MODEL: str = "gemini-1.5-flash" # 流式输出优化器配置 STREAM_MIN_DELAY: float = 0.016 @@ -29,14 +46,16 @@ class Settings(BaseSettings): STREAM_SHORT_TEXT_THRESHOLD: int = 10 STREAM_LONG_TEXT_THRESHOLD: int = 50 STREAM_CHUNK_SIZE: int = 5 - - def __init__(self): - super().__init__() - if not self.AUTH_TOKEN: - self.AUTH_TOKEN = self.ALLOWED_TOKENS[0] if self.ALLOWED_TOKENS else "" - + + def __init__(self, **kwargs): + super().__init__(**kwargs) + # 设置默认AUTH_TOKEN(如果未提供) + if not self.AUTH_TOKEN and self.ALLOWED_TOKENS: + self.AUTH_TOKEN = self.ALLOWED_TOKENS[0] + class Config: env_file = ".env" +# 创建全局配置实例 settings = Settings() diff --git a/app/core/application.py b/app/core/application.py new file mode 100644 index 0000000..1d9f825 --- /dev/null +++ b/app/core/application.py @@ -0,0 +1,71 @@ +""" +应用程序工厂模块,负责创建和配置FastAPI应用程序实例 +""" +from contextlib import asynccontextmanager +from fastapi import FastAPI +from fastapi.staticfiles import StaticFiles + +from app.config.config import settings +from app.logger.logger import get_main_logger +from app.middleware.middleware import setup_middlewares +from app.exception.exceptions import setup_exception_handlers +from app.router.routers import setup_routers +from app.service.key.key_manager import get_key_manager_instance +from app.core.initialization import initialize_app + +logger = get_main_logger() + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + 应用程序生命周期管理器 + + Args: + app: FastAPI应用实例 + """ + # 启动事件 + logger.info("Application starting up...") + try: + # 初始化KeyManager + await get_key_manager_instance(settings.API_KEYS) + logger.info("KeyManager initialized successfully") + except Exception as e: + logger.error(f"Failed to initialize KeyManager: {str(e)}") + raise + + yield # 应用程序运行期间 + + # 关闭事件 + logger.info("Application shutting down...") + +def create_app() -> FastAPI: + """ + 创建并配置FastAPI应用程序实例 + + Returns: + FastAPI: 配置好的FastAPI应用程序实例 + """ + # 初始化应用程序 + initialize_app() + + # 创建FastAPI应用 + app = FastAPI( + title="Gemini Balance API", + description="Gemini API代理服务,支持负载均衡和密钥管理", + version="1.0.0", + lifespan=lifespan + ) + + # 配置静态文件 + app.mount("/static", StaticFiles(directory="app/static"), name="static") + + # 配置中间件 + setup_middlewares(app) + + # 配置异常处理器 + setup_exception_handlers(app) + + # 配置路由 + setup_routers(app) + + return app diff --git a/app/core/constants.py b/app/core/constants.py new file mode 100644 index 0000000..0701048 --- /dev/null +++ b/app/core/constants.py @@ -0,0 +1,32 @@ +""" +常量定义模块 +""" + +# API相关常量 +API_VERSION = "v1beta" +DEFAULT_TIMEOUT = 300 # 秒 + +# 模型相关常量 +SUPPORTED_ROLES = ["user", "model", "system"] +DEFAULT_MODEL = "gemini-1.5-flash" +DEFAULT_TEMPERATURE = 0.7 +DEFAULT_MAX_TOKENS = 8192 +DEFAULT_TOP_P = 0.9 +DEFAULT_TOP_K = 40 + +# 图像生成相关常量 +VALID_IMAGE_RATIOS = ["1:1", "3:4", "4:3", "9:16", "16:9"] + +# 上传提供商 +UPLOAD_PROVIDERS = ["smms", "picgo", "cloudflare_imgbed"] + +# 流式输出相关常量 +DEFAULT_STREAM_MIN_DELAY = 0.016 +DEFAULT_STREAM_MAX_DELAY = 0.024 +DEFAULT_STREAM_SHORT_TEXT_THRESHOLD = 10 +DEFAULT_STREAM_LONG_TEXT_THRESHOLD = 50 +DEFAULT_STREAM_CHUNK_SIZE = 5 + +# 正则表达式模式 +IMAGE_URL_PATTERN = r'!\[(.*?)\]\((.*?)\)' +DATA_URL_PATTERN = r'data:([^;]+);base64,(.+)' diff --git a/app/core/initialization.py b/app/core/initialization.py new file mode 100644 index 0000000..40ce4b9 --- /dev/null +++ b/app/core/initialization.py @@ -0,0 +1,39 @@ +""" +应用程序初始化模块 +""" +import logging +from pathlib import Path +from typing import List + +logger = logging.getLogger("initialization") + + +def ensure_directories_exist(directories: List[str]) -> None: + """ + 确保指定的目录存在,如果不存在则创建 + + Args: + directories: 要确保存在的目录列表 + """ + for directory in directories: + try: + Path(directory).mkdir(parents=True, exist_ok=True) + logger.info(f"Ensured directory exists: {directory}") + except Exception as e: + logger.error(f"Failed to create directory {directory}: {str(e)}") + + +def initialize_app() -> None: + """ + 初始化应用程序,确保所需的目录和文件都存在 + """ + # 确保必要的目录存在 + required_directories = [ + "app/static/css", + "app/static/js", + "app/static/icons", + "app/templates", + ] + + ensure_directories_exist(required_directories) + logger.info("Application initialization completed") diff --git a/app/core/security.py b/app/core/security.py index e4c78e9..2684c23 100644 --- a/app/core/security.py +++ b/app/core/security.py @@ -1,7 +1,7 @@ from fastapi import HTTPException, Header from typing import Optional -from app.core.logger import get_security_logger -from app.core.config import settings +from app.logger.logger import get_security_logger +from app.config.config import settings logger = get_security_logger() diff --git a/app/schemas/gemini_models.py b/app/domain/gemini_models.py similarity index 96% rename from app/schemas/gemini_models.py rename to app/domain/gemini_models.py index 26515f1..e2fdb19 100644 --- a/app/schemas/gemini_models.py +++ b/app/domain/gemini_models.py @@ -36,5 +36,5 @@ class GeminiRequest(BaseModel): contents: List[GeminiContent] = [] tools: Optional[List[Dict[str, Any]]] = [] safetySettings: Optional[List[SafetySetting]] = None - generationConfig: Optional[GenerationConfig] = {} + generationConfig: Optional[GenerationConfig] = None systemInstruction: Optional[SystemInstruction] = None diff --git a/app/schemas/image_models.py b/app/domain/image_models.py similarity index 100% rename from app/schemas/image_models.py rename to app/domain/image_models.py diff --git a/app/schemas/openai_models.py b/app/domain/openai_models.py similarity index 65% rename from app/schemas/openai_models.py rename to app/domain/openai_models.py index 52b143e..69a28b0 100644 --- a/app/schemas/openai_models.py +++ b/app/domain/openai_models.py @@ -1,17 +1,19 @@ from pydantic import BaseModel from typing import List, Optional, Union +from app.core.constants import DEFAULT_MAX_TOKENS, DEFAULT_MODEL, DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P + class ChatRequest(BaseModel): messages: List[dict] - model: str = "gemini-1.5-flash-002" - temperature: Optional[float] = 0.7 + model: str = DEFAULT_MODEL + temperature: Optional[float] = DEFAULT_TEMPERATURE stream: Optional[bool] = False tools: Optional[List[dict]] = [] - max_tokens: Optional[int] = 8192 + max_tokens: Optional[int] = DEFAULT_MAX_TOKENS + top_p: Optional[float] = DEFAULT_TOP_P + top_k: Optional[int] = DEFAULT_TOP_K stop: Optional[List[str]] = [] - top_p: Optional[float] = 0.9 - top_k: Optional[int] = 40 class EmbeddingRequest(BaseModel): diff --git a/app/exception/exceptions.py b/app/exception/exceptions.py new file mode 100644 index 0000000..ee86957 --- /dev/null +++ b/app/exception/exceptions.py @@ -0,0 +1,133 @@ +""" +异常处理模块,定义应用程序中使用的自定义异常和异常处理器 +""" +from fastapi import Request, FastAPI +from fastapi.responses import JSONResponse +from fastapi.exceptions import RequestValidationError +from starlette.exceptions import HTTPException as StarletteHTTPException + +from app.logger.logger import get_main_logger + +logger = get_main_logger() + + +class APIError(Exception): + """API错误基类""" + def __init__(self, status_code: int, detail: str, error_code: str = None): + self.status_code = status_code + self.detail = detail + self.error_code = error_code or "api_error" + super().__init__(self.detail) + + +class AuthenticationError(APIError): + """认证错误""" + def __init__(self, detail: str = "Authentication failed"): + super().__init__(status_code=401, detail=detail, error_code="authentication_error") + + +class AuthorizationError(APIError): + """授权错误""" + def __init__(self, detail: str = "Not authorized to access this resource"): + super().__init__(status_code=403, detail=detail, error_code="authorization_error") + + +class ResourceNotFoundError(APIError): + """资源未找到错误""" + def __init__(self, detail: str = "Resource not found"): + super().__init__(status_code=404, detail=detail, error_code="resource_not_found") + + +class ModelNotSupportedError(APIError): + """模型不支持错误""" + def __init__(self, model: str): + super().__init__( + status_code=400, + detail=f"Model {model} is not supported", + error_code="model_not_supported" + ) + + +class APIKeyError(APIError): + """API密钥错误""" + def __init__(self, detail: str = "Invalid or expired API key"): + super().__init__(status_code=401, detail=detail, error_code="api_key_error") + + +class ServiceUnavailableError(APIError): + """服务不可用错误""" + def __init__(self, detail: str = "Service temporarily unavailable"): + super().__init__(status_code=503, detail=detail, error_code="service_unavailable") + + +def setup_exception_handlers(app: FastAPI) -> None: + """ + 设置应用程序的异常处理器 + + Args: + app: FastAPI应用程序实例 + """ + @app.exception_handler(APIError) + async def api_error_handler(request: Request, exc: APIError): + """处理API错误""" + logger.error(f"API Error: {exc.detail} (Code: {exc.error_code})") + return JSONResponse( + status_code=exc.status_code, + content={ + "error": { + "code": exc.error_code, + "message": exc.detail + } + } + ) + + @app.exception_handler(StarletteHTTPException) + async def http_exception_handler(request: Request, exc: StarletteHTTPException): + """处理HTTP异常""" + logger.error(f"HTTP Exception: {exc.detail} (Status: {exc.status_code})") + return JSONResponse( + status_code=exc.status_code, + content={ + "error": { + "code": "http_error", + "message": exc.detail + } + } + ) + + @app.exception_handler(RequestValidationError) + async def validation_exception_handler(request: Request, exc: RequestValidationError): + """处理请求验证错误""" + error_details = [] + for error in exc.errors(): + error_details.append({ + "loc": error["loc"], + "msg": error["msg"], + "type": error["type"] + }) + + logger.error(f"Validation Error: {error_details}") + return JSONResponse( + status_code=422, + content={ + "error": { + "code": "validation_error", + "message": "Request validation failed", + "details": error_details + } + } + ) + + @app.exception_handler(Exception) + async def general_exception_handler(request: Request, exc: Exception): + """处理通用异常""" + logger.exception(f"Unhandled Exception: {str(exc)}") + return JSONResponse( + status_code=500, + content={ + "error": { + "code": "internal_server_error", + "message": "An unexpected error occurred" + } + } + ) diff --git a/app/services/chat/message_converter.py b/app/handler/message_converter.py similarity index 97% rename from app/services/chat/message_converter.py rename to app/handler/message_converter.py index e7dbc52..0cb3163 100644 --- a/app/services/chat/message_converter.py +++ b/app/handler/message_converter.py @@ -6,8 +6,7 @@ from typing import Any, Dict, List, Optional import requests import base64 -SUPPORTED_ROLES = ["user", "model", "system"] -IMAGE_URL_PATTERN = r'\[(.*?)\]\((.*?)\)' +from app.core.constants import DATA_URL_PATTERN, IMAGE_URL_PATTERN, SUPPORTED_ROLES class MessageConverter(ABC): @@ -30,7 +29,7 @@ def _get_mime_type_and_data(base64_string): # 检查字符串是否以 "data:" 格式开始 if base64_string.startswith('data:'): # 提取 MIME 类型和数据 - pattern = r'data:([^;]+);base64,(.+)' + pattern = DATA_URL_PATTERN match = re.match(pattern, base64_string) if match: mime_type = "image/jpeg" if match.group(1) == "image/jpg" else match.group(1) diff --git a/app/services/chat/response_handler.py b/app/handler/response_handler.py similarity index 99% rename from app/services/chat/response_handler.py rename to app/handler/response_handler.py index 14ee175..c3ef49b 100644 --- a/app/services/chat/response_handler.py +++ b/app/handler/response_handler.py @@ -8,8 +8,8 @@ from abc import ABC, abstractmethod from typing import Dict, Any, List, Optional import time import uuid -from app.core.config import settings -from app.core.uploader import ImageUploaderFactory +from app.config.config import settings +from app.utils.uploader import ImageUploaderFactory class ResponseHandler(ABC): diff --git a/app/services/chat/retry_handler.py b/app/handler/retry_handler.py similarity index 96% rename from app/services/chat/retry_handler.py rename to app/handler/retry_handler.py index 6646e0f..60b0c83 100644 --- a/app/services/chat/retry_handler.py +++ b/app/handler/retry_handler.py @@ -2,7 +2,7 @@ from typing import TypeVar, Callable from functools import wraps -from app.core.logger import get_retry_logger +from app.logger.logger import get_retry_logger T = TypeVar('T') logger = get_retry_logger() diff --git a/app/services/chat/stream_optimizer.py b/app/handler/stream_optimizer.py similarity index 87% rename from app/services/chat/stream_optimizer.py rename to app/handler/stream_optimizer.py index cdf3010..caeb262 100644 --- a/app/services/chat/stream_optimizer.py +++ b/app/handler/stream_optimizer.py @@ -3,8 +3,9 @@ import asyncio import math from typing import Any, List, AsyncGenerator, Callable -from app.core.logger import get_openai_logger, get_gemini_logger -from app.core.config import settings +from app.logger.logger import get_openai_logger, get_gemini_logger +from app.config.config import settings +from app.core.constants import DEFAULT_STREAM_CHUNK_SIZE, DEFAULT_STREAM_LONG_TEXT_THRESHOLD, DEFAULT_STREAM_MAX_DELAY, DEFAULT_STREAM_MIN_DELAY, DEFAULT_STREAM_SHORT_TEXT_THRESHOLD logger_openai = get_openai_logger() logger_gemini = get_gemini_logger() @@ -18,11 +19,11 @@ class StreamOptimizer: def __init__(self, logger=None, - min_delay: float = 0.016, - max_delay: float = 0.024, - short_text_threshold: int = 10, - long_text_threshold: int = 50, - chunk_size: int = 5): + min_delay: float = DEFAULT_STREAM_MIN_DELAY, + max_delay: float = DEFAULT_STREAM_MAX_DELAY, + short_text_threshold: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD, + long_text_threshold: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD, + chunk_size: int = DEFAULT_STREAM_CHUNK_SIZE): """初始化流式输出优化器 参数: diff --git a/app/core/logger.py b/app/logger/logger.py similarity index 100% rename from app/core/logger.py rename to app/logger/logger.py diff --git a/app/main.py b/app/main.py index 0130889..c9acbdd 100644 --- a/app/main.py +++ b/app/main.py @@ -1,134 +1,16 @@ -from fastapi import FastAPI, Request -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import HTMLResponse, RedirectResponse -from fastapi.templating import Jinja2Templates -from fastapi.staticfiles import StaticFiles -from app.core.logger import get_main_logger -from app.core.security import verify_auth_token -from app.services.key_manager import get_key_manager_instance -from app.core.config import settings - -from app.api import gemini_routes, openai_routes +""" +应用程序入口模块 +""" import uvicorn +from app.core.application import create_app +from app.logger.logger import get_main_logger +# 创建应用程序实例 +app = create_app() # 配置日志 logger = get_main_logger() -app = FastAPI() - -# 配置Jinja2模板 -templates = Jinja2Templates(directory="app/templates") - -# 配置静态文件 -app.mount("/static", StaticFiles(directory="app/static"), name="static") - -# 创建 KeyManager 实例 -key_manager = None - -@app.on_event("startup") -async def startup_event(): - global key_manager - logger.info("Application starting up...") - try: - key_manager = await get_key_manager_instance(settings.API_KEYS) - logger.info("KeyManager initialized successfully") - except Exception as e: - logger.error(f"Failed to initialize KeyManager: {str(e)}") - raise - -# 添加中间件来处理未经身份验证的请求 -@app.middleware("http") -async def auth_middleware(request: Request, call_next): - # 允许 gemini_routes 和 openai_routes 中的端点绕过身份验证 - if (request.url.path not in ["/", "/auth"] and - not request.url.path.startswith("/static") and - not request.url.path.startswith("/gemini") and - not request.url.path.startswith("/v1") and - not request.url.path.startswith("/v1beta") and - not request.url.path.startswith("/health") and - not request.url.path.startswith("/hf")): - auth_token = request.cookies.get("auth_token") - if not auth_token or not verify_auth_token(auth_token): - logger.warning(f"Unauthorized access attempt to {request.url.path}") - return RedirectResponse(url="/") - logger.debug("Request authenticated successfully") - response = await call_next(request) - return response - -# 添加请求日志中间件 -# app.add_middleware(RequestLoggingMiddleware) - -# 配置CORS中间件 -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], # 生产环境建议配置具体的域名 - allow_credentials=True, - allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], # 明确指定允许的HTTP方法 - allow_headers=["*"], # 生产环境建议配置具体的请求头 - expose_headers=["*"], # 允许前端访问的响应头 - max_age=600, # 预检请求缓存时间(秒) -) - -# 包含所有路由 -app.include_router(openai_routes.router) -app.include_router(gemini_routes.router) -app.include_router(gemini_routes.router_v1beta) - - -@app.get("/", response_class=HTMLResponse) -async def auth_page(request: Request): - return templates.TemplateResponse("auth.html", {"request": request}) - - -@app.post("/auth") -async def authenticate(request: Request): - try: - form = await request.form() - auth_token = form.get("auth_token") - if not auth_token: - logger.warning("Authentication attempt with empty token") - return RedirectResponse(url="/", status_code=302) - - if verify_auth_token(auth_token): - logger.info("Successful authentication") - response = RedirectResponse(url="/keys", status_code=302) - response.set_cookie(key="auth_token", value=auth_token, httponly=True, max_age=3600) - return response - logger.warning("Failed authentication attempt with invalid token") - return RedirectResponse(url="/", status_code=302) - except Exception as e: - logger.error(f"Authentication error: {str(e)}") - return RedirectResponse(url="/", status_code=302) - -@app.get("/keys", response_class=HTMLResponse) -async def keys_page(request: Request): - try: - auth_token = request.cookies.get("auth_token") - if not auth_token or not verify_auth_token(auth_token): - logger.warning("Unauthorized access attempt to keys page") - return RedirectResponse(url="/", status_code=302) - - keys_status = await key_manager.get_keys_by_status() - total = len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"]) - logger.info(f"Keys status retrieved successfully. Total keys: {total}") - return templates.TemplateResponse("keys_status.html", { - "request": request, - "valid_keys": keys_status["valid_keys"], - "invalid_keys": keys_status["invalid_keys"], - "total": total - }) - except Exception as e: - logger.error(f"Error retrieving keys status: {str(e)}") - raise - - -@app.get("/health") -async def health_check(request: Request): - logger.info("Health check endpoint called") - return {"status": "healthy"} - - if __name__ == "__main__": logger.info("Starting application server...") uvicorn.run(app, host="0.0.0.0", port=8001) diff --git a/app/middleware/middleware.py b/app/middleware/middleware.py new file mode 100644 index 0000000..dc39576 --- /dev/null +++ b/app/middleware/middleware.py @@ -0,0 +1,61 @@ +""" +中间件配置模块,负责设置和配置应用程序的中间件 +""" +from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import RedirectResponse +from starlette.middleware.base import BaseHTTPMiddleware + +from app.logger.logger import get_main_logger +from app.core.security import verify_auth_token +# from app.middleware.request_logging_middleware import RequestLoggingMiddleware +from app.core.constants import API_VERSION + +logger = get_main_logger() + +class AuthMiddleware(BaseHTTPMiddleware): + """ + 认证中间件,处理未经身份验证的请求 + """ + async def dispatch(self, request: Request, call_next): + # 允许特定路径绕过身份验证 + if (request.url.path not in ["/", "/auth"] and + not request.url.path.startswith("/static") and + not request.url.path.startswith("/gemini") and + not request.url.path.startswith("/v1") and + not request.url.path.startswith(f"/{API_VERSION}") and + not request.url.path.startswith("/health") and + not request.url.path.startswith("/hf")): + + auth_token = request.cookies.get("auth_token") + if not auth_token or not verify_auth_token(auth_token): + logger.warning(f"Unauthorized access attempt to {request.url.path}") + return RedirectResponse(url="/") + logger.debug("Request authenticated successfully") + + response = await call_next(request) + return response + +def setup_middlewares(app: FastAPI) -> None: + """ + 设置应用程序的中间件 + + Args: + app: FastAPI应用程序实例 + """ + # 添加认证中间件 + app.add_middleware(AuthMiddleware) + + # 添加请求日志中间件(可选,默认注释掉) + # app.add_middleware(RequestLoggingMiddleware) + + # 配置CORS中间件 + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # 生产环境建议配置具体的域名 + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], # 明确指定允许的HTTP方法 + allow_headers=["*"], # 生产环境建议配置具体的请求头 + expose_headers=["*"], # 允许前端访问的响应头 + max_age=600, # 预检请求缓存时间(秒) + ) diff --git a/app/middleware/request_logging_middleware.py b/app/middleware/request_logging_middleware.py index bcef5c4..814514e 100644 --- a/app/middleware/request_logging_middleware.py +++ b/app/middleware/request_logging_middleware.py @@ -1,7 +1,7 @@ from fastapi import Request from starlette.middleware.base import BaseHTTPMiddleware import json -from app.core.logger import get_request_logger +from app.logger.logger import get_request_logger logger = get_request_logger() diff --git a/app/api/gemini_routes.py b/app/router/gemini_routes.py similarity index 68% rename from app/api/gemini_routes.py rename to app/router/gemini_routes.py index f0f5164..d10b517 100644 --- a/app/api/gemini_routes.py +++ b/app/router/gemini_routes.py @@ -1,73 +1,80 @@ from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse, JSONResponse from copy import deepcopy -from app.core.config import settings -from app.core.logger import get_gemini_logger +from app.config.config import settings +from app.logger.logger import get_gemini_logger from app.core.security import SecurityService -from app.schemas.gemini_models import GeminiContent, GeminiRequest -from app.services.gemini_chat_service import GeminiChatService -from app.services.key_manager import KeyManager, get_key_manager_instance -from app.services.model_service import ModelService -from app.services.chat.retry_handler import RetryHandler +from app.domain.gemini_models import GeminiContent, GeminiRequest +from app.service.chat.gemini_chat_service import GeminiChatService +from app.service.key.key_manager import KeyManager, get_key_manager_instance +from app.service.model.model_service import ModelService +from app.handler.retry_handler import RetryHandler +from app.core.constants import API_VERSION -router = APIRouter(prefix="/gemini/v1beta") -router_v1beta = APIRouter(prefix="/v1beta") +# 路由设置 +router = APIRouter(prefix=f"/gemini/{API_VERSION}") +router_v1beta = APIRouter(prefix=f"/{API_VERSION}") logger = get_gemini_logger() # 初始化服务 security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN) +model_service = ModelService(settings.SEARCH_MODELS, settings.IMAGE_MODELS) + async def get_key_manager(): + """获取密钥管理器实例""" return await get_key_manager_instance() -async def get_next_working_key_wrapper(key_manager: KeyManager = Depends(get_key_manager)): - return await key_manager.get_next_working_key() -model_service = ModelService(settings.SEARCH_MODELS,settings.IMAGE_MODELS) +async def get_next_working_key(key_manager: KeyManager = Depends(get_key_manager)): + """获取下一个可用的API密钥""" + return await key_manager.get_next_working_key() @router.get("/models") @router_v1beta.get("/models") -async def list_models(_=Depends(security_service.verify_key), - key_manager: KeyManager = Depends(get_key_manager)): +async def list_models( + _=Depends(security_service.verify_key), + key_manager: KeyManager = Depends(get_key_manager) +): """获取可用的Gemini模型列表""" logger.info("-" * 50 + "list_gemini_models" + "-" * 50) logger.info("Handling Gemini models list request") + api_key = await key_manager.get_next_working_key() logger.info(f"Using API key: {api_key}") + models_json = model_service.get_gemini_models(api_key) - - # 模型名称以及对应的详细信息 model_mapping = {x.get("name", "").split("/", maxsplit=1)[1]: x for x in models_json["models"]} - + # 添加搜索模型 if model_service.search_models: for name in model_service.search_models: - model = model_mapping.get(name, None) + model = model_mapping.get(name) if not model: continue - + item = deepcopy(model) item["name"] = f"models/{name}-search" display_name = f'{item.get("displayName")} For Search' item["displayName"] = display_name item["description"] = display_name - + models_json["models"].append(item) - + # 添加图像生成模型 if model_service.image_models: for name in model_service.image_models: - model = model_mapping.get(name, None) + model = model_mapping.get(name) if not model: continue - + item = deepcopy(model) item["name"] = f"models/{name}-image" display_name = f'{item.get("displayName")} For Image' item["displayName"] = display_name item["description"] = display_name - + models_json["models"].append(item) return models_json @@ -77,30 +84,29 @@ async def list_models(_=Depends(security_service.verify_key), @router_v1beta.post("/models/{model_name}:generateContent") @RetryHandler(max_retries=3, key_arg="api_key") async def generate_content( - model_name: str, - request: GeminiRequest, - _=Depends(security_service.verify_goog_api_key), - api_key: str = Depends(get_next_working_key_wrapper), - key_manager: KeyManager = Depends(get_key_manager) + model_name: str, + request: GeminiRequest, + _=Depends(security_service.verify_goog_api_key), + api_key: str = Depends(get_next_working_key), + key_manager: KeyManager = Depends(get_key_manager) ): - chat_service = GeminiChatService(settings.BASE_URL, key_manager) """非流式生成内容""" logger.info("-" * 50 + "gemini_generate_content" + "-" * 50) logger.info(f"Handling Gemini content generation request for model: {model_name}") logger.info(f"Request: \n{request.model_dump_json(indent=2)}") logger.info(f"Using API key: {api_key}") - + if not model_service.check_model_support(model_name): raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported") - + try: + chat_service = GeminiChatService(settings.BASE_URL, key_manager) response = await chat_service.generate_content( model=model_name, request=request, api_key=api_key ) return response - except Exception as e: logger.error(f"Chat completion failed after retries: {str(e)}") raise HTTPException(status_code=500, detail="Chat completion failed") from e @@ -110,45 +116,46 @@ async def generate_content( @router_v1beta.post("/models/{model_name}:streamGenerateContent") @RetryHandler(max_retries=3, key_arg="api_key") async def stream_generate_content( - model_name: str, - request: GeminiRequest, - _=Depends(security_service.verify_goog_api_key), - api_key: str = Depends(get_next_working_key_wrapper), - key_manager: KeyManager = Depends(get_key_manager) + model_name: str, + request: GeminiRequest, + _=Depends(security_service.verify_goog_api_key), + api_key: str = Depends(get_next_working_key), + key_manager: KeyManager = Depends(get_key_manager) ): - chat_service = GeminiChatService(settings.BASE_URL, key_manager) """流式生成内容""" logger.info("-" * 50 + "gemini_stream_generate_content" + "-" * 50) logger.info(f"Handling Gemini streaming content generation for model: {model_name}") logger.info(f"Request: \n{request.model_dump_json(indent=2)}") logger.info(f"Using API key: {api_key}") - + if not model_service.check_model_support(model_name): raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported") - + try: + chat_service = GeminiChatService(settings.BASE_URL, key_manager) response_stream = chat_service.stream_generate_content( model=model_name, request=request, api_key=api_key ) return StreamingResponse(response_stream, media_type="text/event-stream") - except Exception as e: logger.error(f"Streaming request failed: {str(e)}") + raise HTTPException(status_code=500, detail="Streaming request failed") from e @router.post("/verify-key/{api_key}") async def verify_key(api_key: str): - key_manager = await get_key_manager() - chat_service = GeminiChatService(settings.BASE_URL, key_manager) """验证Gemini API密钥的有效性""" logger.info("-" * 50 + "verify_gemini_key" + "-" * 50) logger.info("Verifying API key validity") try: + key_manager = await get_key_manager() + chat_service = GeminiChatService(settings.BASE_URL, key_manager) + # 使用generate_content接口测试key的有效性 - gemini_requset = GeminiRequest( + gemini_request = GeminiRequest( contents=[ GeminiContent( role="user", @@ -156,10 +163,16 @@ async def verify_key(api_key: str): ) ] ) - response = await chat_service.generate_content(settings.TEST_MODEL,gemini_requset, api_key) + + response = await chat_service.generate_content( + settings.TEST_MODEL, + gemini_request, + api_key + ) + if response: return JSONResponse({"status": "valid"}) return JSONResponse({"status": "invalid"}) except Exception as e: logger.error(f"Key verification failed: {str(e)}") - return JSONResponse({"status": "invalid", "error": str(e)}) + return JSONResponse({"status": "invalid", "error": str(e)}) \ No newline at end of file diff --git a/app/api/openai_routes.py b/app/router/openai_routes.py similarity index 90% rename from app/api/openai_routes.py rename to app/router/openai_routes.py index 3e1f728..62211f9 100644 --- a/app/api/openai_routes.py +++ b/app/router/openai_routes.py @@ -1,16 +1,16 @@ from fastapi import HTTPException, APIRouter, Depends from fastapi.responses import StreamingResponse -from app.core.config import settings -from app.core.logger import get_openai_logger +from app.config.config import settings +from app.logger.logger import get_openai_logger from app.core.security import SecurityService -from app.schemas.openai_models import ChatRequest, EmbeddingRequest, ImageGenerationRequest -from app.services.chat.retry_handler import RetryHandler -from app.services.embedding_service import EmbeddingService -from app.services.image_create_service import ImageCreateService -from app.services.key_manager import KeyManager, get_key_manager_instance -from app.services.model_service import ModelService -from app.services.openai_chat_service import OpenAIChatService +from app.domain.openai_models import ChatRequest, EmbeddingRequest, ImageGenerationRequest +from app.handler.retry_handler import RetryHandler +from app.service.embedding.embedding_service import EmbeddingService +from app.service.image.image_create_service import ImageCreateService +from app.service.key.key_manager import KeyManager, get_key_manager_instance +from app.service.model.model_service import ModelService +from app.service.chat.openai_chat_service import OpenAIChatService router = APIRouter() logger = get_openai_logger() diff --git a/app/router/routers.py b/app/router/routers.py new file mode 100644 index 0000000..333a644 --- /dev/null +++ b/app/router/routers.py @@ -0,0 +1,103 @@ +""" +路由配置模块,负责设置和配置应用程序的路由 +""" +from fastapi import FastAPI, Request +from fastapi.responses import HTMLResponse, RedirectResponse +from fastapi.templating import Jinja2Templates + +from app.logger.logger import get_main_logger +from app.core.security import verify_auth_token +from app.router import gemini_routes, openai_routes +from app.service.key.key_manager import get_key_manager_instance + +logger = get_main_logger() + +# 配置Jinja2模板 +templates = Jinja2Templates(directory="app/templates") + +def setup_routers(app: FastAPI) -> None: + """ + 设置应用程序的路由 + + Args: + app: FastAPI应用程序实例 + """ + # 包含API路由 + app.include_router(openai_routes.router) + app.include_router(gemini_routes.router) + app.include_router(gemini_routes.router_v1beta) + + # 添加页面路由 + setup_page_routes(app) + + # 添加健康检查路由 + setup_health_routes(app) + +def setup_page_routes(app: FastAPI) -> None: + """ + 设置页面相关的路由 + + Args: + app: FastAPI应用程序实例 + """ + @app.get("/", response_class=HTMLResponse) + async def auth_page(request: Request): + """认证页面""" + return templates.TemplateResponse("auth.html", {"request": request}) + + @app.post("/auth") + async def authenticate(request: Request): + """处理认证请求""" + try: + form = await request.form() + auth_token = form.get("auth_token") + if not auth_token: + logger.warning("Authentication attempt with empty token") + return RedirectResponse(url="/", status_code=302) + + if verify_auth_token(auth_token): + logger.info("Successful authentication") + response = RedirectResponse(url="/keys", status_code=302) + response.set_cookie(key="auth_token", value=auth_token, httponly=True, max_age=3600) + return response + logger.warning("Failed authentication attempt with invalid token") + return RedirectResponse(url="/", status_code=302) + except Exception as e: + logger.error(f"Authentication error: {str(e)}") + return RedirectResponse(url="/", status_code=302) + + @app.get("/keys", response_class=HTMLResponse) + async def keys_page(request: Request): + """密钥管理页面""" + try: + auth_token = request.cookies.get("auth_token") + if not auth_token or not verify_auth_token(auth_token): + logger.warning("Unauthorized access attempt to keys page") + return RedirectResponse(url="/", status_code=302) + + key_manager = await get_key_manager_instance() + keys_status = await key_manager.get_keys_by_status() + total = len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"]) + logger.info(f"Keys status retrieved successfully. Total keys: {total}") + return templates.TemplateResponse("keys_status.html", { + "request": request, + "valid_keys": keys_status["valid_keys"], + "invalid_keys": keys_status["invalid_keys"], + "total": total + }) + except Exception as e: + logger.error(f"Error retrieving keys status: {str(e)}") + raise + +def setup_health_routes(app: FastAPI) -> None: + """ + 设置健康检查相关的路由 + + Args: + app: FastAPI应用程序实例 + """ + @app.get("/health") + async def health_check(request: Request): + """健康检查端点""" + logger.info("Health check endpoint called") + return {"status": "healthy"} diff --git a/app/services/gemini_chat_service.py b/app/service/chat/gemini_chat_service.py similarity index 94% rename from app/services/gemini_chat_service.py rename to app/service/chat/gemini_chat_service.py index 2cf82c9..653264e 100644 --- a/app/services/gemini_chat_service.py +++ b/app/service/chat/gemini_chat_service.py @@ -2,13 +2,13 @@ import json from typing import Dict, Any, AsyncGenerator, List -from app.core.logger import get_gemini_logger -from app.services.chat.api_client import GeminiApiClient -from app.services.chat.stream_optimizer import gemini_optimizer -from app.schemas.gemini_models import GeminiRequest -from app.core.config import settings -from app.services.chat.response_handler import GeminiResponseHandler -from app.services.key_manager import KeyManager +from app.logger.logger import get_gemini_logger +from app.service.client.api_client import GeminiApiClient +from app.handler.stream_optimizer import gemini_optimizer +from app.domain.gemini_models import GeminiRequest +from app.config.config import settings +from app.handler.response_handler import GeminiResponseHandler +from app.service.key.key_manager import KeyManager logger = get_gemini_logger() diff --git a/app/services/openai_chat_service.py b/app/service/chat/openai_chat_service.py similarity index 95% rename from app/services/openai_chat_service.py rename to app/service/chat/openai_chat_service.py index f0462a7..5161796 100644 --- a/app/services/openai_chat_service.py +++ b/app/service/chat/openai_chat_service.py @@ -3,15 +3,15 @@ from copy import deepcopy import json from typing import Dict, Any, AsyncGenerator, List, Optional, Union -from app.core.logger import get_openai_logger -from app.services.chat.message_converter import OpenAIMessageConverter -from app.services.chat.response_handler import OpenAIResponseHandler -from app.services.chat.api_client import GeminiApiClient -from app.services.chat.stream_optimizer import openai_optimizer -from app.schemas.openai_models import ChatRequest, ImageGenerationRequest -from app.core.config import settings -from app.services.image_create_service import ImageCreateService -from app.services.key_manager import KeyManager +from app.logger.logger import get_openai_logger +from app.handler.message_converter import OpenAIMessageConverter +from app.handler.response_handler import OpenAIResponseHandler +from app.service.client.api_client import GeminiApiClient +from app.handler.stream_optimizer import openai_optimizer +from app.domain.openai_models import ChatRequest, ImageGenerationRequest +from app.config.config import settings +from app.service.image.image_create_service import ImageCreateService +from app.service.key.key_manager import KeyManager logger = get_openai_logger() diff --git a/app/services/chat/api_client.py b/app/service/client/api_client.py similarity index 95% rename from app/services/chat/api_client.py rename to app/service/client/api_client.py index 9469395..f540535 100644 --- a/app/services/chat/api_client.py +++ b/app/service/client/api_client.py @@ -4,6 +4,8 @@ from typing import Dict, Any, AsyncGenerator import httpx from abc import ABC, abstractmethod +from app.core.constants import DEFAULT_TIMEOUT + class ApiClient(ABC): """API客户端基类""" @@ -20,7 +22,7 @@ class ApiClient(ABC): class GeminiApiClient(ApiClient): """Gemini API客户端""" - def __init__(self, base_url: str, timeout: int = 300): + def __init__(self, base_url: str, timeout: int = DEFAULT_TIMEOUT): self.base_url = base_url self.timeout = timeout diff --git a/app/services/embedding_service.py b/app/service/embedding/embedding_service.py similarity index 93% rename from app/services/embedding_service.py rename to app/service/embedding/embedding_service.py index 8c3d3c0..a76079e 100644 --- a/app/services/embedding_service.py +++ b/app/service/embedding/embedding_service.py @@ -3,7 +3,7 @@ from typing import Union, List import openai from openai.types import CreateEmbeddingResponse -from app.core.logger import get_embeddings_logger +from app.logger.logger import get_embeddings_logger logger = get_embeddings_logger() diff --git a/app/services/image_create_service.py b/app/service/image/image_create_service.py similarity index 94% rename from app/services/image_create_service.py rename to app/service/image/image_create_service.py index d31f87f..e70291b 100644 --- a/app/services/image_create_service.py +++ b/app/service/image/image_create_service.py @@ -5,10 +5,11 @@ from google import genai from google.genai import types import base64 -from app.core.config import settings -from app.core.logger import get_image_create_logger -from app.core.uploader import ImageUploaderFactory -from app.schemas.openai_models import ImageGenerationRequest +from app.config.config import settings +from app.logger.logger import get_image_create_logger +from app.utils.uploader import ImageUploaderFactory +from app.domain.openai_models import ImageGenerationRequest +from app.core.constants import VALID_IMAGE_RATIOS logger = get_image_create_logger() @@ -43,10 +44,9 @@ class ImageCreateService: ratio_match = re.search(r'{ratio:(\d+:\d+)}', prompt) if ratio_match: aspect_ratio = ratio_match.group(1) - valid_ratios = ["1:1", "3:4", "4:3", "9:16", "16:9"] - if aspect_ratio not in valid_ratios: + if aspect_ratio not in VALID_IMAGE_RATIOS: raise ValueError( - f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(valid_ratios)}" + f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}" ) prompt = prompt.replace(ratio_match.group(0), '').strip() diff --git a/app/services/key_manager.py b/app/service/key/key_manager.py similarity index 97% rename from app/services/key_manager.py rename to app/service/key/key_manager.py index 6895fbe..3f1fadb 100644 --- a/app/services/key_manager.py +++ b/app/service/key/key_manager.py @@ -1,8 +1,8 @@ import asyncio from itertools import cycle from typing import Dict -from app.core.logger import get_key_manager_logger -from app.core.config import settings +from app.logger.logger import get_key_manager_logger +from app.config.config import settings logger = get_key_manager_logger() diff --git a/app/services/model_service.py b/app/service/model/model_service.py similarity index 95% rename from app/services/model_service.py rename to app/service/model/model_service.py index 7cf97be..e8b329d 100644 --- a/app/services/model_service.py +++ b/app/service/model/model_service.py @@ -1,8 +1,8 @@ import requests from datetime import datetime, timezone from typing import Optional, Dict, Any -from app.core.logger import get_model_logger -from app.core.config import settings +from app.logger.logger import get_model_logger +from app.config.config import settings logger = get_model_logger() @@ -10,7 +10,7 @@ class ModelService: def __init__(self, search_models: list, image_models: list): self.search_models = search_models self.image_models = image_models - self.base_url = "https://generativelanguage.googleapis.com/v1beta" + self.base_url = settings.BASE_URL self.filtered_models = settings.FILTERED_MODELS def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]: diff --git a/app/utils/__init__.py b/app/utils/__init__.py new file mode 100644 index 0000000..668784a --- /dev/null +++ b/app/utils/__init__.py @@ -0,0 +1,3 @@ +""" +工具包初始化模块 +""" diff --git a/app/utils/helpers.py b/app/utils/helpers.py new file mode 100644 index 0000000..957f177 --- /dev/null +++ b/app/utils/helpers.py @@ -0,0 +1,146 @@ +""" +通用工具函数模块 +""" +import json +import re +import base64 +import requests +from typing import Dict, Any, List, Optional, Tuple + +from app.core.constants import DATA_URL_PATTERN, IMAGE_URL_PATTERN, VALID_IMAGE_RATIOS + + +def extract_mime_type_and_data(base64_string: str) -> Tuple[Optional[str], str]: + """ + 从 base64 字符串中提取 MIME 类型和数据 + + Args: + base64_string: 可能包含 MIME 类型信息的 base64 字符串 + + Returns: + tuple: (mime_type, encoded_data) + """ + # 检查字符串是否以 "data:" 格式开始 + if base64_string.startswith('data:'): + # 提取 MIME 类型和数据 + pattern = DATA_URL_PATTERN + match = re.match(pattern, base64_string) + if match: + mime_type = "image/jpeg" if match.group(1) == "image/jpg" else match.group(1) + encoded_data = match.group(2) + return mime_type, encoded_data + + # 如果不是预期格式,假定它只是数据部分 + return None, base64_string + + +def convert_image_to_base64(url: str) -> str: + """ + 将图片URL转换为base64编码 + + Args: + url: 图片URL + + Returns: + str: base64编码的图片数据 + + Raises: + Exception: 如果获取图片失败 + """ + response = requests.get(url) + if response.status_code == 200: + # 将图片内容转换为base64 + img_data = base64.b64encode(response.content).decode('utf-8') + return img_data + else: + raise Exception(f"Failed to fetch image: {response.status_code}") + + +def format_json_response(data: Dict[str, Any], indent: int = 2) -> str: + """ + 格式化JSON响应 + + Args: + data: 要格式化的数据 + indent: 缩进空格数 + + Returns: + str: 格式化后的JSON字符串 + """ + return json.dumps(data, indent=indent, ensure_ascii=False) + + +def parse_prompt_parameters(prompt: str, default_ratio: str = "1:1") -> Tuple[str, int, str]: + """ + 从prompt中解析参数 + + 支持的格式: + - {n:数量} 例如: {n:2} 生成2张图片 + - {ratio:比例} 例如: {ratio:16:9} 使用16:9比例 + + Args: + prompt: 提示文本 + default_ratio: 默认比例 + + Returns: + tuple: (清理后的提示文本, 图片数量, 比例) + """ + # 默认值 + n = 1 + aspect_ratio = default_ratio + + # 解析n参数 + n_match = re.search(r'{n:(\d+)}', prompt) + if n_match: + n = int(n_match.group(1)) + if n < 1 or n > 4: + raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.") + prompt = prompt.replace(n_match.group(0), '').strip() + + # 解析ratio参数 + ratio_match = re.search(r'{ratio:(\d+:\d+)}', prompt) + if ratio_match: + aspect_ratio = ratio_match.group(1) + if aspect_ratio not in VALID_IMAGE_RATIOS: + raise ValueError( + f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}" + ) + prompt = prompt.replace(ratio_match.group(0), '').strip() + + return prompt, n, aspect_ratio + + +def extract_image_urls_from_markdown(text: str) -> List[str]: + """ + 从Markdown文本中提取图片URL + + Args: + text: Markdown文本 + + Returns: + List[str]: 图片URL列表 + """ + pattern = IMAGE_URL_PATTERN + matches = re.findall(pattern, text) + return [match[1] for match in matches] + + +def is_valid_api_key(key: str) -> bool: + """ + 检查API密钥格式是否有效 + + Args: + key: API密钥 + + Returns: + bool: 如果密钥格式有效则返回True + """ + # 检查Gemini API密钥格式 + if key.startswith('AIza'): + return len(key) >= 30 + + # 检查OpenAI API密钥格式 + if key.startswith('sk-'): + return len(key) >= 30 + + return False diff --git a/app/core/uploader.py b/app/utils/uploader.py similarity index 99% rename from app/core/uploader.py rename to app/utils/uploader.py index 3f2b6ab..37fee6c 100644 --- a/app/core/uploader.py +++ b/app/utils/uploader.py @@ -1,5 +1,5 @@ import requests -from app.schemas.image_models import ImageMetadata, ImageUploader, UploadResponse +from app.domain.image_models import ImageMetadata, ImageUploader, UploadResponse from enum import Enum from typing import Optional, Any