Files
gemini-balance/app/services/chat/api_client.py
yinpeng cd45f4b5ab refactor: 重构Gemini和OpenAI聊天服务以支持工具和安全设置
- 将 `_build_payload`、`_build_tools`、`_get_safety_settings` 和 `_has_image_parts` 函数从 `OpenAIChatService` 和 `GeminiChatService` 类中提取为独立的函数。
- 将 `_handle_stream_response` 和 `_handle_normal_response` 函数从 `GeminiResponseHandler` 和 `OpenAIResponseHandler` 类中提取为独立的函数。
- 将 `_extract_text` 函数从 `OpenAIResponseHandler` 类中提取为独立的函数, 并在 `GeminiResponseHandler` 中复用。
- 将 `_convert_image` 函数从 `OpenAIMessageConverter` 类中提取为独立的函数。
- 优化 `OpenAIChatService` 和 `GeminiChatService` 中的代码结构, 使其更清晰。
- 优化 `app/api/openai_routes.py` 和 `app/api/gemini_routes.py` 中的路由函数, 移除不必要的参数。
2025-02-06 21:35:19 +08:00

52 lines
2.2 KiB
Python

# app/services/chat/api_client.py
from typing import Dict, Any, AsyncGenerator
import httpx
from abc import ABC, abstractmethod
class ApiClient(ABC):
"""API客户端基类"""
@abstractmethod
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
pass
@abstractmethod
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
pass
class GeminiApiClient(ApiClient):
"""Gemini API客户端"""
def __init__(self, base_url: str, timeout: int = 300):
self.base_url = base_url
self.timeout = timeout
def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
if model.endswith("-search"):
model = model[:-7]
with httpx.Client(timeout=timeout) as client:
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
response = client.post(url, json=payload)
if response.status_code != 200:
error_content = response.text
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
return response.json()
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
if model.endswith("-search"):
model = model[:-7]
async with httpx.AsyncClient(timeout=timeout) as client:
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
async with client.stream(method="POST", url=url, json=payload) as response:
if response.status_code != 200:
error_content = await response.aread()
error_msg = error_content.decode("utf-8")
raise Exception(f"API call failed with status code {response.status_code}, {error_msg}")
async for line in response.aiter_lines():
yield line