feat: 添加重试机制和消息转换器,并支持Gemini v1beta API

This commit is contained in:
yinpeng
2024-12-27 20:07:43 +08:00
parent 6e90463251
commit 870b1ecc17
12 changed files with 755 additions and 597 deletions

View File

@@ -1,3 +1,4 @@
from http.client import HTTPException
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
@@ -5,18 +6,18 @@ from app.core.config import settings
from app.core.logger import get_gemini_logger
from app.core.security import SecurityService
from app.schemas.gemini_models import GeminiRequest
from app.services.chat_service import ChatService
from app.services.gemini_chat_service import GeminiChatService
from app.services.key_manager import KeyManager
from app.services.model_service import ModelService
from app.services.chat.retry_handler import RetryHandler
router = APIRouter(prefix="/gemini/v1beta")
router_v1beta = APIRouter(prefix="/v1beta")
logger = get_gemini_logger()
# 初始化服务
security_service = SecurityService(settings.ALLOWED_TOKENS, settings.AUTH_TOKEN)
key_manager = KeyManager(settings.API_KEYS)
model_service = ModelService(settings.MODEL_SEARCH)
chat_service = ChatService(base_url=settings.BASE_URL, key_manager=key_manager)
@router.get("/models")
@@ -34,58 +35,51 @@ async def list_models(
return models_json
@router.post("/models/{model_name}:generateContent")
@RetryHandler(max_retries=3, key_manager=key_manager, key_arg="api_key")
async def generate_content(
model_name: str,
request: GeminiRequest,
x_goog_api_key: str = Depends(security_service.verify_goog_api_key),
# x_goog_api_key: str = Depends(security_service.verify_goog_api_key),
api_key: str = Depends(key_manager.get_next_working_key),
):
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
"""非流式生成内容"""
logger.info("-" * 50 + "gemini_generate_content" + "-" * 50)
logger.info(f"Handling Gemini content generation request for model: {model_name}")
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
api_key = await key_manager.get_next_working_key()
logger.info(f"Using API key: {api_key}")
retries = 0
MAX_RETRIES = 3
while retries < MAX_RETRIES:
try:
response = await chat_service.generate_content(
model_name=model_name,
request=request,
api_key=api_key
)
return response
except Exception as e:
logger.warning(
f"API call failed with error: {str(e)}. Attempt {retries + 1} of {MAX_RETRIES}"
)
api_key = await key_manager.handle_api_failure(api_key)
logger.info(f"Switched to new API key: {api_key}")
retries += 1
if retries >= MAX_RETRIES:
logger.error(f"Max retries ({MAX_RETRIES}) reached. Raising error")
@router.post("/models/{model_name}:streamGenerateContent")
async def stream_generate_content(
model_name: str,
request: GeminiRequest,
x_goog_api_key: str = Depends(security_service.verify_goog_api_key),
):
"""流式生成内容"""
logger.info("-" * 50 + "gemini_stream_generate_content" + "-" * 50)
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
api_key = await key_manager.get_next_working_key()
logger.info(f"Using API key: {api_key}")
try:
chat_service = ChatService(base_url=settings.BASE_URL, key_manager=key_manager)
response_stream = chat_service.stream_generate_content(
model_name=model_name,
response = chat_service.generate_content(
model=model_name,
request=request,
api_key=api_key
)
return response
except Exception as e:
logger.error(f"Chat completion failed after retries: {str(e)}")
raise HTTPException(status_code=500, detail="Chat completion failed") from e
@router.post("/models/{model_name}:streamGenerateContent")
@RetryHandler(max_retries=3, key_manager=key_manager, key_arg="api_key")
async def stream_generate_content(
model_name: str,
request: GeminiRequest,
# x_goog_api_key: str = Depends(security_service.verify_goog_api_key),
api_key: str = Depends(key_manager.get_next_working_key),
):
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
"""流式生成内容"""
logger.info("-" * 50 + "gemini_stream_generate_content" + "-" * 50)
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}")
try:
response_stream =chat_service.stream_generate_content(
model=model_name,
request=request,
api_key=api_key
)

View File

@@ -3,9 +3,10 @@ from fastapi import APIRouter, Depends, Header
from fastapi.responses import StreamingResponse
from app.core.security import SecurityService
from app.services.chat.retry_handler import RetryHandler
from app.services.key_manager import KeyManager
from app.services.model_service import ModelService
from app.services.chat_service import ChatService
from app.services.openai_chat_service import OpenAIChatService
from app.services.embedding_service import EmbeddingService
from app.schemas.openai_models import ChatRequest, EmbeddingRequest
from app.core.config import settings
@@ -31,47 +32,42 @@ async def list_models(
logger.info("Handling models list request")
api_key = await key_manager.get_next_working_key()
logger.info(f"Using API key: {api_key}")
return model_service.get_gemini_openai_models(api_key)
try:
return model_service.get_gemini_openai_models(api_key)
except Exception as e:
logger.error(f"Error getting models list: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error while fetching models list") from e
@router.post("/v1/chat/completions")
@router.post("/hf/v1/chat/completions")
@RetryHandler(max_retries=3, key_manager=key_manager, key_arg="api_key")
async def chat_completion(
request: ChatRequest,
authorization: str = Header(None),
token: str = Depends(security_service.verify_authorization),
api_key: str = Depends(key_manager.get_next_working_key),
):
chat_service = ChatService(settings.BASE_URL, key_manager)
chat_service = OpenAIChatService(settings.BASE_URL, key_manager)
logger.info("-" * 50 + "chat_completion" + "-" * 50)
logger.info(f"Handling chat completion request for model: {request.model}")
logger.info(f"Request: \n{request.model_dump_json(indent=2)}")
api_key = await key_manager.get_next_working_key()
logger.info(f"Using API key: {api_key}")
retries = 0
max_retries = 3
try:
response = await chat_service.create_chat_completion(
request=request,
api_key=api_key,
)
# 处理流式响应
if request.stream:
return StreamingResponse(response, media_type="text/event-stream")
logger.info("Chat completion request successful")
return response
while retries < max_retries:
try:
response = await chat_service.create_chat_completion(
request=request,
api_key=api_key,
)
# 处理流式响应
if request.stream:
return StreamingResponse(response, media_type="text/event-stream")
return response
except Exception as e:
logger.warning(
f"API call failed with error: {str(e)}. Attempt {retries + 1} of {max_retries}"
)
api_key = await key_manager.handle_api_failure(api_key)
logger.info(f"Switched to new API key: {api_key}")
retries += 1
if retries >= max_retries:
logger.error(f"Max retries ({max_retries}) reached. Raising error")
raise
except Exception as e:
logger.error(f"Chat completion failed after retries: {str(e)}")
raise HTTPException(status_code=500, detail="Chat completion failed") from e
@router.post("/v1/embeddings")
@@ -93,7 +89,7 @@ async def embedding(
return response
except Exception as e:
logger.error(f"Embedding request failed: {str(e)}")
raise
raise HTTPException(status_code=500, detail="Embedding request failed") from e
@router.get("/v1/keys/list")
@@ -120,4 +116,4 @@ async def get_keys_list(
raise HTTPException(
status_code=500,
detail="Internal server error while fetching keys list"
)
) from e

View File

@@ -125,3 +125,7 @@ def get_embeddings_logger():
def get_request_logger():
return Logger.setup_logger("request")
def get_retry_logger():
return Logger.setup_logger("retry")

View File

@@ -29,6 +29,7 @@ app.add_middleware(
# 包含所有路由
app.include_router(openai_routes.router)
app.include_router(gemini_routes.router)
app.include_router(gemini_routes.router_v1beta)
@app.get("/health")

View File

@@ -0,0 +1,49 @@
# app/services/chat/api_client.py
from typing import Dict, Any, AsyncGenerator
import httpx
from abc import ABC, abstractmethod
class ApiClient(ABC):
"""API客户端基类"""
@abstractmethod
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
pass
@abstractmethod
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
pass
class GeminiApiClient(ApiClient):
"""Gemini API客户端"""
def __init__(self, base_url: str, timeout: int = 300):
self.base_url = base_url
self.timeout = timeout
def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
if model.endswith("-search"):
model = model[:-7]
with httpx.Client(timeout=timeout) as client:
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
response = client.post(url, json=payload)
if response.status_code != 200:
error_content = response.text
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
return response.json()
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
if model.endswith("-search"):
model = model[:-7]
async with httpx.AsyncClient(timeout=timeout) as client:
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
async with client.stream("POST", url, json=payload) as response:
if response.status_code != 200:
error_content = await response.aread()
error_msg = error_content.decode("utf-8")
raise Exception(f"API call failed with status code {response.status_code}, {error_msg}")
async for line in response.aiter_lines():
yield line

View File

@@ -0,0 +1,50 @@
# app/services/chat/message_converter.py
from abc import ABC, abstractmethod
from typing import List, Dict, Any
class MessageConverter(ABC):
"""消息转换器基类"""
@abstractmethod
def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
pass
class OpenAIMessageConverter(MessageConverter):
"""OpenAI消息格式转换器"""
def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
converted_messages = []
for msg in messages:
role = "user" if msg["role"] == "user" else "model"
parts = []
if isinstance(msg["content"], str):
parts.append({"text": msg["content"]})
elif isinstance(msg["content"], list):
for content in msg["content"]:
if isinstance(content, str):
parts.append({"text": content})
elif isinstance(content, dict):
if content["type"] == "text":
parts.append({"text": content["text"]})
elif content["type"] == "image_url":
parts.append(self._convert_image(content["image_url"]["url"]))
converted_messages.append({"role": role, "parts": parts})
return converted_messages
def _convert_image(self, image_url: str) -> Dict[str, Any]:
if image_url.startswith("data:image"):
return {
"inline_data": {
"mime_type": "image/jpeg",
"data": image_url.split(",")[1]
}
}
return {
"image_url": {
"url": image_url
}
}

View File

@@ -0,0 +1,322 @@
# app/services/chat/response_handler.py
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
import time
import uuid
from app.core.config import settings
class ResponseHandler(ABC):
"""响应处理器基类"""
@abstractmethod
def handle_response(self, response: Dict[str, Any], model: str, stream: bool = False) -> Dict[str, Any]:
pass
class GeminiResponseHandler(ResponseHandler):
"""Gemini响应处理器"""
def __init__(self):
self.thinking_first = True
self.thinking_status = False
def handle_response(self, response: Dict[str, Any], model: str, stream: bool = False) -> Dict[str, Any]:
if stream:
return self._handle_stream_response(response, model, stream)
return self._handle_normal_response(response, model, stream)
def _handle_stream_response(self, response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]:
text = self._extract_text(response, model, stream=stream)
content = {"parts": [{"text": text}],"role": "model"}
response["candidates"][0]["content"] = content
return response
def _handle_normal_response(self, response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]:
text = self._extract_text(response, model, stream=stream)
content = {"parts": [{"text": text}],"role": "model"}
response["candidates"][0]["content"] = content
return response
def _extract_text(self, response: Dict[str, Any], model: str, stream: bool = False) -> str:
text = ""
if stream:
if response.get("candidates"):
candidate = response["candidates"][0]
content = candidate.get("content", {})
parts = content.get("parts", [])
if "thinking" in model:
if settings.SHOW_THINKING_PROCESS:
if len(parts) == 1:
if self.thinking_first:
self.thinking_first = False
self.thinking_status = True
text = "> thinking\n\n" + parts[0].get("text")
else:
text = parts[0].get("text")
if len(parts) == 2:
self.thinking_status = False
if self.thinking_first:
self.thinking_first = False
text = (
"> thinking\n\n"
+ parts[0].get("text")
+ "\n\n---\n> output\n\n"
+ parts[1].get("text")
)
else:
text = (
parts[0].get("text")
+ "\n\n---\n> output\n\n"
+ parts[1].get("text")
)
else:
if len(parts) == 1:
if self.thinking_first:
self.thinking_first = False
self.thinking_status = True
text = ""
elif self.thinking_status:
text = ""
else:
text = parts[0].get("text")
if len(parts) == 2:
self.thinking_status = False
if self.thinking_first:
self.thinking_first = False
text = parts[1].get("text")
else:
text = parts[1].get("text")
else:
if "text" in parts[0]:
text = parts[0].get("text")
elif "executableCode" in parts[0]:
text = _format_code_block(parts[0]["executableCode"])
elif "codeExecution" in parts[0]:
text = _format_code_block(parts[0]["codeExecution"])
elif "executableCodeResult" in parts[0]:
text = _format_execution_result(
parts[0]["executableCodeResult"]
)
elif "codeExecutionResult" in parts[0]:
text = _format_execution_result(
parts[0]["codeExecutionResult"]
)
else:
text = ""
text = _add_search_link_text(model, candidate, text)
else:
if response.get("candidates"):
candidate = response["candidates"][0]
if "thinking" in model:
if settings.SHOW_THINKING_PROCESS:
if len(candidate["content"]["parts"]) == 2:
text = (
"> thinking\n\n"
+ candidate["content"]["parts"][0]["text"]
+ "\n\n---\n> output\n\n"
+ candidate["content"]["parts"][1]["text"]
)
else:
text = candidate["content"]["parts"][0]["text"]
else:
if len(candidate["content"]["parts"]) == 2:
text = candidate["content"]["parts"][1]["text"]
else:
text = candidate["content"]["parts"][0]["text"]
else:
text = candidate["content"]["parts"][0]["text"]
text = _add_search_link_text(model, candidate, text)
else:
text = "暂无返回"
return text
class OpenAIResponseHandler(ResponseHandler):
"""OpenAI响应处理器"""
def __init__(self, config):
self.config = config
self.thinking_first = True
self.thinking_status = False
def handle_response(
self,
response: Dict[str, Any],
model: str,
stream: bool = False,
finish_reason: str = None
) -> Optional[Dict[str, Any]]:
if stream:
return self._handle_stream_response(response, model, finish_reason)
return self._handle_normal_response(response, model, finish_reason)
def _handle_stream_response(self, response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]:
text = self._extract_text(response, model, stream=True)
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"delta": {"content": text} if text else {},
"finish_reason": finish_reason
}]
}
def _handle_normal_response(self, response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]:
text = self._extract_text(response, model, stream=False)
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": text
},
"finish_reason": finish_reason
}],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
}
def _extract_text(self, response: Dict[str, Any], model: str, stream: bool = False) -> str:
text = ""
if stream:
if response.get("candidates"):
candidate = response["candidates"][0]
content = candidate.get("content", {})
parts = content.get("parts", [])
if "thinking" in model:
if settings.SHOW_THINKING_PROCESS:
if len(parts) == 1:
if self.thinking_first:
self.thinking_first = False
self.thinking_status = True
text = "> thinking\n\n" + parts[0].get("text")
else:
text = parts[0].get("text")
if len(parts) == 2:
self.thinking_status = False
if self.thinking_first:
self.thinking_first = False
text = (
"> thinking\n\n"
+ parts[0].get("text")
+ "\n\n---\n> output\n\n"
+ parts[1].get("text")
)
else:
text = (
parts[0].get("text")
+ "\n\n---\n> output\n\n"
+ parts[1].get("text")
)
else:
if len(parts) == 1:
if self.thinking_first:
self.thinking_first = False
self.thinking_status = True
text = ""
elif self.thinking_status:
text = ""
else:
text = parts[0].get("text")
if len(parts) == 2:
self.thinking_status = False
if self.thinking_first:
self.thinking_first = False
text = parts[1].get("text")
else:
text = parts[1].get("text")
else:
if "text" in parts[0]:
text = parts[0].get("text")
elif "executableCode" in parts[0]:
text = _format_code_block(parts[0]["executableCode"])
elif "codeExecution" in parts[0]:
text = _format_code_block(parts[0]["codeExecution"])
elif "executableCodeResult" in parts[0]:
text = _format_execution_result(
parts[0]["executableCodeResult"]
)
elif "codeExecutionResult" in parts[0]:
text = _format_execution_result(
parts[0]["codeExecutionResult"]
)
else:
text = ""
text = _add_search_link_text(model, candidate, text)
else:
if response.get("candidates"):
candidate = response["candidates"][0]
if "thinking" in model:
if settings.SHOW_THINKING_PROCESS:
if len(candidate["content"]["parts"]) == 2:
text = (
"> thinking\n\n"
+ candidate["content"]["parts"][0]["text"]
+ "\n\n---\n> output\n\n"
+ candidate["content"]["parts"][1]["text"]
)
else:
text = candidate["content"]["parts"][0]["text"]
else:
if len(candidate["content"]["parts"]) == 2:
text = candidate["content"]["parts"][1]["text"]
else:
text = candidate["content"]["parts"][0]["text"]
else:
text = candidate["content"]["parts"][0]["text"]
text = _add_search_link_text(model, candidate, text)
else:
text = "暂无返回"
return text
def _format_code_block(code_data: dict) -> str:
"""格式化代码块输出"""
language = code_data.get("language", "").lower()
code = code_data.get("code", "").strip()
return f"""\n\n---\n\n【代码执行】\n```{language}\n{code}\n```\n"""
def _add_search_link_text(model:str, candidate:dict, text:str) -> str:
if (
settings.SHOW_SEARCH_LINK
and model.endswith("-search")
and "groundingMetadata" in candidate
and "groundingChunks" in candidate["groundingMetadata"]
):
grounding_chunks = candidate["groundingMetadata"]["groundingChunks"]
text += "\n\n---\n\n"
text += "**【引用来源】**\n\n"
for _, grounding_chunk in enumerate(grounding_chunks, 1):
if "web" in grounding_chunk:
text += _create_search_link(grounding_chunk["web"])
return text
else:
return text
def _create_search_link(grounding_chunk: dict) -> str:
return f'\n- [{grounding_chunk["title"]}]({grounding_chunk["uri"]})'
def _format_execution_result(result_data: dict) -> str:
"""格式化执行结果输出"""
outcome = result_data.get("outcome", "")
output = result_data.get("output", "").strip()
return f"""\n【执行结果】\n> outcome: {outcome}\n\n【输出结果】\n```plaintext\n{output}\n```\n\n---\n\n"""

View File

@@ -0,0 +1,40 @@
# app/services/chat/retry_handler.py
from typing import TypeVar, Callable
from functools import wraps
from app.core.logger import get_retry_logger
from app.services.key_manager import KeyManager
T = TypeVar('T')
logger = get_retry_logger()
class RetryHandler:
"""重试处理装饰器"""
def __init__(self, max_retries: int = 3, key_manager: KeyManager = None, key_arg: str = "api_key"):
self.max_retries = max_retries
self.key_manager = key_manager
self.key_arg = key_arg
def __call__(self, func: Callable[..., T]) -> Callable[..., T]:
@wraps(func)
async def wrapper(*args, **kwargs) -> T:
last_exception = None
for attempt in range(self.max_retries):
try:
return await func(*args, **kwargs)
except Exception as e:
last_exception = e
logger.warning(f"API call failed with error: {str(e)}. Attempt {attempt + 1} of {self.max_retries}")
if self.key_manager:
old_key = kwargs.get(self.key_arg)
new_key = await self.key_manager.handle_api_failure(old_key)
kwargs[self.key_arg] = new_key
logger.info(f"Switched to new API key: {new_key}")
logger.error(f"All retry attempts failed, raising final exception: {str(last_exception)}")
raise last_exception
return wrapper

View File

@@ -1,523 +0,0 @@
import httpx
import json
import time
import uuid
from typing import Dict, Any, Optional, AsyncGenerator, Union
from app.core.config import settings
from app.core.logger import get_chat_logger
from app.schemas.gemini_models import GeminiRequest
from app.schemas.openai_models import ChatRequest
logger = get_chat_logger()
def convert_messages_to_gemini_format(messages: list) -> list:
"""Convert OpenAI message format to Gemini format"""
converted_messages = []
for msg in messages:
role = "user" if msg["role"] == "user" else "model"
parts = []
# 处理文本内容
if isinstance(msg["content"], str):
parts.append({"text": msg["content"]})
# 处理包含图片的消息
elif isinstance(msg["content"], list):
for content in msg["content"]:
if isinstance(content, str):
parts.append({"text": content})
elif isinstance(content, dict) and content["type"] == "text":
parts.append({"text": content["text"]})
elif isinstance(content, dict) and content["type"] == "image_url":
# 处理图片URL
image_url = content["image_url"]["url"]
if image_url.startswith("data:image"):
# 处理base64图片
parts.append(
{
"inline_data": {
"mime_type": "image/jpeg",
"data": image_url.split(",")[1],
}
}
)
else:
# 处理普通URL图片
parts.append(
{
"image_url": {
"url": image_url,
}
}
)
converted_messages.append({"role": role, "parts": parts})
return converted_messages
def format_execution_result(result_data: dict) -> str:
"""格式化执行结果输出"""
outcome = result_data.get("outcome", "")
output = result_data.get("output", "").strip()
return f"""\n【执行结果】\n> outcome: {outcome}\n\n【输出结果】\n```plaintext\n{output}\n```\n"""
def create_search_link(web):
return f'\n- [{web["title"]}]({web["uri"]})'
class ChatService:
def __init__(self, base_url: str, key_manager=None):
self.base_url = base_url
self.key_manager = key_manager
self.thinking_first = True
self.thinking_status = False
def convert_gemini_response_to_openai(
self,
response: Dict[str, Any],
model: str,
stream: bool = False,
finish_reason: str = None,
) -> Optional[Dict[str, Any]]:
"""Convert Gemini response to OpenAI format"""
if stream:
try:
text = ""
if response.get("candidates"):
candidate = response["candidates"][0]
content = candidate.get("content", {})
parts = content.get("parts", [])
if "thinking" in model:
if settings.SHOW_THINKING_PROCESS:
if len(parts) == 1:
if self.thinking_first:
self.thinking_first = False
self.thinking_status = True
text = "> thinking\n\n" + parts[0].get("text")
else:
text = parts[0].get("text")
if len(parts) == 2:
self.thinking_status = False
if self.thinking_first:
self.thinking_first = False
text = (
"> thinking\n\n"
+ parts[0].get("text")
+ "\n\n---\n> output\n\n"
+ parts[1].get("text")
)
else:
text = (
parts[0].get("text")
+ "\n\n---\n> output\n\n"
+ parts[1].get("text")
)
else:
if len(parts) == 1:
if self.thinking_first:
self.thinking_first = False
self.thinking_status = True
text = ""
elif self.thinking_status:
text = ""
else:
text = parts[0].get("text")
if len(parts) == 2:
self.thinking_status = False
if self.thinking_first:
self.thinking_first = False
text = parts[1].get("text")
else:
text = parts[1].get("text")
else:
if "text" in parts[0]:
text = parts[0].get("text")
elif "executableCode" in parts[0]:
text = self.format_code_block(parts[0]["executableCode"])
elif "codeExecution" in parts[0]:
text = self.format_code_block(parts[0]["codeExecution"])
elif "executableCodeResult" in parts[0]:
text = format_execution_result(
parts[0]["executableCodeResult"]
)
elif "codeExecutionResult" in parts[0]:
text = format_execution_result(
parts[0]["codeExecutionResult"]
)
else:
text = ""
text = self.add_search_link_text(model, candidate, text)
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": text} if text else {},
"finish_reason": finish_reason,
}
],
}
except Exception as e:
logger.error(f"Error converting Gemini response: {str(e)}")
logger.debug(f"Raw response: {response}")
return None
else:
res = {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": response["candidates"][0]["content"]["parts"][0]["text"],
},
"finish_reason": finish_reason,
}
],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
},
}
try:
if response.get("candidates"):
candidate = response["candidates"][0]
if "thinking" in model:
if settings.SHOW_THINKING_PROCESS:
if len(candidate["content"]["parts"]) == 2:
text = (
"> thinking\n\n"
+ candidate["content"]["parts"][0]["text"]
+ "\n\n---\n> output\n\n"
+ candidate["content"]["parts"][1]["text"]
)
else:
text = candidate["content"]["parts"][0]["text"]
else:
if len(candidate["content"]["parts"]) == 2:
text = candidate["content"]["parts"][1]["text"]
else:
text = candidate["content"]["parts"][0]["text"]
else:
text = candidate["content"]["parts"][0]["text"]
text = self.add_search_link_text(model, candidate, text)
res["choices"][0]["message"]["content"] = text
return res
else:
res["choices"][0]["message"]["content"] = "暂无返回"
return res
except Exception as e:
logger.error(f"Error converting Gemini response: {str(e)}")
logger.debug(f"Raw response: {response}")
res["choices"][0]["message"][
"content"
] = f"Error converting Gemini response: {str(e)}"
return res
def add_search_link_text(self, model, candidate, text):
if (
settings.SHOW_SEARCH_LINK
and model.endswith("-search")
and "groundingMetadata" in candidate
and "groundingChunks" in candidate["groundingMetadata"]
):
grounding_chunks = candidate["groundingMetadata"]["groundingChunks"]
text += "\n\n---\n\n"
text += "**【引用来源】**\n\n"
for _, grounding_chunk in enumerate(grounding_chunks, 1):
if "web" in grounding_chunk:
text += create_search_link(grounding_chunk["web"])
return text
else:
return text
async def create_chat_completion(
self,
request: ChatRequest,
api_key: str,
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
"""Create chat completion using either Gemini or OpenAI API"""
model = request.model
tools = request.tools
if tools is None:
tools = []
if settings.TOOLS_CODE_EXECUTION_ENABLED and not (
model.endswith("-search") or "-thinking" in model
):
tools.append({"code_execution": {}})
if model.endswith("-search"):
tools.append({"googleSearch": {}})
return await self._gemini_chat_completion(request, api_key, tools)
async def _gemini_chat_completion(
self,
request: ChatRequest,
api_key: str,
tools: Optional[list] = None,
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
"""Handle Gemini API chat completion"""
model = request.model
messages = request.messages
temperature = request.temperature
stream = request.stream
max_tokens = request.max_tokens
stop = request.stop
top_p = request.top_p
top_k = request.top_k
if model.endswith("-search"):
gemini_model = model[:-7] # Remove -search suffix
else:
gemini_model = model
gemini_messages = convert_messages_to_gemini_format(messages)
if not stream:
# 非流式模式下,移除代码执行工具
if {"code_execution": {}} in tools:
tools.remove({"code_execution": {}})
payload = {
"contents": gemini_messages,
"generationConfig": {
"temperature": temperature,
"maxOutputTokens": max_tokens,
"stopSequences": stop,
"topP": top_p,
"topK": top_k,
},
"tools": tools,
"safetySettings": [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_CIVIC_INTEGRITY",
"threshold": "BLOCK_NONE",
},
],
}
if stream:
async def generate():
retries = 0
max_retries = 3
current_api_key = api_key
while retries < max_retries:
try:
timeout = httpx.Timeout(
300.0, read=300.0
) # 连接超时300秒读取超时300秒
async with httpx.AsyncClient(timeout=timeout) as async_client:
stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:streamGenerateContent?alt=sse&key={current_api_key}"
async with async_client.stream(
"POST", stream_url, json=payload
) as async_response:
if async_response.status_code != 200:
error_content = await async_response.aread()
error_msg = error_content.decode("utf-8")
logger.error(
f"API error: {async_response.status_code}, {error_msg}"
)
if retries < max_retries - 1:
current_api_key = (
await self.key_manager.handle_api_failure(
current_api_key
)
)
retries += 1
continue
else:
logger.error(
f"Max retries reached. Final error: {async_response.status_code}, {error_msg}"
)
yield f"data: {json.dumps({'error': f'API error: {async_response.status_code}, {error_msg}'})}\n\n"
return
async for line in async_response.aiter_lines():
if line.startswith("data: "):
try:
chunk = json.loads(line[6:])
openai_chunk = (
self.convert_gemini_response_to_openai(
chunk,
model,
stream=True,
finish_reason=None,
)
)
if openai_chunk:
yield f"data: {json.dumps(openai_chunk)}\n\n"
except json.JSONDecodeError:
continue
yield f"data: {json.dumps(self.convert_gemini_response_to_openai({}, model, stream=True, finish_reason='stop'))}\n\n"
yield "data: [DONE]\n\n"
return
except httpx.ReadTimeout:
logger.warning(
f"Read timeout occurred, attempting retry {retries + 1}"
)
if retries < max_retries - 1:
current_api_key = await self.key_manager.handle_api_failure(
current_api_key
)
logger.info(f"Switched to new API key: {current_api_key}")
retries += 1
continue
else:
logger.error(
f"Max retries reached. Final error: Read timeout"
)
yield f"data: {json.dumps({'error': 'Read timeout'})}\n\n"
return
except Exception as e:
logger.exception(
f"Stream error: {str(e)}, attempting retry {retries + 1}"
)
if retries < max_retries - 1:
current_api_key = await self.key_manager.handle_api_failure(
current_api_key
)
logger.info(f"Switched to new API key: {current_api_key}")
retries += 1
continue
else:
logger.error(f"Max retries reached. Final error: {e}")
yield f"data: {json.dumps({'error': str(e)})}\n\n"
return
return generate()
else:
try:
timeout = httpx.Timeout(
300.0, read=300.0
) # 连接超时300秒读取超时300秒
async with httpx.AsyncClient(timeout=timeout) as client:
url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:generateContent?key={api_key}"
response = await client.post(url, json=payload)
if response.status_code != 200:
error_text = response.text
error_code = response.status_code
raise Exception(
f"API调用错误 - 状态码: {error_code}, 响应内容: {error_text}"
)
gemini_response = response.json()
return self.convert_gemini_response_to_openai(
gemini_response, model, stream=False, finish_reason="stop"
)
except Exception as e:
logger.error(f"Error in non-stream completion")
raise
def format_code_block(self, code_data: dict) -> str:
"""格式化代码块输出"""
language = code_data.get("language", "").lower()
code = code_data.get("code", "").strip()
return f"""\n【代码执行】\n```{language}\n{code}\n```\n"""
async def generate_content(
self, model_name: str, request: GeminiRequest, api_key: str
) -> dict:
"""调用Gemini API生成内容"""
url = f"{self.base_url}/models/{model_name}:generateContent?key={api_key}"
timeout = httpx.Timeout(300.0, read=300.0) # 连接超时300秒读取超时300秒
async with httpx.AsyncClient(timeout=timeout) as client:
try:
response = await client.post(url, json=request.model_dump())
if response.status_code == 200:
return response.json()
else:
error_text = response.text
logger.error(f"Error: {response.status_code}")
logger.error(error_text)
raise Exception(
f"API request failed with status {response.status_code}: {error_text}"
)
except Exception as e:
logger.error(f"Request failed: {str(e)}")
raise
async def stream_generate_content(
self, model_name: str, request: GeminiRequest, api_key: str
) -> AsyncGenerator:
"""调用Gemini API流式生成内容"""
retries = 0
MAX_RETRIES = 3
current_api_key = api_key
while retries < MAX_RETRIES:
try:
url = f"{self.base_url}/models/{model_name}:streamGenerateContent?alt=sse&key={current_api_key}"
timeout = httpx.Timeout(300.0, read=300.0)
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream(
"POST", url, json=request.model_dump()
) as response:
if response.status_code != 200:
error_text = await response.text()
logger.error(f"Error: {response.status_code}: {error_text}")
if retries < MAX_RETRIES - 1:
current_api_key = (
await self.key_manager.handle_api_failure(
current_api_key
)
)
logger.info(
f"Switched to new API key: {current_api_key}"
)
retries += 1
continue
raise Exception(
f"API request failed with status {response.status_code}: {error_text}"
)
async for line in response.aiter_lines():
yield line + "\n\n"
return
except httpx.ReadTimeout:
logger.warning(f"Read timeout occurred, attempting retry {retries + 1}")
if retries < MAX_RETRIES - 1:
current_api_key = await self.key_manager.handle_api_failure(
current_api_key
)
logger.info(f"Switched to new API key: {current_api_key}")
retries += 1
continue
raise
except Exception as e:
logger.error(f"Streaming request failed: {str(e)}")
if retries < MAX_RETRIES - 1:
current_api_key = await self.key_manager.handle_api_failure(
current_api_key
)
logger.info(f"Switched to new API key: {current_api_key}")
retries += 1
continue
raise

View File

@@ -0,0 +1,89 @@
# app/services/chat_service.py
import json
from typing import Dict, Any, AsyncGenerator, List
from app.core.logger import get_gemini_logger
from app.services.chat.api_client import GeminiApiClient
from app.schemas.gemini_models import GeminiRequest
from app.core.config import settings
from app.services.chat.response_handler import GeminiResponseHandler
from app.services.key_manager import KeyManager
logger = get_gemini_logger()
class GeminiChatService:
"""聊天服务"""
def __init__(self, base_url: str, key_manager: KeyManager):
self.api_client = GeminiApiClient(base_url)
self.key_manager = key_manager
self.response_handler = GeminiResponseHandler()
def generate_content(self, model: str, request: GeminiRequest, api_key: str) -> Dict[str, Any]:
"""生成内容"""
payload = self._build_payload(model, request)
response = self.api_client.generate_content(payload, model, api_key)
return self.response_handler.handle_response(response, model, stream=False)
async def stream_generate_content(self, model: str, request: GeminiRequest, api_key: str) -> AsyncGenerator[str, None]:
"""流式生成内容"""
retries = 0
max_retries = 3
payload = self._build_payload(model, request)
while retries < max_retries:
try:
async for line in self.api_client.stream_generate_content(payload, model, api_key):
if line.startswith("data:"):
line = line[6:]
line = json.dumps(self.response_handler.handle_response(json.loads(line), model, stream=True))
yield "data: " + line + "\n\n"
logger.info("Streaming completed successfully")
break
except Exception as e:
retries += 1
logger.warning(f"Streaming API call failed with error: {str(e)}. Attempt {retries} of {max_retries}")
api_key = await self.key_manager.handle_api_failure(api_key)
logger.info(f"Switched to new API key: {api_key}")
if retries >= max_retries:
logger.error(f"Max retries ({max_retries}) reached for streaming. Raising error")
break
def _build_payload(self,model: str, request: GeminiRequest) -> Dict[str, Any]:
"""构建请求payload"""
payload = request.model_dump()
return {
"contents": payload.get("contents", []),
"tools": self._build_tools(model, payload),
"safetySettings": self._get_safety_settings(),
"generationConfig": payload.get("generationConfig", {}),
"systemInstruction": payload.get("systemInstruction", [])
}
def _build_tools(self, model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
"""构建工具"""
tools = []
if settings.TOOLS_CODE_EXECUTION_ENABLED and not (
model.endswith("-search") or "-thinking" in model
) and not self._has_image_parts(payload.get("contents", [])):
tools.append({"code_execution": {}})
if model.endswith("-search"):
tools.append({"googleSearch": {}})
return tools
def _has_image_parts(self, contents: List[Dict[str, Any]]) -> bool:
"""判断消息是否包含图片部分"""
for content in contents:
if "parts" in content:
for part in content["parts"]:
if "image_url" in part or "inline_data" in part:
return True
return False
def _get_safety_settings(self) -> List[Dict[str, str]]:
"""获取安全设置"""
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}
]

View File

@@ -0,0 +1,136 @@
# app/services/chat_service.py
import json
from typing import Dict, Any, AsyncGenerator, List, Union
from app.core.logger import get_openai_logger
from app.services.chat.message_converter import OpenAIMessageConverter
from app.services.chat.response_handler import OpenAIResponseHandler
from app.services.chat.api_client import GeminiApiClient
from app.schemas.openai_models import ChatRequest
from app.core.config import settings
from app.services.key_manager import KeyManager
logger = get_openai_logger()
class OpenAIChatService:
"""聊天服务"""
def __init__(self, base_url: str, key_manager: KeyManager):
self.message_converter = OpenAIMessageConverter()
self.response_handler = OpenAIResponseHandler(config=None)
self.api_client = GeminiApiClient(base_url)
self.key_manager = key_manager
async def create_chat_completion(
self,
request: ChatRequest,
api_key: str,
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
"""创建聊天完成"""
# 转换消息格式
messages = self.message_converter.convert(request.messages)
# 构建请求payload
payload = self._build_payload(request, messages)
if request.stream:
return self._handle_stream_completion(request.model, payload, api_key)
return self._handle_normal_completion(request.model, payload, api_key)
def _handle_normal_completion(
self,
model: str,
payload: Dict[str, Any],
api_key: str
) -> Dict[str, Any]:
"""处理普通聊天完成"""
response = self.api_client.generate_content(payload, model, api_key)
return self.response_handler.handle_response(
response,
model,
stream=False,
finish_reason="stop"
)
async def _handle_stream_completion(
self,
model: str,
payload: Dict[str, Any],
api_key: str
) -> AsyncGenerator[str, None]:
"""处理流式聊天完成,添加重试逻辑"""
retries = 0
max_retries = 3
while retries < max_retries:
try:
async for line in self.api_client.stream_generate_content(payload, model, api_key):
if line.startswith("data:"):
chunk = json.loads(line[6:])
openai_chunk = self.response_handler.handle_response(
chunk,
model,
stream=True,
finish_reason=None
)
if openai_chunk:
yield f"data: {json.dumps(openai_chunk)}\n\n"
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
yield "data: [DONE]\n\n"
logger.info("Streaming completed successfully")
break # 成功后退出循环
except Exception as e:
retries += 1
logger.warning(f"Streaming API call failed with error: {str(e)}. Attempt {retries} of {max_retries}")
api_key = await self.key_manager.handle_api_failure(api_key)
logger.info(f"Switched to new API key: {api_key}")
if retries >= max_retries:
logger.error(f"Max retries ({max_retries}) reached for streaming. Raising error")
yield f"data: {json.dumps({'error': 'Streaming failed after retries'})}\n\n"
yield "data: [DONE]\n\n"
break
def _build_payload(self, request: ChatRequest, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
"""构建请求payload"""
return {
"contents": messages,
"generationConfig": {
"temperature": request.temperature,
"maxOutputTokens": request.max_tokens,
"stopSequences": request.stop,
"topP": request.top_p,
"topK": request.top_k
},
"tools": self._build_tools(request, messages),
"safetySettings": self._get_safety_settings()
}
def _build_tools(self, request: ChatRequest, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""构建工具"""
tools = []
model = request.model
if settings.TOOLS_CODE_EXECUTION_ENABLED and not (
model.endswith("-search") or "-thinking" in model
) and not self._has_image_parts(messages):
tools.append({"code_execution": {}})
if model.endswith("-search"):
tools.append({"googleSearch": {}})
return tools
def _has_image_parts(self, contents: List[Dict[str, Any]]) -> bool:
"""判断消息是否包含图片部分"""
for content in contents:
if "parts" in content:
for part in content["parts"]:
if "image_url" in part or "inline_data" in part:
return True
return False
def _get_safety_settings(self) -> List[Dict[str, str]]:
"""获取安全设置"""
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}
]

View File