mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-07-04 06:11:32 +08:00
chore: add system instruction to enhance compliance with function call
This commit is contained in:
@@ -1,14 +1,16 @@
|
||||
# app/services/chat/message_converter.py
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
SUPPORTED_ROLES = ["user", "model", "system"]
|
||||
|
||||
|
||||
class MessageConverter(ABC):
|
||||
"""消息转换器基类"""
|
||||
|
||||
@abstractmethod
|
||||
def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||
pass
|
||||
|
||||
|
||||
@@ -30,16 +32,19 @@ def _convert_image(image_url: str) -> Dict[str, Any]:
|
||||
class OpenAIMessageConverter(MessageConverter):
|
||||
"""OpenAI消息格式转换器"""
|
||||
|
||||
def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||
converted_messages = []
|
||||
for msg in messages:
|
||||
role = "user" if msg["role"] == "user" else "model"
|
||||
parts = []
|
||||
system_instruction = None
|
||||
|
||||
if isinstance(msg["content"], str):
|
||||
# 请求 gemini 接口时如果包含 content 字段但内容为空时会返回 400 错误,所以需要判断是否为空
|
||||
if msg["content"]:
|
||||
parts.append({"text": msg["content"]})
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
if role not in SUPPORTED_ROLES:
|
||||
role = "model"
|
||||
|
||||
parts = []
|
||||
if isinstance(msg["content"], str) and msg["content"]:
|
||||
# 请求 gemini 接口时如果包含 content 字段但内容为空时会返回 400 错误,所以需要判断是否为空并移除
|
||||
parts.append({"text": msg["content"]})
|
||||
elif isinstance(msg["content"], list):
|
||||
for content in msg["content"]:
|
||||
if isinstance(content, str) and content:
|
||||
@@ -50,6 +55,10 @@ class OpenAIMessageConverter(MessageConverter):
|
||||
elif content["type"] == "image_url":
|
||||
parts.append(_convert_image(content["image_url"]["url"]))
|
||||
|
||||
converted_messages.append({"role": role, "parts": parts})
|
||||
if parts:
|
||||
if role == "system":
|
||||
system_instruction = {"role": "system", "parts": parts}
|
||||
else:
|
||||
converted_messages.append({"role": role, "parts": parts})
|
||||
|
||||
return converted_messages
|
||||
return converted_messages, system_instruction
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from copy import deepcopy
|
||||
import json
|
||||
from typing import Dict, Any, AsyncGenerator, List, Union
|
||||
from typing import Dict, Any, AsyncGenerator, List, Optional, Union
|
||||
from app.core.logger import get_openai_logger
|
||||
from app.services.chat.message_converter import OpenAIMessageConverter
|
||||
from app.services.chat.response_handler import OpenAIResponseHandler
|
||||
@@ -87,10 +87,10 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
|
||||
|
||||
def _build_payload(
|
||||
request: ChatRequest, messages: List[Dict[str, Any]]
|
||||
request: ChatRequest, messages: List[Dict[str, Any]], instruction: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""构建请求payload"""
|
||||
return {
|
||||
payload = {
|
||||
"contents": messages,
|
||||
"generationConfig": {
|
||||
"temperature": request.temperature,
|
||||
@@ -103,6 +103,16 @@ def _build_payload(
|
||||
"safetySettings": _get_safety_settings(request.model),
|
||||
}
|
||||
|
||||
if (
|
||||
instruction
|
||||
and isinstance(instruction, dict)
|
||||
and instruction.get("role") == "system"
|
||||
and instruction.get("parts")
|
||||
):
|
||||
payload["systemInstruction"] = instruction
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
class OpenAIChatService:
|
||||
"""聊天服务"""
|
||||
@@ -120,10 +130,10 @@ class OpenAIChatService:
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
"""创建聊天完成"""
|
||||
# 转换消息格式
|
||||
messages = self.message_converter.convert(request.messages)
|
||||
messages, instruction = self.message_converter.convert(request.messages)
|
||||
|
||||
# 构建请求payload
|
||||
payload = _build_payload(request, messages)
|
||||
payload = _build_payload(request, messages, instruction)
|
||||
|
||||
if request.stream:
|
||||
return self._handle_stream_completion(request.model, payload, api_key)
|
||||
|
||||
Reference in New Issue
Block a user