diff --git a/app/router/gemini_routes.py b/app/router/gemini_routes.py index e349d04..0e91955 100644 --- a/app/router/gemini_routes.py +++ b/app/router/gemini_routes.py @@ -113,24 +113,34 @@ async def generate_content( logger.info(f"Handling Gemini content generation request for model: {model_name}") logger.debug(f"Request: \n{request.model_dump_json(indent=2)}") + # 检测是否为多人TTS请求 + is_multi_speaker_tts = False + if "tts" in model_name.lower(): + try: + raw_body = await raw_request.body() + raw_data = json.loads(raw_body.decode('utf-8')) + + # 检查是否包含多人语音配置 + speech_config = raw_data.get("generationConfig", {}).get("speechConfig", {}) + if "multiSpeakerVoiceConfig" in speech_config: + is_multi_speaker_tts = True + logger.info("Detected multi-speaker TTS request") + logger.info(f"Raw request data for multi-speaker TTS: {json.dumps(raw_data, indent=2, ensure_ascii=False)}") + + # 将TTS字段添加到请求对象中 + setattr(request, '_raw_tts_data', raw_data) + except Exception as e: + logger.warning(f"Failed to parse request for multi-speaker TTS detection: {e}") + logger.info(f"Using API key: {api_key}") if not await model_service.check_model_support(model_name): raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported") - # 检测是否为TTS模型,如果是则使用TTS服务 - if "tts" in model_name.lower(): - logger.info("Detected TTS model, using TTS-enhanced service") + # 只有多人TTS请求才使用增强服务 + if is_multi_speaker_tts: try: - # 从原始请求体中提取TTS字段 - raw_body = await raw_request.body() - raw_data = json.loads(raw_body.decode('utf-8')) - logger.info(f"Raw request data for TTS: {json.dumps(raw_data, indent=2, ensure_ascii=False)}") - - # 将TTS字段添加到请求对象中 - setattr(request, '_raw_tts_data', raw_data) - - # 使用TTS增强服务 + logger.info("Using multi-speaker TTS enhanced service") tts_service = await get_tts_chat_service(key_manager) response = await tts_service.generate_content( model=model_name, @@ -139,10 +149,9 @@ async def generate_content( ) return response except Exception as e: - logger.warning(f"TTS processing failed, falling back to standard service: {e}") - # 如果TTS处理失败,回退到标准服务 + logger.warning(f"Multi-speaker TTS processing failed, falling back to standard service: {e}") - # 使用标准服务处理非TTS模型或TTS失败的情况 + # 使用标准服务处理所有其他请求(包括单人TTS) response = await chat_service.generate_content( model=model_name, request=request,