mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-06 20:42:43 +08:00
Add agent image support for Telegram and Slack
This commit is contained in:
@@ -9,6 +9,7 @@ Core Capabilities:
|
||||
2. Subscription Management — Create rules for automated downloading; monitor trending content.
|
||||
3. Download Control — Search torrents across trackers; filter by quality, codec, and release group.
|
||||
4. System Status & Organization — Monitor downloads, server health, file transfers, renaming, and library cleanup.
|
||||
5. Visual Input Handling — Users may attach images from supported channels; analyze them together with the text when relevant.
|
||||
|
||||
<communication>
|
||||
{verbose_spec}
|
||||
@@ -19,6 +20,7 @@ Core Capabilities:
|
||||
- Use Markdown for structured data. Use `inline code` for media titles/paths.
|
||||
- Include key details (year, rating, resolution) but do NOT over-explain.
|
||||
- Do not stop for approval on read-only operations. Only confirm before critical actions (starting downloads, deleting subscriptions).
|
||||
- If the current channel supports image sending and an image would materially help, you may use the `send_message` tool with `image_url` to send it.
|
||||
- NOT a coding assistant. Do not offer code snippets.
|
||||
- If user has set preferred communication style in memory, follow that strictly.
|
||||
</communication>
|
||||
|
||||
@@ -249,7 +249,9 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
|
||||
return None
|
||||
|
||||
async def send_tool_message(self, message: str, title: str = ""):
|
||||
async def send_tool_message(
|
||||
self, message: str, title: str = "", image: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
发送工具消息
|
||||
"""
|
||||
@@ -261,5 +263,6 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
username=self._username,
|
||||
title=title,
|
||||
text=message,
|
||||
image=image,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
@@ -15,42 +15,64 @@ class SendMessageInput(BaseModel):
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
message: str = Field(
|
||||
...,
|
||||
message: Optional[str] = Field(
|
||||
None,
|
||||
description="The message content to send to the user (should be clear and informative)",
|
||||
)
|
||||
message_type: Optional[str] = Field(
|
||||
None,
|
||||
description="Title of the message, a short summary of the message content",
|
||||
)
|
||||
image_url: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional image URL to send together with the message on channels that support images (such as Telegram and Slack)",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_payload(self):
|
||||
if not self.message and not self.message_type and not self.image_url:
|
||||
raise ValueError("message、message_type、image_url 至少需要提供一个")
|
||||
return self
|
||||
|
||||
|
||||
class SendMessageTool(MoviePilotTool):
|
||||
name: str = "send_message"
|
||||
description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Used to inform users about operation results, errors, or important updates."
|
||||
description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Supports optional image_url on channels that can send images. Used to inform users about operation results, errors, important updates, or proactively send a relevant image."
|
||||
args_schema: Type[BaseModel] = SendMessageInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据消息参数生成友好的提示消息"""
|
||||
message = kwargs.get("message", "")
|
||||
message = kwargs.get("message", "") or ""
|
||||
title = kwargs.get("message_type") or ""
|
||||
image_url = kwargs.get("image_url")
|
||||
|
||||
# 截断过长的消息
|
||||
if len(message) > 50:
|
||||
message = message[:50] + "..."
|
||||
|
||||
if title and image_url:
|
||||
return f"正在发送图文消息: [{title}] {message}"
|
||||
if title:
|
||||
return f"正在发送消息: [{title}] {message}"
|
||||
if image_url:
|
||||
return f"正在发送图片消息: {message}"
|
||||
return f"正在发送消息: {message}"
|
||||
|
||||
async def run(
|
||||
self, message: str, message_type: Optional[str] = None, **kwargs
|
||||
self,
|
||||
message: Optional[str] = None,
|
||||
message_type: Optional[str] = None,
|
||||
image_url: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
title = message_type or ""
|
||||
logger.info(f"执行工具: {self.name}, 参数: title={title}, message={message}")
|
||||
title = message_type or ("图片" if image_url and not message else "")
|
||||
text = message or ""
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: title={title}, message={text}, image_url={image_url}"
|
||||
)
|
||||
try:
|
||||
await self.send_tool_message(message, title=title)
|
||||
await self.send_tool_message(text, title=title, image=image_url)
|
||||
return "消息已发送"
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
|
||||
@@ -126,15 +126,15 @@ class MessageChain(ChainBase):
|
||||
logger.debug(f"未识别到用户ID:{body}{form}{args}")
|
||||
return
|
||||
# 消息内容
|
||||
text = str(info.text).strip() if info.text else None
|
||||
if not text:
|
||||
text = str(info.text).strip() if info.text else ""
|
||||
images = info.images
|
||||
if not text and not images:
|
||||
logger.debug(f"未识别到消息内容::{body}{form}{args}")
|
||||
return
|
||||
|
||||
# 获取原消息ID信息
|
||||
original_message_id = info.message_id
|
||||
original_chat_id = info.chat_id
|
||||
images = info.images
|
||||
|
||||
# 处理消息
|
||||
self.handle_message(
|
||||
@@ -221,6 +221,16 @@ class MessageChain(ChainBase):
|
||||
username=username,
|
||||
images=images,
|
||||
)
|
||||
elif settings.AI_AGENT_ENABLE and images:
|
||||
# 带图消息优先交给智能体处理,避免图片在传统消息链路中丢失
|
||||
self._handle_ai_message(
|
||||
text=text,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
images=images,
|
||||
)
|
||||
elif settings.AI_AGENT_ENABLE and settings.AI_AGENT_GLOBAL:
|
||||
# 普通消息,全局智能体响应
|
||||
self._handle_ai_message(
|
||||
@@ -1234,8 +1244,20 @@ class MessageChain(ChainBase):
|
||||
session_id = self._get_or_create_session_id(userid)
|
||||
|
||||
# 下载图片并转为base64
|
||||
original_images = images
|
||||
if images:
|
||||
images = self._download_images_to_base64(images, channel, source)
|
||||
if original_images and not images and not user_message:
|
||||
self.post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="图片读取失败,请稍后重试",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# 在事件循环中处理
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
@@ -1275,6 +1297,12 @@ class MessageChain(ChainBase):
|
||||
)
|
||||
if base64_data:
|
||||
base64_images.append(f"data:image/jpeg;base64,{base64_data}")
|
||||
elif channel == MessageChannel.Slack:
|
||||
data_url = self.run_module(
|
||||
"download_file_to_data_url", file_url=img, source=source
|
||||
)
|
||||
if data_url:
|
||||
base64_images.append(data_url)
|
||||
elif img.startswith("http"):
|
||||
resp = RequestUtils(timeout=30).get_res(img)
|
||||
if resp and resp.content:
|
||||
|
||||
@@ -279,12 +279,40 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
return None
|
||||
images = []
|
||||
for file in files:
|
||||
if file.get("type") in ("image", "jpg", "jpeg", "png", "gif", "webp"):
|
||||
file_type = str(file.get("type", "")).lower()
|
||||
file_ext = str(file.get("filetype", "")).lower()
|
||||
mime_type = str(file.get("mimetype", "")).lower()
|
||||
if (
|
||||
file_type == "image"
|
||||
or file_ext in ("jpg", "jpeg", "png", "gif", "webp", "bmp")
|
||||
or mime_type.startswith("image/")
|
||||
):
|
||||
url = file.get("url_private") or file.get("url_private_download")
|
||||
if url:
|
||||
images.append(url)
|
||||
return images if images else None
|
||||
|
||||
def download_file_to_data_url(self, file_url: str, source: str) -> Optional[str]:
|
||||
"""
|
||||
下载Slack文件并转为data URL
|
||||
:param file_url: Slack私有文件URL
|
||||
:param source: 来源名称
|
||||
:return: data URL
|
||||
"""
|
||||
config = self.get_config(source)
|
||||
if not config:
|
||||
return None
|
||||
client = self.get_instance(config.name)
|
||||
if not client:
|
||||
return None
|
||||
file_data = client.download_file(file_url)
|
||||
if file_data:
|
||||
import base64
|
||||
|
||||
content, mime_type = file_data
|
||||
return f"data:{mime_type};base64,{base64.b64encode(content).decode()}"
|
||||
return None
|
||||
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送消息
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import re
|
||||
from threading import Lock
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
from urllib.parse import quote
|
||||
|
||||
import requests
|
||||
@@ -12,6 +12,7 @@ from app.core.config import settings
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.log import logger
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
lock = Lock()
|
||||
@@ -22,6 +23,7 @@ class Slack:
|
||||
_service: SocketModeHandler = None
|
||||
_ds_url = f"http://127.0.0.1:{settings.PORT}/api/v1/message?token={settings.API_TOKEN}"
|
||||
_channel = ""
|
||||
_oauth_token = ""
|
||||
|
||||
def __init__(self, SLACK_OAUTH_TOKEN: Optional[str] = None, SLACK_APP_TOKEN: Optional[str] = None,
|
||||
SLACK_CHANNEL: Optional[str] = None, **kwargs):
|
||||
@@ -40,6 +42,7 @@ class Slack:
|
||||
|
||||
self._client = slack_app.client
|
||||
self._channel = SLACK_CHANNEL
|
||||
self._oauth_token = SLACK_OAUTH_TOKEN
|
||||
|
||||
# 标记消息来源
|
||||
if kwargs.get("name"):
|
||||
@@ -102,6 +105,28 @@ class Slack:
|
||||
"""
|
||||
return True if self._client else False
|
||||
|
||||
def download_file(self, file_url: str) -> Optional[Tuple[bytes, str]]:
|
||||
"""
|
||||
下载Slack私有文件
|
||||
:param file_url: Slack文件URL
|
||||
:return: (文件内容, MIME类型)
|
||||
"""
|
||||
if not self._client or not self._oauth_token or not file_url:
|
||||
return None
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._oauth_token}",
|
||||
"User-Agent": settings.USER_AGENT,
|
||||
"Accept": "*/*",
|
||||
}
|
||||
resp = RequestUtils(headers=headers, timeout=30).get_res(file_url)
|
||||
if resp and resp.content:
|
||||
mime_type = resp.headers.get("Content-Type", "image/jpeg")
|
||||
return resp.content, mime_type.split(";")[0]
|
||||
except Exception as e:
|
||||
logger.error(f"下载Slack文件失败: {e}")
|
||||
return None
|
||||
|
||||
def send_msg(self, title: str, text: Optional[str] = None,
|
||||
image: Optional[str] = None, link: Optional[str] = None,
|
||||
userid: Optional[str] = None, buttons: Optional[List[List[dict]]] = None,
|
||||
|
||||
@@ -267,14 +267,14 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
largest_photo = photo[-1]
|
||||
file_id = largest_photo.get("file_id")
|
||||
if file_id:
|
||||
images.append(file_id)
|
||||
images.append(f"tg://file_id/{file_id}")
|
||||
|
||||
document = msg.get("document")
|
||||
if document:
|
||||
file_id = document.get("file_id")
|
||||
mime_type = document.get("mime_type", "")
|
||||
if file_id and mime_type.startswith("image/"):
|
||||
images.append(file_id)
|
||||
images.append(f"tg://file_id/{file_id}")
|
||||
|
||||
return images if images else None
|
||||
|
||||
|
||||
119
tests/test_agent_image_support.py
Normal file
119
tests/test_agent_image_support.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import base64
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from app.agent.tools.impl.send_message import SendMessageInput
|
||||
from app.chain.message import MessageChain
|
||||
from app.core.config import settings
|
||||
from app.modules.slack import SlackModule
|
||||
from app.modules.telegram import TelegramModule
|
||||
from app.schemas import CommingMessage
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class AgentImageSupportTest(unittest.TestCase):
|
||||
def test_telegram_extract_images_returns_prefixed_file_ids(self):
|
||||
images = TelegramModule._extract_images(
|
||||
{
|
||||
"photo": [{"file_id": "small"}, {"file_id": "large"}],
|
||||
"document": {"file_id": "doc-image", "mime_type": "image/png"},
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
images,
|
||||
["tg://file_id/large", "tg://file_id/doc-image"],
|
||||
)
|
||||
|
||||
def test_process_allows_image_only_message(self):
|
||||
chain = MessageChain()
|
||||
message = CommingMessage(
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
images=["tg://file_id/image-1"],
|
||||
)
|
||||
|
||||
with patch.object(chain, "message_parser", return_value=message), patch.object(
|
||||
chain, "handle_message"
|
||||
) as handle_message:
|
||||
chain.process(body="{}", form={}, args={"source": "telegram-test"})
|
||||
|
||||
handle_kwargs = handle_message.call_args.kwargs
|
||||
self.assertEqual(handle_kwargs["text"], "")
|
||||
self.assertEqual(handle_kwargs["images"], ["tg://file_id/image-1"])
|
||||
|
||||
def test_image_message_routes_to_agent_even_when_global_agent_is_disabled(self):
|
||||
chain = MessageChain()
|
||||
|
||||
with patch.object(chain, "load_cache", return_value={}), patch.object(
|
||||
chain.messagehelper, "put"
|
||||
), patch.object(chain.messageoper, "add"), patch.object(
|
||||
chain, "_handle_ai_message"
|
||||
) as handle_ai_message, patch.object(
|
||||
settings, "AI_AGENT_ENABLE", True
|
||||
), patch.object(
|
||||
settings, "AI_AGENT_GLOBAL", False
|
||||
):
|
||||
chain.handle_message(
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
text="",
|
||||
images=["tg://file_id/image-1"],
|
||||
)
|
||||
|
||||
handle_ai_message.assert_called_once()
|
||||
|
||||
def test_slack_images_use_authenticated_data_url_download(self):
|
||||
chain = MessageChain()
|
||||
|
||||
with patch.object(
|
||||
chain,
|
||||
"run_module",
|
||||
return_value="data:image/png;base64,abc123",
|
||||
) as run_module:
|
||||
images = chain._download_images_to_base64(
|
||||
images=["https://files.slack.com/files-pri/T1-F1/test.png"],
|
||||
channel=MessageChannel.Slack,
|
||||
source="slack-test",
|
||||
)
|
||||
|
||||
self.assertEqual(images, ["data:image/png;base64,abc123"])
|
||||
run_module.assert_called_once_with(
|
||||
"download_file_to_data_url",
|
||||
file_url="https://files.slack.com/files-pri/T1-F1/test.png",
|
||||
source="slack-test",
|
||||
)
|
||||
|
||||
def test_slack_module_download_file_to_data_url(self):
|
||||
module = SlackModule()
|
||||
client = Mock()
|
||||
client.download_file.return_value = (b"png-binary", "image/png")
|
||||
|
||||
with patch.object(
|
||||
module, "get_config", return_value=SimpleNamespace(name="slack-test")
|
||||
), patch.object(module, "get_instance", return_value=client):
|
||||
data_url = module.download_file_to_data_url(
|
||||
"https://files.slack.com/files-pri/T1-F1/test.png",
|
||||
"slack-test",
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
data_url,
|
||||
f"data:image/png;base64,{base64.b64encode(b'png-binary').decode()}",
|
||||
)
|
||||
|
||||
def test_send_message_input_accepts_image_only_payload(self):
|
||||
payload = SendMessageInput(
|
||||
explanation="send poster image",
|
||||
image_url="https://example.com/poster.png",
|
||||
)
|
||||
|
||||
self.assertEqual(payload.image_url, "https://example.com/poster.png")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user