feat: 添加智能路由中间件,支持API路径自动规范化

- 新增SmartRoutingMiddleware智能路由中间件
- 支持OpenAI/HF/Gemini/默认格式的自动检测和转换
- 修复错误URL路径格式,提升API兼容性
- 添加URL_NORMALIZATION_ENABLED配置开关,默认关闭
- 智能路由功能默认关闭,需手动启用
This commit is contained in:
chchchchc1023
2025-06-30 22:58:58 +08:00
parent 00f423a622
commit 18a166afb0
5 changed files with 2509 additions and 2260 deletions

View File

@@ -14,6 +14,7 @@ ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta
ENV TOOLS_CODE_EXECUTION_ENABLED=false
ENV IMAGE_MODELS='["gemini-2.0-flash-exp"]'
ENV SEARCH_MODELS='["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]'
ENV URL_NORMALIZATION_ENABLED=false
# Expose port
EXPOSE 8000

View File

@@ -63,7 +63,10 @@ class Settings(BaseSettings):
PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: bool = True # 是否使用一致性哈希来选择代理
VERTEX_API_KEYS: List[str] = []
VERTEX_EXPRESS_BASE_URL: str = "https://aiplatform.googleapis.com/v1beta1/publishers/google"
# 智能路由配置
URL_NORMALIZATION_ENABLED: bool = False # 是否启用智能路由映射功能
# 模型相关配置
SEARCH_MODELS: List[str] = ["gemini-2.0-flash-exp"]
IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp"]
@@ -111,6 +114,9 @@ class Settings(BaseSettings):
AUTO_DELETE_REQUEST_LOGS_DAYS: int = 30
SAFETY_SETTINGS: List[Dict[str, str]] = DEFAULT_SAFETY_SETTINGS
#是否开启新手模式
URL_NORMALIZATION_ENABLED: bool = False
def __init__(self, **kwargs):
super().__init__(**kwargs)
# 设置默认AUTH_TOKEN如果未提供

View File

@@ -8,6 +8,7 @@ from fastapi.responses import RedirectResponse
from starlette.middleware.base import BaseHTTPMiddleware
# from app.middleware.request_logging_middleware import RequestLoggingMiddleware
from app.middleware.smart_routing_middleware import SmartRoutingMiddleware
from app.core.constants import API_VERSION
from app.core.security import verify_auth_token
from app.log.logger import get_middleware_logger
@@ -52,6 +53,9 @@ def setup_middlewares(app: FastAPI) -> None:
Args:
app: FastAPI应用程序实例
"""
# 添加智能路由中间件(必须在认证中间件之前)
app.add_middleware(SmartRoutingMiddleware)
# 添加认证中间件
app.add_middleware(AuthMiddleware)

View File

@@ -0,0 +1,227 @@
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from app.config.config import settings
from app.log.logger import get_main_logger
import re
logger = get_main_logger()
class SmartRoutingMiddleware(BaseHTTPMiddleware):
def __init__(self, app):
super().__init__(app)
# 简化的路由规则 - 直接根据检测结果路由
pass
async def dispatch(self, request: Request, call_next):
if not settings.URL_NORMALIZATION_ENABLED:
return await call_next(request)
original_path = str(request.url.path)
method = request.method
# 尝试修复URL
fixed_path, fix_info = self.fix_request_url(original_path, method, request)
if fixed_path != original_path:
logger.info(f"URL fixed: {method} {original_path}{fixed_path}")
if fix_info:
logger.debug(f"Fix details: {fix_info}")
# 重写请求路径
request.scope["path"] = fixed_path
request.scope["raw_path"] = fixed_path.encode()
return await call_next(request)
def fix_request_url(self, path: str, method: str, request: Request) -> tuple:
"""修复错误的请求URL - 简化版本"""
# 首先检查是否已经是正确的格式,如果是则不处理
if self.is_already_correct_format(path):
return path, None
# 检测是否为流式请求
is_stream_request = self.detect_stream_request(path, request)
# 1. 优先检测OpenAI格式请求避免被v1beta误判
if self.is_openai_request(path, request):
return self.fix_openai_request(path, method, request)
# 2. 检测HF格式请求
if self.is_hf_request(path, request):
return self.fix_hf_request(path, method, request)
# 3. 检测Gemini请求
if self.is_gemini_request(path):
return self.fix_gemini_request(path, method, request, is_stream_request)
# 4. 默认处理其他请求转为最快的v1端点
return self.fix_default_request(path, method, request)
def is_already_correct_format(self, path: str) -> bool:
"""检查是否已经是正确的API格式"""
# 检查是否已经是正确的端点格式
correct_patterns = [
r'^/v1beta/models/[^/:]+:(generate|streamGenerate)Content$', # Gemini原生
r'^/gemini/v1beta/models/[^/:]+:(generate|streamGenerate)Content$', # Gemini带前缀
r'^/v1beta/models$', # Gemini模型列表
r'^/v1/(chat/completions|models|embeddings|images/generations)$', # v1格式
r'^/openai/v1/(chat/completions|models|embeddings|images/generations)$', # OpenAI格式
r'^/hf/v1/(chat/completions|models|embeddings|images/generations)$', # HF格式
]
for pattern in correct_patterns:
if re.match(pattern, path):
return True
return False
def is_openai_request(self, path: str, request: Request) -> bool:
"""检测OpenAI格式请求"""
return '/openai/' in path.lower()
def is_hf_request(self, path: str, request: Request) -> bool:
"""检测HF格式请求"""
return '/hf/' in path.lower()
def fix_openai_request(self, path: str, method: str, request: Request) -> tuple:
"""修复OpenAI格式请求"""
if method == 'POST':
if 'chat' in path.lower() or 'completion' in path.lower():
return '/openai/v1/chat/completions', {'type': 'openai_chat'}
elif 'embedding' in path.lower():
return '/openai/v1/embeddings', {'type': 'openai_embeddings'}
elif 'image' in path.lower():
return '/openai/v1/images/generations', {'type': 'openai_images'}
elif method == 'GET':
if 'model' in path.lower():
return '/openai/v1/models', {'type': 'openai_models'}
return path, None
def fix_hf_request(self, path: str, method: str, request: Request) -> tuple:
"""修复HF格式请求"""
if method == 'POST':
if 'chat' in path.lower() or 'completion' in path.lower():
return '/hf/v1/chat/completions', {'type': 'hf_chat'}
elif 'embedding' in path.lower():
return '/hf/v1/embeddings', {'type': 'hf_embeddings'}
elif 'image' in path.lower():
return '/hf/v1/images/generations', {'type': 'hf_images'}
elif method == 'GET':
if 'model' in path.lower():
return '/hf/v1/models', {'type': 'hf_models'}
return path, None
def fix_default_request(self, path: str, method: str, request: Request) -> tuple:
"""修复默认请求转为最快的v1端点"""
if method == 'POST':
if 'chat' in path.lower() or 'completion' in path.lower():
return '/v1/chat/completions', {'type': 'default_chat'}
elif 'embedding' in path.lower():
return '/v1/embeddings', {'type': 'default_embeddings'}
elif 'image' in path.lower():
return '/v1/images/generations', {'type': 'default_images'}
elif method == 'GET':
if 'model' in path.lower():
return '/v1/models', {'type': 'default_models'}
return path, None
def fix_gemini_request(self, path: str, method: str, request: Request, is_stream: bool) -> tuple:
"""修复Gemini请求"""
if method != 'POST':
if method == 'GET' and 'models' in path.lower():
return '/v1beta/models', {'rule': 'gemini_models', 'preference': 'gemini_format'}
return path, None
# 提取模型名称
try:
model_name = self.extract_model_name(path, request)
except ValueError:
# 无法提取模型名称,返回原路径不做处理
return path, None
# 构建目标URL
if is_stream:
target_url = f'/v1beta/models/{model_name}:streamGenerateContent'
else:
target_url = f'/v1beta/models/{model_name}:generateContent'
fix_info = {
'rule': 'gemini_generate' if not is_stream else 'gemini_stream',
'preference': 'gemini_format',
'is_stream': is_stream,
'model': model_name
}
return target_url, fix_info
def detect_stream_request(self, path: str, request: Request) -> bool:
"""检测是否为流式请求"""
# 1. 路径中包含stream关键词
if 'stream' in path.lower():
return True
# 2. 查询参数
if request.query_params.get('stream') == 'true':
return True
# 3. Accept头部
accept = request.headers.get('accept', '')
if 'text/event-stream' in accept:
return True
return False
def is_gemini_request(self, path: str) -> bool:
"""判断是否为Gemini API请求"""
path_lower = path.lower()
# 如果已经是OpenAI或HF格式不应该被识别为Gemini
if '/openai/' in path_lower or '/hf/' in path_lower:
return False
# 1. 检查是否是明确的Gemini路径模式
gemini_path_patterns = [
r'/v1beta/models/', # Gemini原生API路径
r'/gemini/v1beta/', # 带gemini前缀的路径
]
for pattern in gemini_path_patterns:
if re.search(pattern, path_lower):
return True
# 2. 检查是否包含Gemini模型名称
if 'gemini' in path_lower and ('models' in path_lower or 'generatecontent' in path_lower):
return True
return False
def extract_model_name(self, path: str, request: Request) -> str:
"""从请求中提取模型名称用于构建Gemini API URL"""
# 1. 从请求体中提取
try:
if hasattr(request, '_body') and request._body:
import json
body = json.loads(request._body.decode())
if 'model' in body and body['model']:
return body['model']
except:
pass
# 2. 从查询参数中提取
model_param = request.query_params.get('model')
if model_param:
return model_param
# 3. 从路径中提取(用于已包含模型名称的路径)
match = re.search(r'/models/([^/:]+)', path, re.IGNORECASE)
if match:
return match.group(1)
# 4. 如果无法提取模型名称,抛出异常
raise ValueError("Unable to extract model name from request")

File diff suppressed because it is too large Load Diff