mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-05-31 21:29:44 +08:00
192 lines
7.6 KiB
Python
192 lines
7.6 KiB
Python
# app/services/chat_service.py
|
|
|
|
import json
|
|
from typing import Dict, Any, AsyncGenerator, List, Union
|
|
from app.core.logger import get_openai_logger
|
|
from app.services.chat.message_converter import OpenAIMessageConverter
|
|
from app.services.chat.response_handler import OpenAIResponseHandler
|
|
from app.services.chat.api_client import GeminiApiClient
|
|
from app.schemas.openai_models import ChatRequest, ImageGenerationRequest
|
|
from app.core.config import settings
|
|
from app.services.image_create_service import ImageCreateService
|
|
from app.services.key_manager import KeyManager
|
|
|
|
logger = get_openai_logger()
|
|
|
|
|
|
def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
|
|
"""判断消息是否包含图片部分"""
|
|
for content in contents:
|
|
if "parts" in content:
|
|
for part in content["parts"]:
|
|
if "image_url" in part or "inline_data" in part:
|
|
return True
|
|
return False
|
|
|
|
|
|
def _build_tools(
|
|
request: ChatRequest, messages: List[Dict[str, Any]]
|
|
) -> List[Dict[str, Any]]:
|
|
"""构建工具"""
|
|
tools = []
|
|
model = request.model
|
|
|
|
if (
|
|
settings.TOOLS_CODE_EXECUTION_ENABLED
|
|
and not (model.endswith("-search") or "-thinking" in model)
|
|
and not _has_image_parts(messages)
|
|
):
|
|
tools.append({"code_execution": {}})
|
|
if model.endswith("-search"):
|
|
tools.append({"googleSearch": {}})
|
|
return tools
|
|
|
|
|
|
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
|
"""获取安全设置"""
|
|
# if (
|
|
# "2.0" in model
|
|
# and "gemini-2.0-flash-thinking-exp" not in model
|
|
# and "gemini-2.0-pro-exp" not in model
|
|
# ):
|
|
if model == "gemini-2.0-flash-exp":
|
|
return [
|
|
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
|
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
|
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
|
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
|
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
|
|
]
|
|
return [
|
|
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
|
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
|
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
|
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
|
|
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
|
|
]
|
|
|
|
|
|
def _build_payload(
|
|
request: ChatRequest, messages: List[Dict[str, Any]]
|
|
) -> Dict[str, Any]:
|
|
"""构建请求payload"""
|
|
return {
|
|
"contents": messages,
|
|
"generationConfig": {
|
|
"temperature": request.temperature,
|
|
"maxOutputTokens": request.max_tokens,
|
|
"stopSequences": request.stop,
|
|
"topP": request.top_p,
|
|
"topK": request.top_k,
|
|
},
|
|
"tools": _build_tools(request, messages),
|
|
"safetySettings": _get_safety_settings(request.model),
|
|
}
|
|
|
|
|
|
class OpenAIChatService:
|
|
"""聊天服务"""
|
|
def __init__(self, base_url: str, key_manager: KeyManager = None):
|
|
self.message_converter = OpenAIMessageConverter()
|
|
self.response_handler = OpenAIResponseHandler(config=None)
|
|
self.api_client = GeminiApiClient(base_url)
|
|
self.key_manager = key_manager
|
|
self.image_create_service = ImageCreateService()
|
|
|
|
async def create_chat_completion(
|
|
self,
|
|
request: ChatRequest,
|
|
api_key: str,
|
|
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
|
"""创建聊天完成"""
|
|
# 转换消息格式
|
|
messages = self.message_converter.convert(request.messages)
|
|
|
|
# 构建请求payload
|
|
payload = _build_payload(request, messages)
|
|
|
|
if request.stream:
|
|
return self._handle_stream_completion(request.model, payload, api_key)
|
|
return self._handle_normal_completion(request.model, payload, api_key)
|
|
|
|
def _handle_normal_completion(
|
|
self, model: str, payload: Dict[str, Any], api_key: str
|
|
) -> Dict[str, Any]:
|
|
"""处理普通聊天完成"""
|
|
response = self.api_client.generate_content(payload, model, api_key)
|
|
return self.response_handler.handle_response(
|
|
response, model, stream=False, finish_reason="stop"
|
|
)
|
|
|
|
async def _handle_stream_completion(
|
|
self, model: str, payload: Dict[str, Any], api_key: str
|
|
) -> AsyncGenerator[str, None]:
|
|
"""处理流式聊天完成,添加重试逻辑"""
|
|
retries = 0
|
|
max_retries = 3
|
|
while retries < max_retries:
|
|
try:
|
|
async for line in self.api_client.stream_generate_content(
|
|
payload, model, api_key
|
|
):
|
|
# print(line)
|
|
if line.startswith("data:"):
|
|
chunk = json.loads(line[6:])
|
|
openai_chunk = self.response_handler.handle_response(
|
|
chunk, model, stream=True, finish_reason=None
|
|
)
|
|
if openai_chunk:
|
|
yield f"data: {json.dumps(openai_chunk)}\n\n"
|
|
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
logger.info("Streaming completed successfully")
|
|
break # 成功后退出循环
|
|
except Exception as e:
|
|
retries += 1
|
|
logger.warning(
|
|
f"Streaming API call failed with error: {str(e)}. Attempt {retries} of {max_retries}"
|
|
)
|
|
api_key = await self.key_manager.handle_api_failure(api_key)
|
|
logger.info(f"Switched to new API key: {api_key}")
|
|
if retries >= max_retries:
|
|
logger.error(
|
|
f"Max retries ({max_retries}) reached for streaming. Raising error"
|
|
)
|
|
yield f"data: {json.dumps({'error': 'Streaming failed after retries'})}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
break
|
|
|
|
async def create_image_chat_completion(
|
|
self,
|
|
request: ChatRequest,
|
|
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
|
|
|
image_generate_request = ImageGenerationRequest()
|
|
image_generate_request.prompt = request.messages[-1]["content"]
|
|
image_res = self.image_create_service.generate_images_chat(image_generate_request)
|
|
|
|
if request.stream:
|
|
return self._handle_stream_image_completion(request.model,image_res)
|
|
else:
|
|
return self._handle_normal_image_completion(request.model,image_res)
|
|
|
|
async def _handle_stream_image_completion(
|
|
self, model: str, image_data: str
|
|
) -> AsyncGenerator[str, None]:
|
|
if image_data:
|
|
openai_chunk = self.response_handler.handle_image_chat_response(
|
|
image_data, model, stream=True, finish_reason=None
|
|
)
|
|
if openai_chunk:
|
|
yield f"data: {json.dumps(openai_chunk)}\n\n"
|
|
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
logger.info("Image chat streaming completed successfully")
|
|
|
|
def _handle_normal_image_completion(
|
|
self, model: str, image_data: str
|
|
) -> Dict[str, Any]:
|
|
|
|
return self.response_handler.handle_image_chat_response(
|
|
image_data, model, stream=False, finish_reason="stop"
|
|
) |