diff --git a/app/agent/prompt/Agent Prompt.txt b/app/agent/prompt/Agent Prompt.txt index 8672f93a..4af9e6ef 100644 --- a/app/agent/prompt/Agent Prompt.txt +++ b/app/agent/prompt/Agent Prompt.txt @@ -10,7 +10,7 @@ Core Capabilities: 3. Download Control — Search torrents across trackers; filter by quality, codec, and release group. 4. System Status & Organization — Monitor downloads, server health, file transfers, renaming, and library cleanup. 5. Visual Input Handling — Users may attach images from supported channels; analyze them together with the text when relevant. -6. File Context Handling — User messages may arrive as structured JSON. Treat the `message` field as the user's text. Non-image attachments appear in `files`; when `local_path` is present, use local file tools to inspect the uploaded file directly. +6. File Context Handling — User messages may arrive as structured JSON. Treat the `message` field as the user's text. Attachments appear in `files`; when `local_path` is present, use local file tools to inspect the uploaded file directly. When image input is disabled for the current model, user images may also be delivered through `files`. {verbose_spec} diff --git a/app/chain/message.py b/app/chain/message.py index 8af455c5..27e9bb47 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -1,10 +1,11 @@ import asyncio +import mimetypes import re import time from datetime import datetime, timedelta from pathlib import Path from typing import Any, Optional, Dict, Union, List -from urllib.parse import unquote +from urllib.parse import unquote, urlparse import uuid import base64 @@ -20,6 +21,7 @@ from app.core.context import MediaInfo, Context from app.core.meta import MetaBase from app.db.user_oper import UserOper from app.helper.torrent import TorrentHelper +from app.helper.llm import LLMHelper from app.helper.voice import VoiceHelper from app.log import logger from app.schemas import Notification, NotExistMediaInfo, CommingMessage @@ -169,7 +171,7 @@ class MessageChain(ChainBase): text: str, original_message_id: Optional[Union[str, int]] = None, original_chat_id: Optional[str] = None, - images: Optional[List[str]] = None, + images: Optional[List[CommingMessage.MessageImage]] = None, audio_refs: Optional[List[str]] = None, files: Optional[List[CommingMessage.MessageAttachment]] = None, ) -> None: @@ -183,6 +185,8 @@ class MessageChain(ChainBase): user_cache: Dict[str, dict] = self.load_cache(self._cache_file) or {} try: + images = CommingMessage.MessageImage.normalize_list(images) + # 识别语音为文本 reply_with_voice = bool(audio_refs) if audio_refs: @@ -1238,7 +1242,7 @@ class MessageChain(ChainBase): source: str, userid: Union[str, int], username: str, - images: Optional[List[str]] = None, + images: Optional[List[CommingMessage.MessageImage]] = None, files: Optional[List[CommingMessage.MessageAttachment]] = None, reply_with_voice: bool = False, ) -> None: @@ -1259,6 +1263,8 @@ class MessageChain(ChainBase): ) return + images = CommingMessage.MessageImage.normalize_list(images) + # 提取用户消息 if text.lower().startswith("/ai"): user_message = text[3:].strip() # 移除 "/ai" 前缀(大小写不敏感) @@ -1282,7 +1288,8 @@ class MessageChain(ChainBase): # 下载图片并转为base64 original_images = images - if images: + all_files = list(files or []) + if images and LLMHelper.supports_image_input(): images = self._download_images_to_base64(images, channel, source) if original_images and not images and not user_message and not files: self.post_message( @@ -1295,14 +1302,29 @@ class MessageChain(ChainBase): ) ) return + elif images: + image_attachments = self._build_image_attachments(images) + if original_images and not image_attachments and not user_message and not files: + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title="图片读取失败,请稍后重试", + ) + ) + return + all_files.extend(image_attachments) + images = None prepared_files = self._prepare_agent_files( session_id=session_id, - files=files, + files=all_files, channel=channel, source=source, ) - if files and not prepared_files and not user_message and not images: + if all_files and not prepared_files and not user_message and not images: self.post_message( Notification( channel=channel, @@ -1452,15 +1474,20 @@ class MessageChain(ChainBase): return default def _download_images_to_base64( - self, images: List[str], channel: MessageChannel, source: str + self, + images: List[CommingMessage.MessageImage], + channel: MessageChannel, + source: str, ) -> List[str]: """ 下载图片并转为base64 """ + images = CommingMessage.MessageImage.normalize_list(images) if not images: return None base64_images = [] - for img in images: + for image in images: + img = image.ref try: if img.startswith("data:"): base64_images.append(img) @@ -1511,6 +1538,33 @@ class MessageChain(ChainBase): logger.error(f"下载图片失败: {img}, error: {e}") return base64_images if base64_images else None + def _build_image_attachments( + self, images: List[CommingMessage.MessageImage] + ) -> List[CommingMessage.MessageAttachment]: + """ + 将图片引用转换为附件描述,以便按文件方式交给 Agent 处理。 + """ + images = CommingMessage.MessageImage.normalize_list(images) + if not images: + return [] + + attachments = [] + for index, image in enumerate(images, start=1): + image_ref = image.ref + if not image_ref: + continue + name = image.name or self._guess_image_attachment_name(image_ref, index) + mime_type = image.mime_type or self._guess_image_mime_type(image_ref, name) + attachments.append( + CommingMessage.MessageAttachment( + ref=image_ref, + name=name, + mime_type=mime_type, + size=image.size, + ) + ) + return attachments + def _prepare_agent_files( self, session_id: str, @@ -1570,15 +1624,31 @@ class MessageChain(ChainBase): """ if not file_ref: return None + if file_ref.startswith("data:"): + return self._decode_data_url_bytes(file_ref) + if file_ref.startswith("tg://file_id/"): + file_id = file_ref.replace("tg://file_id/", "", 1) + return self.run_module( + "download_telegram_file_bytes", file_id=file_id, source=source + ) if file_ref.startswith("tg://document_file_id/"): file_id = file_ref.replace("tg://document_file_id/", "", 1) return self.run_module( "download_telegram_file_bytes", file_id=file_id, source=source ) + if file_ref.startswith("wxwork://media_id/"): + return self.run_module( + "download_wechat_media_bytes", media_ref=file_ref, source=source + ) if file_ref.startswith("wxwork://file_media_id/"): return self.run_module( "download_wechat_media_bytes", media_ref=file_ref, source=source ) + if file_ref.startswith("wxbot://image/"): + data_url = self.run_module( + "download_wechat_image_to_data_url", image_ref=file_ref, source=source + ) + return self._decode_data_url_bytes(data_url) if data_url else None if file_ref.startswith("wxbot://file/"): file_url = unquote(file_ref.replace("wxbot://file/", "", 1)) resp = RequestUtils(timeout=30).get_res(file_url) @@ -1604,6 +1674,11 @@ class MessageChain(ChainBase): "download_synologychat_file_bytes", file_ref=file_ref, source=source ) if file_ref.startswith("http"): + if channel == MessageChannel.Slack: + data_url = self.run_module( + "download_slack_file_to_data_url", file_url=file_ref, source=source + ) + return self._decode_data_url_bytes(data_url) if data_url else None resp = RequestUtils(timeout=30).get_res(file_ref) return resp.content if resp and resp.content else None logger.debug( @@ -1647,6 +1722,11 @@ class MessageChain(ChainBase): if "." not in name: mime = (mime_type or "").split(";", 1)[0].strip().lower() default_ext = { + "image/jpeg": ".jpg", + "image/png": ".png", + "image/gif": ".gif", + "image/webp": ".webp", + "image/bmp": ".bmp", "application/json": ".json", "text/plain": ".txt", "text/markdown": ".md", @@ -1655,3 +1735,50 @@ class MessageChain(ChainBase): if default_ext: name = f"{name}{default_ext}" return name + + @staticmethod + def _guess_image_attachment_name(image_ref: str, index: int) -> str: + """ + 根据图片引用推测附件名。 + """ + if not image_ref: + return f"image_{index}.jpg" + if image_ref.startswith("data:"): + mime_part = image_ref[5:].split(";", 1)[0].strip().lower() + ext = mimetypes.guess_extension(mime_part) or ".jpg" + return f"image_{index}{ext}" + + parsed = urlparse(unquote(image_ref)) + name = Path(parsed.path).name if parsed.path else "" + if name and "." in name: + return name + return f"image_{index}.jpg" + + @staticmethod + def _guess_image_mime_type(image_ref: str, filename: Optional[str]) -> str: + """ + 根据图片引用或文件名推测 MIME 类型。 + """ + if image_ref and image_ref.startswith("data:"): + mime = image_ref[5:].split(";", 1)[0].strip().lower() + return mime or "image/jpeg" + guessed, _ = mimetypes.guess_type(filename or "") + if guessed and guessed.startswith("image/"): + return guessed + return "image/jpeg" + + @staticmethod + def _decode_data_url_bytes(data_url: Optional[str]) -> Optional[bytes]: + """ + 将 data URL 解码为原始字节。 + """ + if not data_url or not data_url.startswith("data:"): + return None + try: + _, payload = data_url.split(",", 1) + except ValueError: + return None + try: + return base64.b64decode(payload) + except Exception: + return None diff --git a/app/core/config.py b/app/core/config.py index e1eb6fd7..f31387af 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -494,6 +494,8 @@ class ConfigModel(BaseModel): LLM_PROVIDER: str = "deepseek" # LLM模型名称 LLM_MODEL: str = "deepseek-chat" + # LLM是否支持图片输入,开启后消息图片会按多模态输入发送给模型 + LLM_SUPPORT_IMAGE_INPUT: bool = False # LLM API密钥 LLM_API_KEY: Optional[str] = None # LLM基础URL(用于自定义API端点) diff --git a/app/helper/llm.py b/app/helper/llm.py index 08bce910..fe4e252b 100644 --- a/app/helper/llm.py +++ b/app/helper/llm.py @@ -59,6 +59,13 @@ def _get_httpx_proxy_key() -> str: class LLMHelper: """LLM模型相关辅助功能""" + @staticmethod + def supports_image_input() -> bool: + """ + 判断当前模型是否启用了图片输入能力。 + """ + return bool(settings.LLM_SUPPORT_IMAGE_INPUT) + @staticmethod def get_llm(streaming: bool = False): """ diff --git a/app/modules/discord/__init__.py b/app/modules/discord/__init__.py index 9c469c30..8b7bbe89 100644 --- a/app/modules/discord/__init__.py +++ b/app/modules/discord/__init__.py @@ -181,7 +181,9 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]): return None @staticmethod - def _extract_images(msg_json: dict) -> Optional[List[str]]: + def _extract_images( + msg_json: dict, + ) -> Optional[List[CommingMessage.MessageImage]]: """ 从Discord消息中提取图片URL """ @@ -200,7 +202,14 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]): or content_type.startswith("image/") or filename.endswith(DiscordModule._IMAGE_SUFFIXES) ): - images.append(url) + images.append( + CommingMessage.MessageImage( + ref=url, + name=attachment.get("filename"), + mime_type=attachment.get("content_type"), + size=attachment.get("size"), + ) + ) return images if images else None @classmethod diff --git a/app/modules/qqbot/__init__.py b/app/modules/qqbot/__init__.py index eb76e1f3..0a5557b3 100644 --- a/app/modules/qqbot/__init__.py +++ b/app/modules/qqbot/__init__.py @@ -155,8 +155,10 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]): return None @classmethod - def _extract_images(cls, msg_body: dict) -> Optional[List[str]]: - images: List[str] = [] + def _extract_images( + cls, msg_body: dict + ) -> Optional[List[CommingMessage.MessageImage]]: + images: List[CommingMessage.MessageImage] = [] attachments = msg_body.get("attachments") or [] if isinstance(attachments, list): for attachment in attachments: @@ -176,26 +178,42 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]): or "" ).lower() if content_type.startswith("image/") or filename.endswith(cls._IMAGE_SUFFIXES): - images.append(url) + images.append( + CommingMessage.MessageImage( + ref=url, + name=attachment.get("filename") or attachment.get("name"), + mime_type=attachment.get("content_type") + or attachment.get("mime_type"), + size=attachment.get("size"), + ) + ) for key in ("image", "image_url", "pic_url"): value = msg_body.get(key) if isinstance(value, str) and value.startswith("http"): - images.append(value) + images.append(CommingMessage.MessageImage(ref=value)) extra_images = msg_body.get("images") if isinstance(extra_images, list): for item in extra_images: if isinstance(item, str) and item.startswith("http"): - images.append(item) + images.append(CommingMessage.MessageImage(ref=item)) elif isinstance(item, dict): url = item.get("url") or item.get("image_url") if isinstance(url, str) and url.startswith("http"): - images.append(url) + images.append( + CommingMessage.MessageImage( + ref=url, + name=item.get("name") or item.get("filename"), + mime_type=item.get("content_type") + or item.get("mime_type"), + size=item.get("size"), + ) + ) deduped = [] for image in images: - if image not in deduped: + if image.ref not in [item.ref for item in deduped]: deduped.append(image) return deduped or None diff --git a/app/modules/slack/__init__.py b/app/modules/slack/__init__.py index d84774de..1f089c90 100644 --- a/app/modules/slack/__init__.py +++ b/app/modules/slack/__init__.py @@ -301,7 +301,9 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): return None @staticmethod - def _extract_images(msg_json: dict) -> Optional[List[str]]: + def _extract_images( + msg_json: dict, + ) -> Optional[List[CommingMessage.MessageImage]]: """ 从Slack消息中提取图片URL """ @@ -320,7 +322,14 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): ): url = file.get("url_private") or file.get("url_private_download") if url: - images.append(url) + images.append( + CommingMessage.MessageImage( + ref=url, + name=file.get("name") or file.get("title"), + mime_type=file.get("mimetype"), + size=file.get("size"), + ) + ) return images if images else None @classmethod diff --git a/app/modules/synologychat/__init__.py b/app/modules/synologychat/__init__.py index 179e90cd..3658ae62 100644 --- a/app/modules/synologychat/__init__.py +++ b/app/modules/synologychat/__init__.py @@ -141,12 +141,14 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]): return None @classmethod - def _extract_images(cls, message: dict) -> Optional[List[str]]: + def _extract_images( + cls, message: dict + ) -> Optional[List[CommingMessage.MessageImage]]: images = [] for key in ("file_url", "image_url", "pic_url"): value = message.get(key) if isinstance(value, str) and cls._looks_like_image(value): - images.append(value) + images.append(CommingMessage.MessageImage(ref=value)) for key in ("attachments", "files"): raw_value = message.get(key) @@ -159,15 +161,23 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]): items = parsed if isinstance(parsed, list) else [parsed] for item in items: if isinstance(item, str) and cls._looks_like_image(item): - images.append(item) + images.append(CommingMessage.MessageImage(ref=item)) elif isinstance(item, dict): url = item.get("url") or item.get("file_url") or item.get("image_url") if isinstance(url, str) and cls._looks_like_image(url): - images.append(url) + images.append( + CommingMessage.MessageImage( + ref=url, + name=item.get("name") or item.get("filename"), + mime_type=item.get("content_type") + or item.get("mime_type"), + size=item.get("size"), + ) + ) deduped = [] for image in images: - if image not in deduped: + if image.ref not in [item.ref for item in deduped]: deduped.append(image) return deduped or None diff --git a/app/modules/telegram/__init__.py b/app/modules/telegram/__init__.py index 50385563..001669e9 100644 --- a/app/modules/telegram/__init__.py +++ b/app/modules/telegram/__init__.py @@ -273,7 +273,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): return None @staticmethod - def _extract_images(msg: dict) -> Optional[List[str]]: + def _extract_images(msg: dict) -> Optional[List[CommingMessage.MessageImage]]: """ 从Telegram消息中提取图片file_id """ @@ -283,14 +283,27 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): largest_photo = photo[-1] file_id = largest_photo.get("file_id") if file_id: - images.append(f"tg://file_id/{file_id}") + images.append( + CommingMessage.MessageImage( + ref=f"tg://file_id/{file_id}", + mime_type="image/jpeg", + size=largest_photo.get("file_size"), + ) + ) document = msg.get("document") if document: file_id = document.get("file_id") mime_type = document.get("mime_type", "") if file_id and mime_type.startswith("image/"): - images.append(f"tg://file_id/{file_id}") + images.append( + CommingMessage.MessageImage( + ref=f"tg://file_id/{file_id}", + name=document.get("file_name"), + mime_type=document.get("mime_type"), + size=document.get("file_size"), + ) + ) return images if images else None diff --git a/app/modules/vocechat/__init__.py b/app/modules/vocechat/__init__.py index ca88f6d2..78ec5c74 100644 --- a/app/modules/vocechat/__init__.py +++ b/app/modules/vocechat/__init__.py @@ -162,7 +162,9 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]): return None @classmethod - def _extract_images(cls, detail: dict) -> Optional[List[str]]: + def _extract_images( + cls, detail: dict + ) -> Optional[List[CommingMessage.MessageImage]]: content_type = detail.get("content_type") or "" if content_type != "vocechat/file": return None @@ -194,9 +196,23 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]): if not is_image: return None if isinstance(direct_url, str) and direct_url.startswith("http"): - return [direct_url] + return [ + CommingMessage.MessageImage( + ref=direct_url, + name=properties.get("name") or properties.get("filename"), + mime_type=mime_type or None, + size=properties.get("size"), + ) + ] if isinstance(file_path, str) and file_path: - return [f"vocechat://file/{quote(file_path, safe='')}"] + return [ + CommingMessage.MessageImage( + ref=f"vocechat://file/{quote(file_path, safe='')}", + name=properties.get("name") or properties.get("filename"), + mime_type=mime_type or None, + size=properties.get("size"), + ) + ] return None @classmethod diff --git a/app/modules/wechat/__init__.py b/app/modules/wechat/__init__.py index 16aaec42..b82c2a80 100644 --- a/app/modules/wechat/__init__.py +++ b/app/modules/wechat/__init__.py @@ -189,9 +189,9 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]): media_id = DomUtils.tag_value(root_node, "MediaId") pic_url = DomUtils.tag_value(root_node, "PicUrl") if media_id: - images = [f"wxwork://media_id/{media_id}"] + images = [CommingMessage.MessageImage(ref=f"wxwork://media_id/{media_id}")] elif pic_url: - images = [pic_url] + images = [CommingMessage.MessageImage(ref=pic_url)] logger.info( f"收到来自 {client_config.name} 的微信图片消息:userid={user_id}, images={len(images) if images else 0}" ) diff --git a/app/modules/wechat/wechatbot.py b/app/modules/wechat/wechatbot.py index da349650..1858f5f5 100644 --- a/app/modules/wechat/wechatbot.py +++ b/app/modules/wechat/wechatbot.py @@ -16,6 +16,7 @@ from app.core.config import settings from app.core.context import MediaInfo, Context from app.core.metainfo import MetaInfo from app.log import logger +from app.schemas import CommingMessage from app.utils.http import RequestUtils from app.utils.string import StringUtils @@ -359,27 +360,50 @@ class WeChatBot: return f"wxbot://image/{encoded}" @classmethod - def _extract_images_from_body(cls, body: dict) -> Optional[List[str]]: - images: List[str] = [] + def _extract_images_from_body( + cls, body: dict + ) -> Optional[List["CommingMessage.MessageImage"]]: + images: List["CommingMessage.MessageImage"] = [] msgtype = body.get("msgtype") if msgtype == "image": - image_ref = cls._build_image_ref(body.get("image") or {}) + image_payload = body.get("image") or {} + image_ref = cls._build_image_ref(image_payload) if image_ref: - images.append(image_ref) + images.append( + CommingMessage.MessageImage( + ref=image_ref, + mime_type=image_payload.get("mime_type") + or image_payload.get("content_type"), + ) + ) elif msgtype == "mixed": for item in (body.get("mixed") or {}).get("msg_item") or []: if item.get("msgtype") != "image": continue - image_ref = cls._build_image_ref(item.get("image") or {}) + image_payload = item.get("image") or {} + image_ref = cls._build_image_ref(image_payload) if image_ref: - images.append(image_ref) + images.append( + CommingMessage.MessageImage( + ref=image_ref, + mime_type=image_payload.get("mime_type") + or image_payload.get("content_type"), + ) + ) quote = body.get("quote") or {} if not images and quote.get("msgtype") == "image": - image_ref = cls._build_image_ref(quote.get("image") or {}) + image_payload = quote.get("image") or {} + image_ref = cls._build_image_ref(image_payload) if image_ref: - images.append(image_ref) + images.append( + CommingMessage.MessageImage( + ref=image_ref, + mime_type=image_payload.get("mime_type") + or image_payload.get("content_type"), + ) + ) return images or None diff --git a/app/schemas/message.py b/app/schemas/message.py index 8766c559..b96f9fa3 100644 --- a/app/schemas/message.py +++ b/app/schemas/message.py @@ -1,8 +1,8 @@ from dataclasses import dataclass from enum import Enum -from typing import Optional, Union, List, Dict, Set +from typing import Optional, Union, List, Dict, Set, Any -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from app.schemas.types import ContentType, NotificationType, MessageChannel @@ -29,6 +29,61 @@ class CommingMessage(BaseModel): 外来消息 """ + class MessageImage(BaseModel): + """ + 外来消息图片 + """ + + ref: str + name: Optional[str] = None + mime_type: Optional[str] = None + size: Optional[int] = None + + @classmethod + def from_value(cls, value: Any) -> Optional["CommingMessage.MessageImage"]: + if value is None: + return None + if isinstance(value, cls): + return value + if isinstance(value, str): + return cls(ref=value) + if isinstance(value, dict): + ref = ( + value.get("ref") + or value.get("url") + or value.get("image_url") + or value.get("file_url") + ) + if not ref: + return None + size = value.get("size") + try: + size = int(size) if size is not None else None + except (TypeError, ValueError): + size = None + return cls( + ref=ref, + name=value.get("name") or value.get("filename"), + mime_type=value.get("mime_type") or value.get("content_type"), + size=size, + ) + return None + + @classmethod + def normalize_list( + cls, values: Optional[Any] + ) -> Optional[List["CommingMessage.MessageImage"]]: + if not values: + return None + if not isinstance(values, list): + values = [values] + normalized = [] + for value in values: + item = cls.from_value(value) + if item: + normalized.append(item) + return normalized or None + class MessageAttachment(BaseModel): """ 外来消息附件(非图片/非语音) @@ -64,12 +119,19 @@ class CommingMessage(BaseModel): # 完整的回调查询信息(原始数据) callback_query: Optional[Dict] = None # 图片列表(图片URL或file_id) - images: Optional[List[str]] = None + images: Optional[List[MessageImage]] = None # 语音/音频引用列表 audio_refs: Optional[List[str]] = None # 文件附件列表 files: Optional[List[MessageAttachment]] = None + @field_validator("images", mode="before") + @classmethod + def _normalize_images( + cls, value: Any + ) -> Optional[List["CommingMessage.MessageImage"]]: + return cls.MessageImage.normalize_list(value) + def to_dict(self): """ 转换为字典 diff --git a/tests/test_agent_image_support.py b/tests/test_agent_image_support.py index 73c6a5bb..6e15efa8 100644 --- a/tests/test_agent_image_support.py +++ b/tests/test_agent_image_support.py @@ -14,6 +14,7 @@ from app.agent.tools.impl.send_local_file import SendLocalFileInput from app.agent import MoviePilotAgent, AgentChain from app.chain.message import MessageChain from app.core.config import settings +from app.helper.llm import LLMHelper from app.helper.voice import VoiceHelper from app.modules.discord import DiscordModule from app.modules.qqbot import QQBotModule @@ -46,14 +47,18 @@ class AgentImageSupportTest(unittest.TestCase): images = TelegramModule._extract_images( { "photo": [{"file_id": "small"}, {"file_id": "large"}], - "document": {"file_id": "doc-image", "mime_type": "image/png"}, + "document": { + "file_id": "doc-image", + "mime_type": "image/png", + "file_name": "poster.png", + }, } ) - self.assertEqual( - images, - ["tg://file_id/large", "tg://file_id/doc-image"], - ) + self.assertEqual([image.ref for image in images], ["tg://file_id/large", "tg://file_id/doc-image"]) + self.assertEqual(images[0].mime_type, "image/jpeg") + self.assertEqual(images[1].mime_type, "image/png") + self.assertEqual(images[1].name, "poster.png") def test_telegram_message_parser_accepts_double_encoded_body(self): module = TelegramModule() @@ -83,7 +88,7 @@ class AgentImageSupportTest(unittest.TestCase): ) self.assertIsNotNone(message) - self.assertEqual(message.images, ["tg://file_id/large"]) + self.assertEqual([image.ref for image in message.images], ["tg://file_id/large"]) def test_telegram_forward_payload_uses_dict_not_json_string(self): payload = Telegram._serialize_update_payload( @@ -143,7 +148,7 @@ class AgentImageSupportTest(unittest.TestCase): handle_kwargs = handle_message.call_args.kwargs self.assertEqual(handle_kwargs["text"], "") - self.assertEqual(handle_kwargs["images"], ["tg://file_id/image-1"]) + self.assertEqual([image.ref for image in handle_kwargs["images"]], ["tg://file_id/image-1"]) def test_process_allows_audio_only_message(self): chain = MessageChain() @@ -370,6 +375,61 @@ class AgentImageSupportTest(unittest.TestCase): self.assertEqual(payload["message"], "帮我总结这个文件") self.assertEqual(payload["files"][0]["local_path"], "/tmp/report.txt") + def test_llm_supports_image_input_respects_explicit_override(self): + with patch.object(settings, "LLM_SUPPORT_IMAGE_INPUT", False): + self.assertFalse(LLMHelper.supports_image_input()) + + def test_llm_supports_image_input_uses_boolean_setting(self): + with patch.object(settings, "LLM_SUPPORT_IMAGE_INPUT", True): + self.assertTrue(LLMHelper.supports_image_input()) + + with patch.object(settings, "LLM_SUPPORT_IMAGE_INPUT", False): + self.assertFalse(LLMHelper.supports_image_input()) + + def test_handle_ai_message_routes_images_to_files_when_image_input_disabled(self): + chain = MessageChain() + + with patch.object(settings, "AI_AGENT_ENABLE", True), patch.object( + settings, "LLM_SUPPORT_IMAGE_INPUT", False + ), patch.object(chain, "_get_or_create_session_id", return_value="session-1"), patch.object( + chain, "_download_images_to_base64" + ) as download_images, patch.object( + chain, + "_prepare_agent_files", + return_value=[ + { + "name": "image_1.jpg", + "mime_type": "image/jpeg", + "local_path": "/tmp/image_1.jpg", + "status": "ready", + } + ], + ) as prepare_files, patch( + "app.chain.message.agent_manager.process_message", new_callable=AsyncMock + ) as process_message, patch( + "app.chain.message.asyncio.run_coroutine_threadsafe" + ) as run_coroutine_threadsafe: + chain._handle_ai_message( + text="/ai 帮我看看这张图", + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + username="tester", + images=["tg://file_id/image-1"], + ) + + download_images.assert_not_called() + prepare_files.assert_called_once() + attachments = prepare_files.call_args.kwargs["files"] + self.assertEqual(attachments[0].ref, "tg://file_id/image-1") + self.assertEqual(attachments[0].mime_type, "image/jpeg") + run_coroutine_threadsafe.assert_called_once() + self.assertEqual(process_message.call_args.kwargs["images"], None) + self.assertEqual( + process_message.call_args.kwargs["files"][0]["local_path"], + "/tmp/image_1.jpg", + ) + def test_slack_images_use_authenticated_data_url_download(self): chain = MessageChain() @@ -460,7 +520,8 @@ class AgentImageSupportTest(unittest.TestCase): } ) - self.assertEqual(images, ["https://cdn.discordapp.com/test.png"]) + self.assertEqual([image.ref for image in images], ["https://cdn.discordapp.com/test.png"]) + self.assertEqual(images[0].mime_type, "image/png") def test_discord_extract_audio_refs_supports_attachment_content_type(self): audio_refs = DiscordModule._extract_audio_refs( @@ -587,7 +648,7 @@ class AgentImageSupportTest(unittest.TestCase): ) self.assertIsNotNone(message) - self.assertEqual(message.images, ["wxwork://media_id/media-1"]) + self.assertEqual([image.ref for image in message.images], ["wxwork://media_id/media-1"]) def test_wechat_message_parser_extracts_file_media_id(self): module = WechatModule() @@ -661,7 +722,7 @@ class AgentImageSupportTest(unittest.TestCase): ) self.assertIsNotNone(message) - self.assertTrue(message.images[0].startswith("wxbot://image/")) + self.assertTrue(message.images[0].ref.startswith("wxbot://image/")) def test_wechat_bot_handles_image_only_callback(self): bot = WeChatBot.__new__(WeChatBot) @@ -718,9 +779,10 @@ class AgentImageSupportTest(unittest.TestCase): self.assertIsNotNone(message) self.assertEqual( - message.images, + [image.ref for image in message.images], ["vocechat://file/%2Fuploads%2Fposter.png"], ) + self.assertEqual(message.images[0].mime_type, "image/png") def test_vocechat_message_parser_extracts_audio_file_payload(self): module = VoceChatModule() @@ -901,7 +963,8 @@ class AgentImageSupportTest(unittest.TestCase): ) self.assertIsNotNone(message) - self.assertEqual(message.images, ["https://example.com/qq-image.png"]) + self.assertEqual([image.ref for image in message.images], ["https://example.com/qq-image.png"]) + self.assertEqual(message.images[0].mime_type, "image/png") def test_qq_message_parser_accepts_audio_only_attachment(self): module = QQBotModule() @@ -959,7 +1022,7 @@ class AgentImageSupportTest(unittest.TestCase): ) self.assertIsNotNone(message) - self.assertEqual(message.images, ["https://example.com/image.png"]) + self.assertEqual([image.ref for image in message.images], ["https://example.com/image.png"]) def test_synology_message_parser_accepts_audio_only_form(self): module = SynologyChatModule()