mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-28 11:12:00 +08:00
Improve agent image capability routing
This commit is contained in:
@@ -10,7 +10,7 @@ Core Capabilities:
|
||||
3. Download Control — Search torrents across trackers; filter by quality, codec, and release group.
|
||||
4. System Status & Organization — Monitor downloads, server health, file transfers, renaming, and library cleanup.
|
||||
5. Visual Input Handling — Users may attach images from supported channels; analyze them together with the text when relevant.
|
||||
6. File Context Handling — User messages may arrive as structured JSON. Treat the `message` field as the user's text. Non-image attachments appear in `files`; when `local_path` is present, use local file tools to inspect the uploaded file directly.
|
||||
6. File Context Handling — User messages may arrive as structured JSON. Treat the `message` field as the user's text. Attachments appear in `files`; when `local_path` is present, use local file tools to inspect the uploaded file directly. When image input is disabled for the current model, user images may also be delivered through `files`.
|
||||
|
||||
<communication>
|
||||
{verbose_spec}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import asyncio
|
||||
import mimetypes
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Dict, Union, List
|
||||
from urllib.parse import unquote
|
||||
from urllib.parse import unquote, urlparse
|
||||
import uuid
|
||||
|
||||
import base64
|
||||
@@ -20,6 +21,7 @@ from app.core.context import MediaInfo, Context
|
||||
from app.core.meta import MetaBase
|
||||
from app.db.user_oper import UserOper
|
||||
from app.helper.torrent import TorrentHelper
|
||||
from app.helper.llm import LLMHelper
|
||||
from app.helper.voice import VoiceHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotExistMediaInfo, CommingMessage
|
||||
@@ -169,7 +171,7 @@ class MessageChain(ChainBase):
|
||||
text: str,
|
||||
original_message_id: Optional[Union[str, int]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
images: Optional[List[str]] = None,
|
||||
images: Optional[List[CommingMessage.MessageImage]] = None,
|
||||
audio_refs: Optional[List[str]] = None,
|
||||
files: Optional[List[CommingMessage.MessageAttachment]] = None,
|
||||
) -> None:
|
||||
@@ -183,6 +185,8 @@ class MessageChain(ChainBase):
|
||||
user_cache: Dict[str, dict] = self.load_cache(self._cache_file) or {}
|
||||
|
||||
try:
|
||||
images = CommingMessage.MessageImage.normalize_list(images)
|
||||
|
||||
# 识别语音为文本
|
||||
reply_with_voice = bool(audio_refs)
|
||||
if audio_refs:
|
||||
@@ -1238,7 +1242,7 @@ class MessageChain(ChainBase):
|
||||
source: str,
|
||||
userid: Union[str, int],
|
||||
username: str,
|
||||
images: Optional[List[str]] = None,
|
||||
images: Optional[List[CommingMessage.MessageImage]] = None,
|
||||
files: Optional[List[CommingMessage.MessageAttachment]] = None,
|
||||
reply_with_voice: bool = False,
|
||||
) -> None:
|
||||
@@ -1259,6 +1263,8 @@ class MessageChain(ChainBase):
|
||||
)
|
||||
return
|
||||
|
||||
images = CommingMessage.MessageImage.normalize_list(images)
|
||||
|
||||
# 提取用户消息
|
||||
if text.lower().startswith("/ai"):
|
||||
user_message = text[3:].strip() # 移除 "/ai" 前缀(大小写不敏感)
|
||||
@@ -1282,7 +1288,8 @@ class MessageChain(ChainBase):
|
||||
|
||||
# 下载图片并转为base64
|
||||
original_images = images
|
||||
if images:
|
||||
all_files = list(files or [])
|
||||
if images and LLMHelper.supports_image_input():
|
||||
images = self._download_images_to_base64(images, channel, source)
|
||||
if original_images and not images and not user_message and not files:
|
||||
self.post_message(
|
||||
@@ -1295,14 +1302,29 @@ class MessageChain(ChainBase):
|
||||
)
|
||||
)
|
||||
return
|
||||
elif images:
|
||||
image_attachments = self._build_image_attachments(images)
|
||||
if original_images and not image_attachments and not user_message and not files:
|
||||
self.post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="图片读取失败,请稍后重试",
|
||||
)
|
||||
)
|
||||
return
|
||||
all_files.extend(image_attachments)
|
||||
images = None
|
||||
|
||||
prepared_files = self._prepare_agent_files(
|
||||
session_id=session_id,
|
||||
files=files,
|
||||
files=all_files,
|
||||
channel=channel,
|
||||
source=source,
|
||||
)
|
||||
if files and not prepared_files and not user_message and not images:
|
||||
if all_files and not prepared_files and not user_message and not images:
|
||||
self.post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
@@ -1452,15 +1474,20 @@ class MessageChain(ChainBase):
|
||||
return default
|
||||
|
||||
def _download_images_to_base64(
|
||||
self, images: List[str], channel: MessageChannel, source: str
|
||||
self,
|
||||
images: List[CommingMessage.MessageImage],
|
||||
channel: MessageChannel,
|
||||
source: str,
|
||||
) -> List[str]:
|
||||
"""
|
||||
下载图片并转为base64
|
||||
"""
|
||||
images = CommingMessage.MessageImage.normalize_list(images)
|
||||
if not images:
|
||||
return None
|
||||
base64_images = []
|
||||
for img in images:
|
||||
for image in images:
|
||||
img = image.ref
|
||||
try:
|
||||
if img.startswith("data:"):
|
||||
base64_images.append(img)
|
||||
@@ -1511,6 +1538,33 @@ class MessageChain(ChainBase):
|
||||
logger.error(f"下载图片失败: {img}, error: {e}")
|
||||
return base64_images if base64_images else None
|
||||
|
||||
def _build_image_attachments(
|
||||
self, images: List[CommingMessage.MessageImage]
|
||||
) -> List[CommingMessage.MessageAttachment]:
|
||||
"""
|
||||
将图片引用转换为附件描述,以便按文件方式交给 Agent 处理。
|
||||
"""
|
||||
images = CommingMessage.MessageImage.normalize_list(images)
|
||||
if not images:
|
||||
return []
|
||||
|
||||
attachments = []
|
||||
for index, image in enumerate(images, start=1):
|
||||
image_ref = image.ref
|
||||
if not image_ref:
|
||||
continue
|
||||
name = image.name or self._guess_image_attachment_name(image_ref, index)
|
||||
mime_type = image.mime_type or self._guess_image_mime_type(image_ref, name)
|
||||
attachments.append(
|
||||
CommingMessage.MessageAttachment(
|
||||
ref=image_ref,
|
||||
name=name,
|
||||
mime_type=mime_type,
|
||||
size=image.size,
|
||||
)
|
||||
)
|
||||
return attachments
|
||||
|
||||
def _prepare_agent_files(
|
||||
self,
|
||||
session_id: str,
|
||||
@@ -1570,15 +1624,31 @@ class MessageChain(ChainBase):
|
||||
"""
|
||||
if not file_ref:
|
||||
return None
|
||||
if file_ref.startswith("data:"):
|
||||
return self._decode_data_url_bytes(file_ref)
|
||||
if file_ref.startswith("tg://file_id/"):
|
||||
file_id = file_ref.replace("tg://file_id/", "", 1)
|
||||
return self.run_module(
|
||||
"download_telegram_file_bytes", file_id=file_id, source=source
|
||||
)
|
||||
if file_ref.startswith("tg://document_file_id/"):
|
||||
file_id = file_ref.replace("tg://document_file_id/", "", 1)
|
||||
return self.run_module(
|
||||
"download_telegram_file_bytes", file_id=file_id, source=source
|
||||
)
|
||||
if file_ref.startswith("wxwork://media_id/"):
|
||||
return self.run_module(
|
||||
"download_wechat_media_bytes", media_ref=file_ref, source=source
|
||||
)
|
||||
if file_ref.startswith("wxwork://file_media_id/"):
|
||||
return self.run_module(
|
||||
"download_wechat_media_bytes", media_ref=file_ref, source=source
|
||||
)
|
||||
if file_ref.startswith("wxbot://image/"):
|
||||
data_url = self.run_module(
|
||||
"download_wechat_image_to_data_url", image_ref=file_ref, source=source
|
||||
)
|
||||
return self._decode_data_url_bytes(data_url) if data_url else None
|
||||
if file_ref.startswith("wxbot://file/"):
|
||||
file_url = unquote(file_ref.replace("wxbot://file/", "", 1))
|
||||
resp = RequestUtils(timeout=30).get_res(file_url)
|
||||
@@ -1604,6 +1674,11 @@ class MessageChain(ChainBase):
|
||||
"download_synologychat_file_bytes", file_ref=file_ref, source=source
|
||||
)
|
||||
if file_ref.startswith("http"):
|
||||
if channel == MessageChannel.Slack:
|
||||
data_url = self.run_module(
|
||||
"download_slack_file_to_data_url", file_url=file_ref, source=source
|
||||
)
|
||||
return self._decode_data_url_bytes(data_url) if data_url else None
|
||||
resp = RequestUtils(timeout=30).get_res(file_ref)
|
||||
return resp.content if resp and resp.content else None
|
||||
logger.debug(
|
||||
@@ -1647,6 +1722,11 @@ class MessageChain(ChainBase):
|
||||
if "." not in name:
|
||||
mime = (mime_type or "").split(";", 1)[0].strip().lower()
|
||||
default_ext = {
|
||||
"image/jpeg": ".jpg",
|
||||
"image/png": ".png",
|
||||
"image/gif": ".gif",
|
||||
"image/webp": ".webp",
|
||||
"image/bmp": ".bmp",
|
||||
"application/json": ".json",
|
||||
"text/plain": ".txt",
|
||||
"text/markdown": ".md",
|
||||
@@ -1655,3 +1735,50 @@ class MessageChain(ChainBase):
|
||||
if default_ext:
|
||||
name = f"{name}{default_ext}"
|
||||
return name
|
||||
|
||||
@staticmethod
|
||||
def _guess_image_attachment_name(image_ref: str, index: int) -> str:
|
||||
"""
|
||||
根据图片引用推测附件名。
|
||||
"""
|
||||
if not image_ref:
|
||||
return f"image_{index}.jpg"
|
||||
if image_ref.startswith("data:"):
|
||||
mime_part = image_ref[5:].split(";", 1)[0].strip().lower()
|
||||
ext = mimetypes.guess_extension(mime_part) or ".jpg"
|
||||
return f"image_{index}{ext}"
|
||||
|
||||
parsed = urlparse(unquote(image_ref))
|
||||
name = Path(parsed.path).name if parsed.path else ""
|
||||
if name and "." in name:
|
||||
return name
|
||||
return f"image_{index}.jpg"
|
||||
|
||||
@staticmethod
|
||||
def _guess_image_mime_type(image_ref: str, filename: Optional[str]) -> str:
|
||||
"""
|
||||
根据图片引用或文件名推测 MIME 类型。
|
||||
"""
|
||||
if image_ref and image_ref.startswith("data:"):
|
||||
mime = image_ref[5:].split(";", 1)[0].strip().lower()
|
||||
return mime or "image/jpeg"
|
||||
guessed, _ = mimetypes.guess_type(filename or "")
|
||||
if guessed and guessed.startswith("image/"):
|
||||
return guessed
|
||||
return "image/jpeg"
|
||||
|
||||
@staticmethod
|
||||
def _decode_data_url_bytes(data_url: Optional[str]) -> Optional[bytes]:
|
||||
"""
|
||||
将 data URL 解码为原始字节。
|
||||
"""
|
||||
if not data_url or not data_url.startswith("data:"):
|
||||
return None
|
||||
try:
|
||||
_, payload = data_url.split(",", 1)
|
||||
except ValueError:
|
||||
return None
|
||||
try:
|
||||
return base64.b64decode(payload)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@@ -494,6 +494,8 @@ class ConfigModel(BaseModel):
|
||||
LLM_PROVIDER: str = "deepseek"
|
||||
# LLM模型名称
|
||||
LLM_MODEL: str = "deepseek-chat"
|
||||
# LLM是否支持图片输入,开启后消息图片会按多模态输入发送给模型
|
||||
LLM_SUPPORT_IMAGE_INPUT: bool = False
|
||||
# LLM API密钥
|
||||
LLM_API_KEY: Optional[str] = None
|
||||
# LLM基础URL(用于自定义API端点)
|
||||
|
||||
@@ -59,6 +59,13 @@ def _get_httpx_proxy_key() -> str:
|
||||
class LLMHelper:
|
||||
"""LLM模型相关辅助功能"""
|
||||
|
||||
@staticmethod
|
||||
def supports_image_input() -> bool:
|
||||
"""
|
||||
判断当前模型是否启用了图片输入能力。
|
||||
"""
|
||||
return bool(settings.LLM_SUPPORT_IMAGE_INPUT)
|
||||
|
||||
@staticmethod
|
||||
def get_llm(streaming: bool = False):
|
||||
"""
|
||||
|
||||
@@ -181,7 +181,9 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_images(msg_json: dict) -> Optional[List[str]]:
|
||||
def _extract_images(
|
||||
msg_json: dict,
|
||||
) -> Optional[List[CommingMessage.MessageImage]]:
|
||||
"""
|
||||
从Discord消息中提取图片URL
|
||||
"""
|
||||
@@ -200,7 +202,14 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
or content_type.startswith("image/")
|
||||
or filename.endswith(DiscordModule._IMAGE_SUFFIXES)
|
||||
):
|
||||
images.append(url)
|
||||
images.append(
|
||||
CommingMessage.MessageImage(
|
||||
ref=url,
|
||||
name=attachment.get("filename"),
|
||||
mime_type=attachment.get("content_type"),
|
||||
size=attachment.get("size"),
|
||||
)
|
||||
)
|
||||
return images if images else None
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -155,8 +155,10 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _extract_images(cls, msg_body: dict) -> Optional[List[str]]:
|
||||
images: List[str] = []
|
||||
def _extract_images(
|
||||
cls, msg_body: dict
|
||||
) -> Optional[List[CommingMessage.MessageImage]]:
|
||||
images: List[CommingMessage.MessageImage] = []
|
||||
attachments = msg_body.get("attachments") or []
|
||||
if isinstance(attachments, list):
|
||||
for attachment in attachments:
|
||||
@@ -176,26 +178,42 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
|
||||
or ""
|
||||
).lower()
|
||||
if content_type.startswith("image/") or filename.endswith(cls._IMAGE_SUFFIXES):
|
||||
images.append(url)
|
||||
images.append(
|
||||
CommingMessage.MessageImage(
|
||||
ref=url,
|
||||
name=attachment.get("filename") or attachment.get("name"),
|
||||
mime_type=attachment.get("content_type")
|
||||
or attachment.get("mime_type"),
|
||||
size=attachment.get("size"),
|
||||
)
|
||||
)
|
||||
|
||||
for key in ("image", "image_url", "pic_url"):
|
||||
value = msg_body.get(key)
|
||||
if isinstance(value, str) and value.startswith("http"):
|
||||
images.append(value)
|
||||
images.append(CommingMessage.MessageImage(ref=value))
|
||||
|
||||
extra_images = msg_body.get("images")
|
||||
if isinstance(extra_images, list):
|
||||
for item in extra_images:
|
||||
if isinstance(item, str) and item.startswith("http"):
|
||||
images.append(item)
|
||||
images.append(CommingMessage.MessageImage(ref=item))
|
||||
elif isinstance(item, dict):
|
||||
url = item.get("url") or item.get("image_url")
|
||||
if isinstance(url, str) and url.startswith("http"):
|
||||
images.append(url)
|
||||
images.append(
|
||||
CommingMessage.MessageImage(
|
||||
ref=url,
|
||||
name=item.get("name") or item.get("filename"),
|
||||
mime_type=item.get("content_type")
|
||||
or item.get("mime_type"),
|
||||
size=item.get("size"),
|
||||
)
|
||||
)
|
||||
|
||||
deduped = []
|
||||
for image in images:
|
||||
if image not in deduped:
|
||||
if image.ref not in [item.ref for item in deduped]:
|
||||
deduped.append(image)
|
||||
return deduped or None
|
||||
|
||||
|
||||
@@ -301,7 +301,9 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_images(msg_json: dict) -> Optional[List[str]]:
|
||||
def _extract_images(
|
||||
msg_json: dict,
|
||||
) -> Optional[List[CommingMessage.MessageImage]]:
|
||||
"""
|
||||
从Slack消息中提取图片URL
|
||||
"""
|
||||
@@ -320,7 +322,14 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
):
|
||||
url = file.get("url_private") or file.get("url_private_download")
|
||||
if url:
|
||||
images.append(url)
|
||||
images.append(
|
||||
CommingMessage.MessageImage(
|
||||
ref=url,
|
||||
name=file.get("name") or file.get("title"),
|
||||
mime_type=file.get("mimetype"),
|
||||
size=file.get("size"),
|
||||
)
|
||||
)
|
||||
return images if images else None
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -141,12 +141,14 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _extract_images(cls, message: dict) -> Optional[List[str]]:
|
||||
def _extract_images(
|
||||
cls, message: dict
|
||||
) -> Optional[List[CommingMessage.MessageImage]]:
|
||||
images = []
|
||||
for key in ("file_url", "image_url", "pic_url"):
|
||||
value = message.get(key)
|
||||
if isinstance(value, str) and cls._looks_like_image(value):
|
||||
images.append(value)
|
||||
images.append(CommingMessage.MessageImage(ref=value))
|
||||
|
||||
for key in ("attachments", "files"):
|
||||
raw_value = message.get(key)
|
||||
@@ -159,15 +161,23 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]):
|
||||
items = parsed if isinstance(parsed, list) else [parsed]
|
||||
for item in items:
|
||||
if isinstance(item, str) and cls._looks_like_image(item):
|
||||
images.append(item)
|
||||
images.append(CommingMessage.MessageImage(ref=item))
|
||||
elif isinstance(item, dict):
|
||||
url = item.get("url") or item.get("file_url") or item.get("image_url")
|
||||
if isinstance(url, str) and cls._looks_like_image(url):
|
||||
images.append(url)
|
||||
images.append(
|
||||
CommingMessage.MessageImage(
|
||||
ref=url,
|
||||
name=item.get("name") or item.get("filename"),
|
||||
mime_type=item.get("content_type")
|
||||
or item.get("mime_type"),
|
||||
size=item.get("size"),
|
||||
)
|
||||
)
|
||||
|
||||
deduped = []
|
||||
for image in images:
|
||||
if image not in deduped:
|
||||
if image.ref not in [item.ref for item in deduped]:
|
||||
deduped.append(image)
|
||||
return deduped or None
|
||||
|
||||
|
||||
@@ -273,7 +273,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_images(msg: dict) -> Optional[List[str]]:
|
||||
def _extract_images(msg: dict) -> Optional[List[CommingMessage.MessageImage]]:
|
||||
"""
|
||||
从Telegram消息中提取图片file_id
|
||||
"""
|
||||
@@ -283,14 +283,27 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
largest_photo = photo[-1]
|
||||
file_id = largest_photo.get("file_id")
|
||||
if file_id:
|
||||
images.append(f"tg://file_id/{file_id}")
|
||||
images.append(
|
||||
CommingMessage.MessageImage(
|
||||
ref=f"tg://file_id/{file_id}",
|
||||
mime_type="image/jpeg",
|
||||
size=largest_photo.get("file_size"),
|
||||
)
|
||||
)
|
||||
|
||||
document = msg.get("document")
|
||||
if document:
|
||||
file_id = document.get("file_id")
|
||||
mime_type = document.get("mime_type", "")
|
||||
if file_id and mime_type.startswith("image/"):
|
||||
images.append(f"tg://file_id/{file_id}")
|
||||
images.append(
|
||||
CommingMessage.MessageImage(
|
||||
ref=f"tg://file_id/{file_id}",
|
||||
name=document.get("file_name"),
|
||||
mime_type=document.get("mime_type"),
|
||||
size=document.get("file_size"),
|
||||
)
|
||||
)
|
||||
|
||||
return images if images else None
|
||||
|
||||
|
||||
@@ -162,7 +162,9 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _extract_images(cls, detail: dict) -> Optional[List[str]]:
|
||||
def _extract_images(
|
||||
cls, detail: dict
|
||||
) -> Optional[List[CommingMessage.MessageImage]]:
|
||||
content_type = detail.get("content_type") or ""
|
||||
if content_type != "vocechat/file":
|
||||
return None
|
||||
@@ -194,9 +196,23 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
|
||||
if not is_image:
|
||||
return None
|
||||
if isinstance(direct_url, str) and direct_url.startswith("http"):
|
||||
return [direct_url]
|
||||
return [
|
||||
CommingMessage.MessageImage(
|
||||
ref=direct_url,
|
||||
name=properties.get("name") or properties.get("filename"),
|
||||
mime_type=mime_type or None,
|
||||
size=properties.get("size"),
|
||||
)
|
||||
]
|
||||
if isinstance(file_path, str) and file_path:
|
||||
return [f"vocechat://file/{quote(file_path, safe='')}"]
|
||||
return [
|
||||
CommingMessage.MessageImage(
|
||||
ref=f"vocechat://file/{quote(file_path, safe='')}",
|
||||
name=properties.get("name") or properties.get("filename"),
|
||||
mime_type=mime_type or None,
|
||||
size=properties.get("size"),
|
||||
)
|
||||
]
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -189,9 +189,9 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
|
||||
media_id = DomUtils.tag_value(root_node, "MediaId")
|
||||
pic_url = DomUtils.tag_value(root_node, "PicUrl")
|
||||
if media_id:
|
||||
images = [f"wxwork://media_id/{media_id}"]
|
||||
images = [CommingMessage.MessageImage(ref=f"wxwork://media_id/{media_id}")]
|
||||
elif pic_url:
|
||||
images = [pic_url]
|
||||
images = [CommingMessage.MessageImage(ref=pic_url)]
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的微信图片消息:userid={user_id}, images={len(images) if images else 0}"
|
||||
)
|
||||
|
||||
@@ -16,6 +16,7 @@ from app.core.config import settings
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.log import logger
|
||||
from app.schemas import CommingMessage
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
@@ -359,27 +360,50 @@ class WeChatBot:
|
||||
return f"wxbot://image/{encoded}"
|
||||
|
||||
@classmethod
|
||||
def _extract_images_from_body(cls, body: dict) -> Optional[List[str]]:
|
||||
images: List[str] = []
|
||||
def _extract_images_from_body(
|
||||
cls, body: dict
|
||||
) -> Optional[List["CommingMessage.MessageImage"]]:
|
||||
images: List["CommingMessage.MessageImage"] = []
|
||||
msgtype = body.get("msgtype")
|
||||
|
||||
if msgtype == "image":
|
||||
image_ref = cls._build_image_ref(body.get("image") or {})
|
||||
image_payload = body.get("image") or {}
|
||||
image_ref = cls._build_image_ref(image_payload)
|
||||
if image_ref:
|
||||
images.append(image_ref)
|
||||
images.append(
|
||||
CommingMessage.MessageImage(
|
||||
ref=image_ref,
|
||||
mime_type=image_payload.get("mime_type")
|
||||
or image_payload.get("content_type"),
|
||||
)
|
||||
)
|
||||
elif msgtype == "mixed":
|
||||
for item in (body.get("mixed") or {}).get("msg_item") or []:
|
||||
if item.get("msgtype") != "image":
|
||||
continue
|
||||
image_ref = cls._build_image_ref(item.get("image") or {})
|
||||
image_payload = item.get("image") or {}
|
||||
image_ref = cls._build_image_ref(image_payload)
|
||||
if image_ref:
|
||||
images.append(image_ref)
|
||||
images.append(
|
||||
CommingMessage.MessageImage(
|
||||
ref=image_ref,
|
||||
mime_type=image_payload.get("mime_type")
|
||||
or image_payload.get("content_type"),
|
||||
)
|
||||
)
|
||||
|
||||
quote = body.get("quote") or {}
|
||||
if not images and quote.get("msgtype") == "image":
|
||||
image_ref = cls._build_image_ref(quote.get("image") or {})
|
||||
image_payload = quote.get("image") or {}
|
||||
image_ref = cls._build_image_ref(image_payload)
|
||||
if image_ref:
|
||||
images.append(image_ref)
|
||||
images.append(
|
||||
CommingMessage.MessageImage(
|
||||
ref=image_ref,
|
||||
mime_type=image_payload.get("mime_type")
|
||||
or image_payload.get("content_type"),
|
||||
)
|
||||
)
|
||||
|
||||
return images or None
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional, Union, List, Dict, Set
|
||||
from typing import Optional, Union, List, Dict, Set, Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.schemas.types import ContentType, NotificationType, MessageChannel
|
||||
|
||||
@@ -29,6 +29,61 @@ class CommingMessage(BaseModel):
|
||||
外来消息
|
||||
"""
|
||||
|
||||
class MessageImage(BaseModel):
|
||||
"""
|
||||
外来消息图片
|
||||
"""
|
||||
|
||||
ref: str
|
||||
name: Optional[str] = None
|
||||
mime_type: Optional[str] = None
|
||||
size: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def from_value(cls, value: Any) -> Optional["CommingMessage.MessageImage"]:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, cls):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return cls(ref=value)
|
||||
if isinstance(value, dict):
|
||||
ref = (
|
||||
value.get("ref")
|
||||
or value.get("url")
|
||||
or value.get("image_url")
|
||||
or value.get("file_url")
|
||||
)
|
||||
if not ref:
|
||||
return None
|
||||
size = value.get("size")
|
||||
try:
|
||||
size = int(size) if size is not None else None
|
||||
except (TypeError, ValueError):
|
||||
size = None
|
||||
return cls(
|
||||
ref=ref,
|
||||
name=value.get("name") or value.get("filename"),
|
||||
mime_type=value.get("mime_type") or value.get("content_type"),
|
||||
size=size,
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def normalize_list(
|
||||
cls, values: Optional[Any]
|
||||
) -> Optional[List["CommingMessage.MessageImage"]]:
|
||||
if not values:
|
||||
return None
|
||||
if not isinstance(values, list):
|
||||
values = [values]
|
||||
normalized = []
|
||||
for value in values:
|
||||
item = cls.from_value(value)
|
||||
if item:
|
||||
normalized.append(item)
|
||||
return normalized or None
|
||||
|
||||
class MessageAttachment(BaseModel):
|
||||
"""
|
||||
外来消息附件(非图片/非语音)
|
||||
@@ -64,12 +119,19 @@ class CommingMessage(BaseModel):
|
||||
# 完整的回调查询信息(原始数据)
|
||||
callback_query: Optional[Dict] = None
|
||||
# 图片列表(图片URL或file_id)
|
||||
images: Optional[List[str]] = None
|
||||
images: Optional[List[MessageImage]] = None
|
||||
# 语音/音频引用列表
|
||||
audio_refs: Optional[List[str]] = None
|
||||
# 文件附件列表
|
||||
files: Optional[List[MessageAttachment]] = None
|
||||
|
||||
@field_validator("images", mode="before")
|
||||
@classmethod
|
||||
def _normalize_images(
|
||||
cls, value: Any
|
||||
) -> Optional[List["CommingMessage.MessageImage"]]:
|
||||
return cls.MessageImage.normalize_list(value)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
转换为字典
|
||||
|
||||
@@ -14,6 +14,7 @@ from app.agent.tools.impl.send_local_file import SendLocalFileInput
|
||||
from app.agent import MoviePilotAgent, AgentChain
|
||||
from app.chain.message import MessageChain
|
||||
from app.core.config import settings
|
||||
from app.helper.llm import LLMHelper
|
||||
from app.helper.voice import VoiceHelper
|
||||
from app.modules.discord import DiscordModule
|
||||
from app.modules.qqbot import QQBotModule
|
||||
@@ -46,14 +47,18 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
images = TelegramModule._extract_images(
|
||||
{
|
||||
"photo": [{"file_id": "small"}, {"file_id": "large"}],
|
||||
"document": {"file_id": "doc-image", "mime_type": "image/png"},
|
||||
"document": {
|
||||
"file_id": "doc-image",
|
||||
"mime_type": "image/png",
|
||||
"file_name": "poster.png",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
images,
|
||||
["tg://file_id/large", "tg://file_id/doc-image"],
|
||||
)
|
||||
self.assertEqual([image.ref for image in images], ["tg://file_id/large", "tg://file_id/doc-image"])
|
||||
self.assertEqual(images[0].mime_type, "image/jpeg")
|
||||
self.assertEqual(images[1].mime_type, "image/png")
|
||||
self.assertEqual(images[1].name, "poster.png")
|
||||
|
||||
def test_telegram_message_parser_accepts_double_encoded_body(self):
|
||||
module = TelegramModule()
|
||||
@@ -83,7 +88,7 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
self.assertIsNotNone(message)
|
||||
self.assertEqual(message.images, ["tg://file_id/large"])
|
||||
self.assertEqual([image.ref for image in message.images], ["tg://file_id/large"])
|
||||
|
||||
def test_telegram_forward_payload_uses_dict_not_json_string(self):
|
||||
payload = Telegram._serialize_update_payload(
|
||||
@@ -143,7 +148,7 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
|
||||
handle_kwargs = handle_message.call_args.kwargs
|
||||
self.assertEqual(handle_kwargs["text"], "")
|
||||
self.assertEqual(handle_kwargs["images"], ["tg://file_id/image-1"])
|
||||
self.assertEqual([image.ref for image in handle_kwargs["images"]], ["tg://file_id/image-1"])
|
||||
|
||||
def test_process_allows_audio_only_message(self):
|
||||
chain = MessageChain()
|
||||
@@ -370,6 +375,61 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
self.assertEqual(payload["message"], "帮我总结这个文件")
|
||||
self.assertEqual(payload["files"][0]["local_path"], "/tmp/report.txt")
|
||||
|
||||
def test_llm_supports_image_input_respects_explicit_override(self):
|
||||
with patch.object(settings, "LLM_SUPPORT_IMAGE_INPUT", False):
|
||||
self.assertFalse(LLMHelper.supports_image_input())
|
||||
|
||||
def test_llm_supports_image_input_uses_boolean_setting(self):
|
||||
with patch.object(settings, "LLM_SUPPORT_IMAGE_INPUT", True):
|
||||
self.assertTrue(LLMHelper.supports_image_input())
|
||||
|
||||
with patch.object(settings, "LLM_SUPPORT_IMAGE_INPUT", False):
|
||||
self.assertFalse(LLMHelper.supports_image_input())
|
||||
|
||||
def test_handle_ai_message_routes_images_to_files_when_image_input_disabled(self):
|
||||
chain = MessageChain()
|
||||
|
||||
with patch.object(settings, "AI_AGENT_ENABLE", True), patch.object(
|
||||
settings, "LLM_SUPPORT_IMAGE_INPUT", False
|
||||
), patch.object(chain, "_get_or_create_session_id", return_value="session-1"), patch.object(
|
||||
chain, "_download_images_to_base64"
|
||||
) as download_images, patch.object(
|
||||
chain,
|
||||
"_prepare_agent_files",
|
||||
return_value=[
|
||||
{
|
||||
"name": "image_1.jpg",
|
||||
"mime_type": "image/jpeg",
|
||||
"local_path": "/tmp/image_1.jpg",
|
||||
"status": "ready",
|
||||
}
|
||||
],
|
||||
) as prepare_files, patch(
|
||||
"app.chain.message.agent_manager.process_message", new_callable=AsyncMock
|
||||
) as process_message, patch(
|
||||
"app.chain.message.asyncio.run_coroutine_threadsafe"
|
||||
) as run_coroutine_threadsafe:
|
||||
chain._handle_ai_message(
|
||||
text="/ai 帮我看看这张图",
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
images=["tg://file_id/image-1"],
|
||||
)
|
||||
|
||||
download_images.assert_not_called()
|
||||
prepare_files.assert_called_once()
|
||||
attachments = prepare_files.call_args.kwargs["files"]
|
||||
self.assertEqual(attachments[0].ref, "tg://file_id/image-1")
|
||||
self.assertEqual(attachments[0].mime_type, "image/jpeg")
|
||||
run_coroutine_threadsafe.assert_called_once()
|
||||
self.assertEqual(process_message.call_args.kwargs["images"], None)
|
||||
self.assertEqual(
|
||||
process_message.call_args.kwargs["files"][0]["local_path"],
|
||||
"/tmp/image_1.jpg",
|
||||
)
|
||||
|
||||
def test_slack_images_use_authenticated_data_url_download(self):
|
||||
chain = MessageChain()
|
||||
|
||||
@@ -460,7 +520,8 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual(images, ["https://cdn.discordapp.com/test.png"])
|
||||
self.assertEqual([image.ref for image in images], ["https://cdn.discordapp.com/test.png"])
|
||||
self.assertEqual(images[0].mime_type, "image/png")
|
||||
|
||||
def test_discord_extract_audio_refs_supports_attachment_content_type(self):
|
||||
audio_refs = DiscordModule._extract_audio_refs(
|
||||
@@ -587,7 +648,7 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
self.assertIsNotNone(message)
|
||||
self.assertEqual(message.images, ["wxwork://media_id/media-1"])
|
||||
self.assertEqual([image.ref for image in message.images], ["wxwork://media_id/media-1"])
|
||||
|
||||
def test_wechat_message_parser_extracts_file_media_id(self):
|
||||
module = WechatModule()
|
||||
@@ -661,7 +722,7 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
self.assertIsNotNone(message)
|
||||
self.assertTrue(message.images[0].startswith("wxbot://image/"))
|
||||
self.assertTrue(message.images[0].ref.startswith("wxbot://image/"))
|
||||
|
||||
def test_wechat_bot_handles_image_only_callback(self):
|
||||
bot = WeChatBot.__new__(WeChatBot)
|
||||
@@ -718,9 +779,10 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
|
||||
self.assertIsNotNone(message)
|
||||
self.assertEqual(
|
||||
message.images,
|
||||
[image.ref for image in message.images],
|
||||
["vocechat://file/%2Fuploads%2Fposter.png"],
|
||||
)
|
||||
self.assertEqual(message.images[0].mime_type, "image/png")
|
||||
|
||||
def test_vocechat_message_parser_extracts_audio_file_payload(self):
|
||||
module = VoceChatModule()
|
||||
@@ -901,7 +963,8 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
self.assertIsNotNone(message)
|
||||
self.assertEqual(message.images, ["https://example.com/qq-image.png"])
|
||||
self.assertEqual([image.ref for image in message.images], ["https://example.com/qq-image.png"])
|
||||
self.assertEqual(message.images[0].mime_type, "image/png")
|
||||
|
||||
def test_qq_message_parser_accepts_audio_only_attachment(self):
|
||||
module = QQBotModule()
|
||||
@@ -959,7 +1022,7 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
self.assertIsNotNone(message)
|
||||
self.assertEqual(message.images, ["https://example.com/image.png"])
|
||||
self.assertEqual([image.ref for image in message.images], ["https://example.com/image.png"])
|
||||
|
||||
def test_synology_message_parser_accepts_audio_only_form(self):
|
||||
module = SynologyChatModule()
|
||||
|
||||
Reference in New Issue
Block a user