mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-05-11 10:00:37 +08:00
Refactor: Use TTS service only for TTS models, keep original service for others
- Remove ENABLE_TTS environment variable dependency - Detect TTS models dynamically by model name - Use TTS-enhanced service only when needed - Fallback to standard service if TTS processing fails - Maintain full backward compatibility
This commit is contained in:
@@ -35,12 +35,7 @@ async def get_next_working_key(key_manager: KeyManager = Depends(get_key_manager
|
||||
|
||||
async def get_chat_service(key_manager: KeyManager = Depends(get_key_manager)):
|
||||
"""获取Gemini聊天服务实例"""
|
||||
# 检查是否启用TTS功能
|
||||
import os
|
||||
if os.getenv("ENABLE_TTS", "false").lower() in ("true", "1", "yes", "on"):
|
||||
return await get_tts_chat_service(key_manager)
|
||||
else:
|
||||
return GeminiChatService(settings.BASE_URL, key_manager)
|
||||
return GeminiChatService(settings.BASE_URL, key_manager)
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
@@ -118,27 +113,36 @@ async def generate_content(
|
||||
logger.info(f"Handling Gemini content generation request for model: {model_name}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
|
||||
# 对于TTS模型,我们需要从原始请求体中提取TTS字段
|
||||
if "tts" in model_name.lower():
|
||||
try:
|
||||
raw_body = await raw_request.body()
|
||||
raw_data = json.loads(raw_body.decode('utf-8'))
|
||||
logger.info(f"Raw request data for TTS: {json.dumps(raw_data, indent=2, ensure_ascii=False)}")
|
||||
|
||||
# 将TTS字段添加到请求对象中
|
||||
if hasattr(request, '_raw_tts_data'):
|
||||
request._raw_tts_data = raw_data
|
||||
else:
|
||||
# 动态添加属性
|
||||
setattr(request, '_raw_tts_data', raw_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse raw request for TTS: {e}")
|
||||
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
|
||||
# 检测是否为TTS模型,如果是则使用TTS服务
|
||||
if "tts" in model_name.lower():
|
||||
logger.info("Detected TTS model, using TTS-enhanced service")
|
||||
try:
|
||||
# 从原始请求体中提取TTS字段
|
||||
raw_body = await raw_request.body()
|
||||
raw_data = json.loads(raw_body.decode('utf-8'))
|
||||
logger.info(f"Raw request data for TTS: {json.dumps(raw_data, indent=2, ensure_ascii=False)}")
|
||||
|
||||
# 将TTS字段添加到请求对象中
|
||||
setattr(request, '_raw_tts_data', raw_data)
|
||||
|
||||
# 使用TTS增强服务
|
||||
tts_service = await get_tts_chat_service(key_manager)
|
||||
response = await tts_service.generate_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.warning(f"TTS processing failed, falling back to standard service: {e}")
|
||||
# 如果TTS处理失败,回退到标准服务
|
||||
|
||||
# 使用标准服务处理非TTS模型或TTS失败的情况
|
||||
response = await chat_service.generate_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
|
||||
Reference in New Issue
Block a user