Compare commits

..

7 Commits

Author SHA1 Message Date
snaily
5a44a76c48 Merge remote-tracking branch 'BetterAndBetterII/main' 2025-03-03 18:45:56 +08:00
Toddy
7b5b6c7d4c if role is tool then set to user 2025-03-03 08:23:04 +00:00
Yuzhong Zhang
68ed4da789 Update Dockerfile 2025-03-03 14:09:45 +08:00
Yuzhong Zhang
cdbca7ec62 优化dockerfile,增加docker-compose,async openai 2025-03-03 13:55:09 +08:00
Yuzhong Zhang
48d58ef2e8 异步生成完成 2025-03-03 13:41:06 +08:00
snaily
88d483c1ef Merge pull request #4 from toddyoe/main
chore: add system instruction to enhance compliance with function call
2025-02-27 19:17:39 +08:00
Toddy
8d48db026c chore: add system instruction to enhance compliance with function call 2025-02-27 10:35:25 +00:00
7 changed files with 71 additions and 29 deletions

View File

@@ -3,10 +3,10 @@ FROM python:3.10-slim
WORKDIR /app
# 复制所需文件到容器中
COPY ./app /app/app
COPY ./requirements.txt /app
RUN pip install --no-cache-dir -r requirements.txt
COPY ./app /app/app
ENV API_KEYS='["your_api_key_1"]'
ENV ALLOWED_TOKENS='["your_token_1"]'
ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta

View File

@@ -63,7 +63,7 @@ async def generate_content(
logger.info(f"Using API key: {api_key}")
try:
response = chat_service.generate_content(
response = await chat_service.generate_content(
model=model_name,
request=request,
api_key=api_key
@@ -122,7 +122,7 @@ async def verify_key(api_key: str):
)
]
)
response = chat_service.generate_content(settings.TEST_MODEL,gemini_requset, api_key)
response = await chat_service.generate_content(settings.TEST_MODEL,gemini_requset, api_key)
if response:
return JSONResponse({"status": "valid"})
return JSONResponse({"status": "invalid"})

View File

@@ -24,13 +24,13 @@ class GeminiApiClient(ApiClient):
self.base_url = base_url
self.timeout = timeout
def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
if model.endswith("-search"):
model = model[:-7]
with httpx.Client(timeout=timeout) as client:
async with httpx.AsyncClient(timeout=timeout) as client:
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
response = client.post(url, json=payload)
response = await client.post(url, json=payload)
if response.status_code != 200:
error_content = response.text
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")

View File

@@ -1,14 +1,16 @@
# app/services/chat/message_converter.py
from abc import ABC, abstractmethod
from typing import List, Dict, Any
from typing import Any, Dict, List, Optional
SUPPORTED_ROLES = ["user", "model", "system"]
class MessageConverter(ABC):
"""消息转换器基类"""
@abstractmethod
def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
pass
@@ -30,16 +32,26 @@ def _convert_image(image_url: str) -> Dict[str, Any]:
class OpenAIMessageConverter(MessageConverter):
"""OpenAI消息格式转换器"""
def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
converted_messages = []
for msg in messages:
role = "user" if msg["role"] == "user" else "model"
parts = []
system_instruction = None
if isinstance(msg["content"], str):
# 请求 gemini 接口时如果包含 content 字段但内容为空时会返回 400 错误,所以需要判断是否为空
if msg["content"]:
parts.append({"text": msg["content"]})
for idx, msg in enumerate(messages):
role = msg.get("role", "")
if role not in SUPPORTED_ROLES:
if role == "tool":
role = "user"
else:
# 如果是最后一条消息,则认为是用户消息
if idx == len(messages) - 1:
role = "user"
else:
role = "model"
parts = []
if isinstance(msg["content"], str) and msg["content"]:
# 请求 gemini 接口时如果包含 content 字段但内容为空时会返回 400 错误,所以需要判断是否为空并移除
parts.append({"text": msg["content"]})
elif isinstance(msg["content"], list):
for content in msg["content"]:
if isinstance(content, str) and content:
@@ -50,6 +62,10 @@ class OpenAIMessageConverter(MessageConverter):
elif content["type"] == "image_url":
parts.append(_convert_image(content["image_url"]["url"]))
converted_messages.append({"role": role, "parts": parts})
if parts:
if role == "system":
system_instruction = {"role": "system", "parts": parts}
else:
converted_messages.append({"role": role, "parts": parts})
return converted_messages
return converted_messages, system_instruction

View File

@@ -79,10 +79,10 @@ class GeminiChatService:
self.key_manager = key_manager
self.response_handler = GeminiResponseHandler()
def generate_content(self, model: str, request: GeminiRequest, api_key: str) -> Dict[str, Any]:
async def generate_content(self, model: str, request: GeminiRequest, api_key: str) -> Dict[str, Any]:
"""生成内容"""
payload = _build_payload(model, request)
response = self.api_client.generate_content(payload, model, api_key)
response = await self.api_client.generate_content(payload, model, api_key)
return self.response_handler.handle_response(response, model, stream=False)
async def stream_generate_content(self, model: str, request: GeminiRequest, api_key: str) -> AsyncGenerator[str, None]:

View File

@@ -2,7 +2,7 @@
from copy import deepcopy
import json
from typing import Dict, Any, AsyncGenerator, List, Union
from typing import Dict, Any, AsyncGenerator, List, Optional, Union
from app.core.logger import get_openai_logger
from app.services.chat.message_converter import OpenAIMessageConverter
from app.services.chat.response_handler import OpenAIResponseHandler
@@ -57,7 +57,14 @@ def _build_tools(
function_declarations.append(function)
if function_declarations:
tools.append({"functionDeclarations": function_declarations})
# 按照 function 的 name 去重
names, functions = set(), []
for item in function_declarations:
if item.get("name") not in names:
names.add(item.get("name"))
functions.append(item)
tools.append({"functionDeclarations": functions})
return tools
@@ -87,10 +94,10 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
def _build_payload(
request: ChatRequest, messages: List[Dict[str, Any]]
request: ChatRequest, messages: List[Dict[str, Any]], instruction: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""构建请求payload"""
return {
payload = {
"contents": messages,
"generationConfig": {
"temperature": request.temperature,
@@ -103,6 +110,16 @@ def _build_payload(
"safetySettings": _get_safety_settings(request.model),
}
if (
instruction
and isinstance(instruction, dict)
and instruction.get("role") == "system"
and instruction.get("parts")
):
payload["systemInstruction"] = instruction
return payload
class OpenAIChatService:
"""聊天服务"""
@@ -120,20 +137,20 @@ class OpenAIChatService:
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
"""创建聊天完成"""
# 转换消息格式
messages = self.message_converter.convert(request.messages)
messages, instruction = self.message_converter.convert(request.messages)
# 构建请求payload
payload = _build_payload(request, messages)
payload = _build_payload(request, messages, instruction)
if request.stream:
return self._handle_stream_completion(request.model, payload, api_key)
return self._handle_normal_completion(request.model, payload, api_key)
return await self._handle_normal_completion(request.model, payload, api_key)
def _handle_normal_completion(
async def _handle_normal_completion(
self, model: str, payload: Dict[str, Any], api_key: str
) -> Dict[str, Any]:
"""处理普通聊天完成"""
response = self.api_client.generate_content(payload, model, api_key)
response = await self.api_client.generate_content(payload, model, api_key)
return self.response_handler.handle_response(
response, model, stream=False, finish_reason="stop"
)

9
docker-compose.yml Normal file
View File

@@ -0,0 +1,9 @@
version: '3'
services:
gemini-balance:
build: .
ports:
- "8000:8000"
env_file:
- .env