From b3a057b6ba75f06bf49762bc65b1807bc65429da Mon Sep 17 00:00:00 2001 From: snaily Date: Thu, 20 Mar 2025 21:59:18 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=BB=A3=E7=A0=81=E7=BB=93?= =?UTF-8?q?=E6=9E=84=E4=BC=98=E5=8C=96=E4=B8=8E=E5=B8=B8=E9=87=8F=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 将日志系统从 app/logger/ 移至 app/log/ 目录 将路由配置从 routers.py 重命名为 routes.py 将硬编码配置值移至 constants.py 中的默认常量 统一代码格式和导入排序 优化函数参数对齐方式 --- app/config/config.py | 22 ++-- app/core/application.py | 6 +- app/core/constants.py | 9 ++ app/core/initialization.py | 5 +- app/core/security.py | 23 ++-- app/exception/exceptions.py | 83 ++++++++------- app/handler/retry_handler.py | 17 +-- app/handler/stream_optimizer.py | 79 ++++++++------ app/{logger => log}/logger.py | 20 ++++ app/main.py | 4 +- app/middleware/middleware.py | 44 +++++--- app/middleware/request_logging_middleware.py | 10 +- app/router/gemini_routes.py | 2 +- app/router/openai_routes.py | 53 ++++++---- app/router/{routers.py => routes.py} | 43 +++++--- app/service/chat/gemini_chat_service.py | 77 +++++++++----- app/service/chat/openai_chat_service.py | 105 +++++++++++-------- app/service/embedding/embedding_service.py | 4 +- app/service/image/image_create_service.py | 83 ++++++++------- app/service/key/key_manager.py | 20 ++-- app/service/model/model_service.py | 15 +-- 21 files changed, 442 insertions(+), 282 deletions(-) rename app/{logger => log}/logger.py (88%) rename app/router/{routers.py => routes.py} (85%) diff --git a/app/config/config.py b/app/config/config.py index d12f6b8..8630cc6 100644 --- a/app/config/config.py +++ b/app/config/config.py @@ -4,7 +4,7 @@ from typing import List from pydantic_settings import BaseSettings -from app.core.constants import API_VERSION, DEFAULT_MODEL +from app.core.constants import API_VERSION, DEFAULT_CREATE_IMAGE_MODEL, DEFAULT_FILTER_MODELS, DEFAULT_MODEL, DEFAULT_STREAM_CHUNK_SIZE, DEFAULT_STREAM_LONG_TEXT_THRESHOLD, DEFAULT_STREAM_MAX_DELAY, DEFAULT_STREAM_MIN_DELAY, DEFAULT_STREAM_SHORT_TEXT_THRESHOLD class Settings(BaseSettings): @@ -20,20 +20,14 @@ class Settings(BaseSettings): # 模型相关配置 SEARCH_MODELS: List[str] = ["gemini-2.0-flash-exp"] IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp"] - FILTERED_MODELS: List[str] = [ - "gemini-1.0-pro-vision-latest", - "gemini-pro-vision", - "chat-bison-001", - "text-bison-001", - "embedding-gecko-001" - ] + FILTERED_MODELS: List[str] = DEFAULT_FILTER_MODELS TOOLS_CODE_EXECUTION_ENABLED: bool = False SHOW_SEARCH_LINK: bool = True SHOW_THINKING_PROCESS: bool = True # 图像生成相关配置 PAID_KEY: str = "" - CREATE_IMAGE_MODEL: str = "imagen-3.0-generate-002" + CREATE_IMAGE_MODEL: str = DEFAULT_CREATE_IMAGE_MODEL UPLOAD_PROVIDER: str = "smms" SMMS_SECRET_TOKEN: str = "" PICGO_API_KEY: str = "" @@ -41,11 +35,11 @@ class Settings(BaseSettings): CLOUDFLARE_IMGBED_AUTH_CODE: str = "" # 流式输出优化器配置 - STREAM_MIN_DELAY: float = 0.016 - STREAM_MAX_DELAY: float = 0.024 - STREAM_SHORT_TEXT_THRESHOLD: int = 10 - STREAM_LONG_TEXT_THRESHOLD: int = 50 - STREAM_CHUNK_SIZE: int = 5 + STREAM_MIN_DELAY: float = DEFAULT_STREAM_MIN_DELAY + STREAM_MAX_DELAY: float = DEFAULT_STREAM_MAX_DELAY + STREAM_SHORT_TEXT_THRESHOLD: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD + STREAM_LONG_TEXT_THRESHOLD: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD + STREAM_CHUNK_SIZE: int = DEFAULT_STREAM_CHUNK_SIZE def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/app/core/application.py b/app/core/application.py index 1d9f825..d3aa02b 100644 --- a/app/core/application.py +++ b/app/core/application.py @@ -6,14 +6,14 @@ from fastapi import FastAPI from fastapi.staticfiles import StaticFiles from app.config.config import settings -from app.logger.logger import get_main_logger +from app.log.logger import get_application_logger from app.middleware.middleware import setup_middlewares from app.exception.exceptions import setup_exception_handlers -from app.router.routers import setup_routers +from app.router.routes import setup_routers from app.service.key.key_manager import get_key_manager_instance from app.core.initialization import initialize_app -logger = get_main_logger() +logger = get_application_logger() @asynccontextmanager async def lifespan(app: FastAPI): diff --git a/app/core/constants.py b/app/core/constants.py index 0701048..32ccf78 100644 --- a/app/core/constants.py +++ b/app/core/constants.py @@ -13,12 +13,21 @@ DEFAULT_TEMPERATURE = 0.7 DEFAULT_MAX_TOKENS = 8192 DEFAULT_TOP_P = 0.9 DEFAULT_TOP_K = 40 +DEFAULT_FILTER_MODELS = [ + "gemini-1.0-pro-vision-latest", + "gemini-pro-vision", + "chat-bison-001", + "text-bison-001", + "embedding-gecko-001" + ] +DEFAULT_CREATE_IMAGE_MODEL = "imagen-3.0-generate-002" # 图像生成相关常量 VALID_IMAGE_RATIOS = ["1:1", "3:4", "4:3", "9:16", "16:9"] # 上传提供商 UPLOAD_PROVIDERS = ["smms", "picgo", "cloudflare_imgbed"] +DEFAULT_UPLOAD_PROVIDER = "smms" # 流式输出相关常量 DEFAULT_STREAM_MIN_DELAY = 0.016 diff --git a/app/core/initialization.py b/app/core/initialization.py index 40ce4b9..92bff80 100644 --- a/app/core/initialization.py +++ b/app/core/initialization.py @@ -1,11 +1,12 @@ """ 应用程序初始化模块 """ -import logging from pathlib import Path from typing import List -logger = logging.getLogger("initialization") +from app.log.logger import get_initialization_logger + +logger = get_initialization_logger() def ensure_directories_exist(directories: List[str]) -> None: diff --git a/app/core/security.py b/app/core/security.py index 2684c23..060928a 100644 --- a/app/core/security.py +++ b/app/core/security.py @@ -1,13 +1,17 @@ -from fastapi import HTTPException, Header from typing import Optional -from app.logger.logger import get_security_logger + +from fastapi import Header, HTTPException + from app.config.config import settings +from app.log.logger import get_security_logger logger = get_security_logger() + def verify_auth_token(token: str) -> bool: return token == settings.AUTH_TOKEN + class SecurityService: def __init__(self, allowed_tokens: list, auth_token: str): self.allowed_tokens = allowed_tokens @@ -20,7 +24,7 @@ class SecurityService: return key async def verify_authorization( - self, authorization: Optional[str] = Header(None) + self, authorization: Optional[str] = Header(None) ) -> str: if not authorization: logger.error("Missing Authorization header") @@ -39,19 +43,26 @@ class SecurityService: return token - async def verify_goog_api_key(self, x_goog_api_key: Optional[str] = Header(None)) -> str: + async def verify_goog_api_key( + self, x_goog_api_key: Optional[str] = Header(None) + ) -> str: """验证Google API Key""" if not x_goog_api_key: logger.error("Missing x-goog-api-key header") raise HTTPException(status_code=401, detail="Missing x-goog-api-key header") - if x_goog_api_key not in self.allowed_tokens and x_goog_api_key != self.auth_token: + if ( + x_goog_api_key not in self.allowed_tokens + and x_goog_api_key != self.auth_token + ): logger.error("Invalid x-goog-api-key") raise HTTPException(status_code=401, detail="Invalid x-goog-api-key") return x_goog_api_key - async def verify_auth_token(self, authorization: Optional[str] = Header(None)) -> str: + async def verify_auth_token( + self, authorization: Optional[str] = Header(None) + ) -> str: if not authorization: logger.error("Missing auth_token header") raise HTTPException(status_code=401, detail="Missing auth_token header") diff --git a/app/exception/exceptions.py b/app/exception/exceptions.py index ee86957..0e9fb30 100644 --- a/app/exception/exceptions.py +++ b/app/exception/exceptions.py @@ -1,18 +1,20 @@ """ 异常处理模块,定义应用程序中使用的自定义异常和异常处理器 """ -from fastapi import Request, FastAPI -from fastapi.responses import JSONResponse + +from fastapi import FastAPI, Request from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse from starlette.exceptions import HTTPException as StarletteHTTPException -from app.logger.logger import get_main_logger +from app.log.logger import get_exceptions_logger -logger = get_main_logger() +logger = get_exceptions_logger() class APIError(Exception): """API错误基类""" + def __init__(self, status_code: int, detail: str, error_code: str = None): self.status_code = status_code self.detail = detail @@ -22,90 +24,95 @@ class APIError(Exception): class AuthenticationError(APIError): """认证错误""" + def __init__(self, detail: str = "Authentication failed"): - super().__init__(status_code=401, detail=detail, error_code="authentication_error") + super().__init__( + status_code=401, detail=detail, error_code="authentication_error" + ) class AuthorizationError(APIError): """授权错误""" + def __init__(self, detail: str = "Not authorized to access this resource"): - super().__init__(status_code=403, detail=detail, error_code="authorization_error") + super().__init__( + status_code=403, detail=detail, error_code="authorization_error" + ) class ResourceNotFoundError(APIError): """资源未找到错误""" + def __init__(self, detail: str = "Resource not found"): - super().__init__(status_code=404, detail=detail, error_code="resource_not_found") + super().__init__( + status_code=404, detail=detail, error_code="resource_not_found" + ) class ModelNotSupportedError(APIError): """模型不支持错误""" + def __init__(self, model: str): super().__init__( - status_code=400, - detail=f"Model {model} is not supported", - error_code="model_not_supported" + status_code=400, + detail=f"Model {model} is not supported", + error_code="model_not_supported", ) class APIKeyError(APIError): """API密钥错误""" + def __init__(self, detail: str = "Invalid or expired API key"): super().__init__(status_code=401, detail=detail, error_code="api_key_error") class ServiceUnavailableError(APIError): """服务不可用错误""" + def __init__(self, detail: str = "Service temporarily unavailable"): - super().__init__(status_code=503, detail=detail, error_code="service_unavailable") + super().__init__( + status_code=503, detail=detail, error_code="service_unavailable" + ) def setup_exception_handlers(app: FastAPI) -> None: """ 设置应用程序的异常处理器 - + Args: app: FastAPI应用程序实例 """ + @app.exception_handler(APIError) async def api_error_handler(request: Request, exc: APIError): """处理API错误""" logger.error(f"API Error: {exc.detail} (Code: {exc.error_code})") return JSONResponse( status_code=exc.status_code, - content={ - "error": { - "code": exc.error_code, - "message": exc.detail - } - } + content={"error": {"code": exc.error_code, "message": exc.detail}}, ) - + @app.exception_handler(StarletteHTTPException) async def http_exception_handler(request: Request, exc: StarletteHTTPException): """处理HTTP异常""" logger.error(f"HTTP Exception: {exc.detail} (Status: {exc.status_code})") return JSONResponse( status_code=exc.status_code, - content={ - "error": { - "code": "http_error", - "message": exc.detail - } - } + content={"error": {"code": "http_error", "message": exc.detail}}, ) - + @app.exception_handler(RequestValidationError) - async def validation_exception_handler(request: Request, exc: RequestValidationError): + async def validation_exception_handler( + request: Request, exc: RequestValidationError + ): """处理请求验证错误""" error_details = [] for error in exc.errors(): - error_details.append({ - "loc": error["loc"], - "msg": error["msg"], - "type": error["type"] - }) - + error_details.append( + {"loc": error["loc"], "msg": error["msg"], "type": error["type"]} + ) + logger.error(f"Validation Error: {error_details}") return JSONResponse( status_code=422, @@ -113,11 +120,11 @@ def setup_exception_handlers(app: FastAPI) -> None: "error": { "code": "validation_error", "message": "Request validation failed", - "details": error_details + "details": error_details, } - } + }, ) - + @app.exception_handler(Exception) async def general_exception_handler(request: Request, exc: Exception): """处理通用异常""" @@ -127,7 +134,7 @@ def setup_exception_handlers(app: FastAPI) -> None: content={ "error": { "code": "internal_server_error", - "message": "An unexpected error occurred" + "message": "An unexpected error occurred", } - } + }, ) diff --git a/app/handler/retry_handler.py b/app/handler/retry_handler.py index 60b0c83..c56e57e 100644 --- a/app/handler/retry_handler.py +++ b/app/handler/retry_handler.py @@ -1,10 +1,11 @@ # app/services/chat/retry_handler.py -from typing import TypeVar, Callable from functools import wraps -from app.logger.logger import get_retry_logger +from typing import Callable, TypeVar -T = TypeVar('T') +from app.log.logger import get_retry_logger + +T = TypeVar("T") logger = get_retry_logger() @@ -25,17 +26,21 @@ class RetryHandler: return await func(*args, **kwargs) except Exception as e: last_exception = e - logger.warning(f"API call failed with error: {str(e)}. Attempt {attempt + 1} of {self.max_retries}") + logger.warning( + f"API call failed with error: {str(e)}. Attempt {attempt + 1} of {self.max_retries}" + ) # 从函数参数中获取 key_manager - key_manager = kwargs.get('key_manager') + key_manager = kwargs.get("key_manager") if key_manager: old_key = kwargs.get(self.key_arg) new_key = await key_manager.handle_api_failure(old_key) kwargs[self.key_arg] = new_key logger.info(f"Switched to new API key: {new_key}") - logger.error(f"All retry attempts failed, raising final exception: {str(last_exception)}") + logger.error( + f"All retry attempts failed, raising final exception: {str(last_exception)}" + ) raise last_exception return wrapper diff --git a/app/handler/stream_optimizer.py b/app/handler/stream_optimizer.py index caeb262..2356da3 100644 --- a/app/handler/stream_optimizer.py +++ b/app/handler/stream_optimizer.py @@ -2,10 +2,17 @@ import asyncio import math -from typing import Any, List, AsyncGenerator, Callable -from app.logger.logger import get_openai_logger, get_gemini_logger +from typing import Any, AsyncGenerator, Callable, List + from app.config.config import settings -from app.core.constants import DEFAULT_STREAM_CHUNK_SIZE, DEFAULT_STREAM_LONG_TEXT_THRESHOLD, DEFAULT_STREAM_MAX_DELAY, DEFAULT_STREAM_MIN_DELAY, DEFAULT_STREAM_SHORT_TEXT_THRESHOLD +from app.core.constants import ( + DEFAULT_STREAM_CHUNK_SIZE, + DEFAULT_STREAM_LONG_TEXT_THRESHOLD, + DEFAULT_STREAM_MAX_DELAY, + DEFAULT_STREAM_MIN_DELAY, + DEFAULT_STREAM_SHORT_TEXT_THRESHOLD, +) +from app.log.logger import get_gemini_logger, get_openai_logger logger_openai = get_openai_logger() logger_gemini = get_gemini_logger() @@ -13,19 +20,21 @@ logger_gemini = get_gemini_logger() class StreamOptimizer: """流式输出优化器 - + 提供流式输出优化功能,包括智能延迟调整和长文本分块输出。 """ - - def __init__(self, - logger=None, - min_delay: float = DEFAULT_STREAM_MIN_DELAY, - max_delay: float = DEFAULT_STREAM_MAX_DELAY, - short_text_threshold: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD, - long_text_threshold: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD, - chunk_size: int = DEFAULT_STREAM_CHUNK_SIZE): + + def __init__( + self, + logger=None, + min_delay: float = DEFAULT_STREAM_MIN_DELAY, + max_delay: float = DEFAULT_STREAM_MAX_DELAY, + short_text_threshold: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD, + long_text_threshold: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD, + chunk_size: int = DEFAULT_STREAM_CHUNK_SIZE, + ): """初始化流式输出优化器 - + 参数: logger: 日志记录器 min_delay: 最小延迟时间(秒) @@ -40,13 +49,13 @@ class StreamOptimizer: self.short_text_threshold = short_text_threshold self.long_text_threshold = long_text_threshold self.chunk_size = chunk_size - + def calculate_delay(self, text_length: int) -> float: """根据文本长度计算延迟时间 - + 参数: text_length: 文本长度 - + 返回: 延迟时间(秒) """ @@ -59,42 +68,48 @@ class StreamOptimizer: else: # 中等长度文本使用线性插值计算延迟 # 使用对数函数使延迟变化更平滑 - ratio = math.log(text_length / self.short_text_threshold) / math.log(self.long_text_threshold / self.short_text_threshold) + ratio = math.log(text_length / self.short_text_threshold) / math.log( + self.long_text_threshold / self.short_text_threshold + ) return self.max_delay - ratio * (self.max_delay - self.min_delay) - + def split_text_into_chunks(self, text: str) -> List[str]: """将文本分割成小块 - + 参数: text: 要分割的文本 - + 返回: 文本块列表 """ - return [text[i:i+self.chunk_size] for i in range(0, len(text), self.chunk_size)] - - async def optimize_stream_output(self, - text: str, - create_response_chunk: Callable[[str], Any], - format_chunk: Callable[[Any], str]) -> AsyncGenerator[str, None]: + return [ + text[i : i + self.chunk_size] for i in range(0, len(text), self.chunk_size) + ] + + async def optimize_stream_output( + self, + text: str, + create_response_chunk: Callable[[str], Any], + format_chunk: Callable[[Any], str], + ) -> AsyncGenerator[str, None]: """优化流式输出 - + 参数: text: 要输出的文本 create_response_chunk: 创建响应块的函数,接收文本,返回响应块 format_chunk: 格式化响应块的函数,接收响应块,返回格式化后的字符串 - + 返回: 异步生成器,生成格式化后的响应块 """ if not text: return - + # 计算智能延迟时间 delay = self.calculate_delay(len(text)) if self.logger: self.logger.info(f"Text length: {len(text)}, delay: {delay:.4f}s") - + # 根据文本长度决定输出方式 if len(text) >= self.long_text_threshold: # 长文本:分块输出 @@ -120,7 +135,7 @@ openai_optimizer = StreamOptimizer( max_delay=settings.STREAM_MAX_DELAY, short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD, long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD, - chunk_size=settings.STREAM_CHUNK_SIZE + chunk_size=settings.STREAM_CHUNK_SIZE, ) gemini_optimizer = StreamOptimizer( @@ -129,5 +144,5 @@ gemini_optimizer = StreamOptimizer( max_delay=settings.STREAM_MAX_DELAY, short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD, long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD, - chunk_size=settings.STREAM_CHUNK_SIZE + chunk_size=settings.STREAM_CHUNK_SIZE, ) diff --git a/app/logger/logger.py b/app/log/logger.py similarity index 88% rename from app/logger/logger.py rename to app/log/logger.py index d22607e..35e2fa5 100644 --- a/app/logger/logger.py +++ b/app/log/logger.py @@ -133,3 +133,23 @@ def get_retry_logger(): def get_image_create_logger(): return Logger.setup_logger("image_create") + + +def get_exceptions_logger(): + return Logger.setup_logger("exceptions") + + +def get_application_logger(): + return Logger.setup_logger("application") + + +def get_initialization_logger(): + return Logger.setup_logger("initialization") + + +def get_middleware_logger(): + return Logger.setup_logger("middleware") + + +def get_routes_logger(): + return Logger.setup_logger("routes") \ No newline at end of file diff --git a/app/main.py b/app/main.py index c9acbdd..1d1f6d1 100644 --- a/app/main.py +++ b/app/main.py @@ -1,9 +1,11 @@ """ 应用程序入口模块 """ + import uvicorn + from app.core.application import create_app -from app.logger.logger import get_main_logger +from app.log.logger import get_main_logger # 创建应用程序实例 app = create_app() diff --git a/app/middleware/middleware.py b/app/middleware/middleware.py index dc39576..e187223 100644 --- a/app/middleware/middleware.py +++ b/app/middleware/middleware.py @@ -1,60 +1,72 @@ """ 中间件配置模块,负责设置和配置应用程序的中间件 """ + from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse from starlette.middleware.base import BaseHTTPMiddleware -from app.logger.logger import get_main_logger -from app.core.security import verify_auth_token # from app.middleware.request_logging_middleware import RequestLoggingMiddleware from app.core.constants import API_VERSION +from app.core.security import verify_auth_token +from app.log.logger import get_middleware_logger + +logger = get_middleware_logger() -logger = get_main_logger() class AuthMiddleware(BaseHTTPMiddleware): """ 认证中间件,处理未经身份验证的请求 """ + async def dispatch(self, request: Request, call_next): # 允许特定路径绕过身份验证 - if (request.url.path not in ["/", "/auth"] and - not request.url.path.startswith("/static") and - not request.url.path.startswith("/gemini") and - not request.url.path.startswith("/v1") and - not request.url.path.startswith(f"/{API_VERSION}") and - not request.url.path.startswith("/health") and - not request.url.path.startswith("/hf")): - + if ( + request.url.path not in ["/", "/auth"] + and not request.url.path.startswith("/static") + and not request.url.path.startswith("/gemini") + and not request.url.path.startswith("/v1") + and not request.url.path.startswith(f"/{API_VERSION}") + and not request.url.path.startswith("/health") + and not request.url.path.startswith("/hf") + ): + auth_token = request.cookies.get("auth_token") if not auth_token or not verify_auth_token(auth_token): logger.warning(f"Unauthorized access attempt to {request.url.path}") return RedirectResponse(url="/") logger.debug("Request authenticated successfully") - + response = await call_next(request) return response + def setup_middlewares(app: FastAPI) -> None: """ 设置应用程序的中间件 - + Args: app: FastAPI应用程序实例 """ # 添加认证中间件 app.add_middleware(AuthMiddleware) - + # 添加请求日志中间件(可选,默认注释掉) # app.add_middleware(RequestLoggingMiddleware) - + # 配置CORS中间件 app.add_middleware( CORSMiddleware, allow_origins=["*"], # 生产环境建议配置具体的域名 allow_credentials=True, - allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], # 明确指定允许的HTTP方法 + allow_methods=[ + "GET", + "POST", + "PUT", + "DELETE", + "OPTIONS", + ], # 明确指定允许的HTTP方法 allow_headers=["*"], # 生产环境建议配置具体的请求头 expose_headers=["*"], # 允许前端访问的响应头 max_age=600, # 预检请求缓存时间(秒) diff --git a/app/middleware/request_logging_middleware.py b/app/middleware/request_logging_middleware.py index 814514e..c3d516b 100644 --- a/app/middleware/request_logging_middleware.py +++ b/app/middleware/request_logging_middleware.py @@ -1,7 +1,9 @@ +import json + from fastapi import Request from starlette.middleware.base import BaseHTTPMiddleware -import json -from app.logger.logger import get_request_logger + +from app.log.logger import get_request_logger logger = get_request_logger() @@ -20,7 +22,9 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware): # 尝试格式化JSON try: formatted_body = json.loads(body_str) - logger.info(f"Formatted request body:\n{json.dumps(formatted_body, indent=2, ensure_ascii=False)}") + logger.info( + f"Formatted request body:\n{json.dumps(formatted_body, indent=2, ensure_ascii=False)}" + ) except json.JSONDecodeError: logger.info("Request body is not valid JSON.") except Exception as e: diff --git a/app/router/gemini_routes.py b/app/router/gemini_routes.py index d10b517..06d6700 100644 --- a/app/router/gemini_routes.py +++ b/app/router/gemini_routes.py @@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse, JSONResponse from copy import deepcopy from app.config.config import settings -from app.logger.logger import get_gemini_logger +from app.log.logger import get_gemini_logger from app.core.security import SecurityService from app.domain.gemini_models import GeminiContent, GeminiRequest from app.service.chat.gemini_chat_service import GeminiChatService diff --git a/app/router/openai_routes.py b/app/router/openai_routes.py index 62211f9..a99dc34 100644 --- a/app/router/openai_routes.py +++ b/app/router/openai_routes.py @@ -1,37 +1,46 @@ -from fastapi import HTTPException, APIRouter, Depends +from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse from app.config.config import settings -from app.logger.logger import get_openai_logger from app.core.security import SecurityService -from app.domain.openai_models import ChatRequest, EmbeddingRequest, ImageGenerationRequest +from app.domain.openai_models import ( + ChatRequest, + EmbeddingRequest, + ImageGenerationRequest, +) from app.handler.retry_handler import RetryHandler +from app.log.logger import get_openai_logger +from app.service.chat.openai_chat_service import OpenAIChatService from app.service.embedding.embedding_service import EmbeddingService from app.service.image.image_create_service import ImageCreateService from app.service.key.key_manager import KeyManager, get_key_manager_instance from app.service.model.model_service import ModelService -from app.service.chat.openai_chat_service import OpenAIChatService router = APIRouter() logger = get_openai_logger() # 初始化服务 security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN) -model_service = ModelService(settings.SEARCH_MODELS,settings.IMAGE_MODELS) +model_service = ModelService(settings.SEARCH_MODELS, settings.IMAGE_MODELS) embedding_service = EmbeddingService(settings.BASE_URL) image_create_service = ImageCreateService() + async def get_key_manager(): return await get_key_manager_instance() -async def get_next_working_key_wrapper(key_manager: KeyManager = Depends(get_key_manager)): + +async def get_next_working_key_wrapper( + key_manager: KeyManager = Depends(get_key_manager), +): return await key_manager.get_next_working_key() + @router.get("/v1/models") @router.get("/hf/v1/models") async def list_models( _=Depends(security_service.verify_authorization), - key_manager: KeyManager = Depends(get_key_manager) + key_manager: KeyManager = Depends(get_key_manager), ): logger.info("-" * 50 + "list_models" + "-" * 50) logger.info("Handling models list request") @@ -41,7 +50,9 @@ async def list_models( return model_service.get_gemini_openai_models(api_key) except Exception as e: logger.error(f"Error getting models list: {str(e)}") - raise HTTPException(status_code=500, detail="Internal server error while fetching models list") from e + raise HTTPException( + status_code=500, detail="Internal server error while fetching models list" + ) from e @router.post("/v1/chat/completions") @@ -51,7 +62,7 @@ async def chat_completion( request: ChatRequest, _=Depends(security_service.verify_authorization), api_key: str = Depends(get_next_working_key_wrapper), - key_manager: KeyManager = Depends(get_key_manager) + key_manager: KeyManager = Depends(get_key_manager), ): # 如果model是imagen3,使用paid_key if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat": @@ -63,8 +74,10 @@ async def chat_completion( logger.info(f"Using API key: {api_key}") if not model_service.check_model_support(request.model): - raise HTTPException(status_code=400, detail=f"Model {request.model} is not supported") - + raise HTTPException( + status_code=400, detail=f"Model {request.model} is not supported" + ) + try: # 如果model是imagen3,使用paid_key if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat": @@ -80,6 +93,7 @@ async def chat_completion( logger.error(f"Chat completion failed after retries: {str(e)}") raise HTTPException(status_code=500, detail="Chat completion failed") from e + @router.post("/v1/images/generations") @router.post("/hf/v1/images/generations") async def generate_image( @@ -95,14 +109,17 @@ async def generate_image( return response except Exception as e: logger.error(f"Image generation request failed: {str(e)}") - raise HTTPException(status_code=500, detail="Image generation request failed") from e + raise HTTPException( + status_code=500, detail="Image generation request failed" + ) from e + @router.post("/v1/embeddings") @router.post("/hf/v1/embeddings") async def embedding( request: EmbeddingRequest, _=Depends(security_service.verify_authorization), - key_manager: KeyManager = Depends(get_key_manager) + key_manager: KeyManager = Depends(get_key_manager), ): logger.info("-" * 50 + "embedding" + "-" * 50) logger.info(f"Handling embedding request for model: {request.model}") @@ -118,11 +135,12 @@ async def embedding( logger.error(f"Embedding request failed: {str(e)}") raise HTTPException(status_code=500, detail="Embedding request failed") from e + @router.get("/v1/keys/list") @router.get("/hf/v1/keys/list") async def get_keys_list( _=Depends(security_service.verify_auth_token), - key_manager: KeyManager = Depends(get_key_manager) + key_manager: KeyManager = Depends(get_key_manager), ): """获取有效和无效的API key列表""" logger.info("-" * 50 + "get_keys_list" + "-" * 50) @@ -133,13 +151,12 @@ async def get_keys_list( "status": "success", "data": { "valid_keys": keys_status["valid_keys"], - "invalid_keys": keys_status["invalid_keys"] + "invalid_keys": keys_status["invalid_keys"], }, - "total": len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"]) + "total": len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"]), } except Exception as e: logger.error(f"Error getting keys list: {str(e)}") raise HTTPException( - status_code=500, - detail="Internal server error while fetching keys list" + status_code=500, detail="Internal server error while fetching keys list" ) from e diff --git a/app/router/routers.py b/app/router/routes.py similarity index 85% rename from app/router/routers.py rename to app/router/routes.py index 333a644..d842b08 100644 --- a/app/router/routers.py +++ b/app/router/routes.py @@ -1,24 +1,26 @@ """ 路由配置模块,负责设置和配置应用程序的路由 """ + from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.templating import Jinja2Templates -from app.logger.logger import get_main_logger from app.core.security import verify_auth_token +from app.log.logger import get_routes_logger from app.router import gemini_routes, openai_routes from app.service.key.key_manager import get_key_manager_instance -logger = get_main_logger() +logger = get_routes_logger() # 配置Jinja2模板 templates = Jinja2Templates(directory="app/templates") + def setup_routers(app: FastAPI) -> None: """ 设置应用程序的路由 - + Args: app: FastAPI应用程序实例 """ @@ -26,20 +28,22 @@ def setup_routers(app: FastAPI) -> None: app.include_router(openai_routes.router) app.include_router(gemini_routes.router) app.include_router(gemini_routes.router_v1beta) - + # 添加页面路由 setup_page_routes(app) - + # 添加健康检查路由 setup_health_routes(app) + def setup_page_routes(app: FastAPI) -> None: """ 设置页面相关的路由 - + Args: app: FastAPI应用程序实例 """ + @app.get("/", response_class=HTMLResponse) async def auth_page(request: Request): """认证页面""" @@ -54,11 +58,13 @@ def setup_page_routes(app: FastAPI) -> None: if not auth_token: logger.warning("Authentication attempt with empty token") return RedirectResponse(url="/", status_code=302) - + if verify_auth_token(auth_token): logger.info("Successful authentication") response = RedirectResponse(url="/keys", status_code=302) - response.set_cookie(key="auth_token", value=auth_token, httponly=True, max_age=3600) + response.set_cookie( + key="auth_token", value=auth_token, httponly=True, max_age=3600 + ) return response logger.warning("Failed authentication attempt with invalid token") return RedirectResponse(url="/", status_code=302) @@ -74,28 +80,33 @@ def setup_page_routes(app: FastAPI) -> None: if not auth_token or not verify_auth_token(auth_token): logger.warning("Unauthorized access attempt to keys page") return RedirectResponse(url="/", status_code=302) - + key_manager = await get_key_manager_instance() keys_status = await key_manager.get_keys_by_status() total = len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"]) logger.info(f"Keys status retrieved successfully. Total keys: {total}") - return templates.TemplateResponse("keys_status.html", { - "request": request, - "valid_keys": keys_status["valid_keys"], - "invalid_keys": keys_status["invalid_keys"], - "total": total - }) + return templates.TemplateResponse( + "keys_status.html", + { + "request": request, + "valid_keys": keys_status["valid_keys"], + "invalid_keys": keys_status["invalid_keys"], + "total": total, + }, + ) except Exception as e: logger.error(f"Error retrieving keys status: {str(e)}") raise + def setup_health_routes(app: FastAPI) -> None: """ 设置健康检查相关的路由 - + Args: app: FastAPI应用程序实例 """ + @app.get("/health") async def health_check(request: Request): """健康检查端点""" diff --git a/app/service/chat/gemini_chat_service.py b/app/service/chat/gemini_chat_service.py index 653264e..61c4b30 100644 --- a/app/service/chat/gemini_chat_service.py +++ b/app/service/chat/gemini_chat_service.py @@ -1,13 +1,14 @@ # app/services/chat_service.py import json -from typing import Dict, Any, AsyncGenerator, List -from app.logger.logger import get_gemini_logger -from app.service.client.api_client import GeminiApiClient -from app.handler.stream_optimizer import gemini_optimizer -from app.domain.gemini_models import GeminiRequest +from typing import Any, AsyncGenerator, Dict, List + from app.config.config import settings +from app.domain.gemini_models import GeminiRequest from app.handler.response_handler import GeminiResponseHandler +from app.handler.stream_optimizer import gemini_optimizer +from app.log.logger import get_gemini_logger +from app.service.client.api_client import GeminiApiClient from app.service.key.key_manager import KeyManager logger = get_gemini_logger() @@ -26,9 +27,11 @@ def _has_image_parts(contents: List[Dict[str, Any]]) -> bool: def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]: """构建工具""" tools = [] - if settings.TOOLS_CODE_EXECUTION_ENABLED and not ( - model.endswith("-search") or "-thinking" in model - ) and not _has_image_parts(payload.get("contents", [])): + if ( + settings.TOOLS_CODE_EXECUTION_ENABLED + and not (model.endswith("-search") or "-thinking" in model) + and not _has_image_parts(payload.get("contents", [])) + ): tools.append({"code_execution": {}}) if model.endswith("-search"): tools.append({"googleSearch": {}}) @@ -49,14 +52,14 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]: {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, - {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"} + {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"}, ] return [ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, - {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"} + {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}, ] @@ -68,12 +71,12 @@ def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]: "tools": _build_tools(model, request_dict), "safetySettings": _get_safety_settings(model), "generationConfig": request_dict.get("generationConfig", {}), - "systemInstruction": request_dict.get("systemInstruction", "") + "systemInstruction": request_dict.get("systemInstruction", ""), } - + if model.endswith("-image") or model.endswith("-image-generation"): payload.pop("systemInstruction") - payload["generationConfig"]["responseModalities"] = ["Text","Image"] + payload["generationConfig"]["responseModalities"] = ["Text", "Image"] return payload @@ -84,54 +87,68 @@ class GeminiChatService: self.api_client = GeminiApiClient(base_url) self.key_manager = key_manager self.response_handler = GeminiResponseHandler() - + def _extract_text_from_response(self, response: Dict[str, Any]) -> str: """从响应中提取文本内容""" if not response.get("candidates"): return "" - + candidate = response["candidates"][0] content = candidate.get("content", {}) parts = content.get("parts", []) - + if parts and "text" in parts[0]: return parts[0].get("text", "") return "" - - def _create_char_response(self, original_response: Dict[str, Any], text: str) -> Dict[str, Any]: + + def _create_char_response( + self, original_response: Dict[str, Any], text: str + ) -> Dict[str, Any]: """创建包含指定文本的响应""" response_copy = json.loads(json.dumps(original_response)) # 深拷贝 - if response_copy.get("candidates") and response_copy["candidates"][0].get("content", {}).get("parts"): + if response_copy.get("candidates") and response_copy["candidates"][0].get( + "content", {} + ).get("parts"): response_copy["candidates"][0]["content"]["parts"][0]["text"] = text return response_copy - async def generate_content(self, model: str, request: GeminiRequest, api_key: str) -> Dict[str, Any]: + async def generate_content( + self, model: str, request: GeminiRequest, api_key: str + ) -> Dict[str, Any]: """生成内容""" payload = _build_payload(model, request) response = await self.api_client.generate_content(payload, model, api_key) return self.response_handler.handle_response(response, model, stream=False) - async def stream_generate_content(self, model: str, request: GeminiRequest, api_key: str) -> AsyncGenerator[str, None]: + async def stream_generate_content( + self, model: str, request: GeminiRequest, api_key: str + ) -> AsyncGenerator[str, None]: """流式生成内容""" retries = 0 max_retries = 3 payload = _build_payload(model, request) while retries < max_retries: try: - async for line in self.api_client.stream_generate_content(payload, model, api_key): + async for line in self.api_client.stream_generate_content( + payload, model, api_key + ): # print(line) if line.startswith("data:"): line = line[6:] - response_data = self.response_handler.handle_response(json.loads(line), model, stream=True) + response_data = self.response_handler.handle_response( + json.loads(line), model, stream=True + ) text = self._extract_text_from_response(response_data) - + # 如果有文本内容,使用流式输出优化器处理 if text: # 使用流式输出优化器处理文本输出 - async for optimized_chunk in gemini_optimizer.optimize_stream_output( + async for ( + optimized_chunk + ) in gemini_optimizer.optimize_stream_output( text, lambda t: self._create_char_response(response_data, t), - lambda c: "data: " + json.dumps(c) + "\n\n" + lambda c: "data: " + json.dumps(c) + "\n\n", ): yield optimized_chunk else: @@ -141,9 +158,13 @@ class GeminiChatService: break except Exception as e: retries += 1 - logger.warning(f"Streaming API call failed with error: {str(e)}. Attempt {retries} of {max_retries}") + logger.warning( + f"Streaming API call failed with error: {str(e)}. Attempt {retries} of {max_retries}" + ) api_key = await self.key_manager.handle_api_failure(api_key) logger.info(f"Switched to new API key: {api_key}") if retries >= max_retries: - logger.error(f"Max retries ({max_retries}) reached for streaming. Raising error") + logger.error( + f"Max retries ({max_retries}) reached for streaming. Raising error" + ) break diff --git a/app/service/chat/openai_chat_service.py b/app/service/chat/openai_chat_service.py index 5161796..4727ea5 100644 --- a/app/service/chat/openai_chat_service.py +++ b/app/service/chat/openai_chat_service.py @@ -1,15 +1,16 @@ # app/services/chat_service.py -from copy import deepcopy import json -from typing import Dict, Any, AsyncGenerator, List, Optional, Union -from app.logger.logger import get_openai_logger +from copy import deepcopy +from typing import Any, AsyncGenerator, Dict, List, Optional, Union + +from app.config.config import settings +from app.domain.openai_models import ChatRequest, ImageGenerationRequest from app.handler.message_converter import OpenAIMessageConverter from app.handler.response_handler import OpenAIResponseHandler -from app.service.client.api_client import GeminiApiClient from app.handler.stream_optimizer import openai_optimizer -from app.domain.openai_models import ChatRequest, ImageGenerationRequest -from app.config.config import settings +from app.log.logger import get_openai_logger +from app.service.client.api_client import GeminiApiClient from app.service.image.image_create_service import ImageCreateService from app.service.key.key_manager import KeyManager @@ -27,16 +28,21 @@ def _has_image_parts(contents: List[Dict[str, Any]]) -> bool: def _build_tools( - request: ChatRequest, messages: List[Dict[str, Any]] + request: ChatRequest, messages: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: """构建工具""" tools = [] model = request.model if ( - settings.TOOLS_CODE_EXECUTION_ENABLED - and not (model.endswith("-search") or "-thinking" in model or model.endswith("-image") or model.endswith("-image-generation")) - and not _has_image_parts(messages) + settings.TOOLS_CODE_EXECUTION_ENABLED + and not ( + model.endswith("-search") + or "-thinking" in model + or model.endswith("-image") + or model.endswith("-image-generation") + ) + and not _has_image_parts(messages) ): tools.append({"code_execution": {}}) if model.endswith("-search"): @@ -52,7 +58,9 @@ def _build_tools( if tool.get("type", "") == "function" and tool.get("function"): function = deepcopy(tool.get("function")) parameters = function.get("parameters", {}) - if parameters.get("type") == "object" and not parameters.get("properties", {}): + if parameters.get("type") == "object" and not parameters.get( + "properties", {} + ): function.pop("parameters", None) function_declarations.append(function) @@ -66,7 +74,7 @@ def _build_tools( functions.append(item) tools.append({"functionDeclarations": functions}) - + return tools @@ -95,7 +103,9 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]: def _build_payload( - request: ChatRequest, messages: List[Dict[str, Any]], instruction: Optional[Dict[str, Any]] = None + request: ChatRequest, + messages: List[Dict[str, Any]], + instruction: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """构建请求payload""" payload = { @@ -111,8 +121,8 @@ def _build_payload( "safetySettings": _get_safety_settings(request.model), } if request.model.endswith("-image") or request.model.endswith("-image-generation"): - payload["generationConfig"]["responseModalities"] = ["Text","Image"] - + payload["generationConfig"]["responseModalities"] = ["Text", "Image"] + if ( instruction and isinstance(instruction, dict) @@ -128,24 +138,27 @@ def _build_payload( class OpenAIChatService: """聊天服务""" + def __init__(self, base_url: str, key_manager: KeyManager = None): self.message_converter = OpenAIMessageConverter() self.response_handler = OpenAIResponseHandler(config=None) self.api_client = GeminiApiClient(base_url) self.key_manager = key_manager self.image_create_service = ImageCreateService() - + def _extract_text_from_openai_chunk(self, chunk: Dict[str, Any]) -> str: """从OpenAI响应块中提取文本内容""" if not chunk.get("choices"): return "" - + choice = chunk["choices"][0] if "delta" in choice and "content" in choice["delta"]: return choice["delta"]["content"] return "" - - def _create_char_openai_chunk(self, original_chunk: Dict[str, Any], text: str) -> Dict[str, Any]: + + def _create_char_openai_chunk( + self, original_chunk: Dict[str, Any], text: str + ) -> Dict[str, Any]: """创建包含指定文本的OpenAI响应块""" chunk_copy = json.loads(json.dumps(original_chunk)) # 深拷贝 if chunk_copy.get("choices") and "delta" in chunk_copy["choices"][0]: @@ -153,9 +166,9 @@ class OpenAIChatService: return chunk_copy async def create_chat_completion( - self, - request: ChatRequest, - api_key: str, + self, + request: ChatRequest, + api_key: str, ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: """创建聊天完成""" # 转换消息格式 @@ -169,7 +182,7 @@ class OpenAIChatService: return await self._handle_normal_completion(request.model, payload, api_key) async def _handle_normal_completion( - self, model: str, payload: Dict[str, Any], api_key: str + self, model: str, payload: Dict[str, Any], api_key: str ) -> Dict[str, Any]: """处理普通聊天完成""" response = await self.api_client.generate_content(payload, model, api_key) @@ -178,7 +191,7 @@ class OpenAIChatService: ) async def _handle_stream_completion( - self, model: str, payload: Dict[str, Any], api_key: str + self, model: str, payload: Dict[str, Any], api_key: str ) -> AsyncGenerator[str, None]: """处理流式聊天完成,添加重试逻辑""" retries = 0 @@ -186,7 +199,7 @@ class OpenAIChatService: while retries < max_retries: try: async for line in self.api_client.stream_generate_content( - payload, model, api_key + payload, model, api_key ): # print(line) if line.startswith("data:"): @@ -199,10 +212,14 @@ class OpenAIChatService: text = self._extract_text_from_openai_chunk(openai_chunk) if text: # 使用流式输出优化器处理文本输出 - async for optimized_chunk in openai_optimizer.optimize_stream_output( + async for ( + optimized_chunk + ) in openai_optimizer.optimize_stream_output( text, - lambda t: self._create_char_openai_chunk(openai_chunk, t), - lambda c: f"data: {json.dumps(c)}\n\n" + lambda t: self._create_char_openai_chunk( + openai_chunk, t + ), + lambda c: f"data: {json.dumps(c)}\n\n", ): yield optimized_chunk else: @@ -228,21 +245,23 @@ class OpenAIChatService: break async def create_image_chat_completion( - self, - request: ChatRequest, + self, + request: ChatRequest, ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: - + image_generate_request = ImageGenerationRequest() image_generate_request.prompt = request.messages[-1]["content"] - image_res = self.image_create_service.generate_images_chat(image_generate_request) - + image_res = self.image_create_service.generate_images_chat( + image_generate_request + ) + if request.stream: - return self._handle_stream_image_completion(request.model,image_res) + return self._handle_stream_image_completion(request.model, image_res) else: - return self._handle_normal_image_completion(request.model,image_res) - + return self._handle_normal_image_completion(request.model, image_res) + async def _handle_stream_image_completion( - self, model: str, image_data: str + self, model: str, image_data: str ) -> AsyncGenerator[str, None]: if image_data: openai_chunk = self.response_handler.handle_image_chat_response( @@ -253,10 +272,12 @@ class OpenAIChatService: text = self._extract_text_from_openai_chunk(openai_chunk) if text: # 使用流式输出优化器处理文本输出 - async for optimized_chunk in openai_optimizer.optimize_stream_output( + async for ( + optimized_chunk + ) in openai_optimizer.optimize_stream_output( text, lambda t: self._create_char_openai_chunk(openai_chunk, t), - lambda c: f"data: {json.dumps(c)}\n\n" + lambda c: f"data: {json.dumps(c)}\n\n", ): yield optimized_chunk else: @@ -265,11 +286,11 @@ class OpenAIChatService: yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n" yield "data: [DONE]\n\n" logger.info("Image chat streaming completed successfully") - + def _handle_normal_image_completion( - self, model: str, image_data: str + self, model: str, image_data: str ) -> Dict[str, Any]: - + return self.response_handler.handle_image_chat_response( image_data, model, stream=False, finish_reason="stop" ) diff --git a/app/service/embedding/embedding_service.py b/app/service/embedding/embedding_service.py index a76079e..6823099 100644 --- a/app/service/embedding/embedding_service.py +++ b/app/service/embedding/embedding_service.py @@ -1,9 +1,9 @@ -from typing import Union, List +from typing import List, Union import openai from openai.types import CreateEmbeddingResponse -from app.logger.logger import get_embeddings_logger +from app.log.logger import get_embeddings_logger logger = get_embeddings_logger() diff --git a/app/service/image/image_create_service.py b/app/service/image/image_create_service.py index e70291b..92efcd9 100644 --- a/app/service/image/image_create_service.py +++ b/app/service/image/image_create_service.py @@ -1,15 +1,15 @@ +import base64 import time import uuid from google import genai from google.genai import types -import base64 from app.config.config import settings -from app.logger.logger import get_image_create_logger -from app.utils.uploader import ImageUploaderFactory -from app.domain.openai_models import ImageGenerationRequest from app.core.constants import VALID_IMAGE_RATIOS +from app.domain.openai_models import ImageGenerationRequest +from app.log.logger import get_image_create_logger +from app.utils.uploader import ImageUploaderFactory logger = get_image_create_logger() @@ -27,34 +27,34 @@ class ImageCreateService: - {ratio:比例} 例如: {ratio:16:9} 使用16:9比例 """ import re - + # 默认值 n = 1 aspect_ratio = self.aspect_ratio - + # 解析n参数 - n_match = re.search(r'{n:(\d+)}', prompt) + n_match = re.search(r"{n:(\d+)}", prompt) if n_match: n = int(n_match.group(1)) if n < 1 or n > 4: raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.") - prompt = prompt.replace(n_match.group(0), '').strip() - - # 解析ratio参数 - ratio_match = re.search(r'{ratio:(\d+:\d+)}', prompt) + prompt = prompt.replace(n_match.group(0), "").strip() + + # 解析ratio参数 + ratio_match = re.search(r"{ratio:(\d+:\d+)}", prompt) if ratio_match: aspect_ratio = ratio_match.group(1) if aspect_ratio not in VALID_IMAGE_RATIOS: raise ValueError( f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}" ) - prompt = prompt.replace(ratio_match.group(0), '').strip() - + prompt = prompt.replace(ratio_match.group(0), "").strip() + return prompt, n, aspect_ratio def generate_images(self, request: ImageGenerationRequest): client = genai.Client(api_key=self.paid_key) - + if request.size == "1024x1024": self.aspect_ratio = "1:1" elif request.size == "1792x1024": @@ -67,13 +67,15 @@ class ImageCreateService: ) # 解析prompt中的参数 - cleaned_prompt, prompt_n, prompt_ratio = self.parse_prompt_parameters(request.prompt) + cleaned_prompt, prompt_n, prompt_ratio = self.parse_prompt_parameters( + request.prompt + ) request.prompt = cleaned_prompt - + # 如果prompt中指定了n,则覆盖请求中的n if prompt_n > 1: request.n = prompt_n - + # 如果prompt中指定了ratio,则覆盖默认的aspect_ratio if prompt_ratio != self.aspect_ratio: self.aspect_ratio = prompt_ratio @@ -96,46 +98,49 @@ class ImageCreateService: for index, generated_image in enumerate(response.generated_images): image_data = generated_image.image.image_bytes image_uploader = None - + if request.response_format == "b64_json": - base64_image = base64.b64encode(image_data).decode('utf-8') - images_data.append({ - "b64_json": base64_image, - "revised_prompt": request.prompt - }) + base64_image = base64.b64encode(image_data).decode("utf-8") + images_data.append( + {"b64_json": base64_image, "revised_prompt": request.prompt} + ) else: current_date = time.strftime("%Y/%m/%d") filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png" - + if settings.UPLOAD_PROVIDER == "smms": image_uploader = ImageUploaderFactory.create( provider=settings.UPLOAD_PROVIDER, - api_key=settings.SMMS_SECRET_TOKEN + api_key=settings.SMMS_SECRET_TOKEN, ) elif settings.UPLOAD_PROVIDER == "picgo": image_uploader = ImageUploaderFactory.create( provider=settings.UPLOAD_PROVIDER, - api_key=settings.PICGO_API_KEY + api_key=settings.PICGO_API_KEY, ) elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed": image_uploader = ImageUploaderFactory.create( provider=settings.UPLOAD_PROVIDER, base_url=settings.CLOUDFLARE_IMGBED_URL, - auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE + auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE, ) else: - raise ValueError(f"Unsupported upload provider: {settings.UPLOAD_PROVIDER}") - + raise ValueError( + f"Unsupported upload provider: {settings.UPLOAD_PROVIDER}" + ) + upload_response = image_uploader.upload(image_data, filename) - images_data.append({ - "url": f"{upload_response.data.url}", - "revised_prompt": request.prompt - }) + images_data.append( + { + "url": f"{upload_response.data.url}", + "revised_prompt": request.prompt, + } + ) response_data = { "created": int(time.time()), # Current timestamp - "data": images_data + "data": images_data, } return response_data else: @@ -147,9 +152,13 @@ class ImageCreateService: if image_datas: markdown_images = [] for index, image_data in enumerate(image_datas): - if 'url' in image_data: - markdown_images.append(f"![Generated Image {index+1}]({image_data['url']})") + if "url" in image_data: + markdown_images.append( + f"![Generated Image {index+1}]({image_data['url']})" + ) else: # 如果是base64格式,创建data URL - markdown_images.append(f"![Generated Image {index+1}](data:image/png;base64,{image_data['b64_json']})") + markdown_images.append( + f"![Generated Image {index+1}](data:image/png;base64,{image_data['b64_json']})" + ) return "\n".join(markdown_images) diff --git a/app/service/key/key_manager.py b/app/service/key/key_manager.py index 3f1fadb..e50d3d3 100644 --- a/app/service/key/key_manager.py +++ b/app/service/key/key_manager.py @@ -1,9 +1,9 @@ import asyncio from itertools import cycle from typing import Dict -from app.logger.logger import get_key_manager_logger -from app.config.config import settings +from app.config.config import settings +from app.log.logger import get_key_manager_logger logger = get_key_manager_logger() @@ -20,7 +20,7 @@ class KeyManager: async def get_paid_key(self) -> str: return self.paid_key - + async def get_next_key(self) -> str: """获取下一个API key""" async with self.key_cycle_lock: @@ -70,7 +70,7 @@ class KeyManager: """获取分类后的API key列表,包括失败次数""" valid_keys = {} invalid_keys = {} - + async with self.failure_count_lock: for key in self.api_keys: fail_count = self.key_failure_counts[key] @@ -78,16 +78,14 @@ class KeyManager: valid_keys[key] = fail_count else: invalid_keys[key] = fail_count - - return { - "valid_keys": valid_keys, - "invalid_keys": invalid_keys - } - - + + return {"valid_keys": valid_keys, "invalid_keys": invalid_keys} + + _singleton_instance = None _singleton_lock = asyncio.Lock() + async def get_key_manager_instance(api_keys: list = None) -> KeyManager: """ 获取 KeyManager 单例实例。 diff --git a/app/service/model/model_service.py b/app/service/model/model_service.py index e8b329d..5bcc044 100644 --- a/app/service/model/model_service.py +++ b/app/service/model/model_service.py @@ -1,11 +1,14 @@ -import requests from datetime import datetime, timezone -from typing import Optional, Dict, Any -from app.logger.logger import get_model_logger +from typing import Any, Dict, Optional + +import requests + from app.config.config import settings +from app.log.logger import get_model_logger logger = get_model_logger() + class ModelService: def __init__(self, search_models: list, image_models: list): self.search_models = search_models @@ -20,7 +23,7 @@ class ModelService: response = requests.get(url) if response.status_code == 200: gemini_models = response.json() - + filtered_models_list = [] for model in gemini_models.get("models", []): model_id = model["name"].split("/")[-1] @@ -28,7 +31,7 @@ class ModelService: filtered_models_list.append(model) else: logger.info(f"Filtered out model: {model_id}") - + gemini_models["models"] = filtered_models_list return gemini_models else: @@ -48,7 +51,7 @@ class ModelService: return None def convert_to_openai_models_format( - self, gemini_models: Dict[str, Any] + self, gemini_models: Dict[str, Any] ) -> Dict[str, Any]: openai_format = {"object": "list", "data": [], "success": True}