From 2b48c853fe880f706829c5b8f705196615bb9c35 Mon Sep 17 00:00:00 2001 From: zzh Date: Tue, 15 Jul 2025 15:34:55 +0900 Subject: [PATCH] Refactor: Use TTS service only for TTS models, keep original service for others - Remove ENABLE_TTS environment variable dependency - Detect TTS models dynamically by model name - Use TTS-enhanced service only when needed - Fallback to standard service if TTS processing fails - Maintain full backward compatibility --- app/router/gemini_routes.py | 48 ++++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/app/router/gemini_routes.py b/app/router/gemini_routes.py index 619b54d..e349d04 100644 --- a/app/router/gemini_routes.py +++ b/app/router/gemini_routes.py @@ -35,12 +35,7 @@ async def get_next_working_key(key_manager: KeyManager = Depends(get_key_manager async def get_chat_service(key_manager: KeyManager = Depends(get_key_manager)): """获取Gemini聊天服务实例""" - # 检查是否启用TTS功能 - import os - if os.getenv("ENABLE_TTS", "false").lower() in ("true", "1", "yes", "on"): - return await get_tts_chat_service(key_manager) - else: - return GeminiChatService(settings.BASE_URL, key_manager) + return GeminiChatService(settings.BASE_URL, key_manager) @router.get("/models") @@ -118,27 +113,36 @@ async def generate_content( logger.info(f"Handling Gemini content generation request for model: {model_name}") logger.debug(f"Request: \n{request.model_dump_json(indent=2)}") - # 对于TTS模型,我们需要从原始请求体中提取TTS字段 - if "tts" in model_name.lower(): - try: - raw_body = await raw_request.body() - raw_data = json.loads(raw_body.decode('utf-8')) - logger.info(f"Raw request data for TTS: {json.dumps(raw_data, indent=2, ensure_ascii=False)}") - - # 将TTS字段添加到请求对象中 - if hasattr(request, '_raw_tts_data'): - request._raw_tts_data = raw_data - else: - # 动态添加属性 - setattr(request, '_raw_tts_data', raw_data) - except Exception as e: - logger.warning(f"Failed to parse raw request for TTS: {e}") - logger.info(f"Using API key: {api_key}") if not await model_service.check_model_support(model_name): raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported") + # 检测是否为TTS模型,如果是则使用TTS服务 + if "tts" in model_name.lower(): + logger.info("Detected TTS model, using TTS-enhanced service") + try: + # 从原始请求体中提取TTS字段 + raw_body = await raw_request.body() + raw_data = json.loads(raw_body.decode('utf-8')) + logger.info(f"Raw request data for TTS: {json.dumps(raw_data, indent=2, ensure_ascii=False)}") + + # 将TTS字段添加到请求对象中 + setattr(request, '_raw_tts_data', raw_data) + + # 使用TTS增强服务 + tts_service = await get_tts_chat_service(key_manager) + response = await tts_service.generate_content( + model=model_name, + request=request, + api_key=api_key + ) + return response + except Exception as e: + logger.warning(f"TTS processing failed, falling back to standard service: {e}") + # 如果TTS处理失败,回退到标准服务 + + # 使用标准服务处理非TTS模型或TTS失败的情况 response = await chat_service.generate_content( model=model_name, request=request,