feat(image): 支持多模态模型输入base64格式图片

- 在消息转换中,增加对 `data:image/png;base64,...` 格式图片的支持,允许用户直接在输入中提供base64编码的图片。
- 调整图片处理逻辑,使其能够根据模型名称判断是否启用多模态能力,避免非多模态模型错误处理图片链接。
- 当未配置图床时,模型输出的图片将回退为base64格式,确保图片内容始终可用。
- 优化了相关函数的参数传递和代码格式,提高了代码的可读性和健壮性。
This commit is contained in:
snaily
2025-08-31 21:39:12 +08:00
parent b0127e6fc2
commit 611559d298
6 changed files with 181 additions and 110 deletions

View File

@@ -27,7 +27,7 @@ class MessageConverter(ABC):
@abstractmethod
def convert(
self, messages: List[Dict[str, Any]]
self, messages: List[Dict[str, Any]], model: str
) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
pass
@@ -84,7 +84,7 @@ def _convert_image_to_base64(url: str) -> str:
raise Exception(f"Failed to fetch image: {response.status_code}")
def _process_text_with_image(text: str) -> List[Dict[str, Any]]:
def _process_text_with_image(text: str, model: str) -> List[Dict[str, Any]]:
"""
处理可能包含图片URL的文本提取图片并转换为base64
@@ -94,17 +94,31 @@ def _process_text_with_image(text: str) -> List[Dict[str, Any]]:
Returns:
List[Dict[str, Any]]: 包含文本和图片的部分列表
"""
# 如果模型名中没有包含image当作普通文本处理
if "image" not in model:
return [{"text": text}]
parts = []
img_url_match = re.search(IMAGE_URL_PATTERN, text)
if img_url_match:
# 提取URL
img_url = img_url_match.group(2)
# 将URL对应的图片转换为base64
# 先判断是否是base64url如果是直接用不过不是将URL对应的图片转换为base64
try:
base64_data = _convert_image_to_base64(img_url)
parts.append(
{"inline_data": {"mimeType": "image/png", "data": base64_data}}
)
base64_url_match = re.search(DATA_URL_PATTERN, img_url)
if base64_url_match:
parts.append(
{
"inline_data": {
"mimeType": base64_url_match.group(1),
"data": base64_url_match.group(2),
}
}
)
else:
base64_data = _convert_image_to_base64(img_url)
parts.append(
{"inline_data": {"mimeType": "image/png", "data": base64_data}}
)
except Exception:
# 如果转换失败,回退到文本模式
parts.append({"text": text})
@@ -145,7 +159,7 @@ class OpenAIMessageConverter(MessageConverter):
raise
def convert(
self, messages: List[Dict[str, Any]]
self, messages: List[Dict[str, Any]], model: str
) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
converted_messages = []
system_instruction_parts = []
@@ -296,7 +310,7 @@ class OpenAIMessageConverter(MessageConverter):
elif (
"content" in msg and isinstance(msg["content"], str) and msg["content"]
):
parts.extend(_process_text_with_image(msg["content"]))
parts.extend(_process_text_with_image(msg["content"], model))
elif "tool_calls" in msg and isinstance(msg["tool_calls"], list):
# Keep existing tool call processing
for tool_call in msg["tool_calls"]:

View File

@@ -8,9 +8,9 @@ from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from app.config.config import settings
from app.utils.uploader import ImageUploaderFactory
from app.log.logger import get_openai_logger
from app.utils.helpers import is_image_upload_configured
from app.utils.uploader import ImageUploaderFactory
logger = get_openai_logger()
@@ -33,7 +33,11 @@ class GeminiResponseHandler(ResponseHandler):
self.thinking_status = False
def handle_response(
self, response: Dict[str, Any], model: str, stream: bool = False, usage_metadata: Optional[Dict[str, Any]] = None
self,
response: Dict[str, Any],
model: str,
stream: bool = False,
usage_metadata: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
if stream:
return _handle_gemini_stream_response(response, model, stream)
@@ -41,7 +45,10 @@ class GeminiResponseHandler(ResponseHandler):
def _handle_openai_stream_response(
response: Dict[str, Any], model: str, finish_reason: str, usage_metadata: Optional[Dict[str, Any]]
response: Dict[str, Any],
model: str,
finish_reason: str,
usage_metadata: Optional[Dict[str, Any]],
) -> Dict[str, Any]:
choices = []
candidates = response.get("candidates", [])
@@ -55,15 +62,15 @@ def _handle_openai_stream_response(
if not text and not tool_calls and not reasoning_content:
delta = {}
else:
delta = {"content": text, "reasoning_content": reasoning_content, "role": "assistant"}
delta = {
"content": text,
"reasoning_content": reasoning_content,
"role": "assistant",
}
if tool_calls:
delta["tool_calls"] = tool_calls
choice = {
"index": index,
"delta": delta,
"finish_reason": finish_reason
}
choice = {"index": index, "delta": delta, "finish_reason": finish_reason}
choices.append(choice)
template_chunk = {
@@ -74,16 +81,23 @@ def _handle_openai_stream_response(
"choices": choices,
}
if usage_metadata:
template_chunk["usage"] = {"prompt_tokens": usage_metadata.get("promptTokenCount", 0), "completion_tokens": usage_metadata.get("candidatesTokenCount",0), "total_tokens": usage_metadata.get("totalTokenCount", 0)}
template_chunk["usage"] = {
"prompt_tokens": usage_metadata.get("promptTokenCount", 0),
"completion_tokens": usage_metadata.get("candidatesTokenCount", 0),
"total_tokens": usage_metadata.get("totalTokenCount", 0),
}
return template_chunk
def _handle_openai_normal_response(
response: Dict[str, Any], model: str, finish_reason: str, usage_metadata: Optional[Dict[str, Any]]
response: Dict[str, Any],
model: str,
finish_reason: str,
usage_metadata: Optional[Dict[str, Any]],
) -> Dict[str, Any]:
choices = []
candidates = response.get("candidates", [])
for i, candidate in enumerate(candidates):
text, reasoning_content, tool_calls, _ = _extract_result(
{"candidates": [candidate]}, model, stream=False, gemini_format=False
@@ -106,7 +120,11 @@ def _handle_openai_normal_response(
"created": int(time.time()),
"model": model,
"choices": choices,
"usage": {"prompt_tokens": usage_metadata.get("promptTokenCount", 0), "completion_tokens": usage_metadata.get("candidatesTokenCount",0), "total_tokens": usage_metadata.get("totalTokenCount", 0)},
"usage": {
"prompt_tokens": usage_metadata.get("promptTokenCount", 0),
"completion_tokens": usage_metadata.get("candidatesTokenCount", 0),
"total_tokens": usage_metadata.get("totalTokenCount", 0),
},
}
@@ -127,8 +145,12 @@ class OpenAIResponseHandler(ResponseHandler):
usage_metadata: Optional[Dict[str, Any]] = None,
) -> Optional[Dict[str, Any]]:
if stream:
return _handle_openai_stream_response(response, model, finish_reason, usage_metadata)
return _handle_openai_normal_response(response, model, finish_reason, usage_metadata)
return _handle_openai_stream_response(
response, model, finish_reason, usage_metadata
)
return _handle_openai_normal_response(
response, model, finish_reason, usage_metadata
)
def handle_image_chat_response(
self, image_str: str, model: str, stream=False, finish_reason="stop"
@@ -182,7 +204,7 @@ def _extract_result(
gemini_format: bool = False,
) -> tuple[str, Optional[str], List[Dict[str, Any]], Optional[bool]]:
text, reasoning_content, tool_calls, thought = "", "", [], None
if stream:
if response.get("candidates"):
candidate = response["candidates"][0]
@@ -191,7 +213,7 @@ def _extract_result(
if not parts:
logger.warning("No parts found in stream response")
return "", None, [], None
if "text" in parts[0]:
text = parts[0].get("text")
if "thought" in parts[0]:
@@ -217,13 +239,13 @@ def _extract_result(
if response.get("candidates"):
candidate = response["candidates"][0]
text, reasoning_content = "", ""
# 使用安全的访问方式
content = candidate.get("content", {})
if content and isinstance(content, dict):
parts = content.get("parts", [])
if parts:
for part in parts:
if "text" in part:
@@ -241,14 +263,14 @@ def _extract_result(
logger.error(f"Invalid content structure for model: {model}")
text = _add_search_link_text(model, candidate, text)
# 安全地获取 parts 用于工具调用提取
parts = candidate.get("content", {}).get("parts", [])
tool_calls = _extract_tool_calls(parts, gemini_format)
else:
logger.warning(f"No candidates found in response for model: {model}")
text = "暂无返回"
return text, reasoning_content, tool_calls, thought
@@ -264,10 +286,6 @@ def _has_inline_image_part(response: Dict[str, Any]) -> bool:
def _extract_image_data(part: dict) -> str:
# Return empty string if no uploader is configured
if not is_image_upload_configured():
return ""
image_uploader = None
if settings.UPLOAD_PROVIDER == "smms":
image_uploader = ImageUploaderFactory.create(
@@ -287,13 +305,17 @@ def _extract_image_data(part: dict) -> str:
current_date = time.strftime("%Y/%m/%d")
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
base64_data = part["inlineData"]["data"]
mime_type = part["inlineData"]["mimeType"]
# 将base64_data转成bytes数组
# Return empty string if no uploader is configured
if not is_image_upload_configured(settings):
return f"\n\n![image](data:{mime_type};base64,{base64_data})\n\n"
bytes_data = base64.b64decode(base64_data)
upload_response = image_uploader.upload(bytes_data, filename)
if upload_response.success:
text = f"\n\n![image]({upload_response.data.url})\n\n"
else:
text = ""
text = f"\n\n![image](data:{mime_type};base64,{base64_data})\n\n"
return text
@@ -306,7 +328,7 @@ def _extract_tool_calls(
letters = string.ascii_lowercase + string.digits
tool_calls = list()
for i in range(len(parts)):
part = parts[i]
if not part or not isinstance(part, dict):
@@ -315,7 +337,7 @@ def _extract_tool_calls(
item = part.get("functionCall", {})
if not item or not isinstance(item, dict):
continue
if gemini_format:
tool_calls.append(part)
else:
@@ -339,9 +361,9 @@ def _handle_gemini_stream_response(
response: Dict[str, Any], model: str, stream: bool
) -> Dict[str, Any]:
# Early return raw Gemini response if no uploader configured and contains inline images
if not is_image_upload_configured() and _has_inline_image_part(response):
if not is_image_upload_configured(settings) and _has_inline_image_part(response):
return response
text, reasoning_content, tool_calls, thought = _extract_result(
response, model, stream=stream, gemini_format=True
)
@@ -360,9 +382,9 @@ def _handle_gemini_normal_response(
response: Dict[str, Any], model: str, stream: bool
) -> Dict[str, Any]:
# Early return raw Gemini response if no uploader configured and contains inline images
if not is_image_upload_configured() and _has_inline_image_part(response):
if not is_image_upload_configured(settings) and _has_inline_image_part(response):
return response
text, reasoning_content, tool_calls, thought = _extract_result(
response, model, stream=stream, gemini_format=True
)
@@ -371,7 +393,7 @@ def _handle_gemini_normal_response(
parts = tool_calls
else:
if thought is not None:
parts.append({"text": reasoning_content,"thought": thought})
parts.append({"text": reasoning_content, "thought": thought})
part = {"text": text}
parts.append(part)
content = {"parts": parts, "role": "model"}

View File

@@ -1,9 +1,8 @@
import logging
import platform
import sys
import re
import sys
from typing import Dict, Optional
from app.utils.helpers import redact_key_for_logging as _redact_key_for_logging
# ANSI转义序列颜色代码
COLORS = {
@@ -15,7 +14,6 @@ COLORS = {
}
# Windows系统启用ANSI支持
if platform.system() == "Windows":
import ctypes
@@ -46,14 +44,16 @@ class AccessLogFormatter(logging.Formatter):
# API key patterns to match in URLs
API_KEY_PATTERNS = [
r'\bAIza[0-9A-Za-z_-]{35}', # Google API keys (like Gemini)
r'\bsk-[0-9A-Za-z_-]{20,}', # OpenAI and general sk- prefixed keys
r"\bAIza[0-9A-Za-z_-]{35}", # Google API keys (like Gemini)
r"\bsk-[0-9A-Za-z_-]{20,}", # OpenAI and general sk- prefixed keys
]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Compile regex patterns for better performance
self.compiled_patterns = [re.compile(pattern) for pattern in self.API_KEY_PATTERNS]
self.compiled_patterns = [
re.compile(pattern) for pattern in self.API_KEY_PATTERNS
]
def format(self, record):
# Format the record normally first
@@ -68,9 +68,10 @@ class AccessLogFormatter(logging.Formatter):
"""
try:
for pattern in self.compiled_patterns:
def replace_key(match):
key = match.group(0)
return _redact_key_for_logging(key)
return redact_key_for_logging(key)
message = pattern.sub(replace_key, message)
@@ -78,11 +79,31 @@ class AccessLogFormatter(logging.Formatter):
except Exception as e:
# Log the error but don't expose the original message in case it contains keys
import logging
logger = logging.getLogger(__name__)
logger.error(f"Error redacting API keys in access log: {e}")
return "[LOG_REDACTION_ERROR]"
def redact_key_for_logging(key: str) -> str:
"""
Redacts API key for secure logging by showing only first and last 6 characters.
Args:
key: API key to redact
Returns:
str: Redacted key in format "first6...last6" or descriptive placeholder for edge cases
"""
if not key:
return key
if len(key) <= 12:
return f"{key[:3]}...{key[-3:]}"
else:
return f"{key[:6]}...{key[-6:]}"
# 日志格式 - 使用 fileloc 并设置固定宽度 (例如 30)
FORMATTER = ColoredFormatter(
"%(asctime)s | %(levelname)-17s | %(fileloc)-30s | %(message)s"
@@ -326,4 +347,3 @@ def setup_access_logging():
access_logger.propagate = False
return access_logger

View File

@@ -285,7 +285,9 @@ class OpenAIChatService:
api_key: str,
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
"""创建聊天完成"""
messages, instruction = self.message_converter.convert(request.messages)
messages, instruction = self.message_converter.convert(
request.messages, request.model
)
payload = _build_payload(request, messages, instruction)

View File

@@ -9,8 +9,8 @@ from app.config.config import settings
from app.core.constants import VALID_IMAGE_RATIOS
from app.domain.openai_models import ImageGenerationRequest
from app.log.logger import get_image_create_logger
from app.utils.uploader import ImageUploaderFactory
from app.utils.helpers import is_image_upload_configured
from app.utils.uploader import ImageUploaderFactory
logger = get_image_create_logger()
@@ -99,7 +99,10 @@ class ImageCreateService:
image_uploader = None
# Return base64 if explicitly requested or if no uploader is configured
if request.response_format == "b64_json" or not is_image_upload_configured():
if (
request.response_format == "b64_json"
or not is_image_upload_configured(settings)
):
base64_image = base64.b64encode(image_data).decode("utf-8")
images_data.append(
{"b64_json": base64_image, "revised_prompt": request.prompt}

View File

@@ -1,14 +1,17 @@
"""
通用工具函数模块
"""
import json
import re
import base64
import requests
from typing import Dict, Any, List, Optional, Tuple
from pathlib import Path
import logging
import base64
import json
import logging
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import requests
from app.config.config import Settings
from app.core.constants import DATA_URL_PATTERN, IMAGE_URL_PATTERN, VALID_IMAGE_RATIOS
helper_logger = logging.getLogger("app.utils")
@@ -20,23 +23,25 @@ VERSION_FILE_PATH = PROJECT_ROOT / "VERSION"
def extract_mime_type_and_data(base64_string: str) -> Tuple[Optional[str], str]:
"""
从 base64 字符串中提取 MIME 类型和数据
Args:
base64_string: 可能包含 MIME 类型信息的 base64 字符串
Returns:
tuple: (mime_type, encoded_data)
"""
# 检查字符串是否以 "data:" 格式开始
if base64_string.startswith('data:'):
if base64_string.startswith("data:"):
# 提取 MIME 类型和数据
pattern = DATA_URL_PATTERN
match = re.match(pattern, base64_string)
if match:
mime_type = "image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
mime_type = (
"image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
)
encoded_data = match.group(2)
return mime_type, encoded_data
# 如果不是预期格式,假定它只是数据部分
return None, base64_string
@@ -44,20 +49,20 @@ def extract_mime_type_and_data(base64_string: str) -> Tuple[Optional[str], str]:
def convert_image_to_base64(url: str) -> str:
"""
将图片URL转换为base64编码
Args:
url: 图片URL
Returns:
str: base64编码的图片数据
Raises:
Exception: 如果获取图片失败
"""
response = requests.get(url)
if response.status_code == 200:
# 将图片内容转换为base64
img_data = base64.b64encode(response.content).decode('utf-8')
img_data = base64.b64encode(response.content).decode("utf-8")
return img_data
else:
raise Exception(f"Failed to fetch image: {response.status_code}")
@@ -66,64 +71,66 @@ def convert_image_to_base64(url: str) -> str:
def format_json_response(data: Dict[str, Any], indent: int = 2) -> str:
"""
格式化JSON响应
Args:
data: 要格式化的数据
indent: 缩进空格数
Returns:
str: 格式化后的JSON字符串
"""
return json.dumps(data, indent=indent, ensure_ascii=False)
def parse_prompt_parameters(prompt: str, default_ratio: str = "1:1") -> Tuple[str, int, str]:
def parse_prompt_parameters(
prompt: str, default_ratio: str = "1:1"
) -> Tuple[str, int, str]:
"""
从prompt中解析参数
支持的格式:
- {n:数量} 例如: {n:2} 生成2张图片
- {ratio:比例} 例如: {ratio:16:9} 使用16:9比例
Args:
prompt: 提示文本
default_ratio: 默认比例
Returns:
tuple: (清理后的提示文本, 图片数量, 比例)
"""
# 默认值
n = 1
aspect_ratio = default_ratio
# 解析n参数
n_match = re.search(r'{n:(\d+)}', prompt)
n_match = re.search(r"{n:(\d+)}", prompt)
if n_match:
n = int(n_match.group(1))
if n < 1 or n > 4:
raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.")
prompt = prompt.replace(n_match.group(0), '').strip()
# 解析ratio参数
ratio_match = re.search(r'{ratio:(\d+:\d+)}', prompt)
prompt = prompt.replace(n_match.group(0), "").strip()
# 解析ratio参数
ratio_match = re.search(r"{ratio:(\d+:\d+)}", prompt)
if ratio_match:
aspect_ratio = ratio_match.group(1)
if aspect_ratio not in VALID_IMAGE_RATIOS:
raise ValueError(
f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}"
)
prompt = prompt.replace(ratio_match.group(0), '').strip()
prompt = prompt.replace(ratio_match.group(0), "").strip()
return prompt, n, aspect_ratio
def extract_image_urls_from_markdown(text: str) -> List[str]:
"""
从Markdown文本中提取图片URL
Args:
text: Markdown文本
Returns:
List[str]: 图片URL列表
"""
@@ -135,23 +142,22 @@ def extract_image_urls_from_markdown(text: str) -> List[str]:
def is_valid_api_key(key: str) -> bool:
"""
检查API密钥格式是否有效
Args:
key: API密钥
Returns:
bool: 如果密钥格式有效则返回True
"""
# 检查Gemini API密钥格式
if key.startswith('AIza'):
if key.startswith("AIza"):
return len(key) >= 30
# 检查OpenAI API密钥格式
if key.startswith('sk-'):
return len(key) >= 30
return False
# 检查OpenAI API密钥格式
if key.startswith("sk-"):
return len(key) >= 30
return False
def redact_key_for_logging(key: str) -> str:
@@ -177,26 +183,28 @@ def get_current_version(default_version: str = "0.0.0") -> str:
"""Reads the current version from the VERSION file."""
version_file = VERSION_FILE_PATH
try:
with version_file.open('r', encoding='utf-8') as f:
with version_file.open("r", encoding="utf-8") as f:
version = f.read().strip()
if not version:
helper_logger.warning(f"VERSION file ('{version_file}') is empty. Using default version '{default_version}'.")
helper_logger.warning(
f"VERSION file ('{version_file}') is empty. Using default version '{default_version}'."
)
return default_version
return version
except FileNotFoundError:
helper_logger.warning(f"VERSION file not found at '{version_file}'. Using default version '{default_version}'.")
helper_logger.warning(
f"VERSION file not found at '{version_file}'. Using default version '{default_version}'."
)
return default_version
except IOError as e:
helper_logger.error(f"Error reading VERSION file ('{version_file}'): {e}. Using default version '{default_version}'.")
helper_logger.error(
f"Error reading VERSION file ('{version_file}'): {e}. Using default version '{default_version}'."
)
return default_version
def is_image_upload_configured() -> bool:
"""Return True only if a valid upload provider is selected and all required settings for that provider are present. Uses lazy import to avoid circular imports."""
try:
from app.config.config import settings # local import to avoid circular dependency at module import time
except Exception:
return False
def is_image_upload_configured(settings: Settings) -> bool:
"""Return True only if a valid upload provider is selected and all required settings for that provider are present."""
provider = (getattr(settings, "UPLOAD_PROVIDER", "") or "").strip().lower()
if provider == "smms":
@@ -204,8 +212,10 @@ def is_image_upload_configured() -> bool:
if provider == "picgo":
return bool(getattr(settings, "PICGO_API_KEY", ""))
if provider == "cloudflare_imgbed":
return all([
getattr(settings, "CLOUDFLARE_IMGBED_URL", ""),
getattr(settings, "CLOUDFLARE_IMGBED_AUTH_CODE", ""),
])
return all(
[
getattr(settings, "CLOUDFLARE_IMGBED_URL", ""),
getattr(settings, "CLOUDFLARE_IMGBED_AUTH_CODE", ""),
]
)
return False