mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-05-07 06:03:11 +08:00
refactor: 代码结构优化与常量化
将日志系统从 app/logger/ 移至 app/log/ 目录 将路由配置从 routers.py 重命名为 routes.py 将硬编码配置值移至 constants.py 中的默认常量 统一代码格式和导入排序 优化函数参数对齐方式
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
from typing import List
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from app.core.constants import API_VERSION, DEFAULT_MODEL
|
||||
from app.core.constants import API_VERSION, DEFAULT_CREATE_IMAGE_MODEL, DEFAULT_FILTER_MODELS, DEFAULT_MODEL, DEFAULT_STREAM_CHUNK_SIZE, DEFAULT_STREAM_LONG_TEXT_THRESHOLD, DEFAULT_STREAM_MAX_DELAY, DEFAULT_STREAM_MIN_DELAY, DEFAULT_STREAM_SHORT_TEXT_THRESHOLD
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
@@ -20,20 +20,14 @@ class Settings(BaseSettings):
|
||||
# 模型相关配置
|
||||
SEARCH_MODELS: List[str] = ["gemini-2.0-flash-exp"]
|
||||
IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp"]
|
||||
FILTERED_MODELS: List[str] = [
|
||||
"gemini-1.0-pro-vision-latest",
|
||||
"gemini-pro-vision",
|
||||
"chat-bison-001",
|
||||
"text-bison-001",
|
||||
"embedding-gecko-001"
|
||||
]
|
||||
FILTERED_MODELS: List[str] = DEFAULT_FILTER_MODELS
|
||||
TOOLS_CODE_EXECUTION_ENABLED: bool = False
|
||||
SHOW_SEARCH_LINK: bool = True
|
||||
SHOW_THINKING_PROCESS: bool = True
|
||||
|
||||
# 图像生成相关配置
|
||||
PAID_KEY: str = ""
|
||||
CREATE_IMAGE_MODEL: str = "imagen-3.0-generate-002"
|
||||
CREATE_IMAGE_MODEL: str = DEFAULT_CREATE_IMAGE_MODEL
|
||||
UPLOAD_PROVIDER: str = "smms"
|
||||
SMMS_SECRET_TOKEN: str = ""
|
||||
PICGO_API_KEY: str = ""
|
||||
@@ -41,11 +35,11 @@ class Settings(BaseSettings):
|
||||
CLOUDFLARE_IMGBED_AUTH_CODE: str = ""
|
||||
|
||||
# 流式输出优化器配置
|
||||
STREAM_MIN_DELAY: float = 0.016
|
||||
STREAM_MAX_DELAY: float = 0.024
|
||||
STREAM_SHORT_TEXT_THRESHOLD: int = 10
|
||||
STREAM_LONG_TEXT_THRESHOLD: int = 50
|
||||
STREAM_CHUNK_SIZE: int = 5
|
||||
STREAM_MIN_DELAY: float = DEFAULT_STREAM_MIN_DELAY
|
||||
STREAM_MAX_DELAY: float = DEFAULT_STREAM_MAX_DELAY
|
||||
STREAM_SHORT_TEXT_THRESHOLD: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD
|
||||
STREAM_LONG_TEXT_THRESHOLD: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD
|
||||
STREAM_CHUNK_SIZE: int = DEFAULT_STREAM_CHUNK_SIZE
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -6,14 +6,14 @@ from fastapi import FastAPI
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from app.config.config import settings
|
||||
from app.logger.logger import get_main_logger
|
||||
from app.log.logger import get_application_logger
|
||||
from app.middleware.middleware import setup_middlewares
|
||||
from app.exception.exceptions import setup_exception_handlers
|
||||
from app.router.routers import setup_routers
|
||||
from app.router.routes import setup_routers
|
||||
from app.service.key.key_manager import get_key_manager_instance
|
||||
from app.core.initialization import initialize_app
|
||||
|
||||
logger = get_main_logger()
|
||||
logger = get_application_logger()
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
|
||||
@@ -13,12 +13,21 @@ DEFAULT_TEMPERATURE = 0.7
|
||||
DEFAULT_MAX_TOKENS = 8192
|
||||
DEFAULT_TOP_P = 0.9
|
||||
DEFAULT_TOP_K = 40
|
||||
DEFAULT_FILTER_MODELS = [
|
||||
"gemini-1.0-pro-vision-latest",
|
||||
"gemini-pro-vision",
|
||||
"chat-bison-001",
|
||||
"text-bison-001",
|
||||
"embedding-gecko-001"
|
||||
]
|
||||
DEFAULT_CREATE_IMAGE_MODEL = "imagen-3.0-generate-002"
|
||||
|
||||
# 图像生成相关常量
|
||||
VALID_IMAGE_RATIOS = ["1:1", "3:4", "4:3", "9:16", "16:9"]
|
||||
|
||||
# 上传提供商
|
||||
UPLOAD_PROVIDERS = ["smms", "picgo", "cloudflare_imgbed"]
|
||||
DEFAULT_UPLOAD_PROVIDER = "smms"
|
||||
|
||||
# 流式输出相关常量
|
||||
DEFAULT_STREAM_MIN_DELAY = 0.016
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""
|
||||
应用程序初始化模块
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
logger = logging.getLogger("initialization")
|
||||
from app.log.logger import get_initialization_logger
|
||||
|
||||
logger = get_initialization_logger()
|
||||
|
||||
|
||||
def ensure_directories_exist(directories: List[str]) -> None:
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
from fastapi import HTTPException, Header
|
||||
from typing import Optional
|
||||
from app.logger.logger import get_security_logger
|
||||
|
||||
from fastapi import Header, HTTPException
|
||||
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_security_logger
|
||||
|
||||
logger = get_security_logger()
|
||||
|
||||
|
||||
def verify_auth_token(token: str) -> bool:
|
||||
return token == settings.AUTH_TOKEN
|
||||
|
||||
|
||||
class SecurityService:
|
||||
def __init__(self, allowed_tokens: list, auth_token: str):
|
||||
self.allowed_tokens = allowed_tokens
|
||||
@@ -20,7 +24,7 @@ class SecurityService:
|
||||
return key
|
||||
|
||||
async def verify_authorization(
|
||||
self, authorization: Optional[str] = Header(None)
|
||||
self, authorization: Optional[str] = Header(None)
|
||||
) -> str:
|
||||
if not authorization:
|
||||
logger.error("Missing Authorization header")
|
||||
@@ -39,19 +43,26 @@ class SecurityService:
|
||||
|
||||
return token
|
||||
|
||||
async def verify_goog_api_key(self, x_goog_api_key: Optional[str] = Header(None)) -> str:
|
||||
async def verify_goog_api_key(
|
||||
self, x_goog_api_key: Optional[str] = Header(None)
|
||||
) -> str:
|
||||
"""验证Google API Key"""
|
||||
if not x_goog_api_key:
|
||||
logger.error("Missing x-goog-api-key header")
|
||||
raise HTTPException(status_code=401, detail="Missing x-goog-api-key header")
|
||||
|
||||
if x_goog_api_key not in self.allowed_tokens and x_goog_api_key != self.auth_token:
|
||||
if (
|
||||
x_goog_api_key not in self.allowed_tokens
|
||||
and x_goog_api_key != self.auth_token
|
||||
):
|
||||
logger.error("Invalid x-goog-api-key")
|
||||
raise HTTPException(status_code=401, detail="Invalid x-goog-api-key")
|
||||
|
||||
return x_goog_api_key
|
||||
|
||||
async def verify_auth_token(self, authorization: Optional[str] = Header(None)) -> str:
|
||||
async def verify_auth_token(
|
||||
self, authorization: Optional[str] = Header(None)
|
||||
) -> str:
|
||||
if not authorization:
|
||||
logger.error("Missing auth_token header")
|
||||
raise HTTPException(status_code=401, detail="Missing auth_token header")
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
"""
|
||||
异常处理模块,定义应用程序中使用的自定义异常和异常处理器
|
||||
"""
|
||||
from fastapi import Request, FastAPI
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
|
||||
from app.logger.logger import get_main_logger
|
||||
from app.log.logger import get_exceptions_logger
|
||||
|
||||
logger = get_main_logger()
|
||||
logger = get_exceptions_logger()
|
||||
|
||||
|
||||
class APIError(Exception):
|
||||
"""API错误基类"""
|
||||
|
||||
def __init__(self, status_code: int, detail: str, error_code: str = None):
|
||||
self.status_code = status_code
|
||||
self.detail = detail
|
||||
@@ -22,90 +24,95 @@ class APIError(Exception):
|
||||
|
||||
class AuthenticationError(APIError):
|
||||
"""认证错误"""
|
||||
|
||||
def __init__(self, detail: str = "Authentication failed"):
|
||||
super().__init__(status_code=401, detail=detail, error_code="authentication_error")
|
||||
super().__init__(
|
||||
status_code=401, detail=detail, error_code="authentication_error"
|
||||
)
|
||||
|
||||
|
||||
class AuthorizationError(APIError):
|
||||
"""授权错误"""
|
||||
|
||||
def __init__(self, detail: str = "Not authorized to access this resource"):
|
||||
super().__init__(status_code=403, detail=detail, error_code="authorization_error")
|
||||
super().__init__(
|
||||
status_code=403, detail=detail, error_code="authorization_error"
|
||||
)
|
||||
|
||||
|
||||
class ResourceNotFoundError(APIError):
|
||||
"""资源未找到错误"""
|
||||
|
||||
def __init__(self, detail: str = "Resource not found"):
|
||||
super().__init__(status_code=404, detail=detail, error_code="resource_not_found")
|
||||
super().__init__(
|
||||
status_code=404, detail=detail, error_code="resource_not_found"
|
||||
)
|
||||
|
||||
|
||||
class ModelNotSupportedError(APIError):
|
||||
"""模型不支持错误"""
|
||||
|
||||
def __init__(self, model: str):
|
||||
super().__init__(
|
||||
status_code=400,
|
||||
detail=f"Model {model} is not supported",
|
||||
error_code="model_not_supported"
|
||||
status_code=400,
|
||||
detail=f"Model {model} is not supported",
|
||||
error_code="model_not_supported",
|
||||
)
|
||||
|
||||
|
||||
class APIKeyError(APIError):
|
||||
"""API密钥错误"""
|
||||
|
||||
def __init__(self, detail: str = "Invalid or expired API key"):
|
||||
super().__init__(status_code=401, detail=detail, error_code="api_key_error")
|
||||
|
||||
|
||||
class ServiceUnavailableError(APIError):
|
||||
"""服务不可用错误"""
|
||||
|
||||
def __init__(self, detail: str = "Service temporarily unavailable"):
|
||||
super().__init__(status_code=503, detail=detail, error_code="service_unavailable")
|
||||
super().__init__(
|
||||
status_code=503, detail=detail, error_code="service_unavailable"
|
||||
)
|
||||
|
||||
|
||||
def setup_exception_handlers(app: FastAPI) -> None:
|
||||
"""
|
||||
设置应用程序的异常处理器
|
||||
|
||||
|
||||
Args:
|
||||
app: FastAPI应用程序实例
|
||||
"""
|
||||
|
||||
@app.exception_handler(APIError)
|
||||
async def api_error_handler(request: Request, exc: APIError):
|
||||
"""处理API错误"""
|
||||
logger.error(f"API Error: {exc.detail} (Code: {exc.error_code})")
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={
|
||||
"error": {
|
||||
"code": exc.error_code,
|
||||
"message": exc.detail
|
||||
}
|
||||
}
|
||||
content={"error": {"code": exc.error_code, "message": exc.detail}},
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(StarletteHTTPException)
|
||||
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
|
||||
"""处理HTTP异常"""
|
||||
logger.error(f"HTTP Exception: {exc.detail} (Status: {exc.status_code})")
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={
|
||||
"error": {
|
||||
"code": "http_error",
|
||||
"message": exc.detail
|
||||
}
|
||||
}
|
||||
content={"error": {"code": "http_error", "message": exc.detail}},
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
async def validation_exception_handler(
|
||||
request: Request, exc: RequestValidationError
|
||||
):
|
||||
"""处理请求验证错误"""
|
||||
error_details = []
|
||||
for error in exc.errors():
|
||||
error_details.append({
|
||||
"loc": error["loc"],
|
||||
"msg": error["msg"],
|
||||
"type": error["type"]
|
||||
})
|
||||
|
||||
error_details.append(
|
||||
{"loc": error["loc"], "msg": error["msg"], "type": error["type"]}
|
||||
)
|
||||
|
||||
logger.error(f"Validation Error: {error_details}")
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
@@ -113,11 +120,11 @@ def setup_exception_handlers(app: FastAPI) -> None:
|
||||
"error": {
|
||||
"code": "validation_error",
|
||||
"message": "Request validation failed",
|
||||
"details": error_details
|
||||
"details": error_details,
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def general_exception_handler(request: Request, exc: Exception):
|
||||
"""处理通用异常"""
|
||||
@@ -127,7 +134,7 @@ def setup_exception_handlers(app: FastAPI) -> None:
|
||||
content={
|
||||
"error": {
|
||||
"code": "internal_server_error",
|
||||
"message": "An unexpected error occurred"
|
||||
"message": "An unexpected error occurred",
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
# app/services/chat/retry_handler.py
|
||||
|
||||
from typing import TypeVar, Callable
|
||||
from functools import wraps
|
||||
from app.logger.logger import get_retry_logger
|
||||
from typing import Callable, TypeVar
|
||||
|
||||
T = TypeVar('T')
|
||||
from app.log.logger import get_retry_logger
|
||||
|
||||
T = TypeVar("T")
|
||||
logger = get_retry_logger()
|
||||
|
||||
|
||||
@@ -25,17 +26,21 @@ class RetryHandler:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
logger.warning(f"API call failed with error: {str(e)}. Attempt {attempt + 1} of {self.max_retries}")
|
||||
logger.warning(
|
||||
f"API call failed with error: {str(e)}. Attempt {attempt + 1} of {self.max_retries}"
|
||||
)
|
||||
|
||||
# 从函数参数中获取 key_manager
|
||||
key_manager = kwargs.get('key_manager')
|
||||
key_manager = kwargs.get("key_manager")
|
||||
if key_manager:
|
||||
old_key = kwargs.get(self.key_arg)
|
||||
new_key = await key_manager.handle_api_failure(old_key)
|
||||
kwargs[self.key_arg] = new_key
|
||||
logger.info(f"Switched to new API key: {new_key}")
|
||||
|
||||
logger.error(f"All retry attempts failed, raising final exception: {str(last_exception)}")
|
||||
logger.error(
|
||||
f"All retry attempts failed, raising final exception: {str(last_exception)}"
|
||||
)
|
||||
raise last_exception
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -2,10 +2,17 @@
|
||||
|
||||
import asyncio
|
||||
import math
|
||||
from typing import Any, List, AsyncGenerator, Callable
|
||||
from app.logger.logger import get_openai_logger, get_gemini_logger
|
||||
from typing import Any, AsyncGenerator, Callable, List
|
||||
|
||||
from app.config.config import settings
|
||||
from app.core.constants import DEFAULT_STREAM_CHUNK_SIZE, DEFAULT_STREAM_LONG_TEXT_THRESHOLD, DEFAULT_STREAM_MAX_DELAY, DEFAULT_STREAM_MIN_DELAY, DEFAULT_STREAM_SHORT_TEXT_THRESHOLD
|
||||
from app.core.constants import (
|
||||
DEFAULT_STREAM_CHUNK_SIZE,
|
||||
DEFAULT_STREAM_LONG_TEXT_THRESHOLD,
|
||||
DEFAULT_STREAM_MAX_DELAY,
|
||||
DEFAULT_STREAM_MIN_DELAY,
|
||||
DEFAULT_STREAM_SHORT_TEXT_THRESHOLD,
|
||||
)
|
||||
from app.log.logger import get_gemini_logger, get_openai_logger
|
||||
|
||||
logger_openai = get_openai_logger()
|
||||
logger_gemini = get_gemini_logger()
|
||||
@@ -13,19 +20,21 @@ logger_gemini = get_gemini_logger()
|
||||
|
||||
class StreamOptimizer:
|
||||
"""流式输出优化器
|
||||
|
||||
|
||||
提供流式输出优化功能,包括智能延迟调整和长文本分块输出。
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
logger=None,
|
||||
min_delay: float = DEFAULT_STREAM_MIN_DELAY,
|
||||
max_delay: float = DEFAULT_STREAM_MAX_DELAY,
|
||||
short_text_threshold: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD,
|
||||
long_text_threshold: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD,
|
||||
chunk_size: int = DEFAULT_STREAM_CHUNK_SIZE):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logger=None,
|
||||
min_delay: float = DEFAULT_STREAM_MIN_DELAY,
|
||||
max_delay: float = DEFAULT_STREAM_MAX_DELAY,
|
||||
short_text_threshold: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD,
|
||||
long_text_threshold: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD,
|
||||
chunk_size: int = DEFAULT_STREAM_CHUNK_SIZE,
|
||||
):
|
||||
"""初始化流式输出优化器
|
||||
|
||||
|
||||
参数:
|
||||
logger: 日志记录器
|
||||
min_delay: 最小延迟时间(秒)
|
||||
@@ -40,13 +49,13 @@ class StreamOptimizer:
|
||||
self.short_text_threshold = short_text_threshold
|
||||
self.long_text_threshold = long_text_threshold
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
|
||||
def calculate_delay(self, text_length: int) -> float:
|
||||
"""根据文本长度计算延迟时间
|
||||
|
||||
|
||||
参数:
|
||||
text_length: 文本长度
|
||||
|
||||
|
||||
返回:
|
||||
延迟时间(秒)
|
||||
"""
|
||||
@@ -59,42 +68,48 @@ class StreamOptimizer:
|
||||
else:
|
||||
# 中等长度文本使用线性插值计算延迟
|
||||
# 使用对数函数使延迟变化更平滑
|
||||
ratio = math.log(text_length / self.short_text_threshold) / math.log(self.long_text_threshold / self.short_text_threshold)
|
||||
ratio = math.log(text_length / self.short_text_threshold) / math.log(
|
||||
self.long_text_threshold / self.short_text_threshold
|
||||
)
|
||||
return self.max_delay - ratio * (self.max_delay - self.min_delay)
|
||||
|
||||
|
||||
def split_text_into_chunks(self, text: str) -> List[str]:
|
||||
"""将文本分割成小块
|
||||
|
||||
|
||||
参数:
|
||||
text: 要分割的文本
|
||||
|
||||
|
||||
返回:
|
||||
文本块列表
|
||||
"""
|
||||
return [text[i:i+self.chunk_size] for i in range(0, len(text), self.chunk_size)]
|
||||
|
||||
async def optimize_stream_output(self,
|
||||
text: str,
|
||||
create_response_chunk: Callable[[str], Any],
|
||||
format_chunk: Callable[[Any], str]) -> AsyncGenerator[str, None]:
|
||||
return [
|
||||
text[i : i + self.chunk_size] for i in range(0, len(text), self.chunk_size)
|
||||
]
|
||||
|
||||
async def optimize_stream_output(
|
||||
self,
|
||||
text: str,
|
||||
create_response_chunk: Callable[[str], Any],
|
||||
format_chunk: Callable[[Any], str],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""优化流式输出
|
||||
|
||||
|
||||
参数:
|
||||
text: 要输出的文本
|
||||
create_response_chunk: 创建响应块的函数,接收文本,返回响应块
|
||||
format_chunk: 格式化响应块的函数,接收响应块,返回格式化后的字符串
|
||||
|
||||
|
||||
返回:
|
||||
异步生成器,生成格式化后的响应块
|
||||
"""
|
||||
if not text:
|
||||
return
|
||||
|
||||
|
||||
# 计算智能延迟时间
|
||||
delay = self.calculate_delay(len(text))
|
||||
if self.logger:
|
||||
self.logger.info(f"Text length: {len(text)}, delay: {delay:.4f}s")
|
||||
|
||||
|
||||
# 根据文本长度决定输出方式
|
||||
if len(text) >= self.long_text_threshold:
|
||||
# 长文本:分块输出
|
||||
@@ -120,7 +135,7 @@ openai_optimizer = StreamOptimizer(
|
||||
max_delay=settings.STREAM_MAX_DELAY,
|
||||
short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD,
|
||||
long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD,
|
||||
chunk_size=settings.STREAM_CHUNK_SIZE
|
||||
chunk_size=settings.STREAM_CHUNK_SIZE,
|
||||
)
|
||||
|
||||
gemini_optimizer = StreamOptimizer(
|
||||
@@ -129,5 +144,5 @@ gemini_optimizer = StreamOptimizer(
|
||||
max_delay=settings.STREAM_MAX_DELAY,
|
||||
short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD,
|
||||
long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD,
|
||||
chunk_size=settings.STREAM_CHUNK_SIZE
|
||||
chunk_size=settings.STREAM_CHUNK_SIZE,
|
||||
)
|
||||
|
||||
@@ -133,3 +133,23 @@ def get_retry_logger():
|
||||
|
||||
def get_image_create_logger():
|
||||
return Logger.setup_logger("image_create")
|
||||
|
||||
|
||||
def get_exceptions_logger():
|
||||
return Logger.setup_logger("exceptions")
|
||||
|
||||
|
||||
def get_application_logger():
|
||||
return Logger.setup_logger("application")
|
||||
|
||||
|
||||
def get_initialization_logger():
|
||||
return Logger.setup_logger("initialization")
|
||||
|
||||
|
||||
def get_middleware_logger():
|
||||
return Logger.setup_logger("middleware")
|
||||
|
||||
|
||||
def get_routes_logger():
|
||||
return Logger.setup_logger("routes")
|
||||
@@ -1,9 +1,11 @@
|
||||
"""
|
||||
应用程序入口模块
|
||||
"""
|
||||
|
||||
import uvicorn
|
||||
|
||||
from app.core.application import create_app
|
||||
from app.logger.logger import get_main_logger
|
||||
from app.log.logger import get_main_logger
|
||||
|
||||
# 创建应用程序实例
|
||||
app = create_app()
|
||||
|
||||
@@ -1,60 +1,72 @@
|
||||
"""
|
||||
中间件配置模块,负责设置和配置应用程序的中间件
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import RedirectResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.logger.logger import get_main_logger
|
||||
from app.core.security import verify_auth_token
|
||||
# from app.middleware.request_logging_middleware import RequestLoggingMiddleware
|
||||
from app.core.constants import API_VERSION
|
||||
from app.core.security import verify_auth_token
|
||||
from app.log.logger import get_middleware_logger
|
||||
|
||||
logger = get_middleware_logger()
|
||||
|
||||
logger = get_main_logger()
|
||||
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
认证中间件,处理未经身份验证的请求
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# 允许特定路径绕过身份验证
|
||||
if (request.url.path not in ["/", "/auth"] and
|
||||
not request.url.path.startswith("/static") and
|
||||
not request.url.path.startswith("/gemini") and
|
||||
not request.url.path.startswith("/v1") and
|
||||
not request.url.path.startswith(f"/{API_VERSION}") and
|
||||
not request.url.path.startswith("/health") and
|
||||
not request.url.path.startswith("/hf")):
|
||||
|
||||
if (
|
||||
request.url.path not in ["/", "/auth"]
|
||||
and not request.url.path.startswith("/static")
|
||||
and not request.url.path.startswith("/gemini")
|
||||
and not request.url.path.startswith("/v1")
|
||||
and not request.url.path.startswith(f"/{API_VERSION}")
|
||||
and not request.url.path.startswith("/health")
|
||||
and not request.url.path.startswith("/hf")
|
||||
):
|
||||
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning(f"Unauthorized access attempt to {request.url.path}")
|
||||
return RedirectResponse(url="/")
|
||||
logger.debug("Request authenticated successfully")
|
||||
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
def setup_middlewares(app: FastAPI) -> None:
|
||||
"""
|
||||
设置应用程序的中间件
|
||||
|
||||
|
||||
Args:
|
||||
app: FastAPI应用程序实例
|
||||
"""
|
||||
# 添加认证中间件
|
||||
app.add_middleware(AuthMiddleware)
|
||||
|
||||
|
||||
# 添加请求日志中间件(可选,默认注释掉)
|
||||
# app.add_middleware(RequestLoggingMiddleware)
|
||||
|
||||
|
||||
# 配置CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # 生产环境建议配置具体的域名
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], # 明确指定允许的HTTP方法
|
||||
allow_methods=[
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"DELETE",
|
||||
"OPTIONS",
|
||||
], # 明确指定允许的HTTP方法
|
||||
allow_headers=["*"], # 生产环境建议配置具体的请求头
|
||||
expose_headers=["*"], # 允许前端访问的响应头
|
||||
max_age=600, # 预检请求缓存时间(秒)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import json
|
||||
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
import json
|
||||
from app.logger.logger import get_request_logger
|
||||
|
||||
from app.log.logger import get_request_logger
|
||||
|
||||
logger = get_request_logger()
|
||||
|
||||
@@ -20,7 +22,9 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
# 尝试格式化JSON
|
||||
try:
|
||||
formatted_body = json.loads(body_str)
|
||||
logger.info(f"Formatted request body:\n{json.dumps(formatted_body, indent=2, ensure_ascii=False)}")
|
||||
logger.info(
|
||||
f"Formatted request body:\n{json.dumps(formatted_body, indent=2, ensure_ascii=False)}"
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
logger.info("Request body is not valid JSON.")
|
||||
except Exception as e:
|
||||
|
||||
@@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from copy import deepcopy
|
||||
from app.config.config import settings
|
||||
from app.logger.logger import get_gemini_logger
|
||||
from app.log.logger import get_gemini_logger
|
||||
from app.core.security import SecurityService
|
||||
from app.domain.gemini_models import GeminiContent, GeminiRequest
|
||||
from app.service.chat.gemini_chat_service import GeminiChatService
|
||||
|
||||
@@ -1,37 +1,46 @@
|
||||
from fastapi import HTTPException, APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.config.config import settings
|
||||
from app.logger.logger import get_openai_logger
|
||||
from app.core.security import SecurityService
|
||||
from app.domain.openai_models import ChatRequest, EmbeddingRequest, ImageGenerationRequest
|
||||
from app.domain.openai_models import (
|
||||
ChatRequest,
|
||||
EmbeddingRequest,
|
||||
ImageGenerationRequest,
|
||||
)
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.log.logger import get_openai_logger
|
||||
from app.service.chat.openai_chat_service import OpenAIChatService
|
||||
from app.service.embedding.embedding_service import EmbeddingService
|
||||
from app.service.image.image_create_service import ImageCreateService
|
||||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.service.model.model_service import ModelService
|
||||
from app.service.chat.openai_chat_service import OpenAIChatService
|
||||
|
||||
router = APIRouter()
|
||||
logger = get_openai_logger()
|
||||
|
||||
# 初始化服务
|
||||
security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN)
|
||||
model_service = ModelService(settings.SEARCH_MODELS,settings.IMAGE_MODELS)
|
||||
model_service = ModelService(settings.SEARCH_MODELS, settings.IMAGE_MODELS)
|
||||
embedding_service = EmbeddingService(settings.BASE_URL)
|
||||
image_create_service = ImageCreateService()
|
||||
|
||||
|
||||
async def get_key_manager():
|
||||
return await get_key_manager_instance()
|
||||
|
||||
async def get_next_working_key_wrapper(key_manager: KeyManager = Depends(get_key_manager)):
|
||||
|
||||
async def get_next_working_key_wrapper(
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
return await key_manager.get_next_working_key()
|
||||
|
||||
|
||||
@router.get("/v1/models")
|
||||
@router.get("/hf/v1/models")
|
||||
async def list_models(
|
||||
_=Depends(security_service.verify_authorization),
|
||||
key_manager: KeyManager = Depends(get_key_manager)
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
logger.info("-" * 50 + "list_models" + "-" * 50)
|
||||
logger.info("Handling models list request")
|
||||
@@ -41,7 +50,9 @@ async def list_models(
|
||||
return model_service.get_gemini_openai_models(api_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting models list: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error while fetching models list") from e
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Internal server error while fetching models list"
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
@@ -51,7 +62,7 @@ async def chat_completion(
|
||||
request: ChatRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
api_key: str = Depends(get_next_working_key_wrapper),
|
||||
key_manager: KeyManager = Depends(get_key_manager)
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
# 如果model是imagen3,使用paid_key
|
||||
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":
|
||||
@@ -63,8 +74,10 @@ async def chat_completion(
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not model_service.check_model_support(request.model):
|
||||
raise HTTPException(status_code=400, detail=f"Model {request.model} is not supported")
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Model {request.model} is not supported"
|
||||
)
|
||||
|
||||
try:
|
||||
# 如果model是imagen3,使用paid_key
|
||||
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":
|
||||
@@ -80,6 +93,7 @@ async def chat_completion(
|
||||
logger.error(f"Chat completion failed after retries: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Chat completion failed") from e
|
||||
|
||||
|
||||
@router.post("/v1/images/generations")
|
||||
@router.post("/hf/v1/images/generations")
|
||||
async def generate_image(
|
||||
@@ -95,14 +109,17 @@ async def generate_image(
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Image generation request failed: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Image generation request failed") from e
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Image generation request failed"
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/v1/embeddings")
|
||||
@router.post("/hf/v1/embeddings")
|
||||
async def embedding(
|
||||
request: EmbeddingRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
key_manager: KeyManager = Depends(get_key_manager)
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
logger.info("-" * 50 + "embedding" + "-" * 50)
|
||||
logger.info(f"Handling embedding request for model: {request.model}")
|
||||
@@ -118,11 +135,12 @@ async def embedding(
|
||||
logger.error(f"Embedding request failed: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Embedding request failed") from e
|
||||
|
||||
|
||||
@router.get("/v1/keys/list")
|
||||
@router.get("/hf/v1/keys/list")
|
||||
async def get_keys_list(
|
||||
_=Depends(security_service.verify_auth_token),
|
||||
key_manager: KeyManager = Depends(get_key_manager)
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
"""获取有效和无效的API key列表"""
|
||||
logger.info("-" * 50 + "get_keys_list" + "-" * 50)
|
||||
@@ -133,13 +151,12 @@ async def get_keys_list(
|
||||
"status": "success",
|
||||
"data": {
|
||||
"valid_keys": keys_status["valid_keys"],
|
||||
"invalid_keys": keys_status["invalid_keys"]
|
||||
"invalid_keys": keys_status["invalid_keys"],
|
||||
},
|
||||
"total": len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"])
|
||||
"total": len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"]),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting keys list: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal server error while fetching keys list"
|
||||
status_code=500, detail="Internal server error while fetching keys list"
|
||||
) from e
|
||||
|
||||
@@ -1,24 +1,26 @@
|
||||
"""
|
||||
路由配置模块,负责设置和配置应用程序的路由
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
from app.logger.logger import get_main_logger
|
||||
from app.core.security import verify_auth_token
|
||||
from app.log.logger import get_routes_logger
|
||||
from app.router import gemini_routes, openai_routes
|
||||
from app.service.key.key_manager import get_key_manager_instance
|
||||
|
||||
logger = get_main_logger()
|
||||
logger = get_routes_logger()
|
||||
|
||||
# 配置Jinja2模板
|
||||
templates = Jinja2Templates(directory="app/templates")
|
||||
|
||||
|
||||
def setup_routers(app: FastAPI) -> None:
|
||||
"""
|
||||
设置应用程序的路由
|
||||
|
||||
|
||||
Args:
|
||||
app: FastAPI应用程序实例
|
||||
"""
|
||||
@@ -26,20 +28,22 @@ def setup_routers(app: FastAPI) -> None:
|
||||
app.include_router(openai_routes.router)
|
||||
app.include_router(gemini_routes.router)
|
||||
app.include_router(gemini_routes.router_v1beta)
|
||||
|
||||
|
||||
# 添加页面路由
|
||||
setup_page_routes(app)
|
||||
|
||||
|
||||
# 添加健康检查路由
|
||||
setup_health_routes(app)
|
||||
|
||||
|
||||
def setup_page_routes(app: FastAPI) -> None:
|
||||
"""
|
||||
设置页面相关的路由
|
||||
|
||||
|
||||
Args:
|
||||
app: FastAPI应用程序实例
|
||||
"""
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def auth_page(request: Request):
|
||||
"""认证页面"""
|
||||
@@ -54,11 +58,13 @@ def setup_page_routes(app: FastAPI) -> None:
|
||||
if not auth_token:
|
||||
logger.warning("Authentication attempt with empty token")
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
|
||||
|
||||
if verify_auth_token(auth_token):
|
||||
logger.info("Successful authentication")
|
||||
response = RedirectResponse(url="/keys", status_code=302)
|
||||
response.set_cookie(key="auth_token", value=auth_token, httponly=True, max_age=3600)
|
||||
response.set_cookie(
|
||||
key="auth_token", value=auth_token, httponly=True, max_age=3600
|
||||
)
|
||||
return response
|
||||
logger.warning("Failed authentication attempt with invalid token")
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
@@ -74,28 +80,33 @@ def setup_page_routes(app: FastAPI) -> None:
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to keys page")
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
|
||||
|
||||
key_manager = await get_key_manager_instance()
|
||||
keys_status = await key_manager.get_keys_by_status()
|
||||
total = len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"])
|
||||
logger.info(f"Keys status retrieved successfully. Total keys: {total}")
|
||||
return templates.TemplateResponse("keys_status.html", {
|
||||
"request": request,
|
||||
"valid_keys": keys_status["valid_keys"],
|
||||
"invalid_keys": keys_status["invalid_keys"],
|
||||
"total": total
|
||||
})
|
||||
return templates.TemplateResponse(
|
||||
"keys_status.html",
|
||||
{
|
||||
"request": request,
|
||||
"valid_keys": keys_status["valid_keys"],
|
||||
"invalid_keys": keys_status["invalid_keys"],
|
||||
"total": total,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving keys status: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def setup_health_routes(app: FastAPI) -> None:
|
||||
"""
|
||||
设置健康检查相关的路由
|
||||
|
||||
|
||||
Args:
|
||||
app: FastAPI应用程序实例
|
||||
"""
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check(request: Request):
|
||||
"""健康检查端点"""
|
||||
@@ -1,13 +1,14 @@
|
||||
# app/services/chat_service.py
|
||||
|
||||
import json
|
||||
from typing import Dict, Any, AsyncGenerator, List
|
||||
from app.logger.logger import get_gemini_logger
|
||||
from app.service.client.api_client import GeminiApiClient
|
||||
from app.handler.stream_optimizer import gemini_optimizer
|
||||
from app.domain.gemini_models import GeminiRequest
|
||||
from typing import Any, AsyncGenerator, Dict, List
|
||||
|
||||
from app.config.config import settings
|
||||
from app.domain.gemini_models import GeminiRequest
|
||||
from app.handler.response_handler import GeminiResponseHandler
|
||||
from app.handler.stream_optimizer import gemini_optimizer
|
||||
from app.log.logger import get_gemini_logger
|
||||
from app.service.client.api_client import GeminiApiClient
|
||||
from app.service.key.key_manager import KeyManager
|
||||
|
||||
logger = get_gemini_logger()
|
||||
@@ -26,9 +27,11 @@ def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
|
||||
def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""构建工具"""
|
||||
tools = []
|
||||
if settings.TOOLS_CODE_EXECUTION_ENABLED and not (
|
||||
model.endswith("-search") or "-thinking" in model
|
||||
) and not _has_image_parts(payload.get("contents", [])):
|
||||
if (
|
||||
settings.TOOLS_CODE_EXECUTION_ENABLED
|
||||
and not (model.endswith("-search") or "-thinking" in model)
|
||||
and not _has_image_parts(payload.get("contents", []))
|
||||
):
|
||||
tools.append({"code_execution": {}})
|
||||
if model.endswith("-search"):
|
||||
tools.append({"googleSearch": {}})
|
||||
@@ -49,14 +52,14 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"}
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
|
||||
]
|
||||
return [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
|
||||
]
|
||||
|
||||
|
||||
@@ -68,12 +71,12 @@ def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
|
||||
"tools": _build_tools(model, request_dict),
|
||||
"safetySettings": _get_safety_settings(model),
|
||||
"generationConfig": request_dict.get("generationConfig", {}),
|
||||
"systemInstruction": request_dict.get("systemInstruction", "")
|
||||
"systemInstruction": request_dict.get("systemInstruction", ""),
|
||||
}
|
||||
|
||||
|
||||
if model.endswith("-image") or model.endswith("-image-generation"):
|
||||
payload.pop("systemInstruction")
|
||||
payload["generationConfig"]["responseModalities"] = ["Text","Image"]
|
||||
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
|
||||
return payload
|
||||
|
||||
|
||||
@@ -84,54 +87,68 @@ class GeminiChatService:
|
||||
self.api_client = GeminiApiClient(base_url)
|
||||
self.key_manager = key_manager
|
||||
self.response_handler = GeminiResponseHandler()
|
||||
|
||||
|
||||
def _extract_text_from_response(self, response: Dict[str, Any]) -> str:
|
||||
"""从响应中提取文本内容"""
|
||||
if not response.get("candidates"):
|
||||
return ""
|
||||
|
||||
|
||||
candidate = response["candidates"][0]
|
||||
content = candidate.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
|
||||
|
||||
if parts and "text" in parts[0]:
|
||||
return parts[0].get("text", "")
|
||||
return ""
|
||||
|
||||
def _create_char_response(self, original_response: Dict[str, Any], text: str) -> Dict[str, Any]:
|
||||
|
||||
def _create_char_response(
|
||||
self, original_response: Dict[str, Any], text: str
|
||||
) -> Dict[str, Any]:
|
||||
"""创建包含指定文本的响应"""
|
||||
response_copy = json.loads(json.dumps(original_response)) # 深拷贝
|
||||
if response_copy.get("candidates") and response_copy["candidates"][0].get("content", {}).get("parts"):
|
||||
if response_copy.get("candidates") and response_copy["candidates"][0].get(
|
||||
"content", {}
|
||||
).get("parts"):
|
||||
response_copy["candidates"][0]["content"]["parts"][0]["text"] = text
|
||||
return response_copy
|
||||
|
||||
async def generate_content(self, model: str, request: GeminiRequest, api_key: str) -> Dict[str, Any]:
|
||||
async def generate_content(
|
||||
self, model: str, request: GeminiRequest, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""生成内容"""
|
||||
payload = _build_payload(model, request)
|
||||
response = await self.api_client.generate_content(payload, model, api_key)
|
||||
return self.response_handler.handle_response(response, model, stream=False)
|
||||
|
||||
async def stream_generate_content(self, model: str, request: GeminiRequest, api_key: str) -> AsyncGenerator[str, None]:
|
||||
async def stream_generate_content(
|
||||
self, model: str, request: GeminiRequest, api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""流式生成内容"""
|
||||
retries = 0
|
||||
max_retries = 3
|
||||
payload = _build_payload(model, request)
|
||||
while retries < max_retries:
|
||||
try:
|
||||
async for line in self.api_client.stream_generate_content(payload, model, api_key):
|
||||
async for line in self.api_client.stream_generate_content(
|
||||
payload, model, api_key
|
||||
):
|
||||
# print(line)
|
||||
if line.startswith("data:"):
|
||||
line = line[6:]
|
||||
response_data = self.response_handler.handle_response(json.loads(line), model, stream=True)
|
||||
response_data = self.response_handler.handle_response(
|
||||
json.loads(line), model, stream=True
|
||||
)
|
||||
text = self._extract_text_from_response(response_data)
|
||||
|
||||
|
||||
# 如果有文本内容,使用流式输出优化器处理
|
||||
if text:
|
||||
# 使用流式输出优化器处理文本输出
|
||||
async for optimized_chunk in gemini_optimizer.optimize_stream_output(
|
||||
async for (
|
||||
optimized_chunk
|
||||
) in gemini_optimizer.optimize_stream_output(
|
||||
text,
|
||||
lambda t: self._create_char_response(response_data, t),
|
||||
lambda c: "data: " + json.dumps(c) + "\n\n"
|
||||
lambda c: "data: " + json.dumps(c) + "\n\n",
|
||||
):
|
||||
yield optimized_chunk
|
||||
else:
|
||||
@@ -141,9 +158,13 @@ class GeminiChatService:
|
||||
break
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
logger.warning(f"Streaming API call failed with error: {str(e)}. Attempt {retries} of {max_retries}")
|
||||
logger.warning(
|
||||
f"Streaming API call failed with error: {str(e)}. Attempt {retries} of {max_retries}"
|
||||
)
|
||||
api_key = await self.key_manager.handle_api_failure(api_key)
|
||||
logger.info(f"Switched to new API key: {api_key}")
|
||||
if retries >= max_retries:
|
||||
logger.error(f"Max retries ({max_retries}) reached for streaming. Raising error")
|
||||
logger.error(
|
||||
f"Max retries ({max_retries}) reached for streaming. Raising error"
|
||||
)
|
||||
break
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
# app/services/chat_service.py
|
||||
|
||||
from copy import deepcopy
|
||||
import json
|
||||
from typing import Dict, Any, AsyncGenerator, List, Optional, Union
|
||||
from app.logger.logger import get_openai_logger
|
||||
from copy import deepcopy
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
from app.config.config import settings
|
||||
from app.domain.openai_models import ChatRequest, ImageGenerationRequest
|
||||
from app.handler.message_converter import OpenAIMessageConverter
|
||||
from app.handler.response_handler import OpenAIResponseHandler
|
||||
from app.service.client.api_client import GeminiApiClient
|
||||
from app.handler.stream_optimizer import openai_optimizer
|
||||
from app.domain.openai_models import ChatRequest, ImageGenerationRequest
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_openai_logger
|
||||
from app.service.client.api_client import GeminiApiClient
|
||||
from app.service.image.image_create_service import ImageCreateService
|
||||
from app.service.key.key_manager import KeyManager
|
||||
|
||||
@@ -27,16 +28,21 @@ def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
|
||||
|
||||
|
||||
def _build_tools(
|
||||
request: ChatRequest, messages: List[Dict[str, Any]]
|
||||
request: ChatRequest, messages: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""构建工具"""
|
||||
tools = []
|
||||
model = request.model
|
||||
|
||||
if (
|
||||
settings.TOOLS_CODE_EXECUTION_ENABLED
|
||||
and not (model.endswith("-search") or "-thinking" in model or model.endswith("-image") or model.endswith("-image-generation"))
|
||||
and not _has_image_parts(messages)
|
||||
settings.TOOLS_CODE_EXECUTION_ENABLED
|
||||
and not (
|
||||
model.endswith("-search")
|
||||
or "-thinking" in model
|
||||
or model.endswith("-image")
|
||||
or model.endswith("-image-generation")
|
||||
)
|
||||
and not _has_image_parts(messages)
|
||||
):
|
||||
tools.append({"code_execution": {}})
|
||||
if model.endswith("-search"):
|
||||
@@ -52,7 +58,9 @@ def _build_tools(
|
||||
if tool.get("type", "") == "function" and tool.get("function"):
|
||||
function = deepcopy(tool.get("function"))
|
||||
parameters = function.get("parameters", {})
|
||||
if parameters.get("type") == "object" and not parameters.get("properties", {}):
|
||||
if parameters.get("type") == "object" and not parameters.get(
|
||||
"properties", {}
|
||||
):
|
||||
function.pop("parameters", None)
|
||||
|
||||
function_declarations.append(function)
|
||||
@@ -66,7 +74,7 @@ def _build_tools(
|
||||
functions.append(item)
|
||||
|
||||
tools.append({"functionDeclarations": functions})
|
||||
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
@@ -95,7 +103,9 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
|
||||
|
||||
def _build_payload(
|
||||
request: ChatRequest, messages: List[Dict[str, Any]], instruction: Optional[Dict[str, Any]] = None
|
||||
request: ChatRequest,
|
||||
messages: List[Dict[str, Any]],
|
||||
instruction: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""构建请求payload"""
|
||||
payload = {
|
||||
@@ -111,8 +121,8 @@ def _build_payload(
|
||||
"safetySettings": _get_safety_settings(request.model),
|
||||
}
|
||||
if request.model.endswith("-image") or request.model.endswith("-image-generation"):
|
||||
payload["generationConfig"]["responseModalities"] = ["Text","Image"]
|
||||
|
||||
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
|
||||
|
||||
if (
|
||||
instruction
|
||||
and isinstance(instruction, dict)
|
||||
@@ -128,24 +138,27 @@ def _build_payload(
|
||||
|
||||
class OpenAIChatService:
|
||||
"""聊天服务"""
|
||||
|
||||
def __init__(self, base_url: str, key_manager: KeyManager = None):
|
||||
self.message_converter = OpenAIMessageConverter()
|
||||
self.response_handler = OpenAIResponseHandler(config=None)
|
||||
self.api_client = GeminiApiClient(base_url)
|
||||
self.key_manager = key_manager
|
||||
self.image_create_service = ImageCreateService()
|
||||
|
||||
|
||||
def _extract_text_from_openai_chunk(self, chunk: Dict[str, Any]) -> str:
|
||||
"""从OpenAI响应块中提取文本内容"""
|
||||
if not chunk.get("choices"):
|
||||
return ""
|
||||
|
||||
|
||||
choice = chunk["choices"][0]
|
||||
if "delta" in choice and "content" in choice["delta"]:
|
||||
return choice["delta"]["content"]
|
||||
return ""
|
||||
|
||||
def _create_char_openai_chunk(self, original_chunk: Dict[str, Any], text: str) -> Dict[str, Any]:
|
||||
|
||||
def _create_char_openai_chunk(
|
||||
self, original_chunk: Dict[str, Any], text: str
|
||||
) -> Dict[str, Any]:
|
||||
"""创建包含指定文本的OpenAI响应块"""
|
||||
chunk_copy = json.loads(json.dumps(original_chunk)) # 深拷贝
|
||||
if chunk_copy.get("choices") and "delta" in chunk_copy["choices"][0]:
|
||||
@@ -153,9 +166,9 @@ class OpenAIChatService:
|
||||
return chunk_copy
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
request: ChatRequest,
|
||||
api_key: str,
|
||||
self,
|
||||
request: ChatRequest,
|
||||
api_key: str,
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
"""创建聊天完成"""
|
||||
# 转换消息格式
|
||||
@@ -169,7 +182,7 @@ class OpenAIChatService:
|
||||
return await self._handle_normal_completion(request.model, payload, api_key)
|
||||
|
||||
async def _handle_normal_completion(
|
||||
self, model: str, payload: Dict[str, Any], api_key: str
|
||||
self, model: str, payload: Dict[str, Any], api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""处理普通聊天完成"""
|
||||
response = await self.api_client.generate_content(payload, model, api_key)
|
||||
@@ -178,7 +191,7 @@ class OpenAIChatService:
|
||||
)
|
||||
|
||||
async def _handle_stream_completion(
|
||||
self, model: str, payload: Dict[str, Any], api_key: str
|
||||
self, model: str, payload: Dict[str, Any], api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""处理流式聊天完成,添加重试逻辑"""
|
||||
retries = 0
|
||||
@@ -186,7 +199,7 @@ class OpenAIChatService:
|
||||
while retries < max_retries:
|
||||
try:
|
||||
async for line in self.api_client.stream_generate_content(
|
||||
payload, model, api_key
|
||||
payload, model, api_key
|
||||
):
|
||||
# print(line)
|
||||
if line.startswith("data:"):
|
||||
@@ -199,10 +212,14 @@ class OpenAIChatService:
|
||||
text = self._extract_text_from_openai_chunk(openai_chunk)
|
||||
if text:
|
||||
# 使用流式输出优化器处理文本输出
|
||||
async for optimized_chunk in openai_optimizer.optimize_stream_output(
|
||||
async for (
|
||||
optimized_chunk
|
||||
) in openai_optimizer.optimize_stream_output(
|
||||
text,
|
||||
lambda t: self._create_char_openai_chunk(openai_chunk, t),
|
||||
lambda c: f"data: {json.dumps(c)}\n\n"
|
||||
lambda t: self._create_char_openai_chunk(
|
||||
openai_chunk, t
|
||||
),
|
||||
lambda c: f"data: {json.dumps(c)}\n\n",
|
||||
):
|
||||
yield optimized_chunk
|
||||
else:
|
||||
@@ -228,21 +245,23 @@ class OpenAIChatService:
|
||||
break
|
||||
|
||||
async def create_image_chat_completion(
|
||||
self,
|
||||
request: ChatRequest,
|
||||
self,
|
||||
request: ChatRequest,
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
|
||||
|
||||
image_generate_request = ImageGenerationRequest()
|
||||
image_generate_request.prompt = request.messages[-1]["content"]
|
||||
image_res = self.image_create_service.generate_images_chat(image_generate_request)
|
||||
|
||||
image_res = self.image_create_service.generate_images_chat(
|
||||
image_generate_request
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
return self._handle_stream_image_completion(request.model,image_res)
|
||||
return self._handle_stream_image_completion(request.model, image_res)
|
||||
else:
|
||||
return self._handle_normal_image_completion(request.model,image_res)
|
||||
|
||||
return self._handle_normal_image_completion(request.model, image_res)
|
||||
|
||||
async def _handle_stream_image_completion(
|
||||
self, model: str, image_data: str
|
||||
self, model: str, image_data: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
if image_data:
|
||||
openai_chunk = self.response_handler.handle_image_chat_response(
|
||||
@@ -253,10 +272,12 @@ class OpenAIChatService:
|
||||
text = self._extract_text_from_openai_chunk(openai_chunk)
|
||||
if text:
|
||||
# 使用流式输出优化器处理文本输出
|
||||
async for optimized_chunk in openai_optimizer.optimize_stream_output(
|
||||
async for (
|
||||
optimized_chunk
|
||||
) in openai_optimizer.optimize_stream_output(
|
||||
text,
|
||||
lambda t: self._create_char_openai_chunk(openai_chunk, t),
|
||||
lambda c: f"data: {json.dumps(c)}\n\n"
|
||||
lambda c: f"data: {json.dumps(c)}\n\n",
|
||||
):
|
||||
yield optimized_chunk
|
||||
else:
|
||||
@@ -265,11 +286,11 @@ class OpenAIChatService:
|
||||
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
logger.info("Image chat streaming completed successfully")
|
||||
|
||||
|
||||
def _handle_normal_image_completion(
|
||||
self, model: str, image_data: str
|
||||
self, model: str, image_data: str
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
|
||||
return self.response_handler.handle_image_chat_response(
|
||||
image_data, model, stream=False, finish_reason="stop"
|
||||
)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from typing import Union, List
|
||||
from typing import List, Union
|
||||
|
||||
import openai
|
||||
from openai.types import CreateEmbeddingResponse
|
||||
|
||||
from app.logger.logger import get_embeddings_logger
|
||||
from app.log.logger import get_embeddings_logger
|
||||
|
||||
logger = get_embeddings_logger()
|
||||
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
import base64
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
import base64
|
||||
|
||||
from app.config.config import settings
|
||||
from app.logger.logger import get_image_create_logger
|
||||
from app.utils.uploader import ImageUploaderFactory
|
||||
from app.domain.openai_models import ImageGenerationRequest
|
||||
from app.core.constants import VALID_IMAGE_RATIOS
|
||||
from app.domain.openai_models import ImageGenerationRequest
|
||||
from app.log.logger import get_image_create_logger
|
||||
from app.utils.uploader import ImageUploaderFactory
|
||||
|
||||
logger = get_image_create_logger()
|
||||
|
||||
@@ -27,34 +27,34 @@ class ImageCreateService:
|
||||
- {ratio:比例} 例如: {ratio:16:9} 使用16:9比例
|
||||
"""
|
||||
import re
|
||||
|
||||
|
||||
# 默认值
|
||||
n = 1
|
||||
aspect_ratio = self.aspect_ratio
|
||||
|
||||
|
||||
# 解析n参数
|
||||
n_match = re.search(r'{n:(\d+)}', prompt)
|
||||
n_match = re.search(r"{n:(\d+)}", prompt)
|
||||
if n_match:
|
||||
n = int(n_match.group(1))
|
||||
if n < 1 or n > 4:
|
||||
raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.")
|
||||
prompt = prompt.replace(n_match.group(0), '').strip()
|
||||
|
||||
# 解析ratio参数
|
||||
ratio_match = re.search(r'{ratio:(\d+:\d+)}', prompt)
|
||||
prompt = prompt.replace(n_match.group(0), "").strip()
|
||||
|
||||
# 解析ratio参数
|
||||
ratio_match = re.search(r"{ratio:(\d+:\d+)}", prompt)
|
||||
if ratio_match:
|
||||
aspect_ratio = ratio_match.group(1)
|
||||
if aspect_ratio not in VALID_IMAGE_RATIOS:
|
||||
raise ValueError(
|
||||
f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}"
|
||||
)
|
||||
prompt = prompt.replace(ratio_match.group(0), '').strip()
|
||||
|
||||
prompt = prompt.replace(ratio_match.group(0), "").strip()
|
||||
|
||||
return prompt, n, aspect_ratio
|
||||
|
||||
def generate_images(self, request: ImageGenerationRequest):
|
||||
client = genai.Client(api_key=self.paid_key)
|
||||
|
||||
|
||||
if request.size == "1024x1024":
|
||||
self.aspect_ratio = "1:1"
|
||||
elif request.size == "1792x1024":
|
||||
@@ -67,13 +67,15 @@ class ImageCreateService:
|
||||
)
|
||||
|
||||
# 解析prompt中的参数
|
||||
cleaned_prompt, prompt_n, prompt_ratio = self.parse_prompt_parameters(request.prompt)
|
||||
cleaned_prompt, prompt_n, prompt_ratio = self.parse_prompt_parameters(
|
||||
request.prompt
|
||||
)
|
||||
request.prompt = cleaned_prompt
|
||||
|
||||
|
||||
# 如果prompt中指定了n,则覆盖请求中的n
|
||||
if prompt_n > 1:
|
||||
request.n = prompt_n
|
||||
|
||||
|
||||
# 如果prompt中指定了ratio,则覆盖默认的aspect_ratio
|
||||
if prompt_ratio != self.aspect_ratio:
|
||||
self.aspect_ratio = prompt_ratio
|
||||
@@ -96,46 +98,49 @@ class ImageCreateService:
|
||||
for index, generated_image in enumerate(response.generated_images):
|
||||
image_data = generated_image.image.image_bytes
|
||||
image_uploader = None
|
||||
|
||||
|
||||
if request.response_format == "b64_json":
|
||||
base64_image = base64.b64encode(image_data).decode('utf-8')
|
||||
images_data.append({
|
||||
"b64_json": base64_image,
|
||||
"revised_prompt": request.prompt
|
||||
})
|
||||
base64_image = base64.b64encode(image_data).decode("utf-8")
|
||||
images_data.append(
|
||||
{"b64_json": base64_image, "revised_prompt": request.prompt}
|
||||
)
|
||||
else:
|
||||
current_date = time.strftime("%Y/%m/%d")
|
||||
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
|
||||
|
||||
|
||||
if settings.UPLOAD_PROVIDER == "smms":
|
||||
image_uploader = ImageUploaderFactory.create(
|
||||
provider=settings.UPLOAD_PROVIDER,
|
||||
api_key=settings.SMMS_SECRET_TOKEN
|
||||
api_key=settings.SMMS_SECRET_TOKEN,
|
||||
)
|
||||
elif settings.UPLOAD_PROVIDER == "picgo":
|
||||
image_uploader = ImageUploaderFactory.create(
|
||||
provider=settings.UPLOAD_PROVIDER,
|
||||
api_key=settings.PICGO_API_KEY
|
||||
api_key=settings.PICGO_API_KEY,
|
||||
)
|
||||
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
|
||||
image_uploader = ImageUploaderFactory.create(
|
||||
provider=settings.UPLOAD_PROVIDER,
|
||||
base_url=settings.CLOUDFLARE_IMGBED_URL,
|
||||
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE
|
||||
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported upload provider: {settings.UPLOAD_PROVIDER}")
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported upload provider: {settings.UPLOAD_PROVIDER}"
|
||||
)
|
||||
|
||||
upload_response = image_uploader.upload(image_data, filename)
|
||||
|
||||
images_data.append({
|
||||
"url": f"{upload_response.data.url}",
|
||||
"revised_prompt": request.prompt
|
||||
})
|
||||
images_data.append(
|
||||
{
|
||||
"url": f"{upload_response.data.url}",
|
||||
"revised_prompt": request.prompt,
|
||||
}
|
||||
)
|
||||
|
||||
response_data = {
|
||||
"created": int(time.time()), # Current timestamp
|
||||
"data": images_data
|
||||
"data": images_data,
|
||||
}
|
||||
return response_data
|
||||
else:
|
||||
@@ -147,9 +152,13 @@ class ImageCreateService:
|
||||
if image_datas:
|
||||
markdown_images = []
|
||||
for index, image_data in enumerate(image_datas):
|
||||
if 'url' in image_data:
|
||||
markdown_images.append(f"")
|
||||
if "url" in image_data:
|
||||
markdown_images.append(
|
||||
f""
|
||||
)
|
||||
else:
|
||||
# 如果是base64格式,创建data URL
|
||||
markdown_images.append(f"")
|
||||
markdown_images.append(
|
||||
f""
|
||||
)
|
||||
return "\n".join(markdown_images)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import asyncio
|
||||
from itertools import cycle
|
||||
from typing import Dict
|
||||
from app.logger.logger import get_key_manager_logger
|
||||
from app.config.config import settings
|
||||
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_key_manager_logger
|
||||
|
||||
logger = get_key_manager_logger()
|
||||
|
||||
@@ -20,7 +20,7 @@ class KeyManager:
|
||||
|
||||
async def get_paid_key(self) -> str:
|
||||
return self.paid_key
|
||||
|
||||
|
||||
async def get_next_key(self) -> str:
|
||||
"""获取下一个API key"""
|
||||
async with self.key_cycle_lock:
|
||||
@@ -70,7 +70,7 @@ class KeyManager:
|
||||
"""获取分类后的API key列表,包括失败次数"""
|
||||
valid_keys = {}
|
||||
invalid_keys = {}
|
||||
|
||||
|
||||
async with self.failure_count_lock:
|
||||
for key in self.api_keys:
|
||||
fail_count = self.key_failure_counts[key]
|
||||
@@ -78,16 +78,14 @@ class KeyManager:
|
||||
valid_keys[key] = fail_count
|
||||
else:
|
||||
invalid_keys[key] = fail_count
|
||||
|
||||
return {
|
||||
"valid_keys": valid_keys,
|
||||
"invalid_keys": invalid_keys
|
||||
}
|
||||
|
||||
|
||||
|
||||
return {"valid_keys": valid_keys, "invalid_keys": invalid_keys}
|
||||
|
||||
|
||||
_singleton_instance = None
|
||||
_singleton_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_key_manager_instance(api_keys: list = None) -> KeyManager:
|
||||
"""
|
||||
获取 KeyManager 单例实例。
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import requests
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Dict, Any
|
||||
from app.logger.logger import get_model_logger
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_model_logger
|
||||
|
||||
logger = get_model_logger()
|
||||
|
||||
|
||||
class ModelService:
|
||||
def __init__(self, search_models: list, image_models: list):
|
||||
self.search_models = search_models
|
||||
@@ -20,7 +23,7 @@ class ModelService:
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
gemini_models = response.json()
|
||||
|
||||
|
||||
filtered_models_list = []
|
||||
for model in gemini_models.get("models", []):
|
||||
model_id = model["name"].split("/")[-1]
|
||||
@@ -28,7 +31,7 @@ class ModelService:
|
||||
filtered_models_list.append(model)
|
||||
else:
|
||||
logger.info(f"Filtered out model: {model_id}")
|
||||
|
||||
|
||||
gemini_models["models"] = filtered_models_list
|
||||
return gemini_models
|
||||
else:
|
||||
@@ -48,7 +51,7 @@ class ModelService:
|
||||
return None
|
||||
|
||||
def convert_to_openai_models_format(
|
||||
self, gemini_models: Dict[str, Any]
|
||||
self, gemini_models: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
openai_format = {"object": "list", "data": [], "success": True}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user