mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-07-04 06:11:32 +08:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
88d483c1ef | ||
|
|
8d48db026c | ||
|
|
a592269198 | ||
|
|
18a5fe6109 | ||
|
|
348cbbdf2a |
@@ -1,14 +1,16 @@
|
||||
# app/services/chat/message_converter.py
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
SUPPORTED_ROLES = ["user", "model", "system"]
|
||||
|
||||
|
||||
class MessageConverter(ABC):
|
||||
"""消息转换器基类"""
|
||||
|
||||
@abstractmethod
|
||||
def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||
pass
|
||||
|
||||
|
||||
@@ -30,24 +32,33 @@ def _convert_image(image_url: str) -> Dict[str, Any]:
|
||||
class OpenAIMessageConverter(MessageConverter):
|
||||
"""OpenAI消息格式转换器"""
|
||||
|
||||
def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||
converted_messages = []
|
||||
for msg in messages:
|
||||
role = "user" if msg["role"] == "user" else "model"
|
||||
parts = []
|
||||
system_instruction = None
|
||||
|
||||
if isinstance(msg["content"], str):
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
if role not in SUPPORTED_ROLES:
|
||||
role = "model"
|
||||
|
||||
parts = []
|
||||
if isinstance(msg["content"], str) and msg["content"]:
|
||||
# 请求 gemini 接口时如果包含 content 字段但内容为空时会返回 400 错误,所以需要判断是否为空并移除
|
||||
parts.append({"text": msg["content"]})
|
||||
elif isinstance(msg["content"], list):
|
||||
for content in msg["content"]:
|
||||
if isinstance(content, str):
|
||||
if isinstance(content, str) and content:
|
||||
parts.append({"text": content})
|
||||
elif isinstance(content, dict):
|
||||
if content["type"] == "text":
|
||||
if content["type"] == "text" and content["text"]:
|
||||
parts.append({"text": content["text"]})
|
||||
elif content["type"] == "image_url":
|
||||
parts.append(_convert_image(content["image_url"]["url"]))
|
||||
|
||||
converted_messages.append({"role": role, "parts": parts})
|
||||
if parts:
|
||||
if role == "system":
|
||||
system_instruction = {"role": "system", "parts": parts}
|
||||
else:
|
||||
converted_messages.append({"role": role, "parts": parts})
|
||||
|
||||
return converted_messages
|
||||
return converted_messages, system_instruction
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
# app/services/chat/response_handler.py
|
||||
|
||||
import json
|
||||
import random
|
||||
import string
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Dict, Any, List, Optional
|
||||
import time
|
||||
import uuid
|
||||
from app.core.config import settings
|
||||
@@ -29,40 +32,38 @@ class GeminiResponseHandler(ResponseHandler):
|
||||
|
||||
|
||||
def _handle_openai_stream_response(response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]:
|
||||
text = _extract_text(response, model, stream=True)
|
||||
text, tool_calls = _extract_result(response, model, stream=True, gemini_format=False)
|
||||
if not text and not tool_calls:
|
||||
delta = {}
|
||||
else:
|
||||
delta = {"content": text, "role": "assistant"}
|
||||
if tool_calls:
|
||||
delta["tool_calls"] = tool_calls
|
||||
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {"content": text} if text else {},
|
||||
"finish_reason": finish_reason
|
||||
}]
|
||||
"choices": [{"index": 0, "delta": delta, "finish_reason": finish_reason}],
|
||||
}
|
||||
|
||||
|
||||
def _handle_openai_normal_response(response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]:
|
||||
text = _extract_text(response, model, stream=False)
|
||||
text, tool_calls = _extract_result(response, model, stream=False, gemini_format=False)
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": text
|
||||
},
|
||||
"finish_reason": finish_reason
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": text, "tool_calls": tool_calls},
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
|
||||
@@ -127,8 +128,8 @@ def _handle_openai_normal_image_response(image_str: str,model: str,finish_reason
|
||||
}
|
||||
|
||||
|
||||
def _extract_text(response: Dict[str, Any], model: str, stream: bool = False) -> str:
|
||||
text = ""
|
||||
def _extract_result(response: Dict[str, Any], model: str, stream: bool = False, gemini_format: bool = False) -> tuple[str, List[Dict[str, Any]]]:
|
||||
text, tool_calls = "", []
|
||||
if stream:
|
||||
if response.get("candidates"):
|
||||
candidate = response["candidates"][0]
|
||||
@@ -212,6 +213,7 @@ def _extract_text(response: Dict[str, Any], model: str, stream: bool = False) ->
|
||||
else:
|
||||
text = ""
|
||||
text = _add_search_link_text(model, candidate, text)
|
||||
tool_calls = _extract_tool_calls(parts, gemini_format)
|
||||
else:
|
||||
if response.get("candidates"):
|
||||
candidate = response["candidates"][0]
|
||||
@@ -234,23 +236,65 @@ def _extract_text(response: Dict[str, Any], model: str, stream: bool = False) ->
|
||||
else:
|
||||
text = ""
|
||||
for part in candidate["content"]["parts"]:
|
||||
text += part["text"]
|
||||
text += part.get("text", "")
|
||||
text = _add_search_link_text(model, candidate, text)
|
||||
tool_calls = _extract_tool_calls(candidate["content"]["parts"], gemini_format)
|
||||
else:
|
||||
text = "暂无返回"
|
||||
return text
|
||||
return text, tool_calls
|
||||
|
||||
def _extract_tool_calls(parts: List[Dict[str, Any]], gemini_format: bool) -> List[Dict[str, Any]]:
|
||||
"""提取工具调用信息"""
|
||||
if not parts or not isinstance(parts, list):
|
||||
return []
|
||||
|
||||
letters = string.ascii_lowercase + string.digits
|
||||
|
||||
tool_calls = list()
|
||||
for i in range(len(parts)):
|
||||
part = parts[i]
|
||||
if not part or not isinstance(part, dict):
|
||||
continue
|
||||
|
||||
item = part.get("functionCall", {})
|
||||
if not item or not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
if gemini_format:
|
||||
tool_calls.append(part)
|
||||
else:
|
||||
id = f"call_{''.join(random.sample(letters, 32))}"
|
||||
name = item.get("name", "")
|
||||
arguments = json.dumps(item.get("args", None) or {})
|
||||
|
||||
tool_calls.append(
|
||||
{
|
||||
"index": i,
|
||||
"id": id,
|
||||
"type": "function",
|
||||
"function": {"name": name, "arguments": arguments},
|
||||
}
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
|
||||
def _handle_gemini_stream_response(response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]:
|
||||
text = _extract_text(response, model, stream=stream)
|
||||
content = {"parts": [{"text": text}], "role": "model"}
|
||||
text, tool_calls = _extract_result(response, model, stream=stream, gemini_format=True)
|
||||
if tool_calls:
|
||||
content = {"parts": tool_calls, "role": "model"}
|
||||
else:
|
||||
content = {"parts": [{"text": text}], "role": "model"}
|
||||
response["candidates"][0]["content"] = content
|
||||
return response
|
||||
|
||||
|
||||
def _handle_gemini_normal_response(response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]:
|
||||
text = _extract_text(response, model, stream=stream)
|
||||
content = {"parts": [{"text": text}], "role": "model"}
|
||||
text, tool_calls = _extract_result(response, model, stream=stream, gemini_format=True)
|
||||
if tool_calls:
|
||||
content = {"parts": tool_calls, "role": "model"}
|
||||
else:
|
||||
content = {"parts": [{"text": text}], "role": "model"}
|
||||
response["candidates"][0]["content"] = content
|
||||
return response
|
||||
|
||||
|
||||
@@ -31,6 +31,12 @@ def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
tools.append({"code_execution": {}})
|
||||
if model.endswith("-search"):
|
||||
tools.append({"googleSearch": {}})
|
||||
|
||||
if payload and isinstance(payload, dict) and "tools" in payload:
|
||||
items = payload.get("tools", [])
|
||||
if items and isinstance(items, list):
|
||||
tools.extend(items)
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# app/services/chat_service.py
|
||||
|
||||
from copy import deepcopy
|
||||
import json
|
||||
from typing import Dict, Any, AsyncGenerator, List, Union
|
||||
from typing import Dict, Any, AsyncGenerator, List, Optional, Union
|
||||
from app.core.logger import get_openai_logger
|
||||
from app.services.chat.message_converter import OpenAIMessageConverter
|
||||
from app.services.chat.response_handler import OpenAIResponseHandler
|
||||
@@ -39,6 +40,25 @@ def _build_tools(
|
||||
tools.append({"code_execution": {}})
|
||||
if model.endswith("-search"):
|
||||
tools.append({"googleSearch": {}})
|
||||
|
||||
# 将 request 中的 tools 合并到 tools 中
|
||||
if request.tools:
|
||||
function_declarations = []
|
||||
for tool in request.tools:
|
||||
if not tool or not isinstance(tool, dict):
|
||||
continue
|
||||
|
||||
if tool.get("type", "") == "function" and tool.get("function"):
|
||||
function = deepcopy(tool.get("function"))
|
||||
parameters = function.get("parameters", {})
|
||||
if parameters.get("type") == "object" and not parameters.get("properties", {}):
|
||||
function.pop("parameters", None)
|
||||
|
||||
function_declarations.append(function)
|
||||
|
||||
if function_declarations:
|
||||
tools.append({"functionDeclarations": function_declarations})
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
@@ -67,10 +87,10 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
|
||||
|
||||
def _build_payload(
|
||||
request: ChatRequest, messages: List[Dict[str, Any]]
|
||||
request: ChatRequest, messages: List[Dict[str, Any]], instruction: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""构建请求payload"""
|
||||
return {
|
||||
payload = {
|
||||
"contents": messages,
|
||||
"generationConfig": {
|
||||
"temperature": request.temperature,
|
||||
@@ -83,6 +103,16 @@ def _build_payload(
|
||||
"safetySettings": _get_safety_settings(request.model),
|
||||
}
|
||||
|
||||
if (
|
||||
instruction
|
||||
and isinstance(instruction, dict)
|
||||
and instruction.get("role") == "system"
|
||||
and instruction.get("parts")
|
||||
):
|
||||
payload["systemInstruction"] = instruction
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
class OpenAIChatService:
|
||||
"""聊天服务"""
|
||||
@@ -100,10 +130,10 @@ class OpenAIChatService:
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
"""创建聊天完成"""
|
||||
# 转换消息格式
|
||||
messages = self.message_converter.convert(request.messages)
|
||||
messages, instruction = self.message_converter.convert(request.messages)
|
||||
|
||||
# 构建请求payload
|
||||
payload = _build_payload(request, messages)
|
||||
payload = _build_payload(request, messages, instruction)
|
||||
|
||||
if request.stream:
|
||||
return self._handle_stream_completion(request.model, payload, api_key)
|
||||
|
||||
Reference in New Issue
Block a user