From 870b1ecc1786db217c8f3fefc9b02f0786272f44 Mon Sep 17 00:00:00 2001 From: yinpeng <2291314224@qq.com> Date: Fri, 27 Dec 2024 20:07:43 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E9=87=8D=E8=AF=95?= =?UTF-8?q?=E6=9C=BA=E5=88=B6=E5=92=8C=E6=B6=88=E6=81=AF=E8=BD=AC=E6=8D=A2?= =?UTF-8?q?=E5=99=A8=EF=BC=8C=E5=B9=B6=E6=94=AF=E6=8C=81Gemini=20v1beta=20?= =?UTF-8?q?API?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/gemini_routes.py | 82 ++-- app/api/openai_routes.py | 56 ++- app/core/logger.py | 4 + app/main.py | 1 + app/services/chat/api_client.py | 49 +++ app/services/chat/message_converter.py | 50 +++ app/services/chat/response_handler.py | 322 +++++++++++++++ app/services/chat/retry_handler.py | 40 ++ app/services/chat_service.py | 523 ------------------------- app/services/gemini_chat_service.py | 89 +++++ app/services/openai_chat_service.py | 136 +++++++ app/utils/helpers.py | 0 12 files changed, 755 insertions(+), 597 deletions(-) create mode 100644 app/services/chat/api_client.py create mode 100644 app/services/chat/message_converter.py create mode 100644 app/services/chat/response_handler.py create mode 100644 app/services/chat/retry_handler.py delete mode 100644 app/services/chat_service.py create mode 100644 app/services/gemini_chat_service.py create mode 100644 app/services/openai_chat_service.py delete mode 100644 app/utils/helpers.py diff --git a/app/api/gemini_routes.py b/app/api/gemini_routes.py index acaeb4c..e4b965c 100644 --- a/app/api/gemini_routes.py +++ b/app/api/gemini_routes.py @@ -1,3 +1,4 @@ +from http.client import HTTPException from fastapi import APIRouter, Depends from fastapi.responses import StreamingResponse @@ -5,18 +6,18 @@ from app.core.config import settings from app.core.logger import get_gemini_logger from app.core.security import SecurityService from app.schemas.gemini_models import GeminiRequest -from app.services.chat_service import ChatService +from app.services.gemini_chat_service import GeminiChatService from app.services.key_manager import KeyManager from app.services.model_service import ModelService - +from app.services.chat.retry_handler import RetryHandler router = APIRouter(prefix="/gemini/v1beta") +router_v1beta = APIRouter(prefix="/v1beta") logger = get_gemini_logger() # 初始化服务 security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN) key_manager = KeyManager(settings.API_KEYS) model_service = ModelService(settings.MODEL_SEARCH) -chat_service = ChatService(base_url=settings.BASE_URL, key_manager=key_manager) @router.get("/models") @@ -34,58 +35,51 @@ async def list_models( return models_json @router.post("/models/{model_name}:generateContent") +@RetryHandler(max_retries=3, key_manager=key_manager, key_arg="api_key") async def generate_content( model_name: str, request: GeminiRequest, - x_goog_api_key: str = Depends(security_service.verify_goog_api_key), + # x_goog_api_key: str = Depends(security_service.verify_goog_api_key), + api_key: str = Depends(key_manager.get_next_working_key), ): + chat_service = GeminiChatService(settings.BASE_URL, key_manager) """非流式生成内容""" logger.info("-" * 50 + "gemini_generate_content" + "-" * 50) logger.info(f"Handling Gemini content generation request for model: {model_name}") logger.info(f"Request: \n{request.model_dump_json(indent=2)}") - - api_key = await key_manager.get_next_working_key() - logger.info(f"Using API key: {api_key}") - retries = 0 - MAX_RETRIES = 3 - - while retries < MAX_RETRIES: - try: - response = await chat_service.generate_content( - model_name=model_name, - request=request, - api_key=api_key - ) - return response - - except Exception as e: - logger.warning( - f"API call failed with error: {str(e)}. Attempt {retries + 1} of {MAX_RETRIES}" - ) - api_key = await key_manager.handle_api_failure(api_key) - logger.info(f"Switched to new API key: {api_key}") - retries += 1 - if retries >= MAX_RETRIES: - logger.error(f"Max retries ({MAX_RETRIES}) reached. Raising error") - - -@router.post("/models/{model_name}:streamGenerateContent") -async def stream_generate_content( - model_name: str, - request: GeminiRequest, - x_goog_api_key: str = Depends(security_service.verify_goog_api_key), -): - """流式生成内容""" - logger.info("-" * 50 + "gemini_stream_generate_content" + "-" * 50) - logger.info(f"Handling Gemini streaming content generation for model: {model_name}") - - api_key = await key_manager.get_next_working_key() logger.info(f"Using API key: {api_key}") try: - chat_service = ChatService(base_url=settings.BASE_URL, key_manager=key_manager) - response_stream = chat_service.stream_generate_content( - model_name=model_name, + response = chat_service.generate_content( + model=model_name, + request=request, + api_key=api_key + ) + return response + + except Exception as e: + logger.error(f"Chat completion failed after retries: {str(e)}") + raise HTTPException(status_code=500, detail="Chat completion failed") from e + + +@router.post("/models/{model_name}:streamGenerateContent") +@RetryHandler(max_retries=3, key_manager=key_manager, key_arg="api_key") +async def stream_generate_content( + model_name: str, + request: GeminiRequest, + # x_goog_api_key: str = Depends(security_service.verify_goog_api_key), + api_key: str = Depends(key_manager.get_next_working_key), +): + chat_service = GeminiChatService(settings.BASE_URL, key_manager) + """流式生成内容""" + logger.info("-" * 50 + "gemini_stream_generate_content" + "-" * 50) + logger.info(f"Handling Gemini streaming content generation for model: {model_name}") + logger.info(f"Request: \n{request.model_dump_json(indent=2)}") + logger.info(f"Using API key: {api_key}") + + try: + response_stream =chat_service.stream_generate_content( + model=model_name, request=request, api_key=api_key ) diff --git a/app/api/openai_routes.py b/app/api/openai_routes.py index 007b403..32e03b7 100644 --- a/app/api/openai_routes.py +++ b/app/api/openai_routes.py @@ -3,9 +3,10 @@ from fastapi import APIRouter, Depends, Header from fastapi.responses import StreamingResponse from app.core.security import SecurityService +from app.services.chat.retry_handler import RetryHandler from app.services.key_manager import KeyManager from app.services.model_service import ModelService -from app.services.chat_service import ChatService +from app.services.openai_chat_service import OpenAIChatService from app.services.embedding_service import EmbeddingService from app.schemas.openai_models import ChatRequest, EmbeddingRequest from app.core.config import settings @@ -31,47 +32,42 @@ async def list_models( logger.info("Handling models list request") api_key = await key_manager.get_next_working_key() logger.info(f"Using API key: {api_key}") - return model_service.get_gemini_openai_models(api_key) + try: + return model_service.get_gemini_openai_models(api_key) + except Exception as e: + logger.error(f"Error getting models list: {str(e)}") + raise HTTPException(status_code=500, detail="Internal server error while fetching models list") from e @router.post("/v1/chat/completions") @router.post("/hf/v1/chat/completions") +@RetryHandler(max_retries=3, key_manager=key_manager, key_arg="api_key") async def chat_completion( request: ChatRequest, authorization: str = Header(None), token: str = Depends(security_service.verify_authorization), + api_key: str = Depends(key_manager.get_next_working_key), ): - chat_service = ChatService(settings.BASE_URL, key_manager) + chat_service = OpenAIChatService(settings.BASE_URL, key_manager) logger.info("-" * 50 + "chat_completion" + "-" * 50) logger.info(f"Handling chat completion request for model: {request.model}") logger.info(f"Request: \n{request.model_dump_json(indent=2)}") - api_key = await key_manager.get_next_working_key() logger.info(f"Using API key: {api_key}") - retries = 0 - max_retries = 3 + try: + response = await chat_service.create_chat_completion( + request=request, + api_key=api_key, + ) + # 处理流式响应 + if request.stream: + return StreamingResponse(response, media_type="text/event-stream") + logger.info("Chat completion request successful") + return response - while retries < max_retries: - try: - response = await chat_service.create_chat_completion( - request=request, - api_key=api_key, - ) - - # 处理流式响应 - if request.stream: - return StreamingResponse(response, media_type="text/event-stream") - return response - - except Exception as e: - logger.warning( - f"API call failed with error: {str(e)}. Attempt {retries + 1} of {max_retries}" - ) - api_key = await key_manager.handle_api_failure(api_key) - logger.info(f"Switched to new API key: {api_key}") - retries += 1 - if retries >= max_retries: - logger.error(f"Max retries ({max_retries}) reached. Raising error") - raise + except Exception as e: + logger.error(f"Chat completion failed after retries: {str(e)}") + raise HTTPException(status_code=500, detail="Chat completion failed") from e + @router.post("/v1/embeddings") @@ -93,7 +89,7 @@ async def embedding( return response except Exception as e: logger.error(f"Embedding request failed: {str(e)}") - raise + raise HTTPException(status_code=500, detail="Embedding request failed") from e @router.get("/v1/keys/list") @@ -120,4 +116,4 @@ async def get_keys_list( raise HTTPException( status_code=500, detail="Internal server error while fetching keys list" - ) + ) from e diff --git a/app/core/logger.py b/app/core/logger.py index bc84e23..62f722a 100644 --- a/app/core/logger.py +++ b/app/core/logger.py @@ -125,3 +125,7 @@ def get_embeddings_logger(): def get_request_logger(): return Logger.setup_logger("request") + + +def get_retry_logger(): + return Logger.setup_logger("retry") \ No newline at end of file diff --git a/app/main.py b/app/main.py index f2f4d4a..85bc6d9 100644 --- a/app/main.py +++ b/app/main.py @@ -29,6 +29,7 @@ app.add_middleware( # 包含所有路由 app.include_router(openai_routes.router) app.include_router(gemini_routes.router) +app.include_router(gemini_routes.router_v1beta) @app.get("/health") diff --git a/app/services/chat/api_client.py b/app/services/chat/api_client.py new file mode 100644 index 0000000..a27302b --- /dev/null +++ b/app/services/chat/api_client.py @@ -0,0 +1,49 @@ +# app/services/chat/api_client.py + +from typing import Dict, Any, AsyncGenerator +import httpx +from abc import ABC, abstractmethod + +class ApiClient(ABC): + """API客户端基类""" + + @abstractmethod + async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]: + pass + + @abstractmethod + async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]: + pass + +class GeminiApiClient(ApiClient): + """Gemini API客户端""" + + def __init__(self, base_url: str, timeout: int = 300): + self.base_url = base_url + self.timeout = timeout + + def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]: + timeout = httpx.Timeout(self.timeout, read=self.timeout) + if model.endswith("-search"): + model = model[:-7] + with httpx.Client(timeout=timeout) as client: + url = f"{self.base_url}/models/{model}:generateContent?key={api_key}" + response = client.post(url, json=payload) + if response.status_code != 200: + error_content = response.text + raise Exception(f"API call failed with status code {response.status_code}, {error_content}") + return response.json() + + async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]: + timeout = httpx.Timeout(self.timeout, read=self.timeout) + if model.endswith("-search"): + model = model[:-7] + async with httpx.AsyncClient(timeout=timeout) as client: + url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}" + async with client.stream("POST", url, json=payload) as response: + if response.status_code != 200: + error_content = await response.aread() + error_msg = error_content.decode("utf-8") + raise Exception(f"API call failed with status code {response.status_code}, {error_msg}") + async for line in response.aiter_lines(): + yield line diff --git a/app/services/chat/message_converter.py b/app/services/chat/message_converter.py new file mode 100644 index 0000000..17372ae --- /dev/null +++ b/app/services/chat/message_converter.py @@ -0,0 +1,50 @@ +# app/services/chat/message_converter.py + +from abc import ABC, abstractmethod +from typing import List, Dict, Any + +class MessageConverter(ABC): + """消息转换器基类""" + + @abstractmethod + def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + pass + +class OpenAIMessageConverter(MessageConverter): + """OpenAI消息格式转换器""" + + def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + converted_messages = [] + for msg in messages: + role = "user" if msg["role"] == "user" else "model" + parts = [] + + if isinstance(msg["content"], str): + parts.append({"text": msg["content"]}) + elif isinstance(msg["content"], list): + for content in msg["content"]: + if isinstance(content, str): + parts.append({"text": content}) + elif isinstance(content, dict): + if content["type"] == "text": + parts.append({"text": content["text"]}) + elif content["type"] == "image_url": + parts.append(self._convert_image(content["image_url"]["url"])) + + converted_messages.append({"role": role, "parts": parts}) + + return converted_messages + + def _convert_image(self, image_url: str) -> Dict[str, Any]: + if image_url.startswith("data:image"): + return { + "inline_data": { + "mime_type": "image/jpeg", + "data": image_url.split(",")[1] + } + } + return { + "image_url": { + "url": image_url + } + } \ No newline at end of file diff --git a/app/services/chat/response_handler.py b/app/services/chat/response_handler.py new file mode 100644 index 0000000..4c919a3 --- /dev/null +++ b/app/services/chat/response_handler.py @@ -0,0 +1,322 @@ +# app/services/chat/response_handler.py + +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional +import time +import uuid +from app.core.config import settings + +class ResponseHandler(ABC): + """响应处理器基类""" + + @abstractmethod + def handle_response(self, response: Dict[str, Any], model: str, stream: bool = False) -> Dict[str, Any]: + pass + +class GeminiResponseHandler(ResponseHandler): + """Gemini响应处理器""" + + def __init__(self): + self.thinking_first = True + self.thinking_status = False + + def handle_response(self, response: Dict[str, Any], model: str, stream: bool = False) -> Dict[str, Any]: + if stream: + return self._handle_stream_response(response, model, stream) + return self._handle_normal_response(response, model, stream) + + def _handle_stream_response(self, response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]: + text = self._extract_text(response, model, stream=stream) + content = {"parts": [{"text": text}],"role": "model"} + response["candidates"][0]["content"] = content + return response + + def _handle_normal_response(self, response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]: + text = self._extract_text(response, model, stream=stream) + content = {"parts": [{"text": text}],"role": "model"} + response["candidates"][0]["content"] = content + return response + + def _extract_text(self, response: Dict[str, Any], model: str, stream: bool = False) -> str: + text = "" + if stream: + if response.get("candidates"): + candidate = response["candidates"][0] + content = candidate.get("content", {}) + parts = content.get("parts", []) + if "thinking" in model: + if settings.SHOW_THINKING_PROCESS: + if len(parts) == 1: + if self.thinking_first: + self.thinking_first = False + self.thinking_status = True + text = "> thinking\n\n" + parts[0].get("text") + else: + text = parts[0].get("text") + + if len(parts) == 2: + self.thinking_status = False + if self.thinking_first: + self.thinking_first = False + text = ( + "> thinking\n\n" + + parts[0].get("text") + + "\n\n---\n> output\n\n" + + parts[1].get("text") + ) + else: + text = ( + parts[0].get("text") + + "\n\n---\n> output\n\n" + + parts[1].get("text") + ) + else: + if len(parts) == 1: + if self.thinking_first: + self.thinking_first = False + self.thinking_status = True + text = "" + elif self.thinking_status: + text = "" + else: + text = parts[0].get("text") + + if len(parts) == 2: + self.thinking_status = False + if self.thinking_first: + self.thinking_first = False + text = parts[1].get("text") + else: + text = parts[1].get("text") + else: + if "text" in parts[0]: + text = parts[0].get("text") + elif "executableCode" in parts[0]: + text = _format_code_block(parts[0]["executableCode"]) + elif "codeExecution" in parts[0]: + text = _format_code_block(parts[0]["codeExecution"]) + elif "executableCodeResult" in parts[0]: + text = _format_execution_result( + parts[0]["executableCodeResult"] + ) + elif "codeExecutionResult" in parts[0]: + text = _format_execution_result( + parts[0]["codeExecutionResult"] + ) + else: + text = "" + text = _add_search_link_text(model, candidate, text) + else: + if response.get("candidates"): + candidate = response["candidates"][0] + if "thinking" in model: + if settings.SHOW_THINKING_PROCESS: + if len(candidate["content"]["parts"]) == 2: + text = ( + "> thinking\n\n" + + candidate["content"]["parts"][0]["text"] + + "\n\n---\n> output\n\n" + + candidate["content"]["parts"][1]["text"] + ) + else: + text = candidate["content"]["parts"][0]["text"] + else: + if len(candidate["content"]["parts"]) == 2: + text = candidate["content"]["parts"][1]["text"] + else: + text = candidate["content"]["parts"][0]["text"] + else: + text = candidate["content"]["parts"][0]["text"] + text = _add_search_link_text(model, candidate, text) + else: + text = "暂无返回" + return text + + + +class OpenAIResponseHandler(ResponseHandler): + """OpenAI响应处理器""" + + def __init__(self, config): + self.config = config + self.thinking_first = True + self.thinking_status = False + + def handle_response( + self, + response: Dict[str, Any], + model: str, + stream: bool = False, + finish_reason: str = None + ) -> Optional[Dict[str, Any]]: + if stream: + return self._handle_stream_response(response, model, finish_reason) + return self._handle_normal_response(response, model, finish_reason) + + def _handle_stream_response(self, response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]: + text = self._extract_text(response, model, stream=True) + return { + "id": f"chatcmpl-{uuid.uuid4()}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [{ + "index": 0, + "delta": {"content": text} if text else {}, + "finish_reason": finish_reason + }] + } + + + def _handle_normal_response(self, response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]: + text = self._extract_text(response, model, stream=False) + return { + "id": f"chatcmpl-{uuid.uuid4()}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": text + }, + "finish_reason": finish_reason + }], + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + } + } + + def _extract_text(self, response: Dict[str, Any], model: str, stream: bool = False) -> str: + text = "" + if stream: + if response.get("candidates"): + candidate = response["candidates"][0] + content = candidate.get("content", {}) + parts = content.get("parts", []) + if "thinking" in model: + if settings.SHOW_THINKING_PROCESS: + if len(parts) == 1: + if self.thinking_first: + self.thinking_first = False + self.thinking_status = True + text = "> thinking\n\n" + parts[0].get("text") + else: + text = parts[0].get("text") + + if len(parts) == 2: + self.thinking_status = False + if self.thinking_first: + self.thinking_first = False + text = ( + "> thinking\n\n" + + parts[0].get("text") + + "\n\n---\n> output\n\n" + + parts[1].get("text") + ) + else: + text = ( + parts[0].get("text") + + "\n\n---\n> output\n\n" + + parts[1].get("text") + ) + else: + if len(parts) == 1: + if self.thinking_first: + self.thinking_first = False + self.thinking_status = True + text = "" + elif self.thinking_status: + text = "" + else: + text = parts[0].get("text") + + if len(parts) == 2: + self.thinking_status = False + if self.thinking_first: + self.thinking_first = False + text = parts[1].get("text") + else: + text = parts[1].get("text") + else: + if "text" in parts[0]: + text = parts[0].get("text") + elif "executableCode" in parts[0]: + text = _format_code_block(parts[0]["executableCode"]) + elif "codeExecution" in parts[0]: + text = _format_code_block(parts[0]["codeExecution"]) + elif "executableCodeResult" in parts[0]: + text = _format_execution_result( + parts[0]["executableCodeResult"] + ) + elif "codeExecutionResult" in parts[0]: + text = _format_execution_result( + parts[0]["codeExecutionResult"] + ) + else: + text = "" + text = _add_search_link_text(model, candidate, text) + else: + if response.get("candidates"): + candidate = response["candidates"][0] + if "thinking" in model: + if settings.SHOW_THINKING_PROCESS: + if len(candidate["content"]["parts"]) == 2: + text = ( + "> thinking\n\n" + + candidate["content"]["parts"][0]["text"] + + "\n\n---\n> output\n\n" + + candidate["content"]["parts"][1]["text"] + ) + else: + text = candidate["content"]["parts"][0]["text"] + else: + if len(candidate["content"]["parts"]) == 2: + text = candidate["content"]["parts"][1]["text"] + else: + text = candidate["content"]["parts"][0]["text"] + else: + text = candidate["content"]["parts"][0]["text"] + text = _add_search_link_text(model, candidate, text) + else: + text = "暂无返回" + return text + + +def _format_code_block(code_data: dict) -> str: + """格式化代码块输出""" + language = code_data.get("language", "").lower() + code = code_data.get("code", "").strip() + return f"""\n\n---\n\n【代码执行】\n```{language}\n{code}\n```\n""" + + +def _add_search_link_text(model:str, candidate:dict, text:str) -> str: + if ( + settings.SHOW_SEARCH_LINK + and model.endswith("-search") + and "groundingMetadata" in candidate + and "groundingChunks" in candidate["groundingMetadata"] + ): + grounding_chunks = candidate["groundingMetadata"]["groundingChunks"] + text += "\n\n---\n\n" + text += "**【引用来源】**\n\n" + for _, grounding_chunk in enumerate(grounding_chunks, 1): + if "web" in grounding_chunk: + text += _create_search_link(grounding_chunk["web"]) + return text + else: + return text + + +def _create_search_link(grounding_chunk: dict) -> str: + return f'\n- [{grounding_chunk["title"]}]({grounding_chunk["uri"]})' + + +def _format_execution_result(result_data: dict) -> str: + """格式化执行结果输出""" + outcome = result_data.get("outcome", "") + output = result_data.get("output", "").strip() + return f"""\n【执行结果】\n> outcome: {outcome}\n\n【输出结果】\n```plaintext\n{output}\n```\n\n---\n\n""" \ No newline at end of file diff --git a/app/services/chat/retry_handler.py b/app/services/chat/retry_handler.py new file mode 100644 index 0000000..a8915aa --- /dev/null +++ b/app/services/chat/retry_handler.py @@ -0,0 +1,40 @@ +# app/services/chat/retry_handler.py + +from typing import TypeVar, Callable +from functools import wraps +from app.core.logger import get_retry_logger +from app.services.key_manager import KeyManager + +T = TypeVar('T') +logger = get_retry_logger() + +class RetryHandler: + """重试处理装饰器""" + + def __init__(self, max_retries: int = 3, key_manager: KeyManager = None, key_arg: str = "api_key"): + self.max_retries = max_retries + self.key_manager = key_manager + self.key_arg = key_arg + + def __call__(self, func: Callable[..., T]) -> Callable[..., T]: + @wraps(func) + async def wrapper(*args, **kwargs) -> T: + last_exception = None + + for attempt in range(self.max_retries): + try: + return await func(*args, **kwargs) + except Exception as e: + last_exception = e + logger.warning(f"API call failed with error: {str(e)}. Attempt {attempt + 1} of {self.max_retries}") + + if self.key_manager: + old_key = kwargs.get(self.key_arg) + new_key = await self.key_manager.handle_api_failure(old_key) + kwargs[self.key_arg] = new_key + logger.info(f"Switched to new API key: {new_key}") + + logger.error(f"All retry attempts failed, raising final exception: {str(last_exception)}") + raise last_exception + + return wrapper \ No newline at end of file diff --git a/app/services/chat_service.py b/app/services/chat_service.py deleted file mode 100644 index 63777bb..0000000 --- a/app/services/chat_service.py +++ /dev/null @@ -1,523 +0,0 @@ -import httpx -import json -import time -import uuid -from typing import Dict, Any, Optional, AsyncGenerator, Union -from app.core.config import settings -from app.core.logger import get_chat_logger -from app.schemas.gemini_models import GeminiRequest -from app.schemas.openai_models import ChatRequest - -logger = get_chat_logger() - - -def convert_messages_to_gemini_format(messages: list) -> list: - """Convert OpenAI message format to Gemini format""" - converted_messages = [] - for msg in messages: - role = "user" if msg["role"] == "user" else "model" - parts = [] - - # 处理文本内容 - if isinstance(msg["content"], str): - parts.append({"text": msg["content"]}) - # 处理包含图片的消息 - elif isinstance(msg["content"], list): - for content in msg["content"]: - if isinstance(content, str): - parts.append({"text": content}) - elif isinstance(content, dict) and content["type"] == "text": - parts.append({"text": content["text"]}) - elif isinstance(content, dict) and content["type"] == "image_url": - # 处理图片URL - image_url = content["image_url"]["url"] - if image_url.startswith("data:image"): - # 处理base64图片 - parts.append( - { - "inline_data": { - "mime_type": "image/jpeg", - "data": image_url.split(",")[1], - } - } - ) - else: - # 处理普通URL图片 - parts.append( - { - "image_url": { - "url": image_url, - } - } - ) - - converted_messages.append({"role": role, "parts": parts}) - - return converted_messages - - -def format_execution_result(result_data: dict) -> str: - """格式化执行结果输出""" - outcome = result_data.get("outcome", "") - output = result_data.get("output", "").strip() - return f"""\n【执行结果】\n> outcome: {outcome}\n\n【输出结果】\n```plaintext\n{output}\n```\n""" - - -def create_search_link(web): - return f'\n- [{web["title"]}]({web["uri"]})' - - -class ChatService: - def __init__(self, base_url: str, key_manager=None): - self.base_url = base_url - self.key_manager = key_manager - self.thinking_first = True - self.thinking_status = False - - def convert_gemini_response_to_openai( - self, - response: Dict[str, Any], - model: str, - stream: bool = False, - finish_reason: str = None, - ) -> Optional[Dict[str, Any]]: - """Convert Gemini response to OpenAI format""" - if stream: - try: - text = "" - if response.get("candidates"): - candidate = response["candidates"][0] - content = candidate.get("content", {}) - parts = content.get("parts", []) - - if "thinking" in model: - if settings.SHOW_THINKING_PROCESS: - if len(parts) == 1: - if self.thinking_first: - self.thinking_first = False - self.thinking_status = True - text = "> thinking\n\n" + parts[0].get("text") - else: - text = parts[0].get("text") - - if len(parts) == 2: - self.thinking_status = False - if self.thinking_first: - self.thinking_first = False - text = ( - "> thinking\n\n" - + parts[0].get("text") - + "\n\n---\n> output\n\n" - + parts[1].get("text") - ) - else: - text = ( - parts[0].get("text") - + "\n\n---\n> output\n\n" - + parts[1].get("text") - ) - else: - if len(parts) == 1: - if self.thinking_first: - self.thinking_first = False - self.thinking_status = True - text = "" - elif self.thinking_status: - text = "" - else: - text = parts[0].get("text") - - if len(parts) == 2: - self.thinking_status = False - if self.thinking_first: - self.thinking_first = False - text = parts[1].get("text") - else: - text = parts[1].get("text") - else: - if "text" in parts[0]: - text = parts[0].get("text") - elif "executableCode" in parts[0]: - text = self.format_code_block(parts[0]["executableCode"]) - elif "codeExecution" in parts[0]: - text = self.format_code_block(parts[0]["codeExecution"]) - elif "executableCodeResult" in parts[0]: - text = format_execution_result( - parts[0]["executableCodeResult"] - ) - elif "codeExecutionResult" in parts[0]: - text = format_execution_result( - parts[0]["codeExecutionResult"] - ) - else: - text = "" - - text = self.add_search_link_text(model, candidate, text) - - return { - "id": f"chatcmpl-{uuid.uuid4()}", - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": model, - "choices": [ - { - "index": 0, - "delta": {"content": text} if text else {}, - "finish_reason": finish_reason, - } - ], - } - except Exception as e: - logger.error(f"Error converting Gemini response: {str(e)}") - logger.debug(f"Raw response: {response}") - return None - else: - res = { - "id": f"chatcmpl-{uuid.uuid4()}", - "object": "chat.completion", - "created": int(time.time()), - "model": model, - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": response["candidates"][0]["content"]["parts"][0]["text"], - }, - "finish_reason": finish_reason, - } - ], - "usage": { - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0, - }, - } - try: - if response.get("candidates"): - candidate = response["candidates"][0] - if "thinking" in model: - if settings.SHOW_THINKING_PROCESS: - if len(candidate["content"]["parts"]) == 2: - text = ( - "> thinking\n\n" - + candidate["content"]["parts"][0]["text"] - + "\n\n---\n> output\n\n" - + candidate["content"]["parts"][1]["text"] - ) - else: - text = candidate["content"]["parts"][0]["text"] - else: - if len(candidate["content"]["parts"]) == 2: - text = candidate["content"]["parts"][1]["text"] - else: - text = candidate["content"]["parts"][0]["text"] - else: - text = candidate["content"]["parts"][0]["text"] - - text = self.add_search_link_text(model, candidate, text) - res["choices"][0]["message"]["content"] = text - return res - else: - res["choices"][0]["message"]["content"] = "暂无返回" - return res - except Exception as e: - logger.error(f"Error converting Gemini response: {str(e)}") - logger.debug(f"Raw response: {response}") - res["choices"][0]["message"][ - "content" - ] = f"Error converting Gemini response: {str(e)}" - return res - - def add_search_link_text(self, model, candidate, text): - if ( - settings.SHOW_SEARCH_LINK - and model.endswith("-search") - and "groundingMetadata" in candidate - and "groundingChunks" in candidate["groundingMetadata"] - ): - grounding_chunks = candidate["groundingMetadata"]["groundingChunks"] - text += "\n\n---\n\n" - text += "**【引用来源】**\n\n" - for _, grounding_chunk in enumerate(grounding_chunks, 1): - if "web" in grounding_chunk: - text += create_search_link(grounding_chunk["web"]) - return text - else: - return text - - async def create_chat_completion( - self, - request: ChatRequest, - api_key: str, - ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: - """Create chat completion using either Gemini or OpenAI API""" - model = request.model - tools = request.tools - if tools is None: - tools = [] - if settings.TOOLS_CODE_EXECUTION_ENABLED and not ( - model.endswith("-search") or "-thinking" in model - ): - tools.append({"code_execution": {}}) - if model.endswith("-search"): - tools.append({"googleSearch": {}}) - return await self._gemini_chat_completion(request, api_key, tools) - - async def _gemini_chat_completion( - self, - request: ChatRequest, - api_key: str, - tools: Optional[list] = None, - ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: - """Handle Gemini API chat completion""" - model = request.model - messages = request.messages - temperature = request.temperature - stream = request.stream - max_tokens = request.max_tokens - stop = request.stop - top_p = request.top_p - top_k = request.top_k - if model.endswith("-search"): - gemini_model = model[:-7] # Remove -search suffix - else: - gemini_model = model - gemini_messages = convert_messages_to_gemini_format(messages) - - if not stream: - # 非流式模式下,移除代码执行工具 - if {"code_execution": {}} in tools: - tools.remove({"code_execution": {}}) - payload = { - "contents": gemini_messages, - "generationConfig": { - "temperature": temperature, - "maxOutputTokens": max_tokens, - "stopSequences": stop, - "topP": top_p, - "topK": top_k, - }, - "tools": tools, - "safetySettings": [ - {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, - {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, - { - "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", - "threshold": "BLOCK_NONE", - }, - { - "category": "HARM_CATEGORY_DANGEROUS_CONTENT", - "threshold": "BLOCK_NONE", - }, - { - "category": "HARM_CATEGORY_CIVIC_INTEGRITY", - "threshold": "BLOCK_NONE", - }, - ], - } - - if stream: - - async def generate(): - retries = 0 - max_retries = 3 - current_api_key = api_key - - while retries < max_retries: - try: - timeout = httpx.Timeout( - 300.0, read=300.0 - ) # 连接超时300秒,读取超时300秒 - async with httpx.AsyncClient(timeout=timeout) as async_client: - stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:streamGenerateContent?alt=sse&key={current_api_key}" - async with async_client.stream( - "POST", stream_url, json=payload - ) as async_response: - if async_response.status_code != 200: - error_content = await async_response.aread() - error_msg = error_content.decode("utf-8") - logger.error( - f"API error: {async_response.status_code}, {error_msg}" - ) - if retries < max_retries - 1: - current_api_key = ( - await self.key_manager.handle_api_failure( - current_api_key - ) - ) - retries += 1 - continue - else: - logger.error( - f"Max retries reached. Final error: {async_response.status_code}, {error_msg}" - ) - yield f"data: {json.dumps({'error': f'API error: {async_response.status_code}, {error_msg}'})}\n\n" - return - - async for line in async_response.aiter_lines(): - if line.startswith("data: "): - try: - chunk = json.loads(line[6:]) - openai_chunk = ( - self.convert_gemini_response_to_openai( - chunk, - model, - stream=True, - finish_reason=None, - ) - ) - if openai_chunk: - yield f"data: {json.dumps(openai_chunk)}\n\n" - except json.JSONDecodeError: - continue - yield f"data: {json.dumps(self.convert_gemini_response_to_openai({}, model, stream=True, finish_reason='stop'))}\n\n" - yield "data: [DONE]\n\n" - return - - except httpx.ReadTimeout: - logger.warning( - f"Read timeout occurred, attempting retry {retries + 1}" - ) - if retries < max_retries - 1: - current_api_key = await self.key_manager.handle_api_failure( - current_api_key - ) - logger.info(f"Switched to new API key: {current_api_key}") - retries += 1 - continue - else: - logger.error( - f"Max retries reached. Final error: Read timeout" - ) - yield f"data: {json.dumps({'error': 'Read timeout'})}\n\n" - return - except Exception as e: - logger.exception( - f"Stream error: {str(e)}, attempting retry {retries + 1}" - ) - if retries < max_retries - 1: - current_api_key = await self.key_manager.handle_api_failure( - current_api_key - ) - logger.info(f"Switched to new API key: {current_api_key}") - retries += 1 - continue - else: - logger.error(f"Max retries reached. Final error: {e}") - yield f"data: {json.dumps({'error': str(e)})}\n\n" - return - - return generate() - else: - try: - timeout = httpx.Timeout( - 300.0, read=300.0 - ) # 连接超时300秒,读取超时300秒 - async with httpx.AsyncClient(timeout=timeout) as client: - url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:generateContent?key={api_key}" - response = await client.post(url, json=payload) - if response.status_code != 200: - error_text = response.text - error_code = response.status_code - raise Exception( - f"API调用错误 - 状态码: {error_code}, 响应内容: {error_text}" - ) - gemini_response = response.json() - return self.convert_gemini_response_to_openai( - gemini_response, model, stream=False, finish_reason="stop" - ) - except Exception as e: - logger.error(f"Error in non-stream completion") - raise - - def format_code_block(self, code_data: dict) -> str: - """格式化代码块输出""" - language = code_data.get("language", "").lower() - code = code_data.get("code", "").strip() - - return f"""\n【代码执行】\n```{language}\n{code}\n```\n""" - - async def generate_content( - self, model_name: str, request: GeminiRequest, api_key: str - ) -> dict: - """调用Gemini API生成内容""" - url = f"{self.base_url}/models/{model_name}:generateContent?key={api_key}" - - timeout = httpx.Timeout(300.0, read=300.0) # 连接超时300秒,读取超时300秒 - async with httpx.AsyncClient(timeout=timeout) as client: - try: - response = await client.post(url, json=request.model_dump()) - if response.status_code == 200: - return response.json() - else: - error_text = response.text - logger.error(f"Error: {response.status_code}") - logger.error(error_text) - raise Exception( - f"API request failed with status {response.status_code}: {error_text}" - ) - except Exception as e: - logger.error(f"Request failed: {str(e)}") - raise - - async def stream_generate_content( - self, model_name: str, request: GeminiRequest, api_key: str - ) -> AsyncGenerator: - """调用Gemini API流式生成内容""" - retries = 0 - MAX_RETRIES = 3 - current_api_key = api_key - - while retries < MAX_RETRIES: - try: - url = f"{self.base_url}/models/{model_name}:streamGenerateContent?alt=sse&key={current_api_key}" - timeout = httpx.Timeout(300.0, read=300.0) - - async with httpx.AsyncClient(timeout=timeout) as client: - async with client.stream( - "POST", url, json=request.model_dump() - ) as response: - if response.status_code != 200: - error_text = await response.text() - logger.error(f"Error: {response.status_code}: {error_text}") - if retries < MAX_RETRIES - 1: - current_api_key = ( - await self.key_manager.handle_api_failure( - current_api_key - ) - ) - logger.info( - f"Switched to new API key: {current_api_key}" - ) - retries += 1 - continue - raise Exception( - f"API request failed with status {response.status_code}: {error_text}" - ) - - async for line in response.aiter_lines(): - yield line + "\n\n" - return - - except httpx.ReadTimeout: - logger.warning(f"Read timeout occurred, attempting retry {retries + 1}") - if retries < MAX_RETRIES - 1: - current_api_key = await self.key_manager.handle_api_failure( - current_api_key - ) - logger.info(f"Switched to new API key: {current_api_key}") - retries += 1 - continue - raise - - except Exception as e: - logger.error(f"Streaming request failed: {str(e)}") - if retries < MAX_RETRIES - 1: - current_api_key = await self.key_manager.handle_api_failure( - current_api_key - ) - logger.info(f"Switched to new API key: {current_api_key}") - retries += 1 - continue - raise diff --git a/app/services/gemini_chat_service.py b/app/services/gemini_chat_service.py new file mode 100644 index 0000000..0ee8772 --- /dev/null +++ b/app/services/gemini_chat_service.py @@ -0,0 +1,89 @@ +# app/services/chat_service.py + +import json +from typing import Dict, Any, AsyncGenerator, List +from app.core.logger import get_gemini_logger +from app.services.chat.api_client import GeminiApiClient +from app.schemas.gemini_models import GeminiRequest +from app.core.config import settings +from app.services.chat.response_handler import GeminiResponseHandler +from app.services.key_manager import KeyManager + +logger = get_gemini_logger() +class GeminiChatService: + """聊天服务""" + + def __init__(self, base_url: str, key_manager: KeyManager): + self.api_client = GeminiApiClient(base_url) + self.key_manager = key_manager + self.response_handler = GeminiResponseHandler() + + def generate_content(self, model: str, request: GeminiRequest, api_key: str) -> Dict[str, Any]: + """生成内容""" + payload = self._build_payload(model, request) + response = self.api_client.generate_content(payload, model, api_key) + return self.response_handler.handle_response(response, model, stream=False) + + async def stream_generate_content(self, model: str, request: GeminiRequest, api_key: str) -> AsyncGenerator[str, None]: + """流式生成内容""" + retries = 0 + max_retries = 3 + payload = self._build_payload(model, request) + while retries < max_retries: + try: + async for line in self.api_client.stream_generate_content(payload, model, api_key): + if line.startswith("data:"): + line = line[6:] + line = json.dumps(self.response_handler.handle_response(json.loads(line), model, stream=True)) + yield "data: " + line + "\n\n" + logger.info("Streaming completed successfully") + break + except Exception as e: + retries += 1 + logger.warning(f"Streaming API call failed with error: {str(e)}. Attempt {retries} of {max_retries}") + api_key = await self.key_manager.handle_api_failure(api_key) + logger.info(f"Switched to new API key: {api_key}") + if retries >= max_retries: + logger.error(f"Max retries ({max_retries}) reached for streaming. Raising error") + break + + def _build_payload(self,model: str, request: GeminiRequest) -> Dict[str, Any]: + """构建请求payload""" + payload = request.model_dump() + return { + "contents": payload.get("contents", []), + "tools": self._build_tools(model, payload), + "safetySettings": self._get_safety_settings(), + "generationConfig": payload.get("generationConfig", {}), + "systemInstruction": payload.get("systemInstruction", []) + } + + def _build_tools(self, model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]: + """构建工具""" + tools = [] + if settings.TOOLS_CODE_EXECUTION_ENABLED and not ( + model.endswith("-search") or "-thinking" in model + ) and not self._has_image_parts(payload.get("contents", [])): + tools.append({"code_execution": {}}) + if model.endswith("-search"): + tools.append({"googleSearch": {}}) + return tools + + def _has_image_parts(self, contents: List[Dict[str, Any]]) -> bool: + """判断消息是否包含图片部分""" + for content in contents: + if "parts" in content: + for part in content["parts"]: + if "image_url" in part or "inline_data" in part: + return True + return False + + def _get_safety_settings(self) -> List[Dict[str, str]]: + """获取安全设置""" + return [ + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"} + ] \ No newline at end of file diff --git a/app/services/openai_chat_service.py b/app/services/openai_chat_service.py new file mode 100644 index 0000000..ca67bae --- /dev/null +++ b/app/services/openai_chat_service.py @@ -0,0 +1,136 @@ +# app/services/chat_service.py + +import json +from typing import Dict, Any, AsyncGenerator, List, Union +from app.core.logger import get_openai_logger +from app.services.chat.message_converter import OpenAIMessageConverter +from app.services.chat.response_handler import OpenAIResponseHandler +from app.services.chat.api_client import GeminiApiClient +from app.schemas.openai_models import ChatRequest +from app.core.config import settings +from app.services.key_manager import KeyManager + +logger = get_openai_logger() +class OpenAIChatService: + """聊天服务""" + + def __init__(self, base_url: str, key_manager: KeyManager): + self.message_converter = OpenAIMessageConverter() + self.response_handler = OpenAIResponseHandler(config=None) + self.api_client = GeminiApiClient(base_url) + self.key_manager = key_manager + + async def create_chat_completion( + self, + request: ChatRequest, + api_key: str, + ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: + """创建聊天完成""" + # 转换消息格式 + messages = self.message_converter.convert(request.messages) + + # 构建请求payload + payload = self._build_payload(request, messages) + + if request.stream: + return self._handle_stream_completion(request.model, payload, api_key) + return self._handle_normal_completion(request.model, payload, api_key) + + def _handle_normal_completion( + self, + model: str, + payload: Dict[str, Any], + api_key: str + ) -> Dict[str, Any]: + """处理普通聊天完成""" + response = self.api_client.generate_content(payload, model, api_key) + return self.response_handler.handle_response( + response, + model, + stream=False, + finish_reason="stop" + ) + + async def _handle_stream_completion( + self, + model: str, + payload: Dict[str, Any], + api_key: str + ) -> AsyncGenerator[str, None]: + """处理流式聊天完成,添加重试逻辑""" + retries = 0 + max_retries = 3 + while retries < max_retries: + try: + async for line in self.api_client.stream_generate_content(payload, model, api_key): + if line.startswith("data:"): + chunk = json.loads(line[6:]) + openai_chunk = self.response_handler.handle_response( + chunk, + model, + stream=True, + finish_reason=None + ) + if openai_chunk: + yield f"data: {json.dumps(openai_chunk)}\n\n" + yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n" + yield "data: [DONE]\n\n" + logger.info("Streaming completed successfully") + break # 成功后退出循环 + except Exception as e: + retries += 1 + logger.warning(f"Streaming API call failed with error: {str(e)}. Attempt {retries} of {max_retries}") + api_key = await self.key_manager.handle_api_failure(api_key) + logger.info(f"Switched to new API key: {api_key}") + if retries >= max_retries: + logger.error(f"Max retries ({max_retries}) reached for streaming. Raising error") + yield f"data: {json.dumps({'error': 'Streaming failed after retries'})}\n\n" + yield "data: [DONE]\n\n" + break + + def _build_payload(self, request: ChatRequest, messages: List[Dict[str, Any]]) -> Dict[str, Any]: + """构建请求payload""" + return { + "contents": messages, + "generationConfig": { + "temperature": request.temperature, + "maxOutputTokens": request.max_tokens, + "stopSequences": request.stop, + "topP": request.top_p, + "topK": request.top_k + }, + "tools": self._build_tools(request, messages), + "safetySettings": self._get_safety_settings() + } + + def _build_tools(self, request: ChatRequest, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """构建工具""" + tools = [] + model = request.model + + if settings.TOOLS_CODE_EXECUTION_ENABLED and not ( + model.endswith("-search") or "-thinking" in model + ) and not self._has_image_parts(messages): + tools.append({"code_execution": {}}) + if model.endswith("-search"): + tools.append({"googleSearch": {}}) + return tools + + def _has_image_parts(self, contents: List[Dict[str, Any]]) -> bool: + """判断消息是否包含图片部分""" + for content in contents: + if "parts" in content: + for part in content["parts"]: + if "image_url" in part or "inline_data" in part: + return True + return False + + def _get_safety_settings(self) -> List[Dict[str, str]]: + """获取安全设置""" + return [ + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"} + ] \ No newline at end of file diff --git a/app/utils/helpers.py b/app/utils/helpers.py deleted file mode 100644 index e69de29..0000000