Files
gemini-balance/app/services/gemini_chat_service.py
yinpeng cd45f4b5ab refactor: 重构Gemini和OpenAI聊天服务以支持工具和安全设置
- 将 `_build_payload`、`_build_tools`、`_get_safety_settings` 和 `_has_image_parts` 函数从 `OpenAIChatService` 和 `GeminiChatService` 类中提取为独立的函数。
- 将 `_handle_stream_response` 和 `_handle_normal_response` 函数从 `GeminiResponseHandler` 和 `OpenAIResponseHandler` 类中提取为独立的函数。
- 将 `_extract_text` 函数从 `OpenAIResponseHandler` 类中提取为独立的函数, 并在 `GeminiResponseHandler` 中复用。
- 将 `_convert_image` 函数从 `OpenAIMessageConverter` 类中提取为独立的函数。
- 优化 `OpenAIChatService` 和 `GeminiChatService` 中的代码结构, 使其更清晰。
- 优化 `app/api/openai_routes.py` 和 `app/api/gemini_routes.py` 中的路由函数, 移除不必要的参数。
2025-02-06 21:35:19 +08:00

105 lines
4.4 KiB
Python

# app/services/chat_service.py
import json
from typing import Dict, Any, AsyncGenerator, List
from app.core.logger import get_gemini_logger
from app.services.chat.api_client import GeminiApiClient
from app.schemas.gemini_models import GeminiRequest
from app.core.config import settings
from app.services.chat.response_handler import GeminiResponseHandler
from app.services.key_manager import KeyManager
logger = get_gemini_logger()
def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
"""判断消息是否包含图片部分"""
for content in contents:
if "parts" in content:
for part in content["parts"]:
if "image_url" in part or "inline_data" in part:
return True
return False
def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
"""构建工具"""
tools = []
if settings.TOOLS_CODE_EXECUTION_ENABLED and not (
model.endswith("-search") or "-thinking" in model
) and not _has_image_parts(payload.get("contents", [])):
tools.append({"code_execution": {}})
if model.endswith("-search"):
tools.append({"googleSearch": {}})
return tools
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
"""获取安全设置"""
if model == "gemini-2.0-flash-exp":
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"}
]
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}
]
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
"""构建请求payload"""
payload = request.model_dump()
return {
"contents": payload.get("contents", []),
"tools": _build_tools(model, payload),
"safetySettings": _get_safety_settings(model),
"generationConfig": payload.get("generationConfig", {}),
"systemInstruction": payload.get("systemInstruction", [])
}
class GeminiChatService:
"""聊天服务"""
def __init__(self, base_url: str, key_manager: KeyManager):
self.api_client = GeminiApiClient(base_url)
self.key_manager = key_manager
self.response_handler = GeminiResponseHandler()
def generate_content(self, model: str, request: GeminiRequest, api_key: str) -> Dict[str, Any]:
"""生成内容"""
payload = _build_payload(model, request)
response = self.api_client.generate_content(payload, model, api_key)
return self.response_handler.handle_response(response, model, stream=False)
async def stream_generate_content(self, model: str, request: GeminiRequest, api_key: str) -> AsyncGenerator[str, None]:
"""流式生成内容"""
retries = 0
max_retries = 3
payload = _build_payload(model, request)
while retries < max_retries:
try:
async for line in self.api_client.stream_generate_content(payload, model, api_key):
# print(line)
if line.startswith("data:"):
line = line[6:]
line = json.dumps(self.response_handler.handle_response(json.loads(line), model, stream=True))
yield "data: " + line + "\n\n"
logger.info("Streaming completed successfully")
break
except Exception as e:
retries += 1
logger.warning(f"Streaming API call failed with error: {str(e)}. Attempt {retries} of {max_retries}")
api_key = await self.key_manager.handle_api_failure(api_key)
logger.info(f"Switched to new API key: {api_key}")
if retries >= max_retries:
logger.error(f"Max retries ({max_retries}) reached for streaming. Raising error")
break