Refactor: Use TTS service only for TTS models, keep original service for others

- Remove ENABLE_TTS environment variable dependency
- Detect TTS models dynamically by model name
- Use TTS-enhanced service only when needed
- Fallback to standard service if TTS processing fails
- Maintain full backward compatibility
This commit is contained in:
zzh
2025-07-15 15:34:55 +09:00
parent c47f696691
commit 2b48c853fe

View File

@@ -35,12 +35,7 @@ async def get_next_working_key(key_manager: KeyManager = Depends(get_key_manager
async def get_chat_service(key_manager: KeyManager = Depends(get_key_manager)):
"""获取Gemini聊天服务实例"""
# 检查是否启用TTS功能
import os
if os.getenv("ENABLE_TTS", "false").lower() in ("true", "1", "yes", "on"):
return await get_tts_chat_service(key_manager)
else:
return GeminiChatService(settings.BASE_URL, key_manager)
return GeminiChatService(settings.BASE_URL, key_manager)
@router.get("/models")
@@ -118,27 +113,36 @@ async def generate_content(
logger.info(f"Handling Gemini content generation request for model: {model_name}")
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
# 对于TTS模型我们需要从原始请求体中提取TTS字段
if "tts" in model_name.lower():
try:
raw_body = await raw_request.body()
raw_data = json.loads(raw_body.decode('utf-8'))
logger.info(f"Raw request data for TTS: {json.dumps(raw_data, indent=2, ensure_ascii=False)}")
# 将TTS字段添加到请求对象中
if hasattr(request, '_raw_tts_data'):
request._raw_tts_data = raw_data
else:
# 动态添加属性
setattr(request, '_raw_tts_data', raw_data)
except Exception as e:
logger.warning(f"Failed to parse raw request for TTS: {e}")
logger.info(f"Using API key: {api_key}")
if not await model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
# 检测是否为TTS模型如果是则使用TTS服务
if "tts" in model_name.lower():
logger.info("Detected TTS model, using TTS-enhanced service")
try:
# 从原始请求体中提取TTS字段
raw_body = await raw_request.body()
raw_data = json.loads(raw_body.decode('utf-8'))
logger.info(f"Raw request data for TTS: {json.dumps(raw_data, indent=2, ensure_ascii=False)}")
# 将TTS字段添加到请求对象中
setattr(request, '_raw_tts_data', raw_data)
# 使用TTS增强服务
tts_service = await get_tts_chat_service(key_manager)
response = await tts_service.generate_content(
model=model_name,
request=request,
api_key=api_key
)
return response
except Exception as e:
logger.warning(f"TTS processing failed, falling back to standard service: {e}")
# 如果TTS处理失败回退到标准服务
# 使用标准服务处理非TTS模型或TTS失败的情况
response = await chat_service.generate_content(
model=model_name,
request=request,