Files
gemini-balance/app/services/chat_service.py

302 lines
13 KiB
Python

import httpx
import json
import time
import uuid
from typing import Dict, Any, Optional, AsyncGenerator, Union
import openai
from app.core.config import settings
from app.core.logger import get_chat_logger
logger = get_chat_logger()
class ChatService:
def __init__(self, base_url: str, key_manager=None):
self.base_url = base_url
self.key_manager = key_manager
def convert_messages_to_gemini_format(self, messages: list) -> list:
"""Convert OpenAI message format to Gemini format"""
converted_messages = []
for msg in messages:
role = "user" if msg["role"] == "user" else "model"
parts = []
# 处理文本内容
if isinstance(msg["content"], str):
parts.append({"text": msg["content"]})
# 处理包含图片的消息
elif isinstance(msg["content"], list):
for content in msg["content"]:
if isinstance(content, str):
parts.append({"text": content})
elif isinstance(content, dict) and content["type"] == "text":
parts.append({"text": content["text"]})
elif isinstance(content, dict) and content["type"] == "image_url":
# 处理图片URL
image_url = content["image_url"]["url"]
if image_url.startswith("data:image"):
# 处理base64图片
parts.append(
{
"inline_data": {
"mime_type": "image/jpeg",
"data": image_url.split(",")[1],
}
}
)
else:
# 处理普通URL图片
parts.append(
{
"inline_data": {
"mime_type": "image/jpeg",
"data": image_url,
}
}
)
converted_messages.append({"role": role, "parts": parts})
return converted_messages
def convert_gemini_response_to_openai(
self, response: Dict[str, Any], model: str, stream: bool = False, finish_reason: str = None
) -> Optional[Dict[str, Any]]:
"""Convert Gemini response to OpenAI format"""
if stream:
try:
if response.get("candidates"):
candidate = response["candidates"][0]
content = candidate.get("content", {})
parts = content.get("parts", [])
if "text" in parts[0]:
text = parts[0].get("text")
elif "executableCode" in parts[0]:
text = self.format_code_block(parts[0]["executableCode"])
elif "codeExecution" in parts[0]:
text = self.format_code_block(parts[0]["codeExecution"])
elif "executableCodeResult" in parts[0]:
text = self.format_execution_result(parts[0]["executableCodeResult"])
elif "codeExecutionResult" in parts[0]:
text = self.format_execution_result(parts[0]["codeExecutionResult"])
else:
text = ""
else:
text = ""
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": text} if text else {},
"finish_reason": finish_reason,
}
],
}
except Exception as e:
logger.error(f"Error converting Gemini response: {str(e)}")
logger.debug(f"Raw response: {response}")
return None
else:
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": response["candidates"][0]["content"]["parts"][0][
"text"
],
},
"finish_reason": finish_reason,
}
],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
},
}
async def create_chat_completion(
self,
messages: list,
model: str,
temperature: float,
stream: bool,
api_key: str,
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
"""Create chat completion using either Gemini or OpenAI API"""
if tools is None:
tools = []
if settings.TOOLS_CODE_EXECUTION_ENABLED and not model.endswith("-search"):
tools.append({"code_execution": {}})
if model.endswith("-search"):
tools.append({"googleSearch": {}})
return await self._gemini_chat_completion(
messages, model, temperature, stream, api_key, tools
)
# else:
# return await self._openai_chat_completion(
# messages, model, temperature, stream, api_key, tools
# )
async def _gemini_chat_completion(
self,
messages: list,
model: str,
temperature: float,
stream: bool,
api_key: str,
tools: Optional[list] = None,
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
"""Handle Gemini API chat completion"""
if model.endswith("-search"):
gemini_model = model[:-7] # Remove -search suffix
else:
gemini_model = model
gemini_messages = self.convert_messages_to_gemini_format(messages)
if not stream:
# 非流式模式下,移除代码执行工具
tools.remove({"code_execution": {}})
payload = {
"contents": gemini_messages,
"generationConfig": {"temperature": temperature},
"tools": tools,
"safetySettings": [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
],
}
if stream:
async def generate():
retries = 0
MAX_RETRIES = 3
current_api_key = api_key
while retries < MAX_RETRIES:
try:
async with httpx.AsyncClient() as client:
stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:streamGenerateContent?alt=sse&key={current_api_key}"
async with client.stream("POST", stream_url, json=payload) as response:
if response.status_code != 200:
error_msg = await response.text()
logger.error(f"API error: {response.status_code}, {error_msg}")
if retries < MAX_RETRIES - 1:
current_api_key = await self.key_manager.handle_api_failure(current_api_key)
retries += 1
continue
else:
logger.error(
f"Max retries reached. Final error: {response.status_code}"
)
yield f"data: {json.dumps({'error': f'API error: {response.status_code}'})}\n\n"
return
async for line in response.aiter_lines():
if line.startswith("data: "):
try:
chunk = json.loads(line[6:])
openai_chunk = self.convert_gemini_response_to_openai(
chunk, model, stream=True, finish_reason=None
)
if openai_chunk:
yield f"data: {json.dumps(openai_chunk)}\n\n"
except json.JSONDecodeError:
continue
yield f"data: {json.dumps(self.convert_gemini_response_to_openai({}, model,stream=True, finish_reason='stop'))}\n\n"
yield "data: [DONE]\n\n"
return
except Exception as e:
logger.warning(f"Stream error: {str(e)}, attempting retry {retries + 1}")
if retries < MAX_RETRIES - 1:
current_api_key = await self.key_manager.handle_api_failure(current_api_key)
retries += 1
continue
else:
logger.error(f"Max retries reached. Final error: {str(e)}")
yield f"data: {json.dumps({'error': str(e)})}\n\n"
return
return generate()
else:
try:
async with httpx.AsyncClient() as client:
url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:generateContent?key={api_key}"
response = await client.post(url, json=payload)
if response.status_code != 200:
error_text = response.text
error_code = response.status_code
raise Exception(f"API调用错误 - 状态码: {error_code}, 响应内容: {error_text}")
gemini_response = response.json()
return self.convert_gemini_response_to_openai(gemini_response, model, finish_reason="stop")
except Exception as e:
logger.error(f"Error in non-stream completion")
raise
# async def _openai_chat_completion(
# self,
# messages: list,
# model: str,
# temperature: float,
# stream: bool,
# api_key: str,
# tools: Optional[list] = None,
# ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
# """Handle OpenAI API chat completion"""
# client = openai.OpenAI(api_key=api_key, base_url=self.base_url)
# if tools:
# response = client.chat.completions.create(
# model=model,
# messages=messages,
# temperature=temperature,
# stream=stream,
# tools=tools,
# )
# else:
# response = client.chat.completions.create(
# model=model, messages=messages, temperature=temperature, stream=stream
# )
# if stream:
# async def generate():
# for chunk in response:
# yield f"data: {chunk.model_dump_json()}\n\n"
# return generate()
# return response
def format_code_block(self, code_data: dict) -> str:
"""格式化代码块输出"""
language = code_data.get("language", "").lower()
code = code_data.get("code", "").strip()
return f"""\n【代码执行】\n```{language}\n{code}\n```\n"""
def format_execution_result(self, result_data: dict) -> str:
"""格式化执行结果输出"""
outcome = result_data.get("outcome", "")
output = result_data.get("output", "").strip()
return f"""\n【执行结果】\noutcome: {outcome}\noutput:\n```{output}```\n"""