From ada32d526a7cd73e4799983d0d87d61105ef91f0 Mon Sep 17 00:00:00 2001 From: chinrain <3523213146@qq.com> Date: Thu, 3 Jul 2025 03:01:10 +0800 Subject: [PATCH 1/2] =?UTF-8?q?refactor:=20=E7=AE=80=E5=8C=96=E6=99=BA?= =?UTF-8?q?=E8=83=BD=E8=B7=AF=E7=94=B1=E4=B8=AD=E9=97=B4=E4=BB=B6=EF=BC=8C?= =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=B7=B7=E5=90=88=E6=A0=BC=E5=BC=8FURL?= =?UTF-8?q?=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 重构智能路由逻辑,在保证聊天的同时尽量简化 - 只会修改常见错误,其余的透传(方便以后维护或者不用维护) - 常见错误都能正常聊天 - 统一前端样式 --- app/middleware/smart_routing_middleware.py | 204 +++++++-------------- app/templates/config_editor.html | 28 ++- 2 files changed, 89 insertions(+), 143 deletions(-) diff --git a/app/middleware/smart_routing_middleware.py b/app/middleware/smart_routing_middleware.py index 732677d..e8a4756 100644 --- a/app/middleware/smart_routing_middleware.py +++ b/app/middleware/smart_routing_middleware.py @@ -34,33 +34,30 @@ class SmartRoutingMiddleware(BaseHTTPMiddleware): return await call_next(request) def fix_request_url(self, path: str, method: str, request: Request) -> tuple: - """修复错误的请求URL - 简化版本""" + """简化的URL修复逻辑""" # 首先检查是否已经是正确的格式,如果是则不处理 if self.is_already_correct_format(path): return path, None - # 检测是否为流式请求 - is_stream_request = self.detect_stream_request(path, request) + # 1. 最高优先级:包含generateContent → Gemini格式 + if 'generatecontent' in path.lower(): + return self.fix_gemini_by_operation(path, method, request) - # 1. 优先检测OpenAI格式请求(避免被v1beta误判) - if self.is_openai_request(path, request): - return self.fix_openai_request(path, method, request) + # 2. 第二优先级:包含/openai/ → OpenAI格式 + if '/openai/' in path.lower(): + return self.fix_openai_by_operation(path, method) - # 2. 检测HF格式请求 - if self.is_hf_request(path, request): - return self.fix_hf_request(path, method, request) + # 3. 第三优先级:包含/v1/ → v1格式 + if '/v1/' in path.lower(): + return self.fix_v1_by_operation(path, method) - # 3. 检测Vertex Express格式请求(优先级高于Gemini) - if self.is_vertex_express_request(path, request): - return self.fix_vertex_express_request(path, method, request, is_stream_request) + # 4. 第四优先级:包含/chat/completions → chat功能 + if '/chat/completions' in path.lower(): + return '/v1/chat/completions', {'type': 'v1_chat'} - # 4. 检测Gemini请求 - if self.is_gemini_request(path): - return self.fix_gemini_request(path, method, request, is_stream_request) - - # 5. 默认处理其他请求(转为最快的v1端点) - return self.fix_default_request(path, method, request) + # 5. 默认:原样传递 + return path, None def is_already_correct_format(self, path: str) -> bool: """检查是否已经是正确的API格式""" @@ -73,8 +70,9 @@ class SmartRoutingMiddleware(BaseHTTPMiddleware): r'^/v1/(chat/completions|models|embeddings|images/generations)$', # v1格式 r'^/openai/v1/(chat/completions|models|embeddings|images/generations)$', # OpenAI格式 r'^/hf/v1/(chat/completions|models|embeddings|images/generations)$', # HF格式 - r'^/vertex-express/v1beta/models/[^/:]+:(generate|streamGenerate)Content$', # Vertex Express + r'^/vertex-express/v1beta/models/[^/:]+:(generate|streamGenerate)Content$', # Vertex Express Gemini格式 r'^/vertex-express/v1beta/models$', # Vertex Express模型列表 + r'^/vertex-express/v1/(chat/completions|models|embeddings|images/generations)$', # Vertex Express OpenAI格式 ] for pattern in correct_patterns: @@ -83,20 +81,52 @@ class SmartRoutingMiddleware(BaseHTTPMiddleware): return False - def is_openai_request(self, path: str, request: Request) -> bool: - """检测OpenAI格式请求""" - return '/openai/' in path.lower() + def fix_gemini_by_operation(self, path: str, method: str, request: Request) -> tuple: + """根据Gemini操作修复,考虑端点偏好""" + if method != 'POST': + return path, None - def is_hf_request(self, path: str, request: Request) -> bool: - """检测HF格式请求""" - return '/hf/' in path.lower() + # 提取模型名称 + try: + model_name = self.extract_model_name(path, request) + except ValueError: + # 无法提取模型名称,返回原路径不做处理 + return path, None - def is_vertex_express_request(self, path: str, request: Request) -> bool: - """检测Vertex Express格式请求""" - return '/vertex-express/' in path.lower() + # 检测是否为流式请求 + is_stream = self.detect_stream_request(path, request) - def fix_openai_request(self, path: str, method: str, request: Request) -> tuple: - """修复OpenAI格式请求""" + # 检查是否有vertex-express偏好 + if '/vertex-express/' in path.lower(): + if is_stream: + target_url = f'/vertex-express/v1beta/models/{model_name}:streamGenerateContent' + else: + target_url = f'/vertex-express/v1beta/models/{model_name}:generateContent' + + fix_info = { + 'rule': 'vertex_express_generate' if not is_stream else 'vertex_express_stream', + 'preference': 'vertex_express_format', + 'is_stream': is_stream, + 'model': model_name + } + else: + # 标准Gemini端点 + if is_stream: + target_url = f'/v1beta/models/{model_name}:streamGenerateContent' + else: + target_url = f'/v1beta/models/{model_name}:generateContent' + + fix_info = { + 'rule': 'gemini_generate' if not is_stream else 'gemini_stream', + 'preference': 'gemini_format', + 'is_stream': is_stream, + 'model': model_name + } + + return target_url, fix_info + + def fix_openai_by_operation(self, path: str, method: str) -> tuple: + """根据操作类型修复OpenAI格式""" if method == 'POST': if 'chat' in path.lower() or 'completion' in path.lower(): return '/openai/v1/chat/completions', {'type': 'openai_chat'} @@ -110,94 +140,21 @@ class SmartRoutingMiddleware(BaseHTTPMiddleware): return path, None - def fix_hf_request(self, path: str, method: str, request: Request) -> tuple: - """修复HF格式请求""" + def fix_v1_by_operation(self, path: str, method: str) -> tuple: + """根据操作类型修复v1格式""" if method == 'POST': if 'chat' in path.lower() or 'completion' in path.lower(): - return '/hf/v1/chat/completions', {'type': 'hf_chat'} + return '/v1/chat/completions', {'type': 'v1_chat'} elif 'embedding' in path.lower(): - return '/hf/v1/embeddings', {'type': 'hf_embeddings'} + return '/v1/embeddings', {'type': 'v1_embeddings'} elif 'image' in path.lower(): - return '/hf/v1/images/generations', {'type': 'hf_images'} + return '/v1/images/generations', {'type': 'v1_images'} elif method == 'GET': if 'model' in path.lower(): - return '/hf/v1/models', {'type': 'hf_models'} + return '/v1/models', {'type': 'v1_models'} return path, None - def fix_vertex_express_request(self, path: str, method: str, request: Request, is_stream: bool) -> tuple: - """修复Vertex Express请求""" - if method != 'POST': - if method == 'GET' and 'models' in path.lower(): - return '/vertex-express/v1beta/models', {'rule': 'vertex_express_models', 'preference': 'vertex_express_format'} - return path, None - - # 提取模型名称 - try: - model_name = self.extract_model_name(path, request) - except ValueError: - # 无法提取模型名称,返回原路径不做处理 - return path, None - - # 构建目标URL - if is_stream: - target_url = f'/vertex-express/v1beta/models/{model_name}:streamGenerateContent' - else: - target_url = f'/vertex-express/v1beta/models/{model_name}:generateContent' - - fix_info = { - 'rule': 'vertex_express_generate' if not is_stream else 'vertex_express_stream', - 'preference': 'vertex_express_format', - 'is_stream': is_stream, - 'model': model_name - } - - return target_url, fix_info - - def fix_default_request(self, path: str, method: str, request: Request) -> tuple: - """修复默认请求(转为最快的v1端点)""" - if method == 'POST': - if 'chat' in path.lower() or 'completion' in path.lower(): - return '/v1/chat/completions', {'type': 'default_chat'} - elif 'embedding' in path.lower(): - return '/v1/embeddings', {'type': 'default_embeddings'} - elif 'image' in path.lower(): - return '/v1/images/generations', {'type': 'default_images'} - elif method == 'GET': - if 'model' in path.lower(): - return '/v1/models', {'type': 'default_models'} - - return path, None - - def fix_gemini_request(self, path: str, method: str, request: Request, is_stream: bool) -> tuple: - """修复Gemini请求""" - if method != 'POST': - if method == 'GET' and 'models' in path.lower(): - return '/v1beta/models', {'rule': 'gemini_models', 'preference': 'gemini_format'} - return path, None - - # 提取模型名称 - try: - model_name = self.extract_model_name(path, request) - except ValueError: - # 无法提取模型名称,返回原路径不做处理 - return path, None - - # 构建目标URL - if is_stream: - target_url = f'/v1beta/models/{model_name}:streamGenerateContent' - else: - target_url = f'/v1beta/models/{model_name}:generateContent' - - fix_info = { - 'rule': 'gemini_generate' if not is_stream else 'gemini_stream', - 'preference': 'gemini_format', - 'is_stream': is_stream, - 'model': model_name - } - - return target_url, fix_info - def detect_stream_request(self, path: str, request: Request) -> bool: """检测是否为流式请求""" # 1. 路径中包含stream关键词 @@ -210,31 +167,6 @@ class SmartRoutingMiddleware(BaseHTTPMiddleware): return False - - def is_gemini_request(self, path: str) -> bool: - """判断是否为Gemini API请求""" - path_lower = path.lower() - - # 如果已经是OpenAI、HF或Vertex Express格式,不应该被识别为Gemini - if '/openai/' in path_lower or '/hf/' in path_lower or '/vertex-express/' in path_lower: - return False - - # 1. 检查是否是明确的Gemini路径模式 - gemini_path_patterns = [ - r'/v1beta/models/', # Gemini原生API路径 - r'/gemini/v1beta/', # 带gemini前缀的路径 - ] - - for pattern in gemini_path_patterns: - if re.search(pattern, path_lower): - return True - - # 2. 检查是否包含Gemini模型名称 - if 'gemini' in path_lower and ('models' in path_lower or 'generatecontent' in path_lower): - return True - - return False - def extract_model_name(self, path: str, request: Request) -> str: """从请求中提取模型名称,用于构建Gemini API URL""" # 1. 从请求体中提取 @@ -258,6 +190,4 @@ class SmartRoutingMiddleware(BaseHTTPMiddleware): return match.group(1) # 4. 如果无法提取模型名称,抛出异常 - raise ValueError("Unable to extract model name from request") - - + raise ValueError("Unable to extract model name from request") \ No newline at end of file diff --git a/app/templates/config_editor.html b/app/templates/config_editor.html index cb82fc8..fe0276a 100644 --- a/app/templates/config_editor.html +++ b/app/templates/config_editor.html @@ -937,13 +937,29 @@ endblock %} {% block head_extra_styles %}
- +
+ +
+ + +
+
- 自动将客户端的各种URL格式映射到正确的API端点 + 自动客户端请求的url拼接为正确格式(仅保证正常聊天,出现问题请关闭)
From f79a52f83998ff504af084ab9b7644dde386cb1c Mon Sep 17 00:00:00 2001 From: snaily Date: Thu, 3 Jul 2025 17:25:50 +0800 Subject: [PATCH 2/2] =?UTF-8?q?fix:=E4=BC=98=E5=8C=96=E6=99=BA=E8=83=BD?= =?UTF-8?q?=E8=B7=AF=E7=94=B1=E4=B8=AD=E9=97=B4=E4=BB=B6=EF=BC=8C=E5=A2=9E?= =?UTF-8?q?=E5=BC=BAURL=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 增加对新路径模式的支持,包括对`v1beta/models`的处理 - 统一日志记录格式,提升调试信息的可读性 - 规范化代码风格,确保一致性和可维护性 - 修复了请求体和查询参数的模型名称提取逻辑 --- app/middleware/smart_routing_middleware.py | 139 +++++++++++---------- 1 file changed, 76 insertions(+), 63 deletions(-) diff --git a/app/middleware/smart_routing_middleware.py b/app/middleware/smart_routing_middleware.py index e8a4756..1ba73f7 100644 --- a/app/middleware/smart_routing_middleware.py +++ b/app/middleware/smart_routing_middleware.py @@ -15,18 +15,18 @@ class SmartRoutingMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): if not settings.URL_NORMALIZATION_ENABLED: return await call_next(request) - + logger.debug(f"request: {request}") original_path = str(request.url.path) method = request.method # 尝试修复URL fixed_path, fix_info = self.fix_request_url(original_path, method, request) - + if fixed_path != original_path: logger.info(f"URL fixed: {method} {original_path} → {fixed_path}") if fix_info: logger.debug(f"Fix details: {fix_info}") - + # 重写请求路径 request.scope["path"] = fixed_path request.scope["raw_path"] = fixed_path.encode() @@ -41,20 +41,20 @@ class SmartRoutingMiddleware(BaseHTTPMiddleware): return path, None # 1. 最高优先级:包含generateContent → Gemini格式 - if 'generatecontent' in path.lower(): + if "generatecontent" in path.lower() or "v1beta/models" in path.lower(): return self.fix_gemini_by_operation(path, method, request) # 2. 第二优先级:包含/openai/ → OpenAI格式 - if '/openai/' in path.lower(): + if "/openai/" in path.lower(): return self.fix_openai_by_operation(path, method) # 3. 第三优先级:包含/v1/ → v1格式 - if '/v1/' in path.lower(): + if "/v1/" in path.lower(): return self.fix_v1_by_operation(path, method) # 4. 第四优先级:包含/chat/completions → chat功能 - if '/chat/completions' in path.lower(): - return '/v1/chat/completions', {'type': 'v1_chat'} + if "/chat/completions" in path.lower(): + return "/v1/chat/completions", {"type": "v1_chat"} # 5. 默认:原样传递 return path, None @@ -63,16 +63,16 @@ class SmartRoutingMiddleware(BaseHTTPMiddleware): """检查是否已经是正确的API格式""" # 检查是否已经是正确的端点格式 correct_patterns = [ - r'^/v1beta/models/[^/:]+:(generate|streamGenerate)Content$', # Gemini原生 - r'^/gemini/v1beta/models/[^/:]+:(generate|streamGenerate)Content$', # Gemini带前缀 - r'^/v1beta/models$', # Gemini模型列表 - r'^/gemini/v1beta/models$', # Gemini带前缀的模型列表 - r'^/v1/(chat/completions|models|embeddings|images/generations)$', # v1格式 - r'^/openai/v1/(chat/completions|models|embeddings|images/generations)$', # OpenAI格式 - r'^/hf/v1/(chat/completions|models|embeddings|images/generations)$', # HF格式 - r'^/vertex-express/v1beta/models/[^/:]+:(generate|streamGenerate)Content$', # Vertex Express Gemini格式 - r'^/vertex-express/v1beta/models$', # Vertex Express模型列表 - r'^/vertex-express/v1/(chat/completions|models|embeddings|images/generations)$', # Vertex Express OpenAI格式 + r"^/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", # Gemini原生 + r"^/gemini/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", # Gemini带前缀 + r"^/v1beta/models$", # Gemini模型列表 + r"^/gemini/v1beta/models$", # Gemini带前缀的模型列表 + r"^/v1/(chat/completions|models|embeddings|images/generations)$", # v1格式 + r"^/openai/v1/(chat/completions|models|embeddings|images/generations)$", # OpenAI格式 + r"^/hf/v1/(chat/completions|models|embeddings|images/generations)$", # HF格式 + r"^/vertex-express/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", # Vertex Express Gemini格式 + r"^/vertex-express/v1beta/models$", # Vertex Express模型列表 + r"^/vertex-express/v1/(chat/completions|models|embeddings|images/generations)$", # Vertex Express OpenAI格式 ] for pattern in correct_patterns: @@ -81,10 +81,14 @@ class SmartRoutingMiddleware(BaseHTTPMiddleware): return False - def fix_gemini_by_operation(self, path: str, method: str, request: Request) -> tuple: + def fix_gemini_by_operation( + self, path: str, method: str, request: Request + ) -> tuple: """根据Gemini操作修复,考虑端点偏好""" - if method != 'POST': - return path, None + if method == "GET": + return "/v1beta/models", { + "role": "gemini_models", + } # 提取模型名称 try: @@ -97,72 +101,80 @@ class SmartRoutingMiddleware(BaseHTTPMiddleware): is_stream = self.detect_stream_request(path, request) # 检查是否有vertex-express偏好 - if '/vertex-express/' in path.lower(): + if "/vertex-express/" in path.lower(): if is_stream: - target_url = f'/vertex-express/v1beta/models/{model_name}:streamGenerateContent' + target_url = ( + f"/vertex-express/v1beta/models/{model_name}:streamGenerateContent" + ) else: - target_url = f'/vertex-express/v1beta/models/{model_name}:generateContent' + target_url = ( + f"/vertex-express/v1beta/models/{model_name}:generateContent" + ) fix_info = { - 'rule': 'vertex_express_generate' if not is_stream else 'vertex_express_stream', - 'preference': 'vertex_express_format', - 'is_stream': is_stream, - 'model': model_name + "rule": ( + "vertex_express_generate" + if not is_stream + else "vertex_express_stream" + ), + "preference": "vertex_express_format", + "is_stream": is_stream, + "model": model_name, } else: # 标准Gemini端点 if is_stream: - target_url = f'/v1beta/models/{model_name}:streamGenerateContent' + target_url = f"/v1beta/models/{model_name}:streamGenerateContent" else: - target_url = f'/v1beta/models/{model_name}:generateContent' + target_url = f"/v1beta/models/{model_name}:generateContent" fix_info = { - 'rule': 'gemini_generate' if not is_stream else 'gemini_stream', - 'preference': 'gemini_format', - 'is_stream': is_stream, - 'model': model_name + "rule": "gemini_generate" if not is_stream else "gemini_stream", + "preference": "gemini_format", + "is_stream": is_stream, + "model": model_name, } return target_url, fix_info def fix_openai_by_operation(self, path: str, method: str) -> tuple: """根据操作类型修复OpenAI格式""" - if method == 'POST': - if 'chat' in path.lower() or 'completion' in path.lower(): - return '/openai/v1/chat/completions', {'type': 'openai_chat'} - elif 'embedding' in path.lower(): - return '/openai/v1/embeddings', {'type': 'openai_embeddings'} - elif 'image' in path.lower(): - return '/openai/v1/images/generations', {'type': 'openai_images'} - elif method == 'GET': - if 'model' in path.lower(): - return '/openai/v1/models', {'type': 'openai_models'} + if method == "POST": + if "chat" in path.lower() or "completion" in path.lower(): + return "/openai/v1/chat/completions", {"type": "openai_chat"} + elif "embedding" in path.lower(): + return "/openai/v1/embeddings", {"type": "openai_embeddings"} + elif "image" in path.lower(): + return "/openai/v1/images/generations", {"type": "openai_images"} + elif method == "GET": + if "model" in path.lower(): + return "/openai/v1/models", {"type": "openai_models"} return path, None def fix_v1_by_operation(self, path: str, method: str) -> tuple: """根据操作类型修复v1格式""" - if method == 'POST': - if 'chat' in path.lower() or 'completion' in path.lower(): - return '/v1/chat/completions', {'type': 'v1_chat'} - elif 'embedding' in path.lower(): - return '/v1/embeddings', {'type': 'v1_embeddings'} - elif 'image' in path.lower(): - return '/v1/images/generations', {'type': 'v1_images'} - elif method == 'GET': - if 'model' in path.lower(): - return '/v1/models', {'type': 'v1_models'} + if method == "POST": + if "chat" in path.lower() or "completion" in path.lower(): + return "/v1/chat/completions", {"type": "v1_chat"} + elif "embedding" in path.lower(): + return "/v1/embeddings", {"type": "v1_embeddings"} + elif "image" in path.lower(): + return "/v1/images/generations", {"type": "v1_images"} + elif method == "GET": + if "model" in path.lower(): + return "/v1/models", {"type": "v1_models"} return path, None def detect_stream_request(self, path: str, request: Request) -> bool: """检测是否为流式请求""" # 1. 路径中包含stream关键词 - if 'stream' in path.lower(): + if "stream" in path.lower(): return True # 2. 查询参数 - if request.query_params.get('stream') == 'true': + if request.query_params.get("stream") == "true": return True return False @@ -171,23 +183,24 @@ class SmartRoutingMiddleware(BaseHTTPMiddleware): """从请求中提取模型名称,用于构建Gemini API URL""" # 1. 从请求体中提取 try: - if hasattr(request, '_body') and request._body: + if hasattr(request, "_body") and request._body: import json + body = json.loads(request._body.decode()) - if 'model' in body and body['model']: - return body['model'] - except: + if "model" in body and body["model"]: + return body["model"] + except Exception: pass # 2. 从查询参数中提取 - model_param = request.query_params.get('model') + model_param = request.query_params.get("model") if model_param: return model_param # 3. 从路径中提取(用于已包含模型名称的路径) - match = re.search(r'/models/([^/:]+)', path, re.IGNORECASE) + match = re.search(r"/models/([^/:]+)", path, re.IGNORECASE) if match: return match.group(1) # 4. 如果无法提取模型名称,抛出异常 - raise ValueError("Unable to extract model name from request") \ No newline at end of file + raise ValueError("Unable to extract model name from request")