diff --git a/Dockerfile b/Dockerfile index 97292c7..908d59d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,6 +14,7 @@ ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta ENV TOOLS_CODE_EXECUTION_ENABLED=false ENV IMAGE_MODELS='["gemini-2.0-flash-exp"]' ENV SEARCH_MODELS='["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]' +ENV URL_NORMALIZATION_ENABLED=false # Expose port EXPOSE 8000 diff --git a/app/config/config.py b/app/config/config.py index e6709e0..124da01 100644 --- a/app/config/config.py +++ b/app/config/config.py @@ -63,7 +63,10 @@ class Settings(BaseSettings): PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: bool = True # 是否使用一致性哈希来选择代理 VERTEX_API_KEYS: List[str] = [] VERTEX_EXPRESS_BASE_URL: str = "https://aiplatform.googleapis.com/v1beta1/publishers/google" - + + # 智能路由配置 + URL_NORMALIZATION_ENABLED: bool = False # 是否启用智能路由映射功能 + # 模型相关配置 SEARCH_MODELS: List[str] = ["gemini-2.0-flash-exp"] IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp"] @@ -111,6 +114,9 @@ class Settings(BaseSettings): AUTO_DELETE_REQUEST_LOGS_DAYS: int = 30 SAFETY_SETTINGS: List[Dict[str, str]] = DEFAULT_SAFETY_SETTINGS + #是否开启新手模式 + URL_NORMALIZATION_ENABLED: bool = False + def __init__(self, **kwargs): super().__init__(**kwargs) # 设置默认AUTH_TOKEN(如果未提供) diff --git a/app/middleware/middleware.py b/app/middleware/middleware.py index 05dded3..85d512f 100644 --- a/app/middleware/middleware.py +++ b/app/middleware/middleware.py @@ -8,6 +8,7 @@ from fastapi.responses import RedirectResponse from starlette.middleware.base import BaseHTTPMiddleware # from app.middleware.request_logging_middleware import RequestLoggingMiddleware +from app.middleware.smart_routing_middleware import SmartRoutingMiddleware from app.core.constants import API_VERSION from app.core.security import verify_auth_token from app.log.logger import get_middleware_logger @@ -52,6 +53,9 @@ def setup_middlewares(app: FastAPI) -> None: Args: app: FastAPI应用程序实例 """ + # 添加智能路由中间件(必须在认证中间件之前) + app.add_middleware(SmartRoutingMiddleware) + # 添加认证中间件 app.add_middleware(AuthMiddleware) diff --git a/app/middleware/smart_routing_middleware.py b/app/middleware/smart_routing_middleware.py new file mode 100644 index 0000000..391d061 --- /dev/null +++ b/app/middleware/smart_routing_middleware.py @@ -0,0 +1,227 @@ +from fastapi import Request +from starlette.middleware.base import BaseHTTPMiddleware +from app.config.config import settings +from app.log.logger import get_main_logger +import re + +logger = get_main_logger() + +class SmartRoutingMiddleware(BaseHTTPMiddleware): + def __init__(self, app): + super().__init__(app) + # 简化的路由规则 - 直接根据检测结果路由 + pass + + async def dispatch(self, request: Request, call_next): + if not settings.URL_NORMALIZATION_ENABLED: + return await call_next(request) + + original_path = str(request.url.path) + method = request.method + + # 尝试修复URL + fixed_path, fix_info = self.fix_request_url(original_path, method, request) + + if fixed_path != original_path: + logger.info(f"URL fixed: {method} {original_path} → {fixed_path}") + if fix_info: + logger.debug(f"Fix details: {fix_info}") + + # 重写请求路径 + request.scope["path"] = fixed_path + request.scope["raw_path"] = fixed_path.encode() + + return await call_next(request) + + def fix_request_url(self, path: str, method: str, request: Request) -> tuple: + """修复错误的请求URL - 简化版本""" + + # 首先检查是否已经是正确的格式,如果是则不处理 + if self.is_already_correct_format(path): + return path, None + + # 检测是否为流式请求 + is_stream_request = self.detect_stream_request(path, request) + + # 1. 优先检测OpenAI格式请求(避免被v1beta误判) + if self.is_openai_request(path, request): + return self.fix_openai_request(path, method, request) + + # 2. 检测HF格式请求 + if self.is_hf_request(path, request): + return self.fix_hf_request(path, method, request) + + # 3. 检测Gemini请求 + if self.is_gemini_request(path): + return self.fix_gemini_request(path, method, request, is_stream_request) + + # 4. 默认处理其他请求(转为最快的v1端点) + return self.fix_default_request(path, method, request) + + def is_already_correct_format(self, path: str) -> bool: + """检查是否已经是正确的API格式""" + # 检查是否已经是正确的端点格式 + correct_patterns = [ + r'^/v1beta/models/[^/:]+:(generate|streamGenerate)Content$', # Gemini原生 + r'^/gemini/v1beta/models/[^/:]+:(generate|streamGenerate)Content$', # Gemini带前缀 + r'^/v1beta/models$', # Gemini模型列表 + r'^/v1/(chat/completions|models|embeddings|images/generations)$', # v1格式 + r'^/openai/v1/(chat/completions|models|embeddings|images/generations)$', # OpenAI格式 + r'^/hf/v1/(chat/completions|models|embeddings|images/generations)$', # HF格式 + ] + + for pattern in correct_patterns: + if re.match(pattern, path): + return True + + return False + + def is_openai_request(self, path: str, request: Request) -> bool: + """检测OpenAI格式请求""" + return '/openai/' in path.lower() + + def is_hf_request(self, path: str, request: Request) -> bool: + """检测HF格式请求""" + return '/hf/' in path.lower() + + def fix_openai_request(self, path: str, method: str, request: Request) -> tuple: + """修复OpenAI格式请求""" + if method == 'POST': + if 'chat' in path.lower() or 'completion' in path.lower(): + return '/openai/v1/chat/completions', {'type': 'openai_chat'} + elif 'embedding' in path.lower(): + return '/openai/v1/embeddings', {'type': 'openai_embeddings'} + elif 'image' in path.lower(): + return '/openai/v1/images/generations', {'type': 'openai_images'} + elif method == 'GET': + if 'model' in path.lower(): + return '/openai/v1/models', {'type': 'openai_models'} + + return path, None + + def fix_hf_request(self, path: str, method: str, request: Request) -> tuple: + """修复HF格式请求""" + if method == 'POST': + if 'chat' in path.lower() or 'completion' in path.lower(): + return '/hf/v1/chat/completions', {'type': 'hf_chat'} + elif 'embedding' in path.lower(): + return '/hf/v1/embeddings', {'type': 'hf_embeddings'} + elif 'image' in path.lower(): + return '/hf/v1/images/generations', {'type': 'hf_images'} + elif method == 'GET': + if 'model' in path.lower(): + return '/hf/v1/models', {'type': 'hf_models'} + + return path, None + + def fix_default_request(self, path: str, method: str, request: Request) -> tuple: + """修复默认请求(转为最快的v1端点)""" + if method == 'POST': + if 'chat' in path.lower() or 'completion' in path.lower(): + return '/v1/chat/completions', {'type': 'default_chat'} + elif 'embedding' in path.lower(): + return '/v1/embeddings', {'type': 'default_embeddings'} + elif 'image' in path.lower(): + return '/v1/images/generations', {'type': 'default_images'} + elif method == 'GET': + if 'model' in path.lower(): + return '/v1/models', {'type': 'default_models'} + + return path, None + + def fix_gemini_request(self, path: str, method: str, request: Request, is_stream: bool) -> tuple: + """修复Gemini请求""" + if method != 'POST': + if method == 'GET' and 'models' in path.lower(): + return '/v1beta/models', {'rule': 'gemini_models', 'preference': 'gemini_format'} + return path, None + + # 提取模型名称 + try: + model_name = self.extract_model_name(path, request) + except ValueError: + # 无法提取模型名称,返回原路径不做处理 + return path, None + + # 构建目标URL + if is_stream: + target_url = f'/v1beta/models/{model_name}:streamGenerateContent' + else: + target_url = f'/v1beta/models/{model_name}:generateContent' + + fix_info = { + 'rule': 'gemini_generate' if not is_stream else 'gemini_stream', + 'preference': 'gemini_format', + 'is_stream': is_stream, + 'model': model_name + } + + return target_url, fix_info + + def detect_stream_request(self, path: str, request: Request) -> bool: + """检测是否为流式请求""" + # 1. 路径中包含stream关键词 + if 'stream' in path.lower(): + return True + + # 2. 查询参数 + if request.query_params.get('stream') == 'true': + return True + + # 3. Accept头部 + accept = request.headers.get('accept', '') + if 'text/event-stream' in accept: + return True + + return False + + def is_gemini_request(self, path: str) -> bool: + """判断是否为Gemini API请求""" + path_lower = path.lower() + + # 如果已经是OpenAI或HF格式,不应该被识别为Gemini + if '/openai/' in path_lower or '/hf/' in path_lower: + return False + + # 1. 检查是否是明确的Gemini路径模式 + gemini_path_patterns = [ + r'/v1beta/models/', # Gemini原生API路径 + r'/gemini/v1beta/', # 带gemini前缀的路径 + ] + + for pattern in gemini_path_patterns: + if re.search(pattern, path_lower): + return True + + # 2. 检查是否包含Gemini模型名称 + if 'gemini' in path_lower and ('models' in path_lower or 'generatecontent' in path_lower): + return True + + return False + + def extract_model_name(self, path: str, request: Request) -> str: + """从请求中提取模型名称,用于构建Gemini API URL""" + # 1. 从请求体中提取 + try: + if hasattr(request, '_body') and request._body: + import json + body = json.loads(request._body.decode()) + if 'model' in body and body['model']: + return body['model'] + except: + pass + + # 2. 从查询参数中提取 + model_param = request.query_params.get('model') + if model_param: + return model_param + + # 3. 从路径中提取(用于已包含模型名称的路径) + match = re.search(r'/models/([^/:]+)', path, re.IGNORECASE) + if match: + return match.group(1) + + # 4. 如果无法提取模型名称,抛出异常 + raise ValueError("Unable to extract model name from request") + + diff --git a/app/templates/config_editor.html b/app/templates/config_editor.html index 524ad31..b07081e 100644 --- a/app/templates/config_editor.html +++ b/app/templates/config_editor.html @@ -1,2371 +1,2382 @@ {% extends "base.html" %} {% block title %}配置编辑器 - Gemini Balance{% -endblock %} {% block head_extra_styles %} - -{% endblock %} {% block content %} -