diff --git a/app/agent/__init__.py b/app/agent/__init__.py index dae14d29..974d8285 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -1,4 +1,5 @@ import asyncio +import json import re import traceback import uuid @@ -281,13 +282,19 @@ class MoviePilotAgent: logger.error(f"创建 Agent 失败: {e}") raise e - async def process(self, message: str, images: List[str] = None) -> str: + async def process( + self, + message: str, + images: List[str] = None, + files: Optional[List[dict]] = None, + ) -> str: """ 处理用户消息,流式推理并返回 Agent 回复 """ try: logger.info( - f"Agent推理: session_id={self.session_id}, input={message}, images={len(images) if images else 0}" + f"Agent推理: session_id={self.session_id}, input={message}, " + f"images={len(images) if images else 0}, files={len(files) if files else 0}" ) self._tool_context = { "incoming_voice": self.reply_with_voice, @@ -300,16 +307,24 @@ class MoviePilotAgent: session_id=self.session_id, user_id=self.user_id ) - # 构建用户消息内容 - if images: - content = [] - if message: - content.append({"type": "text", "text": message}) - for img in images: - content.append({"type": "image_url", "image_url": {"url": img}}) - messages.append(HumanMessage(content=content)) - else: - messages.append(HumanMessage(content=message)) + # 构建结构化用户消息内容 + request_payload = { + "message": message or "", + "images": [ + {"index": index + 1, "type": "image"} + for index, _ in enumerate(images or []) + ], + "files": files or [], + } + content = [ + { + "type": "text", + "text": json.dumps(request_payload, ensure_ascii=False, indent=2), + } + ] + for img in images or []: + content.append({"type": "image_url", "image_url": {"url": img}}) + messages.append(HumanMessage(content=content)) # 执行推理 await self._execute_agent(messages) @@ -544,6 +559,7 @@ class _MessageTask: user_id: str message: str images: Optional[List[str]] = None + files: Optional[List[dict]] = None channel: Optional[str] = None source: Optional[str] = None username: Optional[str] = None @@ -610,6 +626,7 @@ class AgentManager: user_id: str, message: str, images: List[str] = None, + files: Optional[List[dict]] = None, channel: str = None, source: str = None, username: str = None, @@ -624,6 +641,7 @@ class AgentManager: user_id=user_id, message=message, images=images, + files=files, channel=channel, source=source, username=username, @@ -727,7 +745,7 @@ class AgentManager: agent.username = task.username agent.reply_with_voice = task.reply_with_voice - return await agent.process(task.message, images=task.images) + return await agent.process(task.message, images=task.images, files=task.files) async def stop_current_task(self, session_id: str): """ diff --git a/app/agent/prompt/Agent Prompt.txt b/app/agent/prompt/Agent Prompt.txt index 2d4e9c4b..8672f93a 100644 --- a/app/agent/prompt/Agent Prompt.txt +++ b/app/agent/prompt/Agent Prompt.txt @@ -10,6 +10,7 @@ Core Capabilities: 3. Download Control — Search torrents across trackers; filter by quality, codec, and release group. 4. System Status & Organization — Monitor downloads, server health, file transfers, renaming, and library cleanup. 5. Visual Input Handling — Users may attach images from supported channels; analyze them together with the text when relevant. +6. File Context Handling — User messages may arrive as structured JSON. Treat the `message` field as the user's text. Non-image attachments appear in `files`; when `local_path` is present, use local file tools to inspect the uploaded file directly. {verbose_spec} @@ -21,6 +22,7 @@ Core Capabilities: - Include key details (year, rating, resolution) but do NOT over-explain. - Do not stop for approval on read-only operations. Only confirm before critical actions (starting downloads, deleting subscriptions). - If the current channel supports image sending and an image would materially help, you may use the `send_message` tool with `image_url` to send it. +- If the current channel supports file sending and you need to return a local image/file for the user to download, use `send_local_file`. - Voice replies: {voice_reply_spec} - NOT a coding assistant. Do not offer code snippets. - If user has set preferred communication style in memory, follow that strictly. diff --git a/app/agent/tools/factory.py b/app/agent/tools/factory.py index f54251af..a5e3ad3f 100644 --- a/app/agent/tools/factory.py +++ b/app/agent/tools/factory.py @@ -30,6 +30,7 @@ from app.agent.tools.impl.search_torrents import SearchTorrentsTool from app.agent.tools.impl.get_search_results import GetSearchResultsTool from app.agent.tools.impl.search_web import SearchWebTool from app.agent.tools.impl.send_message import SendMessageTool +from app.agent.tools.impl.send_local_file import SendLocalFileTool from app.agent.tools.impl.send_voice_message import SendVoiceMessageTool from app.agent.tools.impl.query_schedulers import QuerySchedulersTool from app.agent.tools.impl.run_scheduler import RunSchedulerTool @@ -119,6 +120,7 @@ class MoviePilotToolFactory: QueryTransferHistoryTool, TransferFileTool, SendMessageTool, + SendLocalFileTool, SendVoiceMessageTool, QuerySchedulersTool, RunSchedulerTool, diff --git a/app/agent/tools/impl/send_local_file.py b/app/agent/tools/impl/send_local_file.py new file mode 100644 index 00000000..8f34edd1 --- /dev/null +++ b/app/agent/tools/impl/send_local_file.py @@ -0,0 +1,107 @@ +"""发送本地附件工具。""" + +from pathlib import Path +from typing import Optional, Type + +from pydantic import BaseModel, Field, model_validator + +from app.agent.tools.base import MoviePilotTool, ToolChain +from app.log import logger +from app.schemas import Notification, NotificationType +from app.schemas.message import ChannelCapabilityManager, ChannelCapability +from app.schemas.types import MessageChannel + + +class SendLocalFileInput(BaseModel): + """发送本地附件工具输入。""" + + explanation: str = Field( + ..., + description="Clear explanation of why sending this local file helps the user", + ) + file_path: str = Field( + ..., + description="Absolute path to the local image or file to send to the user", + ) + message: Optional[str] = Field( + None, + description="Optional message or caption to send with the attachment", + ) + title: Optional[str] = Field( + None, + description="Optional short title shown together with the attachment", + ) + file_name: Optional[str] = Field( + None, + description="Optional override filename presented to the user when downloading", + ) + + @model_validator(mode="after") + def validate_file_path(self): + if not self.file_path: + raise ValueError("file_path 不能为空") + return self + + +class SendLocalFileTool(MoviePilotTool): + name: str = "send_local_file" + description: str = ( + "Send a local image or file from the server filesystem to the current user. " + "Use this when you have generated or identified a local file the user should download." + ) + args_schema: Type[BaseModel] = SendLocalFileInput + require_admin: bool = False + + def get_tool_message(self, **kwargs) -> Optional[str]: + file_path = kwargs.get("file_path", "") + file_name = Path(file_path).name if file_path else "未知文件" + return f"正在发送本地附件: {file_name}" + + async def run( + self, + file_path: str, + message: Optional[str] = None, + title: Optional[str] = None, + file_name: Optional[str] = None, + **kwargs, + ) -> str: + if not self._channel or not self._source: + return "当前不在可回传消息的会话中,无法发送附件" + + try: + channel = MessageChannel(self._channel) + except ValueError: + return f"不支持的消息渠道: {self._channel}" + + if not ChannelCapabilityManager.supports_capability( + channel, ChannelCapability.FILE_SENDING + ): + return f"当前渠道 {channel.value} 暂不支持发送本地文件" + + resolved_path = Path(file_path).expanduser() + if not resolved_path.is_absolute(): + resolved_path = resolved_path.resolve() + if not resolved_path.exists() or not resolved_path.is_file(): + return f"文件不存在: {resolved_path}" + + logger.info( + "执行工具: %s, channel=%s, file=%s", + self.name, + channel.value, + resolved_path, + ) + + await ToolChain().async_post_message( + Notification( + channel=channel, + source=self._source, + mtype=NotificationType.Agent, + userid=self._user_id, + username=self._username, + title=title, + text=message, + file_path=str(resolved_path), + file_name=file_name or resolved_path.name, + ) + ) + return "本地附件已发送" diff --git a/app/agent/tools/impl/send_message.py b/app/agent/tools/impl/send_message.py index 9e2dadc7..e4a17d2f 100644 --- a/app/agent/tools/impl/send_message.py +++ b/app/agent/tools/impl/send_message.py @@ -19,7 +19,7 @@ class SendMessageInput(BaseModel): None, description="The message content to send to the user (should be clear and informative)", ) - message_type: Optional[str] = Field( + title: Optional[str] = Field( None, description="Title of the message, a short summary of the message content", ) @@ -30,8 +30,8 @@ class SendMessageInput(BaseModel): @model_validator(mode="after") def validate_payload(self): - if not self.message and not self.message_type and not self.image_url: - raise ValueError("message、message_type、image_url 至少需要提供一个") + if not self.message and not self.title and not self.image_url: + raise ValueError("message、title、image_url 至少需要提供一个") return self @@ -44,7 +44,7 @@ class SendMessageTool(MoviePilotTool): def get_tool_message(self, **kwargs) -> Optional[str]: """根据消息参数生成友好的提示消息""" message = kwargs.get("message", "") or "" - title = kwargs.get("message_type") or "" + title = kwargs.get("title") or "" image_url = kwargs.get("image_url") # 截断过长的消息 @@ -62,11 +62,11 @@ class SendMessageTool(MoviePilotTool): async def run( self, message: Optional[str] = None, - message_type: Optional[str] = None, + title: Optional[str] = None, image_url: Optional[str] = None, **kwargs, ) -> str: - title = message_type or ("图片" if image_url and not message else "") + title = title or ("图片" if image_url and not message else "") text = message or "" logger.info( f"执行工具: {self.name}, 参数: title={title}, message={text}, image_url={image_url}" diff --git a/app/chain/message.py b/app/chain/message.py index 4c717e17..8af455c5 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -2,8 +2,10 @@ import asyncio import re import time from datetime import datetime, timedelta +from pathlib import Path from typing import Any, Optional, Dict, Union, List from urllib.parse import unquote +import uuid import base64 @@ -135,7 +137,8 @@ class MessageChain(ChainBase): text = str(info.text).strip() if info.text else "" images = info.images audio_refs = info.audio_refs - if not text and not images and not audio_refs: + files = info.files + if not text and not images and not audio_refs and not files: logger.debug(f"未识别到消息内容::{body}{form}{args}") return @@ -154,6 +157,7 @@ class MessageChain(ChainBase): original_chat_id=original_chat_id, images=images, audio_refs=audio_refs, + files=files, ) def handle_message( @@ -167,6 +171,7 @@ class MessageChain(ChainBase): original_chat_id: Optional[str] = None, images: Optional[List[str]] = None, audio_refs: Optional[List[str]] = None, + files: Optional[List[CommingMessage.MessageAttachment]] = None, ) -> None: """ 识别消息内容,执行操作 @@ -253,9 +258,12 @@ class MessageChain(ChainBase): userid=userid, username=username, images=images, + files=files, reply_with_voice=reply_with_voice, ) - elif settings.AI_AGENT_ENABLE and settings.AI_AGENT_GLOBAL: + elif settings.AI_AGENT_ENABLE and ( + settings.AI_AGENT_GLOBAL or images or files + ): # 普通消息,全局智能体响应 self._handle_ai_message( text=text, @@ -264,6 +272,7 @@ class MessageChain(ChainBase): userid=userid, username=username, images=images, + files=files, reply_with_voice=reply_with_voice, ) else: @@ -1230,6 +1239,7 @@ class MessageChain(ChainBase): userid: Union[str, int], username: str, images: Optional[List[str]] = None, + files: Optional[List[CommingMessage.MessageAttachment]] = None, reply_with_voice: bool = False, ) -> None: """ @@ -1255,7 +1265,7 @@ class MessageChain(ChainBase): else: user_message = text.strip() # 按原消息处理 - if not user_message and not images: + if not user_message and not images and not files: self.post_message( Notification( channel=channel, @@ -1274,7 +1284,7 @@ class MessageChain(ChainBase): original_images = images if images: images = self._download_images_to_base64(images, channel, source) - if original_images and not images and not user_message: + if original_images and not images and not user_message and not files: self.post_message( Notification( channel=channel, @@ -1286,6 +1296,24 @@ class MessageChain(ChainBase): ) return + prepared_files = self._prepare_agent_files( + session_id=session_id, + files=files, + channel=channel, + source=source, + ) + if files and not prepared_files and not user_message and not images: + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title="文件读取失败,请稍后重试", + ) + ) + return + # 在事件循环中处理 asyncio.run_coroutine_threadsafe( agent_manager.process_message( @@ -1293,6 +1321,7 @@ class MessageChain(ChainBase): user_id=str(userid), message=user_message, images=images, + files=prepared_files, channel=channel.value if channel else None, source=source, username=username, @@ -1481,3 +1510,148 @@ class MessageChain(ChainBase): except Exception as e: logger.error(f"下载图片失败: {img}, error: {e}") return base64_images if base64_images else None + + def _prepare_agent_files( + self, + session_id: str, + files: Optional[List[CommingMessage.MessageAttachment]], + channel: MessageChannel, + source: str, + ) -> Optional[List[dict]]: + """ + 下载用户上传的文件,落盘到临时目录,并生成文本镜像供 Agent 使用。 + """ + if not files: + return None + + prepared_files = [] + for attachment in files: + payload = { + "name": attachment.name, + "mime_type": attachment.mime_type, + "size": attachment.size, + "ref": attachment.ref, + "status": "download_failed", + } + try: + content = self._download_message_file_bytes( + file_ref=attachment.ref, + channel=channel, + source=source, + ) + if not content: + prepared_files.append(payload) + continue + + local_path = self._save_agent_attachment( + session_id=session_id, + filename=attachment.name, + content=content, + mime_type=attachment.mime_type, + ) + payload.update( + { + "local_path": str(local_path), + "status": "ready", + } + ) + except Exception as err: + logger.error(f"准备文件上下文失败: {attachment.ref}, error: {err}") + payload["error"] = str(err) + prepared_files.append(payload) + + return prepared_files or None + + def _download_message_file_bytes( + self, file_ref: str, channel: MessageChannel, source: str + ) -> Optional[bytes]: + """ + 下载消息附件的原始字节。 + """ + if not file_ref: + return None + if file_ref.startswith("tg://document_file_id/"): + file_id = file_ref.replace("tg://document_file_id/", "", 1) + return self.run_module( + "download_telegram_file_bytes", file_id=file_id, source=source + ) + if file_ref.startswith("wxwork://file_media_id/"): + return self.run_module( + "download_wechat_media_bytes", media_ref=file_ref, source=source + ) + if file_ref.startswith("wxbot://file/"): + file_url = unquote(file_ref.replace("wxbot://file/", "", 1)) + resp = RequestUtils(timeout=30).get_res(file_url) + return resp.content if resp and resp.content else None + if file_ref.startswith("slack://file/"): + return self.run_module( + "download_slack_file_bytes", file_ref=file_ref, source=source + ) + if file_ref.startswith("discord://file/"): + return self.run_module( + "download_discord_file_bytes", file_ref=file_ref, source=source + ) + if file_ref.startswith("qq://file/"): + return self.run_module( + "download_qq_file_bytes", file_ref=file_ref, source=source + ) + if file_ref.startswith("vocechat://file/"): + return self.run_module( + "download_vocechat_file_bytes", file_ref=file_ref, source=source + ) + if file_ref.startswith("synology://file/"): + return self.run_module( + "download_synologychat_file_bytes", file_ref=file_ref, source=source + ) + if file_ref.startswith("http"): + resp = RequestUtils(timeout=30).get_res(file_ref) + return resp.content if resp and resp.content else None + logger.debug( + "暂不支持的文件引用: channel=%s, source=%s, ref=%s", + channel.value if channel else None, + source, + file_ref, + ) + return None + + def _save_agent_attachment( + self, + session_id: str, + filename: Optional[str], + content: bytes, + mime_type: Optional[str] = None, + ) -> Path: + """ + 将用户上传文件写入临时目录,并返回本地路径。 + """ + safe_name = self._sanitize_attachment_name(filename, mime_type) + base_dir = settings.TEMP_PATH / "agent_uploads" / session_id + base_dir.mkdir(parents=True, exist_ok=True) + + file_id = uuid.uuid4().hex[:8] + local_path = base_dir / f"{file_id}_{safe_name}" + local_path.write_bytes(content or b"") + return local_path + + @staticmethod + def _sanitize_attachment_name( + filename: Optional[str], mime_type: Optional[str] = None + ) -> str: + """ + 规范化附件文件名,避免路径穿越和非法字符。 + """ + name = Path(filename or "attachment").name + name = re.sub(r"[^\w.\-]+", "_", name, flags=re.ASCII).strip("._") + if not name: + name = "attachment" + if "." not in name: + mime = (mime_type or "").split(";", 1)[0].strip().lower() + default_ext = { + "application/json": ".json", + "text/plain": ".txt", + "text/markdown": ".md", + "text/csv": ".csv", + }.get(mime) + if default_ext: + name = f"{name}{default_ext}" + return name diff --git a/app/modules/discord/__init__.py b/app/modules/discord/__init__.py index 9efb46f9..9c469c30 100644 --- a/app/modules/discord/__init__.py +++ b/app/modules/discord/__init__.py @@ -159,11 +159,13 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]): chat_id = msg_json.get("chat_id") images = self._extract_images(msg_json) audio_refs = self._extract_audio_refs(msg_json) - if (text or images or audio_refs) and userid: + files = self._extract_files(msg_json) + if (text or images or audio_refs or files) and userid: logger.info( f"收到来自 {client_config.name} 的 Discord 消息:" f"userid={userid}, username={username}, text={text}, " - f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}" + f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}, " + f"files={len(files) if files else 0}" ) return CommingMessage( channel=MessageChannel.Discord, @@ -174,6 +176,7 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]): chat_id=str(chat_id) if chat_id else None, images=images, audio_refs=audio_refs, + files=files, ) return None @@ -219,6 +222,44 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]): audio_refs.append(f"discord://file/{quote(url, safe='')}") return audio_refs if audio_refs else None + @classmethod + def _extract_files( + cls, msg_json: dict + ) -> Optional[List[CommingMessage.MessageAttachment]]: + """ + 从 Discord 消息中提取非图片/非音频文件。 + """ + attachments = msg_json.get("attachments", []) + if not attachments: + return None + + files = [] + for attachment in attachments: + url = attachment.get("url") or attachment.get("proxy_url") + if not url: + continue + content_type = (attachment.get("content_type") or "").lower() + filename = (attachment.get("filename") or "").lower() + is_image = ( + attachment.get("type") == "image" + or content_type.startswith("image/") + or filename.endswith(cls._IMAGE_SUFFIXES) + ) + is_audio = content_type.startswith("audio/") or filename.endswith( + cls._AUDIO_SUFFIXES + ) + if is_image or is_audio: + continue + files.append( + CommingMessage.MessageAttachment( + ref=f"discord://file/{quote(url, safe='')}", + name=attachment.get("filename"), + mime_type=attachment.get("content_type"), + size=attachment.get("size"), + ) + ) + return files or None + def download_discord_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]: """ 下载Discord附件并返回原始字节 @@ -278,19 +319,29 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]): ) if client: logger.debug( - f"[Discord] 调用 client.send_msg, userid={userid}, title={message.title[:50] if message.title else None}..." - ) - result = client.send_msg( - title=message.title, - text=message.text, - image=message.image, - userid=userid, - link=message.link, - buttons=message.buttons, - original_message_id=message.original_message_id, - original_chat_id=message.original_chat_id, - mtype=message.mtype, + f"[Discord] 调用 client 发送, userid={userid}, title={message.title[:50] if message.title else None}..." ) + if message.file_path: + result = client.send_file( + file_path=message.file_path, + file_name=message.file_name, + title=message.title, + text=message.text, + userid=userid, + original_chat_id=message.original_chat_id, + ) + else: + result = client.send_msg( + title=message.title, + text=message.text, + image=message.image, + userid=userid, + link=message.link, + buttons=message.buttons, + original_message_id=message.original_message_id, + original_chat_id=message.original_chat_id, + mtype=message.mtype, + ) logger.debug(f"[Discord] send_msg 返回结果: {result}") else: logger.warning( @@ -427,11 +478,20 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]): return None client: Discord = self.get_instance(conf.name) if client: - result = client.send_msg( - title=message.title or "", - text=message.text, - userid=userid, - ) + if message.file_path: + result = client.send_file( + file_path=message.file_path, + file_name=message.file_name, + title=message.title, + text=message.text, + userid=userid, + ) + else: + result = client.send_msg( + title=message.title or "", + text=message.text, + userid=userid, + ) if result: success, response_data = ( (result[0], result[1]) diff --git a/app/modules/discord/discord.py b/app/modules/discord/discord.py index e2ddb858..aa4a3953 100644 --- a/app/modules/discord/discord.py +++ b/app/modules/discord/discord.py @@ -1,6 +1,7 @@ import asyncio import re import threading +from pathlib import Path from typing import Optional, List, Dict, Any, Tuple, Union from urllib.parse import quote @@ -273,6 +274,37 @@ class Discord: logger.error(f"发送 Discord 消息失败:{err}") return False + def send_file( + self, + file_path: str, + title: Optional[str] = None, + text: Optional[str] = None, + userid: Optional[str] = None, + file_name: Optional[str] = None, + original_chat_id: Optional[str] = None, + ) -> Optional[bool]: + if not self.get_state(): + return False + if not file_path: + return False + + try: + future = asyncio.run_coroutine_threadsafe( + self._send_file( + file_path=file_path, + title=title, + text=text, + userid=userid, + file_name=file_name, + original_chat_id=original_chat_id, + ), + self._loop, + ) + return future.result(timeout=30) + except Exception as err: + logger.error(f"发送 Discord 文件失败:{err}") + return False + def send_medias_msg( self, medias: List[MediaInfo], @@ -414,6 +446,46 @@ class Discord: logger.error(f"[Discord] 发送消息到频道失败: {e}") return False, None + async def _send_file( + self, + file_path: str, + title: Optional[str], + text: Optional[str], + userid: Optional[str], + file_name: Optional[str], + original_chat_id: Optional[str], + ) -> Tuple[bool, Optional[Dict[str, str]]]: + channel = await self._resolve_channel(userid=userid, chat_id=original_chat_id) + if not channel: + logger.error("未找到可用的 Discord 频道或私聊") + return False, None + + local_file = Path(file_path) + if not local_file.exists() or not local_file.is_file(): + logger.error(f"Discord发送文件失败,文件不存在: {local_file}") + return False, None + + content_parts = [part for part in [title, text] if part] + content = "\n".join(content_parts) if content_parts else None + if content and len(content) > 1900: + content = content[:1900] + "..." + + try: + discord_file = discord.File( + str(local_file), filename=file_name or local_file.name + ) + sent_message = await channel.send(content=content, file=discord_file) + return ( + True, + { + "message_id": str(sent_message.id), + "chat_id": str(channel.id), + }, + ) + except Exception as err: + logger.error(f"Discord发送文件失败: {err}") + return False, None + async def _send_list_message( self, embeds: List[discord.Embed], diff --git a/app/modules/qqbot/__init__.py b/app/modules/qqbot/__init__.py index 8ef699f7..eb76e1f3 100644 --- a/app/modules/qqbot/__init__.py +++ b/app/modules/qqbot/__init__.py @@ -107,7 +107,8 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]): content = (msg_body.get("content") or "").strip() images = self._extract_images(msg_body) audio_refs = self._extract_audio_refs(msg_body) - if not content and not images and not audio_refs: + files = self._extract_files(msg_body) + if not content and not images and not audio_refs and not files: return None if msg_type == "C2C_MESSAGE_CREATE": @@ -118,7 +119,7 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]): logger.info( f"收到 QQ 私聊消息: userid={user_openid}, " f"text={(content or '')[:50]}..., images={len(images) if images else 0}, " - f"audios={len(audio_refs) if audio_refs else 0}" + f"audios={len(audio_refs) if audio_refs else 0}, files={len(files) if files else 0}" ) return CommingMessage( channel=MessageChannel.QQ, @@ -128,6 +129,7 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]): text=content, images=images, audio_refs=audio_refs, + files=files, ) elif msg_type == "GROUP_AT_MESSAGE_CREATE": author = msg_body.get("author", {}) @@ -138,7 +140,7 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]): logger.info( f"收到 QQ 群消息: group={group_openid}, userid={member_openid}, " f"text={(content or '')[:50]}..., images={len(images) if images else 0}, " - f"audios={len(audio_refs) if audio_refs else 0}" + f"audios={len(audio_refs) if audio_refs else 0}, files={len(files) if files else 0}" ) return CommingMessage( channel=MessageChannel.QQ, @@ -148,6 +150,7 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]): text=content, images=images, audio_refs=audio_refs, + files=files, ) return None @@ -226,6 +229,46 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]): deduped.append(audio_ref) return deduped or None + @classmethod + def _extract_files( + cls, msg_body: dict + ) -> Optional[List[CommingMessage.MessageAttachment]]: + files: List[CommingMessage.MessageAttachment] = [] + attachments = msg_body.get("attachments") or [] + if isinstance(attachments, list): + for attachment in attachments: + if not isinstance(attachment, dict): + continue + url = attachment.get("url") or attachment.get("proxy_url") + if not url: + continue + content_type = ( + attachment.get("content_type") + or attachment.get("mime_type") + or "" + ).lower() + filename = ( + attachment.get("filename") or attachment.get("name") or "" + ).lower() + is_image = content_type.startswith("image/") or filename.endswith( + cls._IMAGE_SUFFIXES + ) + is_audio = content_type.startswith("audio/") or filename.endswith( + cls._AUDIO_SUFFIXES + ) + if is_image or is_audio: + continue + files.append( + CommingMessage.MessageAttachment( + ref=f"qq://file/{quote(url, safe='')}", + name=attachment.get("filename") or attachment.get("name"), + mime_type=attachment.get("content_type") + or attachment.get("mime_type"), + size=attachment.get("size"), + ) + ) + return files or None + def download_qq_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]: """ 下载QQ音频附件并返回原始字节 diff --git a/app/modules/slack/__init__.py b/app/modules/slack/__init__.py index 722e1ad7..d84774de 100644 --- a/app/modules/slack/__init__.py +++ b/app/modules/slack/__init__.py @@ -221,12 +221,14 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): if msg_json: images = None audio_refs = None + files = None if msg_json.get("type") == "message": userid = msg_json.get("user") text = msg_json.get("text") username = msg_json.get("user") images = self._extract_images(msg_json) audio_refs = self._extract_audio_refs(msg_json) + files = self._extract_files(msg_json) elif msg_json.get("type") == "block_actions": userid = msg_json.get("user", {}).get("id") callback_data = msg_json.get("actions")[0].get("value") @@ -270,6 +272,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): username = "" images = self._extract_images(msg_json.get("event", {})) audio_refs = self._extract_audio_refs(msg_json.get("event", {})) + files = self._extract_files(msg_json.get("event", {})) elif msg_json.get("type") == "shortcut": userid = msg_json.get("user", {}).get("id") text = msg_json.get("callback_id") @@ -282,7 +285,8 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): return None logger.info( f"收到来自 {client_config.name} 的Slack消息:userid={userid}, username={username}, " - f"text={text}, images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}" + f"text={text}, images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}, " + f"files={len(files) if files else 0}" ) return CommingMessage( channel=MessageChannel.Slack, @@ -292,6 +296,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): text=text, images=images, audio_refs=audio_refs, + files=files, ) return None @@ -341,6 +346,48 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): audio_refs.append(f"slack://file/{quote(url, safe='')}") return audio_refs if audio_refs else None + @classmethod + def _extract_files( + cls, msg_json: dict + ) -> Optional[List[CommingMessage.MessageAttachment]]: + """ + 从 Slack 消息中提取非图片/非音频文件。 + """ + files = msg_json.get("files", []) + if not files: + return None + + attachments = [] + for file in files: + file_type = str(file.get("type", "")).lower() + file_ext = f".{str(file.get('filetype', '')).lower().lstrip('.')}" + mime_type = str(file.get("mimetype", "")).lower() + is_image = ( + file_type == "image" + or file_ext in (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp") + or mime_type.startswith("image/") + ) + is_audio = ( + file_type == "audio" + or mime_type.startswith("audio/") + or file_ext in cls._AUDIO_SUFFIXES + ) + if is_image or is_audio: + continue + + url = file.get("url_private_download") or file.get("url_private") + if not url: + continue + attachments.append( + CommingMessage.MessageAttachment( + ref=f"slack://file/{quote(url, safe='')}", + name=file.get("name") or file.get("title"), + mime_type=file.get("mimetype"), + size=file.get("size"), + ) + ) + return attachments or None + def download_slack_file_to_data_url(self, file_url: str, source: str) -> Optional[str]: """ 下载Slack文件并转为data URL @@ -399,16 +446,25 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): return client: Slack = self.get_instance(conf.name) if client: - client.send_msg( - title=message.title, - text=message.text, - image=message.image, - userid=userid, - link=message.link, - buttons=message.buttons, - original_message_id=message.original_message_id, - original_chat_id=message.original_chat_id, - ) + if message.file_path: + client.send_file( + file_path=message.file_path, + file_name=message.file_name, + title=message.title, + text=message.text, + userid=userid, + ) + else: + client.send_msg( + title=message.title, + text=message.text, + image=message.image, + userid=userid, + link=message.link, + buttons=message.buttons, + original_message_id=message.original_message_id, + original_chat_id=message.original_chat_id, + ) def post_medias_message( self, message: Notification, medias: List[MediaInfo] @@ -538,26 +594,40 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): return None client: Slack = self.get_instance(conf.name) if client: - result = client.send_msg( - title=message.title or "", - text=message.text, - userid=userid, - ) + if message.file_path: + result = client.send_file( + file_path=message.file_path, + file_name=message.file_name, + title=message.title, + text=message.text, + userid=userid, + ) + else: + result = client.send_msg( + title=message.title or "", + text=message.text, + userid=userid, + ) if result and result[0]: # Slack 使用时间戳作为 message_id,chat_id 是频道ID # 注意:这里返回的是发送后的结果,需要获取实际的 message_id # 由于 Slack API 返回的是 result[1],包含完整响应,我们需要从中提取 response_data = result[1] - message_id = ( - response_data.get("ts") - if isinstance(response_data, dict) - else None - ) - channel_id = ( - response_data.get("channel") - if isinstance(response_data, dict) - else None - ) + message_id = None + channel_id = None + if hasattr(response_data, "get"): + message_id = response_data.get("ts") + channel_id = response_data.get("channel") + if not message_id and hasattr(response_data, "data"): + files = (response_data.data or {}).get("files") or [] + if files: + message_id = files[0].get("id") + shares = ( + files[0].get("shares", {}) + .get("private", {}) + ) + if shares: + channel_id = next(iter(shares.keys()), None) return MessageResponse( message_id=message_id, chat_id=channel_id, diff --git a/app/modules/slack/slack.py b/app/modules/slack/slack.py index 982dc717..77b95f94 100644 --- a/app/modules/slack/slack.py +++ b/app/modules/slack/slack.py @@ -1,5 +1,6 @@ import re from threading import Lock +from pathlib import Path from typing import List, Optional, Tuple from urllib.parse import quote @@ -246,6 +247,48 @@ class Slack: logger.error(f"Slack消息发送失败: {msg_e}") return False, str(msg_e) + def send_file( + self, + file_path: str, + title: Optional[str] = None, + text: Optional[str] = None, + userid: Optional[str] = None, + file_name: Optional[str] = None, + ): + """ + 发送本地文件到 Slack。 + """ + if not self._client: + return False, "消息客户端未就绪" + if not file_path: + return False, "文件路径不能为空" + + local_file = Path(file_path) + if not local_file.exists() or not local_file.is_file(): + return False, f"文件不存在: {local_file}" + + try: + if userid: + channel = userid + else: + channel = self.__find_public_channel() + + comment_parts = [part for part in [title, text] if part] + initial_comment = "\n".join(comment_parts) if comment_parts else None + + with local_file.open("rb") as fp: + result = self._client.files_upload_v2( + channel=channel, + file=fp, + filename=file_name or local_file.name, + title=title or (file_name or local_file.name), + initial_comment=initial_comment, + ) + return True, result + except Exception as err: + logger.error(f"Slack文件发送失败: {err}") + return False, str(err) + def send_medias_msg(self, medias: List[MediaInfo], userid: Optional[str] = None, title: Optional[str] = None, buttons: Optional[List[List[dict]]] = None, original_message_id: Optional[str] = None, diff --git a/app/modules/synologychat/__init__.py b/app/modules/synologychat/__init__.py index bb0390c9..179e90cd 100644 --- a/app/modules/synologychat/__init__.py +++ b/app/modules/synologychat/__init__.py @@ -125,15 +125,17 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]): user_name = message.get("username") images = self._extract_images(message) audio_refs = self._extract_audio_refs(message) - if (text or images or audio_refs) and user_id: + files = self._extract_files(message) + if (text or images or audio_refs or files) and user_id: logger.info( f"收到来自 {client_config.name} 的SynologyChat消息:" f"userid={user_id}, username={user_name}, text={text}, " - f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}" + f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}, " + f"files={len(files) if files else 0}" ) return CommingMessage(channel=MessageChannel.SynologyChat, source=client_config.name, userid=user_id, username=user_name, text=text or "", - images=images, audio_refs=audio_refs) + images=images, audio_refs=audio_refs, files=files) except Exception as err: logger.debug(f"解析SynologyChat消息失败:{str(err)}") return None @@ -230,6 +232,56 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]): suffix in lowered for suffix in cls._AUDIO_SUFFIXES ) + @classmethod + def _extract_files( + cls, message: dict + ) -> Optional[List[CommingMessage.MessageAttachment]]: + files = [] + for key in ("attachments", "files"): + raw_value = message.get(key) + if not raw_value: + continue + try: + parsed = json.loads(raw_value) if isinstance(raw_value, str) else raw_value + except Exception: + parsed = raw_value + items = parsed if isinstance(parsed, list) else [parsed] + for item in items: + if not isinstance(item, dict): + continue + url = item.get("url") or item.get("file_url") or item.get("download_url") + if not isinstance(url, str) or not url.startswith("http"): + continue + content_type = ( + item.get("content_type") or item.get("mime_type") or "" + ).lower() + name = (item.get("name") or item.get("filename") or "").lower() + is_image = content_type.startswith("image/") or name.endswith( + cls._IMAGE_SUFFIXES + ) or cls._looks_like_image(url) + is_audio = content_type.startswith("audio/") or name.endswith( + cls._AUDIO_SUFFIXES + ) or cls._looks_like_audio(url) + if is_image or is_audio: + continue + files.append( + CommingMessage.MessageAttachment( + ref=f"synology://file/{quote(url, safe='')}", + name=item.get("name") or item.get("filename"), + mime_type=item.get("content_type") or item.get("mime_type"), + size=item.get("size"), + ) + ) + + deduped = [] + seen_refs = set() + for file_item in files: + if file_item.ref in seen_refs: + continue + seen_refs.add(file_item.ref) + deduped.append(file_item) + return deduped or None + def download_synologychat_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]: """ 下载 Synology Chat 音频文件并返回原始字节 diff --git a/app/modules/telegram/__init__.py b/app/modules/telegram/__init__.py index 79adf911..50385563 100644 --- a/app/modules/telegram/__init__.py +++ b/app/modules/telegram/__init__.py @@ -215,18 +215,20 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): images = self._extract_images(msg) audio_refs = self._extract_audio_refs(msg) + files = self._extract_files(msg) if user_id: - if not text and not images and not audio_refs: + if not text and not images and not audio_refs and not files: logger.debug( - f"收到来自 {client_config.name} 的Telegram消息无文本、图片和语音" + f"收到来自 {client_config.name} 的Telegram消息无文本、图片、语音和文件" ) return None logger.info( f"收到来自 {client_config.name} 的Telegram消息:" f"userid={user_id}, username={user_name}, chat_id={chat_id}, text={text}, " - f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}" + f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}, " + f"files={len(files) if files else 0}" ) cleaned_text = ( @@ -266,6 +268,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): chat_id=str(chat_id) if chat_id else None, images=images if images else None, audio_refs=audio_refs if audio_refs else None, + files=files if files else None, ) return None @@ -311,6 +314,29 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): return audio_refs if audio_refs else None + @staticmethod + def _extract_files(msg: dict) -> Optional[List[CommingMessage.MessageAttachment]]: + """ + 从 Telegram 消息中提取非图片文件附件。 + """ + document = msg.get("document") + if not isinstance(document, dict): + return None + + file_id = document.get("file_id") + mime_type = (document.get("mime_type") or "").lower() + if not file_id or mime_type.startswith("image/"): + return None + + return [ + CommingMessage.MessageAttachment( + ref=f"tg://document_file_id/{file_id}", + name=document.get("file_name"), + mime_type=document.get("mime_type"), + size=document.get("file_size"), + ) + ] + @staticmethod def _embed_entity_links(text: str, entities: Optional[List[dict]]) -> str: """ @@ -412,7 +438,16 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): return client: Telegram = self.get_instance(conf.name) if client: - if message.voice_path: + if message.file_path: + client.send_file( + file_path=message.file_path, + file_name=message.file_name, + title=message.title, + text=message.text, + userid=userid, + original_chat_id=message.original_chat_id, + ) + elif message.voice_path: client.send_voice( voice_path=message.voice_path, userid=userid, diff --git a/app/modules/telegram/telegram.py b/app/modules/telegram/telegram.py index 538a7069..da1e3909 100644 --- a/app/modules/telegram/telegram.py +++ b/app/modules/telegram/telegram.py @@ -507,6 +507,70 @@ class Telegram: except Exception as cleanup_err: logger.debug(f"清理语音临时文件失败: {cleanup_err}") + def send_file( + self, + file_path: str, + userid: Optional[str] = None, + title: Optional[str] = None, + text: Optional[str] = None, + file_name: Optional[str] = None, + original_chat_id: Optional[str] = None, + ) -> Optional[dict]: + """ + 发送本地图片或文件给 Telegram 用户。 + """ + if not self._bot or not file_path: + return None + + local_file = Path(file_path) + if not local_file.exists() or not local_file.is_file(): + logger.error(f"附件文件不存在: {local_file}") + return {"success": False} + + chat_id = self._determine_target_chat_id(userid, original_chat_id) + send_name = file_name or local_file.name + suffix = local_file.suffix.lower() + is_image = suffix in {".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp"} + + try: + bold_title = ( + f"**{standardize(title).removesuffix('\n')}**" if title else None + ) + if bold_title and text: + caption = f"{bold_title}\n{text}" + elif bold_title: + caption = bold_title + else: + caption = text or "" + + with local_file.open("rb") as fp: + if is_image: + sent = self._bot.send_photo( + chat_id=chat_id, + photo=fp, + caption=standardize(caption) if caption else None, + parse_mode="MarkdownV2" if caption else None, + ) + else: + sent = self._bot.send_document( + chat_id=chat_id, + document=(send_name, fp), + caption=standardize(caption) if caption else None, + parse_mode="MarkdownV2" if caption else None, + ) + self._stop_typing_task(chat_id) + if sent and hasattr(sent, "message_id"): + return { + "success": True, + "message_id": sent.message_id, + "chat_id": sent.chat.id if hasattr(sent, "chat") else chat_id, + } + return {"success": bool(sent)} + except Exception as err: + logger.error(f"发送本地附件失败: {err}") + self._stop_typing_task(chat_id) + return {"success": False} + def _determine_target_chat_id( self, userid: Optional[str] = None, original_chat_id: Optional[str] = None ) -> str: diff --git a/app/modules/vocechat/__init__.py b/app/modules/vocechat/__init__.py index 93f3fad5..ca88f6d2 100644 --- a/app/modules/vocechat/__init__.py +++ b/app/modules/vocechat/__init__.py @@ -133,6 +133,7 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]): content = detail.get("content") images = self._extract_images(detail) audio_refs = self._extract_audio_refs(detail) + files = self._extract_files(detail) text = None if content_type in ("text/plain", "text/markdown") and isinstance(content, str): text = content @@ -147,15 +148,15 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]): userid = f"UID#{msg_body.get('from_uid')}" # 处理消息内容 - if (text or images or audio_refs) and userid: + if (text or images or audio_refs or files) and userid: logger.info( f"收到来自 {client_config.name} 的VoceChat消息:" f"userid={userid}, text={text}, images={len(images) if images else 0}, " - f"audios={len(audio_refs) if audio_refs else 0}" + f"audios={len(audio_refs) if audio_refs else 0}, files={len(files) if files else 0}" ) return CommingMessage(channel=MessageChannel.VoceChat, source=client_config.name, userid=userid, username=userid, text=text or "", - images=images, audio_refs=audio_refs) + images=images, audio_refs=audio_refs, files=files) except Exception as err: logger.error(f"VoceChat消息处理发生错误:{str(err)}") return None @@ -229,6 +230,51 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]): return [f"vocechat://file/{quote(file_path, safe='')}"] return None + @classmethod + def _extract_files( + cls, detail: dict + ) -> Optional[List[CommingMessage.MessageAttachment]]: + content_type = detail.get("content_type") or "" + if content_type != "vocechat/file": + return None + properties = detail.get("properties") or {} + mime_type = ( + properties.get("content_type") + or properties.get("mime_type") + or properties.get("contentType") + or "" + ).lower() + file_path = ( + properties.get("path") + or properties.get("file_path") + or properties.get("storage_path") + or detail.get("content") + ) + file_name = ( + properties.get("name") + or properties.get("filename") + or (str(file_path).rsplit("/", 1)[-1] if file_path else "") + ) + lowered_name = str(file_name).lower() + is_image = mime_type.startswith("image/") or lowered_name.endswith( + cls._IMAGE_SUFFIXES + ) + is_audio = mime_type.startswith("audio/") or lowered_name.endswith( + cls._AUDIO_SUFFIXES + ) + if is_image or is_audio or not isinstance(file_path, str) or not file_path: + return None + return [ + CommingMessage.MessageAttachment( + ref=f"vocechat://file/{quote(file_path, safe='')}", + name=file_name, + mime_type=properties.get("content_type") + or properties.get("mime_type") + or properties.get("contentType"), + size=properties.get("size"), + ) + ] + def post_message(self, message: Notification, **kwargs) -> None: """ 发送消息 diff --git a/app/modules/wechat/__init__.py b/app/modules/wechat/__init__.py index cee2c55b..16aaec42 100644 --- a/app/modules/wechat/__init__.py +++ b/app/modules/wechat/__init__.py @@ -3,6 +3,7 @@ import json import re import xml.dom.minidom from typing import Optional, Union, List, Tuple, Any, Dict +from urllib.parse import quote from app.core.context import Context, MediaInfo from app.core.event import eventmanager @@ -168,6 +169,7 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]): content = None images = None audio_refs = None + files = None if msg_type == "event" and event == "click": # 校验用户有权限执行交互命令 if client_config.config.get('WECHAT_ADMINS'): @@ -203,14 +205,27 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]): f"收到来自 {client_config.name} 的微信语音消息:userid={user_id}, " f"text={content}, audios={len(audio_refs) if audio_refs else 0}" ) + elif msg_type == "file": + media_id = DomUtils.tag_value(root_node, "MediaId") + file_name = DomUtils.tag_value(root_node, "FileName") + if media_id: + files = [ + CommingMessage.MessageAttachment( + ref=f"wxwork://file_media_id/{media_id}", + name=file_name, + ) + ] + logger.info( + f"收到来自 {client_config.name} 的微信文件消息:userid={user_id}, files={len(files) if files else 0}" + ) else: return None - if content or images or audio_refs: + if content or images or audio_refs or files: # 处理消息内容 return CommingMessage(channel=MessageChannel.Wechat, source=client_config.name, userid=user_id, username=user_id, text=content or "", - images=images, audio_refs=audio_refs) + images=images, audio_refs=audio_refs, files=files) except Exception as err: logger.error(f"微信消息处理发生错误:{str(err)}") return None @@ -242,6 +257,20 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]): text = WeChatBot._extract_text_from_body(payload_body) images = WeChatBot._extract_images_from_body(payload_body) audio_refs = ["wxbot://voice"] if payload_body.get("msgtype") == "voice" else None + files = None + if payload_body.get("msgtype") == "file": + file_payload = payload_body.get("file") or {} + download_url = file_payload.get("download_url") + if download_url: + files = [ + CommingMessage.MessageAttachment( + ref=f"wxbot://file/{quote(download_url, safe='')}", + name=file_payload.get("name") or file_payload.get("filename"), + mime_type=file_payload.get("content_type") + or file_payload.get("mime_type"), + size=file_payload.get("size"), + ) + ] if text: text = re.sub(r"@\S+", "", text).strip() @@ -257,7 +286,7 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]): client.send_msg(title="只有管理员才有权限执行此命令", userid=sender) return None - if not text and not images and not audio_refs: + if not text and not images and not audio_refs and not files: return None logger.info( @@ -272,6 +301,7 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]): text=text or "", images=images, audio_refs=audio_refs, + files=files, ) def post_message(self, message: Notification, **kwargs) -> None: @@ -338,6 +368,9 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]): if media_ref.startswith("wxwork://voice_media_id/"): media_id = media_ref.replace("wxwork://voice_media_id/", "", 1) return client.download_media_bytes(media_id) + if media_ref.startswith("wxwork://file_media_id/"): + media_id = media_ref.replace("wxwork://file_media_id/", "", 1) + return client.download_media_bytes(media_id) return None def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None: diff --git a/app/schemas/message.py b/app/schemas/message.py index 89512cc4..8766c559 100644 --- a/app/schemas/message.py +++ b/app/schemas/message.py @@ -29,6 +29,16 @@ class CommingMessage(BaseModel): 外来消息 """ + class MessageAttachment(BaseModel): + """ + 外来消息附件(非图片/非语音) + """ + + ref: str + name: Optional[str] = None + mime_type: Optional[str] = None + size: Optional[int] = None + # 用户ID userid: Optional[Union[str, int]] = None # 用户名称 @@ -57,6 +67,8 @@ class CommingMessage(BaseModel): images: Optional[List[str]] = None # 语音/音频引用列表 audio_refs: Optional[List[str]] = None + # 文件附件列表 + files: Optional[List[MessageAttachment]] = None def to_dict(self): """ @@ -90,6 +102,10 @@ class Notification(BaseModel): image: Optional[str] = None # 语音文件路径 voice_path: Optional[str] = None + # 本地文件路径 + file_path: Optional[str] = None + # 发送时展示的文件名 + file_name: Optional[str] = None # 语音消息附带说明文字 voice_caption: Optional[str] = None # 链接 @@ -254,6 +270,7 @@ class ChannelCapabilityManager: ChannelCapability.IMAGES, ChannelCapability.LINKS, ChannelCapability.MENU_COMMANDS, + ChannelCapability.FILE_SENDING, }, max_buttons_per_row=3, max_button_rows=8, @@ -272,6 +289,7 @@ class ChannelCapabilityManager: ChannelCapability.RICH_TEXT, ChannelCapability.IMAGES, ChannelCapability.LINKS, + ChannelCapability.FILE_SENDING, }, max_buttons_per_row=5, max_button_rows=5, diff --git a/tests/test_agent_image_support.py b/tests/test_agent_image_support.py index 6ad64470..73c6a5bb 100644 --- a/tests/test_agent_image_support.py +++ b/tests/test_agent_image_support.py @@ -1,6 +1,8 @@ import base64 import json +import tempfile import unittest +from pathlib import Path from types import SimpleNamespace from unittest.mock import AsyncMock, Mock, patch from urllib.parse import quote @@ -8,6 +10,7 @@ from urllib.parse import quote from telebot import apihelper from app.agent.tools.impl.send_message import SendMessageInput +from app.agent.tools.impl.send_local_file import SendLocalFileInput from app.agent import MoviePilotAgent, AgentChain from app.chain.message import MessageChain from app.core.config import settings @@ -161,6 +164,32 @@ class AgentImageSupportTest(unittest.TestCase): self.assertEqual(handle_kwargs["text"], "") self.assertEqual(handle_kwargs["audio_refs"], ["tg://voice_file_id/voice-1"]) + def test_process_allows_file_only_message(self): + chain = MessageChain() + message = CommingMessage( + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + username="tester", + files=[ + CommingMessage.MessageAttachment( + ref="tg://document_file_id/doc-1", + name="note.txt", + mime_type="text/plain", + size=12, + ) + ], + ) + + with patch.object(chain, "message_parser", return_value=message), patch.object( + chain, "handle_message" + ) as handle_message: + chain.process(body="{}", form={}, args={"source": "telegram-test"}) + + handle_kwargs = handle_message.call_args.kwargs + self.assertEqual(handle_kwargs["text"], "") + self.assertEqual(handle_kwargs["files"][0].ref, "tg://document_file_id/doc-1") + def test_image_message_routes_to_agent_even_when_global_agent_is_disabled(self): chain = MessageChain() @@ -205,6 +234,36 @@ class AgentImageSupportTest(unittest.TestCase): self.assertEqual(handle_ai_message.call_args.kwargs["text"], "帮我推荐一部电影") self.assertTrue(handle_ai_message.call_args.kwargs["reply_with_voice"]) + def test_file_message_routes_to_agent_even_when_global_agent_is_disabled(self): + chain = MessageChain() + + with patch.object(chain, "load_cache", return_value={}), patch.object( + chain.messagehelper, "put" + ), patch.object(chain.messageoper, "add"), patch.object( + chain, "_handle_ai_message" + ) as handle_ai_message, patch.object( + settings, "AI_AGENT_ENABLE", True + ), patch.object( + settings, "AI_AGENT_GLOBAL", False + ): + chain.handle_message( + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + username="tester", + text="", + files=[ + CommingMessage.MessageAttachment( + ref="tg://document_file_id/doc-1", + name="report.txt", + mime_type="text/plain", + ) + ], + ) + + handle_ai_message.assert_called_once() + self.assertEqual(handle_ai_message.call_args.kwargs["files"][0].name, "report.txt") + def test_transcribe_audio_refs_supports_new_channel_refs(self): chain = MessageChain() audio_refs = [ @@ -276,6 +335,41 @@ class AgentImageSupportTest(unittest.TestCase): self.assertIsNone(notification.voice_path) self.assertEqual(notification.text, "这是语音回复") + def test_agent_process_wraps_request_as_structured_json(self): + agent = MoviePilotAgent( + session_id="session-1", + user_id="user-1", + channel=MessageChannel.Telegram.value, + source="telegram-test", + username="tester", + ) + + with patch( + "app.agent.memory.memory_manager.get_agent_messages", return_value=[] + ), patch.object(agent, "_execute_agent", new_callable=AsyncMock) as execute_agent: + import asyncio + + asyncio.run( + agent.process( + "帮我总结这个文件", + files=[ + { + "name": "report.txt", + "local_path": "/tmp/report.txt", + "status": "ready", + } + ], + ) + ) + + messages = execute_agent.await_args.args[0] + human_message = messages[-1] + content = human_message.content + self.assertIsInstance(content, list) + payload = json.loads(content[0]["text"]) + self.assertEqual(payload["message"], "帮我总结这个文件") + self.assertEqual(payload["files"][0]["local_path"], "/tmp/report.txt") + def test_slack_images_use_authenticated_data_url_download(self): chain = MessageChain() @@ -345,6 +439,15 @@ class AgentImageSupportTest(unittest.TestCase): self.assertEqual(payload.image_url, "https://example.com/poster.png") + def test_send_local_file_input_accepts_file_payload(self): + payload = SendLocalFileInput( + explanation="send generated report", + file_path="/tmp/report.txt", + message="请下载查看", + ) + + self.assertEqual(payload.file_path, "/tmp/report.txt") + def test_discord_extract_images_supports_attachment_content_type(self): images = DiscordModule._extract_images( { @@ -380,6 +483,26 @@ class AgentImageSupportTest(unittest.TestCase): ], ) + def test_discord_extract_files_supports_non_media_attachment(self): + files = DiscordModule._extract_files( + { + "attachments": [ + { + "content_type": "application/pdf", + "filename": "guide.pdf", + "url": "https://cdn.discordapp.com/guide.pdf", + "size": 1024, + } + ] + } + ) + + self.assertEqual(files[0].name, "guide.pdf") + self.assertEqual( + files[0].ref, + "discord://file/" + quote("https://cdn.discordapp.com/guide.pdf", safe=""), + ) + def test_discord_send_direct_message_returns_chat_id(self): module = DiscordModule() client = Mock() @@ -466,6 +589,46 @@ class AgentImageSupportTest(unittest.TestCase): self.assertIsNotNone(message) self.assertEqual(message.images, ["wxwork://media_id/media-1"]) + def test_wechat_message_parser_extracts_file_media_id(self): + module = WechatModule() + xml_message = b""" + + + + + + + """ + crypt = Mock() + crypt.DecryptMsg.return_value = (0, xml_message) + + with patch.object( + module, + "get_config", + return_value=SimpleNamespace( + name="wechat-test", + config={ + "WECHAT_TOKEN": "token", + "WECHAT_ENCODING_AESKEY": "encoding", + "WECHAT_CORPID": "corpid", + }, + ), + ), patch.object( + module, "get_instance", return_value=SimpleNamespace(send_msg=Mock()) + ), patch( + "app.modules.wechat.WXBizMsgCrypt", + return_value=crypt, + ): + message = module.message_parser( + source="wechat-test", + body=b"encrypted", + form={}, + args={"msg_signature": "sig", "timestamp": "1", "nonce": "n"}, + ) + + self.assertIsNotNone(message) + self.assertEqual(message.files[0].ref, "wxwork://file_media_id/file-media-1") + def test_wechat_bot_parser_accepts_image_only_payload(self): module = WechatModule() body = json.dumps( @@ -594,6 +757,38 @@ class AgentImageSupportTest(unittest.TestCase): ["vocechat://file/%2Fuploads%2Fvoice.ogg"], ) + def test_vocechat_message_parser_extracts_generic_file_payload(self): + module = VoceChatModule() + body = json.dumps( + { + "detail": { + "type": "normal", + "content_type": "vocechat/file", + "content": "/uploads/manual.pdf", + "properties": {"content_type": "application/pdf"}, + }, + "from_uid": 7910, + "target": {"gid": 2}, + } + ) + + with patch.object( + module, + "get_config", + return_value=SimpleNamespace( + name="vocechat-test", config={"channel_id": "2"} + ), + ): + message = module.message_parser( + source="vocechat-test", + body=body, + form={}, + args={}, + ) + + self.assertIsNotNone(message) + self.assertEqual(message.files[0].ref, "vocechat://file/%2Fuploads%2Fmanual.pdf") + def test_vocechat_post_message_passes_image_and_correct_target(self): module = VoceChatModule() client = Mock() @@ -623,6 +818,64 @@ class AgentImageSupportTest(unittest.TestCase): link=None, ) + def test_slack_post_message_passes_local_file(self): + module = SlackModule() + client = Mock() + + with tempfile.TemporaryDirectory() as tempdir: + file_path = Path(tempdir) / "guide.pdf" + file_path.write_bytes(b"pdf") + + with patch.object( + module, + "get_configs", + return_value={"slack-test": SimpleNamespace(name="slack-test")}, + ), patch.object( + module, "check_message", return_value=True + ), patch.object( + module, "get_instance", return_value=client + ): + module.post_message( + Notification( + title="手册", + text="请下载", + file_path=str(file_path), + file_name="guide.pdf", + userid="U123", + ) + ) + + client.send_file.assert_called_once() + + def test_discord_post_message_passes_local_file(self): + module = DiscordModule() + client = Mock() + + with tempfile.TemporaryDirectory() as tempdir: + file_path = Path(tempdir) / "guide.pdf" + file_path.write_bytes(b"pdf") + + with patch.object( + module, + "get_configs", + return_value={"discord-test": SimpleNamespace(name="discord-test")}, + ), patch.object( + module, "check_message", return_value=True + ), patch.object( + module, "get_instance", return_value=client + ): + module.post_message( + Notification( + title="手册", + text="请下载", + file_path=str(file_path), + file_name="guide.pdf", + userid="user-1", + ) + ) + + client.send_file.assert_called_once() + def test_qq_message_parser_accepts_image_only_attachment(self): module = QQBotModule() @@ -745,5 +998,97 @@ class AgentImageSupportTest(unittest.TestCase): ["synology://file/" + quote("https://example.com/voice.ogg", safe="")], ) + def test_synology_message_parser_accepts_generic_file_attachment(self): + module = SynologyChatModule() + + with patch.object( + module, + "get_config", + return_value=SimpleNamespace(name="synology-test", config={}), + ), patch.object( + module, + "get_instance", + return_value=SimpleNamespace(check_token=lambda token: token == "token-1"), + ): + message = module.message_parser( + source="synology-test", + body={}, + form={ + "token": "token-1", + "user_id": "42", + "username": "tester", + "attachments": json.dumps( + [ + { + "url": "https://example.com/manual.pdf", + "content_type": "application/pdf", + "filename": "manual.pdf", + } + ] + ), + }, + args={}, + ) + + self.assertIsNotNone(message) + self.assertEqual( + message.files[0].ref, + "synology://file/" + quote("https://example.com/manual.pdf", safe=""), + ) + + def test_prepare_agent_files_saves_local_file(self): + chain = MessageChain() + with tempfile.TemporaryDirectory() as tempdir, patch.object( + settings, "TEMP_PATH", Path(tempdir) + ), patch.object( + chain, + "_download_message_file_bytes", + return_value="你好,MoviePilot".encode("utf-8"), + ): + prepared = chain._prepare_agent_files( + session_id="session-1", + files=[ + CommingMessage.MessageAttachment( + ref="tg://document_file_id/doc-1", + name="note.txt", + mime_type="text/plain", + ) + ], + channel=MessageChannel.Telegram, + source="telegram-test", + ) + + self.assertEqual(prepared[0]["status"], "ready") + self.assertTrue(Path(prepared[0]["local_path"]).exists()) + + def test_telegram_post_message_passes_file_to_client(self): + module = TelegramModule() + client = Mock() + + with tempfile.TemporaryDirectory() as tempdir: + file_path = Path(tempdir) / "report.txt" + file_path.write_text("hello", encoding="utf-8") + + with patch.object( + module, + "get_configs", + return_value={"telegram-test": SimpleNamespace(name="telegram-test")}, + ), patch.object( + module, "check_message", return_value=True + ), patch.object( + module, "get_instance", return_value=client + ): + module.post_message( + Notification( + title="报告", + text="请下载", + file_path=str(file_path), + file_name="report.txt", + userid="user-1", + ) + ) + + client.send_file.assert_called_once() + if __name__ == "__main__": unittest.main()