diff --git a/app/services/chat_service.py b/app/services/chat_service.py index 14cf39d..90ff3bd 100644 --- a/app/services/chat_service.py +++ b/app/services/chat_service.py @@ -71,42 +71,69 @@ class ChatService: def __init__(self, base_url: str, key_manager=None): self.base_url = base_url self.key_manager = key_manager + self.thinking_first = True def convert_gemini_response_to_openai( - self, - response: Dict[str, Any], - model: str, - stream: bool = False, - finish_reason: str = None, + self, + response: Dict[str, Any], + model: str, + stream: bool = False, + finish_reason: str = None, ) -> Optional[Dict[str, Any]]: """Convert Gemini response to OpenAI format""" if stream: try: + text = "" if response.get("candidates"): candidate = response["candidates"][0] content = candidate.get("content", {}) parts = content.get("parts", []) - if "text" in parts[0]: - text = parts[0].get("text") - elif "executableCode" in parts[0]: - text = self.format_code_block(parts[0]["executableCode"]) - elif "codeExecution" in parts[0]: - text = self.format_code_block(parts[0]["codeExecution"]) - elif "executableCodeResult" in parts[0]: - text = format_execution_result( - parts[0]["executableCodeResult"] - ) - elif "codeExecutionResult" in parts[0]: - text = format_execution_result( - parts[0]["codeExecutionResult"] - ) + if "thinking" in model: + if len(parts) == 1: + if self.thinking_first: + self.thinking_first = False + text = "\n🤔 **思考过程** 🤔\n---\n```\n" + parts[ + 0 + ].get("text") + else: + text = parts[0].get("text") + elif len(parts) == 2: + if self.thinking_first: + self.thinking_first = False + text = ( + "\n🤔 **思考过程** 🤔\n---\n```\n" + + parts[0].get("text") + + "\n```\n---\n" + + parts[1].get("text") + ) + else: + text = ( + parts[0].get("text") + + "\n```\n---\n" + + parts[1].get("text") + ) + else: + text = "" else: - text = "" + if "text" in parts[0]: + text = parts[0].get("text") + elif "executableCode" in parts[0]: + text = self.format_code_block(parts[0]["executableCode"]) + elif "codeExecution" in parts[0]: + text = self.format_code_block(parts[0]["codeExecution"]) + elif "executableCodeResult" in parts[0]: + text = format_execution_result( + parts[0]["executableCodeResult"] + ) + elif "codeExecutionResult" in parts[0]: + text = format_execution_result( + parts[0]["codeExecutionResult"] + ) + else: + text = "" text = self.add_search_link_text(model, candidate, text) - else: - text = "" return { "id": f"chatcmpl-{uuid.uuid4()}", @@ -136,7 +163,9 @@ class ChatService: "index": 0, "message": { "role": "assistant", - "content": response["candidates"][0]["content"]["parts"][0]["text"], + "content": response["candidates"][0]["content"]["parts"][0][ + "text" + ], }, "finish_reason": finish_reason, } @@ -149,8 +178,17 @@ class ChatService: } try: if response.get("candidates"): - text = response["candidates"][0]["content"]["parts"][0]["text"] candidate = response["candidates"][0] + if "thinking" in model: + text = ( + "\n🤔 **思考过程** 🤔\n---\n```\n" + + candidate["content"]["parts"][0]["text"] + + "\n```\n---\n" + + candidate["content"]["parts"][1]["text"] + ) + else: + text = candidate["content"]["parts"][0]["text"] + text = self.add_search_link_text(model, candidate, text) res["choices"][0]["message"]["content"] = text return res @@ -160,30 +198,32 @@ class ChatService: except Exception as e: logger.error(f"Error converting Gemini response: {str(e)}") logger.debug(f"Raw response: {response}") - res["choices"][0]["message"]["content"] = f"Error converting Gemini response: {str(e)}" + res["choices"][0]["message"][ + "content" + ] = f"Error converting Gemini response: {str(e)}" return res def add_search_link_text(self, model, candidate, text): if ( - settings.SHOW_SEARCH_LINK - and model.endswith("-search") - and "groundingMetadata" in candidate - and "groundingChunks" in candidate["groundingMetadata"] + settings.SHOW_SEARCH_LINK + and model.endswith("-search") + and "groundingMetadata" in candidate + and "groundingChunks" in candidate["groundingMetadata"] ): grounding_chunks = candidate["groundingMetadata"]["groundingChunks"] text += "\n\n---\n\n" - text += f"**【引用来源】**\n\n" + text += "**【引用来源】**\n\n" for _, grounding_chunk in enumerate(grounding_chunks, 1): if "web" in grounding_chunk: text += create_search_link(grounding_chunk["web"]) return text else: return text - + async def create_chat_completion( - self, - request: ChatRequest, - api_key: str, + self, + request: ChatRequest, + api_key: str, ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: """Create chat completion using either Gemini or OpenAI API""" model = request.model @@ -191,7 +231,7 @@ class ChatService: if tools is None: tools = [] if settings.TOOLS_CODE_EXECUTION_ENABLED and not ( - model.endswith("-search") or "-thinking" in model + model.endswith("-search") or "-thinking" in model ): tools.append({"code_execution": {}}) if model.endswith("-search"): @@ -199,10 +239,10 @@ class ChatService: return await self._gemini_chat_completion(request, api_key, tools) async def _gemini_chat_completion( - self, - request: ChatRequest, - api_key: str, - tools: Optional[list] = None, + self, + request: ChatRequest, + api_key: str, + tools: Optional[list] = None, ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: """Handle Gemini API chat completion""" model = request.model @@ -252,6 +292,7 @@ class ChatService: } if stream: + async def generate(): retries = 0 max_retries = 3 @@ -265,7 +306,7 @@ class ChatService: async with httpx.AsyncClient(timeout=timeout) as async_client: stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:streamGenerateContent?alt=sse&key={current_api_key}" async with async_client.stream( - "POST", stream_url, json=payload + "POST", stream_url, json=payload ) as async_response: if async_response.status_code != 200: error_content = await async_response.read() @@ -344,7 +385,9 @@ class ChatService: return generate() else: try: - timeout = httpx.Timeout(300.0, read=300.0) # 连接超时300秒,读取超时300秒 + timeout = httpx.Timeout( + 300.0, read=300.0 + ) # 连接超时300秒,读取超时300秒 async with httpx.AsyncClient(timeout=timeout) as client: url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:generateContent?key={api_key}" response = await client.post(url, json=payload) @@ -370,7 +413,7 @@ class ChatService: return f"""\n【代码执行】\n```{language}\n{code}\n```\n""" async def generate_content( - self, model_name: str, request: GeminiRequest, api_key: str + self, model_name: str, request: GeminiRequest, api_key: str ) -> dict: """调用Gemini API生成内容""" url = f"{self.base_url}/models/{model_name}:generateContent?key={api_key}" @@ -393,7 +436,7 @@ class ChatService: raise async def stream_generate_content( - self, model_name: str, request: GeminiRequest, api_key: str + self, model_name: str, request: GeminiRequest, api_key: str ) -> AsyncGenerator: """调用Gemini API流式生成内容""" retries = 0 @@ -407,7 +450,7 @@ class ChatService: async with httpx.AsyncClient(timeout=timeout) as client: async with client.stream( - "POST", url, json=request.model_dump() + "POST", url, json=request.model_dump() ) as response: if response.status_code != 200: error_text = await response.text()