mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-05-31 21:29:44 +08:00
Implement smart multi-speaker TTS detection
- Only activate multi-speaker TTS when multiSpeakerVoiceConfig is present - Preserve original TTS functionality for single-speaker requests - Support dynamic model selection from user request - Add fallback mechanism to standard service if multi-speaker TTS fails - Maintain full backward compatibility with existing TTS systems
This commit is contained in:
@@ -113,24 +113,34 @@ async def generate_content(
|
||||
logger.info(f"Handling Gemini content generation request for model: {model_name}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
|
||||
# 检测是否为多人TTS请求
|
||||
is_multi_speaker_tts = False
|
||||
if "tts" in model_name.lower():
|
||||
try:
|
||||
raw_body = await raw_request.body()
|
||||
raw_data = json.loads(raw_body.decode('utf-8'))
|
||||
|
||||
# 检查是否包含多人语音配置
|
||||
speech_config = raw_data.get("generationConfig", {}).get("speechConfig", {})
|
||||
if "multiSpeakerVoiceConfig" in speech_config:
|
||||
is_multi_speaker_tts = True
|
||||
logger.info("Detected multi-speaker TTS request")
|
||||
logger.info(f"Raw request data for multi-speaker TTS: {json.dumps(raw_data, indent=2, ensure_ascii=False)}")
|
||||
|
||||
# 将TTS字段添加到请求对象中
|
||||
setattr(request, '_raw_tts_data', raw_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse request for multi-speaker TTS detection: {e}")
|
||||
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
|
||||
# 检测是否为TTS模型,如果是则使用TTS服务
|
||||
if "tts" in model_name.lower():
|
||||
logger.info("Detected TTS model, using TTS-enhanced service")
|
||||
# 只有多人TTS请求才使用增强服务
|
||||
if is_multi_speaker_tts:
|
||||
try:
|
||||
# 从原始请求体中提取TTS字段
|
||||
raw_body = await raw_request.body()
|
||||
raw_data = json.loads(raw_body.decode('utf-8'))
|
||||
logger.info(f"Raw request data for TTS: {json.dumps(raw_data, indent=2, ensure_ascii=False)}")
|
||||
|
||||
# 将TTS字段添加到请求对象中
|
||||
setattr(request, '_raw_tts_data', raw_data)
|
||||
|
||||
# 使用TTS增强服务
|
||||
logger.info("Using multi-speaker TTS enhanced service")
|
||||
tts_service = await get_tts_chat_service(key_manager)
|
||||
response = await tts_service.generate_content(
|
||||
model=model_name,
|
||||
@@ -139,10 +149,9 @@ async def generate_content(
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.warning(f"TTS processing failed, falling back to standard service: {e}")
|
||||
# 如果TTS处理失败,回退到标准服务
|
||||
logger.warning(f"Multi-speaker TTS processing failed, falling back to standard service: {e}")
|
||||
|
||||
# 使用标准服务处理非TTS模型或TTS失败的情况
|
||||
# 使用标准服务处理所有其他请求(包括单人TTS)
|
||||
response = await chat_service.generate_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
|
||||
Reference in New Issue
Block a user