mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-05-11 18:09:55 +08:00
删除冗余代码
This commit is contained in:
@@ -1,8 +1,7 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from copy import deepcopy
|
||||
import asyncio
|
||||
import json
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_gemini_logger
|
||||
from app.core.security import SecurityService
|
||||
@@ -101,7 +100,6 @@ async def list_models(
|
||||
async def generate_content(
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
raw_request: Request,
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
api_key: str = Depends(get_next_working_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
@@ -115,26 +113,17 @@ async def generate_content(
|
||||
|
||||
# 检测是否为原生Gemini TTS请求
|
||||
is_native_tts = False
|
||||
if "tts" in model_name.lower():
|
||||
try:
|
||||
raw_body = await raw_request.body()
|
||||
raw_data = json.loads(raw_body.decode('utf-8'))
|
||||
if "tts" in model_name.lower() and request.generationConfig:
|
||||
# 直接从解析后的request对象获取TTS配置
|
||||
response_modalities = request.generationConfig.responseModalities or []
|
||||
speech_config = request.generationConfig.speechConfig or {}
|
||||
|
||||
# 检查是否包含原生TTS配置(responseModalities和speechConfig)
|
||||
generation_config = raw_data.get("generationConfig", {})
|
||||
response_modalities = generation_config.get("responseModalities", [])
|
||||
speech_config = generation_config.get("speechConfig", {})
|
||||
|
||||
# 如果包含AUDIO模态和语音配置,则认为是原生TTS请求
|
||||
if "AUDIO" in response_modalities and speech_config:
|
||||
is_native_tts = True
|
||||
logger.info("Detected native Gemini TTS request")
|
||||
logger.info(f"Raw request data for native TTS: {json.dumps(raw_data, indent=2, ensure_ascii=False)}")
|
||||
|
||||
# 将TTS字段添加到请求对象中
|
||||
setattr(request, '_raw_tts_data', raw_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse request for native TTS detection: {e}")
|
||||
# 如果包含AUDIO模态和语音配置,则认为是原生TTS请求
|
||||
if "AUDIO" in response_modalities and speech_config:
|
||||
is_native_tts = True
|
||||
logger.info("Detected native Gemini TTS request")
|
||||
logger.info(f"TTS responseModalities: {response_modalities}")
|
||||
logger.info(f"TTS speechConfig: {speech_config}")
|
||||
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
|
||||
@@ -53,11 +53,10 @@ python -m uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||
|
||||
```python
|
||||
# app/router/gemini_routes.py 中的智能检测逻辑
|
||||
if "tts" in model_name.lower():
|
||||
# 检查是否包含原生TTS配置
|
||||
generation_config = raw_data.get("generationConfig", {})
|
||||
response_modalities = generation_config.get("responseModalities", [])
|
||||
speech_config = generation_config.get("speechConfig", {})
|
||||
if "tts" in model_name.lower() and request.generationConfig:
|
||||
# 直接从解析后的request对象获取TTS配置
|
||||
response_modalities = request.generationConfig.responseModalities or []
|
||||
speech_config = request.generationConfig.speechConfig or {}
|
||||
|
||||
# 如果包含AUDIO模态和语音配置,则认为是原生TTS请求
|
||||
if "AUDIO" in response_modalities and speech_config:
|
||||
@@ -210,14 +209,14 @@ TTSGenerationConfig
|
||||
1. **请求接收**:系统接收到API请求
|
||||
2. **智能检测**:
|
||||
- 检查模型名称是否包含 "tts"
|
||||
- 如果是TTS模型,解析请求体检查是否包含 `responseModalities: ["AUDIO"]` 和 `speechConfig`
|
||||
- 如果是TTS模型,从 `request.generationConfig` 检查是否包含 `responseModalities: ["AUDIO"]` 和 `speechConfig`
|
||||
3. **服务选择**:
|
||||
- **原生TTS请求**:使用 `TTSGeminiChatService` 增强服务
|
||||
- **普通请求**:使用原有 `GeminiChatService`
|
||||
4. **请求处理**:
|
||||
- **原生TTS**:使用 `_handle_tts_request()` 特殊处理
|
||||
- **其他请求**:使用标准 `generate_content()` 方法
|
||||
5. **字段处理**:从原始HTTP请求体提取TTS字段(`responseModalities`, `speechConfig`)
|
||||
5. **字段处理**:从 `request.generationConfig` 直接获取TTS字段(`responseModalities`, `speechConfig`)
|
||||
6. **API调用**:构建优化的payload并调用Gemini API
|
||||
7. **自动回退**:如果原生TTS处理失败,自动回退到标准服务
|
||||
8. **响应处理**:
|
||||
|
||||
@@ -85,21 +85,18 @@ class TTSGeminiChatService(GeminiChatService):
|
||||
if payload["generationConfig"] is None:
|
||||
payload["generationConfig"] = {}
|
||||
|
||||
# 从原始请求中提取TTS相关字段
|
||||
if hasattr(request, '_raw_tts_data'):
|
||||
raw_data = getattr(request, '_raw_tts_data')
|
||||
raw_generation_config = raw_data.get("generationConfig", {})
|
||||
|
||||
# 从request.generationConfig直接获取TTS相关字段
|
||||
if request.generationConfig:
|
||||
# 添加TTS特定字段
|
||||
if "responseModalities" in raw_generation_config:
|
||||
payload["generationConfig"]["responseModalities"] = raw_generation_config["responseModalities"]
|
||||
logger.info(f"Added responseModalities: {raw_generation_config['responseModalities']}")
|
||||
if request.generationConfig.responseModalities:
|
||||
payload["generationConfig"]["responseModalities"] = request.generationConfig.responseModalities
|
||||
logger.info(f"Added responseModalities: {request.generationConfig.responseModalities}")
|
||||
|
||||
if "speechConfig" in raw_generation_config:
|
||||
payload["generationConfig"]["speechConfig"] = raw_generation_config["speechConfig"]
|
||||
logger.info(f"Added speechConfig: {raw_generation_config['speechConfig']}")
|
||||
if request.generationConfig.speechConfig:
|
||||
payload["generationConfig"]["speechConfig"] = request.generationConfig.speechConfig
|
||||
logger.info(f"Added speechConfig: {request.generationConfig.speechConfig}")
|
||||
else:
|
||||
logger.warning("No raw TTS data found in request, TTS fields may be missing")
|
||||
logger.warning("No generationConfig found in request, TTS fields may be missing")
|
||||
|
||||
logger.info(f"TTS payload before API call: {payload}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user