From a5adcdae4883dd178b620f983b8ffc2a7301d3ea Mon Sep 17 00:00:00 2001 From: yinpeng <2291314224@qq.com> Date: Sun, 15 Dec 2024 15:23:25 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96Gemini=E5=93=8D=E5=BA=94?= =?UTF-8?q?=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91=EF=BC=8C=E6=94=B9=E8=BF=9B?= =?UTF-8?q?=E9=94=99=E8=AF=AF=E5=A4=84=E7=90=86=E5=92=8C=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/services/chat_service.py | 88 +++++++++++++++--------------------- 1 file changed, 37 insertions(+), 51 deletions(-) diff --git a/app/services/chat_service.py b/app/services/chat_service.py index 8a0c6af..2b33a39 100644 --- a/app/services/chat_service.py +++ b/app/services/chat_service.py @@ -65,27 +65,24 @@ class ChatService: ) -> Optional[Dict[str, Any]]: """Convert Gemini response to OpenAI format""" if stream: - if not response.get("candidates"): - return None - try: - candidate = response["candidates"][0] - content = candidate.get("content", {}) - parts = content.get("parts", []) + if response.get("candidates"): + candidate = response["candidates"][0] + content = candidate.get("content", {}) + parts = content.get("parts", []) - if not parts: - return None - - if "text" in parts[0]: - text = parts[0].get("text") - elif "executableCode" in parts[0]: - text = self.format_code_block(parts[0]["executableCode"]) - elif "codeExecution" in parts[0]: - text = self.format_code_block(parts[0]["codeExecution"]) - elif "executableCodeResult" in parts[0]: - text = self.format_execution_result(parts[0]["executableCodeResult"]) - elif "codeExecutionResult" in parts[0]: - text = self.format_execution_result(parts[0]["codeExecutionResult"]) + if "text" in parts[0]: + text = parts[0].get("text") + elif "executableCode" in parts[0]: + text = self.format_code_block(parts[0]["executableCode"]) + elif "codeExecution" in parts[0]: + text = self.format_code_block(parts[0]["codeExecution"]) + elif "executableCodeResult" in parts[0]: + text = self.format_execution_result(parts[0]["executableCodeResult"]) + elif "codeExecutionResult" in parts[0]: + text = self.format_execution_result(parts[0]["codeExecutionResult"]) + else: + text = "" else: text = "" @@ -180,7 +177,6 @@ class ChatService: } if stream: - async def generate(): retries = 0 MAX_RETRIES = 3 @@ -190,22 +186,12 @@ class ChatService: try: async with httpx.AsyncClient() as client: stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:streamGenerateContent?alt=sse&key={current_api_key}" - async with client.stream( - "POST", stream_url, json=payload - ) as response: + async with client.stream("POST", stream_url, json=payload) as response: if response.status_code != 200: + error_msg = await response.text() + logger.error(f"API error: {response.status_code}, {error_msg}") if retries < MAX_RETRIES - 1: - logger.warning( - f"API error: {response.status_code}, attempting retry {retries + 1}" - ) - current_api_key = ( - await self.key_manager.handle_api_failure( - current_api_key - ) - ) - logger.info( - f"Switched to new API key: {current_api_key}" - ) + current_api_key = await self.key_manager.handle_api_failure(current_api_key) retries += 1 continue else: @@ -219,27 +205,21 @@ class ChatService: if line.startswith("data: "): try: chunk = json.loads(line[6:]) - openai_chunk = ( - self.convert_gemini_response_to_openai( - chunk, model, stream=True, finish_reason=None - ) + openai_chunk = self.convert_gemini_response_to_openai( + chunk, model, stream=True, finish_reason=None ) if openai_chunk: yield f"data: {json.dumps(openai_chunk)}\n\n" except json.JSONDecodeError: continue - yield f"data: {json.dumps({'finish_reason': 'stop'})}\n\n" + yield f"data: {json.dumps(self.convert_gemini_response_to_openai({}, model,stream=True, finish_reason='stop'))}\n\n" yield "data: [DONE]\n\n" - return # 成功完成,退出重试循环 + return except Exception as e: + logger.warning(f"Stream error: {str(e)}, attempting retry {retries + 1}") if retries < MAX_RETRIES - 1: - logger.warning( - f"Stream error: {str(e)}, attempting retry {retries + 1}" - ) - current_api_key = await self.key_manager.handle_api_failure( - current_api_key - ) + current_api_key = await self.key_manager.handle_api_failure(current_api_key) retries += 1 continue else: @@ -249,11 +229,17 @@ class ChatService: return generate() else: - async with httpx.AsyncClient() as client: - url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:generateContent?key={api_key}" - response = await client.post(url, json=payload) - gemini_response = response.json() - return self.convert_gemini_response_to_openai(gemini_response, model, finish_reason="stop") + try: + async with httpx.AsyncClient() as client: + url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:generateContent?key={api_key}" + response = await client.post(url, json=payload) + if response.status_code != 200: + raise Exception(f"API error: {response.status_code}") + gemini_response = response.json() + return self.convert_gemini_response_to_openai(gemini_response, model, finish_reason="stop") + except Exception as e: + logger.error(f"Error in non-stream completion: {str(e)}") + raise async def _openai_chat_completion( self,