feat(agent): support file attachments and local file replies

This commit is contained in:
jxxghp
2026-04-14 15:22:01 +08:00
parent 81828948dd
commit 7a5e513f25
18 changed files with 1268 additions and 84 deletions

View File

@@ -1,4 +1,5 @@
import asyncio
import json
import re
import traceback
import uuid
@@ -281,13 +282,19 @@ class MoviePilotAgent:
logger.error(f"创建 Agent 失败: {e}")
raise e
async def process(self, message: str, images: List[str] = None) -> str:
async def process(
self,
message: str,
images: List[str] = None,
files: Optional[List[dict]] = None,
) -> str:
"""
处理用户消息,流式推理并返回 Agent 回复
"""
try:
logger.info(
f"Agent推理: session_id={self.session_id}, input={message}, images={len(images) if images else 0}"
f"Agent推理: session_id={self.session_id}, input={message}, "
f"images={len(images) if images else 0}, files={len(files) if files else 0}"
)
self._tool_context = {
"incoming_voice": self.reply_with_voice,
@@ -300,16 +307,24 @@ class MoviePilotAgent:
session_id=self.session_id, user_id=self.user_id
)
# 构建用户消息内容
if images:
content = []
if message:
content.append({"type": "text", "text": message})
for img in images:
content.append({"type": "image_url", "image_url": {"url": img}})
messages.append(HumanMessage(content=content))
else:
messages.append(HumanMessage(content=message))
# 构建结构化用户消息内容
request_payload = {
"message": message or "",
"images": [
{"index": index + 1, "type": "image"}
for index, _ in enumerate(images or [])
],
"files": files or [],
}
content = [
{
"type": "text",
"text": json.dumps(request_payload, ensure_ascii=False, indent=2),
}
]
for img in images or []:
content.append({"type": "image_url", "image_url": {"url": img}})
messages.append(HumanMessage(content=content))
# 执行推理
await self._execute_agent(messages)
@@ -544,6 +559,7 @@ class _MessageTask:
user_id: str
message: str
images: Optional[List[str]] = None
files: Optional[List[dict]] = None
channel: Optional[str] = None
source: Optional[str] = None
username: Optional[str] = None
@@ -610,6 +626,7 @@ class AgentManager:
user_id: str,
message: str,
images: List[str] = None,
files: Optional[List[dict]] = None,
channel: str = None,
source: str = None,
username: str = None,
@@ -624,6 +641,7 @@ class AgentManager:
user_id=user_id,
message=message,
images=images,
files=files,
channel=channel,
source=source,
username=username,
@@ -727,7 +745,7 @@ class AgentManager:
agent.username = task.username
agent.reply_with_voice = task.reply_with_voice
return await agent.process(task.message, images=task.images)
return await agent.process(task.message, images=task.images, files=task.files)
async def stop_current_task(self, session_id: str):
"""

View File

@@ -10,6 +10,7 @@ Core Capabilities:
3. Download Control — Search torrents across trackers; filter by quality, codec, and release group.
4. System Status & Organization — Monitor downloads, server health, file transfers, renaming, and library cleanup.
5. Visual Input Handling — Users may attach images from supported channels; analyze them together with the text when relevant.
6. File Context Handling — User messages may arrive as structured JSON. Treat the `message` field as the user's text. Non-image attachments appear in `files`; when `local_path` is present, use local file tools to inspect the uploaded file directly.
<communication>
{verbose_spec}
@@ -21,6 +22,7 @@ Core Capabilities:
- Include key details (year, rating, resolution) but do NOT over-explain.
- Do not stop for approval on read-only operations. Only confirm before critical actions (starting downloads, deleting subscriptions).
- If the current channel supports image sending and an image would materially help, you may use the `send_message` tool with `image_url` to send it.
- If the current channel supports file sending and you need to return a local image/file for the user to download, use `send_local_file`.
- Voice replies: {voice_reply_spec}
- NOT a coding assistant. Do not offer code snippets.
- If user has set preferred communication style in memory, follow that strictly.

View File

@@ -30,6 +30,7 @@ from app.agent.tools.impl.search_torrents import SearchTorrentsTool
from app.agent.tools.impl.get_search_results import GetSearchResultsTool
from app.agent.tools.impl.search_web import SearchWebTool
from app.agent.tools.impl.send_message import SendMessageTool
from app.agent.tools.impl.send_local_file import SendLocalFileTool
from app.agent.tools.impl.send_voice_message import SendVoiceMessageTool
from app.agent.tools.impl.query_schedulers import QuerySchedulersTool
from app.agent.tools.impl.run_scheduler import RunSchedulerTool
@@ -119,6 +120,7 @@ class MoviePilotToolFactory:
QueryTransferHistoryTool,
TransferFileTool,
SendMessageTool,
SendLocalFileTool,
SendVoiceMessageTool,
QuerySchedulersTool,
RunSchedulerTool,

View File

@@ -0,0 +1,107 @@
"""发送本地附件工具。"""
from pathlib import Path
from typing import Optional, Type
from pydantic import BaseModel, Field, model_validator
from app.agent.tools.base import MoviePilotTool, ToolChain
from app.log import logger
from app.schemas import Notification, NotificationType
from app.schemas.message import ChannelCapabilityManager, ChannelCapability
from app.schemas.types import MessageChannel
class SendLocalFileInput(BaseModel):
"""发送本地附件工具输入。"""
explanation: str = Field(
...,
description="Clear explanation of why sending this local file helps the user",
)
file_path: str = Field(
...,
description="Absolute path to the local image or file to send to the user",
)
message: Optional[str] = Field(
None,
description="Optional message or caption to send with the attachment",
)
title: Optional[str] = Field(
None,
description="Optional short title shown together with the attachment",
)
file_name: Optional[str] = Field(
None,
description="Optional override filename presented to the user when downloading",
)
@model_validator(mode="after")
def validate_file_path(self):
if not self.file_path:
raise ValueError("file_path 不能为空")
return self
class SendLocalFileTool(MoviePilotTool):
name: str = "send_local_file"
description: str = (
"Send a local image or file from the server filesystem to the current user. "
"Use this when you have generated or identified a local file the user should download."
)
args_schema: Type[BaseModel] = SendLocalFileInput
require_admin: bool = False
def get_tool_message(self, **kwargs) -> Optional[str]:
file_path = kwargs.get("file_path", "")
file_name = Path(file_path).name if file_path else "未知文件"
return f"正在发送本地附件: {file_name}"
async def run(
self,
file_path: str,
message: Optional[str] = None,
title: Optional[str] = None,
file_name: Optional[str] = None,
**kwargs,
) -> str:
if not self._channel or not self._source:
return "当前不在可回传消息的会话中,无法发送附件"
try:
channel = MessageChannel(self._channel)
except ValueError:
return f"不支持的消息渠道: {self._channel}"
if not ChannelCapabilityManager.supports_capability(
channel, ChannelCapability.FILE_SENDING
):
return f"当前渠道 {channel.value} 暂不支持发送本地文件"
resolved_path = Path(file_path).expanduser()
if not resolved_path.is_absolute():
resolved_path = resolved_path.resolve()
if not resolved_path.exists() or not resolved_path.is_file():
return f"文件不存在: {resolved_path}"
logger.info(
"执行工具: %s, channel=%s, file=%s",
self.name,
channel.value,
resolved_path,
)
await ToolChain().async_post_message(
Notification(
channel=channel,
source=self._source,
mtype=NotificationType.Agent,
userid=self._user_id,
username=self._username,
title=title,
text=message,
file_path=str(resolved_path),
file_name=file_name or resolved_path.name,
)
)
return "本地附件已发送"

View File

@@ -19,7 +19,7 @@ class SendMessageInput(BaseModel):
None,
description="The message content to send to the user (should be clear and informative)",
)
message_type: Optional[str] = Field(
title: Optional[str] = Field(
None,
description="Title of the message, a short summary of the message content",
)
@@ -30,8 +30,8 @@ class SendMessageInput(BaseModel):
@model_validator(mode="after")
def validate_payload(self):
if not self.message and not self.message_type and not self.image_url:
raise ValueError("message、message_type、image_url 至少需要提供一个")
if not self.message and not self.title and not self.image_url:
raise ValueError("message、title、image_url 至少需要提供一个")
return self
@@ -44,7 +44,7 @@ class SendMessageTool(MoviePilotTool):
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据消息参数生成友好的提示消息"""
message = kwargs.get("message", "") or ""
title = kwargs.get("message_type") or ""
title = kwargs.get("title") or ""
image_url = kwargs.get("image_url")
# 截断过长的消息
@@ -62,11 +62,11 @@ class SendMessageTool(MoviePilotTool):
async def run(
self,
message: Optional[str] = None,
message_type: Optional[str] = None,
title: Optional[str] = None,
image_url: Optional[str] = None,
**kwargs,
) -> str:
title = message_type or ("图片" if image_url and not message else "")
title = title or ("图片" if image_url and not message else "")
text = message or ""
logger.info(
f"执行工具: {self.name}, 参数: title={title}, message={text}, image_url={image_url}"

View File

@@ -2,8 +2,10 @@ import asyncio
import re
import time
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Optional, Dict, Union, List
from urllib.parse import unquote
import uuid
import base64
@@ -135,7 +137,8 @@ class MessageChain(ChainBase):
text = str(info.text).strip() if info.text else ""
images = info.images
audio_refs = info.audio_refs
if not text and not images and not audio_refs:
files = info.files
if not text and not images and not audio_refs and not files:
logger.debug(f"未识别到消息内容::{body}{form}{args}")
return
@@ -154,6 +157,7 @@ class MessageChain(ChainBase):
original_chat_id=original_chat_id,
images=images,
audio_refs=audio_refs,
files=files,
)
def handle_message(
@@ -167,6 +171,7 @@ class MessageChain(ChainBase):
original_chat_id: Optional[str] = None,
images: Optional[List[str]] = None,
audio_refs: Optional[List[str]] = None,
files: Optional[List[CommingMessage.MessageAttachment]] = None,
) -> None:
"""
识别消息内容,执行操作
@@ -253,9 +258,12 @@ class MessageChain(ChainBase):
userid=userid,
username=username,
images=images,
files=files,
reply_with_voice=reply_with_voice,
)
elif settings.AI_AGENT_ENABLE and settings.AI_AGENT_GLOBAL:
elif settings.AI_AGENT_ENABLE and (
settings.AI_AGENT_GLOBAL or images or files
):
# 普通消息,全局智能体响应
self._handle_ai_message(
text=text,
@@ -264,6 +272,7 @@ class MessageChain(ChainBase):
userid=userid,
username=username,
images=images,
files=files,
reply_with_voice=reply_with_voice,
)
else:
@@ -1230,6 +1239,7 @@ class MessageChain(ChainBase):
userid: Union[str, int],
username: str,
images: Optional[List[str]] = None,
files: Optional[List[CommingMessage.MessageAttachment]] = None,
reply_with_voice: bool = False,
) -> None:
"""
@@ -1255,7 +1265,7 @@ class MessageChain(ChainBase):
else:
user_message = text.strip() # 按原消息处理
if not user_message and not images:
if not user_message and not images and not files:
self.post_message(
Notification(
channel=channel,
@@ -1274,7 +1284,7 @@ class MessageChain(ChainBase):
original_images = images
if images:
images = self._download_images_to_base64(images, channel, source)
if original_images and not images and not user_message:
if original_images and not images and not user_message and not files:
self.post_message(
Notification(
channel=channel,
@@ -1286,6 +1296,24 @@ class MessageChain(ChainBase):
)
return
prepared_files = self._prepare_agent_files(
session_id=session_id,
files=files,
channel=channel,
source=source,
)
if files and not prepared_files and not user_message and not images:
self.post_message(
Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title="文件读取失败,请稍后重试",
)
)
return
# 在事件循环中处理
asyncio.run_coroutine_threadsafe(
agent_manager.process_message(
@@ -1293,6 +1321,7 @@ class MessageChain(ChainBase):
user_id=str(userid),
message=user_message,
images=images,
files=prepared_files,
channel=channel.value if channel else None,
source=source,
username=username,
@@ -1481,3 +1510,148 @@ class MessageChain(ChainBase):
except Exception as e:
logger.error(f"下载图片失败: {img}, error: {e}")
return base64_images if base64_images else None
def _prepare_agent_files(
self,
session_id: str,
files: Optional[List[CommingMessage.MessageAttachment]],
channel: MessageChannel,
source: str,
) -> Optional[List[dict]]:
"""
下载用户上传的文件,落盘到临时目录,并生成文本镜像供 Agent 使用。
"""
if not files:
return None
prepared_files = []
for attachment in files:
payload = {
"name": attachment.name,
"mime_type": attachment.mime_type,
"size": attachment.size,
"ref": attachment.ref,
"status": "download_failed",
}
try:
content = self._download_message_file_bytes(
file_ref=attachment.ref,
channel=channel,
source=source,
)
if not content:
prepared_files.append(payload)
continue
local_path = self._save_agent_attachment(
session_id=session_id,
filename=attachment.name,
content=content,
mime_type=attachment.mime_type,
)
payload.update(
{
"local_path": str(local_path),
"status": "ready",
}
)
except Exception as err:
logger.error(f"准备文件上下文失败: {attachment.ref}, error: {err}")
payload["error"] = str(err)
prepared_files.append(payload)
return prepared_files or None
def _download_message_file_bytes(
self, file_ref: str, channel: MessageChannel, source: str
) -> Optional[bytes]:
"""
下载消息附件的原始字节。
"""
if not file_ref:
return None
if file_ref.startswith("tg://document_file_id/"):
file_id = file_ref.replace("tg://document_file_id/", "", 1)
return self.run_module(
"download_telegram_file_bytes", file_id=file_id, source=source
)
if file_ref.startswith("wxwork://file_media_id/"):
return self.run_module(
"download_wechat_media_bytes", media_ref=file_ref, source=source
)
if file_ref.startswith("wxbot://file/"):
file_url = unquote(file_ref.replace("wxbot://file/", "", 1))
resp = RequestUtils(timeout=30).get_res(file_url)
return resp.content if resp and resp.content else None
if file_ref.startswith("slack://file/"):
return self.run_module(
"download_slack_file_bytes", file_ref=file_ref, source=source
)
if file_ref.startswith("discord://file/"):
return self.run_module(
"download_discord_file_bytes", file_ref=file_ref, source=source
)
if file_ref.startswith("qq://file/"):
return self.run_module(
"download_qq_file_bytes", file_ref=file_ref, source=source
)
if file_ref.startswith("vocechat://file/"):
return self.run_module(
"download_vocechat_file_bytes", file_ref=file_ref, source=source
)
if file_ref.startswith("synology://file/"):
return self.run_module(
"download_synologychat_file_bytes", file_ref=file_ref, source=source
)
if file_ref.startswith("http"):
resp = RequestUtils(timeout=30).get_res(file_ref)
return resp.content if resp and resp.content else None
logger.debug(
"暂不支持的文件引用: channel=%s, source=%s, ref=%s",
channel.value if channel else None,
source,
file_ref,
)
return None
def _save_agent_attachment(
self,
session_id: str,
filename: Optional[str],
content: bytes,
mime_type: Optional[str] = None,
) -> Path:
"""
将用户上传文件写入临时目录,并返回本地路径。
"""
safe_name = self._sanitize_attachment_name(filename, mime_type)
base_dir = settings.TEMP_PATH / "agent_uploads" / session_id
base_dir.mkdir(parents=True, exist_ok=True)
file_id = uuid.uuid4().hex[:8]
local_path = base_dir / f"{file_id}_{safe_name}"
local_path.write_bytes(content or b"")
return local_path
@staticmethod
def _sanitize_attachment_name(
filename: Optional[str], mime_type: Optional[str] = None
) -> str:
"""
规范化附件文件名,避免路径穿越和非法字符。
"""
name = Path(filename or "attachment").name
name = re.sub(r"[^\w.\-]+", "_", name, flags=re.ASCII).strip("._")
if not name:
name = "attachment"
if "." not in name:
mime = (mime_type or "").split(";", 1)[0].strip().lower()
default_ext = {
"application/json": ".json",
"text/plain": ".txt",
"text/markdown": ".md",
"text/csv": ".csv",
}.get(mime)
if default_ext:
name = f"{name}{default_ext}"
return name

View File

@@ -159,11 +159,13 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
chat_id = msg_json.get("chat_id")
images = self._extract_images(msg_json)
audio_refs = self._extract_audio_refs(msg_json)
if (text or images or audio_refs) and userid:
files = self._extract_files(msg_json)
if (text or images or audio_refs or files) and userid:
logger.info(
f"收到来自 {client_config.name} 的 Discord 消息:"
f"userid={userid}, username={username}, text={text}, "
f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}"
f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}, "
f"files={len(files) if files else 0}"
)
return CommingMessage(
channel=MessageChannel.Discord,
@@ -174,6 +176,7 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
chat_id=str(chat_id) if chat_id else None,
images=images,
audio_refs=audio_refs,
files=files,
)
return None
@@ -219,6 +222,44 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
audio_refs.append(f"discord://file/{quote(url, safe='')}")
return audio_refs if audio_refs else None
@classmethod
def _extract_files(
cls, msg_json: dict
) -> Optional[List[CommingMessage.MessageAttachment]]:
"""
从 Discord 消息中提取非图片/非音频文件。
"""
attachments = msg_json.get("attachments", [])
if not attachments:
return None
files = []
for attachment in attachments:
url = attachment.get("url") or attachment.get("proxy_url")
if not url:
continue
content_type = (attachment.get("content_type") or "").lower()
filename = (attachment.get("filename") or "").lower()
is_image = (
attachment.get("type") == "image"
or content_type.startswith("image/")
or filename.endswith(cls._IMAGE_SUFFIXES)
)
is_audio = content_type.startswith("audio/") or filename.endswith(
cls._AUDIO_SUFFIXES
)
if is_image or is_audio:
continue
files.append(
CommingMessage.MessageAttachment(
ref=f"discord://file/{quote(url, safe='')}",
name=attachment.get("filename"),
mime_type=attachment.get("content_type"),
size=attachment.get("size"),
)
)
return files or None
def download_discord_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]:
"""
下载Discord附件并返回原始字节
@@ -278,19 +319,29 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
)
if client:
logger.debug(
f"[Discord] 调用 client.send_msg, userid={userid}, title={message.title[:50] if message.title else None}..."
)
result = client.send_msg(
title=message.title,
text=message.text,
image=message.image,
userid=userid,
link=message.link,
buttons=message.buttons,
original_message_id=message.original_message_id,
original_chat_id=message.original_chat_id,
mtype=message.mtype,
f"[Discord] 调用 client 发送, userid={userid}, title={message.title[:50] if message.title else None}..."
)
if message.file_path:
result = client.send_file(
file_path=message.file_path,
file_name=message.file_name,
title=message.title,
text=message.text,
userid=userid,
original_chat_id=message.original_chat_id,
)
else:
result = client.send_msg(
title=message.title,
text=message.text,
image=message.image,
userid=userid,
link=message.link,
buttons=message.buttons,
original_message_id=message.original_message_id,
original_chat_id=message.original_chat_id,
mtype=message.mtype,
)
logger.debug(f"[Discord] send_msg 返回结果: {result}")
else:
logger.warning(
@@ -427,11 +478,20 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
return None
client: Discord = self.get_instance(conf.name)
if client:
result = client.send_msg(
title=message.title or "",
text=message.text,
userid=userid,
)
if message.file_path:
result = client.send_file(
file_path=message.file_path,
file_name=message.file_name,
title=message.title,
text=message.text,
userid=userid,
)
else:
result = client.send_msg(
title=message.title or "",
text=message.text,
userid=userid,
)
if result:
success, response_data = (
(result[0], result[1])

View File

@@ -1,6 +1,7 @@
import asyncio
import re
import threading
from pathlib import Path
from typing import Optional, List, Dict, Any, Tuple, Union
from urllib.parse import quote
@@ -273,6 +274,37 @@ class Discord:
logger.error(f"发送 Discord 消息失败:{err}")
return False
def send_file(
self,
file_path: str,
title: Optional[str] = None,
text: Optional[str] = None,
userid: Optional[str] = None,
file_name: Optional[str] = None,
original_chat_id: Optional[str] = None,
) -> Optional[bool]:
if not self.get_state():
return False
if not file_path:
return False
try:
future = asyncio.run_coroutine_threadsafe(
self._send_file(
file_path=file_path,
title=title,
text=text,
userid=userid,
file_name=file_name,
original_chat_id=original_chat_id,
),
self._loop,
)
return future.result(timeout=30)
except Exception as err:
logger.error(f"发送 Discord 文件失败:{err}")
return False
def send_medias_msg(
self,
medias: List[MediaInfo],
@@ -414,6 +446,46 @@ class Discord:
logger.error(f"[Discord] 发送消息到频道失败: {e}")
return False, None
async def _send_file(
self,
file_path: str,
title: Optional[str],
text: Optional[str],
userid: Optional[str],
file_name: Optional[str],
original_chat_id: Optional[str],
) -> Tuple[bool, Optional[Dict[str, str]]]:
channel = await self._resolve_channel(userid=userid, chat_id=original_chat_id)
if not channel:
logger.error("未找到可用的 Discord 频道或私聊")
return False, None
local_file = Path(file_path)
if not local_file.exists() or not local_file.is_file():
logger.error(f"Discord发送文件失败文件不存在: {local_file}")
return False, None
content_parts = [part for part in [title, text] if part]
content = "\n".join(content_parts) if content_parts else None
if content and len(content) > 1900:
content = content[:1900] + "..."
try:
discord_file = discord.File(
str(local_file), filename=file_name or local_file.name
)
sent_message = await channel.send(content=content, file=discord_file)
return (
True,
{
"message_id": str(sent_message.id),
"chat_id": str(channel.id),
},
)
except Exception as err:
logger.error(f"Discord发送文件失败: {err}")
return False, None
async def _send_list_message(
self,
embeds: List[discord.Embed],

View File

@@ -107,7 +107,8 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
content = (msg_body.get("content") or "").strip()
images = self._extract_images(msg_body)
audio_refs = self._extract_audio_refs(msg_body)
if not content and not images and not audio_refs:
files = self._extract_files(msg_body)
if not content and not images and not audio_refs and not files:
return None
if msg_type == "C2C_MESSAGE_CREATE":
@@ -118,7 +119,7 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
logger.info(
f"收到 QQ 私聊消息: userid={user_openid}, "
f"text={(content or '')[:50]}..., images={len(images) if images else 0}, "
f"audios={len(audio_refs) if audio_refs else 0}"
f"audios={len(audio_refs) if audio_refs else 0}, files={len(files) if files else 0}"
)
return CommingMessage(
channel=MessageChannel.QQ,
@@ -128,6 +129,7 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
text=content,
images=images,
audio_refs=audio_refs,
files=files,
)
elif msg_type == "GROUP_AT_MESSAGE_CREATE":
author = msg_body.get("author", {})
@@ -138,7 +140,7 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
logger.info(
f"收到 QQ 群消息: group={group_openid}, userid={member_openid}, "
f"text={(content or '')[:50]}..., images={len(images) if images else 0}, "
f"audios={len(audio_refs) if audio_refs else 0}"
f"audios={len(audio_refs) if audio_refs else 0}, files={len(files) if files else 0}"
)
return CommingMessage(
channel=MessageChannel.QQ,
@@ -148,6 +150,7 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
text=content,
images=images,
audio_refs=audio_refs,
files=files,
)
return None
@@ -226,6 +229,46 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
deduped.append(audio_ref)
return deduped or None
@classmethod
def _extract_files(
cls, msg_body: dict
) -> Optional[List[CommingMessage.MessageAttachment]]:
files: List[CommingMessage.MessageAttachment] = []
attachments = msg_body.get("attachments") or []
if isinstance(attachments, list):
for attachment in attachments:
if not isinstance(attachment, dict):
continue
url = attachment.get("url") or attachment.get("proxy_url")
if not url:
continue
content_type = (
attachment.get("content_type")
or attachment.get("mime_type")
or ""
).lower()
filename = (
attachment.get("filename") or attachment.get("name") or ""
).lower()
is_image = content_type.startswith("image/") or filename.endswith(
cls._IMAGE_SUFFIXES
)
is_audio = content_type.startswith("audio/") or filename.endswith(
cls._AUDIO_SUFFIXES
)
if is_image or is_audio:
continue
files.append(
CommingMessage.MessageAttachment(
ref=f"qq://file/{quote(url, safe='')}",
name=attachment.get("filename") or attachment.get("name"),
mime_type=attachment.get("content_type")
or attachment.get("mime_type"),
size=attachment.get("size"),
)
)
return files or None
def download_qq_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]:
"""
下载QQ音频附件并返回原始字节

View File

@@ -221,12 +221,14 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
if msg_json:
images = None
audio_refs = None
files = None
if msg_json.get("type") == "message":
userid = msg_json.get("user")
text = msg_json.get("text")
username = msg_json.get("user")
images = self._extract_images(msg_json)
audio_refs = self._extract_audio_refs(msg_json)
files = self._extract_files(msg_json)
elif msg_json.get("type") == "block_actions":
userid = msg_json.get("user", {}).get("id")
callback_data = msg_json.get("actions")[0].get("value")
@@ -270,6 +272,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
username = ""
images = self._extract_images(msg_json.get("event", {}))
audio_refs = self._extract_audio_refs(msg_json.get("event", {}))
files = self._extract_files(msg_json.get("event", {}))
elif msg_json.get("type") == "shortcut":
userid = msg_json.get("user", {}).get("id")
text = msg_json.get("callback_id")
@@ -282,7 +285,8 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
return None
logger.info(
f"收到来自 {client_config.name} 的Slack消息userid={userid}, username={username}, "
f"text={text}, images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}"
f"text={text}, images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}, "
f"files={len(files) if files else 0}"
)
return CommingMessage(
channel=MessageChannel.Slack,
@@ -292,6 +296,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
text=text,
images=images,
audio_refs=audio_refs,
files=files,
)
return None
@@ -341,6 +346,48 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
audio_refs.append(f"slack://file/{quote(url, safe='')}")
return audio_refs if audio_refs else None
@classmethod
def _extract_files(
cls, msg_json: dict
) -> Optional[List[CommingMessage.MessageAttachment]]:
"""
从 Slack 消息中提取非图片/非音频文件。
"""
files = msg_json.get("files", [])
if not files:
return None
attachments = []
for file in files:
file_type = str(file.get("type", "")).lower()
file_ext = f".{str(file.get('filetype', '')).lower().lstrip('.')}"
mime_type = str(file.get("mimetype", "")).lower()
is_image = (
file_type == "image"
or file_ext in (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp")
or mime_type.startswith("image/")
)
is_audio = (
file_type == "audio"
or mime_type.startswith("audio/")
or file_ext in cls._AUDIO_SUFFIXES
)
if is_image or is_audio:
continue
url = file.get("url_private_download") or file.get("url_private")
if not url:
continue
attachments.append(
CommingMessage.MessageAttachment(
ref=f"slack://file/{quote(url, safe='')}",
name=file.get("name") or file.get("title"),
mime_type=file.get("mimetype"),
size=file.get("size"),
)
)
return attachments or None
def download_slack_file_to_data_url(self, file_url: str, source: str) -> Optional[str]:
"""
下载Slack文件并转为data URL
@@ -399,16 +446,25 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
return
client: Slack = self.get_instance(conf.name)
if client:
client.send_msg(
title=message.title,
text=message.text,
image=message.image,
userid=userid,
link=message.link,
buttons=message.buttons,
original_message_id=message.original_message_id,
original_chat_id=message.original_chat_id,
)
if message.file_path:
client.send_file(
file_path=message.file_path,
file_name=message.file_name,
title=message.title,
text=message.text,
userid=userid,
)
else:
client.send_msg(
title=message.title,
text=message.text,
image=message.image,
userid=userid,
link=message.link,
buttons=message.buttons,
original_message_id=message.original_message_id,
original_chat_id=message.original_chat_id,
)
def post_medias_message(
self, message: Notification, medias: List[MediaInfo]
@@ -538,26 +594,40 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
return None
client: Slack = self.get_instance(conf.name)
if client:
result = client.send_msg(
title=message.title or "",
text=message.text,
userid=userid,
)
if message.file_path:
result = client.send_file(
file_path=message.file_path,
file_name=message.file_name,
title=message.title,
text=message.text,
userid=userid,
)
else:
result = client.send_msg(
title=message.title or "",
text=message.text,
userid=userid,
)
if result and result[0]:
# Slack 使用时间戳作为 message_idchat_id 是频道ID
# 注意:这里返回的是发送后的结果,需要获取实际的 message_id
# 由于 Slack API 返回的是 result[1],包含完整响应,我们需要从中提取
response_data = result[1]
message_id = (
response_data.get("ts")
if isinstance(response_data, dict)
else None
)
channel_id = (
response_data.get("channel")
if isinstance(response_data, dict)
else None
)
message_id = None
channel_id = None
if hasattr(response_data, "get"):
message_id = response_data.get("ts")
channel_id = response_data.get("channel")
if not message_id and hasattr(response_data, "data"):
files = (response_data.data or {}).get("files") or []
if files:
message_id = files[0].get("id")
shares = (
files[0].get("shares", {})
.get("private", {})
)
if shares:
channel_id = next(iter(shares.keys()), None)
return MessageResponse(
message_id=message_id,
chat_id=channel_id,

View File

@@ -1,5 +1,6 @@
import re
from threading import Lock
from pathlib import Path
from typing import List, Optional, Tuple
from urllib.parse import quote
@@ -246,6 +247,48 @@ class Slack:
logger.error(f"Slack消息发送失败: {msg_e}")
return False, str(msg_e)
def send_file(
self,
file_path: str,
title: Optional[str] = None,
text: Optional[str] = None,
userid: Optional[str] = None,
file_name: Optional[str] = None,
):
"""
发送本地文件到 Slack。
"""
if not self._client:
return False, "消息客户端未就绪"
if not file_path:
return False, "文件路径不能为空"
local_file = Path(file_path)
if not local_file.exists() or not local_file.is_file():
return False, f"文件不存在: {local_file}"
try:
if userid:
channel = userid
else:
channel = self.__find_public_channel()
comment_parts = [part for part in [title, text] if part]
initial_comment = "\n".join(comment_parts) if comment_parts else None
with local_file.open("rb") as fp:
result = self._client.files_upload_v2(
channel=channel,
file=fp,
filename=file_name or local_file.name,
title=title or (file_name or local_file.name),
initial_comment=initial_comment,
)
return True, result
except Exception as err:
logger.error(f"Slack文件发送失败: {err}")
return False, str(err)
def send_medias_msg(self, medias: List[MediaInfo], userid: Optional[str] = None, title: Optional[str] = None,
buttons: Optional[List[List[dict]]] = None,
original_message_id: Optional[str] = None,

View File

@@ -125,15 +125,17 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]):
user_name = message.get("username")
images = self._extract_images(message)
audio_refs = self._extract_audio_refs(message)
if (text or images or audio_refs) and user_id:
files = self._extract_files(message)
if (text or images or audio_refs or files) and user_id:
logger.info(
f"收到来自 {client_config.name} 的SynologyChat消息"
f"userid={user_id}, username={user_name}, text={text}, "
f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}"
f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}, "
f"files={len(files) if files else 0}"
)
return CommingMessage(channel=MessageChannel.SynologyChat, source=client_config.name,
userid=user_id, username=user_name, text=text or "",
images=images, audio_refs=audio_refs)
images=images, audio_refs=audio_refs, files=files)
except Exception as err:
logger.debug(f"解析SynologyChat消息失败{str(err)}")
return None
@@ -230,6 +232,56 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]):
suffix in lowered for suffix in cls._AUDIO_SUFFIXES
)
@classmethod
def _extract_files(
cls, message: dict
) -> Optional[List[CommingMessage.MessageAttachment]]:
files = []
for key in ("attachments", "files"):
raw_value = message.get(key)
if not raw_value:
continue
try:
parsed = json.loads(raw_value) if isinstance(raw_value, str) else raw_value
except Exception:
parsed = raw_value
items = parsed if isinstance(parsed, list) else [parsed]
for item in items:
if not isinstance(item, dict):
continue
url = item.get("url") or item.get("file_url") or item.get("download_url")
if not isinstance(url, str) or not url.startswith("http"):
continue
content_type = (
item.get("content_type") or item.get("mime_type") or ""
).lower()
name = (item.get("name") or item.get("filename") or "").lower()
is_image = content_type.startswith("image/") or name.endswith(
cls._IMAGE_SUFFIXES
) or cls._looks_like_image(url)
is_audio = content_type.startswith("audio/") or name.endswith(
cls._AUDIO_SUFFIXES
) or cls._looks_like_audio(url)
if is_image or is_audio:
continue
files.append(
CommingMessage.MessageAttachment(
ref=f"synology://file/{quote(url, safe='')}",
name=item.get("name") or item.get("filename"),
mime_type=item.get("content_type") or item.get("mime_type"),
size=item.get("size"),
)
)
deduped = []
seen_refs = set()
for file_item in files:
if file_item.ref in seen_refs:
continue
seen_refs.add(file_item.ref)
deduped.append(file_item)
return deduped or None
def download_synologychat_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]:
"""
下载 Synology Chat 音频文件并返回原始字节

View File

@@ -215,18 +215,20 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
images = self._extract_images(msg)
audio_refs = self._extract_audio_refs(msg)
files = self._extract_files(msg)
if user_id:
if not text and not images and not audio_refs:
if not text and not images and not audio_refs and not files:
logger.debug(
f"收到来自 {client_config.name} 的Telegram消息无文本、图片语音"
f"收到来自 {client_config.name} 的Telegram消息无文本、图片语音和文件"
)
return None
logger.info(
f"收到来自 {client_config.name} 的Telegram消息"
f"userid={user_id}, username={user_name}, chat_id={chat_id}, text={text}, "
f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}"
f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}, "
f"files={len(files) if files else 0}"
)
cleaned_text = (
@@ -266,6 +268,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
chat_id=str(chat_id) if chat_id else None,
images=images if images else None,
audio_refs=audio_refs if audio_refs else None,
files=files if files else None,
)
return None
@@ -311,6 +314,29 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
return audio_refs if audio_refs else None
@staticmethod
def _extract_files(msg: dict) -> Optional[List[CommingMessage.MessageAttachment]]:
"""
从 Telegram 消息中提取非图片文件附件。
"""
document = msg.get("document")
if not isinstance(document, dict):
return None
file_id = document.get("file_id")
mime_type = (document.get("mime_type") or "").lower()
if not file_id or mime_type.startswith("image/"):
return None
return [
CommingMessage.MessageAttachment(
ref=f"tg://document_file_id/{file_id}",
name=document.get("file_name"),
mime_type=document.get("mime_type"),
size=document.get("file_size"),
)
]
@staticmethod
def _embed_entity_links(text: str, entities: Optional[List[dict]]) -> str:
"""
@@ -412,7 +438,16 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
return
client: Telegram = self.get_instance(conf.name)
if client:
if message.voice_path:
if message.file_path:
client.send_file(
file_path=message.file_path,
file_name=message.file_name,
title=message.title,
text=message.text,
userid=userid,
original_chat_id=message.original_chat_id,
)
elif message.voice_path:
client.send_voice(
voice_path=message.voice_path,
userid=userid,

View File

@@ -507,6 +507,70 @@ class Telegram:
except Exception as cleanup_err:
logger.debug(f"清理语音临时文件失败: {cleanup_err}")
def send_file(
self,
file_path: str,
userid: Optional[str] = None,
title: Optional[str] = None,
text: Optional[str] = None,
file_name: Optional[str] = None,
original_chat_id: Optional[str] = None,
) -> Optional[dict]:
"""
发送本地图片或文件给 Telegram 用户。
"""
if not self._bot or not file_path:
return None
local_file = Path(file_path)
if not local_file.exists() or not local_file.is_file():
logger.error(f"附件文件不存在: {local_file}")
return {"success": False}
chat_id = self._determine_target_chat_id(userid, original_chat_id)
send_name = file_name or local_file.name
suffix = local_file.suffix.lower()
is_image = suffix in {".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp"}
try:
bold_title = (
f"**{standardize(title).removesuffix('\n')}**" if title else None
)
if bold_title and text:
caption = f"{bold_title}\n{text}"
elif bold_title:
caption = bold_title
else:
caption = text or ""
with local_file.open("rb") as fp:
if is_image:
sent = self._bot.send_photo(
chat_id=chat_id,
photo=fp,
caption=standardize(caption) if caption else None,
parse_mode="MarkdownV2" if caption else None,
)
else:
sent = self._bot.send_document(
chat_id=chat_id,
document=(send_name, fp),
caption=standardize(caption) if caption else None,
parse_mode="MarkdownV2" if caption else None,
)
self._stop_typing_task(chat_id)
if sent and hasattr(sent, "message_id"):
return {
"success": True,
"message_id": sent.message_id,
"chat_id": sent.chat.id if hasattr(sent, "chat") else chat_id,
}
return {"success": bool(sent)}
except Exception as err:
logger.error(f"发送本地附件失败: {err}")
self._stop_typing_task(chat_id)
return {"success": False}
def _determine_target_chat_id(
self, userid: Optional[str] = None, original_chat_id: Optional[str] = None
) -> str:

View File

@@ -133,6 +133,7 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
content = detail.get("content")
images = self._extract_images(detail)
audio_refs = self._extract_audio_refs(detail)
files = self._extract_files(detail)
text = None
if content_type in ("text/plain", "text/markdown") and isinstance(content, str):
text = content
@@ -147,15 +148,15 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
userid = f"UID#{msg_body.get('from_uid')}"
# 处理消息内容
if (text or images or audio_refs) and userid:
if (text or images or audio_refs or files) and userid:
logger.info(
f"收到来自 {client_config.name} 的VoceChat消息"
f"userid={userid}, text={text}, images={len(images) if images else 0}, "
f"audios={len(audio_refs) if audio_refs else 0}"
f"audios={len(audio_refs) if audio_refs else 0}, files={len(files) if files else 0}"
)
return CommingMessage(channel=MessageChannel.VoceChat, source=client_config.name,
userid=userid, username=userid, text=text or "",
images=images, audio_refs=audio_refs)
images=images, audio_refs=audio_refs, files=files)
except Exception as err:
logger.error(f"VoceChat消息处理发生错误{str(err)}")
return None
@@ -229,6 +230,51 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
return [f"vocechat://file/{quote(file_path, safe='')}"]
return None
@classmethod
def _extract_files(
cls, detail: dict
) -> Optional[List[CommingMessage.MessageAttachment]]:
content_type = detail.get("content_type") or ""
if content_type != "vocechat/file":
return None
properties = detail.get("properties") or {}
mime_type = (
properties.get("content_type")
or properties.get("mime_type")
or properties.get("contentType")
or ""
).lower()
file_path = (
properties.get("path")
or properties.get("file_path")
or properties.get("storage_path")
or detail.get("content")
)
file_name = (
properties.get("name")
or properties.get("filename")
or (str(file_path).rsplit("/", 1)[-1] if file_path else "")
)
lowered_name = str(file_name).lower()
is_image = mime_type.startswith("image/") or lowered_name.endswith(
cls._IMAGE_SUFFIXES
)
is_audio = mime_type.startswith("audio/") or lowered_name.endswith(
cls._AUDIO_SUFFIXES
)
if is_image or is_audio or not isinstance(file_path, str) or not file_path:
return None
return [
CommingMessage.MessageAttachment(
ref=f"vocechat://file/{quote(file_path, safe='')}",
name=file_name,
mime_type=properties.get("content_type")
or properties.get("mime_type")
or properties.get("contentType"),
size=properties.get("size"),
)
]
def post_message(self, message: Notification, **kwargs) -> None:
"""
发送消息

View File

@@ -3,6 +3,7 @@ import json
import re
import xml.dom.minidom
from typing import Optional, Union, List, Tuple, Any, Dict
from urllib.parse import quote
from app.core.context import Context, MediaInfo
from app.core.event import eventmanager
@@ -168,6 +169,7 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
content = None
images = None
audio_refs = None
files = None
if msg_type == "event" and event == "click":
# 校验用户有权限执行交互命令
if client_config.config.get('WECHAT_ADMINS'):
@@ -203,14 +205,27 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
f"收到来自 {client_config.name} 的微信语音消息userid={user_id}, "
f"text={content}, audios={len(audio_refs) if audio_refs else 0}"
)
elif msg_type == "file":
media_id = DomUtils.tag_value(root_node, "MediaId")
file_name = DomUtils.tag_value(root_node, "FileName")
if media_id:
files = [
CommingMessage.MessageAttachment(
ref=f"wxwork://file_media_id/{media_id}",
name=file_name,
)
]
logger.info(
f"收到来自 {client_config.name} 的微信文件消息userid={user_id}, files={len(files) if files else 0}"
)
else:
return None
if content or images or audio_refs:
if content or images or audio_refs or files:
# 处理消息内容
return CommingMessage(channel=MessageChannel.Wechat, source=client_config.name,
userid=user_id, username=user_id, text=content or "",
images=images, audio_refs=audio_refs)
images=images, audio_refs=audio_refs, files=files)
except Exception as err:
logger.error(f"微信消息处理发生错误:{str(err)}")
return None
@@ -242,6 +257,20 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
text = WeChatBot._extract_text_from_body(payload_body)
images = WeChatBot._extract_images_from_body(payload_body)
audio_refs = ["wxbot://voice"] if payload_body.get("msgtype") == "voice" else None
files = None
if payload_body.get("msgtype") == "file":
file_payload = payload_body.get("file") or {}
download_url = file_payload.get("download_url")
if download_url:
files = [
CommingMessage.MessageAttachment(
ref=f"wxbot://file/{quote(download_url, safe='')}",
name=file_payload.get("name") or file_payload.get("filename"),
mime_type=file_payload.get("content_type")
or file_payload.get("mime_type"),
size=file_payload.get("size"),
)
]
if text:
text = re.sub(r"@\S+", "", text).strip()
@@ -257,7 +286,7 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
client.send_msg(title="只有管理员才有权限执行此命令", userid=sender)
return None
if not text and not images and not audio_refs:
if not text and not images and not audio_refs and not files:
return None
logger.info(
@@ -272,6 +301,7 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
text=text or "",
images=images,
audio_refs=audio_refs,
files=files,
)
def post_message(self, message: Notification, **kwargs) -> None:
@@ -338,6 +368,9 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
if media_ref.startswith("wxwork://voice_media_id/"):
media_id = media_ref.replace("wxwork://voice_media_id/", "", 1)
return client.download_media_bytes(media_id)
if media_ref.startswith("wxwork://file_media_id/"):
media_id = media_ref.replace("wxwork://file_media_id/", "", 1)
return client.download_media_bytes(media_id)
return None
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:

View File

@@ -29,6 +29,16 @@ class CommingMessage(BaseModel):
外来消息
"""
class MessageAttachment(BaseModel):
"""
外来消息附件(非图片/非语音)
"""
ref: str
name: Optional[str] = None
mime_type: Optional[str] = None
size: Optional[int] = None
# 用户ID
userid: Optional[Union[str, int]] = None
# 用户名称
@@ -57,6 +67,8 @@ class CommingMessage(BaseModel):
images: Optional[List[str]] = None
# 语音/音频引用列表
audio_refs: Optional[List[str]] = None
# 文件附件列表
files: Optional[List[MessageAttachment]] = None
def to_dict(self):
"""
@@ -90,6 +102,10 @@ class Notification(BaseModel):
image: Optional[str] = None
# 语音文件路径
voice_path: Optional[str] = None
# 本地文件路径
file_path: Optional[str] = None
# 发送时展示的文件名
file_name: Optional[str] = None
# 语音消息附带说明文字
voice_caption: Optional[str] = None
# 链接
@@ -254,6 +270,7 @@ class ChannelCapabilityManager:
ChannelCapability.IMAGES,
ChannelCapability.LINKS,
ChannelCapability.MENU_COMMANDS,
ChannelCapability.FILE_SENDING,
},
max_buttons_per_row=3,
max_button_rows=8,
@@ -272,6 +289,7 @@ class ChannelCapabilityManager:
ChannelCapability.RICH_TEXT,
ChannelCapability.IMAGES,
ChannelCapability.LINKS,
ChannelCapability.FILE_SENDING,
},
max_buttons_per_row=5,
max_button_rows=5,

View File

@@ -1,6 +1,8 @@
import base64
import json
import tempfile
import unittest
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, Mock, patch
from urllib.parse import quote
@@ -8,6 +10,7 @@ from urllib.parse import quote
from telebot import apihelper
from app.agent.tools.impl.send_message import SendMessageInput
from app.agent.tools.impl.send_local_file import SendLocalFileInput
from app.agent import MoviePilotAgent, AgentChain
from app.chain.message import MessageChain
from app.core.config import settings
@@ -161,6 +164,32 @@ class AgentImageSupportTest(unittest.TestCase):
self.assertEqual(handle_kwargs["text"], "")
self.assertEqual(handle_kwargs["audio_refs"], ["tg://voice_file_id/voice-1"])
def test_process_allows_file_only_message(self):
chain = MessageChain()
message = CommingMessage(
channel=MessageChannel.Telegram,
source="telegram-test",
userid="10001",
username="tester",
files=[
CommingMessage.MessageAttachment(
ref="tg://document_file_id/doc-1",
name="note.txt",
mime_type="text/plain",
size=12,
)
],
)
with patch.object(chain, "message_parser", return_value=message), patch.object(
chain, "handle_message"
) as handle_message:
chain.process(body="{}", form={}, args={"source": "telegram-test"})
handle_kwargs = handle_message.call_args.kwargs
self.assertEqual(handle_kwargs["text"], "")
self.assertEqual(handle_kwargs["files"][0].ref, "tg://document_file_id/doc-1")
def test_image_message_routes_to_agent_even_when_global_agent_is_disabled(self):
chain = MessageChain()
@@ -205,6 +234,36 @@ class AgentImageSupportTest(unittest.TestCase):
self.assertEqual(handle_ai_message.call_args.kwargs["text"], "帮我推荐一部电影")
self.assertTrue(handle_ai_message.call_args.kwargs["reply_with_voice"])
def test_file_message_routes_to_agent_even_when_global_agent_is_disabled(self):
chain = MessageChain()
with patch.object(chain, "load_cache", return_value={}), patch.object(
chain.messagehelper, "put"
), patch.object(chain.messageoper, "add"), patch.object(
chain, "_handle_ai_message"
) as handle_ai_message, patch.object(
settings, "AI_AGENT_ENABLE", True
), patch.object(
settings, "AI_AGENT_GLOBAL", False
):
chain.handle_message(
channel=MessageChannel.Telegram,
source="telegram-test",
userid="10001",
username="tester",
text="",
files=[
CommingMessage.MessageAttachment(
ref="tg://document_file_id/doc-1",
name="report.txt",
mime_type="text/plain",
)
],
)
handle_ai_message.assert_called_once()
self.assertEqual(handle_ai_message.call_args.kwargs["files"][0].name, "report.txt")
def test_transcribe_audio_refs_supports_new_channel_refs(self):
chain = MessageChain()
audio_refs = [
@@ -276,6 +335,41 @@ class AgentImageSupportTest(unittest.TestCase):
self.assertIsNone(notification.voice_path)
self.assertEqual(notification.text, "这是语音回复")
def test_agent_process_wraps_request_as_structured_json(self):
agent = MoviePilotAgent(
session_id="session-1",
user_id="user-1",
channel=MessageChannel.Telegram.value,
source="telegram-test",
username="tester",
)
with patch(
"app.agent.memory.memory_manager.get_agent_messages", return_value=[]
), patch.object(agent, "_execute_agent", new_callable=AsyncMock) as execute_agent:
import asyncio
asyncio.run(
agent.process(
"帮我总结这个文件",
files=[
{
"name": "report.txt",
"local_path": "/tmp/report.txt",
"status": "ready",
}
],
)
)
messages = execute_agent.await_args.args[0]
human_message = messages[-1]
content = human_message.content
self.assertIsInstance(content, list)
payload = json.loads(content[0]["text"])
self.assertEqual(payload["message"], "帮我总结这个文件")
self.assertEqual(payload["files"][0]["local_path"], "/tmp/report.txt")
def test_slack_images_use_authenticated_data_url_download(self):
chain = MessageChain()
@@ -345,6 +439,15 @@ class AgentImageSupportTest(unittest.TestCase):
self.assertEqual(payload.image_url, "https://example.com/poster.png")
def test_send_local_file_input_accepts_file_payload(self):
payload = SendLocalFileInput(
explanation="send generated report",
file_path="/tmp/report.txt",
message="请下载查看",
)
self.assertEqual(payload.file_path, "/tmp/report.txt")
def test_discord_extract_images_supports_attachment_content_type(self):
images = DiscordModule._extract_images(
{
@@ -380,6 +483,26 @@ class AgentImageSupportTest(unittest.TestCase):
],
)
def test_discord_extract_files_supports_non_media_attachment(self):
files = DiscordModule._extract_files(
{
"attachments": [
{
"content_type": "application/pdf",
"filename": "guide.pdf",
"url": "https://cdn.discordapp.com/guide.pdf",
"size": 1024,
}
]
}
)
self.assertEqual(files[0].name, "guide.pdf")
self.assertEqual(
files[0].ref,
"discord://file/" + quote("https://cdn.discordapp.com/guide.pdf", safe=""),
)
def test_discord_send_direct_message_returns_chat_id(self):
module = DiscordModule()
client = Mock()
@@ -466,6 +589,46 @@ class AgentImageSupportTest(unittest.TestCase):
self.assertIsNotNone(message)
self.assertEqual(message.images, ["wxwork://media_id/media-1"])
def test_wechat_message_parser_extracts_file_media_id(self):
module = WechatModule()
xml_message = b"""
<xml>
<FromUserName><![CDATA[user-1]]></FromUserName>
<MsgType><![CDATA[file]]></MsgType>
<MediaId><![CDATA[file-media-1]]></MediaId>
<FileName><![CDATA[manual.pdf]]></FileName>
</xml>
"""
crypt = Mock()
crypt.DecryptMsg.return_value = (0, xml_message)
with patch.object(
module,
"get_config",
return_value=SimpleNamespace(
name="wechat-test",
config={
"WECHAT_TOKEN": "token",
"WECHAT_ENCODING_AESKEY": "encoding",
"WECHAT_CORPID": "corpid",
},
),
), patch.object(
module, "get_instance", return_value=SimpleNamespace(send_msg=Mock())
), patch(
"app.modules.wechat.WXBizMsgCrypt",
return_value=crypt,
):
message = module.message_parser(
source="wechat-test",
body=b"encrypted",
form={},
args={"msg_signature": "sig", "timestamp": "1", "nonce": "n"},
)
self.assertIsNotNone(message)
self.assertEqual(message.files[0].ref, "wxwork://file_media_id/file-media-1")
def test_wechat_bot_parser_accepts_image_only_payload(self):
module = WechatModule()
body = json.dumps(
@@ -594,6 +757,38 @@ class AgentImageSupportTest(unittest.TestCase):
["vocechat://file/%2Fuploads%2Fvoice.ogg"],
)
def test_vocechat_message_parser_extracts_generic_file_payload(self):
module = VoceChatModule()
body = json.dumps(
{
"detail": {
"type": "normal",
"content_type": "vocechat/file",
"content": "/uploads/manual.pdf",
"properties": {"content_type": "application/pdf"},
},
"from_uid": 7910,
"target": {"gid": 2},
}
)
with patch.object(
module,
"get_config",
return_value=SimpleNamespace(
name="vocechat-test", config={"channel_id": "2"}
),
):
message = module.message_parser(
source="vocechat-test",
body=body,
form={},
args={},
)
self.assertIsNotNone(message)
self.assertEqual(message.files[0].ref, "vocechat://file/%2Fuploads%2Fmanual.pdf")
def test_vocechat_post_message_passes_image_and_correct_target(self):
module = VoceChatModule()
client = Mock()
@@ -623,6 +818,64 @@ class AgentImageSupportTest(unittest.TestCase):
link=None,
)
def test_slack_post_message_passes_local_file(self):
module = SlackModule()
client = Mock()
with tempfile.TemporaryDirectory() as tempdir:
file_path = Path(tempdir) / "guide.pdf"
file_path.write_bytes(b"pdf")
with patch.object(
module,
"get_configs",
return_value={"slack-test": SimpleNamespace(name="slack-test")},
), patch.object(
module, "check_message", return_value=True
), patch.object(
module, "get_instance", return_value=client
):
module.post_message(
Notification(
title="手册",
text="请下载",
file_path=str(file_path),
file_name="guide.pdf",
userid="U123",
)
)
client.send_file.assert_called_once()
def test_discord_post_message_passes_local_file(self):
module = DiscordModule()
client = Mock()
with tempfile.TemporaryDirectory() as tempdir:
file_path = Path(tempdir) / "guide.pdf"
file_path.write_bytes(b"pdf")
with patch.object(
module,
"get_configs",
return_value={"discord-test": SimpleNamespace(name="discord-test")},
), patch.object(
module, "check_message", return_value=True
), patch.object(
module, "get_instance", return_value=client
):
module.post_message(
Notification(
title="手册",
text="请下载",
file_path=str(file_path),
file_name="guide.pdf",
userid="user-1",
)
)
client.send_file.assert_called_once()
def test_qq_message_parser_accepts_image_only_attachment(self):
module = QQBotModule()
@@ -745,5 +998,97 @@ class AgentImageSupportTest(unittest.TestCase):
["synology://file/" + quote("https://example.com/voice.ogg", safe="")],
)
def test_synology_message_parser_accepts_generic_file_attachment(self):
module = SynologyChatModule()
with patch.object(
module,
"get_config",
return_value=SimpleNamespace(name="synology-test", config={}),
), patch.object(
module,
"get_instance",
return_value=SimpleNamespace(check_token=lambda token: token == "token-1"),
):
message = module.message_parser(
source="synology-test",
body={},
form={
"token": "token-1",
"user_id": "42",
"username": "tester",
"attachments": json.dumps(
[
{
"url": "https://example.com/manual.pdf",
"content_type": "application/pdf",
"filename": "manual.pdf",
}
]
),
},
args={},
)
self.assertIsNotNone(message)
self.assertEqual(
message.files[0].ref,
"synology://file/" + quote("https://example.com/manual.pdf", safe=""),
)
def test_prepare_agent_files_saves_local_file(self):
chain = MessageChain()
with tempfile.TemporaryDirectory() as tempdir, patch.object(
settings, "TEMP_PATH", Path(tempdir)
), patch.object(
chain,
"_download_message_file_bytes",
return_value="你好MoviePilot".encode("utf-8"),
):
prepared = chain._prepare_agent_files(
session_id="session-1",
files=[
CommingMessage.MessageAttachment(
ref="tg://document_file_id/doc-1",
name="note.txt",
mime_type="text/plain",
)
],
channel=MessageChannel.Telegram,
source="telegram-test",
)
self.assertEqual(prepared[0]["status"], "ready")
self.assertTrue(Path(prepared[0]["local_path"]).exists())
def test_telegram_post_message_passes_file_to_client(self):
module = TelegramModule()
client = Mock()
with tempfile.TemporaryDirectory() as tempdir:
file_path = Path(tempdir) / "report.txt"
file_path.write_text("hello", encoding="utf-8")
with patch.object(
module,
"get_configs",
return_value={"telegram-test": SimpleNamespace(name="telegram-test")},
), patch.object(
module, "check_message", return_value=True
), patch.object(
module, "get_instance", return_value=client
):
module.post_message(
Notification(
title="报告",
text="请下载",
file_path=str(file_path),
file_name="report.txt",
userid="user-1",
)
)
client.send_file.assert_called_once()
if __name__ == "__main__":
unittest.main()