From 80bcaf5cd4a692ea621a12bc3260f90332714e7b Mon Sep 17 00:00:00 2001 From: yinpeng <2291314224@qq.com> Date: Sat, 21 Dec 2024 00:42:08 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=20Gemini=20=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E9=85=8D=E7=BD=AE=E5=92=8C=E8=AF=B7=E6=B1=82=E5=8F=82?= =?UTF-8?q?=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/openai_routes.py | 7 +- app/core/config.py | 2 +- app/schemas/openai_models.py | 5 +- app/services/chat_service.py | 257 ++++++++++++++++++++++++----------- 4 files changed, 184 insertions(+), 87 deletions(-) diff --git a/app/api/openai_routes.py b/app/api/openai_routes.py index bc3dc9b..007abb7 100644 --- a/app/api/openai_routes.py +++ b/app/api/openai_routes.py @@ -53,13 +53,8 @@ async def chat_completion( while retries < MAX_RETRIES: try: response = await chat_service.create_chat_completion( - messages=request.messages, - model=request.model, - temperature=request.temperature, - stream=request.stream, + request=request, api_key=api_key, - tools=request.tools, - tool_choice=request.tool_choice, ) # 处理流式响应 diff --git a/app/core/config.py b/app/core/config.py index 38c82da..0ea7503 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -8,7 +8,7 @@ class Settings(BaseSettings): BASE_URL: str = "https://generativelanguage.googleapis.com/v1beta" MODEL_SEARCH: List[str] = ["gemini-2.0-flash-exp"] TOOLS_CODE_EXECUTION_ENABLED: bool = False - + SHOW_SEARCH_LINK: bool = True class Config: env_file = ".env" diff --git a/app/schemas/openai_models.py b/app/schemas/openai_models.py index 1b6807c..dc9b371 100644 --- a/app/schemas/openai_models.py +++ b/app/schemas/openai_models.py @@ -8,7 +8,10 @@ class ChatRequest(BaseModel): temperature: Optional[float] = 0.7 stream: Optional[bool] = False tools: Optional[List[dict]] = [] - tool_choice: Optional[str] = "auto" + max_tokens: Optional[int] = 8192 + stop: Optional[List[str]] = [] + top_p: Optional[float] = 0.9 + top_k: Optional[int] = 40 class EmbeddingRequest(BaseModel): diff --git a/app/services/chat_service.py b/app/services/chat_service.py index 2684568..c206f60 100644 --- a/app/services/chat_service.py +++ b/app/services/chat_service.py @@ -6,6 +6,7 @@ from typing import Dict, Any, Optional, AsyncGenerator, Union from app.core.config import settings from app.core.logger import get_chat_logger from app.schemas.gemini_models import GeminiRequest +from app.schemas.openai_models import ChatRequest logger = get_chat_logger() @@ -49,9 +50,8 @@ class ChatService: # 处理普通URL图片 parts.append( { - "inline_data": { - "mime_type": "image/jpeg", - "data": image_url, + "image_url": { + "url": image_url, } } ) @@ -61,7 +61,11 @@ class ChatService: return converted_messages def convert_gemini_response_to_openai( - self, response: Dict[str, Any], model: str, stream: bool = False, finish_reason: str = None + self, + response: Dict[str, Any], + model: str, + stream: bool = False, + finish_reason: str = None, ) -> Optional[Dict[str, Any]]: """Convert Gemini response to OpenAI format""" if stream: @@ -78,11 +82,28 @@ class ChatService: elif "codeExecution" in parts[0]: text = self.format_code_block(parts[0]["codeExecution"]) elif "executableCodeResult" in parts[0]: - text = self.format_execution_result(parts[0]["executableCodeResult"]) + text = self.format_execution_result( + parts[0]["executableCodeResult"] + ) elif "codeExecutionResult" in parts[0]: - text = self.format_execution_result(parts[0]["codeExecutionResult"]) + text = self.format_execution_result( + parts[0]["codeExecutionResult"] + ) else: text = "" + + if ( + settings.SHOW_SEARCH_LINK + and model.endswith("-search") + and "groundingMetadata" in candidate + and "groundingChunks" in candidate["groundingMetadata"] + ): + groundingChunks = candidate["groundingMetadata"]["groundingChunks"] + text += "\n\n---\n\n" + text += f"**【引用来源】**\n\n" + for _, groundingChunk in enumerate(groundingChunks, 1): + if "web" in groundingChunk: + text += self.create_search_link(groundingChunk["web"]) else: text = "" @@ -104,66 +125,87 @@ class ChatService: logger.debug(f"Raw response: {response}") return None else: - return { + res = { "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion", "created": int(time.time()), "model": model, "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": response["candidates"][0]["content"]["parts"][0][ - "text" - ], - }, - "finish_reason": finish_reason, - } - ], - "usage": { - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0, - }, - } + { + "index": 0, + "message": { + "role": "assistant", + "content": response["candidates"][0]["content"]["parts"][0]["text"], + }, + "finish_reason": finish_reason, + } + ], + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + }, + } + try: + if response.get("candidates"): + text = response["candidates"][0]["content"]["parts"][0]["text"] + candidate = response["candidates"][0] + if ( + settings.SHOW_SEARCH_LINK + and model.endswith("-search") + and "groundingMetadata" in candidate + and "groundingChunks" in candidate["groundingMetadata"] + ): + groundingChunks = candidate["groundingMetadata"]["groundingChunks"] + text += "\n\n---\n\n" + text += f"**【引用来源】**\n\n" + for _, groundingChunk in enumerate(groundingChunks, 1): + if "web" in groundingChunk: + text += self.create_search_link(groundingChunk["web"]) + res["choices"][0]["message"]["content"] = text + return res + else: + res["choices"][0]["message"]["content"] = "暂无返回" + return res + except Exception as e: + logger.error(f"Error converting Gemini response: {str(e)}") + logger.debug(f"Raw response: {response}") + res["choices"][0]["message"]["content"] = f"Error converting Gemini response: {str(e)}" + return res async def create_chat_completion( self, - messages: list, - model: str, - temperature: float, - stream: bool, + request: ChatRequest, api_key: str, - tools: Optional[list] = None, - tool_choice: Optional[str] = None, ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: """Create chat completion using either Gemini or OpenAI API""" - + model = request.model + tools = request.tools if tools is None: tools = [] - if settings.TOOLS_CODE_EXECUTION_ENABLED and not (model.endswith("-search") or "-thinking" in model): + if settings.TOOLS_CODE_EXECUTION_ENABLED and not ( + model.endswith("-search") or "-thinking" in model + ): tools.append({"code_execution": {}}) if model.endswith("-search"): tools.append({"googleSearch": {}}) - return await self._gemini_chat_completion( - messages, model, temperature, stream, api_key, tools - ) - # else: - # return await self._openai_chat_completion( - # messages, model, temperature, stream, api_key, tools - # ) + return await self._gemini_chat_completion(request, api_key, tools) async def _gemini_chat_completion( self, - messages: list, - model: str, - temperature: float, - stream: bool, + request: ChatRequest, api_key: str, tools: Optional[list] = None, ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: """Handle Gemini API chat completion""" + model = request.model + messages = request.messages + temperature = request.temperature + stream = request.stream + max_tokens = request.max_tokens + stop = request.stop + top_p = request.top_p + top_k = request.top_k if model.endswith("-search"): gemini_model = model[:-7] # Remove -search suffix else: @@ -176,14 +218,29 @@ class ChatService: tools.remove({"code_execution": {}}) payload = { "contents": gemini_messages, - "generationConfig": {"temperature": temperature}, + "generationConfig": { + "temperature": temperature, + "maxOutputTokens": max_tokens, + "stopSequences": stop, + "topP": top_p, + "topK": top_k, + }, "tools": tools, "safetySettings": [ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, - {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, - {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, - {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_NONE", + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_NONE", + }, + { + "category": "HARM_CATEGORY_CIVIC_INTEGRITY", + "threshold": "BLOCK_NONE", + }, ], } @@ -195,16 +252,26 @@ class ChatService: while retries < MAX_RETRIES: try: - timeout = httpx.Timeout(60.0, read=60.0) # 连接超时60秒,读取超时60秒 + timeout = httpx.Timeout( + 60.0, read=60.0 + ) # 连接超时60秒,读取超时60秒 async with httpx.AsyncClient(timeout=timeout) as client: stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:streamGenerateContent?alt=sse&key={current_api_key}" - async with client.stream('POST', stream_url, json=payload) as response: + async with client.stream( + "POST", stream_url, json=payload + ) as response: if response.status_code != 200: error_content = await response.read() - error_msg = error_content.decode('utf-8') - logger.error(f"API error: {response.status_code}, {error_msg}") + error_msg = error_content.decode("utf-8") + logger.error( + f"API error: {response.status_code}, {error_msg}" + ) if retries < MAX_RETRIES - 1: - current_api_key = await self.key_manager.handle_api_failure(current_api_key) + current_api_key = ( + await self.key_manager.handle_api_failure( + current_api_key + ) + ) retries += 1 continue else: @@ -218,8 +285,13 @@ class ChatService: if line.startswith("data: "): try: chunk = json.loads(line[6:]) - openai_chunk = self.convert_gemini_response_to_openai( - chunk, model, stream=True, finish_reason=None + openai_chunk = ( + self.convert_gemini_response_to_openai( + chunk, + model, + stream=True, + finish_reason=None, + ) ) if openai_chunk: yield f"data: {json.dumps(openai_chunk)}\n\n" @@ -230,20 +302,30 @@ class ChatService: return except httpx.ReadTimeout: - logger.warning(f"Read timeout occurred, attempting retry {retries + 1}") + logger.warning( + f"Read timeout occurred, attempting retry {retries + 1}" + ) if retries < MAX_RETRIES - 1: - current_api_key = await self.key_manager.handle_api_failure(current_api_key) + current_api_key = await self.key_manager.handle_api_failure( + current_api_key + ) logger.info(f"Switched to new API key: {current_api_key}") retries += 1 continue else: - logger.error(f"Max retries reached. Final error: Read timeout") + logger.error( + f"Max retries reached. Final error: Read timeout" + ) yield f"data: {json.dumps({'error': 'Read timeout'})}\n\n" return except Exception as e: - logger.exception(f"Stream error: {str(e)}, attempting retry {retries + 1}") + logger.exception( + f"Stream error: {str(e)}, attempting retry {retries + 1}" + ) if retries < MAX_RETRIES - 1: - current_api_key = await self.key_manager.handle_api_failure(current_api_key) + current_api_key = await self.key_manager.handle_api_failure( + current_api_key + ) logger.info(f"Switched to new API key: {current_api_key}") retries += 1 continue @@ -262,9 +344,13 @@ class ChatService: if response.status_code != 200: error_text = response.text error_code = response.status_code - raise Exception(f"API调用错误 - 状态码: {error_code}, 响应内容: {error_text}") + raise Exception( + f"API调用错误 - 状态码: {error_code}, 响应内容: {error_text}" + ) gemini_response = response.json() - return self.convert_gemini_response_to_openai(gemini_response, model, finish_reason="stop") + return self.convert_gemini_response_to_openai( + gemini_response, model, stream=False, finish_reason="stop" + ) except Exception as e: logger.error(f"Error in non-stream completion") raise @@ -283,10 +369,7 @@ class ChatService: return f"""\n【执行结果】\n> outcome: {outcome}\n\n【输出结果】\n```plaintext\n{output}\n```\n""" async def generate_content( - self, - model_name: str, - request: GeminiRequest, - api_key: str + self, model_name: str, request: GeminiRequest, api_key: str ) -> dict: """调用Gemini API生成内容""" url = f"{self.base_url}/models/{model_name}:generateContent?key={api_key}" @@ -301,16 +384,15 @@ class ChatService: error_text = response.text logger.error(f"Error: {response.status_code}") logger.error(error_text) - raise Exception(f"API request failed with status {response.status_code}: {error_text}") + raise Exception( + f"API request failed with status {response.status_code}: {error_text}" + ) except Exception as e: logger.error(f"Request failed: {str(e)}") raise async def stream_generate_content( - self, - model_name: str, - request: GeminiRequest, - api_key: str + self, model_name: str, request: GeminiRequest, api_key: str ) -> AsyncGenerator: """调用Gemini API流式生成内容""" retries = 0 @@ -321,19 +403,29 @@ class ChatService: try: url = f"{self.base_url}/models/{model_name}:streamGenerateContent?alt=sse&key={current_api_key}" timeout = httpx.Timeout(60.0, read=60.0) - + async with httpx.AsyncClient(timeout=timeout) as client: - async with client.stream('POST', url, json=request.model_dump()) as response: + async with client.stream( + "POST", url, json=request.model_dump() + ) as response: if response.status_code != 200: error_text = await response.text() logger.error(f"Error: {response.status_code}: {error_text}") if retries < MAX_RETRIES - 1: - current_api_key = await self.key_manager.handle_api_failure(current_api_key) - logger.info(f"Switched to new API key: {current_api_key}") + current_api_key = ( + await self.key_manager.handle_api_failure( + current_api_key + ) + ) + logger.info( + f"Switched to new API key: {current_api_key}" + ) retries += 1 continue - raise Exception(f"API request failed with status {response.status_code}: {error_text}") - + raise Exception( + f"API request failed with status {response.status_code}: {error_text}" + ) + async for line in response.aiter_lines(): yield line + "\n\n" return @@ -341,7 +433,9 @@ class ChatService: except httpx.ReadTimeout: logger.warning(f"Read timeout occurred, attempting retry {retries + 1}") if retries < MAX_RETRIES - 1: - current_api_key = await self.key_manager.handle_api_failure(current_api_key) + current_api_key = await self.key_manager.handle_api_failure( + current_api_key + ) logger.info(f"Switched to new API key: {current_api_key}") retries += 1 continue @@ -350,8 +444,13 @@ class ChatService: except Exception as e: logger.error(f"Streaming request failed: {str(e)}") if retries < MAX_RETRIES - 1: - current_api_key = await self.key_manager.handle_api_failure(current_api_key) + current_api_key = await self.key_manager.handle_api_failure( + current_api_key + ) logger.info(f"Switched to new API key: {current_api_key}") retries += 1 continue - raise \ No newline at end of file + raise + + def create_search_link(self, web): + return f'\n- [{web["title"]}]({web["uri"]})'