Files
gemini-balance/app/services/chat/message_converter.py
zhanghaoyu 28e67cc3fa 1. modify IMAGE_URL_PATTERN
2. modify import
2025-03-15 12:37:56 +08:00

142 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# app/services/chat/message_converter.py
from abc import ABC, abstractmethod
import re
from typing import Any, Dict, List, Optional
import requests
import base64
SUPPORTED_ROLES = ["user", "model", "system"]
IMAGE_URL_PATTERN = r'\[image\]\((.*?)\)'
class MessageConverter(ABC):
"""消息转换器基类"""
@abstractmethod
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
pass
def _convert_image(image_url: str) -> Dict[str, Any]:
if image_url.startswith("data:image"):
return {
"inline_data": {
"mime_type": "image/jpeg",
"data": image_url.split(",")[1]
}
}
return {
"image_url": {
"url": image_url
}
}
def _convert_image_to_base64(url: str) -> str:
"""
将图片URL转换为base64编码
Args:
url: 图片URL
Returns:
str: base64编码的图片数据
"""
response = requests.get(url)
if response.status_code == 200:
# 将图片内容转换为base64
img_data = base64.b64encode(response.content).decode('utf-8')
return img_data
else:
raise Exception(f"Failed to fetch image: {response.status_code}")
def _process_text_with_image(text: str) -> List[Dict[str, Any]]:
"""
处理可能包含图片URL的文本提取图片并转换为base64
Args:
text: 可能包含图片URL的文本
Returns:
List[Dict[str, Any]]: 包含文本和图片的部分列表
"""
parts = []
img_url_match = re.search(IMAGE_URL_PATTERN, text)
if img_url_match:
# 提取URL
img_url = img_url_match.group(1)
# 将URL对应的图片转换为base64
try:
base64_data = _convert_image_to_base64(img_url)
parts.append({
"inlineData": {
"mimeType": "image/jpeg",
"data": base64_data
}
})
except Exception as e:
# 如果转换失败,回退到文本模式
parts.append({"text": text})
else:
# 没有图片URL作为纯文本处理
parts.append({"text": text})
return parts
class OpenAIMessageConverter(MessageConverter):
"""OpenAI消息格式转换器"""
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
converted_messages = []
system_instruction_parts = []
for idx, msg in enumerate(messages):
role = msg.get("role", "")
if role not in SUPPORTED_ROLES:
if role == "tool":
role = "user"
else:
# 如果是最后一条消息,则认为是用户消息
if idx == len(messages) - 1:
role = "user"
else:
role = "model"
parts = []
# 特别处理assistant的消息按\n\n分割
if role == "assistant" and isinstance(msg["content"], str) and msg["content"]:
# 按\n\n分割消息
content_parts = msg["content"].split("\n\n")
for part in content_parts:
if not part.strip(): # 跳过空内容
continue
# 处理可能包含图片的文本
parts.extend(_process_text_with_image(part))
elif isinstance(msg["content"], str) and msg["content"]:
# 请求 gemini 接口时如果包含 content 字段但内容为空时会返回 400 错误,所以需要判断是否为空并移除
parts.extend(_process_text_with_image(msg["content"]))
elif isinstance(msg["content"], list):
for content in msg["content"]:
if isinstance(content, str) and content:
parts.append({"text": content})
elif isinstance(content, dict):
if content["type"] == "text" and content["text"]:
parts.append({"text": content["text"]})
elif content["type"] == "image_url":
parts.append(_convert_image(content["image_url"]["url"]))
if parts:
if role == "system":
system_instruction_parts.extend(parts)
else:
converted_messages.append({"role": role, "parts": parts})
system_instruction = (
None
if not system_instruction_parts
else {
"role": "system",
"parts": system_instruction_parts,
}
)
return converted_messages, system_instruction