From b60b0630346be35ca973022c140d8fba3b065479 Mon Sep 17 00:00:00 2001 From: yinpeng <2291314224@qq.com> Date: Thu, 6 Feb 2025 00:33:11 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=BF=AE=E6=94=B9=E5=AE=89?= =?UTF-8?q?=E5=85=A8=E8=AE=BE=E7=BD=AE=E9=80=BB=E8=BE=91=E4=BB=A5=E5=8C=B9?= =?UTF-8?q?=E9=85=8D=E7=89=B9=E5=AE=9A=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/services/gemini_chat_service.py | 2 +- app/services/openai_chat_service.py | 89 ++++++++++++++++------------- 2 files changed, 49 insertions(+), 42 deletions(-) diff --git a/app/services/gemini_chat_service.py b/app/services/gemini_chat_service.py index 0c068db..2674b11 100644 --- a/app/services/gemini_chat_service.py +++ b/app/services/gemini_chat_service.py @@ -81,7 +81,7 @@ class GeminiChatService: def _get_safety_settings(self, model: str) -> List[Dict[str, str]]: """获取安全设置""" - if "2.0" in model and "gemini-2.0-flash-thinking-exp" not in model: + if model == "gemini-2.0-flash-exp": return [ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, diff --git a/app/services/openai_chat_service.py b/app/services/openai_chat_service.py index 7976c83..79300cd 100644 --- a/app/services/openai_chat_service.py +++ b/app/services/openai_chat_service.py @@ -4,13 +4,15 @@ import json from typing import Dict, Any, AsyncGenerator, List, Union from app.core.logger import get_openai_logger from app.services.chat.message_converter import OpenAIMessageConverter -from app.services.chat.response_handler import OpenAIResponseHandler +from app.services.chat.response_handler import OpenAIResponseHandler from app.services.chat.api_client import GeminiApiClient from app.schemas.openai_models import ChatRequest from app.core.config import settings from app.services.key_manager import KeyManager logger = get_openai_logger() + + class OpenAIChatService: """聊天服务""" @@ -19,7 +21,7 @@ class OpenAIChatService: self.response_handler = OpenAIResponseHandler(config=None) self.api_client = GeminiApiClient(base_url) self.key_manager = key_manager - + async def create_chat_completion( self, request: ChatRequest, @@ -28,68 +30,64 @@ class OpenAIChatService: """创建聊天完成""" # 转换消息格式 messages = self.message_converter.convert(request.messages) - + # 构建请求payload payload = self._build_payload(request, messages) - + if request.stream: return self._handle_stream_completion(request.model, payload, api_key) return self._handle_normal_completion(request.model, payload, api_key) - + def _handle_normal_completion( - self, - model: str, - payload: Dict[str, Any], - api_key: str + self, model: str, payload: Dict[str, Any], api_key: str ) -> Dict[str, Any]: """处理普通聊天完成""" response = self.api_client.generate_content(payload, model, api_key) return self.response_handler.handle_response( - response, - model, - stream=False, - finish_reason="stop" + response, model, stream=False, finish_reason="stop" ) - + async def _handle_stream_completion( - self, - model: str, - payload: Dict[str, Any], - api_key: str + self, model: str, payload: Dict[str, Any], api_key: str ) -> AsyncGenerator[str, None]: """处理流式聊天完成,添加重试逻辑""" retries = 0 max_retries = 3 while retries < max_retries: try: - async for line in self.api_client.stream_generate_content(payload, model, api_key): + async for line in self.api_client.stream_generate_content( + payload, model, api_key + ): # print(line) if line.startswith("data:"): chunk = json.loads(line[6:]) openai_chunk = self.response_handler.handle_response( - chunk, - model, - stream=True, - finish_reason=None + chunk, model, stream=True, finish_reason=None ) if openai_chunk: yield f"data: {json.dumps(openai_chunk)}\n\n" yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n" yield "data: [DONE]\n\n" logger.info("Streaming completed successfully") - break # 成功后退出循环 + break # 成功后退出循环 except Exception as e: retries += 1 - logger.warning(f"Streaming API call failed with error: {str(e)}. Attempt {retries} of {max_retries}") + logger.warning( + f"Streaming API call failed with error: {str(e)}. Attempt {retries} of {max_retries}" + ) api_key = await self.key_manager.handle_api_failure(api_key) logger.info(f"Switched to new API key: {api_key}") if retries >= max_retries: - logger.error(f"Max retries ({max_retries}) reached for streaming. Raising error") + logger.error( + f"Max retries ({max_retries}) reached for streaming. Raising error" + ) yield f"data: {json.dumps({'error': 'Streaming failed after retries'})}\n\n" yield "data: [DONE]\n\n" break - - def _build_payload(self, request: ChatRequest, messages: List[Dict[str, Any]]) -> Dict[str, Any]: + + def _build_payload( + self, request: ChatRequest, messages: List[Dict[str, Any]] + ) -> Dict[str, Any]: """构建请求payload""" return { "contents": messages, @@ -98,25 +96,29 @@ class OpenAIChatService: "maxOutputTokens": request.max_tokens, "stopSequences": request.stop, "topP": request.top_p, - "topK": request.top_k + "topK": request.top_k, }, "tools": self._build_tools(request, messages), - "safetySettings": self._get_safety_settings(request.model) + "safetySettings": self._get_safety_settings(request.model), } - - def _build_tools(self, request: ChatRequest, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + + def _build_tools( + self, request: ChatRequest, messages: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: """构建工具""" tools = [] model = request.model - if settings.TOOLS_CODE_EXECUTION_ENABLED and not ( - model.endswith("-search") or "-thinking" in model - ) and not self._has_image_parts(messages): + if ( + settings.TOOLS_CODE_EXECUTION_ENABLED + and not (model.endswith("-search") or "-thinking" in model) + and not self._has_image_parts(messages) + ): tools.append({"code_execution": {}}) if model.endswith("-search"): tools.append({"googleSearch": {}}) return tools - + def _has_image_parts(self, contents: List[Dict[str, Any]]) -> bool: """判断消息是否包含图片部分""" for content in contents: @@ -125,21 +127,26 @@ class OpenAIChatService: if "image_url" in part or "inline_data" in part: return True return False - + def _get_safety_settings(self, model: str) -> List[Dict[str, str]]: """获取安全设置""" - if "2.0" in model and "gemini-2.0-flash-thinking-exp" not in model: + # if ( + # "2.0" in model + # and "gemini-2.0-flash-thinking-exp" not in model + # and "gemini-2.0-pro-exp" not in model + # ): + if model == "gemini-2.0-flash-exp": return [ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, - {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"} + {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"}, ] return [ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, - {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"} - ] \ No newline at end of file + {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}, + ]