mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-05-11 18:09:55 +08:00
refactor: 简化智能路由中间件,优化混合格式URL处理
- 重构智能路由逻辑,在保证聊天的同时尽量简化 - 只会修改常见错误,其余的透传(方便以后维护或者不用维护) - 常见错误都能正常聊天 - 统一前端样式
This commit is contained in:
@@ -34,33 +34,30 @@ class SmartRoutingMiddleware(BaseHTTPMiddleware):
|
||||
return await call_next(request)
|
||||
|
||||
def fix_request_url(self, path: str, method: str, request: Request) -> tuple:
|
||||
"""修复错误的请求URL - 简化版本"""
|
||||
"""简化的URL修复逻辑"""
|
||||
|
||||
# 首先检查是否已经是正确的格式,如果是则不处理
|
||||
if self.is_already_correct_format(path):
|
||||
return path, None
|
||||
|
||||
# 检测是否为流式请求
|
||||
is_stream_request = self.detect_stream_request(path, request)
|
||||
# 1. 最高优先级:包含generateContent → Gemini格式
|
||||
if 'generatecontent' in path.lower():
|
||||
return self.fix_gemini_by_operation(path, method, request)
|
||||
|
||||
# 1. 优先检测OpenAI格式请求(避免被v1beta误判)
|
||||
if self.is_openai_request(path, request):
|
||||
return self.fix_openai_request(path, method, request)
|
||||
# 2. 第二优先级:包含/openai/ → OpenAI格式
|
||||
if '/openai/' in path.lower():
|
||||
return self.fix_openai_by_operation(path, method)
|
||||
|
||||
# 2. 检测HF格式请求
|
||||
if self.is_hf_request(path, request):
|
||||
return self.fix_hf_request(path, method, request)
|
||||
# 3. 第三优先级:包含/v1/ → v1格式
|
||||
if '/v1/' in path.lower():
|
||||
return self.fix_v1_by_operation(path, method)
|
||||
|
||||
# 3. 检测Vertex Express格式请求(优先级高于Gemini)
|
||||
if self.is_vertex_express_request(path, request):
|
||||
return self.fix_vertex_express_request(path, method, request, is_stream_request)
|
||||
# 4. 第四优先级:包含/chat/completions → chat功能
|
||||
if '/chat/completions' in path.lower():
|
||||
return '/v1/chat/completions', {'type': 'v1_chat'}
|
||||
|
||||
# 4. 检测Gemini请求
|
||||
if self.is_gemini_request(path):
|
||||
return self.fix_gemini_request(path, method, request, is_stream_request)
|
||||
|
||||
# 5. 默认处理其他请求(转为最快的v1端点)
|
||||
return self.fix_default_request(path, method, request)
|
||||
# 5. 默认:原样传递
|
||||
return path, None
|
||||
|
||||
def is_already_correct_format(self, path: str) -> bool:
|
||||
"""检查是否已经是正确的API格式"""
|
||||
@@ -73,8 +70,9 @@ class SmartRoutingMiddleware(BaseHTTPMiddleware):
|
||||
r'^/v1/(chat/completions|models|embeddings|images/generations)$', # v1格式
|
||||
r'^/openai/v1/(chat/completions|models|embeddings|images/generations)$', # OpenAI格式
|
||||
r'^/hf/v1/(chat/completions|models|embeddings|images/generations)$', # HF格式
|
||||
r'^/vertex-express/v1beta/models/[^/:]+:(generate|streamGenerate)Content$', # Vertex Express
|
||||
r'^/vertex-express/v1beta/models/[^/:]+:(generate|streamGenerate)Content$', # Vertex Express Gemini格式
|
||||
r'^/vertex-express/v1beta/models$', # Vertex Express模型列表
|
||||
r'^/vertex-express/v1/(chat/completions|models|embeddings|images/generations)$', # Vertex Express OpenAI格式
|
||||
]
|
||||
|
||||
for pattern in correct_patterns:
|
||||
@@ -83,20 +81,52 @@ class SmartRoutingMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
return False
|
||||
|
||||
def is_openai_request(self, path: str, request: Request) -> bool:
|
||||
"""检测OpenAI格式请求"""
|
||||
return '/openai/' in path.lower()
|
||||
def fix_gemini_by_operation(self, path: str, method: str, request: Request) -> tuple:
|
||||
"""根据Gemini操作修复,考虑端点偏好"""
|
||||
if method != 'POST':
|
||||
return path, None
|
||||
|
||||
def is_hf_request(self, path: str, request: Request) -> bool:
|
||||
"""检测HF格式请求"""
|
||||
return '/hf/' in path.lower()
|
||||
# 提取模型名称
|
||||
try:
|
||||
model_name = self.extract_model_name(path, request)
|
||||
except ValueError:
|
||||
# 无法提取模型名称,返回原路径不做处理
|
||||
return path, None
|
||||
|
||||
def is_vertex_express_request(self, path: str, request: Request) -> bool:
|
||||
"""检测Vertex Express格式请求"""
|
||||
return '/vertex-express/' in path.lower()
|
||||
# 检测是否为流式请求
|
||||
is_stream = self.detect_stream_request(path, request)
|
||||
|
||||
def fix_openai_request(self, path: str, method: str, request: Request) -> tuple:
|
||||
"""修复OpenAI格式请求"""
|
||||
# 检查是否有vertex-express偏好
|
||||
if '/vertex-express/' in path.lower():
|
||||
if is_stream:
|
||||
target_url = f'/vertex-express/v1beta/models/{model_name}:streamGenerateContent'
|
||||
else:
|
||||
target_url = f'/vertex-express/v1beta/models/{model_name}:generateContent'
|
||||
|
||||
fix_info = {
|
||||
'rule': 'vertex_express_generate' if not is_stream else 'vertex_express_stream',
|
||||
'preference': 'vertex_express_format',
|
||||
'is_stream': is_stream,
|
||||
'model': model_name
|
||||
}
|
||||
else:
|
||||
# 标准Gemini端点
|
||||
if is_stream:
|
||||
target_url = f'/v1beta/models/{model_name}:streamGenerateContent'
|
||||
else:
|
||||
target_url = f'/v1beta/models/{model_name}:generateContent'
|
||||
|
||||
fix_info = {
|
||||
'rule': 'gemini_generate' if not is_stream else 'gemini_stream',
|
||||
'preference': 'gemini_format',
|
||||
'is_stream': is_stream,
|
||||
'model': model_name
|
||||
}
|
||||
|
||||
return target_url, fix_info
|
||||
|
||||
def fix_openai_by_operation(self, path: str, method: str) -> tuple:
|
||||
"""根据操作类型修复OpenAI格式"""
|
||||
if method == 'POST':
|
||||
if 'chat' in path.lower() or 'completion' in path.lower():
|
||||
return '/openai/v1/chat/completions', {'type': 'openai_chat'}
|
||||
@@ -110,94 +140,21 @@ class SmartRoutingMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
return path, None
|
||||
|
||||
def fix_hf_request(self, path: str, method: str, request: Request) -> tuple:
|
||||
"""修复HF格式请求"""
|
||||
def fix_v1_by_operation(self, path: str, method: str) -> tuple:
|
||||
"""根据操作类型修复v1格式"""
|
||||
if method == 'POST':
|
||||
if 'chat' in path.lower() or 'completion' in path.lower():
|
||||
return '/hf/v1/chat/completions', {'type': 'hf_chat'}
|
||||
return '/v1/chat/completions', {'type': 'v1_chat'}
|
||||
elif 'embedding' in path.lower():
|
||||
return '/hf/v1/embeddings', {'type': 'hf_embeddings'}
|
||||
return '/v1/embeddings', {'type': 'v1_embeddings'}
|
||||
elif 'image' in path.lower():
|
||||
return '/hf/v1/images/generations', {'type': 'hf_images'}
|
||||
return '/v1/images/generations', {'type': 'v1_images'}
|
||||
elif method == 'GET':
|
||||
if 'model' in path.lower():
|
||||
return '/hf/v1/models', {'type': 'hf_models'}
|
||||
return '/v1/models', {'type': 'v1_models'}
|
||||
|
||||
return path, None
|
||||
|
||||
def fix_vertex_express_request(self, path: str, method: str, request: Request, is_stream: bool) -> tuple:
|
||||
"""修复Vertex Express请求"""
|
||||
if method != 'POST':
|
||||
if method == 'GET' and 'models' in path.lower():
|
||||
return '/vertex-express/v1beta/models', {'rule': 'vertex_express_models', 'preference': 'vertex_express_format'}
|
||||
return path, None
|
||||
|
||||
# 提取模型名称
|
||||
try:
|
||||
model_name = self.extract_model_name(path, request)
|
||||
except ValueError:
|
||||
# 无法提取模型名称,返回原路径不做处理
|
||||
return path, None
|
||||
|
||||
# 构建目标URL
|
||||
if is_stream:
|
||||
target_url = f'/vertex-express/v1beta/models/{model_name}:streamGenerateContent'
|
||||
else:
|
||||
target_url = f'/vertex-express/v1beta/models/{model_name}:generateContent'
|
||||
|
||||
fix_info = {
|
||||
'rule': 'vertex_express_generate' if not is_stream else 'vertex_express_stream',
|
||||
'preference': 'vertex_express_format',
|
||||
'is_stream': is_stream,
|
||||
'model': model_name
|
||||
}
|
||||
|
||||
return target_url, fix_info
|
||||
|
||||
def fix_default_request(self, path: str, method: str, request: Request) -> tuple:
|
||||
"""修复默认请求(转为最快的v1端点)"""
|
||||
if method == 'POST':
|
||||
if 'chat' in path.lower() or 'completion' in path.lower():
|
||||
return '/v1/chat/completions', {'type': 'default_chat'}
|
||||
elif 'embedding' in path.lower():
|
||||
return '/v1/embeddings', {'type': 'default_embeddings'}
|
||||
elif 'image' in path.lower():
|
||||
return '/v1/images/generations', {'type': 'default_images'}
|
||||
elif method == 'GET':
|
||||
if 'model' in path.lower():
|
||||
return '/v1/models', {'type': 'default_models'}
|
||||
|
||||
return path, None
|
||||
|
||||
def fix_gemini_request(self, path: str, method: str, request: Request, is_stream: bool) -> tuple:
|
||||
"""修复Gemini请求"""
|
||||
if method != 'POST':
|
||||
if method == 'GET' and 'models' in path.lower():
|
||||
return '/v1beta/models', {'rule': 'gemini_models', 'preference': 'gemini_format'}
|
||||
return path, None
|
||||
|
||||
# 提取模型名称
|
||||
try:
|
||||
model_name = self.extract_model_name(path, request)
|
||||
except ValueError:
|
||||
# 无法提取模型名称,返回原路径不做处理
|
||||
return path, None
|
||||
|
||||
# 构建目标URL
|
||||
if is_stream:
|
||||
target_url = f'/v1beta/models/{model_name}:streamGenerateContent'
|
||||
else:
|
||||
target_url = f'/v1beta/models/{model_name}:generateContent'
|
||||
|
||||
fix_info = {
|
||||
'rule': 'gemini_generate' if not is_stream else 'gemini_stream',
|
||||
'preference': 'gemini_format',
|
||||
'is_stream': is_stream,
|
||||
'model': model_name
|
||||
}
|
||||
|
||||
return target_url, fix_info
|
||||
|
||||
def detect_stream_request(self, path: str, request: Request) -> bool:
|
||||
"""检测是否为流式请求"""
|
||||
# 1. 路径中包含stream关键词
|
||||
@@ -210,31 +167,6 @@ class SmartRoutingMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_gemini_request(self, path: str) -> bool:
|
||||
"""判断是否为Gemini API请求"""
|
||||
path_lower = path.lower()
|
||||
|
||||
# 如果已经是OpenAI、HF或Vertex Express格式,不应该被识别为Gemini
|
||||
if '/openai/' in path_lower or '/hf/' in path_lower or '/vertex-express/' in path_lower:
|
||||
return False
|
||||
|
||||
# 1. 检查是否是明确的Gemini路径模式
|
||||
gemini_path_patterns = [
|
||||
r'/v1beta/models/', # Gemini原生API路径
|
||||
r'/gemini/v1beta/', # 带gemini前缀的路径
|
||||
]
|
||||
|
||||
for pattern in gemini_path_patterns:
|
||||
if re.search(pattern, path_lower):
|
||||
return True
|
||||
|
||||
# 2. 检查是否包含Gemini模型名称
|
||||
if 'gemini' in path_lower and ('models' in path_lower or 'generatecontent' in path_lower):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def extract_model_name(self, path: str, request: Request) -> str:
|
||||
"""从请求中提取模型名称,用于构建Gemini API URL"""
|
||||
# 1. 从请求体中提取
|
||||
@@ -258,6 +190,4 @@ class SmartRoutingMiddleware(BaseHTTPMiddleware):
|
||||
return match.group(1)
|
||||
|
||||
# 4. 如果无法提取模型名称,抛出异常
|
||||
raise ValueError("Unable to extract model name from request")
|
||||
|
||||
|
||||
raise ValueError("Unable to extract model name from request")
|
||||
@@ -937,13 +937,29 @@ endblock %} {% block head_extra_styles %}
|
||||
</div>
|
||||
<!-- 智能路由配置 -->
|
||||
<div class="mb-6">
|
||||
<label class="flex items-center">
|
||||
<input type="checkbox" id="URL_NORMALIZATION_ENABLED" name="URL_NORMALIZATION_ENABLED"
|
||||
class="mr-2 rounded" />
|
||||
<span class="font-semibold text-gray-700">启用智能路由映射</span>
|
||||
</label>
|
||||
<div class="flex items-center justify-between">
|
||||
<label
|
||||
for="URL_NORMALIZATION_ENABLED"
|
||||
class="font-semibold text-gray-700"
|
||||
>启用智能路由映射</label
|
||||
>
|
||||
<div
|
||||
class="relative inline-block w-10 mr-2 align-middle select-none transition duration-200 ease-in"
|
||||
>
|
||||
<input
|
||||
type="checkbox"
|
||||
name="URL_NORMALIZATION_ENABLED"
|
||||
id="URL_NORMALIZATION_ENABLED"
|
||||
class="toggle-checkbox absolute block w-6 h-6 rounded-full bg-white border-4 appearance-none cursor-pointer"
|
||||
/>
|
||||
<label
|
||||
for="URL_NORMALIZATION_ENABLED"
|
||||
class="toggle-label block overflow-hidden h-6 rounded-full bg-gray-300 cursor-pointer"
|
||||
></label>
|
||||
</div>
|
||||
</div>
|
||||
<small class="text-gray-500 mt-1 block">
|
||||
自动将客户端的各种URL格式映射到正确的API端点
|
||||
自动客户端请求的url拼接为正确格式(仅保证正常聊天,出现问题请关闭)
|
||||
</small>
|
||||
</div>
|
||||
<!-- 最大失败次数 -->
|
||||
|
||||
Reference in New Issue
Block a user