From 611559d2980d9c8a57478d53860fc17ca5ac85b9 Mon Sep 17 00:00:00 2001 From: snaily Date: Sun, 31 Aug 2025 21:39:12 +0800 Subject: [PATCH] =?UTF-8?q?feat(image):=20=E6=94=AF=E6=8C=81=E5=A4=9A?= =?UTF-8?q?=E6=A8=A1=E6=80=81=E6=A8=A1=E5=9E=8B=E8=BE=93=E5=85=A5base64?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F=E5=9B=BE=E7=89=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在消息转换中,增加对 `data:image/png;base64,...` 格式图片的支持,允许用户直接在输入中提供base64编码的图片。 - 调整图片处理逻辑,使其能够根据模型名称判断是否启用多模态能力,避免非多模态模型错误处理图片链接。 - 当未配置图床时,模型输出的图片将回退为base64格式,确保图片内容始终可用。 - 优化了相关函数的参数传递和代码格式,提高了代码的可读性和健壮性。 --- app/handler/message_converter.py | 32 ++++-- app/handler/response_handler.py | 92 ++++++++++------- app/log/logger.py | 36 +++++-- app/service/chat/openai_chat_service.py | 4 +- app/service/image/image_create_service.py | 7 +- app/utils/helpers.py | 120 ++++++++++++---------- 6 files changed, 181 insertions(+), 110 deletions(-) diff --git a/app/handler/message_converter.py b/app/handler/message_converter.py index 378871a..d185ee0 100644 --- a/app/handler/message_converter.py +++ b/app/handler/message_converter.py @@ -27,7 +27,7 @@ class MessageConverter(ABC): @abstractmethod def convert( - self, messages: List[Dict[str, Any]] + self, messages: List[Dict[str, Any]], model: str ) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]: pass @@ -84,7 +84,7 @@ def _convert_image_to_base64(url: str) -> str: raise Exception(f"Failed to fetch image: {response.status_code}") -def _process_text_with_image(text: str) -> List[Dict[str, Any]]: +def _process_text_with_image(text: str, model: str) -> List[Dict[str, Any]]: """ 处理可能包含图片URL的文本,提取图片并转换为base64 @@ -94,17 +94,31 @@ def _process_text_with_image(text: str) -> List[Dict[str, Any]]: Returns: List[Dict[str, Any]]: 包含文本和图片的部分列表 """ + # 如果模型名中没有包含image,当作普通文本处理 + if "image" not in model: + return [{"text": text}] parts = [] img_url_match = re.search(IMAGE_URL_PATTERN, text) if img_url_match: # 提取URL img_url = img_url_match.group(2) - # 将URL对应的图片转换为base64 + # 先判断是否是base64url如果是,直接用,不过不是,再将URL对应的图片转换为base64 try: - base64_data = _convert_image_to_base64(img_url) - parts.append( - {"inline_data": {"mimeType": "image/png", "data": base64_data}} - ) + base64_url_match = re.search(DATA_URL_PATTERN, img_url) + if base64_url_match: + parts.append( + { + "inline_data": { + "mimeType": base64_url_match.group(1), + "data": base64_url_match.group(2), + } + } + ) + else: + base64_data = _convert_image_to_base64(img_url) + parts.append( + {"inline_data": {"mimeType": "image/png", "data": base64_data}} + ) except Exception: # 如果转换失败,回退到文本模式 parts.append({"text": text}) @@ -145,7 +159,7 @@ class OpenAIMessageConverter(MessageConverter): raise def convert( - self, messages: List[Dict[str, Any]] + self, messages: List[Dict[str, Any]], model: str ) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]: converted_messages = [] system_instruction_parts = [] @@ -296,7 +310,7 @@ class OpenAIMessageConverter(MessageConverter): elif ( "content" in msg and isinstance(msg["content"], str) and msg["content"] ): - parts.extend(_process_text_with_image(msg["content"])) + parts.extend(_process_text_with_image(msg["content"], model)) elif "tool_calls" in msg and isinstance(msg["tool_calls"], list): # Keep existing tool call processing for tool_call in msg["tool_calls"]: diff --git a/app/handler/response_handler.py b/app/handler/response_handler.py index bbe047d..57a2932 100644 --- a/app/handler/response_handler.py +++ b/app/handler/response_handler.py @@ -8,9 +8,9 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional from app.config.config import settings -from app.utils.uploader import ImageUploaderFactory from app.log.logger import get_openai_logger from app.utils.helpers import is_image_upload_configured +from app.utils.uploader import ImageUploaderFactory logger = get_openai_logger() @@ -33,7 +33,11 @@ class GeminiResponseHandler(ResponseHandler): self.thinking_status = False def handle_response( - self, response: Dict[str, Any], model: str, stream: bool = False, usage_metadata: Optional[Dict[str, Any]] = None + self, + response: Dict[str, Any], + model: str, + stream: bool = False, + usage_metadata: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: if stream: return _handle_gemini_stream_response(response, model, stream) @@ -41,7 +45,10 @@ class GeminiResponseHandler(ResponseHandler): def _handle_openai_stream_response( - response: Dict[str, Any], model: str, finish_reason: str, usage_metadata: Optional[Dict[str, Any]] + response: Dict[str, Any], + model: str, + finish_reason: str, + usage_metadata: Optional[Dict[str, Any]], ) -> Dict[str, Any]: choices = [] candidates = response.get("candidates", []) @@ -55,15 +62,15 @@ def _handle_openai_stream_response( if not text and not tool_calls and not reasoning_content: delta = {} else: - delta = {"content": text, "reasoning_content": reasoning_content, "role": "assistant"} + delta = { + "content": text, + "reasoning_content": reasoning_content, + "role": "assistant", + } if tool_calls: delta["tool_calls"] = tool_calls - - choice = { - "index": index, - "delta": delta, - "finish_reason": finish_reason - } + + choice = {"index": index, "delta": delta, "finish_reason": finish_reason} choices.append(choice) template_chunk = { @@ -74,16 +81,23 @@ def _handle_openai_stream_response( "choices": choices, } if usage_metadata: - template_chunk["usage"] = {"prompt_tokens": usage_metadata.get("promptTokenCount", 0), "completion_tokens": usage_metadata.get("candidatesTokenCount",0), "total_tokens": usage_metadata.get("totalTokenCount", 0)} + template_chunk["usage"] = { + "prompt_tokens": usage_metadata.get("promptTokenCount", 0), + "completion_tokens": usage_metadata.get("candidatesTokenCount", 0), + "total_tokens": usage_metadata.get("totalTokenCount", 0), + } return template_chunk def _handle_openai_normal_response( - response: Dict[str, Any], model: str, finish_reason: str, usage_metadata: Optional[Dict[str, Any]] + response: Dict[str, Any], + model: str, + finish_reason: str, + usage_metadata: Optional[Dict[str, Any]], ) -> Dict[str, Any]: choices = [] candidates = response.get("candidates", []) - + for i, candidate in enumerate(candidates): text, reasoning_content, tool_calls, _ = _extract_result( {"candidates": [candidate]}, model, stream=False, gemini_format=False @@ -106,7 +120,11 @@ def _handle_openai_normal_response( "created": int(time.time()), "model": model, "choices": choices, - "usage": {"prompt_tokens": usage_metadata.get("promptTokenCount", 0), "completion_tokens": usage_metadata.get("candidatesTokenCount",0), "total_tokens": usage_metadata.get("totalTokenCount", 0)}, + "usage": { + "prompt_tokens": usage_metadata.get("promptTokenCount", 0), + "completion_tokens": usage_metadata.get("candidatesTokenCount", 0), + "total_tokens": usage_metadata.get("totalTokenCount", 0), + }, } @@ -127,8 +145,12 @@ class OpenAIResponseHandler(ResponseHandler): usage_metadata: Optional[Dict[str, Any]] = None, ) -> Optional[Dict[str, Any]]: if stream: - return _handle_openai_stream_response(response, model, finish_reason, usage_metadata) - return _handle_openai_normal_response(response, model, finish_reason, usage_metadata) + return _handle_openai_stream_response( + response, model, finish_reason, usage_metadata + ) + return _handle_openai_normal_response( + response, model, finish_reason, usage_metadata + ) def handle_image_chat_response( self, image_str: str, model: str, stream=False, finish_reason="stop" @@ -182,7 +204,7 @@ def _extract_result( gemini_format: bool = False, ) -> tuple[str, Optional[str], List[Dict[str, Any]], Optional[bool]]: text, reasoning_content, tool_calls, thought = "", "", [], None - + if stream: if response.get("candidates"): candidate = response["candidates"][0] @@ -191,7 +213,7 @@ def _extract_result( if not parts: logger.warning("No parts found in stream response") return "", None, [], None - + if "text" in parts[0]: text = parts[0].get("text") if "thought" in parts[0]: @@ -217,13 +239,13 @@ def _extract_result( if response.get("candidates"): candidate = response["candidates"][0] text, reasoning_content = "", "" - + # 使用安全的访问方式 content = candidate.get("content", {}) - + if content and isinstance(content, dict): parts = content.get("parts", []) - + if parts: for part in parts: if "text" in part: @@ -241,14 +263,14 @@ def _extract_result( logger.error(f"Invalid content structure for model: {model}") text = _add_search_link_text(model, candidate, text) - + # 安全地获取 parts 用于工具调用提取 parts = candidate.get("content", {}).get("parts", []) tool_calls = _extract_tool_calls(parts, gemini_format) else: logger.warning(f"No candidates found in response for model: {model}") text = "暂无返回" - + return text, reasoning_content, tool_calls, thought @@ -264,10 +286,6 @@ def _has_inline_image_part(response: Dict[str, Any]) -> bool: def _extract_image_data(part: dict) -> str: - # Return empty string if no uploader is configured - if not is_image_upload_configured(): - return "" - image_uploader = None if settings.UPLOAD_PROVIDER == "smms": image_uploader = ImageUploaderFactory.create( @@ -287,13 +305,17 @@ def _extract_image_data(part: dict) -> str: current_date = time.strftime("%Y/%m/%d") filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png" base64_data = part["inlineData"]["data"] + mime_type = part["inlineData"]["mimeType"] # 将base64_data转成bytes数组 + # Return empty string if no uploader is configured + if not is_image_upload_configured(settings): + return f"\n\n![image](data:{mime_type};base64,{base64_data})\n\n" bytes_data = base64.b64decode(base64_data) upload_response = image_uploader.upload(bytes_data, filename) if upload_response.success: text = f"\n\n![image]({upload_response.data.url})\n\n" else: - text = "" + text = f"\n\n![image](data:{mime_type};base64,{base64_data})\n\n" return text @@ -306,7 +328,7 @@ def _extract_tool_calls( letters = string.ascii_lowercase + string.digits tool_calls = list() - + for i in range(len(parts)): part = parts[i] if not part or not isinstance(part, dict): @@ -315,7 +337,7 @@ def _extract_tool_calls( item = part.get("functionCall", {}) if not item or not isinstance(item, dict): continue - + if gemini_format: tool_calls.append(part) else: @@ -339,9 +361,9 @@ def _handle_gemini_stream_response( response: Dict[str, Any], model: str, stream: bool ) -> Dict[str, Any]: # Early return raw Gemini response if no uploader configured and contains inline images - if not is_image_upload_configured() and _has_inline_image_part(response): + if not is_image_upload_configured(settings) and _has_inline_image_part(response): return response - + text, reasoning_content, tool_calls, thought = _extract_result( response, model, stream=stream, gemini_format=True ) @@ -360,9 +382,9 @@ def _handle_gemini_normal_response( response: Dict[str, Any], model: str, stream: bool ) -> Dict[str, Any]: # Early return raw Gemini response if no uploader configured and contains inline images - if not is_image_upload_configured() and _has_inline_image_part(response): + if not is_image_upload_configured(settings) and _has_inline_image_part(response): return response - + text, reasoning_content, tool_calls, thought = _extract_result( response, model, stream=stream, gemini_format=True ) @@ -371,7 +393,7 @@ def _handle_gemini_normal_response( parts = tool_calls else: if thought is not None: - parts.append({"text": reasoning_content,"thought": thought}) + parts.append({"text": reasoning_content, "thought": thought}) part = {"text": text} parts.append(part) content = {"parts": parts, "role": "model"} diff --git a/app/log/logger.py b/app/log/logger.py index 0b8c0d2..71e9c3c 100644 --- a/app/log/logger.py +++ b/app/log/logger.py @@ -1,9 +1,8 @@ import logging import platform -import sys import re +import sys from typing import Dict, Optional -from app.utils.helpers import redact_key_for_logging as _redact_key_for_logging # ANSI转义序列颜色代码 COLORS = { @@ -15,7 +14,6 @@ COLORS = { } - # Windows系统启用ANSI支持 if platform.system() == "Windows": import ctypes @@ -46,14 +44,16 @@ class AccessLogFormatter(logging.Formatter): # API key patterns to match in URLs API_KEY_PATTERNS = [ - r'\bAIza[0-9A-Za-z_-]{35}', # Google API keys (like Gemini) - r'\bsk-[0-9A-Za-z_-]{20,}', # OpenAI and general sk- prefixed keys + r"\bAIza[0-9A-Za-z_-]{35}", # Google API keys (like Gemini) + r"\bsk-[0-9A-Za-z_-]{20,}", # OpenAI and general sk- prefixed keys ] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Compile regex patterns for better performance - self.compiled_patterns = [re.compile(pattern) for pattern in self.API_KEY_PATTERNS] + self.compiled_patterns = [ + re.compile(pattern) for pattern in self.API_KEY_PATTERNS + ] def format(self, record): # Format the record normally first @@ -68,9 +68,10 @@ class AccessLogFormatter(logging.Formatter): """ try: for pattern in self.compiled_patterns: + def replace_key(match): key = match.group(0) - return _redact_key_for_logging(key) + return redact_key_for_logging(key) message = pattern.sub(replace_key, message) @@ -78,11 +79,31 @@ class AccessLogFormatter(logging.Formatter): except Exception as e: # Log the error but don't expose the original message in case it contains keys import logging + logger = logging.getLogger(__name__) logger.error(f"Error redacting API keys in access log: {e}") return "[LOG_REDACTION_ERROR]" +def redact_key_for_logging(key: str) -> str: + """ + Redacts API key for secure logging by showing only first and last 6 characters. + + Args: + key: API key to redact + + Returns: + str: Redacted key in format "first6...last6" or descriptive placeholder for edge cases + """ + if not key: + return key + + if len(key) <= 12: + return f"{key[:3]}...{key[-3:]}" + else: + return f"{key[:6]}...{key[-6:]}" + + # 日志格式 - 使用 fileloc 并设置固定宽度 (例如 30) FORMATTER = ColoredFormatter( "%(asctime)s | %(levelname)-17s | %(fileloc)-30s | %(message)s" @@ -326,4 +347,3 @@ def setup_access_logging(): access_logger.propagate = False return access_logger - diff --git a/app/service/chat/openai_chat_service.py b/app/service/chat/openai_chat_service.py index 3dfb17e..24cd12d 100644 --- a/app/service/chat/openai_chat_service.py +++ b/app/service/chat/openai_chat_service.py @@ -285,7 +285,9 @@ class OpenAIChatService: api_key: str, ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: """创建聊天完成""" - messages, instruction = self.message_converter.convert(request.messages) + messages, instruction = self.message_converter.convert( + request.messages, request.model + ) payload = _build_payload(request, messages, instruction) diff --git a/app/service/image/image_create_service.py b/app/service/image/image_create_service.py index b90ef99..7b676d7 100644 --- a/app/service/image/image_create_service.py +++ b/app/service/image/image_create_service.py @@ -9,8 +9,8 @@ from app.config.config import settings from app.core.constants import VALID_IMAGE_RATIOS from app.domain.openai_models import ImageGenerationRequest from app.log.logger import get_image_create_logger -from app.utils.uploader import ImageUploaderFactory from app.utils.helpers import is_image_upload_configured +from app.utils.uploader import ImageUploaderFactory logger = get_image_create_logger() @@ -99,7 +99,10 @@ class ImageCreateService: image_uploader = None # Return base64 if explicitly requested or if no uploader is configured - if request.response_format == "b64_json" or not is_image_upload_configured(): + if ( + request.response_format == "b64_json" + or not is_image_upload_configured(settings) + ): base64_image = base64.b64encode(image_data).decode("utf-8") images_data.append( {"b64_json": base64_image, "revised_prompt": request.prompt} diff --git a/app/utils/helpers.py b/app/utils/helpers.py index 29778ed..6afa69b 100644 --- a/app/utils/helpers.py +++ b/app/utils/helpers.py @@ -1,14 +1,17 @@ """ 通用工具函数模块 """ -import json -import re -import base64 -import requests -from typing import Dict, Any, List, Optional, Tuple -from pathlib import Path -import logging +import base64 +import json +import logging +import re +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import requests + +from app.config.config import Settings from app.core.constants import DATA_URL_PATTERN, IMAGE_URL_PATTERN, VALID_IMAGE_RATIOS helper_logger = logging.getLogger("app.utils") @@ -20,23 +23,25 @@ VERSION_FILE_PATH = PROJECT_ROOT / "VERSION" def extract_mime_type_and_data(base64_string: str) -> Tuple[Optional[str], str]: """ 从 base64 字符串中提取 MIME 类型和数据 - + Args: base64_string: 可能包含 MIME 类型信息的 base64 字符串 - + Returns: tuple: (mime_type, encoded_data) """ # 检查字符串是否以 "data:" 格式开始 - if base64_string.startswith('data:'): + if base64_string.startswith("data:"): # 提取 MIME 类型和数据 pattern = DATA_URL_PATTERN match = re.match(pattern, base64_string) if match: - mime_type = "image/jpeg" if match.group(1) == "image/jpg" else match.group(1) + mime_type = ( + "image/jpeg" if match.group(1) == "image/jpg" else match.group(1) + ) encoded_data = match.group(2) return mime_type, encoded_data - + # 如果不是预期格式,假定它只是数据部分 return None, base64_string @@ -44,20 +49,20 @@ def extract_mime_type_and_data(base64_string: str) -> Tuple[Optional[str], str]: def convert_image_to_base64(url: str) -> str: """ 将图片URL转换为base64编码 - + Args: url: 图片URL - + Returns: str: base64编码的图片数据 - + Raises: Exception: 如果获取图片失败 """ response = requests.get(url) if response.status_code == 200: # 将图片内容转换为base64 - img_data = base64.b64encode(response.content).decode('utf-8') + img_data = base64.b64encode(response.content).decode("utf-8") return img_data else: raise Exception(f"Failed to fetch image: {response.status_code}") @@ -66,64 +71,66 @@ def convert_image_to_base64(url: str) -> str: def format_json_response(data: Dict[str, Any], indent: int = 2) -> str: """ 格式化JSON响应 - + Args: data: 要格式化的数据 indent: 缩进空格数 - + Returns: str: 格式化后的JSON字符串 """ return json.dumps(data, indent=indent, ensure_ascii=False) -def parse_prompt_parameters(prompt: str, default_ratio: str = "1:1") -> Tuple[str, int, str]: +def parse_prompt_parameters( + prompt: str, default_ratio: str = "1:1" +) -> Tuple[str, int, str]: """ 从prompt中解析参数 - + 支持的格式: - {n:数量} 例如: {n:2} 生成2张图片 - {ratio:比例} 例如: {ratio:16:9} 使用16:9比例 - + Args: prompt: 提示文本 default_ratio: 默认比例 - + Returns: tuple: (清理后的提示文本, 图片数量, 比例) """ # 默认值 n = 1 aspect_ratio = default_ratio - + # 解析n参数 - n_match = re.search(r'{n:(\d+)}', prompt) + n_match = re.search(r"{n:(\d+)}", prompt) if n_match: n = int(n_match.group(1)) if n < 1 or n > 4: raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.") - prompt = prompt.replace(n_match.group(0), '').strip() - - # 解析ratio参数 - ratio_match = re.search(r'{ratio:(\d+:\d+)}', prompt) + prompt = prompt.replace(n_match.group(0), "").strip() + + # 解析ratio参数 + ratio_match = re.search(r"{ratio:(\d+:\d+)}", prompt) if ratio_match: aspect_ratio = ratio_match.group(1) if aspect_ratio not in VALID_IMAGE_RATIOS: raise ValueError( f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}" ) - prompt = prompt.replace(ratio_match.group(0), '').strip() - + prompt = prompt.replace(ratio_match.group(0), "").strip() + return prompt, n, aspect_ratio def extract_image_urls_from_markdown(text: str) -> List[str]: """ 从Markdown文本中提取图片URL - + Args: text: Markdown文本 - + Returns: List[str]: 图片URL列表 """ @@ -135,23 +142,22 @@ def extract_image_urls_from_markdown(text: str) -> List[str]: def is_valid_api_key(key: str) -> bool: """ 检查API密钥格式是否有效 - + Args: key: API密钥 - + Returns: bool: 如果密钥格式有效则返回True """ # 检查Gemini API密钥格式 - if key.startswith('AIza'): + if key.startswith("AIza"): return len(key) >= 30 - - # 检查OpenAI API密钥格式 - if key.startswith('sk-'): - return len(key) >= 30 - - return False + # 检查OpenAI API密钥格式 + if key.startswith("sk-"): + return len(key) >= 30 + + return False def redact_key_for_logging(key: str) -> str: @@ -177,26 +183,28 @@ def get_current_version(default_version: str = "0.0.0") -> str: """Reads the current version from the VERSION file.""" version_file = VERSION_FILE_PATH try: - with version_file.open('r', encoding='utf-8') as f: + with version_file.open("r", encoding="utf-8") as f: version = f.read().strip() if not version: - helper_logger.warning(f"VERSION file ('{version_file}') is empty. Using default version '{default_version}'.") + helper_logger.warning( + f"VERSION file ('{version_file}') is empty. Using default version '{default_version}'." + ) return default_version return version except FileNotFoundError: - helper_logger.warning(f"VERSION file not found at '{version_file}'. Using default version '{default_version}'.") + helper_logger.warning( + f"VERSION file not found at '{version_file}'. Using default version '{default_version}'." + ) return default_version except IOError as e: - helper_logger.error(f"Error reading VERSION file ('{version_file}'): {e}. Using default version '{default_version}'.") + helper_logger.error( + f"Error reading VERSION file ('{version_file}'): {e}. Using default version '{default_version}'." + ) return default_version -def is_image_upload_configured() -> bool: - """Return True only if a valid upload provider is selected and all required settings for that provider are present. Uses lazy import to avoid circular imports.""" - try: - from app.config.config import settings # local import to avoid circular dependency at module import time - except Exception: - return False +def is_image_upload_configured(settings: Settings) -> bool: + """Return True only if a valid upload provider is selected and all required settings for that provider are present.""" provider = (getattr(settings, "UPLOAD_PROVIDER", "") or "").strip().lower() if provider == "smms": @@ -204,8 +212,10 @@ def is_image_upload_configured() -> bool: if provider == "picgo": return bool(getattr(settings, "PICGO_API_KEY", "")) if provider == "cloudflare_imgbed": - return all([ - getattr(settings, "CLOUDFLARE_IMGBED_URL", ""), - getattr(settings, "CLOUDFLARE_IMGBED_AUTH_CODE", ""), - ]) + return all( + [ + getattr(settings, "CLOUDFLARE_IMGBED_URL", ""), + getattr(settings, "CLOUDFLARE_IMGBED_AUTH_CODE", ""), + ] + ) return False