支持 Slack 和 Discord 自动注册命令

This commit is contained in:
jxxghp
2026-06-15 08:03:29 +08:00
parent 0f42a0fb8c
commit c87b856ddf
6 changed files with 665 additions and 9 deletions

View File

@@ -1,13 +1,22 @@
import copy
import json
from typing import Any, Dict, List, Optional, Tuple, Union
from urllib.parse import quote, unquote
from typing import Optional, Union, List, Tuple, Any
from app.core.context import MediaInfo, Context
from app.core.event import eventmanager
from app.log import logger
from app.modules import _ModuleBase, _MessageBase
from app.schemas import MessageChannel, CommingMessage, Notification, MessageResponse
from app.schemas.types import ModuleType
from app.schemas import (
CommandRegisterEventData,
CommingMessage,
MessageChannel,
MessageResponse,
Notification,
)
from app.schemas.types import ChainEventType, ModuleType
from app.utils.http import RequestUtils
from app.utils.structures import DictUtils
try:
from app.modules.discord.discord import Discord
@@ -530,6 +539,54 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
return True
return False
def register_commands(self, commands: Dict[str, dict]) -> None:
"""
注册命令,实现这个函数接收系统可用的命令菜单。
:param commands: 命令字典
"""
for client_config in self.get_configs().values():
client = self.get_instance(client_config.name)
if not client:
continue
scoped_commands = copy.deepcopy(commands)
event = eventmanager.send_event(
ChainEventType.CommandRegister,
CommandRegisterEventData(
commands=scoped_commands,
origin="Discord",
service=client_config.name,
),
)
if event and event.event_data:
event_data: CommandRegisterEventData = event.event_data
if event_data.cancel:
client.delete_commands()
logger.debug(
f"Command registration for {client_config.name} canceled by event: {event_data.source}"
)
continue
scoped_commands = event_data.commands or {}
if not scoped_commands:
logger.debug("Filtered commands are empty, skipping registration.")
client.delete_commands()
filtered_scoped_commands = DictUtils.filter_keys_to_subset(
scoped_commands,
commands,
)
if not filtered_scoped_commands:
logger.debug("Filtered commands are empty, skipping registration.")
client.delete_commands()
continue
if filtered_scoped_commands != commands:
logger.debug(
f"Command set has changed, Updating new commands: {filtered_scoped_commands}"
)
client.register_commands(filtered_scoped_commands)
def mark_message_processing_started(
self,
channel: MessageChannel,

View File

@@ -31,6 +31,8 @@ class Discord:
Discord Bot 通知与交互实现(基于 discord.py 2.6.4
"""
_MAX_SLASH_COMMANDS = 100
def __init__(
self,
DISCORD_BOT_TOKEN: Optional[str] = None,
@@ -69,7 +71,7 @@ class Discord:
self._client: Optional[discord.Client] = discord.Client(
intents=intents, proxy=settings.PROXY_HOST
)
self._tree: Optional[app_commands.CommandTree] = None
self._tree: Optional[app_commands.CommandTree] = app_commands.CommandTree(self._client)
self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
self._thread: Optional[threading.Thread] = None
self._ready_event = threading.Event()
@@ -84,6 +86,7 @@ class Discord:
self._typing_interval_seconds = 5
self._typing_initial_delay_seconds = 1
self._typing_max_duration_seconds = 10 * 60
self._registered_commands: Optional[Dict[str, dict]] = None
self._register_events()
self._start()
@@ -101,6 +104,11 @@ class Discord:
self._bot_user_id = self._client.user.id if self._client.user else None
self._ready_event.set()
logger.info(f"Discord Bot 已登录:{self._client.user}")
if self._registered_commands is not None:
try:
await self._sync_registered_commands()
except Exception as err:
logger.error(f"同步 Discord 斜杠命令失败:{err}")
@self._client.event
async def on_message(message: discord.Message):
@@ -232,6 +240,169 @@ class Discord:
def get_state(self) -> bool:
return self._ready_event.is_set() and self._client is not None
def register_commands(self, commands: Dict[str, dict]) -> bool:
"""
注册 Discord 斜杠命令。
:param commands: 命令字典,键为斜杠命令,值包含描述和分类等元数据
:return: 是否成功提交同步任务
"""
self._registered_commands = dict(commands or {})
return self._schedule_command_sync()
def delete_commands(self) -> bool:
"""
清理 Discord 斜杠命令。
:return: 是否成功提交同步任务
"""
self._registered_commands = {}
return self._schedule_command_sync()
def _schedule_command_sync(self) -> bool:
"""在 Discord 事件循环中提交命令同步任务。"""
if not self._tree or not self._loop:
return False
if not self.get_state():
logger.debug("Discord Bot 未就绪,斜杠命令将在登录后同步")
return True
try:
future = asyncio.run_coroutine_threadsafe(
self._sync_registered_commands(), self._loop
)
return bool(future.result(timeout=30))
except Exception as err:
logger.error(f"同步 Discord 斜杠命令失败:{err}")
return False
async def _sync_registered_commands(self) -> bool:
"""将当前命令集合同步到 Discord 应用命令树。"""
if not self._tree or not self._client:
return False
if not self._client.is_ready():
await self._client.wait_until_ready()
guild = discord.Object(id=self._guild_id) if self._guild_id else None
self._tree.clear_commands(guild=guild)
commands = self._registered_commands or {}
registered_count = 0
seen_names = set()
for command_text, command_data in commands.items():
if registered_count >= self._MAX_SLASH_COMMANDS:
logger.warning(
f"Discord 斜杠命令数量超过 {self._MAX_SLASH_COMMANDS} 个,后续命令已跳过"
)
break
command_name = self._normalize_slash_command_name(command_text)
if not command_name or command_name in seen_names:
logger.warning(f"跳过无效或重复的 Discord 斜杠命令:{command_text}")
continue
seen_names.add(command_name)
description = self._normalize_slash_command_description(
command_data.get("description") if isinstance(command_data, dict) else None,
command_name,
)
self._tree.add_command(
self._build_slash_command(command_text, command_name, description),
guild=guild,
override=True,
)
registered_count += 1
synced_commands = await self._tree.sync(guild=guild)
logger.info(f"Discord 斜杠命令已同步:{len(synced_commands)}")
return True
@staticmethod
def _normalize_slash_command_name(command_text: str) -> str:
"""转换为 Discord 允许的斜杠命令名称。"""
command_name = str(command_text or "").strip().lstrip("/").lower()
if not re.fullmatch(r"[a-z0-9_-]{1,32}", command_name):
return ""
return command_name
@staticmethod
def _normalize_slash_command_description(
description: Optional[str],
fallback: str,
) -> str:
"""整理 Discord 斜杠命令描述,满足长度要求。"""
normalized = str(description or fallback or "MoviePilot").strip()
return normalized[:100] or "MoviePilot"
def _build_slash_command(
self,
command_text: str,
command_name: str,
description: str,
) -> app_commands.Command:
"""构建 Discord 斜杠命令对象。"""
async def _callback(
interaction: discord.Interaction,
args: Optional[str] = None,
) -> None:
await self._handle_slash_command(interaction, command_text, args)
_callback.__name__ = f"moviepilot_{command_name}"
_callback = app_commands.describe(args="命令参数")(_callback)
return app_commands.Command(
name=command_name,
description=description,
callback=_callback,
)
async def _handle_slash_command(
self,
interaction: discord.Interaction,
command_text: str,
args: Optional[str] = None,
) -> None:
"""处理 Discord 斜杠命令回调,并转发到统一消息入口。"""
try:
await interaction.response.defer(ephemeral=True, thinking=True)
except Exception as err:
logger.debug(f"延迟响应 Discord 斜杠命令失败:{err}")
userid = str(interaction.user.id) if interaction.user else None
chat_id = str(interaction.channel.id) if interaction.channel else None
username = None
if interaction.user:
username = (
getattr(interaction.user, "display_name", None)
or getattr(interaction.user, "global_name", None)
or getattr(interaction.user, "name", None)
)
if userid and chat_id:
self._update_user_chat_mapping(userid, chat_id)
arg_text = str(args or "").strip()
payload = {
"type": "message",
"userid": userid,
"username": username,
"user_tag": str(interaction.user) if interaction.user else None,
"text": f"{command_text} {arg_text}".strip(),
"message_id": str(interaction.id),
"chat_id": chat_id,
"channel_type": "dm"
if isinstance(interaction.channel, discord.DMChannel)
else "guild",
}
await self._post_to_ds(payload)
try:
if interaction.response.is_done():
await interaction.followup.send("命令已提交,请稍等...", ephemeral=True)
else:
await interaction.response.send_message(
"命令已提交,请稍等...",
ephemeral=True,
)
except Exception as err:
logger.debug(f"发送 Discord 斜杠命令确认失败:{err}")
def send_msg(
self,
title: str,

View File

@@ -1,14 +1,23 @@
import copy
import json
import re
from typing import Any, Dict, List, Optional, Tuple, Union
from urllib.parse import quote, unquote
from typing import Optional, Union, List, Tuple, Any
from app.core.context import MediaInfo, Context
from app.core.event import eventmanager
from app.log import logger
from app.modules import _ModuleBase, _MessageBase
from app.modules.slack.slack import Slack
from app.schemas import MessageChannel, CommingMessage, Notification, MessageResponse
from app.schemas.types import ModuleType
from app.schemas import (
CommandRegisterEventData,
CommingMessage,
MessageChannel,
MessageResponse,
Notification,
)
from app.schemas.types import ChainEventType, ModuleType
from app.utils.structures import DictUtils
class SlackModule(_ModuleBase, _MessageBase[Slack]):
@@ -661,6 +670,54 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
return True
return False
def register_commands(self, commands: Dict[str, dict]) -> None:
"""
注册命令,实现这个函数接收系统可用的命令菜单。
:param commands: 命令字典
"""
for client_config in self.get_configs().values():
client = self.get_instance(client_config.name)
if not client:
continue
scoped_commands = copy.deepcopy(commands)
event = eventmanager.send_event(
ChainEventType.CommandRegister,
CommandRegisterEventData(
commands=scoped_commands,
origin="Slack",
service=client_config.name,
),
)
if event and event.event_data:
event_data: CommandRegisterEventData = event.event_data
if event_data.cancel:
client.delete_commands()
logger.debug(
f"Command registration for {client_config.name} canceled by event: {event_data.source}"
)
continue
scoped_commands = event_data.commands or {}
if not scoped_commands:
logger.debug("Filtered commands are empty, skipping registration.")
client.delete_commands()
filtered_scoped_commands = DictUtils.filter_keys_to_subset(
scoped_commands,
commands,
)
if not filtered_scoped_commands:
logger.debug("Filtered commands are empty, skipping registration.")
client.delete_commands()
continue
if filtered_scoped_commands != commands:
logger.debug(
f"Command set has changed, Updating new commands: {filtered_scoped_commands}"
)
client.register_commands(filtered_scoped_commands)
def mark_message_processing_started(
self,
channel: MessageChannel,

View File

@@ -1,7 +1,8 @@
import json
import re
from threading import Lock
from pathlib import Path
from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import quote
import requests
@@ -20,14 +21,32 @@ lock = Lock()
class Slack:
"""Slack 通知与交互客户端。"""
_client: WebClient = None
_service: SocketModeHandler = None
_ds_url = f"http://127.0.0.1:{settings.PORT}/api/v1/message?token={settings.API_TOKEN}"
_channel = ""
_oauth_token = ""
_MAX_SLASH_COMMANDS = 50
_SLASH_COMMAND_USAGE_HINT = "MoviePilot 可选参数"
def __init__(self, SLACK_OAUTH_TOKEN: Optional[str] = None, SLACK_APP_TOKEN: Optional[str] = None,
SLACK_CHANNEL: Optional[str] = None, **kwargs):
SLACK_CHANNEL: Optional[str] = None,
SLACK_APP_ID: Optional[str] = None,
SLACK_APP_CONFIG_TOKEN: Optional[str] = None,
SLACK_COMMAND_REQUEST_URL: Optional[str] = None,
**kwargs):
"""
初始化 Slack 客户端。
:param SLACK_OAUTH_TOKEN: Slack Bot User OAuth Token
:param SLACK_APP_TOKEN: Slack Socket Mode App Token
:param SLACK_CHANNEL: 默认发送频道
:param SLACK_APP_ID: Slack App ID用于可选的 Manifest 命令自动注册
:param SLACK_APP_CONFIG_TOKEN: Slack App Configuration Token用于可选的 Manifest 命令自动注册
:param SLACK_COMMAND_REQUEST_URL: Slash Command 请求 URLSocket Mode 下可为空
"""
if not SLACK_OAUTH_TOKEN or not SLACK_APP_TOKEN:
logger.error("Slack 配置不完整!")
@@ -44,6 +63,14 @@ class Slack:
self._client = slack_app.client
self._channel = SLACK_CHANNEL
self._oauth_token = SLACK_OAUTH_TOKEN
self._app_id = (SLACK_APP_ID or "").strip()
self._command_request_url = (SLACK_COMMAND_REQUEST_URL or "").strip()
self._manifest_client = (
WebClient(token=SLACK_APP_CONFIG_TOKEN)
if SLACK_APP_CONFIG_TOKEN and self._app_id
else None
)
self._registered_command_names: set[str] = set()
# 标记消息来源
if kwargs.get("name"):
@@ -106,6 +133,127 @@ class Slack:
"""
return True if self._client else False
def register_commands(self, commands: Dict[str, dict]) -> bool:
"""
通过 Slack App Manifest 注册 Slash Commands。
:param commands: 命令字典,键为斜杠命令,值包含描述和分类等元数据
:return: 注册是否成功
"""
if not self._manifest_client or not self._app_id:
logger.debug("Slack 未配置 SLACK_APP_ID/SLACK_APP_CONFIG_TOKEN跳过命令自动注册")
return False
return self._update_manifest_commands(commands or {})
def delete_commands(self) -> bool:
"""
清理本实例自动注册过的 Slack Slash Commands。
:return: 清理是否成功
"""
if not self._manifest_client or not self._app_id:
logger.debug("Slack 未配置 SLACK_APP_ID/SLACK_APP_CONFIG_TOKEN跳过命令清理")
return False
return self._update_manifest_commands({})
def _update_manifest_commands(self, commands: Dict[str, dict]) -> bool:
"""更新 Slack Manifest 中的 Slash Commands保留非本实例管理的命令。"""
try:
manifest = self._export_manifest()
if not manifest:
return False
features = manifest.setdefault("features", {})
existing_commands = features.get("slash_commands") or []
generated_commands = self._build_slash_commands(commands)
managed_names = self._registered_command_names | {
item["command"] for item in generated_commands
}
preserved_commands = [
item
for item in existing_commands
if (
isinstance(item, dict)
and item.get("command") not in managed_names
and item.get("usage_hint") != self._SLASH_COMMAND_USAGE_HINT
)
]
available = max(self._MAX_SLASH_COMMANDS - len(preserved_commands), 0)
if len(generated_commands) > available:
logger.warning(
f"Slack Slash Commands 超过平台上限,仅注册前 {available}"
)
generated_commands = generated_commands[:available]
features["slash_commands"] = preserved_commands + generated_commands
result = self._manifest_client.apps_manifest_update(
app_id=self._app_id,
manifest=manifest,
)
if result and result.get("ok") is False:
logger.error(f"Slack Manifest 更新失败:{result.get('error')}")
return False
self._registered_command_names = {
item["command"] for item in generated_commands
}
logger.info(f"Slack Slash Commands 已同步:{len(generated_commands)}")
return True
except Exception as err:
logger.error(f"Slack Slash Commands 自动注册失败:{err}")
return False
def _export_manifest(self) -> Optional[Dict[str, Any]]:
"""导出 Slack App Manifest。"""
result = self._manifest_client.apps_manifest_export(app_id=self._app_id)
if result and result.get("ok") is False:
logger.error(f"Slack Manifest 导出失败:{result.get('error')}")
return None
manifest = result.get("manifest") if result else None
if isinstance(manifest, str):
manifest = json.loads(manifest)
return manifest if isinstance(manifest, dict) else None
def _build_slash_commands(self, commands: Dict[str, dict]) -> List[Dict[str, Any]]:
"""构建 Slack Manifest Slash Commands 配置。"""
slash_commands = []
seen_commands = set()
for command_text, command_data in commands.items():
command = self._normalize_slack_command(command_text)
if not command or command in seen_commands:
logger.warning(f"跳过无效或重复的 Slack Slash Command{command_text}")
continue
seen_commands.add(command)
description = self._normalize_slack_description(
command_data.get("description") if isinstance(command_data, dict) else None,
command,
)
item = {
"command": command,
"description": description,
"should_escape": False,
"usage_hint": self._SLASH_COMMAND_USAGE_HINT,
}
if self._command_request_url:
item["url"] = self._command_request_url
slash_commands.append(item)
return slash_commands
@staticmethod
def _normalize_slack_command(command_text: str) -> str:
"""转换为 Slack Slash Command 名称。"""
command = f"/{str(command_text or '').strip().lstrip('/').lower()}"
if not re.fullmatch(r"/[a-z0-9_-]{1,31}", command):
return ""
return command
@staticmethod
def _normalize_slack_description(
description: Optional[str],
fallback: str,
) -> str:
"""整理 Slack Slash Command 描述。"""
normalized = str(description or fallback or "MoviePilot").strip()
return normalized[:2000] or "MoviePilot"
def download_file(self, file_url: str) -> Optional[Tuple[bytes, str]]:
"""
下载Slack私有文件

View File

@@ -0,0 +1,118 @@
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock, Mock, patch
from app.modules.discord import DiscordModule
from app.modules.discord.discord import Discord
from app.schemas import CommandRegisterEventData
def test_discord_module_register_commands_filters_event_subset():
"""Discord 模块注册命令时应复用渠道级 CommandRegister 事件过滤结果。"""
module = DiscordModule()
client = SimpleNamespace(register_commands=Mock(), delete_commands=Mock())
original_commands = {
"/sites": {"description": "管理站点"},
"/version": {"description": "当前版本"},
}
event = SimpleNamespace(
event_data=CommandRegisterEventData(
commands={"/sites": {"description": "管理站点"}, "/unknown": {"description": "无效"}},
origin="DiscordFilter",
service="discord-main",
)
)
with (
patch.object(
module,
"get_configs",
return_value={"discord-main": SimpleNamespace(name="discord-main", config={})},
),
patch.object(module, "get_instance", return_value=client),
patch("app.modules.discord.eventmanager.send_event", return_value=event),
):
module.register_commands(original_commands)
client.register_commands.assert_called_once_with(
{"/sites": {"description": "管理站点"}}
)
client.delete_commands.assert_not_called()
def test_discord_module_register_commands_deletes_when_event_canceled():
"""Discord 模块注册命令被事件取消时应清理应用命令。"""
module = DiscordModule()
client = SimpleNamespace(register_commands=Mock(), delete_commands=Mock())
event = SimpleNamespace(
event_data=CommandRegisterEventData(
commands={"/sites": {"description": "管理站点"}},
origin="DiscordFilter",
service="discord-main",
cancel=True,
)
)
with (
patch.object(
module,
"get_configs",
return_value={"discord-main": SimpleNamespace(name="discord-main", config={})},
),
patch.object(module, "get_instance", return_value=client),
patch("app.modules.discord.eventmanager.send_event", return_value=event),
):
module.register_commands({"/sites": {"description": "管理站点"}})
client.delete_commands.assert_called_once_with()
client.register_commands.assert_not_called()
def test_discord_normalizes_slash_command_names():
"""Discord 命令名称应符合平台只允许小写字母数字下划线连字符的约束。"""
assert Discord._normalize_slash_command_name("/sites") == "sites"
assert Discord._normalize_slash_command_name("/clear_cache") == "clear_cache"
assert Discord._normalize_slash_command_name("/INVALID") == "invalid"
assert Discord._normalize_slash_command_name("/中文") == ""
assert Discord._normalize_slash_command_name("/" + "a" * 33) == ""
def test_discord_handle_slash_command_forwards_to_message_chain():
"""Discord 斜杠命令回调应转发为统一消息入口可识别的命令文本。"""
client = Discord.__new__(Discord)
client._update_user_chat_mapping = Mock()
client._post_to_ds = AsyncMock()
user = SimpleNamespace(id=10001, display_name="tester", global_name=None, name="tester")
channel = SimpleNamespace(id=20001)
interaction = SimpleNamespace(
id=30001,
user=user,
channel=channel,
response=SimpleNamespace(
defer=AsyncMock(),
is_done=Mock(return_value=True),
send_message=AsyncMock(),
),
followup=SimpleNamespace(send=AsyncMock()),
)
asyncio.run(client._handle_slash_command(interaction, "/sites", "refresh"))
client._update_user_chat_mapping.assert_called_once_with("10001", "20001")
client._post_to_ds.assert_awaited_once_with(
{
"type": "message",
"userid": "10001",
"username": "tester",
"user_tag": str(user),
"text": "/sites refresh",
"message_id": "30001",
"chat_id": "20001",
"channel_type": "guild",
}
)
interaction.followup.send.assert_awaited_once_with(
"命令已提交,请稍等...",
ephemeral=True,
)

View File

@@ -0,0 +1,105 @@
from types import SimpleNamespace
from unittest.mock import Mock, patch
from app.modules.slack import SlackModule
from app.modules.slack.slack import Slack
from app.schemas import CommandRegisterEventData
def test_slack_module_register_commands_filters_event_subset():
"""Slack 模块注册命令时应复用渠道级 CommandRegister 事件过滤结果。"""
module = SlackModule()
client = SimpleNamespace(register_commands=Mock(), delete_commands=Mock())
original_commands = {
"/sites": {"description": "管理站点"},
"/version": {"description": "当前版本"},
}
event = SimpleNamespace(
event_data=CommandRegisterEventData(
commands={"/sites": {"description": "管理站点"}, "/unknown": {"description": "无效"}},
origin="SlackFilter",
service="slack-main",
)
)
with (
patch.object(
module,
"get_configs",
return_value={"slack-main": SimpleNamespace(name="slack-main", config={})},
),
patch.object(module, "get_instance", return_value=client),
patch("app.modules.slack.eventmanager.send_event", return_value=event),
):
module.register_commands(original_commands)
client.register_commands.assert_called_once_with(
{"/sites": {"description": "管理站点"}}
)
client.delete_commands.assert_not_called()
def test_slack_manifest_update_preserves_foreign_commands():
"""Slack Manifest 更新应只替换 MoviePilot 本次管理的 Slash Commands。"""
client = Slack.__new__(Slack)
manifest_client = SimpleNamespace(
apps_manifest_export=Mock(
return_value={
"ok": True,
"manifest": {
"features": {
"slash_commands": [
{"command": "/keep", "description": "保留"},
{"command": "/sites", "description": "旧站点"},
{
"command": "/old_moviepilot",
"description": "旧命令",
"usage_hint": "MoviePilot 可选参数",
},
]
}
},
}
),
apps_manifest_update=Mock(return_value={"ok": True}),
)
client._manifest_client = manifest_client
client._app_id = "A123"
client._command_request_url = "https://example.com/slack/commands"
client._registered_command_names = {"/sites"}
assert client.register_commands(
{
"/sites": {"description": "管理站点"},
"/version": {"description": "当前版本"},
}
)
manifest_client.apps_manifest_update.assert_called_once()
_, kwargs = manifest_client.apps_manifest_update.call_args
assert kwargs["app_id"] == "A123"
slash_commands = kwargs["manifest"]["features"]["slash_commands"]
assert [item["command"] for item in slash_commands] == [
"/keep",
"/sites",
"/version",
]
assert slash_commands[1]["url"] == "https://example.com/slack/commands"
assert client._registered_command_names == {"/sites", "/version"}
def test_slack_command_registration_skips_without_manifest_credentials():
"""未配置 Slack Manifest 凭据时不应尝试自动注册。"""
client = Slack.__new__(Slack)
client._manifest_client = None
client._app_id = ""
assert client.register_commands({"/sites": {"description": "管理站点"}}) is False
def test_slack_normalizes_slash_command_names():
"""Slack 命令名称应符合平台 Slash Command 约束。"""
assert Slack._normalize_slack_command("/sites") == "/sites"
assert Slack._normalize_slack_command("CLEAR_CACHE") == "/clear_cache"
assert Slack._normalize_slack_command("/中文") == ""
assert Slack._normalize_slack_command("/" + "a" * 32) == ""