mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-17 13:41:57 +08:00
支持 Slack 和 Discord 自动注册命令
This commit is contained in:
@@ -1,13 +1,22 @@
|
||||
import copy
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from urllib.parse import quote, unquote
|
||||
from typing import Optional, Union, List, Tuple, Any
|
||||
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.core.event import eventmanager
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _MessageBase
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification, MessageResponse
|
||||
from app.schemas.types import ModuleType
|
||||
from app.schemas import (
|
||||
CommandRegisterEventData,
|
||||
CommingMessage,
|
||||
MessageChannel,
|
||||
MessageResponse,
|
||||
Notification,
|
||||
)
|
||||
from app.schemas.types import ChainEventType, ModuleType
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.structures import DictUtils
|
||||
|
||||
try:
|
||||
from app.modules.discord.discord import Discord
|
||||
@@ -530,6 +539,54 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
return True
|
||||
return False
|
||||
|
||||
def register_commands(self, commands: Dict[str, dict]) -> None:
|
||||
"""
|
||||
注册命令,实现这个函数接收系统可用的命令菜单。
|
||||
|
||||
:param commands: 命令字典
|
||||
"""
|
||||
for client_config in self.get_configs().values():
|
||||
client = self.get_instance(client_config.name)
|
||||
if not client:
|
||||
continue
|
||||
|
||||
scoped_commands = copy.deepcopy(commands)
|
||||
event = eventmanager.send_event(
|
||||
ChainEventType.CommandRegister,
|
||||
CommandRegisterEventData(
|
||||
commands=scoped_commands,
|
||||
origin="Discord",
|
||||
service=client_config.name,
|
||||
),
|
||||
)
|
||||
|
||||
if event and event.event_data:
|
||||
event_data: CommandRegisterEventData = event.event_data
|
||||
if event_data.cancel:
|
||||
client.delete_commands()
|
||||
logger.debug(
|
||||
f"Command registration for {client_config.name} canceled by event: {event_data.source}"
|
||||
)
|
||||
continue
|
||||
scoped_commands = event_data.commands or {}
|
||||
if not scoped_commands:
|
||||
logger.debug("Filtered commands are empty, skipping registration.")
|
||||
client.delete_commands()
|
||||
|
||||
filtered_scoped_commands = DictUtils.filter_keys_to_subset(
|
||||
scoped_commands,
|
||||
commands,
|
||||
)
|
||||
if not filtered_scoped_commands:
|
||||
logger.debug("Filtered commands are empty, skipping registration.")
|
||||
client.delete_commands()
|
||||
continue
|
||||
if filtered_scoped_commands != commands:
|
||||
logger.debug(
|
||||
f"Command set has changed, Updating new commands: {filtered_scoped_commands}"
|
||||
)
|
||||
client.register_commands(filtered_scoped_commands)
|
||||
|
||||
def mark_message_processing_started(
|
||||
self,
|
||||
channel: MessageChannel,
|
||||
|
||||
@@ -31,6 +31,8 @@ class Discord:
|
||||
Discord Bot 通知与交互实现(基于 discord.py 2.6.4)
|
||||
"""
|
||||
|
||||
_MAX_SLASH_COMMANDS = 100
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
DISCORD_BOT_TOKEN: Optional[str] = None,
|
||||
@@ -69,7 +71,7 @@ class Discord:
|
||||
self._client: Optional[discord.Client] = discord.Client(
|
||||
intents=intents, proxy=settings.PROXY_HOST
|
||||
)
|
||||
self._tree: Optional[app_commands.CommandTree] = None
|
||||
self._tree: Optional[app_commands.CommandTree] = app_commands.CommandTree(self._client)
|
||||
self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._ready_event = threading.Event()
|
||||
@@ -84,6 +86,7 @@ class Discord:
|
||||
self._typing_interval_seconds = 5
|
||||
self._typing_initial_delay_seconds = 1
|
||||
self._typing_max_duration_seconds = 10 * 60
|
||||
self._registered_commands: Optional[Dict[str, dict]] = None
|
||||
|
||||
self._register_events()
|
||||
self._start()
|
||||
@@ -101,6 +104,11 @@ class Discord:
|
||||
self._bot_user_id = self._client.user.id if self._client.user else None
|
||||
self._ready_event.set()
|
||||
logger.info(f"Discord Bot 已登录:{self._client.user}")
|
||||
if self._registered_commands is not None:
|
||||
try:
|
||||
await self._sync_registered_commands()
|
||||
except Exception as err:
|
||||
logger.error(f"同步 Discord 斜杠命令失败:{err}")
|
||||
|
||||
@self._client.event
|
||||
async def on_message(message: discord.Message):
|
||||
@@ -232,6 +240,169 @@ class Discord:
|
||||
def get_state(self) -> bool:
|
||||
return self._ready_event.is_set() and self._client is not None
|
||||
|
||||
def register_commands(self, commands: Dict[str, dict]) -> bool:
|
||||
"""
|
||||
注册 Discord 斜杠命令。
|
||||
|
||||
:param commands: 命令字典,键为斜杠命令,值包含描述和分类等元数据
|
||||
:return: 是否成功提交同步任务
|
||||
"""
|
||||
self._registered_commands = dict(commands or {})
|
||||
return self._schedule_command_sync()
|
||||
|
||||
def delete_commands(self) -> bool:
|
||||
"""
|
||||
清理 Discord 斜杠命令。
|
||||
|
||||
:return: 是否成功提交同步任务
|
||||
"""
|
||||
self._registered_commands = {}
|
||||
return self._schedule_command_sync()
|
||||
|
||||
def _schedule_command_sync(self) -> bool:
|
||||
"""在 Discord 事件循环中提交命令同步任务。"""
|
||||
if not self._tree or not self._loop:
|
||||
return False
|
||||
if not self.get_state():
|
||||
logger.debug("Discord Bot 未就绪,斜杠命令将在登录后同步")
|
||||
return True
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._sync_registered_commands(), self._loop
|
||||
)
|
||||
return bool(future.result(timeout=30))
|
||||
except Exception as err:
|
||||
logger.error(f"同步 Discord 斜杠命令失败:{err}")
|
||||
return False
|
||||
|
||||
async def _sync_registered_commands(self) -> bool:
|
||||
"""将当前命令集合同步到 Discord 应用命令树。"""
|
||||
if not self._tree or not self._client:
|
||||
return False
|
||||
if not self._client.is_ready():
|
||||
await self._client.wait_until_ready()
|
||||
|
||||
guild = discord.Object(id=self._guild_id) if self._guild_id else None
|
||||
self._tree.clear_commands(guild=guild)
|
||||
|
||||
commands = self._registered_commands or {}
|
||||
registered_count = 0
|
||||
seen_names = set()
|
||||
for command_text, command_data in commands.items():
|
||||
if registered_count >= self._MAX_SLASH_COMMANDS:
|
||||
logger.warning(
|
||||
f"Discord 斜杠命令数量超过 {self._MAX_SLASH_COMMANDS} 个,后续命令已跳过"
|
||||
)
|
||||
break
|
||||
command_name = self._normalize_slash_command_name(command_text)
|
||||
if not command_name or command_name in seen_names:
|
||||
logger.warning(f"跳过无效或重复的 Discord 斜杠命令:{command_text}")
|
||||
continue
|
||||
seen_names.add(command_name)
|
||||
description = self._normalize_slash_command_description(
|
||||
command_data.get("description") if isinstance(command_data, dict) else None,
|
||||
command_name,
|
||||
)
|
||||
self._tree.add_command(
|
||||
self._build_slash_command(command_text, command_name, description),
|
||||
guild=guild,
|
||||
override=True,
|
||||
)
|
||||
registered_count += 1
|
||||
|
||||
synced_commands = await self._tree.sync(guild=guild)
|
||||
logger.info(f"Discord 斜杠命令已同步:{len(synced_commands)} 个")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _normalize_slash_command_name(command_text: str) -> str:
|
||||
"""转换为 Discord 允许的斜杠命令名称。"""
|
||||
command_name = str(command_text or "").strip().lstrip("/").lower()
|
||||
if not re.fullmatch(r"[a-z0-9_-]{1,32}", command_name):
|
||||
return ""
|
||||
return command_name
|
||||
|
||||
@staticmethod
|
||||
def _normalize_slash_command_description(
|
||||
description: Optional[str],
|
||||
fallback: str,
|
||||
) -> str:
|
||||
"""整理 Discord 斜杠命令描述,满足长度要求。"""
|
||||
normalized = str(description or fallback or "MoviePilot").strip()
|
||||
return normalized[:100] or "MoviePilot"
|
||||
|
||||
def _build_slash_command(
|
||||
self,
|
||||
command_text: str,
|
||||
command_name: str,
|
||||
description: str,
|
||||
) -> app_commands.Command:
|
||||
"""构建 Discord 斜杠命令对象。"""
|
||||
|
||||
async def _callback(
|
||||
interaction: discord.Interaction,
|
||||
args: Optional[str] = None,
|
||||
) -> None:
|
||||
await self._handle_slash_command(interaction, command_text, args)
|
||||
|
||||
_callback.__name__ = f"moviepilot_{command_name}"
|
||||
_callback = app_commands.describe(args="命令参数")(_callback)
|
||||
return app_commands.Command(
|
||||
name=command_name,
|
||||
description=description,
|
||||
callback=_callback,
|
||||
)
|
||||
|
||||
async def _handle_slash_command(
|
||||
self,
|
||||
interaction: discord.Interaction,
|
||||
command_text: str,
|
||||
args: Optional[str] = None,
|
||||
) -> None:
|
||||
"""处理 Discord 斜杠命令回调,并转发到统一消息入口。"""
|
||||
try:
|
||||
await interaction.response.defer(ephemeral=True, thinking=True)
|
||||
except Exception as err:
|
||||
logger.debug(f"延迟响应 Discord 斜杠命令失败:{err}")
|
||||
|
||||
userid = str(interaction.user.id) if interaction.user else None
|
||||
chat_id = str(interaction.channel.id) if interaction.channel else None
|
||||
username = None
|
||||
if interaction.user:
|
||||
username = (
|
||||
getattr(interaction.user, "display_name", None)
|
||||
or getattr(interaction.user, "global_name", None)
|
||||
or getattr(interaction.user, "name", None)
|
||||
)
|
||||
if userid and chat_id:
|
||||
self._update_user_chat_mapping(userid, chat_id)
|
||||
|
||||
arg_text = str(args or "").strip()
|
||||
payload = {
|
||||
"type": "message",
|
||||
"userid": userid,
|
||||
"username": username,
|
||||
"user_tag": str(interaction.user) if interaction.user else None,
|
||||
"text": f"{command_text} {arg_text}".strip(),
|
||||
"message_id": str(interaction.id),
|
||||
"chat_id": chat_id,
|
||||
"channel_type": "dm"
|
||||
if isinstance(interaction.channel, discord.DMChannel)
|
||||
else "guild",
|
||||
}
|
||||
await self._post_to_ds(payload)
|
||||
|
||||
try:
|
||||
if interaction.response.is_done():
|
||||
await interaction.followup.send("命令已提交,请稍等...", ephemeral=True)
|
||||
else:
|
||||
await interaction.response.send_message(
|
||||
"命令已提交,请稍等...",
|
||||
ephemeral=True,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug(f"发送 Discord 斜杠命令确认失败:{err}")
|
||||
|
||||
def send_msg(
|
||||
self,
|
||||
title: str,
|
||||
|
||||
@@ -1,14 +1,23 @@
|
||||
import copy
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from urllib.parse import quote, unquote
|
||||
from typing import Optional, Union, List, Tuple, Any
|
||||
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.core.event import eventmanager
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _MessageBase
|
||||
from app.modules.slack.slack import Slack
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification, MessageResponse
|
||||
from app.schemas.types import ModuleType
|
||||
from app.schemas import (
|
||||
CommandRegisterEventData,
|
||||
CommingMessage,
|
||||
MessageChannel,
|
||||
MessageResponse,
|
||||
Notification,
|
||||
)
|
||||
from app.schemas.types import ChainEventType, ModuleType
|
||||
from app.utils.structures import DictUtils
|
||||
|
||||
|
||||
class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
@@ -661,6 +670,54 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
return True
|
||||
return False
|
||||
|
||||
def register_commands(self, commands: Dict[str, dict]) -> None:
|
||||
"""
|
||||
注册命令,实现这个函数接收系统可用的命令菜单。
|
||||
|
||||
:param commands: 命令字典
|
||||
"""
|
||||
for client_config in self.get_configs().values():
|
||||
client = self.get_instance(client_config.name)
|
||||
if not client:
|
||||
continue
|
||||
|
||||
scoped_commands = copy.deepcopy(commands)
|
||||
event = eventmanager.send_event(
|
||||
ChainEventType.CommandRegister,
|
||||
CommandRegisterEventData(
|
||||
commands=scoped_commands,
|
||||
origin="Slack",
|
||||
service=client_config.name,
|
||||
),
|
||||
)
|
||||
|
||||
if event and event.event_data:
|
||||
event_data: CommandRegisterEventData = event.event_data
|
||||
if event_data.cancel:
|
||||
client.delete_commands()
|
||||
logger.debug(
|
||||
f"Command registration for {client_config.name} canceled by event: {event_data.source}"
|
||||
)
|
||||
continue
|
||||
scoped_commands = event_data.commands or {}
|
||||
if not scoped_commands:
|
||||
logger.debug("Filtered commands are empty, skipping registration.")
|
||||
client.delete_commands()
|
||||
|
||||
filtered_scoped_commands = DictUtils.filter_keys_to_subset(
|
||||
scoped_commands,
|
||||
commands,
|
||||
)
|
||||
if not filtered_scoped_commands:
|
||||
logger.debug("Filtered commands are empty, skipping registration.")
|
||||
client.delete_commands()
|
||||
continue
|
||||
if filtered_scoped_commands != commands:
|
||||
logger.debug(
|
||||
f"Command set has changed, Updating new commands: {filtered_scoped_commands}"
|
||||
)
|
||||
client.register_commands(filtered_scoped_commands)
|
||||
|
||||
def mark_message_processing_started(
|
||||
self,
|
||||
channel: MessageChannel,
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import json
|
||||
import re
|
||||
from threading import Lock
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from urllib.parse import quote
|
||||
|
||||
import requests
|
||||
@@ -20,14 +21,32 @@ lock = Lock()
|
||||
|
||||
|
||||
class Slack:
|
||||
"""Slack 通知与交互客户端。"""
|
||||
|
||||
_client: WebClient = None
|
||||
_service: SocketModeHandler = None
|
||||
_ds_url = f"http://127.0.0.1:{settings.PORT}/api/v1/message?token={settings.API_TOKEN}"
|
||||
_channel = ""
|
||||
_oauth_token = ""
|
||||
_MAX_SLASH_COMMANDS = 50
|
||||
_SLASH_COMMAND_USAGE_HINT = "MoviePilot 可选参数"
|
||||
|
||||
def __init__(self, SLACK_OAUTH_TOKEN: Optional[str] = None, SLACK_APP_TOKEN: Optional[str] = None,
|
||||
SLACK_CHANNEL: Optional[str] = None, **kwargs):
|
||||
SLACK_CHANNEL: Optional[str] = None,
|
||||
SLACK_APP_ID: Optional[str] = None,
|
||||
SLACK_APP_CONFIG_TOKEN: Optional[str] = None,
|
||||
SLACK_COMMAND_REQUEST_URL: Optional[str] = None,
|
||||
**kwargs):
|
||||
"""
|
||||
初始化 Slack 客户端。
|
||||
|
||||
:param SLACK_OAUTH_TOKEN: Slack Bot User OAuth Token
|
||||
:param SLACK_APP_TOKEN: Slack Socket Mode App Token
|
||||
:param SLACK_CHANNEL: 默认发送频道
|
||||
:param SLACK_APP_ID: Slack App ID,用于可选的 Manifest 命令自动注册
|
||||
:param SLACK_APP_CONFIG_TOKEN: Slack App Configuration Token,用于可选的 Manifest 命令自动注册
|
||||
:param SLACK_COMMAND_REQUEST_URL: Slash Command 请求 URL,Socket Mode 下可为空
|
||||
"""
|
||||
|
||||
if not SLACK_OAUTH_TOKEN or not SLACK_APP_TOKEN:
|
||||
logger.error("Slack 配置不完整!")
|
||||
@@ -44,6 +63,14 @@ class Slack:
|
||||
self._client = slack_app.client
|
||||
self._channel = SLACK_CHANNEL
|
||||
self._oauth_token = SLACK_OAUTH_TOKEN
|
||||
self._app_id = (SLACK_APP_ID or "").strip()
|
||||
self._command_request_url = (SLACK_COMMAND_REQUEST_URL or "").strip()
|
||||
self._manifest_client = (
|
||||
WebClient(token=SLACK_APP_CONFIG_TOKEN)
|
||||
if SLACK_APP_CONFIG_TOKEN and self._app_id
|
||||
else None
|
||||
)
|
||||
self._registered_command_names: set[str] = set()
|
||||
|
||||
# 标记消息来源
|
||||
if kwargs.get("name"):
|
||||
@@ -106,6 +133,127 @@ class Slack:
|
||||
"""
|
||||
return True if self._client else False
|
||||
|
||||
def register_commands(self, commands: Dict[str, dict]) -> bool:
|
||||
"""
|
||||
通过 Slack App Manifest 注册 Slash Commands。
|
||||
|
||||
:param commands: 命令字典,键为斜杠命令,值包含描述和分类等元数据
|
||||
:return: 注册是否成功
|
||||
"""
|
||||
if not self._manifest_client or not self._app_id:
|
||||
logger.debug("Slack 未配置 SLACK_APP_ID/SLACK_APP_CONFIG_TOKEN,跳过命令自动注册")
|
||||
return False
|
||||
return self._update_manifest_commands(commands or {})
|
||||
|
||||
def delete_commands(self) -> bool:
|
||||
"""
|
||||
清理本实例自动注册过的 Slack Slash Commands。
|
||||
|
||||
:return: 清理是否成功
|
||||
"""
|
||||
if not self._manifest_client or not self._app_id:
|
||||
logger.debug("Slack 未配置 SLACK_APP_ID/SLACK_APP_CONFIG_TOKEN,跳过命令清理")
|
||||
return False
|
||||
return self._update_manifest_commands({})
|
||||
|
||||
def _update_manifest_commands(self, commands: Dict[str, dict]) -> bool:
|
||||
"""更新 Slack Manifest 中的 Slash Commands,保留非本实例管理的命令。"""
|
||||
try:
|
||||
manifest = self._export_manifest()
|
||||
if not manifest:
|
||||
return False
|
||||
features = manifest.setdefault("features", {})
|
||||
existing_commands = features.get("slash_commands") or []
|
||||
generated_commands = self._build_slash_commands(commands)
|
||||
managed_names = self._registered_command_names | {
|
||||
item["command"] for item in generated_commands
|
||||
}
|
||||
preserved_commands = [
|
||||
item
|
||||
for item in existing_commands
|
||||
if (
|
||||
isinstance(item, dict)
|
||||
and item.get("command") not in managed_names
|
||||
and item.get("usage_hint") != self._SLASH_COMMAND_USAGE_HINT
|
||||
)
|
||||
]
|
||||
available = max(self._MAX_SLASH_COMMANDS - len(preserved_commands), 0)
|
||||
if len(generated_commands) > available:
|
||||
logger.warning(
|
||||
f"Slack Slash Commands 超过平台上限,仅注册前 {available} 个"
|
||||
)
|
||||
generated_commands = generated_commands[:available]
|
||||
features["slash_commands"] = preserved_commands + generated_commands
|
||||
|
||||
result = self._manifest_client.apps_manifest_update(
|
||||
app_id=self._app_id,
|
||||
manifest=manifest,
|
||||
)
|
||||
if result and result.get("ok") is False:
|
||||
logger.error(f"Slack Manifest 更新失败:{result.get('error')}")
|
||||
return False
|
||||
self._registered_command_names = {
|
||||
item["command"] for item in generated_commands
|
||||
}
|
||||
logger.info(f"Slack Slash Commands 已同步:{len(generated_commands)} 个")
|
||||
return True
|
||||
except Exception as err:
|
||||
logger.error(f"Slack Slash Commands 自动注册失败:{err}")
|
||||
return False
|
||||
|
||||
def _export_manifest(self) -> Optional[Dict[str, Any]]:
|
||||
"""导出 Slack App Manifest。"""
|
||||
result = self._manifest_client.apps_manifest_export(app_id=self._app_id)
|
||||
if result and result.get("ok") is False:
|
||||
logger.error(f"Slack Manifest 导出失败:{result.get('error')}")
|
||||
return None
|
||||
manifest = result.get("manifest") if result else None
|
||||
if isinstance(manifest, str):
|
||||
manifest = json.loads(manifest)
|
||||
return manifest if isinstance(manifest, dict) else None
|
||||
|
||||
def _build_slash_commands(self, commands: Dict[str, dict]) -> List[Dict[str, Any]]:
|
||||
"""构建 Slack Manifest Slash Commands 配置。"""
|
||||
slash_commands = []
|
||||
seen_commands = set()
|
||||
for command_text, command_data in commands.items():
|
||||
command = self._normalize_slack_command(command_text)
|
||||
if not command or command in seen_commands:
|
||||
logger.warning(f"跳过无效或重复的 Slack Slash Command:{command_text}")
|
||||
continue
|
||||
seen_commands.add(command)
|
||||
description = self._normalize_slack_description(
|
||||
command_data.get("description") if isinstance(command_data, dict) else None,
|
||||
command,
|
||||
)
|
||||
item = {
|
||||
"command": command,
|
||||
"description": description,
|
||||
"should_escape": False,
|
||||
"usage_hint": self._SLASH_COMMAND_USAGE_HINT,
|
||||
}
|
||||
if self._command_request_url:
|
||||
item["url"] = self._command_request_url
|
||||
slash_commands.append(item)
|
||||
return slash_commands
|
||||
|
||||
@staticmethod
|
||||
def _normalize_slack_command(command_text: str) -> str:
|
||||
"""转换为 Slack Slash Command 名称。"""
|
||||
command = f"/{str(command_text or '').strip().lstrip('/').lower()}"
|
||||
if not re.fullmatch(r"/[a-z0-9_-]{1,31}", command):
|
||||
return ""
|
||||
return command
|
||||
|
||||
@staticmethod
|
||||
def _normalize_slack_description(
|
||||
description: Optional[str],
|
||||
fallback: str,
|
||||
) -> str:
|
||||
"""整理 Slack Slash Command 描述。"""
|
||||
normalized = str(description or fallback or "MoviePilot").strip()
|
||||
return normalized[:2000] or "MoviePilot"
|
||||
|
||||
def download_file(self, file_url: str) -> Optional[Tuple[bytes, str]]:
|
||||
"""
|
||||
下载Slack私有文件
|
||||
|
||||
118
tests/test_discord_command_registration.py
Normal file
118
tests/test_discord_command_registration.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from app.modules.discord import DiscordModule
|
||||
from app.modules.discord.discord import Discord
|
||||
from app.schemas import CommandRegisterEventData
|
||||
|
||||
|
||||
def test_discord_module_register_commands_filters_event_subset():
|
||||
"""Discord 模块注册命令时应复用渠道级 CommandRegister 事件过滤结果。"""
|
||||
module = DiscordModule()
|
||||
client = SimpleNamespace(register_commands=Mock(), delete_commands=Mock())
|
||||
original_commands = {
|
||||
"/sites": {"description": "管理站点"},
|
||||
"/version": {"description": "当前版本"},
|
||||
}
|
||||
event = SimpleNamespace(
|
||||
event_data=CommandRegisterEventData(
|
||||
commands={"/sites": {"description": "管理站点"}, "/unknown": {"description": "无效"}},
|
||||
origin="DiscordFilter",
|
||||
service="discord-main",
|
||||
)
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
module,
|
||||
"get_configs",
|
||||
return_value={"discord-main": SimpleNamespace(name="discord-main", config={})},
|
||||
),
|
||||
patch.object(module, "get_instance", return_value=client),
|
||||
patch("app.modules.discord.eventmanager.send_event", return_value=event),
|
||||
):
|
||||
module.register_commands(original_commands)
|
||||
|
||||
client.register_commands.assert_called_once_with(
|
||||
{"/sites": {"description": "管理站点"}}
|
||||
)
|
||||
client.delete_commands.assert_not_called()
|
||||
|
||||
|
||||
def test_discord_module_register_commands_deletes_when_event_canceled():
|
||||
"""Discord 模块注册命令被事件取消时应清理应用命令。"""
|
||||
module = DiscordModule()
|
||||
client = SimpleNamespace(register_commands=Mock(), delete_commands=Mock())
|
||||
event = SimpleNamespace(
|
||||
event_data=CommandRegisterEventData(
|
||||
commands={"/sites": {"description": "管理站点"}},
|
||||
origin="DiscordFilter",
|
||||
service="discord-main",
|
||||
cancel=True,
|
||||
)
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
module,
|
||||
"get_configs",
|
||||
return_value={"discord-main": SimpleNamespace(name="discord-main", config={})},
|
||||
),
|
||||
patch.object(module, "get_instance", return_value=client),
|
||||
patch("app.modules.discord.eventmanager.send_event", return_value=event),
|
||||
):
|
||||
module.register_commands({"/sites": {"description": "管理站点"}})
|
||||
|
||||
client.delete_commands.assert_called_once_with()
|
||||
client.register_commands.assert_not_called()
|
||||
|
||||
|
||||
def test_discord_normalizes_slash_command_names():
|
||||
"""Discord 命令名称应符合平台只允许小写字母数字下划线连字符的约束。"""
|
||||
assert Discord._normalize_slash_command_name("/sites") == "sites"
|
||||
assert Discord._normalize_slash_command_name("/clear_cache") == "clear_cache"
|
||||
assert Discord._normalize_slash_command_name("/INVALID") == "invalid"
|
||||
assert Discord._normalize_slash_command_name("/中文") == ""
|
||||
assert Discord._normalize_slash_command_name("/" + "a" * 33) == ""
|
||||
|
||||
|
||||
def test_discord_handle_slash_command_forwards_to_message_chain():
|
||||
"""Discord 斜杠命令回调应转发为统一消息入口可识别的命令文本。"""
|
||||
client = Discord.__new__(Discord)
|
||||
client._update_user_chat_mapping = Mock()
|
||||
client._post_to_ds = AsyncMock()
|
||||
|
||||
user = SimpleNamespace(id=10001, display_name="tester", global_name=None, name="tester")
|
||||
channel = SimpleNamespace(id=20001)
|
||||
interaction = SimpleNamespace(
|
||||
id=30001,
|
||||
user=user,
|
||||
channel=channel,
|
||||
response=SimpleNamespace(
|
||||
defer=AsyncMock(),
|
||||
is_done=Mock(return_value=True),
|
||||
send_message=AsyncMock(),
|
||||
),
|
||||
followup=SimpleNamespace(send=AsyncMock()),
|
||||
)
|
||||
|
||||
asyncio.run(client._handle_slash_command(interaction, "/sites", "refresh"))
|
||||
|
||||
client._update_user_chat_mapping.assert_called_once_with("10001", "20001")
|
||||
client._post_to_ds.assert_awaited_once_with(
|
||||
{
|
||||
"type": "message",
|
||||
"userid": "10001",
|
||||
"username": "tester",
|
||||
"user_tag": str(user),
|
||||
"text": "/sites refresh",
|
||||
"message_id": "30001",
|
||||
"chat_id": "20001",
|
||||
"channel_type": "guild",
|
||||
}
|
||||
)
|
||||
interaction.followup.send.assert_awaited_once_with(
|
||||
"命令已提交,请稍等...",
|
||||
ephemeral=True,
|
||||
)
|
||||
105
tests/test_slack_command_registration.py
Normal file
105
tests/test_slack_command_registration.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from app.modules.slack import SlackModule
|
||||
from app.modules.slack.slack import Slack
|
||||
from app.schemas import CommandRegisterEventData
|
||||
|
||||
|
||||
def test_slack_module_register_commands_filters_event_subset():
|
||||
"""Slack 模块注册命令时应复用渠道级 CommandRegister 事件过滤结果。"""
|
||||
module = SlackModule()
|
||||
client = SimpleNamespace(register_commands=Mock(), delete_commands=Mock())
|
||||
original_commands = {
|
||||
"/sites": {"description": "管理站点"},
|
||||
"/version": {"description": "当前版本"},
|
||||
}
|
||||
event = SimpleNamespace(
|
||||
event_data=CommandRegisterEventData(
|
||||
commands={"/sites": {"description": "管理站点"}, "/unknown": {"description": "无效"}},
|
||||
origin="SlackFilter",
|
||||
service="slack-main",
|
||||
)
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
module,
|
||||
"get_configs",
|
||||
return_value={"slack-main": SimpleNamespace(name="slack-main", config={})},
|
||||
),
|
||||
patch.object(module, "get_instance", return_value=client),
|
||||
patch("app.modules.slack.eventmanager.send_event", return_value=event),
|
||||
):
|
||||
module.register_commands(original_commands)
|
||||
|
||||
client.register_commands.assert_called_once_with(
|
||||
{"/sites": {"description": "管理站点"}}
|
||||
)
|
||||
client.delete_commands.assert_not_called()
|
||||
|
||||
|
||||
def test_slack_manifest_update_preserves_foreign_commands():
|
||||
"""Slack Manifest 更新应只替换 MoviePilot 本次管理的 Slash Commands。"""
|
||||
client = Slack.__new__(Slack)
|
||||
manifest_client = SimpleNamespace(
|
||||
apps_manifest_export=Mock(
|
||||
return_value={
|
||||
"ok": True,
|
||||
"manifest": {
|
||||
"features": {
|
||||
"slash_commands": [
|
||||
{"command": "/keep", "description": "保留"},
|
||||
{"command": "/sites", "description": "旧站点"},
|
||||
{
|
||||
"command": "/old_moviepilot",
|
||||
"description": "旧命令",
|
||||
"usage_hint": "MoviePilot 可选参数",
|
||||
},
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
),
|
||||
apps_manifest_update=Mock(return_value={"ok": True}),
|
||||
)
|
||||
client._manifest_client = manifest_client
|
||||
client._app_id = "A123"
|
||||
client._command_request_url = "https://example.com/slack/commands"
|
||||
client._registered_command_names = {"/sites"}
|
||||
|
||||
assert client.register_commands(
|
||||
{
|
||||
"/sites": {"description": "管理站点"},
|
||||
"/version": {"description": "当前版本"},
|
||||
}
|
||||
)
|
||||
|
||||
manifest_client.apps_manifest_update.assert_called_once()
|
||||
_, kwargs = manifest_client.apps_manifest_update.call_args
|
||||
assert kwargs["app_id"] == "A123"
|
||||
slash_commands = kwargs["manifest"]["features"]["slash_commands"]
|
||||
assert [item["command"] for item in slash_commands] == [
|
||||
"/keep",
|
||||
"/sites",
|
||||
"/version",
|
||||
]
|
||||
assert slash_commands[1]["url"] == "https://example.com/slack/commands"
|
||||
assert client._registered_command_names == {"/sites", "/version"}
|
||||
|
||||
|
||||
def test_slack_command_registration_skips_without_manifest_credentials():
|
||||
"""未配置 Slack Manifest 凭据时不应尝试自动注册。"""
|
||||
client = Slack.__new__(Slack)
|
||||
client._manifest_client = None
|
||||
client._app_id = ""
|
||||
|
||||
assert client.register_commands({"/sites": {"description": "管理站点"}}) is False
|
||||
|
||||
|
||||
def test_slack_normalizes_slash_command_names():
|
||||
"""Slack 命令名称应符合平台 Slash Command 约束。"""
|
||||
assert Slack._normalize_slack_command("/sites") == "/sites"
|
||||
assert Slack._normalize_slack_command("CLEAR_CACHE") == "/clear_cache"
|
||||
assert Slack._normalize_slack_command("/中文") == ""
|
||||
assert Slack._normalize_slack_command("/" + "a" * 32) == ""
|
||||
Reference in New Issue
Block a user