refactor: 代码结构优化与常量化

将日志系统从 app/logger/ 移至 app/log/ 目录
将路由配置从 routers.py 重命名为 routes.py
将硬编码配置值移至 constants.py 中的默认常量
统一代码格式和导入排序
优化函数参数对齐方式
This commit is contained in:
snaily
2025-03-20 21:59:18 +08:00
parent b14bb93d8f
commit b3a057b6ba
21 changed files with 442 additions and 282 deletions

View File

@@ -4,7 +4,7 @@
from typing import List
from pydantic_settings import BaseSettings
from app.core.constants import API_VERSION, DEFAULT_MODEL
from app.core.constants import API_VERSION, DEFAULT_CREATE_IMAGE_MODEL, DEFAULT_FILTER_MODELS, DEFAULT_MODEL, DEFAULT_STREAM_CHUNK_SIZE, DEFAULT_STREAM_LONG_TEXT_THRESHOLD, DEFAULT_STREAM_MAX_DELAY, DEFAULT_STREAM_MIN_DELAY, DEFAULT_STREAM_SHORT_TEXT_THRESHOLD
class Settings(BaseSettings):
@@ -20,20 +20,14 @@ class Settings(BaseSettings):
# 模型相关配置
SEARCH_MODELS: List[str] = ["gemini-2.0-flash-exp"]
IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp"]
FILTERED_MODELS: List[str] = [
"gemini-1.0-pro-vision-latest",
"gemini-pro-vision",
"chat-bison-001",
"text-bison-001",
"embedding-gecko-001"
]
FILTERED_MODELS: List[str] = DEFAULT_FILTER_MODELS
TOOLS_CODE_EXECUTION_ENABLED: bool = False
SHOW_SEARCH_LINK: bool = True
SHOW_THINKING_PROCESS: bool = True
# 图像生成相关配置
PAID_KEY: str = ""
CREATE_IMAGE_MODEL: str = "imagen-3.0-generate-002"
CREATE_IMAGE_MODEL: str = DEFAULT_CREATE_IMAGE_MODEL
UPLOAD_PROVIDER: str = "smms"
SMMS_SECRET_TOKEN: str = ""
PICGO_API_KEY: str = ""
@@ -41,11 +35,11 @@ class Settings(BaseSettings):
CLOUDFLARE_IMGBED_AUTH_CODE: str = ""
# 流式输出优化器配置
STREAM_MIN_DELAY: float = 0.016
STREAM_MAX_DELAY: float = 0.024
STREAM_SHORT_TEXT_THRESHOLD: int = 10
STREAM_LONG_TEXT_THRESHOLD: int = 50
STREAM_CHUNK_SIZE: int = 5
STREAM_MIN_DELAY: float = DEFAULT_STREAM_MIN_DELAY
STREAM_MAX_DELAY: float = DEFAULT_STREAM_MAX_DELAY
STREAM_SHORT_TEXT_THRESHOLD: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD
STREAM_LONG_TEXT_THRESHOLD: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD
STREAM_CHUNK_SIZE: int = DEFAULT_STREAM_CHUNK_SIZE
def __init__(self, **kwargs):
super().__init__(**kwargs)

View File

@@ -6,14 +6,14 @@ from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from app.config.config import settings
from app.logger.logger import get_main_logger
from app.log.logger import get_application_logger
from app.middleware.middleware import setup_middlewares
from app.exception.exceptions import setup_exception_handlers
from app.router.routers import setup_routers
from app.router.routes import setup_routers
from app.service.key.key_manager import get_key_manager_instance
from app.core.initialization import initialize_app
logger = get_main_logger()
logger = get_application_logger()
@asynccontextmanager
async def lifespan(app: FastAPI):

View File

@@ -13,12 +13,21 @@ DEFAULT_TEMPERATURE = 0.7
DEFAULT_MAX_TOKENS = 8192
DEFAULT_TOP_P = 0.9
DEFAULT_TOP_K = 40
DEFAULT_FILTER_MODELS = [
"gemini-1.0-pro-vision-latest",
"gemini-pro-vision",
"chat-bison-001",
"text-bison-001",
"embedding-gecko-001"
]
DEFAULT_CREATE_IMAGE_MODEL = "imagen-3.0-generate-002"
# 图像生成相关常量
VALID_IMAGE_RATIOS = ["1:1", "3:4", "4:3", "9:16", "16:9"]
# 上传提供商
UPLOAD_PROVIDERS = ["smms", "picgo", "cloudflare_imgbed"]
DEFAULT_UPLOAD_PROVIDER = "smms"
# 流式输出相关常量
DEFAULT_STREAM_MIN_DELAY = 0.016

View File

@@ -1,11 +1,12 @@
"""
应用程序初始化模块
"""
import logging
from pathlib import Path
from typing import List
logger = logging.getLogger("initialization")
from app.log.logger import get_initialization_logger
logger = get_initialization_logger()
def ensure_directories_exist(directories: List[str]) -> None:

View File

@@ -1,13 +1,17 @@
from fastapi import HTTPException, Header
from typing import Optional
from app.logger.logger import get_security_logger
from fastapi import Header, HTTPException
from app.config.config import settings
from app.log.logger import get_security_logger
logger = get_security_logger()
def verify_auth_token(token: str) -> bool:
return token == settings.AUTH_TOKEN
class SecurityService:
def __init__(self, allowed_tokens: list, auth_token: str):
self.allowed_tokens = allowed_tokens
@@ -20,7 +24,7 @@ class SecurityService:
return key
async def verify_authorization(
self, authorization: Optional[str] = Header(None)
self, authorization: Optional[str] = Header(None)
) -> str:
if not authorization:
logger.error("Missing Authorization header")
@@ -39,19 +43,26 @@ class SecurityService:
return token
async def verify_goog_api_key(self, x_goog_api_key: Optional[str] = Header(None)) -> str:
async def verify_goog_api_key(
self, x_goog_api_key: Optional[str] = Header(None)
) -> str:
"""验证Google API Key"""
if not x_goog_api_key:
logger.error("Missing x-goog-api-key header")
raise HTTPException(status_code=401, detail="Missing x-goog-api-key header")
if x_goog_api_key not in self.allowed_tokens and x_goog_api_key != self.auth_token:
if (
x_goog_api_key not in self.allowed_tokens
and x_goog_api_key != self.auth_token
):
logger.error("Invalid x-goog-api-key")
raise HTTPException(status_code=401, detail="Invalid x-goog-api-key")
return x_goog_api_key
async def verify_auth_token(self, authorization: Optional[str] = Header(None)) -> str:
async def verify_auth_token(
self, authorization: Optional[str] = Header(None)
) -> str:
if not authorization:
logger.error("Missing auth_token header")
raise HTTPException(status_code=401, detail="Missing auth_token header")

View File

@@ -1,18 +1,20 @@
"""
异常处理模块,定义应用程序中使用的自定义异常和异常处理器
"""
from fastapi import Request, FastAPI
from fastapi.responses import JSONResponse
from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from starlette.exceptions import HTTPException as StarletteHTTPException
from app.logger.logger import get_main_logger
from app.log.logger import get_exceptions_logger
logger = get_main_logger()
logger = get_exceptions_logger()
class APIError(Exception):
"""API错误基类"""
def __init__(self, status_code: int, detail: str, error_code: str = None):
self.status_code = status_code
self.detail = detail
@@ -22,90 +24,95 @@ class APIError(Exception):
class AuthenticationError(APIError):
"""认证错误"""
def __init__(self, detail: str = "Authentication failed"):
super().__init__(status_code=401, detail=detail, error_code="authentication_error")
super().__init__(
status_code=401, detail=detail, error_code="authentication_error"
)
class AuthorizationError(APIError):
"""授权错误"""
def __init__(self, detail: str = "Not authorized to access this resource"):
super().__init__(status_code=403, detail=detail, error_code="authorization_error")
super().__init__(
status_code=403, detail=detail, error_code="authorization_error"
)
class ResourceNotFoundError(APIError):
"""资源未找到错误"""
def __init__(self, detail: str = "Resource not found"):
super().__init__(status_code=404, detail=detail, error_code="resource_not_found")
super().__init__(
status_code=404, detail=detail, error_code="resource_not_found"
)
class ModelNotSupportedError(APIError):
"""模型不支持错误"""
def __init__(self, model: str):
super().__init__(
status_code=400,
detail=f"Model {model} is not supported",
error_code="model_not_supported"
status_code=400,
detail=f"Model {model} is not supported",
error_code="model_not_supported",
)
class APIKeyError(APIError):
"""API密钥错误"""
def __init__(self, detail: str = "Invalid or expired API key"):
super().__init__(status_code=401, detail=detail, error_code="api_key_error")
class ServiceUnavailableError(APIError):
"""服务不可用错误"""
def __init__(self, detail: str = "Service temporarily unavailable"):
super().__init__(status_code=503, detail=detail, error_code="service_unavailable")
super().__init__(
status_code=503, detail=detail, error_code="service_unavailable"
)
def setup_exception_handlers(app: FastAPI) -> None:
"""
设置应用程序的异常处理器
Args:
app: FastAPI应用程序实例
"""
@app.exception_handler(APIError)
async def api_error_handler(request: Request, exc: APIError):
"""处理API错误"""
logger.error(f"API Error: {exc.detail} (Code: {exc.error_code})")
return JSONResponse(
status_code=exc.status_code,
content={
"error": {
"code": exc.error_code,
"message": exc.detail
}
}
content={"error": {"code": exc.error_code, "message": exc.detail}},
)
@app.exception_handler(StarletteHTTPException)
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
"""处理HTTP异常"""
logger.error(f"HTTP Exception: {exc.detail} (Status: {exc.status_code})")
return JSONResponse(
status_code=exc.status_code,
content={
"error": {
"code": "http_error",
"message": exc.detail
}
}
content={"error": {"code": "http_error", "message": exc.detail}},
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
async def validation_exception_handler(
request: Request, exc: RequestValidationError
):
"""处理请求验证错误"""
error_details = []
for error in exc.errors():
error_details.append({
"loc": error["loc"],
"msg": error["msg"],
"type": error["type"]
})
error_details.append(
{"loc": error["loc"], "msg": error["msg"], "type": error["type"]}
)
logger.error(f"Validation Error: {error_details}")
return JSONResponse(
status_code=422,
@@ -113,11 +120,11 @@ def setup_exception_handlers(app: FastAPI) -> None:
"error": {
"code": "validation_error",
"message": "Request validation failed",
"details": error_details
"details": error_details,
}
}
},
)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
"""处理通用异常"""
@@ -127,7 +134,7 @@ def setup_exception_handlers(app: FastAPI) -> None:
content={
"error": {
"code": "internal_server_error",
"message": "An unexpected error occurred"
"message": "An unexpected error occurred",
}
}
},
)

View File

@@ -1,10 +1,11 @@
# app/services/chat/retry_handler.py
from typing import TypeVar, Callable
from functools import wraps
from app.logger.logger import get_retry_logger
from typing import Callable, TypeVar
T = TypeVar('T')
from app.log.logger import get_retry_logger
T = TypeVar("T")
logger = get_retry_logger()
@@ -25,17 +26,21 @@ class RetryHandler:
return await func(*args, **kwargs)
except Exception as e:
last_exception = e
logger.warning(f"API call failed with error: {str(e)}. Attempt {attempt + 1} of {self.max_retries}")
logger.warning(
f"API call failed with error: {str(e)}. Attempt {attempt + 1} of {self.max_retries}"
)
# 从函数参数中获取 key_manager
key_manager = kwargs.get('key_manager')
key_manager = kwargs.get("key_manager")
if key_manager:
old_key = kwargs.get(self.key_arg)
new_key = await key_manager.handle_api_failure(old_key)
kwargs[self.key_arg] = new_key
logger.info(f"Switched to new API key: {new_key}")
logger.error(f"All retry attempts failed, raising final exception: {str(last_exception)}")
logger.error(
f"All retry attempts failed, raising final exception: {str(last_exception)}"
)
raise last_exception
return wrapper

View File

@@ -2,10 +2,17 @@
import asyncio
import math
from typing import Any, List, AsyncGenerator, Callable
from app.logger.logger import get_openai_logger, get_gemini_logger
from typing import Any, AsyncGenerator, Callable, List
from app.config.config import settings
from app.core.constants import DEFAULT_STREAM_CHUNK_SIZE, DEFAULT_STREAM_LONG_TEXT_THRESHOLD, DEFAULT_STREAM_MAX_DELAY, DEFAULT_STREAM_MIN_DELAY, DEFAULT_STREAM_SHORT_TEXT_THRESHOLD
from app.core.constants import (
DEFAULT_STREAM_CHUNK_SIZE,
DEFAULT_STREAM_LONG_TEXT_THRESHOLD,
DEFAULT_STREAM_MAX_DELAY,
DEFAULT_STREAM_MIN_DELAY,
DEFAULT_STREAM_SHORT_TEXT_THRESHOLD,
)
from app.log.logger import get_gemini_logger, get_openai_logger
logger_openai = get_openai_logger()
logger_gemini = get_gemini_logger()
@@ -13,19 +20,21 @@ logger_gemini = get_gemini_logger()
class StreamOptimizer:
"""流式输出优化器
提供流式输出优化功能,包括智能延迟调整和长文本分块输出。
"""
def __init__(self,
logger=None,
min_delay: float = DEFAULT_STREAM_MIN_DELAY,
max_delay: float = DEFAULT_STREAM_MAX_DELAY,
short_text_threshold: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD,
long_text_threshold: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD,
chunk_size: int = DEFAULT_STREAM_CHUNK_SIZE):
def __init__(
self,
logger=None,
min_delay: float = DEFAULT_STREAM_MIN_DELAY,
max_delay: float = DEFAULT_STREAM_MAX_DELAY,
short_text_threshold: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD,
long_text_threshold: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD,
chunk_size: int = DEFAULT_STREAM_CHUNK_SIZE,
):
"""初始化流式输出优化器
参数:
logger: 日志记录器
min_delay: 最小延迟时间(秒)
@@ -40,13 +49,13 @@ class StreamOptimizer:
self.short_text_threshold = short_text_threshold
self.long_text_threshold = long_text_threshold
self.chunk_size = chunk_size
def calculate_delay(self, text_length: int) -> float:
"""根据文本长度计算延迟时间
参数:
text_length: 文本长度
返回:
延迟时间(秒)
"""
@@ -59,42 +68,48 @@ class StreamOptimizer:
else:
# 中等长度文本使用线性插值计算延迟
# 使用对数函数使延迟变化更平滑
ratio = math.log(text_length / self.short_text_threshold) / math.log(self.long_text_threshold / self.short_text_threshold)
ratio = math.log(text_length / self.short_text_threshold) / math.log(
self.long_text_threshold / self.short_text_threshold
)
return self.max_delay - ratio * (self.max_delay - self.min_delay)
def split_text_into_chunks(self, text: str) -> List[str]:
"""将文本分割成小块
参数:
text: 要分割的文本
返回:
文本块列表
"""
return [text[i:i+self.chunk_size] for i in range(0, len(text), self.chunk_size)]
async def optimize_stream_output(self,
text: str,
create_response_chunk: Callable[[str], Any],
format_chunk: Callable[[Any], str]) -> AsyncGenerator[str, None]:
return [
text[i : i + self.chunk_size] for i in range(0, len(text), self.chunk_size)
]
async def optimize_stream_output(
self,
text: str,
create_response_chunk: Callable[[str], Any],
format_chunk: Callable[[Any], str],
) -> AsyncGenerator[str, None]:
"""优化流式输出
参数:
text: 要输出的文本
create_response_chunk: 创建响应块的函数,接收文本,返回响应块
format_chunk: 格式化响应块的函数,接收响应块,返回格式化后的字符串
返回:
异步生成器,生成格式化后的响应块
"""
if not text:
return
# 计算智能延迟时间
delay = self.calculate_delay(len(text))
if self.logger:
self.logger.info(f"Text length: {len(text)}, delay: {delay:.4f}s")
# 根据文本长度决定输出方式
if len(text) >= self.long_text_threshold:
# 长文本:分块输出
@@ -120,7 +135,7 @@ openai_optimizer = StreamOptimizer(
max_delay=settings.STREAM_MAX_DELAY,
short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD,
long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD,
chunk_size=settings.STREAM_CHUNK_SIZE
chunk_size=settings.STREAM_CHUNK_SIZE,
)
gemini_optimizer = StreamOptimizer(
@@ -129,5 +144,5 @@ gemini_optimizer = StreamOptimizer(
max_delay=settings.STREAM_MAX_DELAY,
short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD,
long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD,
chunk_size=settings.STREAM_CHUNK_SIZE
chunk_size=settings.STREAM_CHUNK_SIZE,
)

View File

@@ -133,3 +133,23 @@ def get_retry_logger():
def get_image_create_logger():
return Logger.setup_logger("image_create")
def get_exceptions_logger():
return Logger.setup_logger("exceptions")
def get_application_logger():
return Logger.setup_logger("application")
def get_initialization_logger():
return Logger.setup_logger("initialization")
def get_middleware_logger():
return Logger.setup_logger("middleware")
def get_routes_logger():
return Logger.setup_logger("routes")

View File

@@ -1,9 +1,11 @@
"""
应用程序入口模块
"""
import uvicorn
from app.core.application import create_app
from app.logger.logger import get_main_logger
from app.log.logger import get_main_logger
# 创建应用程序实例
app = create_app()

View File

@@ -1,60 +1,72 @@
"""
中间件配置模块,负责设置和配置应用程序的中间件
"""
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from starlette.middleware.base import BaseHTTPMiddleware
from app.logger.logger import get_main_logger
from app.core.security import verify_auth_token
# from app.middleware.request_logging_middleware import RequestLoggingMiddleware
from app.core.constants import API_VERSION
from app.core.security import verify_auth_token
from app.log.logger import get_middleware_logger
logger = get_middleware_logger()
logger = get_main_logger()
class AuthMiddleware(BaseHTTPMiddleware):
"""
认证中间件,处理未经身份验证的请求
"""
async def dispatch(self, request: Request, call_next):
# 允许特定路径绕过身份验证
if (request.url.path not in ["/", "/auth"] and
not request.url.path.startswith("/static") and
not request.url.path.startswith("/gemini") and
not request.url.path.startswith("/v1") and
not request.url.path.startswith(f"/{API_VERSION}") and
not request.url.path.startswith("/health") and
not request.url.path.startswith("/hf")):
if (
request.url.path not in ["/", "/auth"]
and not request.url.path.startswith("/static")
and not request.url.path.startswith("/gemini")
and not request.url.path.startswith("/v1")
and not request.url.path.startswith(f"/{API_VERSION}")
and not request.url.path.startswith("/health")
and not request.url.path.startswith("/hf")
):
auth_token = request.cookies.get("auth_token")
if not auth_token or not verify_auth_token(auth_token):
logger.warning(f"Unauthorized access attempt to {request.url.path}")
return RedirectResponse(url="/")
logger.debug("Request authenticated successfully")
response = await call_next(request)
return response
def setup_middlewares(app: FastAPI) -> None:
"""
设置应用程序的中间件
Args:
app: FastAPI应用程序实例
"""
# 添加认证中间件
app.add_middleware(AuthMiddleware)
# 添加请求日志中间件(可选,默认注释掉)
# app.add_middleware(RequestLoggingMiddleware)
# 配置CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 生产环境建议配置具体的域名
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], # 明确指定允许的HTTP方法
allow_methods=[
"GET",
"POST",
"PUT",
"DELETE",
"OPTIONS",
], # 明确指定允许的HTTP方法
allow_headers=["*"], # 生产环境建议配置具体的请求头
expose_headers=["*"], # 允许前端访问的响应头
max_age=600, # 预检请求缓存时间(秒)

View File

@@ -1,7 +1,9 @@
import json
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
import json
from app.logger.logger import get_request_logger
from app.log.logger import get_request_logger
logger = get_request_logger()
@@ -20,7 +22,9 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
# 尝试格式化JSON
try:
formatted_body = json.loads(body_str)
logger.info(f"Formatted request body:\n{json.dumps(formatted_body, indent=2, ensure_ascii=False)}")
logger.info(
f"Formatted request body:\n{json.dumps(formatted_body, indent=2, ensure_ascii=False)}"
)
except json.JSONDecodeError:
logger.info("Request body is not valid JSON.")
except Exception as e:

View File

@@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from copy import deepcopy
from app.config.config import settings
from app.logger.logger import get_gemini_logger
from app.log.logger import get_gemini_logger
from app.core.security import SecurityService
from app.domain.gemini_models import GeminiContent, GeminiRequest
from app.service.chat.gemini_chat_service import GeminiChatService

View File

@@ -1,37 +1,46 @@
from fastapi import HTTPException, APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from app.config.config import settings
from app.logger.logger import get_openai_logger
from app.core.security import SecurityService
from app.domain.openai_models import ChatRequest, EmbeddingRequest, ImageGenerationRequest
from app.domain.openai_models import (
ChatRequest,
EmbeddingRequest,
ImageGenerationRequest,
)
from app.handler.retry_handler import RetryHandler
from app.log.logger import get_openai_logger
from app.service.chat.openai_chat_service import OpenAIChatService
from app.service.embedding.embedding_service import EmbeddingService
from app.service.image.image_create_service import ImageCreateService
from app.service.key.key_manager import KeyManager, get_key_manager_instance
from app.service.model.model_service import ModelService
from app.service.chat.openai_chat_service import OpenAIChatService
router = APIRouter()
logger = get_openai_logger()
# 初始化服务
security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN)
model_service = ModelService(settings.SEARCH_MODELS,settings.IMAGE_MODELS)
model_service = ModelService(settings.SEARCH_MODELS, settings.IMAGE_MODELS)
embedding_service = EmbeddingService(settings.BASE_URL)
image_create_service = ImageCreateService()
async def get_key_manager():
return await get_key_manager_instance()
async def get_next_working_key_wrapper(key_manager: KeyManager = Depends(get_key_manager)):
async def get_next_working_key_wrapper(
key_manager: KeyManager = Depends(get_key_manager),
):
return await key_manager.get_next_working_key()
@router.get("/v1/models")
@router.get("/hf/v1/models")
async def list_models(
_=Depends(security_service.verify_authorization),
key_manager: KeyManager = Depends(get_key_manager)
key_manager: KeyManager = Depends(get_key_manager),
):
logger.info("-" * 50 + "list_models" + "-" * 50)
logger.info("Handling models list request")
@@ -41,7 +50,9 @@ async def list_models(
return model_service.get_gemini_openai_models(api_key)
except Exception as e:
logger.error(f"Error getting models list: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error while fetching models list") from e
raise HTTPException(
status_code=500, detail="Internal server error while fetching models list"
) from e
@router.post("/v1/chat/completions")
@@ -51,7 +62,7 @@ async def chat_completion(
request: ChatRequest,
_=Depends(security_service.verify_authorization),
api_key: str = Depends(get_next_working_key_wrapper),
key_manager: KeyManager = Depends(get_key_manager)
key_manager: KeyManager = Depends(get_key_manager),
):
# 如果model是imagen3,使用paid_key
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":
@@ -63,8 +74,10 @@ async def chat_completion(
logger.info(f"Using API key: {api_key}")
if not model_service.check_model_support(request.model):
raise HTTPException(status_code=400, detail=f"Model {request.model} is not supported")
raise HTTPException(
status_code=400, detail=f"Model {request.model} is not supported"
)
try:
# 如果model是imagen3,使用paid_key
if request.model == f"{settings.CREATE_IMAGE_MODEL}-chat":
@@ -80,6 +93,7 @@ async def chat_completion(
logger.error(f"Chat completion failed after retries: {str(e)}")
raise HTTPException(status_code=500, detail="Chat completion failed") from e
@router.post("/v1/images/generations")
@router.post("/hf/v1/images/generations")
async def generate_image(
@@ -95,14 +109,17 @@ async def generate_image(
return response
except Exception as e:
logger.error(f"Image generation request failed: {str(e)}")
raise HTTPException(status_code=500, detail="Image generation request failed") from e
raise HTTPException(
status_code=500, detail="Image generation request failed"
) from e
@router.post("/v1/embeddings")
@router.post("/hf/v1/embeddings")
async def embedding(
request: EmbeddingRequest,
_=Depends(security_service.verify_authorization),
key_manager: KeyManager = Depends(get_key_manager)
key_manager: KeyManager = Depends(get_key_manager),
):
logger.info("-" * 50 + "embedding" + "-" * 50)
logger.info(f"Handling embedding request for model: {request.model}")
@@ -118,11 +135,12 @@ async def embedding(
logger.error(f"Embedding request failed: {str(e)}")
raise HTTPException(status_code=500, detail="Embedding request failed") from e
@router.get("/v1/keys/list")
@router.get("/hf/v1/keys/list")
async def get_keys_list(
_=Depends(security_service.verify_auth_token),
key_manager: KeyManager = Depends(get_key_manager)
key_manager: KeyManager = Depends(get_key_manager),
):
"""获取有效和无效的API key列表"""
logger.info("-" * 50 + "get_keys_list" + "-" * 50)
@@ -133,13 +151,12 @@ async def get_keys_list(
"status": "success",
"data": {
"valid_keys": keys_status["valid_keys"],
"invalid_keys": keys_status["invalid_keys"]
"invalid_keys": keys_status["invalid_keys"],
},
"total": len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"])
"total": len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"]),
}
except Exception as e:
logger.error(f"Error getting keys list: {str(e)}")
raise HTTPException(
status_code=500,
detail="Internal server error while fetching keys list"
status_code=500, detail="Internal server error while fetching keys list"
) from e

View File

@@ -1,24 +1,26 @@
"""
路由配置模块负责设置和配置应用程序的路由
"""
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from app.logger.logger import get_main_logger
from app.core.security import verify_auth_token
from app.log.logger import get_routes_logger
from app.router import gemini_routes, openai_routes
from app.service.key.key_manager import get_key_manager_instance
logger = get_main_logger()
logger = get_routes_logger()
# 配置Jinja2模板
templates = Jinja2Templates(directory="app/templates")
def setup_routers(app: FastAPI) -> None:
"""
设置应用程序的路由
Args:
app: FastAPI应用程序实例
"""
@@ -26,20 +28,22 @@ def setup_routers(app: FastAPI) -> None:
app.include_router(openai_routes.router)
app.include_router(gemini_routes.router)
app.include_router(gemini_routes.router_v1beta)
# 添加页面路由
setup_page_routes(app)
# 添加健康检查路由
setup_health_routes(app)
def setup_page_routes(app: FastAPI) -> None:
"""
设置页面相关的路由
Args:
app: FastAPI应用程序实例
"""
@app.get("/", response_class=HTMLResponse)
async def auth_page(request: Request):
"""认证页面"""
@@ -54,11 +58,13 @@ def setup_page_routes(app: FastAPI) -> None:
if not auth_token:
logger.warning("Authentication attempt with empty token")
return RedirectResponse(url="/", status_code=302)
if verify_auth_token(auth_token):
logger.info("Successful authentication")
response = RedirectResponse(url="/keys", status_code=302)
response.set_cookie(key="auth_token", value=auth_token, httponly=True, max_age=3600)
response.set_cookie(
key="auth_token", value=auth_token, httponly=True, max_age=3600
)
return response
logger.warning("Failed authentication attempt with invalid token")
return RedirectResponse(url="/", status_code=302)
@@ -74,28 +80,33 @@ def setup_page_routes(app: FastAPI) -> None:
if not auth_token or not verify_auth_token(auth_token):
logger.warning("Unauthorized access attempt to keys page")
return RedirectResponse(url="/", status_code=302)
key_manager = await get_key_manager_instance()
keys_status = await key_manager.get_keys_by_status()
total = len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"])
logger.info(f"Keys status retrieved successfully. Total keys: {total}")
return templates.TemplateResponse("keys_status.html", {
"request": request,
"valid_keys": keys_status["valid_keys"],
"invalid_keys": keys_status["invalid_keys"],
"total": total
})
return templates.TemplateResponse(
"keys_status.html",
{
"request": request,
"valid_keys": keys_status["valid_keys"],
"invalid_keys": keys_status["invalid_keys"],
"total": total,
},
)
except Exception as e:
logger.error(f"Error retrieving keys status: {str(e)}")
raise
def setup_health_routes(app: FastAPI) -> None:
"""
设置健康检查相关的路由
Args:
app: FastAPI应用程序实例
"""
@app.get("/health")
async def health_check(request: Request):
"""健康检查端点"""

View File

@@ -1,13 +1,14 @@
# app/services/chat_service.py
import json
from typing import Dict, Any, AsyncGenerator, List
from app.logger.logger import get_gemini_logger
from app.service.client.api_client import GeminiApiClient
from app.handler.stream_optimizer import gemini_optimizer
from app.domain.gemini_models import GeminiRequest
from typing import Any, AsyncGenerator, Dict, List
from app.config.config import settings
from app.domain.gemini_models import GeminiRequest
from app.handler.response_handler import GeminiResponseHandler
from app.handler.stream_optimizer import gemini_optimizer
from app.log.logger import get_gemini_logger
from app.service.client.api_client import GeminiApiClient
from app.service.key.key_manager import KeyManager
logger = get_gemini_logger()
@@ -26,9 +27,11 @@ def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
"""构建工具"""
tools = []
if settings.TOOLS_CODE_EXECUTION_ENABLED and not (
model.endswith("-search") or "-thinking" in model
) and not _has_image_parts(payload.get("contents", [])):
if (
settings.TOOLS_CODE_EXECUTION_ENABLED
and not (model.endswith("-search") or "-thinking" in model)
and not _has_image_parts(payload.get("contents", []))
):
tools.append({"code_execution": {}})
if model.endswith("-search"):
tools.append({"googleSearch": {}})
@@ -49,14 +52,14 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"}
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
]
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
]
@@ -68,12 +71,12 @@ def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
"tools": _build_tools(model, request_dict),
"safetySettings": _get_safety_settings(model),
"generationConfig": request_dict.get("generationConfig", {}),
"systemInstruction": request_dict.get("systemInstruction", "")
"systemInstruction": request_dict.get("systemInstruction", ""),
}
if model.endswith("-image") or model.endswith("-image-generation"):
payload.pop("systemInstruction")
payload["generationConfig"]["responseModalities"] = ["Text","Image"]
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
return payload
@@ -84,54 +87,68 @@ class GeminiChatService:
self.api_client = GeminiApiClient(base_url)
self.key_manager = key_manager
self.response_handler = GeminiResponseHandler()
def _extract_text_from_response(self, response: Dict[str, Any]) -> str:
"""从响应中提取文本内容"""
if not response.get("candidates"):
return ""
candidate = response["candidates"][0]
content = candidate.get("content", {})
parts = content.get("parts", [])
if parts and "text" in parts[0]:
return parts[0].get("text", "")
return ""
def _create_char_response(self, original_response: Dict[str, Any], text: str) -> Dict[str, Any]:
def _create_char_response(
self, original_response: Dict[str, Any], text: str
) -> Dict[str, Any]:
"""创建包含指定文本的响应"""
response_copy = json.loads(json.dumps(original_response)) # 深拷贝
if response_copy.get("candidates") and response_copy["candidates"][0].get("content", {}).get("parts"):
if response_copy.get("candidates") and response_copy["candidates"][0].get(
"content", {}
).get("parts"):
response_copy["candidates"][0]["content"]["parts"][0]["text"] = text
return response_copy
async def generate_content(self, model: str, request: GeminiRequest, api_key: str) -> Dict[str, Any]:
async def generate_content(
self, model: str, request: GeminiRequest, api_key: str
) -> Dict[str, Any]:
"""生成内容"""
payload = _build_payload(model, request)
response = await self.api_client.generate_content(payload, model, api_key)
return self.response_handler.handle_response(response, model, stream=False)
async def stream_generate_content(self, model: str, request: GeminiRequest, api_key: str) -> AsyncGenerator[str, None]:
async def stream_generate_content(
self, model: str, request: GeminiRequest, api_key: str
) -> AsyncGenerator[str, None]:
"""流式生成内容"""
retries = 0
max_retries = 3
payload = _build_payload(model, request)
while retries < max_retries:
try:
async for line in self.api_client.stream_generate_content(payload, model, api_key):
async for line in self.api_client.stream_generate_content(
payload, model, api_key
):
# print(line)
if line.startswith("data:"):
line = line[6:]
response_data = self.response_handler.handle_response(json.loads(line), model, stream=True)
response_data = self.response_handler.handle_response(
json.loads(line), model, stream=True
)
text = self._extract_text_from_response(response_data)
# 如果有文本内容,使用流式输出优化器处理
if text:
# 使用流式输出优化器处理文本输出
async for optimized_chunk in gemini_optimizer.optimize_stream_output(
async for (
optimized_chunk
) in gemini_optimizer.optimize_stream_output(
text,
lambda t: self._create_char_response(response_data, t),
lambda c: "data: " + json.dumps(c) + "\n\n"
lambda c: "data: " + json.dumps(c) + "\n\n",
):
yield optimized_chunk
else:
@@ -141,9 +158,13 @@ class GeminiChatService:
break
except Exception as e:
retries += 1
logger.warning(f"Streaming API call failed with error: {str(e)}. Attempt {retries} of {max_retries}")
logger.warning(
f"Streaming API call failed with error: {str(e)}. Attempt {retries} of {max_retries}"
)
api_key = await self.key_manager.handle_api_failure(api_key)
logger.info(f"Switched to new API key: {api_key}")
if retries >= max_retries:
logger.error(f"Max retries ({max_retries}) reached for streaming. Raising error")
logger.error(
f"Max retries ({max_retries}) reached for streaming. Raising error"
)
break

View File

@@ -1,15 +1,16 @@
# app/services/chat_service.py
from copy import deepcopy
import json
from typing import Dict, Any, AsyncGenerator, List, Optional, Union
from app.logger.logger import get_openai_logger
from copy import deepcopy
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from app.config.config import settings
from app.domain.openai_models import ChatRequest, ImageGenerationRequest
from app.handler.message_converter import OpenAIMessageConverter
from app.handler.response_handler import OpenAIResponseHandler
from app.service.client.api_client import GeminiApiClient
from app.handler.stream_optimizer import openai_optimizer
from app.domain.openai_models import ChatRequest, ImageGenerationRequest
from app.config.config import settings
from app.log.logger import get_openai_logger
from app.service.client.api_client import GeminiApiClient
from app.service.image.image_create_service import ImageCreateService
from app.service.key.key_manager import KeyManager
@@ -27,16 +28,21 @@ def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
def _build_tools(
request: ChatRequest, messages: List[Dict[str, Any]]
request: ChatRequest, messages: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""构建工具"""
tools = []
model = request.model
if (
settings.TOOLS_CODE_EXECUTION_ENABLED
and not (model.endswith("-search") or "-thinking" in model or model.endswith("-image") or model.endswith("-image-generation"))
and not _has_image_parts(messages)
settings.TOOLS_CODE_EXECUTION_ENABLED
and not (
model.endswith("-search")
or "-thinking" in model
or model.endswith("-image")
or model.endswith("-image-generation")
)
and not _has_image_parts(messages)
):
tools.append({"code_execution": {}})
if model.endswith("-search"):
@@ -52,7 +58,9 @@ def _build_tools(
if tool.get("type", "") == "function" and tool.get("function"):
function = deepcopy(tool.get("function"))
parameters = function.get("parameters", {})
if parameters.get("type") == "object" and not parameters.get("properties", {}):
if parameters.get("type") == "object" and not parameters.get(
"properties", {}
):
function.pop("parameters", None)
function_declarations.append(function)
@@ -66,7 +74,7 @@ def _build_tools(
functions.append(item)
tools.append({"functionDeclarations": functions})
return tools
@@ -95,7 +103,9 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
def _build_payload(
request: ChatRequest, messages: List[Dict[str, Any]], instruction: Optional[Dict[str, Any]] = None
request: ChatRequest,
messages: List[Dict[str, Any]],
instruction: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""构建请求payload"""
payload = {
@@ -111,8 +121,8 @@ def _build_payload(
"safetySettings": _get_safety_settings(request.model),
}
if request.model.endswith("-image") or request.model.endswith("-image-generation"):
payload["generationConfig"]["responseModalities"] = ["Text","Image"]
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
if (
instruction
and isinstance(instruction, dict)
@@ -128,24 +138,27 @@ def _build_payload(
class OpenAIChatService:
"""聊天服务"""
def __init__(self, base_url: str, key_manager: KeyManager = None):
self.message_converter = OpenAIMessageConverter()
self.response_handler = OpenAIResponseHandler(config=None)
self.api_client = GeminiApiClient(base_url)
self.key_manager = key_manager
self.image_create_service = ImageCreateService()
def _extract_text_from_openai_chunk(self, chunk: Dict[str, Any]) -> str:
"""从OpenAI响应块中提取文本内容"""
if not chunk.get("choices"):
return ""
choice = chunk["choices"][0]
if "delta" in choice and "content" in choice["delta"]:
return choice["delta"]["content"]
return ""
def _create_char_openai_chunk(self, original_chunk: Dict[str, Any], text: str) -> Dict[str, Any]:
def _create_char_openai_chunk(
self, original_chunk: Dict[str, Any], text: str
) -> Dict[str, Any]:
"""创建包含指定文本的OpenAI响应块"""
chunk_copy = json.loads(json.dumps(original_chunk)) # 深拷贝
if chunk_copy.get("choices") and "delta" in chunk_copy["choices"][0]:
@@ -153,9 +166,9 @@ class OpenAIChatService:
return chunk_copy
async def create_chat_completion(
self,
request: ChatRequest,
api_key: str,
self,
request: ChatRequest,
api_key: str,
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
"""创建聊天完成"""
# 转换消息格式
@@ -169,7 +182,7 @@ class OpenAIChatService:
return await self._handle_normal_completion(request.model, payload, api_key)
async def _handle_normal_completion(
self, model: str, payload: Dict[str, Any], api_key: str
self, model: str, payload: Dict[str, Any], api_key: str
) -> Dict[str, Any]:
"""处理普通聊天完成"""
response = await self.api_client.generate_content(payload, model, api_key)
@@ -178,7 +191,7 @@ class OpenAIChatService:
)
async def _handle_stream_completion(
self, model: str, payload: Dict[str, Any], api_key: str
self, model: str, payload: Dict[str, Any], api_key: str
) -> AsyncGenerator[str, None]:
"""处理流式聊天完成,添加重试逻辑"""
retries = 0
@@ -186,7 +199,7 @@ class OpenAIChatService:
while retries < max_retries:
try:
async for line in self.api_client.stream_generate_content(
payload, model, api_key
payload, model, api_key
):
# print(line)
if line.startswith("data:"):
@@ -199,10 +212,14 @@ class OpenAIChatService:
text = self._extract_text_from_openai_chunk(openai_chunk)
if text:
# 使用流式输出优化器处理文本输出
async for optimized_chunk in openai_optimizer.optimize_stream_output(
async for (
optimized_chunk
) in openai_optimizer.optimize_stream_output(
text,
lambda t: self._create_char_openai_chunk(openai_chunk, t),
lambda c: f"data: {json.dumps(c)}\n\n"
lambda t: self._create_char_openai_chunk(
openai_chunk, t
),
lambda c: f"data: {json.dumps(c)}\n\n",
):
yield optimized_chunk
else:
@@ -228,21 +245,23 @@ class OpenAIChatService:
break
async def create_image_chat_completion(
self,
request: ChatRequest,
self,
request: ChatRequest,
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
image_generate_request = ImageGenerationRequest()
image_generate_request.prompt = request.messages[-1]["content"]
image_res = self.image_create_service.generate_images_chat(image_generate_request)
image_res = self.image_create_service.generate_images_chat(
image_generate_request
)
if request.stream:
return self._handle_stream_image_completion(request.model,image_res)
return self._handle_stream_image_completion(request.model, image_res)
else:
return self._handle_normal_image_completion(request.model,image_res)
return self._handle_normal_image_completion(request.model, image_res)
async def _handle_stream_image_completion(
self, model: str, image_data: str
self, model: str, image_data: str
) -> AsyncGenerator[str, None]:
if image_data:
openai_chunk = self.response_handler.handle_image_chat_response(
@@ -253,10 +272,12 @@ class OpenAIChatService:
text = self._extract_text_from_openai_chunk(openai_chunk)
if text:
# 使用流式输出优化器处理文本输出
async for optimized_chunk in openai_optimizer.optimize_stream_output(
async for (
optimized_chunk
) in openai_optimizer.optimize_stream_output(
text,
lambda t: self._create_char_openai_chunk(openai_chunk, t),
lambda c: f"data: {json.dumps(c)}\n\n"
lambda c: f"data: {json.dumps(c)}\n\n",
):
yield optimized_chunk
else:
@@ -265,11 +286,11 @@ class OpenAIChatService:
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
yield "data: [DONE]\n\n"
logger.info("Image chat streaming completed successfully")
def _handle_normal_image_completion(
self, model: str, image_data: str
self, model: str, image_data: str
) -> Dict[str, Any]:
return self.response_handler.handle_image_chat_response(
image_data, model, stream=False, finish_reason="stop"
)

View File

@@ -1,9 +1,9 @@
from typing import Union, List
from typing import List, Union
import openai
from openai.types import CreateEmbeddingResponse
from app.logger.logger import get_embeddings_logger
from app.log.logger import get_embeddings_logger
logger = get_embeddings_logger()

View File

@@ -1,15 +1,15 @@
import base64
import time
import uuid
from google import genai
from google.genai import types
import base64
from app.config.config import settings
from app.logger.logger import get_image_create_logger
from app.utils.uploader import ImageUploaderFactory
from app.domain.openai_models import ImageGenerationRequest
from app.core.constants import VALID_IMAGE_RATIOS
from app.domain.openai_models import ImageGenerationRequest
from app.log.logger import get_image_create_logger
from app.utils.uploader import ImageUploaderFactory
logger = get_image_create_logger()
@@ -27,34 +27,34 @@ class ImageCreateService:
- {ratio:比例} 例如: {ratio:16:9} 使用16:9比例
"""
import re
# 默认值
n = 1
aspect_ratio = self.aspect_ratio
# 解析n参数
n_match = re.search(r'{n:(\d+)}', prompt)
n_match = re.search(r"{n:(\d+)}", prompt)
if n_match:
n = int(n_match.group(1))
if n < 1 or n > 4:
raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.")
prompt = prompt.replace(n_match.group(0), '').strip()
# 解析ratio参数
ratio_match = re.search(r'{ratio:(\d+:\d+)}', prompt)
prompt = prompt.replace(n_match.group(0), "").strip()
# 解析ratio参数
ratio_match = re.search(r"{ratio:(\d+:\d+)}", prompt)
if ratio_match:
aspect_ratio = ratio_match.group(1)
if aspect_ratio not in VALID_IMAGE_RATIOS:
raise ValueError(
f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}"
)
prompt = prompt.replace(ratio_match.group(0), '').strip()
prompt = prompt.replace(ratio_match.group(0), "").strip()
return prompt, n, aspect_ratio
def generate_images(self, request: ImageGenerationRequest):
client = genai.Client(api_key=self.paid_key)
if request.size == "1024x1024":
self.aspect_ratio = "1:1"
elif request.size == "1792x1024":
@@ -67,13 +67,15 @@ class ImageCreateService:
)
# 解析prompt中的参数
cleaned_prompt, prompt_n, prompt_ratio = self.parse_prompt_parameters(request.prompt)
cleaned_prompt, prompt_n, prompt_ratio = self.parse_prompt_parameters(
request.prompt
)
request.prompt = cleaned_prompt
# 如果prompt中指定了n则覆盖请求中的n
if prompt_n > 1:
request.n = prompt_n
# 如果prompt中指定了ratio则覆盖默认的aspect_ratio
if prompt_ratio != self.aspect_ratio:
self.aspect_ratio = prompt_ratio
@@ -96,46 +98,49 @@ class ImageCreateService:
for index, generated_image in enumerate(response.generated_images):
image_data = generated_image.image.image_bytes
image_uploader = None
if request.response_format == "b64_json":
base64_image = base64.b64encode(image_data).decode('utf-8')
images_data.append({
"b64_json": base64_image,
"revised_prompt": request.prompt
})
base64_image = base64.b64encode(image_data).decode("utf-8")
images_data.append(
{"b64_json": base64_image, "revised_prompt": request.prompt}
)
else:
current_date = time.strftime("%Y/%m/%d")
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
if settings.UPLOAD_PROVIDER == "smms":
image_uploader = ImageUploaderFactory.create(
provider=settings.UPLOAD_PROVIDER,
api_key=settings.SMMS_SECRET_TOKEN
api_key=settings.SMMS_SECRET_TOKEN,
)
elif settings.UPLOAD_PROVIDER == "picgo":
image_uploader = ImageUploaderFactory.create(
provider=settings.UPLOAD_PROVIDER,
api_key=settings.PICGO_API_KEY
api_key=settings.PICGO_API_KEY,
)
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
image_uploader = ImageUploaderFactory.create(
provider=settings.UPLOAD_PROVIDER,
base_url=settings.CLOUDFLARE_IMGBED_URL,
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE,
)
else:
raise ValueError(f"Unsupported upload provider: {settings.UPLOAD_PROVIDER}")
raise ValueError(
f"Unsupported upload provider: {settings.UPLOAD_PROVIDER}"
)
upload_response = image_uploader.upload(image_data, filename)
images_data.append({
"url": f"{upload_response.data.url}",
"revised_prompt": request.prompt
})
images_data.append(
{
"url": f"{upload_response.data.url}",
"revised_prompt": request.prompt,
}
)
response_data = {
"created": int(time.time()), # Current timestamp
"data": images_data
"data": images_data,
}
return response_data
else:
@@ -147,9 +152,13 @@ class ImageCreateService:
if image_datas:
markdown_images = []
for index, image_data in enumerate(image_datas):
if 'url' in image_data:
markdown_images.append(f"![Generated Image {index+1}]({image_data['url']})")
if "url" in image_data:
markdown_images.append(
f"![Generated Image {index+1}]({image_data['url']})"
)
else:
# 如果是base64格式创建data URL
markdown_images.append(f"![Generated Image {index+1}](data:image/png;base64,{image_data['b64_json']})")
markdown_images.append(
f"![Generated Image {index+1}](data:image/png;base64,{image_data['b64_json']})"
)
return "\n".join(markdown_images)

View File

@@ -1,9 +1,9 @@
import asyncio
from itertools import cycle
from typing import Dict
from app.logger.logger import get_key_manager_logger
from app.config.config import settings
from app.config.config import settings
from app.log.logger import get_key_manager_logger
logger = get_key_manager_logger()
@@ -20,7 +20,7 @@ class KeyManager:
async def get_paid_key(self) -> str:
return self.paid_key
async def get_next_key(self) -> str:
"""获取下一个API key"""
async with self.key_cycle_lock:
@@ -70,7 +70,7 @@ class KeyManager:
"""获取分类后的API key列表包括失败次数"""
valid_keys = {}
invalid_keys = {}
async with self.failure_count_lock:
for key in self.api_keys:
fail_count = self.key_failure_counts[key]
@@ -78,16 +78,14 @@ class KeyManager:
valid_keys[key] = fail_count
else:
invalid_keys[key] = fail_count
return {
"valid_keys": valid_keys,
"invalid_keys": invalid_keys
}
return {"valid_keys": valid_keys, "invalid_keys": invalid_keys}
_singleton_instance = None
_singleton_lock = asyncio.Lock()
async def get_key_manager_instance(api_keys: list = None) -> KeyManager:
"""
获取 KeyManager 单例实例。

View File

@@ -1,11 +1,14 @@
import requests
from datetime import datetime, timezone
from typing import Optional, Dict, Any
from app.logger.logger import get_model_logger
from typing import Any, Dict, Optional
import requests
from app.config.config import settings
from app.log.logger import get_model_logger
logger = get_model_logger()
class ModelService:
def __init__(self, search_models: list, image_models: list):
self.search_models = search_models
@@ -20,7 +23,7 @@ class ModelService:
response = requests.get(url)
if response.status_code == 200:
gemini_models = response.json()
filtered_models_list = []
for model in gemini_models.get("models", []):
model_id = model["name"].split("/")[-1]
@@ -28,7 +31,7 @@ class ModelService:
filtered_models_list.append(model)
else:
logger.info(f"Filtered out model: {model_id}")
gemini_models["models"] = filtered_models_list
return gemini_models
else:
@@ -48,7 +51,7 @@ class ModelService:
return None
def convert_to_openai_models_format(
self, gemini_models: Dict[str, Any]
self, gemini_models: Dict[str, Any]
) -> Dict[str, Any]:
openai_format = {"object": "list", "data": [], "success": True}