refactor: 重构Gemini和OpenAI聊天服务以支持工具和安全设置

- 将 `_build_payload`、`_build_tools`、`_get_safety_settings` 和 `_has_image_parts` 函数从 `OpenAIChatService` 和 `GeminiChatService` 类中提取为独立的函数。
- 将 `_handle_stream_response` 和 `_handle_normal_response` 函数从 `GeminiResponseHandler` 和 `OpenAIResponseHandler` 类中提取为独立的函数。
- 将 `_extract_text` 函数从 `OpenAIResponseHandler` 类中提取为独立的函数, 并在 `GeminiResponseHandler` 中复用。
- 将 `_convert_image` 函数从 `OpenAIMessageConverter` 类中提取为独立的函数。
- 优化 `OpenAIChatService` 和 `GeminiChatService` 中的代码结构, 使其更清晰。
- 优化 `app/api/openai_routes.py` 和 `app/api/gemini_routes.py` 中的路由函数, 移除不必要的参数。
This commit is contained in:
yinpeng
2025-02-06 21:35:19 +08:00
parent b60b063034
commit cd45f4b5ab
12 changed files with 378 additions and 475 deletions

View File

@@ -1,5 +1,4 @@
from http.client import HTTPException
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from app.core.config import settings
@@ -10,6 +9,7 @@ from app.services.gemini_chat_service import GeminiChatService
from app.services.key_manager import KeyManager
from app.services.model_service import ModelService
from app.services.chat.retry_handler import RetryHandler
router = APIRouter(prefix="/gemini/v1beta")
router_v1beta = APIRouter(prefix="/v1beta")
logger = get_gemini_logger()
@@ -22,26 +22,29 @@ model_service = ModelService(settings.MODEL_SEARCH)
@router.get("/models")
@router_v1beta.get("/models")
async def list_models(
key: str = None,
token: str = Depends(security_service.verify_key),
):
async def list_models(_=Depends(security_service.verify_key)):
"""获取可用的Gemini模型列表"""
logger.info("-" * 50 + "list_gemini_models" + "-" * 50)
logger.info("Handling Gemini models list request")
api_key = await key_manager.get_next_working_key()
logger.info(f"Using API key: {api_key}")
models_json = model_service.get_gemini_models(api_key)
models_json["models"].append({"name": "models/gemini-2.0-flash-exp-search", "version": "2.0", "displayName": "Gemini 2.0 Flash Search Experimental", "description": "Gemini 2.0 Flash Search Experimental", "inputTokenLimit": 32767, "outputTokenLimit": 8192, "supportedGenerationMethods": ["generateContent", "countTokens"], "temperature": 1, "topP": 0.95, "topK": 64, "maxTemperature": 2})
models_json["models"].append({"name": "models/gemini-2.0-flash-exp-search", "version": "2.0",
"displayName": "Gemini 2.0 Flash Search Experimental",
"description": "Gemini 2.0 Flash Search Experimental", "inputTokenLimit": 32767,
"outputTokenLimit": 8192,
"supportedGenerationMethods": ["generateContent", "countTokens"], "temperature": 1,
"topP": 0.95, "topK": 64, "maxTemperature": 2})
return models_json
@router.post("/models/{model_name}:generateContent")
@router_v1beta.post("/models/{model_name}:generateContent")
@RetryHandler(max_retries=3, key_manager=key_manager, key_arg="api_key")
async def generate_content(
model_name: str,
request: GeminiRequest,
x_goog_api_key: str = Depends(security_service.verify_goog_api_key),
_=Depends(security_service.verify_goog_api_key),
api_key: str = Depends(key_manager.get_next_working_key),
):
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
@@ -70,7 +73,7 @@ async def generate_content(
async def stream_generate_content(
model_name: str,
request: GeminiRequest,
x_goog_api_key: str = Depends(security_service.verify_goog_api_key),
_=Depends(security_service.verify_goog_api_key),
api_key: str = Depends(key_manager.get_next_working_key),
):
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
@@ -81,7 +84,7 @@ async def stream_generate_content(
logger.info(f"Using API key: {api_key}")
try:
response_stream =chat_service.stream_generate_content(
response_stream = chat_service.stream_generate_content(
model=model_name,
request=request,
api_key=api_key

View File

@@ -1,16 +1,15 @@
from http.client import HTTPException
from fastapi import APIRouter, Depends, Header
from fastapi import HTTPException, APIRouter, Depends
from fastapi.responses import StreamingResponse
from app.core.config import settings
from app.core.logger import get_openai_logger
from app.core.security import SecurityService
from app.schemas.openai_models import ChatRequest, EmbeddingRequest
from app.services.chat.retry_handler import RetryHandler
from app.services.embedding_service import EmbeddingService
from app.services.key_manager import KeyManager
from app.services.model_service import ModelService
from app.services.openai_chat_service import OpenAIChatService
from app.services.embedding_service import EmbeddingService
from app.schemas.openai_models import ChatRequest, EmbeddingRequest
from app.core.config import settings
from app.core.logger import get_openai_logger
router = APIRouter()
logger = get_openai_logger()
@@ -24,10 +23,7 @@ embedding_service = EmbeddingService(settings.BASE_URL)
@router.get("/v1/models")
@router.get("/hf/v1/models")
async def list_models(
authorization: str = Header(None),
token: str = Depends(security_service.verify_authorization),
):
async def list_models(_=Depends(security_service.verify_authorization)):
logger.info("-" * 50 + "list_models" + "-" * 50)
logger.info("Handling models list request")
api_key = await key_manager.get_next_working_key()
@@ -43,10 +39,9 @@ async def list_models(
@router.post("/hf/v1/chat/completions")
@RetryHandler(max_retries=3, key_manager=key_manager, key_arg="api_key")
async def chat_completion(
request: ChatRequest,
authorization: str = Header(None),
token: str = Depends(security_service.verify_authorization),
api_key: str = Depends(key_manager.get_next_working_key),
request: ChatRequest,
_=Depends(security_service.verify_authorization),
api_key: str = Depends(key_manager.get_next_working_key),
):
chat_service = OpenAIChatService(settings.BASE_URL, key_manager)
logger.info("-" * 50 + "chat_completion" + "-" * 50)
@@ -67,15 +62,13 @@ async def chat_completion(
except Exception as e:
logger.error(f"Chat completion failed after retries: {str(e)}")
raise HTTPException(status_code=500, detail="Chat completion failed") from e
@router.post("/v1/embeddings")
@router.post("/hf/v1/embeddings")
async def embedding(
request: EmbeddingRequest,
authorization: str = Header(None),
token: str = Depends(security_service.verify_authorization),
request: EmbeddingRequest,
_=Depends(security_service.verify_authorization),
):
logger.info("-" * 50 + "embedding" + "-" * 50)
logger.info(f"Handling embedding request for model: {request.model}")
@@ -95,8 +88,7 @@ async def embedding(
@router.get("/v1/keys/list")
@router.get("/hf/v1/keys/list")
async def get_keys_list(
authorization: str = Header(None),
token: str = Depends(security_service.verify_auth_token),
_=Depends(security_service.verify_auth_token),
):
"""获取有效和无效的API key列表"""
logger.info("-" * 50 + "get_keys_list" + "-" * 50)

View File

@@ -128,4 +128,4 @@ def get_request_logger():
def get_retry_logger():
return Logger.setup_logger("retry")
return Logger.setup_logger("retry")

View File

@@ -17,7 +17,7 @@ class SecurityService:
return key
async def verify_authorization(
self, authorization: Optional[str] = Header(None)
self, authorization: Optional[str] = Header(None)
) -> str:
if not authorization:
logger.error("Missing Authorization header")
@@ -45,7 +45,7 @@ class SecurityService:
if x_goog_api_key not in self.allowed_tokens and x_goog_api_key != self.auth_token:
logger.error("Invalid x-goog-api-key")
raise HTTPException(status_code=401, detail="Invalid x-goog-api-key")
return x_goog_api_key
async def verify_auth_token(self, authorization: Optional[str] = Header(None)) -> str:
@@ -56,5 +56,5 @@ class SecurityService:
if token != self.auth_token:
logger.error("Invalid auth_token")
raise HTTPException(status_code=401, detail="Invalid auth_token")
return token
return token

View File

@@ -3,10 +3,8 @@ from pydantic import BaseModel
class SafetySetting(BaseModel):
category: Optional[Literal[
"HARM_CATEGORY_HATE_SPEECH", "HARM_CATEGORY_DANGEROUS_CONTENT", "HARM_CATEGORY_HARASSMENT", "HARM_CATEGORY_SEXUALLY_EXPLICIT", "HARM_CATEGORY_CIVIC_INTEGRITY"]] = None
threshold: Optional[Literal[
"HARM_BLOCK_THRESHOLD_UNSPECIFIED", "BLOCK_LOW_AND_ABOVE", "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_ONLY_HIGH", "BLOCK_NONE", "OFF"]] = None
category: Optional[Literal["HARM_CATEGORY_HATE_SPEECH", "HARM_CATEGORY_DANGEROUS_CONTENT", "HARM_CATEGORY_HARASSMENT", "HARM_CATEGORY_SEXUALLY_EXPLICIT", "HARM_CATEGORY_CIVIC_INTEGRITY"]] = None
threshold: Optional[Literal["HARM_BLOCK_THRESHOLD_UNSPECIFIED", "BLOCK_LOW_AND_ABOVE", "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_ONLY_HIGH", "BLOCK_NONE", "OFF"]] = None
class GenerationConfig(BaseModel):

View File

@@ -4,24 +4,26 @@ from typing import Dict, Any, AsyncGenerator
import httpx
from abc import ABC, abstractmethod
class ApiClient(ABC):
"""API客户端基类"""
@abstractmethod
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
pass
@abstractmethod
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
pass
class GeminiApiClient(ApiClient):
"""Gemini API客户端"""
def __init__(self, base_url: str, timeout: int = 300):
self.base_url = base_url
self.timeout = timeout
def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
if model.endswith("-search"):
@@ -33,14 +35,14 @@ class GeminiApiClient(ApiClient):
error_content = response.text
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
return response.json()
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
if model.endswith("-search"):
model = model[:-7]
async with httpx.AsyncClient(timeout=timeout) as client:
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
async with client.stream("POST", url, json=payload) as response:
async with client.stream(method="POST", url=url, json=payload) as response:
if response.status_code != 200:
error_content = await response.aread()
error_msg = error_content.decode("utf-8")

View File

@@ -3,22 +3,39 @@
from abc import ABC, abstractmethod
from typing import List, Dict, Any
class MessageConverter(ABC):
"""消息转换器基类"""
@abstractmethod
def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
pass
def _convert_image(image_url: str) -> Dict[str, Any]:
if image_url.startswith("data:image"):
return {
"inline_data": {
"mime_type": "image/jpeg",
"data": image_url.split(",")[1]
}
}
return {
"image_url": {
"url": image_url
}
}
class OpenAIMessageConverter(MessageConverter):
"""OpenAI消息格式转换器"""
def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
converted_messages = []
for msg in messages:
role = "user" if msg["role"] == "user" else "model"
parts = []
if isinstance(msg["content"], str):
parts.append({"text": msg["content"]})
elif isinstance(msg["content"], list):
@@ -29,22 +46,8 @@ class OpenAIMessageConverter(MessageConverter):
if content["type"] == "text":
parts.append({"text": content["text"]})
elif content["type"] == "image_url":
parts.append(self._convert_image(content["image_url"]["url"]))
parts.append(_convert_image(content["image_url"]["url"]))
converted_messages.append({"role": role, "parts": parts})
return converted_messages
def _convert_image(self, image_url: str) -> Dict[str, Any]:
if image_url.startswith("data:image"):
return {
"inline_data": {
"mime_type": "image/jpeg",
"data": image_url.split(",")[1]
}
}
return {
"image_url": {
"url": image_url
}
}

View File

@@ -6,331 +6,225 @@ import time
import uuid
from app.core.config import settings
class ResponseHandler(ABC):
"""响应处理器基类"""
@abstractmethod
def handle_response(self, response: Dict[str, Any], model: str, stream: bool = False) -> Dict[str, Any]:
pass
class GeminiResponseHandler(ResponseHandler):
"""Gemini响应处理器"""
def __init__(self):
self.thinking_first = True
self.thinking_status = False
def handle_response(self, response: Dict[str, Any], model: str, stream: bool = False) -> Dict[str, Any]:
if stream:
return self._handle_stream_response(response, model, stream)
return self._handle_normal_response(response, model, stream)
def _handle_stream_response(self, response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]:
text = self._extract_text(response, model, stream=stream)
content = {"parts": [{"text": text}],"role": "model"}
response["candidates"][0]["content"] = content
return response
return _handle_gemini_stream_response(response, model, stream)
return _handle_gemini_normal_response(response, model, stream)
def _handle_normal_response(self, response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]:
text = self._extract_text(response, model, stream=stream)
content = {"parts": [{"text": text}],"role": "model"}
response["candidates"][0]["content"] = content
return response
def _extract_text(self, response: Dict[str, Any], model: str, stream: bool = False) -> str:
text = ""
if stream:
if response.get("candidates"):
candidate = response["candidates"][0]
content = candidate.get("content", {})
parts = content.get("parts", [])
# if "thinking" in model:
# if settings.SHOW_THINKING_PROCESS:
# if len(parts) == 1:
# if self.thinking_first:
# self.thinking_first = False
# self.thinking_status = True
# text = "> thinking\n\n" + parts[0].get("text")
# else:
# text = parts[0].get("text")
def _handle_openai_stream_response(response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]:
text = _extract_text(response, model, stream=True)
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"delta": {"content": text} if text else {},
"finish_reason": finish_reason
}]
}
# if len(parts) == 2:
# self.thinking_status = False
# if self.thinking_first:
# self.thinking_first = False
# text = (
# "> thinking\n\n"
# + parts[0].get("text")
# + "\n\n---\n> output\n\n"
# + parts[1].get("text")
# )
# else:
# text = (
# parts[0].get("text")
# + "\n\n---\n> output\n\n"
# + parts[1].get("text")
# )
# else:
# if len(parts) == 1:
# if self.thinking_first:
# self.thinking_first = False
# self.thinking_status = True
# text = ""
# elif self.thinking_status:
# text = ""
# else:
# text = parts[0].get("text")
# if len(parts) == 2:
# self.thinking_status = False
# if self.thinking_first:
# self.thinking_first = False
# text = parts[1].get("text")
# else:
# text = parts[1].get("text")
# else:
# if "text" in parts[0]:
# text = parts[0].get("text")
# elif "executableCode" in parts[0]:
# text = _format_code_block(parts[0]["executableCode"])
# elif "codeExecution" in parts[0]:
# text = _format_code_block(parts[0]["codeExecution"])
# elif "executableCodeResult" in parts[0]:
# text = _format_execution_result(
# parts[0]["executableCodeResult"]
# )
# elif "codeExecutionResult" in parts[0]:
# text = _format_execution_result(
# parts[0]["codeExecutionResult"]
# )
# else:
# text = ""
if "text" in parts[0]:
text = parts[0].get("text")
elif "executableCode" in parts[0]:
text = _format_code_block(parts[0]["executableCode"])
elif "codeExecution" in parts[0]:
text = _format_code_block(parts[0]["codeExecution"])
elif "executableCodeResult" in parts[0]:
text = _format_execution_result(
parts[0]["executableCodeResult"]
)
elif "codeExecutionResult" in parts[0]:
text = _format_execution_result(
parts[0]["codeExecutionResult"]
)
else:
text = ""
text = _add_search_link_text(model, candidate, text)
else:
if response.get("candidates"):
candidate = response["candidates"][0]
if "thinking" in model:
if settings.SHOW_THINKING_PROCESS:
if len(candidate["content"]["parts"]) == 2:
text = (
"> thinking\n\n"
+ candidate["content"]["parts"][0]["text"]
+ "\n\n---\n> output\n\n"
+ candidate["content"]["parts"][1]["text"]
)
else:
text = candidate["content"]["parts"][0]["text"]
else:
if len(candidate["content"]["parts"]) == 2:
text = candidate["content"]["parts"][1]["text"]
else:
text = candidate["content"]["parts"][0]["text"]
else:
text = candidate["content"]["parts"][0]["text"]
text = _add_search_link_text(model, candidate, text)
else:
text = "暂无返回"
return text
def _handle_openai_normal_response(response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]:
text = _extract_text(response, model, stream=False)
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": text
},
"finish_reason": finish_reason
}],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
}
class OpenAIResponseHandler(ResponseHandler):
"""OpenAI响应处理器"""
def __init__(self, config):
self.config = config
self.thinking_first = True
self.thinking_status = False
def handle_response(
self,
response: Dict[str, Any],
model: str,
stream: bool = False,
finish_reason: str = None
self,
response: Dict[str, Any],
model: str,
stream: bool = False,
finish_reason: str = None
) -> Optional[Dict[str, Any]]:
if stream:
return self._handle_stream_response(response, model, finish_reason)
return self._handle_normal_response(response, model, finish_reason)
def _handle_stream_response(self, response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]:
text = self._extract_text(response, model, stream=True)
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"delta": {"content": text} if text else {},
"finish_reason": finish_reason
}]
}
def _handle_normal_response(self, response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]:
text = self._extract_text(response, model, stream=False)
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": text
},
"finish_reason": finish_reason
}],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
}
def _extract_text(self, response: Dict[str, Any], model: str, stream: bool = False) -> str:
text = ""
if stream:
if response.get("candidates"):
candidate = response["candidates"][0]
content = candidate.get("content", {})
parts = content.get("parts", [])
# if "thinking" in model:
# if settings.SHOW_THINKING_PROCESS:
# if len(parts) == 1:
# if self.thinking_first:
# self.thinking_first = False
# self.thinking_status = True
# text = "> thinking\n\n" + parts[0].get("text")
# else:
# text = parts[0].get("text")
return _handle_openai_stream_response(response, model, finish_reason)
return _handle_openai_normal_response(response, model, finish_reason)
# if len(parts) == 2:
# self.thinking_status = False
# if self.thinking_first:
# self.thinking_first = False
# text = (
# "> thinking\n\n"
# + parts[0].get("text")
# + "\n\n---\n> output\n\n"
# + parts[1].get("text")
# )
# else:
# text = (
# parts[0].get("text")
# + "\n\n---\n> output\n\n"
# + parts[1].get("text")
# )
# else:
# if len(parts) == 1:
# if self.thinking_first:
# self.thinking_first = False
# self.thinking_status = True
# text = ""
# elif self.thinking_status:
# text = ""
# else:
# text = parts[0].get("text")
# if len(parts) == 2:
# self.thinking_status = False
# if self.thinking_first:
# self.thinking_first = False
# text = parts[1].get("text")
# else:
# text = parts[1].get("text")
# else:
# if "text" in parts[0]:
# text = parts[0].get("text")
# elif "executableCode" in parts[0]:
# text = _format_code_block(parts[0]["executableCode"])
# elif "codeExecution" in parts[0]:
# text = _format_code_block(parts[0]["codeExecution"])
# elif "executableCodeResult" in parts[0]:
# text = _format_execution_result(
# parts[0]["executableCodeResult"]
# )
# elif "codeExecutionResult" in parts[0]:
# text = _format_execution_result(
# parts[0]["codeExecutionResult"]
# )
# else:
# text = ""
# text = _add_search_link_text(model, candidate, text)
if "text" in parts[0]:
text = parts[0].get("text")
elif "executableCode" in parts[0]:
text = _format_code_block(parts[0]["executableCode"])
elif "codeExecution" in parts[0]:
text = _format_code_block(parts[0]["codeExecution"])
elif "executableCodeResult" in parts[0]:
text = _format_execution_result(
parts[0]["executableCodeResult"]
)
elif "codeExecutionResult" in parts[0]:
text = _format_execution_result(
parts[0]["codeExecutionResult"]
)
else:
text = ""
text = _add_search_link_text(model, candidate, text)
else:
if response.get("candidates"):
candidate = response["candidates"][0]
if "thinking" in model:
if settings.SHOW_THINKING_PROCESS:
if len(candidate["content"]["parts"]) == 2:
text = (
def _extract_text(response: Dict[str, Any], model: str, stream: bool = False) -> str:
text = ""
if stream:
if response.get("candidates"):
candidate = response["candidates"][0]
content = candidate.get("content", {})
parts = content.get("parts", [])
# if "thinking" in model:
# if settings.SHOW_THINKING_PROCESS:
# if len(parts) == 1:
# if self.thinking_first:
# self.thinking_first = False
# self.thinking_status = True
# text = "> thinking\n\n" + parts[0].get("text")
# else:
# text = parts[0].get("text")
# if len(parts) == 2:
# self.thinking_status = False
# if self.thinking_first:
# self.thinking_first = False
# text = (
# "> thinking\n\n"
# + parts[0].get("text")
# + "\n\n---\n> output\n\n"
# + parts[1].get("text")
# )
# else:
# text = (
# parts[0].get("text")
# + "\n\n---\n> output\n\n"
# + parts[1].get("text")
# )
# else:
# if len(parts) == 1:
# if self.thinking_first:
# self.thinking_first = False
# self.thinking_status = True
# text = ""
# elif self.thinking_status:
# text = ""
# else:
# text = parts[0].get("text")
# if len(parts) == 2:
# self.thinking_status = False
# if self.thinking_first:
# self.thinking_first = False
# text = parts[1].get("text")
# else:
# text = parts[1].get("text")
# else:
# if "text" in parts[0]:
# text = parts[0].get("text")
# elif "executableCode" in parts[0]:
# text = _format_code_block(parts[0]["executableCode"])
# elif "codeExecution" in parts[0]:
# text = _format_code_block(parts[0]["codeExecution"])
# elif "executableCodeResult" in parts[0]:
# text = _format_execution_result(
# parts[0]["executableCodeResult"]
# )
# elif "codeExecutionResult" in parts[0]:
# text = _format_execution_result(
# parts[0]["codeExecutionResult"]
# )
# else:
# text = ""
if "text" in parts[0]:
text = parts[0].get("text")
elif "executableCode" in parts[0]:
text = _format_code_block(parts[0]["executableCode"])
elif "codeExecution" in parts[0]:
text = _format_code_block(parts[0]["codeExecution"])
elif "executableCodeResult" in parts[0]:
text = _format_execution_result(
parts[0]["executableCodeResult"]
)
elif "codeExecutionResult" in parts[0]:
text = _format_execution_result(
parts[0]["codeExecutionResult"]
)
else:
text = ""
text = _add_search_link_text(model, candidate, text)
else:
if response.get("candidates"):
candidate = response["candidates"][0]
if "thinking" in model:
if settings.SHOW_THINKING_PROCESS:
if len(candidate["content"]["parts"]) == 2:
text = (
"> thinking\n\n"
+ candidate["content"]["parts"][0]["text"]
+ "\n\n---\n> output\n\n"
+ candidate["content"]["parts"][1]["text"]
)
else:
text = candidate["content"]["parts"][0]["text"]
)
else:
if len(candidate["content"]["parts"]) == 2:
text = candidate["content"]["parts"][1]["text"]
else:
text = candidate["content"]["parts"][0]["text"]
text = candidate["content"]["parts"][0]["text"]
else:
text = candidate["content"]["parts"][0]["text"]
text = _add_search_link_text(model, candidate, text)
if len(candidate["content"]["parts"]) == 2:
text = candidate["content"]["parts"][1]["text"]
else:
text = candidate["content"]["parts"][0]["text"]
else:
text = "暂无返回"
return text
text = candidate["content"]["parts"][0]["text"]
text = _add_search_link_text(model, candidate, text)
else:
text = "暂无返回"
return text
def _handle_gemini_stream_response(response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]:
text = _extract_text(response, model, stream=stream)
content = {"parts": [{"text": text}], "role": "model"}
response["candidates"][0]["content"] = content
return response
def _handle_gemini_normal_response(response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]:
text = _extract_text(response, model, stream=stream)
content = {"parts": [{"text": text}], "role": "model"}
response["candidates"][0]["content"] = content
return response
def _format_code_block(code_data: dict) -> str:
"""格式化代码块输出"""
language = code_data.get("language", "").lower()
code = code_data.get("code", "").strip()
return f"""\n\n---\n\n【代码执行】\n```{language}\n{code}\n```\n"""
def _add_search_link_text(model:str, candidate:dict, text:str) -> str:
def _add_search_link_text(model: str, candidate: dict, text: str) -> str:
if (
settings.SHOW_SEARCH_LINK
and model.endswith("-search")
and "groundingMetadata" in candidate
and "groundingChunks" in candidate["groundingMetadata"]
settings.SHOW_SEARCH_LINK
and model.endswith("-search")
and "groundingMetadata" in candidate
and "groundingChunks" in candidate["groundingMetadata"]
):
grounding_chunks = candidate["groundingMetadata"]["groundingChunks"]
text += "\n\n---\n\n"
@@ -351,4 +245,4 @@ def _format_execution_result(result_data: dict) -> str:
"""格式化执行结果输出"""
outcome = result_data.get("outcome", "")
output = result_data.get("output", "").strip()
return f"""\n【执行结果】\n> outcome: {outcome}\n\n【输出结果】\n```plaintext\n{output}\n```\n\n---\n\n"""
return f"""\n【执行结果】\n> outcome: {outcome}\n\n【输出结果】\n```plaintext\n{output}\n```\n\n---\n\n"""

View File

@@ -8,26 +8,27 @@ from app.services.key_manager import KeyManager
T = TypeVar('T')
logger = get_retry_logger()
class RetryHandler:
"""重试处理装饰器"""
def __init__(self, max_retries: int = 3, key_manager: KeyManager = None, key_arg: str = "api_key"):
self.max_retries = max_retries
self.key_manager = key_manager
self.key_arg = key_arg
def __call__(self, func: Callable[..., T]) -> Callable[..., T]:
@wraps(func)
async def wrapper(*args, **kwargs) -> T:
last_exception = None
for attempt in range(self.max_retries):
try:
return await func(*args, **kwargs)
except Exception as e:
last_exception = e
logger.warning(f"API call failed with error: {str(e)}. Attempt {attempt + 1} of {self.max_retries}")
if self.key_manager:
old_key = kwargs.get(self.key_arg)
new_key = await self.key_manager.handle_api_failure(old_key)
@@ -36,5 +37,5 @@ class RetryHandler:
logger.error(f"All retry attempts failed, raising final exception: {str(last_exception)}")
raise last_exception
return wrapper
return wrapper

View File

@@ -10,6 +10,61 @@ from app.services.chat.response_handler import GeminiResponseHandler
from app.services.key_manager import KeyManager
logger = get_gemini_logger()
def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
"""判断消息是否包含图片部分"""
for content in contents:
if "parts" in content:
for part in content["parts"]:
if "image_url" in part or "inline_data" in part:
return True
return False
def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
"""构建工具"""
tools = []
if settings.TOOLS_CODE_EXECUTION_ENABLED and not (
model.endswith("-search") or "-thinking" in model
) and not _has_image_parts(payload.get("contents", [])):
tools.append({"code_execution": {}})
if model.endswith("-search"):
tools.append({"googleSearch": {}})
return tools
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
"""获取安全设置"""
if model == "gemini-2.0-flash-exp":
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"}
]
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}
]
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
"""构建请求payload"""
payload = request.model_dump()
return {
"contents": payload.get("contents", []),
"tools": _build_tools(model, payload),
"safetySettings": _get_safety_settings(model),
"generationConfig": payload.get("generationConfig", {}),
"systemInstruction": payload.get("systemInstruction", [])
}
class GeminiChatService:
"""聊天服务"""
@@ -17,18 +72,18 @@ class GeminiChatService:
self.api_client = GeminiApiClient(base_url)
self.key_manager = key_manager
self.response_handler = GeminiResponseHandler()
def generate_content(self, model: str, request: GeminiRequest, api_key: str) -> Dict[str, Any]:
"""生成内容"""
payload = self._build_payload(model, request)
payload = _build_payload(model, request)
response = self.api_client.generate_content(payload, model, api_key)
return self.response_handler.handle_response(response, model, stream=False)
async def stream_generate_content(self, model: str, request: GeminiRequest, api_key: str) -> AsyncGenerator[str, None]:
"""流式生成内容"""
retries = 0
max_retries = 3
payload = self._build_payload(model, request)
payload = _build_payload(model, request)
while retries < max_retries:
try:
async for line in self.api_client.stream_generate_content(payload, model, api_key):
@@ -47,52 +102,3 @@ class GeminiChatService:
if retries >= max_retries:
logger.error(f"Max retries ({max_retries}) reached for streaming. Raising error")
break
def _build_payload(self,model: str, request: GeminiRequest) -> Dict[str, Any]:
"""构建请求payload"""
payload = request.model_dump()
return {
"contents": payload.get("contents", []),
"tools": self._build_tools(model, payload),
"safetySettings": self._get_safety_settings(model),
"generationConfig": payload.get("generationConfig", {}),
"systemInstruction": payload.get("systemInstruction", [])
}
def _build_tools(self, model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
"""构建工具"""
tools = []
if settings.TOOLS_CODE_EXECUTION_ENABLED and not (
model.endswith("-search") or "-thinking" in model
) and not self._has_image_parts(payload.get("contents", [])):
tools.append({"code_execution": {}})
if model.endswith("-search"):
tools.append({"googleSearch": {}})
return tools
def _has_image_parts(self, contents: List[Dict[str, Any]]) -> bool:
"""判断消息是否包含图片部分"""
for content in contents:
if "parts" in content:
for part in content["parts"]:
if "image_url" in part or "inline_data" in part:
return True
return False
def _get_safety_settings(self, model: str) -> List[Dict[str, str]]:
"""获取安全设置"""
if model == "gemini-2.0-flash-exp":
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"}
]
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}
]

View File

@@ -36,9 +36,9 @@ class ModelService:
return None
def convert_to_openai_models_format(
self, gemini_models: Dict[str, Any]
self, gemini_models: Dict[str, Any]
) -> Dict[str, Any]:
openai_format = {"object": "list", "data": [],"success": True}
openai_format = {"object": "list", "data": [], "success": True}
for model in gemini_models.get("models", []):
model_id = model["name"].split("/")[-1]

View File

@@ -13,6 +13,76 @@ from app.services.key_manager import KeyManager
logger = get_openai_logger()
def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
"""判断消息是否包含图片部分"""
for content in contents:
if "parts" in content:
for part in content["parts"]:
if "image_url" in part or "inline_data" in part:
return True
return False
def _build_tools(
request: ChatRequest, messages: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""构建工具"""
tools = []
model = request.model
if (
settings.TOOLS_CODE_EXECUTION_ENABLED
and not (model.endswith("-search") or "-thinking" in model)
and not _has_image_parts(messages)
):
tools.append({"code_execution": {}})
if model.endswith("-search"):
tools.append({"googleSearch": {}})
return tools
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
"""获取安全设置"""
# if (
# "2.0" in model
# and "gemini-2.0-flash-thinking-exp" not in model
# and "gemini-2.0-pro-exp" not in model
# ):
if model == "gemini-2.0-flash-exp":
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
]
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
]
def _build_payload(
request: ChatRequest, messages: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""构建请求payload"""
return {
"contents": messages,
"generationConfig": {
"temperature": request.temperature,
"maxOutputTokens": request.max_tokens,
"stopSequences": request.stop,
"topP": request.top_p,
"topK": request.top_k,
},
"tools": _build_tools(request, messages),
"safetySettings": _get_safety_settings(request.model),
}
class OpenAIChatService:
"""聊天服务"""
@@ -32,7 +102,7 @@ class OpenAIChatService:
messages = self.message_converter.convert(request.messages)
# 构建请求payload
payload = self._build_payload(request, messages)
payload = _build_payload(request, messages)
if request.stream:
return self._handle_stream_completion(request.model, payload, api_key)
@@ -84,69 +154,3 @@ class OpenAIChatService:
yield f"data: {json.dumps({'error': 'Streaming failed after retries'})}\n\n"
yield "data: [DONE]\n\n"
break
def _build_payload(
self, request: ChatRequest, messages: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""构建请求payload"""
return {
"contents": messages,
"generationConfig": {
"temperature": request.temperature,
"maxOutputTokens": request.max_tokens,
"stopSequences": request.stop,
"topP": request.top_p,
"topK": request.top_k,
},
"tools": self._build_tools(request, messages),
"safetySettings": self._get_safety_settings(request.model),
}
def _build_tools(
self, request: ChatRequest, messages: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""构建工具"""
tools = []
model = request.model
if (
settings.TOOLS_CODE_EXECUTION_ENABLED
and not (model.endswith("-search") or "-thinking" in model)
and not self._has_image_parts(messages)
):
tools.append({"code_execution": {}})
if model.endswith("-search"):
tools.append({"googleSearch": {}})
return tools
def _has_image_parts(self, contents: List[Dict[str, Any]]) -> bool:
"""判断消息是否包含图片部分"""
for content in contents:
if "parts" in content:
for part in content["parts"]:
if "image_url" in part or "inline_data" in part:
return True
return False
def _get_safety_settings(self, model: str) -> List[Dict[str, str]]:
"""获取安全设置"""
# if (
# "2.0" in model
# and "gemini-2.0-flash-thinking-exp" not in model
# and "gemini-2.0-pro-exp" not in model
# ):
if model == "gemini-2.0-flash-exp":
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
]
return [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
]