diff --git a/app/router/gemini_routes.py b/app/router/gemini_routes.py index 56b69c0..8769e33 100644 --- a/app/router/gemini_routes.py +++ b/app/router/gemini_routes.py @@ -1,8 +1,7 @@ -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse, JSONResponse from copy import deepcopy import asyncio -import json from app.config.config import settings from app.log.logger import get_gemini_logger from app.core.security import SecurityService @@ -101,7 +100,6 @@ async def list_models( async def generate_content( model_name: str, request: GeminiRequest, - raw_request: Request, _=Depends(security_service.verify_key_or_goog_api_key), api_key: str = Depends(get_next_working_key), key_manager: KeyManager = Depends(get_key_manager), @@ -115,26 +113,17 @@ async def generate_content( # 检测是否为原生Gemini TTS请求 is_native_tts = False - if "tts" in model_name.lower(): - try: - raw_body = await raw_request.body() - raw_data = json.loads(raw_body.decode('utf-8')) + if "tts" in model_name.lower() and request.generationConfig: + # 直接从解析后的request对象获取TTS配置 + response_modalities = request.generationConfig.responseModalities or [] + speech_config = request.generationConfig.speechConfig or {} - # 检查是否包含原生TTS配置(responseModalities和speechConfig) - generation_config = raw_data.get("generationConfig", {}) - response_modalities = generation_config.get("responseModalities", []) - speech_config = generation_config.get("speechConfig", {}) - - # 如果包含AUDIO模态和语音配置,则认为是原生TTS请求 - if "AUDIO" in response_modalities and speech_config: - is_native_tts = True - logger.info("Detected native Gemini TTS request") - logger.info(f"Raw request data for native TTS: {json.dumps(raw_data, indent=2, ensure_ascii=False)}") - - # 将TTS字段添加到请求对象中 - setattr(request, '_raw_tts_data', raw_data) - except Exception as e: - logger.warning(f"Failed to parse request for native TTS detection: {e}") + # 如果包含AUDIO模态和语音配置,则认为是原生TTS请求 + if "AUDIO" in response_modalities and speech_config: + is_native_tts = True + logger.info("Detected native Gemini TTS request") + logger.info(f"TTS responseModalities: {response_modalities}") + logger.info(f"TTS speechConfig: {speech_config}") logger.info(f"Using API key: {api_key}") diff --git a/app/service/tts/native/README.md b/app/service/tts/native/README.md index 3acae81..e228df3 100644 --- a/app/service/tts/native/README.md +++ b/app/service/tts/native/README.md @@ -53,11 +53,10 @@ python -m uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload ```python # app/router/gemini_routes.py 中的智能检测逻辑 -if "tts" in model_name.lower(): - # 检查是否包含原生TTS配置 - generation_config = raw_data.get("generationConfig", {}) - response_modalities = generation_config.get("responseModalities", []) - speech_config = generation_config.get("speechConfig", {}) +if "tts" in model_name.lower() and request.generationConfig: + # 直接从解析后的request对象获取TTS配置 + response_modalities = request.generationConfig.responseModalities or [] + speech_config = request.generationConfig.speechConfig or {} # 如果包含AUDIO模态和语音配置,则认为是原生TTS请求 if "AUDIO" in response_modalities and speech_config: @@ -210,14 +209,14 @@ TTSGenerationConfig 1. **请求接收**:系统接收到API请求 2. **智能检测**: - 检查模型名称是否包含 "tts" - - 如果是TTS模型,解析请求体检查是否包含 `responseModalities: ["AUDIO"]` 和 `speechConfig` + - 如果是TTS模型,从 `request.generationConfig` 检查是否包含 `responseModalities: ["AUDIO"]` 和 `speechConfig` 3. **服务选择**: - **原生TTS请求**:使用 `TTSGeminiChatService` 增强服务 - **普通请求**:使用原有 `GeminiChatService` 4. **请求处理**: - **原生TTS**:使用 `_handle_tts_request()` 特殊处理 - **其他请求**:使用标准 `generate_content()` 方法 -5. **字段处理**:从原始HTTP请求体提取TTS字段(`responseModalities`, `speechConfig`) +5. **字段处理**:从 `request.generationConfig` 直接获取TTS字段(`responseModalities`, `speechConfig`) 6. **API调用**:构建优化的payload并调用Gemini API 7. **自动回退**:如果原生TTS处理失败,自动回退到标准服务 8. **响应处理**: diff --git a/app/service/tts/native/tts_chat_service.py b/app/service/tts/native/tts_chat_service.py index e570257..29d22aa 100644 --- a/app/service/tts/native/tts_chat_service.py +++ b/app/service/tts/native/tts_chat_service.py @@ -85,21 +85,18 @@ class TTSGeminiChatService(GeminiChatService): if payload["generationConfig"] is None: payload["generationConfig"] = {} - # 从原始请求中提取TTS相关字段 - if hasattr(request, '_raw_tts_data'): - raw_data = getattr(request, '_raw_tts_data') - raw_generation_config = raw_data.get("generationConfig", {}) - + # 从request.generationConfig直接获取TTS相关字段 + if request.generationConfig: # 添加TTS特定字段 - if "responseModalities" in raw_generation_config: - payload["generationConfig"]["responseModalities"] = raw_generation_config["responseModalities"] - logger.info(f"Added responseModalities: {raw_generation_config['responseModalities']}") + if request.generationConfig.responseModalities: + payload["generationConfig"]["responseModalities"] = request.generationConfig.responseModalities + logger.info(f"Added responseModalities: {request.generationConfig.responseModalities}") - if "speechConfig" in raw_generation_config: - payload["generationConfig"]["speechConfig"] = raw_generation_config["speechConfig"] - logger.info(f"Added speechConfig: {raw_generation_config['speechConfig']}") + if request.generationConfig.speechConfig: + payload["generationConfig"]["speechConfig"] = request.generationConfig.speechConfig + logger.info(f"Added speechConfig: {request.generationConfig.speechConfig}") else: - logger.warning("No raw TTS data found in request, TTS fields may be missing") + logger.warning("No generationConfig found in request, TTS fields may be missing") logger.info(f"TTS payload before API call: {payload}")