mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-06 20:42:43 +08:00
feat(agent): support file attachments and local file replies
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import traceback
|
||||
import uuid
|
||||
@@ -281,13 +282,19 @@ class MoviePilotAgent:
|
||||
logger.error(f"创建 Agent 失败: {e}")
|
||||
raise e
|
||||
|
||||
async def process(self, message: str, images: List[str] = None) -> str:
|
||||
async def process(
|
||||
self,
|
||||
message: str,
|
||||
images: List[str] = None,
|
||||
files: Optional[List[dict]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
处理用户消息,流式推理并返回 Agent 回复
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
f"Agent推理: session_id={self.session_id}, input={message}, images={len(images) if images else 0}"
|
||||
f"Agent推理: session_id={self.session_id}, input={message}, "
|
||||
f"images={len(images) if images else 0}, files={len(files) if files else 0}"
|
||||
)
|
||||
self._tool_context = {
|
||||
"incoming_voice": self.reply_with_voice,
|
||||
@@ -300,16 +307,24 @@ class MoviePilotAgent:
|
||||
session_id=self.session_id, user_id=self.user_id
|
||||
)
|
||||
|
||||
# 构建用户消息内容
|
||||
if images:
|
||||
content = []
|
||||
if message:
|
||||
content.append({"type": "text", "text": message})
|
||||
for img in images:
|
||||
content.append({"type": "image_url", "image_url": {"url": img}})
|
||||
messages.append(HumanMessage(content=content))
|
||||
else:
|
||||
messages.append(HumanMessage(content=message))
|
||||
# 构建结构化用户消息内容
|
||||
request_payload = {
|
||||
"message": message or "",
|
||||
"images": [
|
||||
{"index": index + 1, "type": "image"}
|
||||
for index, _ in enumerate(images or [])
|
||||
],
|
||||
"files": files or [],
|
||||
}
|
||||
content = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(request_payload, ensure_ascii=False, indent=2),
|
||||
}
|
||||
]
|
||||
for img in images or []:
|
||||
content.append({"type": "image_url", "image_url": {"url": img}})
|
||||
messages.append(HumanMessage(content=content))
|
||||
|
||||
# 执行推理
|
||||
await self._execute_agent(messages)
|
||||
@@ -544,6 +559,7 @@ class _MessageTask:
|
||||
user_id: str
|
||||
message: str
|
||||
images: Optional[List[str]] = None
|
||||
files: Optional[List[dict]] = None
|
||||
channel: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
username: Optional[str] = None
|
||||
@@ -610,6 +626,7 @@ class AgentManager:
|
||||
user_id: str,
|
||||
message: str,
|
||||
images: List[str] = None,
|
||||
files: Optional[List[dict]] = None,
|
||||
channel: str = None,
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
@@ -624,6 +641,7 @@ class AgentManager:
|
||||
user_id=user_id,
|
||||
message=message,
|
||||
images=images,
|
||||
files=files,
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username,
|
||||
@@ -727,7 +745,7 @@ class AgentManager:
|
||||
agent.username = task.username
|
||||
agent.reply_with_voice = task.reply_with_voice
|
||||
|
||||
return await agent.process(task.message, images=task.images)
|
||||
return await agent.process(task.message, images=task.images, files=task.files)
|
||||
|
||||
async def stop_current_task(self, session_id: str):
|
||||
"""
|
||||
|
||||
@@ -10,6 +10,7 @@ Core Capabilities:
|
||||
3. Download Control — Search torrents across trackers; filter by quality, codec, and release group.
|
||||
4. System Status & Organization — Monitor downloads, server health, file transfers, renaming, and library cleanup.
|
||||
5. Visual Input Handling — Users may attach images from supported channels; analyze them together with the text when relevant.
|
||||
6. File Context Handling — User messages may arrive as structured JSON. Treat the `message` field as the user's text. Non-image attachments appear in `files`; when `local_path` is present, use local file tools to inspect the uploaded file directly.
|
||||
|
||||
<communication>
|
||||
{verbose_spec}
|
||||
@@ -21,6 +22,7 @@ Core Capabilities:
|
||||
- Include key details (year, rating, resolution) but do NOT over-explain.
|
||||
- Do not stop for approval on read-only operations. Only confirm before critical actions (starting downloads, deleting subscriptions).
|
||||
- If the current channel supports image sending and an image would materially help, you may use the `send_message` tool with `image_url` to send it.
|
||||
- If the current channel supports file sending and you need to return a local image/file for the user to download, use `send_local_file`.
|
||||
- Voice replies: {voice_reply_spec}
|
||||
- NOT a coding assistant. Do not offer code snippets.
|
||||
- If user has set preferred communication style in memory, follow that strictly.
|
||||
|
||||
@@ -30,6 +30,7 @@ from app.agent.tools.impl.search_torrents import SearchTorrentsTool
|
||||
from app.agent.tools.impl.get_search_results import GetSearchResultsTool
|
||||
from app.agent.tools.impl.search_web import SearchWebTool
|
||||
from app.agent.tools.impl.send_message import SendMessageTool
|
||||
from app.agent.tools.impl.send_local_file import SendLocalFileTool
|
||||
from app.agent.tools.impl.send_voice_message import SendVoiceMessageTool
|
||||
from app.agent.tools.impl.query_schedulers import QuerySchedulersTool
|
||||
from app.agent.tools.impl.run_scheduler import RunSchedulerTool
|
||||
@@ -119,6 +120,7 @@ class MoviePilotToolFactory:
|
||||
QueryTransferHistoryTool,
|
||||
TransferFileTool,
|
||||
SendMessageTool,
|
||||
SendLocalFileTool,
|
||||
SendVoiceMessageTool,
|
||||
QuerySchedulersTool,
|
||||
RunSchedulerTool,
|
||||
|
||||
107
app/agent/tools/impl/send_local_file.py
Normal file
107
app/agent/tools/impl/send_local_file.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""发送本地附件工具。"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotificationType
|
||||
from app.schemas.message import ChannelCapabilityManager, ChannelCapability
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class SendLocalFileInput(BaseModel):
|
||||
"""发送本地附件工具输入。"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why sending this local file helps the user",
|
||||
)
|
||||
file_path: str = Field(
|
||||
...,
|
||||
description="Absolute path to the local image or file to send to the user",
|
||||
)
|
||||
message: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional message or caption to send with the attachment",
|
||||
)
|
||||
title: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional short title shown together with the attachment",
|
||||
)
|
||||
file_name: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional override filename presented to the user when downloading",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_file_path(self):
|
||||
if not self.file_path:
|
||||
raise ValueError("file_path 不能为空")
|
||||
return self
|
||||
|
||||
|
||||
class SendLocalFileTool(MoviePilotTool):
|
||||
name: str = "send_local_file"
|
||||
description: str = (
|
||||
"Send a local image or file from the server filesystem to the current user. "
|
||||
"Use this when you have generated or identified a local file the user should download."
|
||||
)
|
||||
args_schema: Type[BaseModel] = SendLocalFileInput
|
||||
require_admin: bool = False
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
file_path = kwargs.get("file_path", "")
|
||||
file_name = Path(file_path).name if file_path else "未知文件"
|
||||
return f"正在发送本地附件: {file_name}"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
file_path: str,
|
||||
message: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
if not self._channel or not self._source:
|
||||
return "当前不在可回传消息的会话中,无法发送附件"
|
||||
|
||||
try:
|
||||
channel = MessageChannel(self._channel)
|
||||
except ValueError:
|
||||
return f"不支持的消息渠道: {self._channel}"
|
||||
|
||||
if not ChannelCapabilityManager.supports_capability(
|
||||
channel, ChannelCapability.FILE_SENDING
|
||||
):
|
||||
return f"当前渠道 {channel.value} 暂不支持发送本地文件"
|
||||
|
||||
resolved_path = Path(file_path).expanduser()
|
||||
if not resolved_path.is_absolute():
|
||||
resolved_path = resolved_path.resolve()
|
||||
if not resolved_path.exists() or not resolved_path.is_file():
|
||||
return f"文件不存在: {resolved_path}"
|
||||
|
||||
logger.info(
|
||||
"执行工具: %s, channel=%s, file=%s",
|
||||
self.name,
|
||||
channel.value,
|
||||
resolved_path,
|
||||
)
|
||||
|
||||
await ToolChain().async_post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=self._source,
|
||||
mtype=NotificationType.Agent,
|
||||
userid=self._user_id,
|
||||
username=self._username,
|
||||
title=title,
|
||||
text=message,
|
||||
file_path=str(resolved_path),
|
||||
file_name=file_name or resolved_path.name,
|
||||
)
|
||||
)
|
||||
return "本地附件已发送"
|
||||
@@ -19,7 +19,7 @@ class SendMessageInput(BaseModel):
|
||||
None,
|
||||
description="The message content to send to the user (should be clear and informative)",
|
||||
)
|
||||
message_type: Optional[str] = Field(
|
||||
title: Optional[str] = Field(
|
||||
None,
|
||||
description="Title of the message, a short summary of the message content",
|
||||
)
|
||||
@@ -30,8 +30,8 @@ class SendMessageInput(BaseModel):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_payload(self):
|
||||
if not self.message and not self.message_type and not self.image_url:
|
||||
raise ValueError("message、message_type、image_url 至少需要提供一个")
|
||||
if not self.message and not self.title and not self.image_url:
|
||||
raise ValueError("message、title、image_url 至少需要提供一个")
|
||||
return self
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ class SendMessageTool(MoviePilotTool):
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据消息参数生成友好的提示消息"""
|
||||
message = kwargs.get("message", "") or ""
|
||||
title = kwargs.get("message_type") or ""
|
||||
title = kwargs.get("title") or ""
|
||||
image_url = kwargs.get("image_url")
|
||||
|
||||
# 截断过长的消息
|
||||
@@ -62,11 +62,11 @@ class SendMessageTool(MoviePilotTool):
|
||||
async def run(
|
||||
self,
|
||||
message: Optional[str] = None,
|
||||
message_type: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
image_url: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
title = message_type or ("图片" if image_url and not message else "")
|
||||
title = title or ("图片" if image_url and not message else "")
|
||||
text = message or ""
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: title={title}, message={text}, image_url={image_url}"
|
||||
|
||||
@@ -2,8 +2,10 @@ import asyncio
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Dict, Union, List
|
||||
from urllib.parse import unquote
|
||||
import uuid
|
||||
|
||||
import base64
|
||||
|
||||
@@ -135,7 +137,8 @@ class MessageChain(ChainBase):
|
||||
text = str(info.text).strip() if info.text else ""
|
||||
images = info.images
|
||||
audio_refs = info.audio_refs
|
||||
if not text and not images and not audio_refs:
|
||||
files = info.files
|
||||
if not text and not images and not audio_refs and not files:
|
||||
logger.debug(f"未识别到消息内容::{body}{form}{args}")
|
||||
return
|
||||
|
||||
@@ -154,6 +157,7 @@ class MessageChain(ChainBase):
|
||||
original_chat_id=original_chat_id,
|
||||
images=images,
|
||||
audio_refs=audio_refs,
|
||||
files=files,
|
||||
)
|
||||
|
||||
def handle_message(
|
||||
@@ -167,6 +171,7 @@ class MessageChain(ChainBase):
|
||||
original_chat_id: Optional[str] = None,
|
||||
images: Optional[List[str]] = None,
|
||||
audio_refs: Optional[List[str]] = None,
|
||||
files: Optional[List[CommingMessage.MessageAttachment]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
识别消息内容,执行操作
|
||||
@@ -253,9 +258,12 @@ class MessageChain(ChainBase):
|
||||
userid=userid,
|
||||
username=username,
|
||||
images=images,
|
||||
files=files,
|
||||
reply_with_voice=reply_with_voice,
|
||||
)
|
||||
elif settings.AI_AGENT_ENABLE and settings.AI_AGENT_GLOBAL:
|
||||
elif settings.AI_AGENT_ENABLE and (
|
||||
settings.AI_AGENT_GLOBAL or images or files
|
||||
):
|
||||
# 普通消息,全局智能体响应
|
||||
self._handle_ai_message(
|
||||
text=text,
|
||||
@@ -264,6 +272,7 @@ class MessageChain(ChainBase):
|
||||
userid=userid,
|
||||
username=username,
|
||||
images=images,
|
||||
files=files,
|
||||
reply_with_voice=reply_with_voice,
|
||||
)
|
||||
else:
|
||||
@@ -1230,6 +1239,7 @@ class MessageChain(ChainBase):
|
||||
userid: Union[str, int],
|
||||
username: str,
|
||||
images: Optional[List[str]] = None,
|
||||
files: Optional[List[CommingMessage.MessageAttachment]] = None,
|
||||
reply_with_voice: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -1255,7 +1265,7 @@ class MessageChain(ChainBase):
|
||||
else:
|
||||
user_message = text.strip() # 按原消息处理
|
||||
|
||||
if not user_message and not images:
|
||||
if not user_message and not images and not files:
|
||||
self.post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
@@ -1274,7 +1284,7 @@ class MessageChain(ChainBase):
|
||||
original_images = images
|
||||
if images:
|
||||
images = self._download_images_to_base64(images, channel, source)
|
||||
if original_images and not images and not user_message:
|
||||
if original_images and not images and not user_message and not files:
|
||||
self.post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
@@ -1286,6 +1296,24 @@ class MessageChain(ChainBase):
|
||||
)
|
||||
return
|
||||
|
||||
prepared_files = self._prepare_agent_files(
|
||||
session_id=session_id,
|
||||
files=files,
|
||||
channel=channel,
|
||||
source=source,
|
||||
)
|
||||
if files and not prepared_files and not user_message and not images:
|
||||
self.post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="文件读取失败,请稍后重试",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# 在事件循环中处理
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
agent_manager.process_message(
|
||||
@@ -1293,6 +1321,7 @@ class MessageChain(ChainBase):
|
||||
user_id=str(userid),
|
||||
message=user_message,
|
||||
images=images,
|
||||
files=prepared_files,
|
||||
channel=channel.value if channel else None,
|
||||
source=source,
|
||||
username=username,
|
||||
@@ -1481,3 +1510,148 @@ class MessageChain(ChainBase):
|
||||
except Exception as e:
|
||||
logger.error(f"下载图片失败: {img}, error: {e}")
|
||||
return base64_images if base64_images else None
|
||||
|
||||
def _prepare_agent_files(
|
||||
self,
|
||||
session_id: str,
|
||||
files: Optional[List[CommingMessage.MessageAttachment]],
|
||||
channel: MessageChannel,
|
||||
source: str,
|
||||
) -> Optional[List[dict]]:
|
||||
"""
|
||||
下载用户上传的文件,落盘到临时目录,并生成文本镜像供 Agent 使用。
|
||||
"""
|
||||
if not files:
|
||||
return None
|
||||
|
||||
prepared_files = []
|
||||
for attachment in files:
|
||||
payload = {
|
||||
"name": attachment.name,
|
||||
"mime_type": attachment.mime_type,
|
||||
"size": attachment.size,
|
||||
"ref": attachment.ref,
|
||||
"status": "download_failed",
|
||||
}
|
||||
try:
|
||||
content = self._download_message_file_bytes(
|
||||
file_ref=attachment.ref,
|
||||
channel=channel,
|
||||
source=source,
|
||||
)
|
||||
if not content:
|
||||
prepared_files.append(payload)
|
||||
continue
|
||||
|
||||
local_path = self._save_agent_attachment(
|
||||
session_id=session_id,
|
||||
filename=attachment.name,
|
||||
content=content,
|
||||
mime_type=attachment.mime_type,
|
||||
)
|
||||
payload.update(
|
||||
{
|
||||
"local_path": str(local_path),
|
||||
"status": "ready",
|
||||
}
|
||||
)
|
||||
except Exception as err:
|
||||
logger.error(f"准备文件上下文失败: {attachment.ref}, error: {err}")
|
||||
payload["error"] = str(err)
|
||||
prepared_files.append(payload)
|
||||
|
||||
return prepared_files or None
|
||||
|
||||
def _download_message_file_bytes(
|
||||
self, file_ref: str, channel: MessageChannel, source: str
|
||||
) -> Optional[bytes]:
|
||||
"""
|
||||
下载消息附件的原始字节。
|
||||
"""
|
||||
if not file_ref:
|
||||
return None
|
||||
if file_ref.startswith("tg://document_file_id/"):
|
||||
file_id = file_ref.replace("tg://document_file_id/", "", 1)
|
||||
return self.run_module(
|
||||
"download_telegram_file_bytes", file_id=file_id, source=source
|
||||
)
|
||||
if file_ref.startswith("wxwork://file_media_id/"):
|
||||
return self.run_module(
|
||||
"download_wechat_media_bytes", media_ref=file_ref, source=source
|
||||
)
|
||||
if file_ref.startswith("wxbot://file/"):
|
||||
file_url = unquote(file_ref.replace("wxbot://file/", "", 1))
|
||||
resp = RequestUtils(timeout=30).get_res(file_url)
|
||||
return resp.content if resp and resp.content else None
|
||||
if file_ref.startswith("slack://file/"):
|
||||
return self.run_module(
|
||||
"download_slack_file_bytes", file_ref=file_ref, source=source
|
||||
)
|
||||
if file_ref.startswith("discord://file/"):
|
||||
return self.run_module(
|
||||
"download_discord_file_bytes", file_ref=file_ref, source=source
|
||||
)
|
||||
if file_ref.startswith("qq://file/"):
|
||||
return self.run_module(
|
||||
"download_qq_file_bytes", file_ref=file_ref, source=source
|
||||
)
|
||||
if file_ref.startswith("vocechat://file/"):
|
||||
return self.run_module(
|
||||
"download_vocechat_file_bytes", file_ref=file_ref, source=source
|
||||
)
|
||||
if file_ref.startswith("synology://file/"):
|
||||
return self.run_module(
|
||||
"download_synologychat_file_bytes", file_ref=file_ref, source=source
|
||||
)
|
||||
if file_ref.startswith("http"):
|
||||
resp = RequestUtils(timeout=30).get_res(file_ref)
|
||||
return resp.content if resp and resp.content else None
|
||||
logger.debug(
|
||||
"暂不支持的文件引用: channel=%s, source=%s, ref=%s",
|
||||
channel.value if channel else None,
|
||||
source,
|
||||
file_ref,
|
||||
)
|
||||
return None
|
||||
|
||||
def _save_agent_attachment(
|
||||
self,
|
||||
session_id: str,
|
||||
filename: Optional[str],
|
||||
content: bytes,
|
||||
mime_type: Optional[str] = None,
|
||||
) -> Path:
|
||||
"""
|
||||
将用户上传文件写入临时目录,并返回本地路径。
|
||||
"""
|
||||
safe_name = self._sanitize_attachment_name(filename, mime_type)
|
||||
base_dir = settings.TEMP_PATH / "agent_uploads" / session_id
|
||||
base_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
file_id = uuid.uuid4().hex[:8]
|
||||
local_path = base_dir / f"{file_id}_{safe_name}"
|
||||
local_path.write_bytes(content or b"")
|
||||
return local_path
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_attachment_name(
|
||||
filename: Optional[str], mime_type: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
规范化附件文件名,避免路径穿越和非法字符。
|
||||
"""
|
||||
name = Path(filename or "attachment").name
|
||||
name = re.sub(r"[^\w.\-]+", "_", name, flags=re.ASCII).strip("._")
|
||||
if not name:
|
||||
name = "attachment"
|
||||
if "." not in name:
|
||||
mime = (mime_type or "").split(";", 1)[0].strip().lower()
|
||||
default_ext = {
|
||||
"application/json": ".json",
|
||||
"text/plain": ".txt",
|
||||
"text/markdown": ".md",
|
||||
"text/csv": ".csv",
|
||||
}.get(mime)
|
||||
if default_ext:
|
||||
name = f"{name}{default_ext}"
|
||||
return name
|
||||
|
||||
@@ -159,11 +159,13 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
chat_id = msg_json.get("chat_id")
|
||||
images = self._extract_images(msg_json)
|
||||
audio_refs = self._extract_audio_refs(msg_json)
|
||||
if (text or images or audio_refs) and userid:
|
||||
files = self._extract_files(msg_json)
|
||||
if (text or images or audio_refs or files) and userid:
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的 Discord 消息:"
|
||||
f"userid={userid}, username={username}, text={text}, "
|
||||
f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}"
|
||||
f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}, "
|
||||
f"files={len(files) if files else 0}"
|
||||
)
|
||||
return CommingMessage(
|
||||
channel=MessageChannel.Discord,
|
||||
@@ -174,6 +176,7 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
chat_id=str(chat_id) if chat_id else None,
|
||||
images=images,
|
||||
audio_refs=audio_refs,
|
||||
files=files,
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -219,6 +222,44 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
audio_refs.append(f"discord://file/{quote(url, safe='')}")
|
||||
return audio_refs if audio_refs else None
|
||||
|
||||
@classmethod
|
||||
def _extract_files(
|
||||
cls, msg_json: dict
|
||||
) -> Optional[List[CommingMessage.MessageAttachment]]:
|
||||
"""
|
||||
从 Discord 消息中提取非图片/非音频文件。
|
||||
"""
|
||||
attachments = msg_json.get("attachments", [])
|
||||
if not attachments:
|
||||
return None
|
||||
|
||||
files = []
|
||||
for attachment in attachments:
|
||||
url = attachment.get("url") or attachment.get("proxy_url")
|
||||
if not url:
|
||||
continue
|
||||
content_type = (attachment.get("content_type") or "").lower()
|
||||
filename = (attachment.get("filename") or "").lower()
|
||||
is_image = (
|
||||
attachment.get("type") == "image"
|
||||
or content_type.startswith("image/")
|
||||
or filename.endswith(cls._IMAGE_SUFFIXES)
|
||||
)
|
||||
is_audio = content_type.startswith("audio/") or filename.endswith(
|
||||
cls._AUDIO_SUFFIXES
|
||||
)
|
||||
if is_image or is_audio:
|
||||
continue
|
||||
files.append(
|
||||
CommingMessage.MessageAttachment(
|
||||
ref=f"discord://file/{quote(url, safe='')}",
|
||||
name=attachment.get("filename"),
|
||||
mime_type=attachment.get("content_type"),
|
||||
size=attachment.get("size"),
|
||||
)
|
||||
)
|
||||
return files or None
|
||||
|
||||
def download_discord_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]:
|
||||
"""
|
||||
下载Discord附件并返回原始字节
|
||||
@@ -278,19 +319,29 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
)
|
||||
if client:
|
||||
logger.debug(
|
||||
f"[Discord] 调用 client.send_msg, userid={userid}, title={message.title[:50] if message.title else None}..."
|
||||
)
|
||||
result = client.send_msg(
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
image=message.image,
|
||||
userid=userid,
|
||||
link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
mtype=message.mtype,
|
||||
f"[Discord] 调用 client 发送, userid={userid}, title={message.title[:50] if message.title else None}..."
|
||||
)
|
||||
if message.file_path:
|
||||
result = client.send_file(
|
||||
file_path=message.file_path,
|
||||
file_name=message.file_name,
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
original_chat_id=message.original_chat_id,
|
||||
)
|
||||
else:
|
||||
result = client.send_msg(
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
image=message.image,
|
||||
userid=userid,
|
||||
link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
mtype=message.mtype,
|
||||
)
|
||||
logger.debug(f"[Discord] send_msg 返回结果: {result}")
|
||||
else:
|
||||
logger.warning(
|
||||
@@ -427,11 +478,20 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
return None
|
||||
client: Discord = self.get_instance(conf.name)
|
||||
if client:
|
||||
result = client.send_msg(
|
||||
title=message.title or "",
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
)
|
||||
if message.file_path:
|
||||
result = client.send_file(
|
||||
file_path=message.file_path,
|
||||
file_name=message.file_name,
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
)
|
||||
else:
|
||||
result = client.send_msg(
|
||||
title=message.title or "",
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
)
|
||||
if result:
|
||||
success, response_data = (
|
||||
(result[0], result[1])
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import re
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Any, Tuple, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
@@ -273,6 +274,37 @@ class Discord:
|
||||
logger.error(f"发送 Discord 消息失败:{err}")
|
||||
return False
|
||||
|
||||
def send_file(
|
||||
self,
|
||||
file_path: str,
|
||||
title: Optional[str] = None,
|
||||
text: Optional[str] = None,
|
||||
userid: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
) -> Optional[bool]:
|
||||
if not self.get_state():
|
||||
return False
|
||||
if not file_path:
|
||||
return False
|
||||
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._send_file(
|
||||
file_path=file_path,
|
||||
title=title,
|
||||
text=text,
|
||||
userid=userid,
|
||||
file_name=file_name,
|
||||
original_chat_id=original_chat_id,
|
||||
),
|
||||
self._loop,
|
||||
)
|
||||
return future.result(timeout=30)
|
||||
except Exception as err:
|
||||
logger.error(f"发送 Discord 文件失败:{err}")
|
||||
return False
|
||||
|
||||
def send_medias_msg(
|
||||
self,
|
||||
medias: List[MediaInfo],
|
||||
@@ -414,6 +446,46 @@ class Discord:
|
||||
logger.error(f"[Discord] 发送消息到频道失败: {e}")
|
||||
return False, None
|
||||
|
||||
async def _send_file(
|
||||
self,
|
||||
file_path: str,
|
||||
title: Optional[str],
|
||||
text: Optional[str],
|
||||
userid: Optional[str],
|
||||
file_name: Optional[str],
|
||||
original_chat_id: Optional[str],
|
||||
) -> Tuple[bool, Optional[Dict[str, str]]]:
|
||||
channel = await self._resolve_channel(userid=userid, chat_id=original_chat_id)
|
||||
if not channel:
|
||||
logger.error("未找到可用的 Discord 频道或私聊")
|
||||
return False, None
|
||||
|
||||
local_file = Path(file_path)
|
||||
if not local_file.exists() or not local_file.is_file():
|
||||
logger.error(f"Discord发送文件失败,文件不存在: {local_file}")
|
||||
return False, None
|
||||
|
||||
content_parts = [part for part in [title, text] if part]
|
||||
content = "\n".join(content_parts) if content_parts else None
|
||||
if content and len(content) > 1900:
|
||||
content = content[:1900] + "..."
|
||||
|
||||
try:
|
||||
discord_file = discord.File(
|
||||
str(local_file), filename=file_name or local_file.name
|
||||
)
|
||||
sent_message = await channel.send(content=content, file=discord_file)
|
||||
return (
|
||||
True,
|
||||
{
|
||||
"message_id": str(sent_message.id),
|
||||
"chat_id": str(channel.id),
|
||||
},
|
||||
)
|
||||
except Exception as err:
|
||||
logger.error(f"Discord发送文件失败: {err}")
|
||||
return False, None
|
||||
|
||||
async def _send_list_message(
|
||||
self,
|
||||
embeds: List[discord.Embed],
|
||||
|
||||
@@ -107,7 +107,8 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
|
||||
content = (msg_body.get("content") or "").strip()
|
||||
images = self._extract_images(msg_body)
|
||||
audio_refs = self._extract_audio_refs(msg_body)
|
||||
if not content and not images and not audio_refs:
|
||||
files = self._extract_files(msg_body)
|
||||
if not content and not images and not audio_refs and not files:
|
||||
return None
|
||||
|
||||
if msg_type == "C2C_MESSAGE_CREATE":
|
||||
@@ -118,7 +119,7 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
|
||||
logger.info(
|
||||
f"收到 QQ 私聊消息: userid={user_openid}, "
|
||||
f"text={(content or '')[:50]}..., images={len(images) if images else 0}, "
|
||||
f"audios={len(audio_refs) if audio_refs else 0}"
|
||||
f"audios={len(audio_refs) if audio_refs else 0}, files={len(files) if files else 0}"
|
||||
)
|
||||
return CommingMessage(
|
||||
channel=MessageChannel.QQ,
|
||||
@@ -128,6 +129,7 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
|
||||
text=content,
|
||||
images=images,
|
||||
audio_refs=audio_refs,
|
||||
files=files,
|
||||
)
|
||||
elif msg_type == "GROUP_AT_MESSAGE_CREATE":
|
||||
author = msg_body.get("author", {})
|
||||
@@ -138,7 +140,7 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
|
||||
logger.info(
|
||||
f"收到 QQ 群消息: group={group_openid}, userid={member_openid}, "
|
||||
f"text={(content or '')[:50]}..., images={len(images) if images else 0}, "
|
||||
f"audios={len(audio_refs) if audio_refs else 0}"
|
||||
f"audios={len(audio_refs) if audio_refs else 0}, files={len(files) if files else 0}"
|
||||
)
|
||||
return CommingMessage(
|
||||
channel=MessageChannel.QQ,
|
||||
@@ -148,6 +150,7 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
|
||||
text=content,
|
||||
images=images,
|
||||
audio_refs=audio_refs,
|
||||
files=files,
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -226,6 +229,46 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
|
||||
deduped.append(audio_ref)
|
||||
return deduped or None
|
||||
|
||||
@classmethod
|
||||
def _extract_files(
|
||||
cls, msg_body: dict
|
||||
) -> Optional[List[CommingMessage.MessageAttachment]]:
|
||||
files: List[CommingMessage.MessageAttachment] = []
|
||||
attachments = msg_body.get("attachments") or []
|
||||
if isinstance(attachments, list):
|
||||
for attachment in attachments:
|
||||
if not isinstance(attachment, dict):
|
||||
continue
|
||||
url = attachment.get("url") or attachment.get("proxy_url")
|
||||
if not url:
|
||||
continue
|
||||
content_type = (
|
||||
attachment.get("content_type")
|
||||
or attachment.get("mime_type")
|
||||
or ""
|
||||
).lower()
|
||||
filename = (
|
||||
attachment.get("filename") or attachment.get("name") or ""
|
||||
).lower()
|
||||
is_image = content_type.startswith("image/") or filename.endswith(
|
||||
cls._IMAGE_SUFFIXES
|
||||
)
|
||||
is_audio = content_type.startswith("audio/") or filename.endswith(
|
||||
cls._AUDIO_SUFFIXES
|
||||
)
|
||||
if is_image or is_audio:
|
||||
continue
|
||||
files.append(
|
||||
CommingMessage.MessageAttachment(
|
||||
ref=f"qq://file/{quote(url, safe='')}",
|
||||
name=attachment.get("filename") or attachment.get("name"),
|
||||
mime_type=attachment.get("content_type")
|
||||
or attachment.get("mime_type"),
|
||||
size=attachment.get("size"),
|
||||
)
|
||||
)
|
||||
return files or None
|
||||
|
||||
def download_qq_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]:
|
||||
"""
|
||||
下载QQ音频附件并返回原始字节
|
||||
|
||||
@@ -221,12 +221,14 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
if msg_json:
|
||||
images = None
|
||||
audio_refs = None
|
||||
files = None
|
||||
if msg_json.get("type") == "message":
|
||||
userid = msg_json.get("user")
|
||||
text = msg_json.get("text")
|
||||
username = msg_json.get("user")
|
||||
images = self._extract_images(msg_json)
|
||||
audio_refs = self._extract_audio_refs(msg_json)
|
||||
files = self._extract_files(msg_json)
|
||||
elif msg_json.get("type") == "block_actions":
|
||||
userid = msg_json.get("user", {}).get("id")
|
||||
callback_data = msg_json.get("actions")[0].get("value")
|
||||
@@ -270,6 +272,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
username = ""
|
||||
images = self._extract_images(msg_json.get("event", {}))
|
||||
audio_refs = self._extract_audio_refs(msg_json.get("event", {}))
|
||||
files = self._extract_files(msg_json.get("event", {}))
|
||||
elif msg_json.get("type") == "shortcut":
|
||||
userid = msg_json.get("user", {}).get("id")
|
||||
text = msg_json.get("callback_id")
|
||||
@@ -282,7 +285,8 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
return None
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的Slack消息:userid={userid}, username={username}, "
|
||||
f"text={text}, images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}"
|
||||
f"text={text}, images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}, "
|
||||
f"files={len(files) if files else 0}"
|
||||
)
|
||||
return CommingMessage(
|
||||
channel=MessageChannel.Slack,
|
||||
@@ -292,6 +296,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
text=text,
|
||||
images=images,
|
||||
audio_refs=audio_refs,
|
||||
files=files,
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -341,6 +346,48 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
audio_refs.append(f"slack://file/{quote(url, safe='')}")
|
||||
return audio_refs if audio_refs else None
|
||||
|
||||
@classmethod
|
||||
def _extract_files(
|
||||
cls, msg_json: dict
|
||||
) -> Optional[List[CommingMessage.MessageAttachment]]:
|
||||
"""
|
||||
从 Slack 消息中提取非图片/非音频文件。
|
||||
"""
|
||||
files = msg_json.get("files", [])
|
||||
if not files:
|
||||
return None
|
||||
|
||||
attachments = []
|
||||
for file in files:
|
||||
file_type = str(file.get("type", "")).lower()
|
||||
file_ext = f".{str(file.get('filetype', '')).lower().lstrip('.')}"
|
||||
mime_type = str(file.get("mimetype", "")).lower()
|
||||
is_image = (
|
||||
file_type == "image"
|
||||
or file_ext in (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp")
|
||||
or mime_type.startswith("image/")
|
||||
)
|
||||
is_audio = (
|
||||
file_type == "audio"
|
||||
or mime_type.startswith("audio/")
|
||||
or file_ext in cls._AUDIO_SUFFIXES
|
||||
)
|
||||
if is_image or is_audio:
|
||||
continue
|
||||
|
||||
url = file.get("url_private_download") or file.get("url_private")
|
||||
if not url:
|
||||
continue
|
||||
attachments.append(
|
||||
CommingMessage.MessageAttachment(
|
||||
ref=f"slack://file/{quote(url, safe='')}",
|
||||
name=file.get("name") or file.get("title"),
|
||||
mime_type=file.get("mimetype"),
|
||||
size=file.get("size"),
|
||||
)
|
||||
)
|
||||
return attachments or None
|
||||
|
||||
def download_slack_file_to_data_url(self, file_url: str, source: str) -> Optional[str]:
|
||||
"""
|
||||
下载Slack文件并转为data URL
|
||||
@@ -399,16 +446,25 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
return
|
||||
client: Slack = self.get_instance(conf.name)
|
||||
if client:
|
||||
client.send_msg(
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
image=message.image,
|
||||
userid=userid,
|
||||
link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
)
|
||||
if message.file_path:
|
||||
client.send_file(
|
||||
file_path=message.file_path,
|
||||
file_name=message.file_name,
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
)
|
||||
else:
|
||||
client.send_msg(
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
image=message.image,
|
||||
userid=userid,
|
||||
link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
)
|
||||
|
||||
def post_medias_message(
|
||||
self, message: Notification, medias: List[MediaInfo]
|
||||
@@ -538,26 +594,40 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
return None
|
||||
client: Slack = self.get_instance(conf.name)
|
||||
if client:
|
||||
result = client.send_msg(
|
||||
title=message.title or "",
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
)
|
||||
if message.file_path:
|
||||
result = client.send_file(
|
||||
file_path=message.file_path,
|
||||
file_name=message.file_name,
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
)
|
||||
else:
|
||||
result = client.send_msg(
|
||||
title=message.title or "",
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
)
|
||||
if result and result[0]:
|
||||
# Slack 使用时间戳作为 message_id,chat_id 是频道ID
|
||||
# 注意:这里返回的是发送后的结果,需要获取实际的 message_id
|
||||
# 由于 Slack API 返回的是 result[1],包含完整响应,我们需要从中提取
|
||||
response_data = result[1]
|
||||
message_id = (
|
||||
response_data.get("ts")
|
||||
if isinstance(response_data, dict)
|
||||
else None
|
||||
)
|
||||
channel_id = (
|
||||
response_data.get("channel")
|
||||
if isinstance(response_data, dict)
|
||||
else None
|
||||
)
|
||||
message_id = None
|
||||
channel_id = None
|
||||
if hasattr(response_data, "get"):
|
||||
message_id = response_data.get("ts")
|
||||
channel_id = response_data.get("channel")
|
||||
if not message_id and hasattr(response_data, "data"):
|
||||
files = (response_data.data or {}).get("files") or []
|
||||
if files:
|
||||
message_id = files[0].get("id")
|
||||
shares = (
|
||||
files[0].get("shares", {})
|
||||
.get("private", {})
|
||||
)
|
||||
if shares:
|
||||
channel_id = next(iter(shares.keys()), None)
|
||||
return MessageResponse(
|
||||
message_id=message_id,
|
||||
chat_id=channel_id,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import re
|
||||
from threading import Lock
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
from urllib.parse import quote
|
||||
|
||||
@@ -246,6 +247,48 @@ class Slack:
|
||||
logger.error(f"Slack消息发送失败: {msg_e}")
|
||||
return False, str(msg_e)
|
||||
|
||||
def send_file(
|
||||
self,
|
||||
file_path: str,
|
||||
title: Optional[str] = None,
|
||||
text: Optional[str] = None,
|
||||
userid: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
发送本地文件到 Slack。
|
||||
"""
|
||||
if not self._client:
|
||||
return False, "消息客户端未就绪"
|
||||
if not file_path:
|
||||
return False, "文件路径不能为空"
|
||||
|
||||
local_file = Path(file_path)
|
||||
if not local_file.exists() or not local_file.is_file():
|
||||
return False, f"文件不存在: {local_file}"
|
||||
|
||||
try:
|
||||
if userid:
|
||||
channel = userid
|
||||
else:
|
||||
channel = self.__find_public_channel()
|
||||
|
||||
comment_parts = [part for part in [title, text] if part]
|
||||
initial_comment = "\n".join(comment_parts) if comment_parts else None
|
||||
|
||||
with local_file.open("rb") as fp:
|
||||
result = self._client.files_upload_v2(
|
||||
channel=channel,
|
||||
file=fp,
|
||||
filename=file_name or local_file.name,
|
||||
title=title or (file_name or local_file.name),
|
||||
initial_comment=initial_comment,
|
||||
)
|
||||
return True, result
|
||||
except Exception as err:
|
||||
logger.error(f"Slack文件发送失败: {err}")
|
||||
return False, str(err)
|
||||
|
||||
def send_medias_msg(self, medias: List[MediaInfo], userid: Optional[str] = None, title: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[str] = None,
|
||||
|
||||
@@ -125,15 +125,17 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]):
|
||||
user_name = message.get("username")
|
||||
images = self._extract_images(message)
|
||||
audio_refs = self._extract_audio_refs(message)
|
||||
if (text or images or audio_refs) and user_id:
|
||||
files = self._extract_files(message)
|
||||
if (text or images or audio_refs or files) and user_id:
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的SynologyChat消息:"
|
||||
f"userid={user_id}, username={user_name}, text={text}, "
|
||||
f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}"
|
||||
f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}, "
|
||||
f"files={len(files) if files else 0}"
|
||||
)
|
||||
return CommingMessage(channel=MessageChannel.SynologyChat, source=client_config.name,
|
||||
userid=user_id, username=user_name, text=text or "",
|
||||
images=images, audio_refs=audio_refs)
|
||||
images=images, audio_refs=audio_refs, files=files)
|
||||
except Exception as err:
|
||||
logger.debug(f"解析SynologyChat消息失败:{str(err)}")
|
||||
return None
|
||||
@@ -230,6 +232,56 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]):
|
||||
suffix in lowered for suffix in cls._AUDIO_SUFFIXES
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_files(
|
||||
cls, message: dict
|
||||
) -> Optional[List[CommingMessage.MessageAttachment]]:
|
||||
files = []
|
||||
for key in ("attachments", "files"):
|
||||
raw_value = message.get(key)
|
||||
if not raw_value:
|
||||
continue
|
||||
try:
|
||||
parsed = json.loads(raw_value) if isinstance(raw_value, str) else raw_value
|
||||
except Exception:
|
||||
parsed = raw_value
|
||||
items = parsed if isinstance(parsed, list) else [parsed]
|
||||
for item in items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
url = item.get("url") or item.get("file_url") or item.get("download_url")
|
||||
if not isinstance(url, str) or not url.startswith("http"):
|
||||
continue
|
||||
content_type = (
|
||||
item.get("content_type") or item.get("mime_type") or ""
|
||||
).lower()
|
||||
name = (item.get("name") or item.get("filename") or "").lower()
|
||||
is_image = content_type.startswith("image/") or name.endswith(
|
||||
cls._IMAGE_SUFFIXES
|
||||
) or cls._looks_like_image(url)
|
||||
is_audio = content_type.startswith("audio/") or name.endswith(
|
||||
cls._AUDIO_SUFFIXES
|
||||
) or cls._looks_like_audio(url)
|
||||
if is_image or is_audio:
|
||||
continue
|
||||
files.append(
|
||||
CommingMessage.MessageAttachment(
|
||||
ref=f"synology://file/{quote(url, safe='')}",
|
||||
name=item.get("name") or item.get("filename"),
|
||||
mime_type=item.get("content_type") or item.get("mime_type"),
|
||||
size=item.get("size"),
|
||||
)
|
||||
)
|
||||
|
||||
deduped = []
|
||||
seen_refs = set()
|
||||
for file_item in files:
|
||||
if file_item.ref in seen_refs:
|
||||
continue
|
||||
seen_refs.add(file_item.ref)
|
||||
deduped.append(file_item)
|
||||
return deduped or None
|
||||
|
||||
def download_synologychat_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]:
|
||||
"""
|
||||
下载 Synology Chat 音频文件并返回原始字节
|
||||
|
||||
@@ -215,18 +215,20 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
|
||||
images = self._extract_images(msg)
|
||||
audio_refs = self._extract_audio_refs(msg)
|
||||
files = self._extract_files(msg)
|
||||
|
||||
if user_id:
|
||||
if not text and not images and not audio_refs:
|
||||
if not text and not images and not audio_refs and not files:
|
||||
logger.debug(
|
||||
f"收到来自 {client_config.name} 的Telegram消息无文本、图片和语音"
|
||||
f"收到来自 {client_config.name} 的Telegram消息无文本、图片、语音和文件"
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的Telegram消息:"
|
||||
f"userid={user_id}, username={user_name}, chat_id={chat_id}, text={text}, "
|
||||
f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}"
|
||||
f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}, "
|
||||
f"files={len(files) if files else 0}"
|
||||
)
|
||||
|
||||
cleaned_text = (
|
||||
@@ -266,6 +268,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
chat_id=str(chat_id) if chat_id else None,
|
||||
images=images if images else None,
|
||||
audio_refs=audio_refs if audio_refs else None,
|
||||
files=files if files else None,
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -311,6 +314,29 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
|
||||
return audio_refs if audio_refs else None
|
||||
|
||||
@staticmethod
|
||||
def _extract_files(msg: dict) -> Optional[List[CommingMessage.MessageAttachment]]:
|
||||
"""
|
||||
从 Telegram 消息中提取非图片文件附件。
|
||||
"""
|
||||
document = msg.get("document")
|
||||
if not isinstance(document, dict):
|
||||
return None
|
||||
|
||||
file_id = document.get("file_id")
|
||||
mime_type = (document.get("mime_type") or "").lower()
|
||||
if not file_id or mime_type.startswith("image/"):
|
||||
return None
|
||||
|
||||
return [
|
||||
CommingMessage.MessageAttachment(
|
||||
ref=f"tg://document_file_id/{file_id}",
|
||||
name=document.get("file_name"),
|
||||
mime_type=document.get("mime_type"),
|
||||
size=document.get("file_size"),
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _embed_entity_links(text: str, entities: Optional[List[dict]]) -> str:
|
||||
"""
|
||||
@@ -412,7 +438,16 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
return
|
||||
client: Telegram = self.get_instance(conf.name)
|
||||
if client:
|
||||
if message.voice_path:
|
||||
if message.file_path:
|
||||
client.send_file(
|
||||
file_path=message.file_path,
|
||||
file_name=message.file_name,
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
original_chat_id=message.original_chat_id,
|
||||
)
|
||||
elif message.voice_path:
|
||||
client.send_voice(
|
||||
voice_path=message.voice_path,
|
||||
userid=userid,
|
||||
|
||||
@@ -507,6 +507,70 @@ class Telegram:
|
||||
except Exception as cleanup_err:
|
||||
logger.debug(f"清理语音临时文件失败: {cleanup_err}")
|
||||
|
||||
def send_file(
|
||||
self,
|
||||
file_path: str,
|
||||
userid: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
text: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
发送本地图片或文件给 Telegram 用户。
|
||||
"""
|
||||
if not self._bot or not file_path:
|
||||
return None
|
||||
|
||||
local_file = Path(file_path)
|
||||
if not local_file.exists() or not local_file.is_file():
|
||||
logger.error(f"附件文件不存在: {local_file}")
|
||||
return {"success": False}
|
||||
|
||||
chat_id = self._determine_target_chat_id(userid, original_chat_id)
|
||||
send_name = file_name or local_file.name
|
||||
suffix = local_file.suffix.lower()
|
||||
is_image = suffix in {".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp"}
|
||||
|
||||
try:
|
||||
bold_title = (
|
||||
f"**{standardize(title).removesuffix('\n')}**" if title else None
|
||||
)
|
||||
if bold_title and text:
|
||||
caption = f"{bold_title}\n{text}"
|
||||
elif bold_title:
|
||||
caption = bold_title
|
||||
else:
|
||||
caption = text or ""
|
||||
|
||||
with local_file.open("rb") as fp:
|
||||
if is_image:
|
||||
sent = self._bot.send_photo(
|
||||
chat_id=chat_id,
|
||||
photo=fp,
|
||||
caption=standardize(caption) if caption else None,
|
||||
parse_mode="MarkdownV2" if caption else None,
|
||||
)
|
||||
else:
|
||||
sent = self._bot.send_document(
|
||||
chat_id=chat_id,
|
||||
document=(send_name, fp),
|
||||
caption=standardize(caption) if caption else None,
|
||||
parse_mode="MarkdownV2" if caption else None,
|
||||
)
|
||||
self._stop_typing_task(chat_id)
|
||||
if sent and hasattr(sent, "message_id"):
|
||||
return {
|
||||
"success": True,
|
||||
"message_id": sent.message_id,
|
||||
"chat_id": sent.chat.id if hasattr(sent, "chat") else chat_id,
|
||||
}
|
||||
return {"success": bool(sent)}
|
||||
except Exception as err:
|
||||
logger.error(f"发送本地附件失败: {err}")
|
||||
self._stop_typing_task(chat_id)
|
||||
return {"success": False}
|
||||
|
||||
def _determine_target_chat_id(
|
||||
self, userid: Optional[str] = None, original_chat_id: Optional[str] = None
|
||||
) -> str:
|
||||
|
||||
@@ -133,6 +133,7 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
|
||||
content = detail.get("content")
|
||||
images = self._extract_images(detail)
|
||||
audio_refs = self._extract_audio_refs(detail)
|
||||
files = self._extract_files(detail)
|
||||
text = None
|
||||
if content_type in ("text/plain", "text/markdown") and isinstance(content, str):
|
||||
text = content
|
||||
@@ -147,15 +148,15 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
|
||||
userid = f"UID#{msg_body.get('from_uid')}"
|
||||
|
||||
# 处理消息内容
|
||||
if (text or images or audio_refs) and userid:
|
||||
if (text or images or audio_refs or files) and userid:
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的VoceChat消息:"
|
||||
f"userid={userid}, text={text}, images={len(images) if images else 0}, "
|
||||
f"audios={len(audio_refs) if audio_refs else 0}"
|
||||
f"audios={len(audio_refs) if audio_refs else 0}, files={len(files) if files else 0}"
|
||||
)
|
||||
return CommingMessage(channel=MessageChannel.VoceChat, source=client_config.name,
|
||||
userid=userid, username=userid, text=text or "",
|
||||
images=images, audio_refs=audio_refs)
|
||||
images=images, audio_refs=audio_refs, files=files)
|
||||
except Exception as err:
|
||||
logger.error(f"VoceChat消息处理发生错误:{str(err)}")
|
||||
return None
|
||||
@@ -229,6 +230,51 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
|
||||
return [f"vocechat://file/{quote(file_path, safe='')}"]
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _extract_files(
|
||||
cls, detail: dict
|
||||
) -> Optional[List[CommingMessage.MessageAttachment]]:
|
||||
content_type = detail.get("content_type") or ""
|
||||
if content_type != "vocechat/file":
|
||||
return None
|
||||
properties = detail.get("properties") or {}
|
||||
mime_type = (
|
||||
properties.get("content_type")
|
||||
or properties.get("mime_type")
|
||||
or properties.get("contentType")
|
||||
or ""
|
||||
).lower()
|
||||
file_path = (
|
||||
properties.get("path")
|
||||
or properties.get("file_path")
|
||||
or properties.get("storage_path")
|
||||
or detail.get("content")
|
||||
)
|
||||
file_name = (
|
||||
properties.get("name")
|
||||
or properties.get("filename")
|
||||
or (str(file_path).rsplit("/", 1)[-1] if file_path else "")
|
||||
)
|
||||
lowered_name = str(file_name).lower()
|
||||
is_image = mime_type.startswith("image/") or lowered_name.endswith(
|
||||
cls._IMAGE_SUFFIXES
|
||||
)
|
||||
is_audio = mime_type.startswith("audio/") or lowered_name.endswith(
|
||||
cls._AUDIO_SUFFIXES
|
||||
)
|
||||
if is_image or is_audio or not isinstance(file_path, str) or not file_path:
|
||||
return None
|
||||
return [
|
||||
CommingMessage.MessageAttachment(
|
||||
ref=f"vocechat://file/{quote(file_path, safe='')}",
|
||||
name=file_name,
|
||||
mime_type=properties.get("content_type")
|
||||
or properties.get("mime_type")
|
||||
or properties.get("contentType"),
|
||||
size=properties.get("size"),
|
||||
)
|
||||
]
|
||||
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送消息
|
||||
|
||||
@@ -3,6 +3,7 @@ import json
|
||||
import re
|
||||
import xml.dom.minidom
|
||||
from typing import Optional, Union, List, Tuple, Any, Dict
|
||||
from urllib.parse import quote
|
||||
|
||||
from app.core.context import Context, MediaInfo
|
||||
from app.core.event import eventmanager
|
||||
@@ -168,6 +169,7 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
|
||||
content = None
|
||||
images = None
|
||||
audio_refs = None
|
||||
files = None
|
||||
if msg_type == "event" and event == "click":
|
||||
# 校验用户有权限执行交互命令
|
||||
if client_config.config.get('WECHAT_ADMINS'):
|
||||
@@ -203,14 +205,27 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
|
||||
f"收到来自 {client_config.name} 的微信语音消息:userid={user_id}, "
|
||||
f"text={content}, audios={len(audio_refs) if audio_refs else 0}"
|
||||
)
|
||||
elif msg_type == "file":
|
||||
media_id = DomUtils.tag_value(root_node, "MediaId")
|
||||
file_name = DomUtils.tag_value(root_node, "FileName")
|
||||
if media_id:
|
||||
files = [
|
||||
CommingMessage.MessageAttachment(
|
||||
ref=f"wxwork://file_media_id/{media_id}",
|
||||
name=file_name,
|
||||
)
|
||||
]
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的微信文件消息:userid={user_id}, files={len(files) if files else 0}"
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
if content or images or audio_refs:
|
||||
if content or images or audio_refs or files:
|
||||
# 处理消息内容
|
||||
return CommingMessage(channel=MessageChannel.Wechat, source=client_config.name,
|
||||
userid=user_id, username=user_id, text=content or "",
|
||||
images=images, audio_refs=audio_refs)
|
||||
images=images, audio_refs=audio_refs, files=files)
|
||||
except Exception as err:
|
||||
logger.error(f"微信消息处理发生错误:{str(err)}")
|
||||
return None
|
||||
@@ -242,6 +257,20 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
|
||||
text = WeChatBot._extract_text_from_body(payload_body)
|
||||
images = WeChatBot._extract_images_from_body(payload_body)
|
||||
audio_refs = ["wxbot://voice"] if payload_body.get("msgtype") == "voice" else None
|
||||
files = None
|
||||
if payload_body.get("msgtype") == "file":
|
||||
file_payload = payload_body.get("file") or {}
|
||||
download_url = file_payload.get("download_url")
|
||||
if download_url:
|
||||
files = [
|
||||
CommingMessage.MessageAttachment(
|
||||
ref=f"wxbot://file/{quote(download_url, safe='')}",
|
||||
name=file_payload.get("name") or file_payload.get("filename"),
|
||||
mime_type=file_payload.get("content_type")
|
||||
or file_payload.get("mime_type"),
|
||||
size=file_payload.get("size"),
|
||||
)
|
||||
]
|
||||
if text:
|
||||
text = re.sub(r"@\S+", "", text).strip()
|
||||
|
||||
@@ -257,7 +286,7 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
|
||||
client.send_msg(title="只有管理员才有权限执行此命令", userid=sender)
|
||||
return None
|
||||
|
||||
if not text and not images and not audio_refs:
|
||||
if not text and not images and not audio_refs and not files:
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
@@ -272,6 +301,7 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
|
||||
text=text or "",
|
||||
images=images,
|
||||
audio_refs=audio_refs,
|
||||
files=files,
|
||||
)
|
||||
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
@@ -338,6 +368,9 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
|
||||
if media_ref.startswith("wxwork://voice_media_id/"):
|
||||
media_id = media_ref.replace("wxwork://voice_media_id/", "", 1)
|
||||
return client.download_media_bytes(media_id)
|
||||
if media_ref.startswith("wxwork://file_media_id/"):
|
||||
media_id = media_ref.replace("wxwork://file_media_id/", "", 1)
|
||||
return client.download_media_bytes(media_id)
|
||||
return None
|
||||
|
||||
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
|
||||
|
||||
@@ -29,6 +29,16 @@ class CommingMessage(BaseModel):
|
||||
外来消息
|
||||
"""
|
||||
|
||||
class MessageAttachment(BaseModel):
|
||||
"""
|
||||
外来消息附件(非图片/非语音)
|
||||
"""
|
||||
|
||||
ref: str
|
||||
name: Optional[str] = None
|
||||
mime_type: Optional[str] = None
|
||||
size: Optional[int] = None
|
||||
|
||||
# 用户ID
|
||||
userid: Optional[Union[str, int]] = None
|
||||
# 用户名称
|
||||
@@ -57,6 +67,8 @@ class CommingMessage(BaseModel):
|
||||
images: Optional[List[str]] = None
|
||||
# 语音/音频引用列表
|
||||
audio_refs: Optional[List[str]] = None
|
||||
# 文件附件列表
|
||||
files: Optional[List[MessageAttachment]] = None
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
@@ -90,6 +102,10 @@ class Notification(BaseModel):
|
||||
image: Optional[str] = None
|
||||
# 语音文件路径
|
||||
voice_path: Optional[str] = None
|
||||
# 本地文件路径
|
||||
file_path: Optional[str] = None
|
||||
# 发送时展示的文件名
|
||||
file_name: Optional[str] = None
|
||||
# 语音消息附带说明文字
|
||||
voice_caption: Optional[str] = None
|
||||
# 链接
|
||||
@@ -254,6 +270,7 @@ class ChannelCapabilityManager:
|
||||
ChannelCapability.IMAGES,
|
||||
ChannelCapability.LINKS,
|
||||
ChannelCapability.MENU_COMMANDS,
|
||||
ChannelCapability.FILE_SENDING,
|
||||
},
|
||||
max_buttons_per_row=3,
|
||||
max_button_rows=8,
|
||||
@@ -272,6 +289,7 @@ class ChannelCapabilityManager:
|
||||
ChannelCapability.RICH_TEXT,
|
||||
ChannelCapability.IMAGES,
|
||||
ChannelCapability.LINKS,
|
||||
ChannelCapability.FILE_SENDING,
|
||||
},
|
||||
max_buttons_per_row=5,
|
||||
max_button_rows=5,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import base64
|
||||
import json
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from urllib.parse import quote
|
||||
@@ -8,6 +10,7 @@ from urllib.parse import quote
|
||||
from telebot import apihelper
|
||||
|
||||
from app.agent.tools.impl.send_message import SendMessageInput
|
||||
from app.agent.tools.impl.send_local_file import SendLocalFileInput
|
||||
from app.agent import MoviePilotAgent, AgentChain
|
||||
from app.chain.message import MessageChain
|
||||
from app.core.config import settings
|
||||
@@ -161,6 +164,32 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
self.assertEqual(handle_kwargs["text"], "")
|
||||
self.assertEqual(handle_kwargs["audio_refs"], ["tg://voice_file_id/voice-1"])
|
||||
|
||||
def test_process_allows_file_only_message(self):
|
||||
chain = MessageChain()
|
||||
message = CommingMessage(
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
files=[
|
||||
CommingMessage.MessageAttachment(
|
||||
ref="tg://document_file_id/doc-1",
|
||||
name="note.txt",
|
||||
mime_type="text/plain",
|
||||
size=12,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
with patch.object(chain, "message_parser", return_value=message), patch.object(
|
||||
chain, "handle_message"
|
||||
) as handle_message:
|
||||
chain.process(body="{}", form={}, args={"source": "telegram-test"})
|
||||
|
||||
handle_kwargs = handle_message.call_args.kwargs
|
||||
self.assertEqual(handle_kwargs["text"], "")
|
||||
self.assertEqual(handle_kwargs["files"][0].ref, "tg://document_file_id/doc-1")
|
||||
|
||||
def test_image_message_routes_to_agent_even_when_global_agent_is_disabled(self):
|
||||
chain = MessageChain()
|
||||
|
||||
@@ -205,6 +234,36 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
self.assertEqual(handle_ai_message.call_args.kwargs["text"], "帮我推荐一部电影")
|
||||
self.assertTrue(handle_ai_message.call_args.kwargs["reply_with_voice"])
|
||||
|
||||
def test_file_message_routes_to_agent_even_when_global_agent_is_disabled(self):
|
||||
chain = MessageChain()
|
||||
|
||||
with patch.object(chain, "load_cache", return_value={}), patch.object(
|
||||
chain.messagehelper, "put"
|
||||
), patch.object(chain.messageoper, "add"), patch.object(
|
||||
chain, "_handle_ai_message"
|
||||
) as handle_ai_message, patch.object(
|
||||
settings, "AI_AGENT_ENABLE", True
|
||||
), patch.object(
|
||||
settings, "AI_AGENT_GLOBAL", False
|
||||
):
|
||||
chain.handle_message(
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
text="",
|
||||
files=[
|
||||
CommingMessage.MessageAttachment(
|
||||
ref="tg://document_file_id/doc-1",
|
||||
name="report.txt",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
handle_ai_message.assert_called_once()
|
||||
self.assertEqual(handle_ai_message.call_args.kwargs["files"][0].name, "report.txt")
|
||||
|
||||
def test_transcribe_audio_refs_supports_new_channel_refs(self):
|
||||
chain = MessageChain()
|
||||
audio_refs = [
|
||||
@@ -276,6 +335,41 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
self.assertIsNone(notification.voice_path)
|
||||
self.assertEqual(notification.text, "这是语音回复")
|
||||
|
||||
def test_agent_process_wraps_request_as_structured_json(self):
|
||||
agent = MoviePilotAgent(
|
||||
session_id="session-1",
|
||||
user_id="user-1",
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-test",
|
||||
username="tester",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.agent.memory.memory_manager.get_agent_messages", return_value=[]
|
||||
), patch.object(agent, "_execute_agent", new_callable=AsyncMock) as execute_agent:
|
||||
import asyncio
|
||||
|
||||
asyncio.run(
|
||||
agent.process(
|
||||
"帮我总结这个文件",
|
||||
files=[
|
||||
{
|
||||
"name": "report.txt",
|
||||
"local_path": "/tmp/report.txt",
|
||||
"status": "ready",
|
||||
}
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
messages = execute_agent.await_args.args[0]
|
||||
human_message = messages[-1]
|
||||
content = human_message.content
|
||||
self.assertIsInstance(content, list)
|
||||
payload = json.loads(content[0]["text"])
|
||||
self.assertEqual(payload["message"], "帮我总结这个文件")
|
||||
self.assertEqual(payload["files"][0]["local_path"], "/tmp/report.txt")
|
||||
|
||||
def test_slack_images_use_authenticated_data_url_download(self):
|
||||
chain = MessageChain()
|
||||
|
||||
@@ -345,6 +439,15 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(payload.image_url, "https://example.com/poster.png")
|
||||
|
||||
def test_send_local_file_input_accepts_file_payload(self):
|
||||
payload = SendLocalFileInput(
|
||||
explanation="send generated report",
|
||||
file_path="/tmp/report.txt",
|
||||
message="请下载查看",
|
||||
)
|
||||
|
||||
self.assertEqual(payload.file_path, "/tmp/report.txt")
|
||||
|
||||
def test_discord_extract_images_supports_attachment_content_type(self):
|
||||
images = DiscordModule._extract_images(
|
||||
{
|
||||
@@ -380,6 +483,26 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
def test_discord_extract_files_supports_non_media_attachment(self):
|
||||
files = DiscordModule._extract_files(
|
||||
{
|
||||
"attachments": [
|
||||
{
|
||||
"content_type": "application/pdf",
|
||||
"filename": "guide.pdf",
|
||||
"url": "https://cdn.discordapp.com/guide.pdf",
|
||||
"size": 1024,
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual(files[0].name, "guide.pdf")
|
||||
self.assertEqual(
|
||||
files[0].ref,
|
||||
"discord://file/" + quote("https://cdn.discordapp.com/guide.pdf", safe=""),
|
||||
)
|
||||
|
||||
def test_discord_send_direct_message_returns_chat_id(self):
|
||||
module = DiscordModule()
|
||||
client = Mock()
|
||||
@@ -466,6 +589,46 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
self.assertIsNotNone(message)
|
||||
self.assertEqual(message.images, ["wxwork://media_id/media-1"])
|
||||
|
||||
def test_wechat_message_parser_extracts_file_media_id(self):
|
||||
module = WechatModule()
|
||||
xml_message = b"""
|
||||
<xml>
|
||||
<FromUserName><![CDATA[user-1]]></FromUserName>
|
||||
<MsgType><![CDATA[file]]></MsgType>
|
||||
<MediaId><![CDATA[file-media-1]]></MediaId>
|
||||
<FileName><![CDATA[manual.pdf]]></FileName>
|
||||
</xml>
|
||||
"""
|
||||
crypt = Mock()
|
||||
crypt.DecryptMsg.return_value = (0, xml_message)
|
||||
|
||||
with patch.object(
|
||||
module,
|
||||
"get_config",
|
||||
return_value=SimpleNamespace(
|
||||
name="wechat-test",
|
||||
config={
|
||||
"WECHAT_TOKEN": "token",
|
||||
"WECHAT_ENCODING_AESKEY": "encoding",
|
||||
"WECHAT_CORPID": "corpid",
|
||||
},
|
||||
),
|
||||
), patch.object(
|
||||
module, "get_instance", return_value=SimpleNamespace(send_msg=Mock())
|
||||
), patch(
|
||||
"app.modules.wechat.WXBizMsgCrypt",
|
||||
return_value=crypt,
|
||||
):
|
||||
message = module.message_parser(
|
||||
source="wechat-test",
|
||||
body=b"encrypted",
|
||||
form={},
|
||||
args={"msg_signature": "sig", "timestamp": "1", "nonce": "n"},
|
||||
)
|
||||
|
||||
self.assertIsNotNone(message)
|
||||
self.assertEqual(message.files[0].ref, "wxwork://file_media_id/file-media-1")
|
||||
|
||||
def test_wechat_bot_parser_accepts_image_only_payload(self):
|
||||
module = WechatModule()
|
||||
body = json.dumps(
|
||||
@@ -594,6 +757,38 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
["vocechat://file/%2Fuploads%2Fvoice.ogg"],
|
||||
)
|
||||
|
||||
def test_vocechat_message_parser_extracts_generic_file_payload(self):
|
||||
module = VoceChatModule()
|
||||
body = json.dumps(
|
||||
{
|
||||
"detail": {
|
||||
"type": "normal",
|
||||
"content_type": "vocechat/file",
|
||||
"content": "/uploads/manual.pdf",
|
||||
"properties": {"content_type": "application/pdf"},
|
||||
},
|
||||
"from_uid": 7910,
|
||||
"target": {"gid": 2},
|
||||
}
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
module,
|
||||
"get_config",
|
||||
return_value=SimpleNamespace(
|
||||
name="vocechat-test", config={"channel_id": "2"}
|
||||
),
|
||||
):
|
||||
message = module.message_parser(
|
||||
source="vocechat-test",
|
||||
body=body,
|
||||
form={},
|
||||
args={},
|
||||
)
|
||||
|
||||
self.assertIsNotNone(message)
|
||||
self.assertEqual(message.files[0].ref, "vocechat://file/%2Fuploads%2Fmanual.pdf")
|
||||
|
||||
def test_vocechat_post_message_passes_image_and_correct_target(self):
|
||||
module = VoceChatModule()
|
||||
client = Mock()
|
||||
@@ -623,6 +818,64 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
link=None,
|
||||
)
|
||||
|
||||
def test_slack_post_message_passes_local_file(self):
|
||||
module = SlackModule()
|
||||
client = Mock()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
file_path = Path(tempdir) / "guide.pdf"
|
||||
file_path.write_bytes(b"pdf")
|
||||
|
||||
with patch.object(
|
||||
module,
|
||||
"get_configs",
|
||||
return_value={"slack-test": SimpleNamespace(name="slack-test")},
|
||||
), patch.object(
|
||||
module, "check_message", return_value=True
|
||||
), patch.object(
|
||||
module, "get_instance", return_value=client
|
||||
):
|
||||
module.post_message(
|
||||
Notification(
|
||||
title="手册",
|
||||
text="请下载",
|
||||
file_path=str(file_path),
|
||||
file_name="guide.pdf",
|
||||
userid="U123",
|
||||
)
|
||||
)
|
||||
|
||||
client.send_file.assert_called_once()
|
||||
|
||||
def test_discord_post_message_passes_local_file(self):
|
||||
module = DiscordModule()
|
||||
client = Mock()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
file_path = Path(tempdir) / "guide.pdf"
|
||||
file_path.write_bytes(b"pdf")
|
||||
|
||||
with patch.object(
|
||||
module,
|
||||
"get_configs",
|
||||
return_value={"discord-test": SimpleNamespace(name="discord-test")},
|
||||
), patch.object(
|
||||
module, "check_message", return_value=True
|
||||
), patch.object(
|
||||
module, "get_instance", return_value=client
|
||||
):
|
||||
module.post_message(
|
||||
Notification(
|
||||
title="手册",
|
||||
text="请下载",
|
||||
file_path=str(file_path),
|
||||
file_name="guide.pdf",
|
||||
userid="user-1",
|
||||
)
|
||||
)
|
||||
|
||||
client.send_file.assert_called_once()
|
||||
|
||||
def test_qq_message_parser_accepts_image_only_attachment(self):
|
||||
module = QQBotModule()
|
||||
|
||||
@@ -745,5 +998,97 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
["synology://file/" + quote("https://example.com/voice.ogg", safe="")],
|
||||
)
|
||||
|
||||
def test_synology_message_parser_accepts_generic_file_attachment(self):
|
||||
module = SynologyChatModule()
|
||||
|
||||
with patch.object(
|
||||
module,
|
||||
"get_config",
|
||||
return_value=SimpleNamespace(name="synology-test", config={}),
|
||||
), patch.object(
|
||||
module,
|
||||
"get_instance",
|
||||
return_value=SimpleNamespace(check_token=lambda token: token == "token-1"),
|
||||
):
|
||||
message = module.message_parser(
|
||||
source="synology-test",
|
||||
body={},
|
||||
form={
|
||||
"token": "token-1",
|
||||
"user_id": "42",
|
||||
"username": "tester",
|
||||
"attachments": json.dumps(
|
||||
[
|
||||
{
|
||||
"url": "https://example.com/manual.pdf",
|
||||
"content_type": "application/pdf",
|
||||
"filename": "manual.pdf",
|
||||
}
|
||||
]
|
||||
),
|
||||
},
|
||||
args={},
|
||||
)
|
||||
|
||||
self.assertIsNotNone(message)
|
||||
self.assertEqual(
|
||||
message.files[0].ref,
|
||||
"synology://file/" + quote("https://example.com/manual.pdf", safe=""),
|
||||
)
|
||||
|
||||
def test_prepare_agent_files_saves_local_file(self):
|
||||
chain = MessageChain()
|
||||
with tempfile.TemporaryDirectory() as tempdir, patch.object(
|
||||
settings, "TEMP_PATH", Path(tempdir)
|
||||
), patch.object(
|
||||
chain,
|
||||
"_download_message_file_bytes",
|
||||
return_value="你好,MoviePilot".encode("utf-8"),
|
||||
):
|
||||
prepared = chain._prepare_agent_files(
|
||||
session_id="session-1",
|
||||
files=[
|
||||
CommingMessage.MessageAttachment(
|
||||
ref="tg://document_file_id/doc-1",
|
||||
name="note.txt",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
],
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
)
|
||||
|
||||
self.assertEqual(prepared[0]["status"], "ready")
|
||||
self.assertTrue(Path(prepared[0]["local_path"]).exists())
|
||||
|
||||
def test_telegram_post_message_passes_file_to_client(self):
|
||||
module = TelegramModule()
|
||||
client = Mock()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
file_path = Path(tempdir) / "report.txt"
|
||||
file_path.write_text("hello", encoding="utf-8")
|
||||
|
||||
with patch.object(
|
||||
module,
|
||||
"get_configs",
|
||||
return_value={"telegram-test": SimpleNamespace(name="telegram-test")},
|
||||
), patch.object(
|
||||
module, "check_message", return_value=True
|
||||
), patch.object(
|
||||
module, "get_instance", return_value=client
|
||||
):
|
||||
module.post_message(
|
||||
Notification(
|
||||
title="报告",
|
||||
text="请下载",
|
||||
file_path=str(file_path),
|
||||
file_name="report.txt",
|
||||
userid="user-1",
|
||||
)
|
||||
)
|
||||
|
||||
client.send_file.assert_called_once()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user