Implement smart multi-speaker TTS detection

- Only activate multi-speaker TTS when multiSpeakerVoiceConfig is present
- Preserve original TTS functionality for single-speaker requests
- Support dynamic model selection from user request
- Add fallback mechanism to standard service if multi-speaker TTS fails
- Maintain full backward compatibility with existing TTS systems
This commit is contained in:
zzh
2025-07-15 15:43:12 +09:00
parent 2b48c853fe
commit eeec45274b

View File

@@ -113,24 +113,34 @@ async def generate_content(
logger.info(f"Handling Gemini content generation request for model: {model_name}")
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
# 检测是否为多人TTS请求
is_multi_speaker_tts = False
if "tts" in model_name.lower():
try:
raw_body = await raw_request.body()
raw_data = json.loads(raw_body.decode('utf-8'))
# 检查是否包含多人语音配置
speech_config = raw_data.get("generationConfig", {}).get("speechConfig", {})
if "multiSpeakerVoiceConfig" in speech_config:
is_multi_speaker_tts = True
logger.info("Detected multi-speaker TTS request")
logger.info(f"Raw request data for multi-speaker TTS: {json.dumps(raw_data, indent=2, ensure_ascii=False)}")
# 将TTS字段添加到请求对象中
setattr(request, '_raw_tts_data', raw_data)
except Exception as e:
logger.warning(f"Failed to parse request for multi-speaker TTS detection: {e}")
logger.info(f"Using API key: {api_key}")
if not await model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
# 检测是否为TTS模型如果是则使用TTS服务
if "tts" in model_name.lower():
logger.info("Detected TTS model, using TTS-enhanced service")
# 只有多人TTS请求才使用增强服务
if is_multi_speaker_tts:
try:
# 从原始请求体中提取TTS字段
raw_body = await raw_request.body()
raw_data = json.loads(raw_body.decode('utf-8'))
logger.info(f"Raw request data for TTS: {json.dumps(raw_data, indent=2, ensure_ascii=False)}")
# 将TTS字段添加到请求对象中
setattr(request, '_raw_tts_data', raw_data)
# 使用TTS增强服务
logger.info("Using multi-speaker TTS enhanced service")
tts_service = await get_tts_chat_service(key_manager)
response = await tts_service.generate_content(
model=model_name,
@@ -139,10 +149,9 @@ async def generate_content(
)
return response
except Exception as e:
logger.warning(f"TTS processing failed, falling back to standard service: {e}")
# 如果TTS处理失败回退到标准服务
logger.warning(f"Multi-speaker TTS processing failed, falling back to standard service: {e}")
# 使用标准服务处理非TTS模型或TTS失败的情况
# 使用标准服务处理所有其他请求包括单人TTS
response = await chat_service.generate_content(
model=model_name,
request=request,