mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-07-05 06:41:29 +08:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b14bb93d8f |
@@ -263,7 +263,7 @@ uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
|||||||
"content": "你好"
|
"content": "你好"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"model": "gemini-1.5-flash-002",
|
"model": "gemini-1.5-flash",
|
||||||
"temperature": 0.7,
|
"temperature": 0.7,
|
||||||
"stream": false,
|
"stream": false,
|
||||||
"tools": [],
|
"tools": [],
|
||||||
@@ -276,7 +276,7 @@ uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
|||||||
|
|
||||||
- `messages`: 消息列表,格式与 OpenAI API 相同
|
- `messages`: 消息列表,格式与 OpenAI API 相同
|
||||||
- `model`: 模型名称,支持所有Gemini模型,包括:
|
- `model`: 模型名称,支持所有Gemini模型,包括:
|
||||||
- `gemini-1.5-flash-002`: 快速响应模型
|
- `gemini-1.5-flash`: 快速响应模型
|
||||||
- `gemini-2.0-flash-exp`: 实验性快速响应模型
|
- `gemini-2.0-flash-exp`: 实验性快速响应模型
|
||||||
- `gemini-2.0-flash-exp-search`: 支持搜索功能的实验性模型
|
- `gemini-2.0-flash-exp-search`: 支持搜索功能的实验性模型
|
||||||
- `stream`: 是否开启流式响应,`true` 或 `false`
|
- `stream`: 是否开启流式响应,`true` 或 `false`
|
||||||
|
|||||||
@@ -1,19 +1,37 @@
|
|||||||
from pydantic_settings import BaseSettings
|
"""
|
||||||
|
应用程序配置模块
|
||||||
|
"""
|
||||||
from typing import List
|
from typing import List
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
from app.core.constants import API_VERSION, DEFAULT_MODEL
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
|
"""应用程序配置"""
|
||||||
|
# API相关配置
|
||||||
API_KEYS: List[str]
|
API_KEYS: List[str]
|
||||||
ALLOWED_TOKENS: List[str]
|
ALLOWED_TOKENS: List[str]
|
||||||
BASE_URL: str = "https://generativelanguage.googleapis.com/v1beta"
|
BASE_URL: str = f"https://generativelanguage.googleapis.com/{API_VERSION}"
|
||||||
|
AUTH_TOKEN: str = ""
|
||||||
|
MAX_FAILURES: int = 3
|
||||||
|
TEST_MODEL: str = DEFAULT_MODEL
|
||||||
|
|
||||||
|
# 模型相关配置
|
||||||
SEARCH_MODELS: List[str] = ["gemini-2.0-flash-exp"]
|
SEARCH_MODELS: List[str] = ["gemini-2.0-flash-exp"]
|
||||||
IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp"]
|
IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp"]
|
||||||
FILTERED_MODELS: List[str] = ["gemini-1.0-pro-vision-latest", "gemini-pro-vision", "chat-bison-001", "text-bison-001", "embedding-gecko-001"]
|
FILTERED_MODELS: List[str] = [
|
||||||
|
"gemini-1.0-pro-vision-latest",
|
||||||
|
"gemini-pro-vision",
|
||||||
|
"chat-bison-001",
|
||||||
|
"text-bison-001",
|
||||||
|
"embedding-gecko-001"
|
||||||
|
]
|
||||||
TOOLS_CODE_EXECUTION_ENABLED: bool = False
|
TOOLS_CODE_EXECUTION_ENABLED: bool = False
|
||||||
SHOW_SEARCH_LINK: bool = True
|
SHOW_SEARCH_LINK: bool = True
|
||||||
SHOW_THINKING_PROCESS: bool = True
|
SHOW_THINKING_PROCESS: bool = True
|
||||||
AUTH_TOKEN: str = ""
|
|
||||||
MAX_FAILURES: int = 3
|
# 图像生成相关配置
|
||||||
PAID_KEY: str = ""
|
PAID_KEY: str = ""
|
||||||
CREATE_IMAGE_MODEL: str = "imagen-3.0-generate-002"
|
CREATE_IMAGE_MODEL: str = "imagen-3.0-generate-002"
|
||||||
UPLOAD_PROVIDER: str = "smms"
|
UPLOAD_PROVIDER: str = "smms"
|
||||||
@@ -21,7 +39,6 @@ class Settings(BaseSettings):
|
|||||||
PICGO_API_KEY: str = ""
|
PICGO_API_KEY: str = ""
|
||||||
CLOUDFLARE_IMGBED_URL: str = ""
|
CLOUDFLARE_IMGBED_URL: str = ""
|
||||||
CLOUDFLARE_IMGBED_AUTH_CODE: str = ""
|
CLOUDFLARE_IMGBED_AUTH_CODE: str = ""
|
||||||
TEST_MODEL: str = "gemini-1.5-flash"
|
|
||||||
|
|
||||||
# 流式输出优化器配置
|
# 流式输出优化器配置
|
||||||
STREAM_MIN_DELAY: float = 0.016
|
STREAM_MIN_DELAY: float = 0.016
|
||||||
@@ -29,14 +46,16 @@ class Settings(BaseSettings):
|
|||||||
STREAM_SHORT_TEXT_THRESHOLD: int = 10
|
STREAM_SHORT_TEXT_THRESHOLD: int = 10
|
||||||
STREAM_LONG_TEXT_THRESHOLD: int = 50
|
STREAM_LONG_TEXT_THRESHOLD: int = 50
|
||||||
STREAM_CHUNK_SIZE: int = 5
|
STREAM_CHUNK_SIZE: int = 5
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, **kwargs):
|
||||||
super().__init__()
|
super().__init__(**kwargs)
|
||||||
if not self.AUTH_TOKEN:
|
# 设置默认AUTH_TOKEN(如果未提供)
|
||||||
self.AUTH_TOKEN = self.ALLOWED_TOKENS[0] if self.ALLOWED_TOKENS else ""
|
if not self.AUTH_TOKEN and self.ALLOWED_TOKENS:
|
||||||
|
self.AUTH_TOKEN = self.ALLOWED_TOKENS[0]
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file = ".env"
|
env_file = ".env"
|
||||||
|
|
||||||
|
|
||||||
|
# 创建全局配置实例
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
71
app/core/application.py
Normal file
71
app/core/application.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
"""
|
||||||
|
应用程序工厂模块,负责创建和配置FastAPI应用程序实例
|
||||||
|
"""
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
|
from app.config.config import settings
|
||||||
|
from app.logger.logger import get_main_logger
|
||||||
|
from app.middleware.middleware import setup_middlewares
|
||||||
|
from app.exception.exceptions import setup_exception_handlers
|
||||||
|
from app.router.routers import setup_routers
|
||||||
|
from app.service.key.key_manager import get_key_manager_instance
|
||||||
|
from app.core.initialization import initialize_app
|
||||||
|
|
||||||
|
logger = get_main_logger()
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""
|
||||||
|
应用程序生命周期管理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI应用实例
|
||||||
|
"""
|
||||||
|
# 启动事件
|
||||||
|
logger.info("Application starting up...")
|
||||||
|
try:
|
||||||
|
# 初始化KeyManager
|
||||||
|
await get_key_manager_instance(settings.API_KEYS)
|
||||||
|
logger.info("KeyManager initialized successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize KeyManager: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
yield # 应用程序运行期间
|
||||||
|
|
||||||
|
# 关闭事件
|
||||||
|
logger.info("Application shutting down...")
|
||||||
|
|
||||||
|
def create_app() -> FastAPI:
|
||||||
|
"""
|
||||||
|
创建并配置FastAPI应用程序实例
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FastAPI: 配置好的FastAPI应用程序实例
|
||||||
|
"""
|
||||||
|
# 初始化应用程序
|
||||||
|
initialize_app()
|
||||||
|
|
||||||
|
# 创建FastAPI应用
|
||||||
|
app = FastAPI(
|
||||||
|
title="Gemini Balance API",
|
||||||
|
description="Gemini API代理服务,支持负载均衡和密钥管理",
|
||||||
|
version="1.0.0",
|
||||||
|
lifespan=lifespan
|
||||||
|
)
|
||||||
|
|
||||||
|
# 配置静态文件
|
||||||
|
app.mount("/static", StaticFiles(directory="app/static"), name="static")
|
||||||
|
|
||||||
|
# 配置中间件
|
||||||
|
setup_middlewares(app)
|
||||||
|
|
||||||
|
# 配置异常处理器
|
||||||
|
setup_exception_handlers(app)
|
||||||
|
|
||||||
|
# 配置路由
|
||||||
|
setup_routers(app)
|
||||||
|
|
||||||
|
return app
|
||||||
32
app/core/constants.py
Normal file
32
app/core/constants.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""
|
||||||
|
常量定义模块
|
||||||
|
"""
|
||||||
|
|
||||||
|
# API相关常量
|
||||||
|
API_VERSION = "v1beta"
|
||||||
|
DEFAULT_TIMEOUT = 300 # 秒
|
||||||
|
|
||||||
|
# 模型相关常量
|
||||||
|
SUPPORTED_ROLES = ["user", "model", "system"]
|
||||||
|
DEFAULT_MODEL = "gemini-1.5-flash"
|
||||||
|
DEFAULT_TEMPERATURE = 0.7
|
||||||
|
DEFAULT_MAX_TOKENS = 8192
|
||||||
|
DEFAULT_TOP_P = 0.9
|
||||||
|
DEFAULT_TOP_K = 40
|
||||||
|
|
||||||
|
# 图像生成相关常量
|
||||||
|
VALID_IMAGE_RATIOS = ["1:1", "3:4", "4:3", "9:16", "16:9"]
|
||||||
|
|
||||||
|
# 上传提供商
|
||||||
|
UPLOAD_PROVIDERS = ["smms", "picgo", "cloudflare_imgbed"]
|
||||||
|
|
||||||
|
# 流式输出相关常量
|
||||||
|
DEFAULT_STREAM_MIN_DELAY = 0.016
|
||||||
|
DEFAULT_STREAM_MAX_DELAY = 0.024
|
||||||
|
DEFAULT_STREAM_SHORT_TEXT_THRESHOLD = 10
|
||||||
|
DEFAULT_STREAM_LONG_TEXT_THRESHOLD = 50
|
||||||
|
DEFAULT_STREAM_CHUNK_SIZE = 5
|
||||||
|
|
||||||
|
# 正则表达式模式
|
||||||
|
IMAGE_URL_PATTERN = r'!\[(.*?)\]\((.*?)\)'
|
||||||
|
DATA_URL_PATTERN = r'data:([^;]+);base64,(.+)'
|
||||||
39
app/core/initialization.py
Normal file
39
app/core/initialization.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
"""
|
||||||
|
应用程序初始化模块
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
logger = logging.getLogger("initialization")
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_directories_exist(directories: List[str]) -> None:
|
||||||
|
"""
|
||||||
|
确保指定的目录存在,如果不存在则创建
|
||||||
|
|
||||||
|
Args:
|
||||||
|
directories: 要确保存在的目录列表
|
||||||
|
"""
|
||||||
|
for directory in directories:
|
||||||
|
try:
|
||||||
|
Path(directory).mkdir(parents=True, exist_ok=True)
|
||||||
|
logger.info(f"Ensured directory exists: {directory}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create directory {directory}: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_app() -> None:
|
||||||
|
"""
|
||||||
|
初始化应用程序,确保所需的目录和文件都存在
|
||||||
|
"""
|
||||||
|
# 确保必要的目录存在
|
||||||
|
required_directories = [
|
||||||
|
"app/static/css",
|
||||||
|
"app/static/js",
|
||||||
|
"app/static/icons",
|
||||||
|
"app/templates",
|
||||||
|
]
|
||||||
|
|
||||||
|
ensure_directories_exist(required_directories)
|
||||||
|
logger.info("Application initialization completed")
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
from fastapi import HTTPException, Header
|
from fastapi import HTTPException, Header
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from app.core.logger import get_security_logger
|
from app.logger.logger import get_security_logger
|
||||||
from app.core.config import settings
|
from app.config.config import settings
|
||||||
|
|
||||||
logger = get_security_logger()
|
logger = get_security_logger()
|
||||||
|
|
||||||
|
|||||||
@@ -36,5 +36,5 @@ class GeminiRequest(BaseModel):
|
|||||||
contents: List[GeminiContent] = []
|
contents: List[GeminiContent] = []
|
||||||
tools: Optional[List[Dict[str, Any]]] = []
|
tools: Optional[List[Dict[str, Any]]] = []
|
||||||
safetySettings: Optional[List[SafetySetting]] = None
|
safetySettings: Optional[List[SafetySetting]] = None
|
||||||
generationConfig: Optional[GenerationConfig] = {}
|
generationConfig: Optional[GenerationConfig] = None
|
||||||
systemInstruction: Optional[SystemInstruction] = None
|
systemInstruction: Optional[SystemInstruction] = None
|
||||||
@@ -1,17 +1,19 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
from app.core.constants import DEFAULT_MAX_TOKENS, DEFAULT_MODEL, DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P
|
||||||
|
|
||||||
|
|
||||||
class ChatRequest(BaseModel):
|
class ChatRequest(BaseModel):
|
||||||
messages: List[dict]
|
messages: List[dict]
|
||||||
model: str = "gemini-1.5-flash-002"
|
model: str = DEFAULT_MODEL
|
||||||
temperature: Optional[float] = 0.7
|
temperature: Optional[float] = DEFAULT_TEMPERATURE
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
tools: Optional[List[dict]] = []
|
tools: Optional[List[dict]] = []
|
||||||
max_tokens: Optional[int] = 8192
|
max_tokens: Optional[int] = DEFAULT_MAX_TOKENS
|
||||||
|
top_p: Optional[float] = DEFAULT_TOP_P
|
||||||
|
top_k: Optional[int] = DEFAULT_TOP_K
|
||||||
stop: Optional[List[str]] = []
|
stop: Optional[List[str]] = []
|
||||||
top_p: Optional[float] = 0.9
|
|
||||||
top_k: Optional[int] = 40
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingRequest(BaseModel):
|
class EmbeddingRequest(BaseModel):
|
||||||
133
app/exception/exceptions.py
Normal file
133
app/exception/exceptions.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
"""
|
||||||
|
异常处理模块,定义应用程序中使用的自定义异常和异常处理器
|
||||||
|
"""
|
||||||
|
from fastapi import Request, FastAPI
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||||
|
|
||||||
|
from app.logger.logger import get_main_logger
|
||||||
|
|
||||||
|
logger = get_main_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class APIError(Exception):
|
||||||
|
"""API错误基类"""
|
||||||
|
def __init__(self, status_code: int, detail: str, error_code: str = None):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.detail = detail
|
||||||
|
self.error_code = error_code or "api_error"
|
||||||
|
super().__init__(self.detail)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationError(APIError):
|
||||||
|
"""认证错误"""
|
||||||
|
def __init__(self, detail: str = "Authentication failed"):
|
||||||
|
super().__init__(status_code=401, detail=detail, error_code="authentication_error")
|
||||||
|
|
||||||
|
|
||||||
|
class AuthorizationError(APIError):
|
||||||
|
"""授权错误"""
|
||||||
|
def __init__(self, detail: str = "Not authorized to access this resource"):
|
||||||
|
super().__init__(status_code=403, detail=detail, error_code="authorization_error")
|
||||||
|
|
||||||
|
|
||||||
|
class ResourceNotFoundError(APIError):
|
||||||
|
"""资源未找到错误"""
|
||||||
|
def __init__(self, detail: str = "Resource not found"):
|
||||||
|
super().__init__(status_code=404, detail=detail, error_code="resource_not_found")
|
||||||
|
|
||||||
|
|
||||||
|
class ModelNotSupportedError(APIError):
|
||||||
|
"""模型不支持错误"""
|
||||||
|
def __init__(self, model: str):
|
||||||
|
super().__init__(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Model {model} is not supported",
|
||||||
|
error_code="model_not_supported"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class APIKeyError(APIError):
|
||||||
|
"""API密钥错误"""
|
||||||
|
def __init__(self, detail: str = "Invalid or expired API key"):
|
||||||
|
super().__init__(status_code=401, detail=detail, error_code="api_key_error")
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceUnavailableError(APIError):
|
||||||
|
"""服务不可用错误"""
|
||||||
|
def __init__(self, detail: str = "Service temporarily unavailable"):
|
||||||
|
super().__init__(status_code=503, detail=detail, error_code="service_unavailable")
|
||||||
|
|
||||||
|
|
||||||
|
def setup_exception_handlers(app: FastAPI) -> None:
|
||||||
|
"""
|
||||||
|
设置应用程序的异常处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI应用程序实例
|
||||||
|
"""
|
||||||
|
@app.exception_handler(APIError)
|
||||||
|
async def api_error_handler(request: Request, exc: APIError):
|
||||||
|
"""处理API错误"""
|
||||||
|
logger.error(f"API Error: {exc.detail} (Code: {exc.error_code})")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"code": exc.error_code,
|
||||||
|
"message": exc.detail
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.exception_handler(StarletteHTTPException)
|
||||||
|
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
|
||||||
|
"""处理HTTP异常"""
|
||||||
|
logger.error(f"HTTP Exception: {exc.detail} (Status: {exc.status_code})")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"code": "http_error",
|
||||||
|
"message": exc.detail
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.exception_handler(RequestValidationError)
|
||||||
|
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||||
|
"""处理请求验证错误"""
|
||||||
|
error_details = []
|
||||||
|
for error in exc.errors():
|
||||||
|
error_details.append({
|
||||||
|
"loc": error["loc"],
|
||||||
|
"msg": error["msg"],
|
||||||
|
"type": error["type"]
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.error(f"Validation Error: {error_details}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=422,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"code": "validation_error",
|
||||||
|
"message": "Request validation failed",
|
||||||
|
"details": error_details
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.exception_handler(Exception)
|
||||||
|
async def general_exception_handler(request: Request, exc: Exception):
|
||||||
|
"""处理通用异常"""
|
||||||
|
logger.exception(f"Unhandled Exception: {str(exc)}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"code": "internal_server_error",
|
||||||
|
"message": "An unexpected error occurred"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
@@ -6,8 +6,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
import requests
|
import requests
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
SUPPORTED_ROLES = ["user", "model", "system"]
|
from app.core.constants import DATA_URL_PATTERN, IMAGE_URL_PATTERN, SUPPORTED_ROLES
|
||||||
IMAGE_URL_PATTERN = r'\[(.*?)\]\((.*?)\)'
|
|
||||||
|
|
||||||
|
|
||||||
class MessageConverter(ABC):
|
class MessageConverter(ABC):
|
||||||
@@ -30,7 +29,7 @@ def _get_mime_type_and_data(base64_string):
|
|||||||
# 检查字符串是否以 "data:" 格式开始
|
# 检查字符串是否以 "data:" 格式开始
|
||||||
if base64_string.startswith('data:'):
|
if base64_string.startswith('data:'):
|
||||||
# 提取 MIME 类型和数据
|
# 提取 MIME 类型和数据
|
||||||
pattern = r'data:([^;]+);base64,(.+)'
|
pattern = DATA_URL_PATTERN
|
||||||
match = re.match(pattern, base64_string)
|
match = re.match(pattern, base64_string)
|
||||||
if match:
|
if match:
|
||||||
mime_type = "image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
|
mime_type = "image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
|
||||||
@@ -8,8 +8,8 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import Dict, Any, List, Optional
|
from typing import Dict, Any, List, Optional
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from app.core.config import settings
|
from app.config.config import settings
|
||||||
from app.core.uploader import ImageUploaderFactory
|
from app.utils.uploader import ImageUploaderFactory
|
||||||
|
|
||||||
|
|
||||||
class ResponseHandler(ABC):
|
class ResponseHandler(ABC):
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from typing import TypeVar, Callable
|
from typing import TypeVar, Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from app.core.logger import get_retry_logger
|
from app.logger.logger import get_retry_logger
|
||||||
|
|
||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
logger = get_retry_logger()
|
logger = get_retry_logger()
|
||||||
@@ -3,8 +3,9 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import math
|
import math
|
||||||
from typing import Any, List, AsyncGenerator, Callable
|
from typing import Any, List, AsyncGenerator, Callable
|
||||||
from app.core.logger import get_openai_logger, get_gemini_logger
|
from app.logger.logger import get_openai_logger, get_gemini_logger
|
||||||
from app.core.config import settings
|
from app.config.config import settings
|
||||||
|
from app.core.constants import DEFAULT_STREAM_CHUNK_SIZE, DEFAULT_STREAM_LONG_TEXT_THRESHOLD, DEFAULT_STREAM_MAX_DELAY, DEFAULT_STREAM_MIN_DELAY, DEFAULT_STREAM_SHORT_TEXT_THRESHOLD
|
||||||
|
|
||||||
logger_openai = get_openai_logger()
|
logger_openai = get_openai_logger()
|
||||||
logger_gemini = get_gemini_logger()
|
logger_gemini = get_gemini_logger()
|
||||||
@@ -18,11 +19,11 @@ class StreamOptimizer:
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
logger=None,
|
logger=None,
|
||||||
min_delay: float = 0.016,
|
min_delay: float = DEFAULT_STREAM_MIN_DELAY,
|
||||||
max_delay: float = 0.024,
|
max_delay: float = DEFAULT_STREAM_MAX_DELAY,
|
||||||
short_text_threshold: int = 10,
|
short_text_threshold: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD,
|
||||||
long_text_threshold: int = 50,
|
long_text_threshold: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD,
|
||||||
chunk_size: int = 5):
|
chunk_size: int = DEFAULT_STREAM_CHUNK_SIZE):
|
||||||
"""初始化流式输出优化器
|
"""初始化流式输出优化器
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
132
app/main.py
132
app/main.py
@@ -1,134 +1,16 @@
|
|||||||
from fastapi import FastAPI, Request
|
"""
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
应用程序入口模块
|
||||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
"""
|
||||||
from fastapi.templating import Jinja2Templates
|
|
||||||
from fastapi.staticfiles import StaticFiles
|
|
||||||
from app.core.logger import get_main_logger
|
|
||||||
from app.core.security import verify_auth_token
|
|
||||||
from app.services.key_manager import get_key_manager_instance
|
|
||||||
from app.core.config import settings
|
|
||||||
|
|
||||||
from app.api import gemini_routes, openai_routes
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
from app.core.application import create_app
|
||||||
|
from app.logger.logger import get_main_logger
|
||||||
|
|
||||||
|
# 创建应用程序实例
|
||||||
|
app = create_app()
|
||||||
|
|
||||||
# 配置日志
|
# 配置日志
|
||||||
logger = get_main_logger()
|
logger = get_main_logger()
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
# 配置Jinja2模板
|
|
||||||
templates = Jinja2Templates(directory="app/templates")
|
|
||||||
|
|
||||||
# 配置静态文件
|
|
||||||
app.mount("/static", StaticFiles(directory="app/static"), name="static")
|
|
||||||
|
|
||||||
# 创建 KeyManager 实例
|
|
||||||
key_manager = None
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
|
||||||
async def startup_event():
|
|
||||||
global key_manager
|
|
||||||
logger.info("Application starting up...")
|
|
||||||
try:
|
|
||||||
key_manager = await get_key_manager_instance(settings.API_KEYS)
|
|
||||||
logger.info("KeyManager initialized successfully")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to initialize KeyManager: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# 添加中间件来处理未经身份验证的请求
|
|
||||||
@app.middleware("http")
|
|
||||||
async def auth_middleware(request: Request, call_next):
|
|
||||||
# 允许 gemini_routes 和 openai_routes 中的端点绕过身份验证
|
|
||||||
if (request.url.path not in ["/", "/auth"] and
|
|
||||||
not request.url.path.startswith("/static") and
|
|
||||||
not request.url.path.startswith("/gemini") and
|
|
||||||
not request.url.path.startswith("/v1") and
|
|
||||||
not request.url.path.startswith("/v1beta") and
|
|
||||||
not request.url.path.startswith("/health") and
|
|
||||||
not request.url.path.startswith("/hf")):
|
|
||||||
auth_token = request.cookies.get("auth_token")
|
|
||||||
if not auth_token or not verify_auth_token(auth_token):
|
|
||||||
logger.warning(f"Unauthorized access attempt to {request.url.path}")
|
|
||||||
return RedirectResponse(url="/")
|
|
||||||
logger.debug("Request authenticated successfully")
|
|
||||||
response = await call_next(request)
|
|
||||||
return response
|
|
||||||
|
|
||||||
# 添加请求日志中间件
|
|
||||||
# app.add_middleware(RequestLoggingMiddleware)
|
|
||||||
|
|
||||||
# 配置CORS中间件
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=["*"], # 生产环境建议配置具体的域名
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], # 明确指定允许的HTTP方法
|
|
||||||
allow_headers=["*"], # 生产环境建议配置具体的请求头
|
|
||||||
expose_headers=["*"], # 允许前端访问的响应头
|
|
||||||
max_age=600, # 预检请求缓存时间(秒)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 包含所有路由
|
|
||||||
app.include_router(openai_routes.router)
|
|
||||||
app.include_router(gemini_routes.router)
|
|
||||||
app.include_router(gemini_routes.router_v1beta)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/", response_class=HTMLResponse)
|
|
||||||
async def auth_page(request: Request):
|
|
||||||
return templates.TemplateResponse("auth.html", {"request": request})
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/auth")
|
|
||||||
async def authenticate(request: Request):
|
|
||||||
try:
|
|
||||||
form = await request.form()
|
|
||||||
auth_token = form.get("auth_token")
|
|
||||||
if not auth_token:
|
|
||||||
logger.warning("Authentication attempt with empty token")
|
|
||||||
return RedirectResponse(url="/", status_code=302)
|
|
||||||
|
|
||||||
if verify_auth_token(auth_token):
|
|
||||||
logger.info("Successful authentication")
|
|
||||||
response = RedirectResponse(url="/keys", status_code=302)
|
|
||||||
response.set_cookie(key="auth_token", value=auth_token, httponly=True, max_age=3600)
|
|
||||||
return response
|
|
||||||
logger.warning("Failed authentication attempt with invalid token")
|
|
||||||
return RedirectResponse(url="/", status_code=302)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Authentication error: {str(e)}")
|
|
||||||
return RedirectResponse(url="/", status_code=302)
|
|
||||||
|
|
||||||
@app.get("/keys", response_class=HTMLResponse)
|
|
||||||
async def keys_page(request: Request):
|
|
||||||
try:
|
|
||||||
auth_token = request.cookies.get("auth_token")
|
|
||||||
if not auth_token or not verify_auth_token(auth_token):
|
|
||||||
logger.warning("Unauthorized access attempt to keys page")
|
|
||||||
return RedirectResponse(url="/", status_code=302)
|
|
||||||
|
|
||||||
keys_status = await key_manager.get_keys_by_status()
|
|
||||||
total = len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"])
|
|
||||||
logger.info(f"Keys status retrieved successfully. Total keys: {total}")
|
|
||||||
return templates.TemplateResponse("keys_status.html", {
|
|
||||||
"request": request,
|
|
||||||
"valid_keys": keys_status["valid_keys"],
|
|
||||||
"invalid_keys": keys_status["invalid_keys"],
|
|
||||||
"total": total
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error retrieving keys status: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
|
||||||
async def health_check(request: Request):
|
|
||||||
logger.info("Health check endpoint called")
|
|
||||||
return {"status": "healthy"}
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logger.info("Starting application server...")
|
logger.info("Starting application server...")
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||||||
|
|||||||
61
app/middleware/middleware.py
Normal file
61
app/middleware/middleware.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
"""
|
||||||
|
中间件配置模块,负责设置和配置应用程序的中间件
|
||||||
|
"""
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
|
||||||
|
from app.logger.logger import get_main_logger
|
||||||
|
from app.core.security import verify_auth_token
|
||||||
|
# from app.middleware.request_logging_middleware import RequestLoggingMiddleware
|
||||||
|
from app.core.constants import API_VERSION
|
||||||
|
|
||||||
|
logger = get_main_logger()
|
||||||
|
|
||||||
|
class AuthMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""
|
||||||
|
认证中间件,处理未经身份验证的请求
|
||||||
|
"""
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
# 允许特定路径绕过身份验证
|
||||||
|
if (request.url.path not in ["/", "/auth"] and
|
||||||
|
not request.url.path.startswith("/static") and
|
||||||
|
not request.url.path.startswith("/gemini") and
|
||||||
|
not request.url.path.startswith("/v1") and
|
||||||
|
not request.url.path.startswith(f"/{API_VERSION}") and
|
||||||
|
not request.url.path.startswith("/health") and
|
||||||
|
not request.url.path.startswith("/hf")):
|
||||||
|
|
||||||
|
auth_token = request.cookies.get("auth_token")
|
||||||
|
if not auth_token or not verify_auth_token(auth_token):
|
||||||
|
logger.warning(f"Unauthorized access attempt to {request.url.path}")
|
||||||
|
return RedirectResponse(url="/")
|
||||||
|
logger.debug("Request authenticated successfully")
|
||||||
|
|
||||||
|
response = await call_next(request)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def setup_middlewares(app: FastAPI) -> None:
|
||||||
|
"""
|
||||||
|
设置应用程序的中间件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI应用程序实例
|
||||||
|
"""
|
||||||
|
# 添加认证中间件
|
||||||
|
app.add_middleware(AuthMiddleware)
|
||||||
|
|
||||||
|
# 添加请求日志中间件(可选,默认注释掉)
|
||||||
|
# app.add_middleware(RequestLoggingMiddleware)
|
||||||
|
|
||||||
|
# 配置CORS中间件
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"], # 生产环境建议配置具体的域名
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], # 明确指定允许的HTTP方法
|
||||||
|
allow_headers=["*"], # 生产环境建议配置具体的请求头
|
||||||
|
expose_headers=["*"], # 允许前端访问的响应头
|
||||||
|
max_age=600, # 预检请求缓存时间(秒)
|
||||||
|
)
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
import json
|
import json
|
||||||
from app.core.logger import get_request_logger
|
from app.logger.logger import get_request_logger
|
||||||
|
|
||||||
logger = get_request_logger()
|
logger = get_request_logger()
|
||||||
|
|
||||||
|
|||||||
@@ -1,73 +1,80 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from fastapi.responses import StreamingResponse, JSONResponse
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from app.core.config import settings
|
from app.config.config import settings
|
||||||
from app.core.logger import get_gemini_logger
|
from app.logger.logger import get_gemini_logger
|
||||||
from app.core.security import SecurityService
|
from app.core.security import SecurityService
|
||||||
from app.schemas.gemini_models import GeminiContent, GeminiRequest
|
from app.domain.gemini_models import GeminiContent, GeminiRequest
|
||||||
from app.services.gemini_chat_service import GeminiChatService
|
from app.service.chat.gemini_chat_service import GeminiChatService
|
||||||
from app.services.key_manager import KeyManager, get_key_manager_instance
|
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||||
from app.services.model_service import ModelService
|
from app.service.model.model_service import ModelService
|
||||||
from app.services.chat.retry_handler import RetryHandler
|
from app.handler.retry_handler import RetryHandler
|
||||||
|
from app.core.constants import API_VERSION
|
||||||
|
|
||||||
router = APIRouter(prefix="/gemini/v1beta")
|
# 路由设置
|
||||||
router_v1beta = APIRouter(prefix="/v1beta")
|
router = APIRouter(prefix=f"/gemini/{API_VERSION}")
|
||||||
|
router_v1beta = APIRouter(prefix=f"/{API_VERSION}")
|
||||||
logger = get_gemini_logger()
|
logger = get_gemini_logger()
|
||||||
|
|
||||||
# 初始化服务
|
# 初始化服务
|
||||||
security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN)
|
security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN)
|
||||||
|
model_service = ModelService(settings.SEARCH_MODELS, settings.IMAGE_MODELS)
|
||||||
|
|
||||||
|
|
||||||
async def get_key_manager():
|
async def get_key_manager():
|
||||||
|
"""获取密钥管理器实例"""
|
||||||
return await get_key_manager_instance()
|
return await get_key_manager_instance()
|
||||||
|
|
||||||
async def get_next_working_key_wrapper(key_manager: KeyManager = Depends(get_key_manager)):
|
|
||||||
return await key_manager.get_next_working_key()
|
|
||||||
|
|
||||||
model_service = ModelService(settings.SEARCH_MODELS,settings.IMAGE_MODELS)
|
async def get_next_working_key(key_manager: KeyManager = Depends(get_key_manager)):
|
||||||
|
"""获取下一个可用的API密钥"""
|
||||||
|
return await key_manager.get_next_working_key()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/models")
|
@router.get("/models")
|
||||||
@router_v1beta.get("/models")
|
@router_v1beta.get("/models")
|
||||||
async def list_models(_=Depends(security_service.verify_key),
|
async def list_models(
|
||||||
key_manager: KeyManager = Depends(get_key_manager)):
|
_=Depends(security_service.verify_key),
|
||||||
|
key_manager: KeyManager = Depends(get_key_manager)
|
||||||
|
):
|
||||||
"""获取可用的Gemini模型列表"""
|
"""获取可用的Gemini模型列表"""
|
||||||
logger.info("-" * 50 + "list_gemini_models" + "-" * 50)
|
logger.info("-" * 50 + "list_gemini_models" + "-" * 50)
|
||||||
logger.info("Handling Gemini models list request")
|
logger.info("Handling Gemini models list request")
|
||||||
|
|
||||||
api_key = await key_manager.get_next_working_key()
|
api_key = await key_manager.get_next_working_key()
|
||||||
logger.info(f"Using API key: {api_key}")
|
logger.info(f"Using API key: {api_key}")
|
||||||
|
|
||||||
models_json = model_service.get_gemini_models(api_key)
|
models_json = model_service.get_gemini_models(api_key)
|
||||||
|
|
||||||
# 模型名称以及对应的详细信息
|
|
||||||
model_mapping = {x.get("name", "").split("/", maxsplit=1)[1]: x for x in models_json["models"]}
|
model_mapping = {x.get("name", "").split("/", maxsplit=1)[1]: x for x in models_json["models"]}
|
||||||
|
|
||||||
# 添加搜索模型
|
# 添加搜索模型
|
||||||
if model_service.search_models:
|
if model_service.search_models:
|
||||||
for name in model_service.search_models:
|
for name in model_service.search_models:
|
||||||
model = model_mapping.get(name, None)
|
model = model_mapping.get(name)
|
||||||
if not model:
|
if not model:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
item = deepcopy(model)
|
item = deepcopy(model)
|
||||||
item["name"] = f"models/{name}-search"
|
item["name"] = f"models/{name}-search"
|
||||||
display_name = f'{item.get("displayName")} For Search'
|
display_name = f'{item.get("displayName")} For Search'
|
||||||
item["displayName"] = display_name
|
item["displayName"] = display_name
|
||||||
item["description"] = display_name
|
item["description"] = display_name
|
||||||
|
|
||||||
models_json["models"].append(item)
|
models_json["models"].append(item)
|
||||||
|
|
||||||
# 添加图像生成模型
|
# 添加图像生成模型
|
||||||
if model_service.image_models:
|
if model_service.image_models:
|
||||||
for name in model_service.image_models:
|
for name in model_service.image_models:
|
||||||
model = model_mapping.get(name, None)
|
model = model_mapping.get(name)
|
||||||
if not model:
|
if not model:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
item = deepcopy(model)
|
item = deepcopy(model)
|
||||||
item["name"] = f"models/{name}-image"
|
item["name"] = f"models/{name}-image"
|
||||||
display_name = f'{item.get("displayName")} For Image'
|
display_name = f'{item.get("displayName")} For Image'
|
||||||
item["displayName"] = display_name
|
item["displayName"] = display_name
|
||||||
item["description"] = display_name
|
item["description"] = display_name
|
||||||
|
|
||||||
models_json["models"].append(item)
|
models_json["models"].append(item)
|
||||||
|
|
||||||
return models_json
|
return models_json
|
||||||
@@ -77,30 +84,29 @@ async def list_models(_=Depends(security_service.verify_key),
|
|||||||
@router_v1beta.post("/models/{model_name}:generateContent")
|
@router_v1beta.post("/models/{model_name}:generateContent")
|
||||||
@RetryHandler(max_retries=3, key_arg="api_key")
|
@RetryHandler(max_retries=3, key_arg="api_key")
|
||||||
async def generate_content(
|
async def generate_content(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
request: GeminiRequest,
|
request: GeminiRequest,
|
||||||
_=Depends(security_service.verify_goog_api_key),
|
_=Depends(security_service.verify_goog_api_key),
|
||||||
api_key: str = Depends(get_next_working_key_wrapper),
|
api_key: str = Depends(get_next_working_key),
|
||||||
key_manager: KeyManager = Depends(get_key_manager)
|
key_manager: KeyManager = Depends(get_key_manager)
|
||||||
):
|
):
|
||||||
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
|
|
||||||
"""非流式生成内容"""
|
"""非流式生成内容"""
|
||||||
logger.info("-" * 50 + "gemini_generate_content" + "-" * 50)
|
logger.info("-" * 50 + "gemini_generate_content" + "-" * 50)
|
||||||
logger.info(f"Handling Gemini content generation request for model: {model_name}")
|
logger.info(f"Handling Gemini content generation request for model: {model_name}")
|
||||||
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||||
logger.info(f"Using API key: {api_key}")
|
logger.info(f"Using API key: {api_key}")
|
||||||
|
|
||||||
if not model_service.check_model_support(model_name):
|
if not model_service.check_model_support(model_name):
|
||||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
|
||||||
response = await chat_service.generate_content(
|
response = await chat_service.generate_content(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
request=request,
|
request=request,
|
||||||
api_key=api_key
|
api_key=api_key
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Chat completion failed after retries: {str(e)}")
|
logger.error(f"Chat completion failed after retries: {str(e)}")
|
||||||
raise HTTPException(status_code=500, detail="Chat completion failed") from e
|
raise HTTPException(status_code=500, detail="Chat completion failed") from e
|
||||||
@@ -110,45 +116,46 @@ async def generate_content(
|
|||||||
@router_v1beta.post("/models/{model_name}:streamGenerateContent")
|
@router_v1beta.post("/models/{model_name}:streamGenerateContent")
|
||||||
@RetryHandler(max_retries=3, key_arg="api_key")
|
@RetryHandler(max_retries=3, key_arg="api_key")
|
||||||
async def stream_generate_content(
|
async def stream_generate_content(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
request: GeminiRequest,
|
request: GeminiRequest,
|
||||||
_=Depends(security_service.verify_goog_api_key),
|
_=Depends(security_service.verify_goog_api_key),
|
||||||
api_key: str = Depends(get_next_working_key_wrapper),
|
api_key: str = Depends(get_next_working_key),
|
||||||
key_manager: KeyManager = Depends(get_key_manager)
|
key_manager: KeyManager = Depends(get_key_manager)
|
||||||
):
|
):
|
||||||
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
|
|
||||||
"""流式生成内容"""
|
"""流式生成内容"""
|
||||||
logger.info("-" * 50 + "gemini_stream_generate_content" + "-" * 50)
|
logger.info("-" * 50 + "gemini_stream_generate_content" + "-" * 50)
|
||||||
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
|
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
|
||||||
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||||
logger.info(f"Using API key: {api_key}")
|
logger.info(f"Using API key: {api_key}")
|
||||||
|
|
||||||
if not model_service.check_model_support(model_name):
|
if not model_service.check_model_support(model_name):
|
||||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
|
||||||
response_stream = chat_service.stream_generate_content(
|
response_stream = chat_service.stream_generate_content(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
request=request,
|
request=request,
|
||||||
api_key=api_key
|
api_key=api_key
|
||||||
)
|
)
|
||||||
return StreamingResponse(response_stream, media_type="text/event-stream")
|
return StreamingResponse(response_stream, media_type="text/event-stream")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Streaming request failed: {str(e)}")
|
logger.error(f"Streaming request failed: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail="Streaming request failed") from e
|
||||||
|
|
||||||
|
|
||||||
@router.post("/verify-key/{api_key}")
|
@router.post("/verify-key/{api_key}")
|
||||||
async def verify_key(api_key: str):
|
async def verify_key(api_key: str):
|
||||||
key_manager = await get_key_manager()
|
|
||||||
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
|
|
||||||
"""验证Gemini API密钥的有效性"""
|
"""验证Gemini API密钥的有效性"""
|
||||||
logger.info("-" * 50 + "verify_gemini_key" + "-" * 50)
|
logger.info("-" * 50 + "verify_gemini_key" + "-" * 50)
|
||||||
logger.info("Verifying API key validity")
|
logger.info("Verifying API key validity")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
key_manager = await get_key_manager()
|
||||||
|
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
|
||||||
|
|
||||||
# 使用generate_content接口测试key的有效性
|
# 使用generate_content接口测试key的有效性
|
||||||
gemini_requset = GeminiRequest(
|
gemini_request = GeminiRequest(
|
||||||
contents=[
|
contents=[
|
||||||
GeminiContent(
|
GeminiContent(
|
||||||
role="user",
|
role="user",
|
||||||
@@ -156,10 +163,16 @@ async def verify_key(api_key: str):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
response = await chat_service.generate_content(settings.TEST_MODEL,gemini_requset, api_key)
|
|
||||||
|
response = await chat_service.generate_content(
|
||||||
|
settings.TEST_MODEL,
|
||||||
|
gemini_request,
|
||||||
|
api_key
|
||||||
|
)
|
||||||
|
|
||||||
if response:
|
if response:
|
||||||
return JSONResponse({"status": "valid"})
|
return JSONResponse({"status": "valid"})
|
||||||
return JSONResponse({"status": "invalid"})
|
return JSONResponse({"status": "invalid"})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Key verification failed: {str(e)}")
|
logger.error(f"Key verification failed: {str(e)}")
|
||||||
return JSONResponse({"status": "invalid", "error": str(e)})
|
return JSONResponse({"status": "invalid", "error": str(e)})
|
||||||
@@ -1,16 +1,16 @@
|
|||||||
from fastapi import HTTPException, APIRouter, Depends
|
from fastapi import HTTPException, APIRouter, Depends
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.config.config import settings
|
||||||
from app.core.logger import get_openai_logger
|
from app.logger.logger import get_openai_logger
|
||||||
from app.core.security import SecurityService
|
from app.core.security import SecurityService
|
||||||
from app.schemas.openai_models import ChatRequest, EmbeddingRequest, ImageGenerationRequest
|
from app.domain.openai_models import ChatRequest, EmbeddingRequest, ImageGenerationRequest
|
||||||
from app.services.chat.retry_handler import RetryHandler
|
from app.handler.retry_handler import RetryHandler
|
||||||
from app.services.embedding_service import EmbeddingService
|
from app.service.embedding.embedding_service import EmbeddingService
|
||||||
from app.services.image_create_service import ImageCreateService
|
from app.service.image.image_create_service import ImageCreateService
|
||||||
from app.services.key_manager import KeyManager, get_key_manager_instance
|
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||||
from app.services.model_service import ModelService
|
from app.service.model.model_service import ModelService
|
||||||
from app.services.openai_chat_service import OpenAIChatService
|
from app.service.chat.openai_chat_service import OpenAIChatService
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
logger = get_openai_logger()
|
logger = get_openai_logger()
|
||||||
103
app/router/routers.py
Normal file
103
app/router/routers.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
"""
|
||||||
|
路由配置模块,负责设置和配置应用程序的路由
|
||||||
|
"""
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||||
|
from fastapi.templating import Jinja2Templates
|
||||||
|
|
||||||
|
from app.logger.logger import get_main_logger
|
||||||
|
from app.core.security import verify_auth_token
|
||||||
|
from app.router import gemini_routes, openai_routes
|
||||||
|
from app.service.key.key_manager import get_key_manager_instance
|
||||||
|
|
||||||
|
logger = get_main_logger()
|
||||||
|
|
||||||
|
# 配置Jinja2模板
|
||||||
|
templates = Jinja2Templates(directory="app/templates")
|
||||||
|
|
||||||
|
def setup_routers(app: FastAPI) -> None:
|
||||||
|
"""
|
||||||
|
设置应用程序的路由
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI应用程序实例
|
||||||
|
"""
|
||||||
|
# 包含API路由
|
||||||
|
app.include_router(openai_routes.router)
|
||||||
|
app.include_router(gemini_routes.router)
|
||||||
|
app.include_router(gemini_routes.router_v1beta)
|
||||||
|
|
||||||
|
# 添加页面路由
|
||||||
|
setup_page_routes(app)
|
||||||
|
|
||||||
|
# 添加健康检查路由
|
||||||
|
setup_health_routes(app)
|
||||||
|
|
||||||
|
def setup_page_routes(app: FastAPI) -> None:
|
||||||
|
"""
|
||||||
|
设置页面相关的路由
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI应用程序实例
|
||||||
|
"""
|
||||||
|
@app.get("/", response_class=HTMLResponse)
|
||||||
|
async def auth_page(request: Request):
|
||||||
|
"""认证页面"""
|
||||||
|
return templates.TemplateResponse("auth.html", {"request": request})
|
||||||
|
|
||||||
|
@app.post("/auth")
|
||||||
|
async def authenticate(request: Request):
|
||||||
|
"""处理认证请求"""
|
||||||
|
try:
|
||||||
|
form = await request.form()
|
||||||
|
auth_token = form.get("auth_token")
|
||||||
|
if not auth_token:
|
||||||
|
logger.warning("Authentication attempt with empty token")
|
||||||
|
return RedirectResponse(url="/", status_code=302)
|
||||||
|
|
||||||
|
if verify_auth_token(auth_token):
|
||||||
|
logger.info("Successful authentication")
|
||||||
|
response = RedirectResponse(url="/keys", status_code=302)
|
||||||
|
response.set_cookie(key="auth_token", value=auth_token, httponly=True, max_age=3600)
|
||||||
|
return response
|
||||||
|
logger.warning("Failed authentication attempt with invalid token")
|
||||||
|
return RedirectResponse(url="/", status_code=302)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Authentication error: {str(e)}")
|
||||||
|
return RedirectResponse(url="/", status_code=302)
|
||||||
|
|
||||||
|
@app.get("/keys", response_class=HTMLResponse)
|
||||||
|
async def keys_page(request: Request):
|
||||||
|
"""密钥管理页面"""
|
||||||
|
try:
|
||||||
|
auth_token = request.cookies.get("auth_token")
|
||||||
|
if not auth_token or not verify_auth_token(auth_token):
|
||||||
|
logger.warning("Unauthorized access attempt to keys page")
|
||||||
|
return RedirectResponse(url="/", status_code=302)
|
||||||
|
|
||||||
|
key_manager = await get_key_manager_instance()
|
||||||
|
keys_status = await key_manager.get_keys_by_status()
|
||||||
|
total = len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"])
|
||||||
|
logger.info(f"Keys status retrieved successfully. Total keys: {total}")
|
||||||
|
return templates.TemplateResponse("keys_status.html", {
|
||||||
|
"request": request,
|
||||||
|
"valid_keys": keys_status["valid_keys"],
|
||||||
|
"invalid_keys": keys_status["invalid_keys"],
|
||||||
|
"total": total
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error retrieving keys status: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def setup_health_routes(app: FastAPI) -> None:
|
||||||
|
"""
|
||||||
|
设置健康检查相关的路由
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI应用程序实例
|
||||||
|
"""
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check(request: Request):
|
||||||
|
"""健康检查端点"""
|
||||||
|
logger.info("Health check endpoint called")
|
||||||
|
return {"status": "healthy"}
|
||||||
@@ -2,13 +2,13 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Dict, Any, AsyncGenerator, List
|
from typing import Dict, Any, AsyncGenerator, List
|
||||||
from app.core.logger import get_gemini_logger
|
from app.logger.logger import get_gemini_logger
|
||||||
from app.services.chat.api_client import GeminiApiClient
|
from app.service.client.api_client import GeminiApiClient
|
||||||
from app.services.chat.stream_optimizer import gemini_optimizer
|
from app.handler.stream_optimizer import gemini_optimizer
|
||||||
from app.schemas.gemini_models import GeminiRequest
|
from app.domain.gemini_models import GeminiRequest
|
||||||
from app.core.config import settings
|
from app.config.config import settings
|
||||||
from app.services.chat.response_handler import GeminiResponseHandler
|
from app.handler.response_handler import GeminiResponseHandler
|
||||||
from app.services.key_manager import KeyManager
|
from app.service.key.key_manager import KeyManager
|
||||||
|
|
||||||
logger = get_gemini_logger()
|
logger = get_gemini_logger()
|
||||||
|
|
||||||
@@ -3,15 +3,15 @@
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import json
|
import json
|
||||||
from typing import Dict, Any, AsyncGenerator, List, Optional, Union
|
from typing import Dict, Any, AsyncGenerator, List, Optional, Union
|
||||||
from app.core.logger import get_openai_logger
|
from app.logger.logger import get_openai_logger
|
||||||
from app.services.chat.message_converter import OpenAIMessageConverter
|
from app.handler.message_converter import OpenAIMessageConverter
|
||||||
from app.services.chat.response_handler import OpenAIResponseHandler
|
from app.handler.response_handler import OpenAIResponseHandler
|
||||||
from app.services.chat.api_client import GeminiApiClient
|
from app.service.client.api_client import GeminiApiClient
|
||||||
from app.services.chat.stream_optimizer import openai_optimizer
|
from app.handler.stream_optimizer import openai_optimizer
|
||||||
from app.schemas.openai_models import ChatRequest, ImageGenerationRequest
|
from app.domain.openai_models import ChatRequest, ImageGenerationRequest
|
||||||
from app.core.config import settings
|
from app.config.config import settings
|
||||||
from app.services.image_create_service import ImageCreateService
|
from app.service.image.image_create_service import ImageCreateService
|
||||||
from app.services.key_manager import KeyManager
|
from app.service.key.key_manager import KeyManager
|
||||||
|
|
||||||
logger = get_openai_logger()
|
logger = get_openai_logger()
|
||||||
|
|
||||||
@@ -4,6 +4,8 @@ from typing import Dict, Any, AsyncGenerator
|
|||||||
import httpx
|
import httpx
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from app.core.constants import DEFAULT_TIMEOUT
|
||||||
|
|
||||||
|
|
||||||
class ApiClient(ABC):
|
class ApiClient(ABC):
|
||||||
"""API客户端基类"""
|
"""API客户端基类"""
|
||||||
@@ -20,7 +22,7 @@ class ApiClient(ABC):
|
|||||||
class GeminiApiClient(ApiClient):
|
class GeminiApiClient(ApiClient):
|
||||||
"""Gemini API客户端"""
|
"""Gemini API客户端"""
|
||||||
|
|
||||||
def __init__(self, base_url: str, timeout: int = 300):
|
def __init__(self, base_url: str, timeout: int = DEFAULT_TIMEOUT):
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
|
||||||
@@ -3,7 +3,7 @@ from typing import Union, List
|
|||||||
import openai
|
import openai
|
||||||
from openai.types import CreateEmbeddingResponse
|
from openai.types import CreateEmbeddingResponse
|
||||||
|
|
||||||
from app.core.logger import get_embeddings_logger
|
from app.logger.logger import get_embeddings_logger
|
||||||
|
|
||||||
logger = get_embeddings_logger()
|
logger = get_embeddings_logger()
|
||||||
|
|
||||||
@@ -5,10 +5,11 @@ from google import genai
|
|||||||
from google.genai import types
|
from google.genai import types
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.config.config import settings
|
||||||
from app.core.logger import get_image_create_logger
|
from app.logger.logger import get_image_create_logger
|
||||||
from app.core.uploader import ImageUploaderFactory
|
from app.utils.uploader import ImageUploaderFactory
|
||||||
from app.schemas.openai_models import ImageGenerationRequest
|
from app.domain.openai_models import ImageGenerationRequest
|
||||||
|
from app.core.constants import VALID_IMAGE_RATIOS
|
||||||
|
|
||||||
logger = get_image_create_logger()
|
logger = get_image_create_logger()
|
||||||
|
|
||||||
@@ -43,10 +44,9 @@ class ImageCreateService:
|
|||||||
ratio_match = re.search(r'{ratio:(\d+:\d+)}', prompt)
|
ratio_match = re.search(r'{ratio:(\d+:\d+)}', prompt)
|
||||||
if ratio_match:
|
if ratio_match:
|
||||||
aspect_ratio = ratio_match.group(1)
|
aspect_ratio = ratio_match.group(1)
|
||||||
valid_ratios = ["1:1", "3:4", "4:3", "9:16", "16:9"]
|
if aspect_ratio not in VALID_IMAGE_RATIOS:
|
||||||
if aspect_ratio not in valid_ratios:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(valid_ratios)}"
|
f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}"
|
||||||
)
|
)
|
||||||
prompt = prompt.replace(ratio_match.group(0), '').strip()
|
prompt = prompt.replace(ratio_match.group(0), '').strip()
|
||||||
|
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from itertools import cycle
|
from itertools import cycle
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from app.core.logger import get_key_manager_logger
|
from app.logger.logger import get_key_manager_logger
|
||||||
from app.core.config import settings
|
from app.config.config import settings
|
||||||
|
|
||||||
|
|
||||||
logger = get_key_manager_logger()
|
logger = get_key_manager_logger()
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
import requests
|
import requests
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
from app.core.logger import get_model_logger
|
from app.logger.logger import get_model_logger
|
||||||
from app.core.config import settings
|
from app.config.config import settings
|
||||||
|
|
||||||
logger = get_model_logger()
|
logger = get_model_logger()
|
||||||
|
|
||||||
@@ -10,7 +10,7 @@ class ModelService:
|
|||||||
def __init__(self, search_models: list, image_models: list):
|
def __init__(self, search_models: list, image_models: list):
|
||||||
self.search_models = search_models
|
self.search_models = search_models
|
||||||
self.image_models = image_models
|
self.image_models = image_models
|
||||||
self.base_url = "https://generativelanguage.googleapis.com/v1beta"
|
self.base_url = settings.BASE_URL
|
||||||
self.filtered_models = settings.FILTERED_MODELS
|
self.filtered_models = settings.FILTERED_MODELS
|
||||||
|
|
||||||
def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]:
|
def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||||
3
app/utils/__init__.py
Normal file
3
app/utils/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""
|
||||||
|
工具包初始化模块
|
||||||
|
"""
|
||||||
146
app/utils/helpers.py
Normal file
146
app/utils/helpers.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
"""
|
||||||
|
通用工具函数模块
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import base64
|
||||||
|
import requests
|
||||||
|
from typing import Dict, Any, List, Optional, Tuple
|
||||||
|
|
||||||
|
from app.core.constants import DATA_URL_PATTERN, IMAGE_URL_PATTERN, VALID_IMAGE_RATIOS
|
||||||
|
|
||||||
|
|
||||||
|
def extract_mime_type_and_data(base64_string: str) -> Tuple[Optional[str], str]:
|
||||||
|
"""
|
||||||
|
从 base64 字符串中提取 MIME 类型和数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base64_string: 可能包含 MIME 类型信息的 base64 字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (mime_type, encoded_data)
|
||||||
|
"""
|
||||||
|
# 检查字符串是否以 "data:" 格式开始
|
||||||
|
if base64_string.startswith('data:'):
|
||||||
|
# 提取 MIME 类型和数据
|
||||||
|
pattern = DATA_URL_PATTERN
|
||||||
|
match = re.match(pattern, base64_string)
|
||||||
|
if match:
|
||||||
|
mime_type = "image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
|
||||||
|
encoded_data = match.group(2)
|
||||||
|
return mime_type, encoded_data
|
||||||
|
|
||||||
|
# 如果不是预期格式,假定它只是数据部分
|
||||||
|
return None, base64_string
|
||||||
|
|
||||||
|
|
||||||
|
def convert_image_to_base64(url: str) -> str:
|
||||||
|
"""
|
||||||
|
将图片URL转换为base64编码
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: 图片URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: base64编码的图片数据
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: 如果获取图片失败
|
||||||
|
"""
|
||||||
|
response = requests.get(url)
|
||||||
|
if response.status_code == 200:
|
||||||
|
# 将图片内容转换为base64
|
||||||
|
img_data = base64.b64encode(response.content).decode('utf-8')
|
||||||
|
return img_data
|
||||||
|
else:
|
||||||
|
raise Exception(f"Failed to fetch image: {response.status_code}")
|
||||||
|
|
||||||
|
|
||||||
|
def format_json_response(data: Dict[str, Any], indent: int = 2) -> str:
|
||||||
|
"""
|
||||||
|
格式化JSON响应
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 要格式化的数据
|
||||||
|
indent: 缩进空格数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 格式化后的JSON字符串
|
||||||
|
"""
|
||||||
|
return json.dumps(data, indent=indent, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_prompt_parameters(prompt: str, default_ratio: str = "1:1") -> Tuple[str, int, str]:
|
||||||
|
"""
|
||||||
|
从prompt中解析参数
|
||||||
|
|
||||||
|
支持的格式:
|
||||||
|
- {n:数量} 例如: {n:2} 生成2张图片
|
||||||
|
- {ratio:比例} 例如: {ratio:16:9} 使用16:9比例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: 提示文本
|
||||||
|
default_ratio: 默认比例
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (清理后的提示文本, 图片数量, 比例)
|
||||||
|
"""
|
||||||
|
# 默认值
|
||||||
|
n = 1
|
||||||
|
aspect_ratio = default_ratio
|
||||||
|
|
||||||
|
# 解析n参数
|
||||||
|
n_match = re.search(r'{n:(\d+)}', prompt)
|
||||||
|
if n_match:
|
||||||
|
n = int(n_match.group(1))
|
||||||
|
if n < 1 or n > 4:
|
||||||
|
raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.")
|
||||||
|
prompt = prompt.replace(n_match.group(0), '').strip()
|
||||||
|
|
||||||
|
# 解析ratio参数
|
||||||
|
ratio_match = re.search(r'{ratio:(\d+:\d+)}', prompt)
|
||||||
|
if ratio_match:
|
||||||
|
aspect_ratio = ratio_match.group(1)
|
||||||
|
if aspect_ratio not in VALID_IMAGE_RATIOS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}"
|
||||||
|
)
|
||||||
|
prompt = prompt.replace(ratio_match.group(0), '').strip()
|
||||||
|
|
||||||
|
return prompt, n, aspect_ratio
|
||||||
|
|
||||||
|
|
||||||
|
def extract_image_urls_from_markdown(text: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
从Markdown文本中提取图片URL
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Markdown文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: 图片URL列表
|
||||||
|
"""
|
||||||
|
pattern = IMAGE_URL_PATTERN
|
||||||
|
matches = re.findall(pattern, text)
|
||||||
|
return [match[1] for match in matches]
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_api_key(key: str) -> bool:
|
||||||
|
"""
|
||||||
|
检查API密钥格式是否有效
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: API密钥
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 如果密钥格式有效则返回True
|
||||||
|
"""
|
||||||
|
# 检查Gemini API密钥格式
|
||||||
|
if key.startswith('AIza'):
|
||||||
|
return len(key) >= 30
|
||||||
|
|
||||||
|
# 检查OpenAI API密钥格式
|
||||||
|
if key.startswith('sk-'):
|
||||||
|
return len(key) >= 30
|
||||||
|
|
||||||
|
return False
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
import requests
|
import requests
|
||||||
from app.schemas.image_models import ImageMetadata, ImageUploader, UploadResponse
|
from app.domain.image_models import ImageMetadata, ImageUploader, UploadResponse
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any
|
||||||
|
|
||||||
Reference in New Issue
Block a user