Improve agent image capability routing

This commit is contained in:
jxxghp
2026-04-15 08:55:32 +08:00
parent bf127d6a70
commit 13c3c082b8
14 changed files with 417 additions and 57 deletions

View File

@@ -10,7 +10,7 @@ Core Capabilities:
3. Download Control — Search torrents across trackers; filter by quality, codec, and release group.
4. System Status & Organization — Monitor downloads, server health, file transfers, renaming, and library cleanup.
5. Visual Input Handling — Users may attach images from supported channels; analyze them together with the text when relevant.
6. File Context Handling — User messages may arrive as structured JSON. Treat the `message` field as the user's text. Non-image attachments appear in `files`; when `local_path` is present, use local file tools to inspect the uploaded file directly.
6. File Context Handling — User messages may arrive as structured JSON. Treat the `message` field as the user's text. Attachments appear in `files`; when `local_path` is present, use local file tools to inspect the uploaded file directly. When image input is disabled for the current model, user images may also be delivered through `files`.
<communication>
{verbose_spec}

View File

@@ -1,10 +1,11 @@
import asyncio
import mimetypes
import re
import time
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Optional, Dict, Union, List
from urllib.parse import unquote
from urllib.parse import unquote, urlparse
import uuid
import base64
@@ -20,6 +21,7 @@ from app.core.context import MediaInfo, Context
from app.core.meta import MetaBase
from app.db.user_oper import UserOper
from app.helper.torrent import TorrentHelper
from app.helper.llm import LLMHelper
from app.helper.voice import VoiceHelper
from app.log import logger
from app.schemas import Notification, NotExistMediaInfo, CommingMessage
@@ -169,7 +171,7 @@ class MessageChain(ChainBase):
text: str,
original_message_id: Optional[Union[str, int]] = None,
original_chat_id: Optional[str] = None,
images: Optional[List[str]] = None,
images: Optional[List[CommingMessage.MessageImage]] = None,
audio_refs: Optional[List[str]] = None,
files: Optional[List[CommingMessage.MessageAttachment]] = None,
) -> None:
@@ -183,6 +185,8 @@ class MessageChain(ChainBase):
user_cache: Dict[str, dict] = self.load_cache(self._cache_file) or {}
try:
images = CommingMessage.MessageImage.normalize_list(images)
# 识别语音为文本
reply_with_voice = bool(audio_refs)
if audio_refs:
@@ -1238,7 +1242,7 @@ class MessageChain(ChainBase):
source: str,
userid: Union[str, int],
username: str,
images: Optional[List[str]] = None,
images: Optional[List[CommingMessage.MessageImage]] = None,
files: Optional[List[CommingMessage.MessageAttachment]] = None,
reply_with_voice: bool = False,
) -> None:
@@ -1259,6 +1263,8 @@ class MessageChain(ChainBase):
)
return
images = CommingMessage.MessageImage.normalize_list(images)
# 提取用户消息
if text.lower().startswith("/ai"):
user_message = text[3:].strip() # 移除 "/ai" 前缀(大小写不敏感)
@@ -1282,7 +1288,8 @@ class MessageChain(ChainBase):
# 下载图片并转为base64
original_images = images
if images:
all_files = list(files or [])
if images and LLMHelper.supports_image_input():
images = self._download_images_to_base64(images, channel, source)
if original_images and not images and not user_message and not files:
self.post_message(
@@ -1295,14 +1302,29 @@ class MessageChain(ChainBase):
)
)
return
elif images:
image_attachments = self._build_image_attachments(images)
if original_images and not image_attachments and not user_message and not files:
self.post_message(
Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title="图片读取失败,请稍后重试",
)
)
return
all_files.extend(image_attachments)
images = None
prepared_files = self._prepare_agent_files(
session_id=session_id,
files=files,
files=all_files,
channel=channel,
source=source,
)
if files and not prepared_files and not user_message and not images:
if all_files and not prepared_files and not user_message and not images:
self.post_message(
Notification(
channel=channel,
@@ -1452,15 +1474,20 @@ class MessageChain(ChainBase):
return default
def _download_images_to_base64(
self, images: List[str], channel: MessageChannel, source: str
self,
images: List[CommingMessage.MessageImage],
channel: MessageChannel,
source: str,
) -> List[str]:
"""
下载图片并转为base64
"""
images = CommingMessage.MessageImage.normalize_list(images)
if not images:
return None
base64_images = []
for img in images:
for image in images:
img = image.ref
try:
if img.startswith("data:"):
base64_images.append(img)
@@ -1511,6 +1538,33 @@ class MessageChain(ChainBase):
logger.error(f"下载图片失败: {img}, error: {e}")
return base64_images if base64_images else None
def _build_image_attachments(
self, images: List[CommingMessage.MessageImage]
) -> List[CommingMessage.MessageAttachment]:
"""
将图片引用转换为附件描述,以便按文件方式交给 Agent 处理。
"""
images = CommingMessage.MessageImage.normalize_list(images)
if not images:
return []
attachments = []
for index, image in enumerate(images, start=1):
image_ref = image.ref
if not image_ref:
continue
name = image.name or self._guess_image_attachment_name(image_ref, index)
mime_type = image.mime_type or self._guess_image_mime_type(image_ref, name)
attachments.append(
CommingMessage.MessageAttachment(
ref=image_ref,
name=name,
mime_type=mime_type,
size=image.size,
)
)
return attachments
def _prepare_agent_files(
self,
session_id: str,
@@ -1570,15 +1624,31 @@ class MessageChain(ChainBase):
"""
if not file_ref:
return None
if file_ref.startswith("data:"):
return self._decode_data_url_bytes(file_ref)
if file_ref.startswith("tg://file_id/"):
file_id = file_ref.replace("tg://file_id/", "", 1)
return self.run_module(
"download_telegram_file_bytes", file_id=file_id, source=source
)
if file_ref.startswith("tg://document_file_id/"):
file_id = file_ref.replace("tg://document_file_id/", "", 1)
return self.run_module(
"download_telegram_file_bytes", file_id=file_id, source=source
)
if file_ref.startswith("wxwork://media_id/"):
return self.run_module(
"download_wechat_media_bytes", media_ref=file_ref, source=source
)
if file_ref.startswith("wxwork://file_media_id/"):
return self.run_module(
"download_wechat_media_bytes", media_ref=file_ref, source=source
)
if file_ref.startswith("wxbot://image/"):
data_url = self.run_module(
"download_wechat_image_to_data_url", image_ref=file_ref, source=source
)
return self._decode_data_url_bytes(data_url) if data_url else None
if file_ref.startswith("wxbot://file/"):
file_url = unquote(file_ref.replace("wxbot://file/", "", 1))
resp = RequestUtils(timeout=30).get_res(file_url)
@@ -1604,6 +1674,11 @@ class MessageChain(ChainBase):
"download_synologychat_file_bytes", file_ref=file_ref, source=source
)
if file_ref.startswith("http"):
if channel == MessageChannel.Slack:
data_url = self.run_module(
"download_slack_file_to_data_url", file_url=file_ref, source=source
)
return self._decode_data_url_bytes(data_url) if data_url else None
resp = RequestUtils(timeout=30).get_res(file_ref)
return resp.content if resp and resp.content else None
logger.debug(
@@ -1647,6 +1722,11 @@ class MessageChain(ChainBase):
if "." not in name:
mime = (mime_type or "").split(";", 1)[0].strip().lower()
default_ext = {
"image/jpeg": ".jpg",
"image/png": ".png",
"image/gif": ".gif",
"image/webp": ".webp",
"image/bmp": ".bmp",
"application/json": ".json",
"text/plain": ".txt",
"text/markdown": ".md",
@@ -1655,3 +1735,50 @@ class MessageChain(ChainBase):
if default_ext:
name = f"{name}{default_ext}"
return name
@staticmethod
def _guess_image_attachment_name(image_ref: str, index: int) -> str:
"""
根据图片引用推测附件名。
"""
if not image_ref:
return f"image_{index}.jpg"
if image_ref.startswith("data:"):
mime_part = image_ref[5:].split(";", 1)[0].strip().lower()
ext = mimetypes.guess_extension(mime_part) or ".jpg"
return f"image_{index}{ext}"
parsed = urlparse(unquote(image_ref))
name = Path(parsed.path).name if parsed.path else ""
if name and "." in name:
return name
return f"image_{index}.jpg"
@staticmethod
def _guess_image_mime_type(image_ref: str, filename: Optional[str]) -> str:
"""
根据图片引用或文件名推测 MIME 类型。
"""
if image_ref and image_ref.startswith("data:"):
mime = image_ref[5:].split(";", 1)[0].strip().lower()
return mime or "image/jpeg"
guessed, _ = mimetypes.guess_type(filename or "")
if guessed and guessed.startswith("image/"):
return guessed
return "image/jpeg"
@staticmethod
def _decode_data_url_bytes(data_url: Optional[str]) -> Optional[bytes]:
"""
将 data URL 解码为原始字节。
"""
if not data_url or not data_url.startswith("data:"):
return None
try:
_, payload = data_url.split(",", 1)
except ValueError:
return None
try:
return base64.b64decode(payload)
except Exception:
return None

View File

@@ -494,6 +494,8 @@ class ConfigModel(BaseModel):
LLM_PROVIDER: str = "deepseek"
# LLM模型名称
LLM_MODEL: str = "deepseek-chat"
# LLM是否支持图片输入开启后消息图片会按多模态输入发送给模型
LLM_SUPPORT_IMAGE_INPUT: bool = False
# LLM API密钥
LLM_API_KEY: Optional[str] = None
# LLM基础URL用于自定义API端点

View File

@@ -59,6 +59,13 @@ def _get_httpx_proxy_key() -> str:
class LLMHelper:
"""LLM模型相关辅助功能"""
@staticmethod
def supports_image_input() -> bool:
"""
判断当前模型是否启用了图片输入能力。
"""
return bool(settings.LLM_SUPPORT_IMAGE_INPUT)
@staticmethod
def get_llm(streaming: bool = False):
"""

View File

@@ -181,7 +181,9 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
return None
@staticmethod
def _extract_images(msg_json: dict) -> Optional[List[str]]:
def _extract_images(
msg_json: dict,
) -> Optional[List[CommingMessage.MessageImage]]:
"""
从Discord消息中提取图片URL
"""
@@ -200,7 +202,14 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
or content_type.startswith("image/")
or filename.endswith(DiscordModule._IMAGE_SUFFIXES)
):
images.append(url)
images.append(
CommingMessage.MessageImage(
ref=url,
name=attachment.get("filename"),
mime_type=attachment.get("content_type"),
size=attachment.get("size"),
)
)
return images if images else None
@classmethod

View File

@@ -155,8 +155,10 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
return None
@classmethod
def _extract_images(cls, msg_body: dict) -> Optional[List[str]]:
images: List[str] = []
def _extract_images(
cls, msg_body: dict
) -> Optional[List[CommingMessage.MessageImage]]:
images: List[CommingMessage.MessageImage] = []
attachments = msg_body.get("attachments") or []
if isinstance(attachments, list):
for attachment in attachments:
@@ -176,26 +178,42 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
or ""
).lower()
if content_type.startswith("image/") or filename.endswith(cls._IMAGE_SUFFIXES):
images.append(url)
images.append(
CommingMessage.MessageImage(
ref=url,
name=attachment.get("filename") or attachment.get("name"),
mime_type=attachment.get("content_type")
or attachment.get("mime_type"),
size=attachment.get("size"),
)
)
for key in ("image", "image_url", "pic_url"):
value = msg_body.get(key)
if isinstance(value, str) and value.startswith("http"):
images.append(value)
images.append(CommingMessage.MessageImage(ref=value))
extra_images = msg_body.get("images")
if isinstance(extra_images, list):
for item in extra_images:
if isinstance(item, str) and item.startswith("http"):
images.append(item)
images.append(CommingMessage.MessageImage(ref=item))
elif isinstance(item, dict):
url = item.get("url") or item.get("image_url")
if isinstance(url, str) and url.startswith("http"):
images.append(url)
images.append(
CommingMessage.MessageImage(
ref=url,
name=item.get("name") or item.get("filename"),
mime_type=item.get("content_type")
or item.get("mime_type"),
size=item.get("size"),
)
)
deduped = []
for image in images:
if image not in deduped:
if image.ref not in [item.ref for item in deduped]:
deduped.append(image)
return deduped or None

View File

@@ -301,7 +301,9 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
return None
@staticmethod
def _extract_images(msg_json: dict) -> Optional[List[str]]:
def _extract_images(
msg_json: dict,
) -> Optional[List[CommingMessage.MessageImage]]:
"""
从Slack消息中提取图片URL
"""
@@ -320,7 +322,14 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
):
url = file.get("url_private") or file.get("url_private_download")
if url:
images.append(url)
images.append(
CommingMessage.MessageImage(
ref=url,
name=file.get("name") or file.get("title"),
mime_type=file.get("mimetype"),
size=file.get("size"),
)
)
return images if images else None
@classmethod

View File

@@ -141,12 +141,14 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]):
return None
@classmethod
def _extract_images(cls, message: dict) -> Optional[List[str]]:
def _extract_images(
cls, message: dict
) -> Optional[List[CommingMessage.MessageImage]]:
images = []
for key in ("file_url", "image_url", "pic_url"):
value = message.get(key)
if isinstance(value, str) and cls._looks_like_image(value):
images.append(value)
images.append(CommingMessage.MessageImage(ref=value))
for key in ("attachments", "files"):
raw_value = message.get(key)
@@ -159,15 +161,23 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]):
items = parsed if isinstance(parsed, list) else [parsed]
for item in items:
if isinstance(item, str) and cls._looks_like_image(item):
images.append(item)
images.append(CommingMessage.MessageImage(ref=item))
elif isinstance(item, dict):
url = item.get("url") or item.get("file_url") or item.get("image_url")
if isinstance(url, str) and cls._looks_like_image(url):
images.append(url)
images.append(
CommingMessage.MessageImage(
ref=url,
name=item.get("name") or item.get("filename"),
mime_type=item.get("content_type")
or item.get("mime_type"),
size=item.get("size"),
)
)
deduped = []
for image in images:
if image not in deduped:
if image.ref not in [item.ref for item in deduped]:
deduped.append(image)
return deduped or None

View File

@@ -273,7 +273,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
return None
@staticmethod
def _extract_images(msg: dict) -> Optional[List[str]]:
def _extract_images(msg: dict) -> Optional[List[CommingMessage.MessageImage]]:
"""
从Telegram消息中提取图片file_id
"""
@@ -283,14 +283,27 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
largest_photo = photo[-1]
file_id = largest_photo.get("file_id")
if file_id:
images.append(f"tg://file_id/{file_id}")
images.append(
CommingMessage.MessageImage(
ref=f"tg://file_id/{file_id}",
mime_type="image/jpeg",
size=largest_photo.get("file_size"),
)
)
document = msg.get("document")
if document:
file_id = document.get("file_id")
mime_type = document.get("mime_type", "")
if file_id and mime_type.startswith("image/"):
images.append(f"tg://file_id/{file_id}")
images.append(
CommingMessage.MessageImage(
ref=f"tg://file_id/{file_id}",
name=document.get("file_name"),
mime_type=document.get("mime_type"),
size=document.get("file_size"),
)
)
return images if images else None

View File

@@ -162,7 +162,9 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
return None
@classmethod
def _extract_images(cls, detail: dict) -> Optional[List[str]]:
def _extract_images(
cls, detail: dict
) -> Optional[List[CommingMessage.MessageImage]]:
content_type = detail.get("content_type") or ""
if content_type != "vocechat/file":
return None
@@ -194,9 +196,23 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
if not is_image:
return None
if isinstance(direct_url, str) and direct_url.startswith("http"):
return [direct_url]
return [
CommingMessage.MessageImage(
ref=direct_url,
name=properties.get("name") or properties.get("filename"),
mime_type=mime_type or None,
size=properties.get("size"),
)
]
if isinstance(file_path, str) and file_path:
return [f"vocechat://file/{quote(file_path, safe='')}"]
return [
CommingMessage.MessageImage(
ref=f"vocechat://file/{quote(file_path, safe='')}",
name=properties.get("name") or properties.get("filename"),
mime_type=mime_type or None,
size=properties.get("size"),
)
]
return None
@classmethod

View File

@@ -189,9 +189,9 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
media_id = DomUtils.tag_value(root_node, "MediaId")
pic_url = DomUtils.tag_value(root_node, "PicUrl")
if media_id:
images = [f"wxwork://media_id/{media_id}"]
images = [CommingMessage.MessageImage(ref=f"wxwork://media_id/{media_id}")]
elif pic_url:
images = [pic_url]
images = [CommingMessage.MessageImage(ref=pic_url)]
logger.info(
f"收到来自 {client_config.name} 的微信图片消息userid={user_id}, images={len(images) if images else 0}"
)

View File

@@ -16,6 +16,7 @@ from app.core.config import settings
from app.core.context import MediaInfo, Context
from app.core.metainfo import MetaInfo
from app.log import logger
from app.schemas import CommingMessage
from app.utils.http import RequestUtils
from app.utils.string import StringUtils
@@ -359,27 +360,50 @@ class WeChatBot:
return f"wxbot://image/{encoded}"
@classmethod
def _extract_images_from_body(cls, body: dict) -> Optional[List[str]]:
images: List[str] = []
def _extract_images_from_body(
cls, body: dict
) -> Optional[List["CommingMessage.MessageImage"]]:
images: List["CommingMessage.MessageImage"] = []
msgtype = body.get("msgtype")
if msgtype == "image":
image_ref = cls._build_image_ref(body.get("image") or {})
image_payload = body.get("image") or {}
image_ref = cls._build_image_ref(image_payload)
if image_ref:
images.append(image_ref)
images.append(
CommingMessage.MessageImage(
ref=image_ref,
mime_type=image_payload.get("mime_type")
or image_payload.get("content_type"),
)
)
elif msgtype == "mixed":
for item in (body.get("mixed") or {}).get("msg_item") or []:
if item.get("msgtype") != "image":
continue
image_ref = cls._build_image_ref(item.get("image") or {})
image_payload = item.get("image") or {}
image_ref = cls._build_image_ref(image_payload)
if image_ref:
images.append(image_ref)
images.append(
CommingMessage.MessageImage(
ref=image_ref,
mime_type=image_payload.get("mime_type")
or image_payload.get("content_type"),
)
)
quote = body.get("quote") or {}
if not images and quote.get("msgtype") == "image":
image_ref = cls._build_image_ref(quote.get("image") or {})
image_payload = quote.get("image") or {}
image_ref = cls._build_image_ref(image_payload)
if image_ref:
images.append(image_ref)
images.append(
CommingMessage.MessageImage(
ref=image_ref,
mime_type=image_payload.get("mime_type")
or image_payload.get("content_type"),
)
)
return images or None

View File

@@ -1,8 +1,8 @@
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Union, List, Dict, Set
from typing import Optional, Union, List, Dict, Set, Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from app.schemas.types import ContentType, NotificationType, MessageChannel
@@ -29,6 +29,61 @@ class CommingMessage(BaseModel):
外来消息
"""
class MessageImage(BaseModel):
"""
外来消息图片
"""
ref: str
name: Optional[str] = None
mime_type: Optional[str] = None
size: Optional[int] = None
@classmethod
def from_value(cls, value: Any) -> Optional["CommingMessage.MessageImage"]:
if value is None:
return None
if isinstance(value, cls):
return value
if isinstance(value, str):
return cls(ref=value)
if isinstance(value, dict):
ref = (
value.get("ref")
or value.get("url")
or value.get("image_url")
or value.get("file_url")
)
if not ref:
return None
size = value.get("size")
try:
size = int(size) if size is not None else None
except (TypeError, ValueError):
size = None
return cls(
ref=ref,
name=value.get("name") or value.get("filename"),
mime_type=value.get("mime_type") or value.get("content_type"),
size=size,
)
return None
@classmethod
def normalize_list(
cls, values: Optional[Any]
) -> Optional[List["CommingMessage.MessageImage"]]:
if not values:
return None
if not isinstance(values, list):
values = [values]
normalized = []
for value in values:
item = cls.from_value(value)
if item:
normalized.append(item)
return normalized or None
class MessageAttachment(BaseModel):
"""
外来消息附件(非图片/非语音)
@@ -64,12 +119,19 @@ class CommingMessage(BaseModel):
# 完整的回调查询信息(原始数据)
callback_query: Optional[Dict] = None
# 图片列表图片URL或file_id
images: Optional[List[str]] = None
images: Optional[List[MessageImage]] = None
# 语音/音频引用列表
audio_refs: Optional[List[str]] = None
# 文件附件列表
files: Optional[List[MessageAttachment]] = None
@field_validator("images", mode="before")
@classmethod
def _normalize_images(
cls, value: Any
) -> Optional[List["CommingMessage.MessageImage"]]:
return cls.MessageImage.normalize_list(value)
def to_dict(self):
"""
转换为字典

View File

@@ -14,6 +14,7 @@ from app.agent.tools.impl.send_local_file import SendLocalFileInput
from app.agent import MoviePilotAgent, AgentChain
from app.chain.message import MessageChain
from app.core.config import settings
from app.helper.llm import LLMHelper
from app.helper.voice import VoiceHelper
from app.modules.discord import DiscordModule
from app.modules.qqbot import QQBotModule
@@ -46,14 +47,18 @@ class AgentImageSupportTest(unittest.TestCase):
images = TelegramModule._extract_images(
{
"photo": [{"file_id": "small"}, {"file_id": "large"}],
"document": {"file_id": "doc-image", "mime_type": "image/png"},
"document": {
"file_id": "doc-image",
"mime_type": "image/png",
"file_name": "poster.png",
},
}
)
self.assertEqual(
images,
["tg://file_id/large", "tg://file_id/doc-image"],
)
self.assertEqual([image.ref for image in images], ["tg://file_id/large", "tg://file_id/doc-image"])
self.assertEqual(images[0].mime_type, "image/jpeg")
self.assertEqual(images[1].mime_type, "image/png")
self.assertEqual(images[1].name, "poster.png")
def test_telegram_message_parser_accepts_double_encoded_body(self):
module = TelegramModule()
@@ -83,7 +88,7 @@ class AgentImageSupportTest(unittest.TestCase):
)
self.assertIsNotNone(message)
self.assertEqual(message.images, ["tg://file_id/large"])
self.assertEqual([image.ref for image in message.images], ["tg://file_id/large"])
def test_telegram_forward_payload_uses_dict_not_json_string(self):
payload = Telegram._serialize_update_payload(
@@ -143,7 +148,7 @@ class AgentImageSupportTest(unittest.TestCase):
handle_kwargs = handle_message.call_args.kwargs
self.assertEqual(handle_kwargs["text"], "")
self.assertEqual(handle_kwargs["images"], ["tg://file_id/image-1"])
self.assertEqual([image.ref for image in handle_kwargs["images"]], ["tg://file_id/image-1"])
def test_process_allows_audio_only_message(self):
chain = MessageChain()
@@ -370,6 +375,61 @@ class AgentImageSupportTest(unittest.TestCase):
self.assertEqual(payload["message"], "帮我总结这个文件")
self.assertEqual(payload["files"][0]["local_path"], "/tmp/report.txt")
def test_llm_supports_image_input_respects_explicit_override(self):
with patch.object(settings, "LLM_SUPPORT_IMAGE_INPUT", False):
self.assertFalse(LLMHelper.supports_image_input())
def test_llm_supports_image_input_uses_boolean_setting(self):
with patch.object(settings, "LLM_SUPPORT_IMAGE_INPUT", True):
self.assertTrue(LLMHelper.supports_image_input())
with patch.object(settings, "LLM_SUPPORT_IMAGE_INPUT", False):
self.assertFalse(LLMHelper.supports_image_input())
def test_handle_ai_message_routes_images_to_files_when_image_input_disabled(self):
chain = MessageChain()
with patch.object(settings, "AI_AGENT_ENABLE", True), patch.object(
settings, "LLM_SUPPORT_IMAGE_INPUT", False
), patch.object(chain, "_get_or_create_session_id", return_value="session-1"), patch.object(
chain, "_download_images_to_base64"
) as download_images, patch.object(
chain,
"_prepare_agent_files",
return_value=[
{
"name": "image_1.jpg",
"mime_type": "image/jpeg",
"local_path": "/tmp/image_1.jpg",
"status": "ready",
}
],
) as prepare_files, patch(
"app.chain.message.agent_manager.process_message", new_callable=AsyncMock
) as process_message, patch(
"app.chain.message.asyncio.run_coroutine_threadsafe"
) as run_coroutine_threadsafe:
chain._handle_ai_message(
text="/ai 帮我看看这张图",
channel=MessageChannel.Telegram,
source="telegram-test",
userid="10001",
username="tester",
images=["tg://file_id/image-1"],
)
download_images.assert_not_called()
prepare_files.assert_called_once()
attachments = prepare_files.call_args.kwargs["files"]
self.assertEqual(attachments[0].ref, "tg://file_id/image-1")
self.assertEqual(attachments[0].mime_type, "image/jpeg")
run_coroutine_threadsafe.assert_called_once()
self.assertEqual(process_message.call_args.kwargs["images"], None)
self.assertEqual(
process_message.call_args.kwargs["files"][0]["local_path"],
"/tmp/image_1.jpg",
)
def test_slack_images_use_authenticated_data_url_download(self):
chain = MessageChain()
@@ -460,7 +520,8 @@ class AgentImageSupportTest(unittest.TestCase):
}
)
self.assertEqual(images, ["https://cdn.discordapp.com/test.png"])
self.assertEqual([image.ref for image in images], ["https://cdn.discordapp.com/test.png"])
self.assertEqual(images[0].mime_type, "image/png")
def test_discord_extract_audio_refs_supports_attachment_content_type(self):
audio_refs = DiscordModule._extract_audio_refs(
@@ -587,7 +648,7 @@ class AgentImageSupportTest(unittest.TestCase):
)
self.assertIsNotNone(message)
self.assertEqual(message.images, ["wxwork://media_id/media-1"])
self.assertEqual([image.ref for image in message.images], ["wxwork://media_id/media-1"])
def test_wechat_message_parser_extracts_file_media_id(self):
module = WechatModule()
@@ -661,7 +722,7 @@ class AgentImageSupportTest(unittest.TestCase):
)
self.assertIsNotNone(message)
self.assertTrue(message.images[0].startswith("wxbot://image/"))
self.assertTrue(message.images[0].ref.startswith("wxbot://image/"))
def test_wechat_bot_handles_image_only_callback(self):
bot = WeChatBot.__new__(WeChatBot)
@@ -718,9 +779,10 @@ class AgentImageSupportTest(unittest.TestCase):
self.assertIsNotNone(message)
self.assertEqual(
message.images,
[image.ref for image in message.images],
["vocechat://file/%2Fuploads%2Fposter.png"],
)
self.assertEqual(message.images[0].mime_type, "image/png")
def test_vocechat_message_parser_extracts_audio_file_payload(self):
module = VoceChatModule()
@@ -901,7 +963,8 @@ class AgentImageSupportTest(unittest.TestCase):
)
self.assertIsNotNone(message)
self.assertEqual(message.images, ["https://example.com/qq-image.png"])
self.assertEqual([image.ref for image in message.images], ["https://example.com/qq-image.png"])
self.assertEqual(message.images[0].mime_type, "image/png")
def test_qq_message_parser_accepts_audio_only_attachment(self):
module = QQBotModule()
@@ -959,7 +1022,7 @@ class AgentImageSupportTest(unittest.TestCase):
)
self.assertIsNotNone(message)
self.assertEqual(message.images, ["https://example.com/image.png"])
self.assertEqual([image.ref for image in message.images], ["https://example.com/image.png"])
def test_synology_message_parser_accepts_audio_only_form(self):
module = SynologyChatModule()