diff --git a/app/agent/__init__.py b/app/agent/__init__.py index 2e6d6338..e031b3d0 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -50,6 +50,7 @@ from app.agent.tools.factory import MoviePilotToolFactory from app.chain import ChainBase from app.core.config import settings from app.core.event import eventmanager +from app.db.user_oper import UserOper from app.log import logger from app.schemas import AgentLLMProviderEventData, AgentTokensUsageEventData, Notification, NotificationType from app.schemas.message import ChannelCapabilityManager, ChannelCapability @@ -418,6 +419,38 @@ class MoviePilotAgent: """ return self.session_id.startswith(HEARTBEAT_SESSION_PREFIX) + async def _is_system_admin_context(self) -> bool: + """ + 判断当前 Agent 会话是否应按系统管理员上下文运行工具。 + """ + if self.is_background: + return True + if self.channel == MessageChannel.Web.value and self.source in { + "openai", + "openai.responses", + "anthropic", + }: + return True + if not self.username: + return False + try: + user = await UserOper().async_get_by_name(self.username) + except Exception as e: + logger.error(f"检查 Agent 用户管理员身份失败: {e}") + return False + return bool(user and user.is_superuser) + + async def _build_tool_context(self, should_dispatch_reply: bool) -> Dict[str, object]: + """ + 构造本轮工具共享上下文。 + """ + return { + "user_reply_sent": False, + "reply_mode": None, + "should_dispatch_reply": should_dispatch_reply, + "is_admin": await self._is_system_admin_context(), + } + def _should_stream(self) -> bool: """ 判断是否应启用流式输出: @@ -804,6 +837,7 @@ class MoviePilotAgent: "user_reply_sent": False, "reply_mode": None, "should_dispatch_reply": False, + "is_admin": bool(self._tool_context.get("is_admin")), }, allow_message_tools=False, ) @@ -920,11 +954,9 @@ class MoviePilotAgent: f"images={len(images) if images else 0}, files={len(files) if files else 0}, " f"audio_input={has_audio_input}" ) - self._tool_context = { - "user_reply_sent": False, - "reply_mode": None, - "should_dispatch_reply": self.should_dispatch_reply, - } + self._tool_context = await self._build_tool_context( + should_dispatch_reply=self.should_dispatch_reply + ) self._streamed_output = "" # 获取历史消息 diff --git a/app/agent/middleware/subagents.py b/app/agent/middleware/subagents.py index 42778707..e84289a9 100644 --- a/app/agent/middleware/subagents.py +++ b/app/agent/middleware/subagents.py @@ -39,6 +39,7 @@ SUBAGENT_MAX_ACTIVE_TASKS = 8 SUBAGENT_MAX_CONCURRENT_TASKS = 4 SUBAGENT_RESULT_MAX_CHARS = 12000 SUBAGENT_DESCRIPTION_MAX_CHARS = 500 +SUBAGENT_PIPELINE_CONTEXT_MAX_CHARS = 12000 SUBAGENT_PARENT_PROMPT = """ You may use subagent tools to delegate independent research, retrieval, @@ -51,6 +52,9 @@ Delegation modes: `action=wait`, or `action=cancel` with the returned task IDs. - Use `subagent_task` with `action=run` when you want to launch a bounded batch and wait for the batch in one tool call. +- Use `subagent_task` with `action=pipeline` when later subtasks must use + previous subagent results. Pipeline steps run sequentially, and each step's + result is passed as private context to the next step. Rules: - Delegate when a task benefits from focused investigation, such as media identity checks, site/resource search, subscription analysis, download/transfer diagnosis, MoviePilot code/config exploration, or read-only system inspection. @@ -71,7 +75,9 @@ SUBAGENT_CONTROL_DESCRIPTION = ( "Use action=start with tasks=[{description, subagent_type}] to launch a batch " "and get task IDs immediately. Use action=status to inspect tasks, action=wait " "to wait for all or any task result, action=cancel to stop running tasks, and " - "action=run to launch a bounded batch and wait in one call." + "action=run to launch a bounded batch and wait in one call. Use action=pipeline " + "to run tasks sequentially while passing each result as private context to the " + "next task." ) SUBAGENT_BASE_PROMPT = """You are a silent subagent working for the MoviePilot main agent. @@ -120,9 +126,9 @@ class _SubAgentTaskSpec(BaseModel): class _SubAgentControlInput(BaseModel): """异步子代理管控工具输入。""" - action: Literal["start", "status", "wait", "cancel", "run"] = Field( + action: Literal["start", "status", "wait", "cancel", "run", "pipeline"] = Field( default="start", - description="Task action: start, status, wait, cancel, or run.", + description="Task action: start, status, wait, cancel, run, or pipeline.", ) description: Optional[str] = Field( default=None, @@ -150,7 +156,10 @@ class _SubAgentControlInput(BaseModel): ) timeout_ms: Optional[int] = Field( default=SUBAGENT_DEFAULT_WAIT_TIMEOUT_MS, - description="Maximum wait time in milliseconds for action=wait or action=run.", + description=( + "Maximum wait time in milliseconds for action=wait, action=run, " + "or each action=pipeline step." + ), ) @@ -742,7 +751,8 @@ class SubAgentTaskControlMiddleware(AgentMiddleware): f"pending={len(pending_tasks) - finished_count}" ) - async def _cancel_records(self, records: list[_SubAgentRuntimeTask]) -> None: + @staticmethod + async def _cancel_records(records: list[_SubAgentRuntimeTask]) -> None: """取消一组尚未完成的任务。""" cancellable_tasks = [ record.task for record in records if not record.task.done() @@ -755,6 +765,156 @@ class SubAgentTaskControlMiddleware(AgentMiddleware): await asyncio.gather(*cancellable_tasks, return_exceptions=True) logger.info(f"子代理任务取消完成: tasks={len(cancellable_tasks)}") + @staticmethod + def _pipeline_description( + *, + description: str, + previous_results: list[tuple[_SubAgentRuntimeTask, str]], + ) -> str: + """追加上游子代理结果,生成当前管道步骤的任务描述。""" + normalized_description = description.strip() + if not previous_results: + return normalized_description + + context_parts = [] + for step_index, (record, result) in enumerate(previous_results, start=1): + clipped_result, result_truncated = _clip_text( + result, + SUBAGENT_RESULT_MAX_CHARS, + ) + truncated_note = "\n[Result truncated]" if result_truncated else "" + context_parts.append( + f"Step {step_index} ({record.subagent_type}) result:\n" + f"{clipped_result}{truncated_note}" + ) + context_text, context_truncated = _clip_text( + "\n\n".join(context_parts), + SUBAGENT_PIPELINE_CONTEXT_MAX_CHARS, + ) + truncated_note = "\n[Pipeline context truncated]" if context_truncated else "" + return ( + f"{normalized_description}\n\n" + "\n" + "Previous subagent results are private context for this delegated " + "subtask. Use them to complete the current task, but do not expose " + "the prior reports verbatim.\n\n" + f"{context_text}{truncated_note}\n" + "" + ) + + async def _execute_pipeline_task( + self, + *, + record: _SubAgentRuntimeTask, + description: str, + ) -> str: + """执行单个管道步骤,保留原始步骤描述用于状态展示。""" + async with self._semaphore: + record.started_at = datetime.now() + logger.info( + f"管道子代理任务开始执行: task_id={record.task_id}, " + f"subagent_type={record.subagent_type}" + ) + try: + result = await self._provider.run_task( + description=description, + subagent_type=record.subagent_type, + task_id=record.task_id, + ) + logger.info( + f"管道子代理任务执行完成: task_id={record.task_id}, " + f"subagent_type={record.subagent_type}, result_chars={len(result)}" + ) + return result + except asyncio.CancelledError: + logger.info( + f"管道子代理任务已取消: task_id={record.task_id}, " + f"subagent_type={record.subagent_type}" + ) + raise + except Exception as err: + logger.error(f"管道子代理任务执行失败: task_id={record.task_id}, error={err}") + raise + + @staticmethod + def _create_pipeline_record( + spec: _SubAgentTaskSpec, + ) -> _SubAgentRuntimeTask: + """创建一个管道步骤记录。""" + task_id = f"subagent-{uuid.uuid4().hex[:12]}" + return _SubAgentRuntimeTask( + task_id=task_id, + description=spec.description.strip(), + subagent_type=spec.subagent_type or "general-purpose", + task=None, + created_at=datetime.now(), + ) + + def _track_pipeline_task( + self, + record: _SubAgentRuntimeTask, + task: asyncio.Task, + ) -> None: + """登记管道步骤任务,复用统一的状态和异常收口逻辑。""" + record.task = task + task.add_done_callback( + lambda finished_task, finished_task_id=record.task_id: self._mark_task_finished( + finished_task_id, + finished_task, + ) + ) + self._tasks[record.task_id] = record + + async def _run_pipeline( + self, + specs: list[_SubAgentTaskSpec], + timeout_ms: Optional[int], + ) -> tuple[list[_SubAgentRuntimeTask], Optional[str]]: + """按顺序执行管道任务,并把每一步结果传给下一步。""" + normalized_timeout_ms = self._normalize_timeout_ms(timeout_ms) + if normalized_timeout_ms <= 0: + return [], "管道任务需要大于 0 的等待时间。" + + records: list[_SubAgentRuntimeTask] = [] + previous_results: list[tuple[_SubAgentRuntimeTask, str]] = [] + timeout = normalized_timeout_ms / 1000 + for step_index, spec in enumerate(specs, start=1): + record = self._create_pipeline_record(spec) + records.append(record) + pipeline_description = self._pipeline_description( + description=record.description, + previous_results=previous_results, + ) + task = asyncio.create_task( + self._execute_pipeline_task( + record=record, + description=pipeline_description, + ), + name=record.task_id, + ) + self._track_pipeline_task(record, task) + logger.info( + f"已启动管道子代理任务: step={step_index}, task_id={record.task_id}, " + f"subagent_type={record.subagent_type}" + ) + + try: + result = await asyncio.wait_for(task, timeout=timeout) + except asyncio.TimeoutError: + error = f"第 {step_index} 个管道子代理任务等待超时。" + logger.info( + f"{error} task_id={record.task_id}, timeout_ms={normalized_timeout_ms}" + ) + return records, error + except Exception as err: + error = f"第 {step_index} 个管道子代理任务执行失败: {err}" + logger.info(f"{error} task_id={record.task_id}") + return records, error + + previous_results.append((record, result)) + + return records, None + async def _control_task( self, action: str = "start", @@ -768,7 +928,7 @@ class SubAgentTaskControlMiddleware(AgentMiddleware): ) -> str: """管理异步子代理任务。""" logger.info(f"收到子代理管控操作: action={action}") - if action in {"start", "run"}: + if action in {"start", "run", "pipeline"}: specs, error = self._normalize_specs( description=description, subagent_type=subagent_type, @@ -779,6 +939,20 @@ class SubAgentTaskControlMiddleware(AgentMiddleware): return self._json_response({"success": False, "error": error}) logger.info(f"准备启动子代理任务: action={action}, tasks={len(specs)}") + if action == "pipeline": + records, pipeline_error = await self._run_pipeline( + specs=specs, + timeout_ms=timeout_ms, + ) + return self._json_response( + { + "success": pipeline_error is None, + "action": action, + "error": pipeline_error, + "tasks": [self._task_output(record) for record in records], + } + ) + records = self._start_tasks(specs) if action == "run": await self._wait_records( diff --git a/app/agent/tools/base.py b/app/agent/tools/base.py index 1a85b76f..2ff7d2f1 100644 --- a/app/agent/tools/base.py +++ b/app/agent/tools/base.py @@ -4,6 +4,7 @@ import threading from abc import ABCMeta, abstractmethod from concurrent.futures import ThreadPoolExecutor from functools import partial +from pathlib import Path from typing import Any, Callable, ClassVar, Optional from langchain_core.tools import BaseTool @@ -373,6 +374,119 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): # 独立的新 dict,跨工具状态(例如质量门槛拒绝标记)无法传播。 self._agent_context = {} if agent_context is None else agent_context + async def is_admin_user(self) -> bool: + """ + 判断当前工具调用者是否拥有管理员级权限。 + + :return: 当前调用者是系统管理员、渠道管理员或显式管理员上下文时返回 True + """ + if bool(self._agent_context.get("is_admin")): + return True + + if not self._channel or not self._source: + return False + + return await self._has_channel_admin_permission() + + @staticmethod + def _resolve_local_path(path: str) -> Path: + """ + 解析本地路径并展开符号链接。 + + :param path: 用户传入的本地文件或目录路径 + :return: 规范化后的绝对路径 + """ + return Path(path).expanduser().resolve(strict=False) + + @staticmethod + def _is_path_relative_to(path: Path, root: Path) -> bool: + """ + 判断路径是否位于指定目录内。 + + :param path: 待检查路径 + :param root: 允许访问的根目录 + :return: 路径在根目录内或等于根目录时返回 True + """ + try: + path.relative_to(root) + return True + except ValueError: + return False + + @classmethod + def _get_non_admin_local_file_roots(cls) -> list[Path]: + """ + 获取普通用户可访问的本地文件根目录。 + + :return: 普通用户允许读写的本地目录列表 + """ + roots = [ + settings.CONFIG_PATH, + settings.LOG_PATH, + settings.CONFIG_PATH / "agent" / "memory", + settings.CONFIG_PATH / "agent" / "activity", + ] + resolved_roots = [] + for root in roots: + resolved_root = cls._resolve_local_path(str(root)) + if resolved_root not in resolved_roots: + resolved_roots.append(resolved_root) + return resolved_roots + + async def _check_local_file_access( + self, path: str, operation: str = "访问" + ) -> tuple[Optional[Path], Optional[str]]: + """ + 检查当前用户是否可访问指定本地路径。 + + :param path: 用户传入的本地文件或目录路径 + :param operation: 当前操作名称,用于生成拒绝提示 + :return: 解析后的路径和拒绝原因;拒绝原因为空表示允许访问 + """ + if not path: + return None, "错误:路径不能为空" + + resolved_path = self._resolve_local_path(path) + if await self.is_admin_user(): + return resolved_path, None + + allowed_roots = self._get_non_admin_local_file_roots() + if any( + self._is_path_relative_to(resolved_path, root) + for root in allowed_roots + ): + return resolved_path, None + + allowed_text = "、".join(str(root) for root in allowed_roots) + return ( + resolved_path, + f"抱歉,普通用户只能{operation}配置目录、Agent记忆目录和日志目录内的文件或目录:{allowed_text}", + ) + + async def _check_local_storage_access( + self, + path: str, + storage: Optional[str] = "local", + operation: str = "访问", + ) -> tuple[Optional[Path], Optional[str]]: + """ + 检查当前用户是否可访问指定存储路径。 + + :param path: 用户传入的文件或目录路径 + :param storage: 存储类型,普通用户只允许 local + :param operation: 当前操作名称,用于生成拒绝提示 + :return: 本地存储时返回解析后的路径和拒绝原因;远程存储无本地路径 + """ + if (storage or "local") != "local": + if await self.is_admin_user(): + return None, None + return ( + None, + f"抱歉,普通用户只能{operation}本地配置目录、Agent记忆目录和日志目录,不能访问远程存储。", + ) + + return await self._check_local_file_access(path=path, operation=operation) + async def _check_permission(self) -> Optional[str]: """ 检查用户权限: @@ -385,9 +499,28 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): if not self._require_admin: return None + if await self.is_admin_user(): + return None + if not self._channel or not self._source: return None + return ( + "抱歉,您没有执行此工具的权限。" + "只有渠道管理员或系统管理员才能执行工具操作。" + "如需执行工具,请联系渠道管理员将您的用户ID添加到渠道管理员列表中," + "或联系系统管理员为您设置权限。" + ) + + async def _has_channel_admin_permission(self) -> bool: + """ + 检查当前消息渠道身份是否具备管理员权限。 + + :return: 当前渠道用户是渠道管理员、系统管理员或默认接收人时返回 True + """ + if not self._channel or not self._source: + return False + # 渠道配置来自 SystemConfigOper 内存缓存,可以直接读取; # 只有用户信息需要走异步数据库查询。 user_id_str = str(self._user_id) if self._user_id else None @@ -411,7 +544,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): break if not channel_type: - return None + return False admin_key_map = { "telegram": "TELEGRAM_ADMINS", @@ -451,7 +584,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): if aid.strip() ] if user_id_str and user_id_str in admin_list: - return None + return True user = ( await UserOper().async_get_by_name(self._username) @@ -459,14 +592,9 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): else None ) if user and user.is_superuser: - return None + return True - return ( - "抱歉,您没有执行此工具的权限。" - "只有渠道管理员或系统管理员才能执行工具操作。" - "如需执行工具,请联系渠道管理员将您的用户ID添加到渠道管理员列表中," - "或联系系统管理员为您设置权限。" - ) + return False else: user = ( await UserOper().async_get_by_name(self._username) @@ -474,22 +602,18 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): else None ) if user and user.is_superuser: - return None + return True if user_id_key: config_user_id = config.config.get(user_id_key) if config_user_id and str(config_user_id) == user_id_str: - return None + return True - return ( - "抱歉,您没有执行此工具的权限。" - "只有系统管理员才能执行工具操作。" - "如需执行工具,请联系系统管理员为您设置权限。" - ) + return False except Exception as e: logger.error(f"检查权限失败: {e}") - return None + return False async def send_tool_message( self, message: str, title: str = "", image: Optional[str] = None diff --git a/app/agent/tools/impl/ask_user_choice.py b/app/agent/tools/impl/ask_user_choice.py index 525f363f..37eed9bd 100644 --- a/app/agent/tools/impl/ask_user_choice.py +++ b/app/agent/tools/impl/ask_user_choice.py @@ -83,7 +83,6 @@ class AskUserChoiceTool(MoviePilotTool): "back as the user's next message. Do not also send the same question as plain text." ) args_schema: Type[BaseModel] = AskUserChoiceInput - require_admin: bool = False def get_tool_message(self, **kwargs) -> Optional[str]: message = kwargs.get("message", "") or "" diff --git a/app/agent/tools/impl/edit_file.py b/app/agent/tools/impl/edit_file.py index 2c0c3699..16b1d860 100644 --- a/app/agent/tools/impl/edit_file.py +++ b/app/agent/tools/impl/edit_file.py @@ -24,11 +24,13 @@ class EditFileTool(MoviePilotTool): tags: list[str] = [ ToolTag.Write, ToolTag.File, - ToolTag.Admin, ] - description: str = "Edit a file by replacing specific old text with new text. Useful for modifying configuration files, code, or scripts." + description: str = ( + "Edit a local text file by replacing specific old text with new text. " + "Non-admin users can only edit files inside the MoviePilot config, " + "Agent memory/activity, and log directories." + ) args_schema: Type[BaseModel] = EditFileInput - require_admin: bool = True def get_tool_message(self, **kwargs) -> Optional[str]: """根据参数生成友好的提示消息""" @@ -40,21 +42,27 @@ class EditFileTool(MoviePilotTool): logger.info(f"执行工具: {self.name}, 参数: file_path={file_path}") try: - path = AsyncPath(file_path) + resolved_path, access_error = await self._check_local_file_access( + file_path, operation="编辑" + ) + if access_error: + return access_error + + path = AsyncPath(resolved_path) # 校验逻辑:如果要替换特定文本,文件必须存在且包含该文本 if not await path.exists(): # 如果 old_text 为空,可能用户想直接创建文件,但通常 edit_file 需要匹配旧内容 if old_text: - return f"错误:文件 {file_path} 不存在,无法进行内容替换。" + return f"错误:文件 {resolved_path} 不存在,无法进行内容替换。" if await path.exists() and not await path.is_file(): - return f"错误:{file_path} 不是一个文件" + return f"错误:{resolved_path} 不是一个文件" if await path.exists(): content = await path.read_text(encoding="utf-8") if old_text not in content: - logger.warning(f"编辑文件 {file_path} 失败:未找到指定的旧文本块") - return f"错误:在文件 {file_path} 中未找到指定的旧文本。请确保包含所有的空格、缩进 and 换行符。" + logger.warning(f"编辑文件 {resolved_path} 失败:未找到指定的旧文本块") + return f"错误:在文件 {resolved_path} 中未找到指定的旧文本。请确保包含所有的空格、缩进 and 换行符。" occurrences = content.count(old_text) new_content = content.replace(old_text, new_text) else: @@ -68,8 +76,8 @@ class EditFileTool(MoviePilotTool): # 写入文件 await path.write_text(new_content, encoding="utf-8") - logger.info(f"成功编辑文件 {file_path},替换了 {occurrences} 处内容") - return f"成功编辑文件 {file_path} (替换了 {occurrences} 处匹配内容)" + logger.info(f"成功编辑文件 {resolved_path},替换了 {occurrences} 处内容") + return f"成功编辑文件 {resolved_path} (替换了 {occurrences} 处匹配内容)" except PermissionError: return f"错误:没有访问/修改 {file_path} 的权限" diff --git a/app/agent/tools/impl/list_directory.py b/app/agent/tools/impl/list_directory.py index 336511d2..a179be5f 100644 --- a/app/agent/tools/impl/list_directory.py +++ b/app/agent/tools/impl/list_directory.py @@ -116,6 +116,13 @@ class ListDirectoryTool(MoviePilotTool): logger.info(f"执行工具: {self.name}, 参数: path={path}, storage={storage}, sort_by={sort_by}") try: + resolved_path, access_error = await self._check_local_storage_access( + path=path, storage=storage, operation="列出" + ) + if access_error: + return access_error + if resolved_path: + path = str(resolved_path) return await self.run_blocking( "storage", self._list_directory_sync, path, storage, sort_by ) diff --git a/app/agent/tools/impl/query_downloaders.py b/app/agent/tools/impl/query_downloaders.py index c9623f21..a2d51dfd 100644 --- a/app/agent/tools/impl/query_downloaders.py +++ b/app/agent/tools/impl/query_downloaders.py @@ -22,10 +22,12 @@ class QueryDownloadersTool(MoviePilotTool): tags: list[str] = [ ToolTag.Read, ToolTag.Download, - ToolTag.Admin, ] - description: str = "Query downloader configuration and list all available downloaders. Shows downloader status, connection details, and configuration settings." - require_admin: bool = True + description: str = ( + "Query downloader configuration and list available downloaders. Non-admin users receive " + "a safe view with only the fields needed to choose a downloader, without host, account, " + "password, token or API key values." + ) args_schema: Type[BaseModel] = QueryDownloadersInput def get_tool_message(self, **kwargs) -> Optional[str]: @@ -37,11 +39,35 @@ class QueryDownloadersTool(MoviePilotTool): """从内存配置缓存中读取下载器配置。""" return SystemConfigOper().get(SystemConfigKey.Downloaders) + @staticmethod + def _sanitize_downloaders_config(downloaders_config: list) -> list: + """ + 生成普通用户可见的下载器配置视图。 + + :param downloaders_config: 系统下载器完整配置列表 + :return: 仅包含名称、类型和启用状态的安全配置列表 + """ + safe_fields = ("name", "type", "enabled", "default", "priority") + safe_downloaders = [] + for downloader in downloaders_config: + if not isinstance(downloader, dict): + continue + safe_downloaders.append({ + key: downloader.get(key) + for key in safe_fields + if key in downloader + }) + return safe_downloaders + async def run(self, **kwargs) -> str: logger.info(f"执行工具: {self.name}") try: downloaders_config = self._load_downloaders_config() if downloaders_config: + if not await self.is_admin_user(): + downloaders_config = self._sanitize_downloaders_config( + downloaders_config + ) return json.dumps(downloaders_config, ensure_ascii=False, indent=2) return "未配置下载器。" except Exception as e: diff --git a/app/agent/tools/impl/query_sites.py b/app/agent/tools/impl/query_sites.py index f3bbe25d..4088e932 100644 --- a/app/agent/tools/impl/query_sites.py +++ b/app/agent/tools/impl/query_sites.py @@ -30,10 +30,12 @@ class QuerySitesTool(MoviePilotTool): tags: list[str] = [ ToolTag.Read, ToolTag.Site, - ToolTag.Admin, ] - description: str = "Query site status and list all configured sites. Shows site name, domain, status, priority, and basic configuration. Site priority (pri): smaller values have higher priority (e.g., pri=1 has higher priority than pri=10)." - require_admin: bool = True + description: str = ( + "Query site status and list configured sites. Non-admin users receive a safe view " + "that omits sensitive fields: cookie, token, API key and RSS URL. " + "Site priority (pri): smaller values have higher priority (e.g., pri=1 has higher priority than pri=10)." + ) args_schema: Type[BaseModel] = QuerySitesInput def get_tool_message(self, **kwargs) -> Optional[str]: @@ -57,6 +59,7 @@ class QuerySitesTool(MoviePilotTool): ) -> str: logger.info(f"执行工具: {self.name}, 参数: status={status}, name={name}") try: + is_admin = await self.is_admin_user() site_oper = SiteOper() # 获取所有站点(按优先级排序) sites = await site_oper.async_list() @@ -82,11 +85,25 @@ class QuerySitesTool(MoviePilotTool): "url": s.url, "pri": s.pri, "is_active": s.is_active, - "cookie": s.cookie, "downloader": s.downloader, + "ua": s.ua, "proxy": s.proxy, + "filter": s.filter, + "render": s.render, + "public": s.public, + "note": s.note, + "limit_interval": s.limit_interval, + "limit_count": s.limit_count, + "limit_seconds": s.limit_seconds, "timeout": s.timeout, } + if is_admin: + simplified.update({ + "rss": s.rss, + "cookie": s.cookie, + "apikey": s.apikey, + "token": s.token, + }) simplified_sites.append(simplified) result_json = json.dumps(simplified_sites, ensure_ascii=False, indent=2) return result_json diff --git a/app/agent/tools/impl/read_file.py b/app/agent/tools/impl/read_file.py index 6533e0ca..303cfdf5 100644 --- a/app/agent/tools/impl/read_file.py +++ b/app/agent/tools/impl/read_file.py @@ -41,13 +41,19 @@ class ReadFileTool(MoviePilotTool): logger.info(f"执行工具: {self.name}, 参数: file_path={file_path}, start_line={start_line}, end_line={end_line}") try: - path = AsyncPath(file_path) + resolved_path, access_error = await self._check_local_file_access( + file_path, operation="读取" + ) + if access_error: + return access_error + + path = AsyncPath(resolved_path) if not await path.exists(): - return f"错误:文件 {file_path} 不存在" + return f"错误:文件 {resolved_path} 不存在" if not await path.is_file(): - return f"错误:{file_path} 不是一个文件" + return f"错误:{resolved_path} 不是一个文件" content = await path.read_text(encoding="utf-8") truncated = False diff --git a/app/agent/tools/impl/send_local_file.py b/app/agent/tools/impl/send_local_file.py index 87d8c476..59926829 100644 --- a/app/agent/tools/impl/send_local_file.py +++ b/app/agent/tools/impl/send_local_file.py @@ -55,7 +55,7 @@ class SendLocalFileTool(MoviePilotTool): "Use this when you have generated or identified a local file the user should download." ) args_schema: Type[BaseModel] = SendLocalFileInput - require_admin: bool = False + require_admin: bool = True def get_tool_message(self, **kwargs) -> Optional[str]: file_path = kwargs.get("file_path", "") diff --git a/app/agent/tools/impl/send_voice_message.py b/app/agent/tools/impl/send_voice_message.py index 7b16dfb7..73b2d749 100644 --- a/app/agent/tools/impl/send_voice_message.py +++ b/app/agent/tools/impl/send_voice_message.py @@ -44,7 +44,6 @@ class SendVoiceMessageTool(MoviePilotTool): "or call `send_message` with the same content." ) args_schema: Type[BaseModel] = SendVoiceMessageInput - require_admin: bool = False def get_tool_message(self, **kwargs) -> Optional[str]: """生成语音回复工具的执行提示。""" diff --git a/app/agent/tools/impl/write_file.py b/app/agent/tools/impl/write_file.py index 57369d9b..ad19cc5f 100644 --- a/app/agent/tools/impl/write_file.py +++ b/app/agent/tools/impl/write_file.py @@ -23,11 +23,12 @@ class WriteFileTool(MoviePilotTool): tags: list[str] = [ ToolTag.Write, ToolTag.File, - ToolTag.Admin, ] - description: str = "Write full content to a file. If the file already exists, it will be overwritten. Automatically creates parent directories if they don't exist." + description: str = ( + "Write full content to a local text file. Non-admin users can only write " + "inside the MoviePilot config, Agent memory/activity, and log directories." + ) args_schema: Type[BaseModel] = WriteFileInput - require_admin: bool = True def get_tool_message(self, **kwargs) -> Optional[str]: """根据参数生成友好的提示消息""" @@ -39,10 +40,16 @@ class WriteFileTool(MoviePilotTool): logger.info(f"执行工具: {self.name}, 参数: file_path={file_path}") try: - path = AsyncPath(file_path) + resolved_path, access_error = await self._check_local_file_access( + file_path, operation="写入" + ) + if access_error: + return access_error + + path = AsyncPath(resolved_path) if await path.exists() and not await path.is_file(): - return f"错误:{file_path} 路径已存在但不是一个文件" + return f"错误:{resolved_path} 路径已存在但不是一个文件" # 自动创建父目录 await path.parent.mkdir(parents=True, exist_ok=True) @@ -50,8 +57,8 @@ class WriteFileTool(MoviePilotTool): # 写入文件 await path.write_text(content, encoding="utf-8") - logger.info(f"成功写入文件 {file_path}") - return f"成功写入文件 {file_path}" + logger.info(f"成功写入文件 {resolved_path}") + return f"成功写入文件 {resolved_path}" except PermissionError: return f"错误:没有权限写入 {file_path}" diff --git a/app/agent/tools/manager.py b/app/agent/tools/manager.py index 96301d23..17ec6ef9 100644 --- a/app/agent/tools/manager.py +++ b/app/agent/tools/manager.py @@ -55,6 +55,7 @@ class MoviePilotToolsManager: source="api", username="API Client", stream_handler=None, + agent_context={"is_admin": self.is_admin}, ) logger.info(f"成功加载 {len(self.tools)} 个工具") except Exception as e: diff --git a/docs/mcp-api.md b/docs/mcp-api.md index 345a494b..0c0654ec 100644 --- a/docs/mcp-api.md +++ b/docs/mcp-api.md @@ -108,8 +108,6 @@ MoviePilot 也提供普通 REST API 给前端和自动化客户端使用。所 | GET | `/api/v1/download/paths` | 查询可用于下载接口 `save_path` 参数的下载路径 | | DELETE | `/api/v1/download/{hashString}` | 删除下载任务,参数:`name` | -MCP 工具 `query_download_tasks` 支持 `status=all|downloading|completed|paused`;其中 `completed` 表示下载器任务既不是下载中,也不是暂停状态。默认仅查询带 MoviePilot 内置标签的任务;如需诊断下载器中未打内置标签的任务,可传 `include_all_tags=true`。 - #### 系统 | 方法 | 路径 | 说明 | diff --git a/skills/browser-use/SKILL.md b/skills/browser-use/SKILL.md index e197c266..9faf6384 100644 --- a/skills/browser-use/SKILL.md +++ b/skills/browser-use/SKILL.md @@ -49,6 +49,8 @@ dedicated tool can complete the task more directly and safely. `google`, `brave`, etc.) and `site_url` for limiting results to a specified domain or URL path. It uses the configured system proxy by default. - `query_sites` - Get MoviePilot site IDs before site-specific operations. + Non-admin callers receive a safe view without Cookie, RSS, Token, or API Key + fields. - `update_site_cookie` - Update a configured site's Cookie and User-Agent using username, password, and optional two-step code. - `test_site` - Verify configured site connectivity and login status. diff --git a/tests/test_agent_resource_flow_permissions.py b/tests/test_agent_resource_flow_permissions.py new file mode 100644 index 00000000..978409a9 --- /dev/null +++ b/tests/test_agent_resource_flow_permissions.py @@ -0,0 +1,317 @@ +"""Agent 资源流程工具权限测试。""" + +import asyncio +import json +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch + +from app.agent.tools.impl.edit_file import EditFileTool +from app.agent.tools.impl.list_directory import ListDirectoryTool +from app.agent.tools.impl.query_downloaders import QueryDownloadersTool +from app.agent.tools.impl.query_sites import QuerySitesTool +from app.agent.tools.impl.read_file import ReadFileTool +from app.agent.tools.impl.send_local_file import SendLocalFileTool +from app.agent.tools.impl.write_file import WriteFileTool +from app.agent.tools.manager import MoviePilotToolsManager +from app.agent import MoviePilotAgent +from app.core.config import settings +from app.schemas.types import MessageChannel + + +def test_non_admin_manager_exposes_resource_flow_helper_tools(): + """普通用户应能看到搜索、订阅、下载流程所需的辅助工具。""" + site_tool = QuerySitesTool(session_id="session-1", user_id="10001") + downloader_tool = QueryDownloadersTool(session_id="session-1", user_id="10001") + + with patch( + "app.agent.tools.manager.MoviePilotToolFactory.create_tools", + return_value=[site_tool, downloader_tool], + ): + manager = MoviePilotToolsManager(is_admin=False) + + tool_names = {tool.name for tool in manager.list_tools()} + assert "query_sites" in tool_names + assert "query_downloaders" in tool_names + + +def test_non_admin_manager_exposes_restricted_file_tools(): + """普通用户应能看到受目录边界限制的文件读写工具。""" + tools = [ + ReadFileTool(session_id="session-1", user_id="10001"), + WriteFileTool(session_id="session-1", user_id="10001"), + EditFileTool(session_id="session-1", user_id="10001"), + ListDirectoryTool(session_id="session-1", user_id="10001"), + ] + + with patch( + "app.agent.tools.manager.MoviePilotToolFactory.create_tools", + return_value=tools, + ): + manager = MoviePilotToolsManager(is_admin=False) + + tool_names = {tool.name for tool in manager.list_tools()} + assert {"read_file", "write_file", "edit_file", "list_directory"} <= tool_names + + +def test_query_sites_hides_only_sensitive_fields_for_non_admin_user(): + """普通用户查询站点时只隐藏 Cookie、API Key、Token 和 RSS。""" + tool = QuerySitesTool(session_id="session-1", user_id="10001") + site = SimpleNamespace( + id=1, + name="TestSite", + domain="secret.example", + url="https://secret.example/", + pri=1, + rss="https://secret.example/rss", + cookie="uid=1; passkey=secret", + ua="SecretUA", + apikey="site-api-key", + token="site-token", + proxy=1, + filter="", + render=0, + public=0, + note={"secret": True}, + limit_interval=0, + limit_count=0, + limit_seconds=0, + timeout=15, + is_active=True, + downloader="qb", + ) + + with patch( + "app.agent.tools.impl.query_sites.SiteOper" + ) as site_oper: + site_oper.return_value.async_list = AsyncMock(return_value=[site]) + result = asyncio.run(tool.run()) + + payload = json.loads(result) + assert payload == [ + { + "id": 1, + "name": "TestSite", + "domain": "secret.example", + "url": "https://secret.example/", + "pri": 1, + "is_active": True, + "downloader": "qb", + "ua": "SecretUA", + "proxy": 1, + "filter": "", + "render": 0, + "public": 0, + "note": {"secret": True}, + "limit_interval": 0, + "limit_count": 0, + "limit_seconds": 0, + "timeout": 15, + } + ] + assert "cookie" not in payload[0] + assert "rss" not in payload[0] + assert "token" not in payload[0] + assert "apikey" not in payload[0] + + +def test_query_sites_keeps_full_fields_for_admin_context(): + """管理员查询站点时保留完整配置视图。""" + tool = QuerySitesTool(session_id="session-1", user_id="admin") + tool.set_agent_context({"is_admin": True}) + site = SimpleNamespace( + id=1, + name="TestSite", + domain="secret.example", + url="https://secret.example/", + pri=1, + rss="https://secret.example/rss", + cookie="uid=1; passkey=secret", + ua="SecretUA", + apikey="site-api-key", + token="site-token", + proxy=1, + filter="", + render=0, + public=0, + note={"secret": True}, + limit_interval=0, + limit_count=0, + limit_seconds=0, + timeout=15, + is_active=True, + downloader="qb", + ) + + with patch( + "app.agent.tools.impl.query_sites.SiteOper" + ) as site_oper: + site_oper.return_value.async_list = AsyncMock(return_value=[site]) + result = asyncio.run(tool.run()) + + payload = json.loads(result) + assert payload[0]["cookie"] == "uid=1; passkey=secret" + assert payload[0]["token"] == "site-token" + assert payload[0]["apikey"] == "site-api-key" + assert payload[0]["url"] == "https://secret.example/" + + +def test_non_admin_file_tools_can_access_config_directory(tmp_path, monkeypatch): + """普通用户可在配置目录内读写和编辑文件。""" + config_path = tmp_path / "config" + monkeypatch.setattr(settings, "CONFIG_DIR", str(config_path)) + memory_path = settings.CONFIG_PATH / "agent" / "memory" / "MEMORY.md" + + write_tool = WriteFileTool(session_id="session-1", user_id="10001") + read_tool = ReadFileTool(session_id="session-1", user_id="10001") + edit_tool = EditFileTool(session_id="session-1", user_id="10001") + + write_result = asyncio.run(write_tool.run(str(memory_path), "hello")) + read_result = asyncio.run(read_tool.run(str(memory_path))) + edit_result = asyncio.run(edit_tool.run(str(memory_path), "hello", "hello mp")) + edited_content = memory_path.read_text(encoding="utf-8") + + assert "成功写入文件" in write_result + assert read_result == "hello" + assert "成功编辑文件" in edit_result + assert edited_content == "hello mp" + + +def test_non_admin_file_tools_block_paths_outside_allowed_roots( + tmp_path, monkeypatch +): + """普通用户不能通过文件工具访问配置、记忆和日志目录外的路径。""" + config_path = tmp_path / "config" + outside_path = tmp_path / "outside.txt" + outside_path.write_text("secret", encoding="utf-8") + monkeypatch.setattr(settings, "CONFIG_DIR", str(config_path)) + + read_tool = ReadFileTool(session_id="session-1", user_id="10001") + write_tool = WriteFileTool(session_id="session-1", user_id="10001") + edit_tool = EditFileTool(session_id="session-1", user_id="10001") + list_tool = ListDirectoryTool(session_id="session-1", user_id="10001") + send_tool = SendLocalFileTool(session_id="session-1", user_id="10001") + send_tool.set_message_attr( + channel=MessageChannel.Telegram.value, + source="telegram-main", + username="normal-user", + ) + + read_result = asyncio.run(read_tool.run(str(outside_path))) + write_result = asyncio.run(write_tool.run(str(outside_path), "changed")) + edit_result = asyncio.run(edit_tool.run(str(outside_path), "secret", "changed")) + with patch.object(ListDirectoryTool, "_list_directory_sync") as list_directory: + list_result = asyncio.run(list_tool.run(str(tmp_path))) + send_result = asyncio.run(send_tool.run(str(outside_path))) + + assert "普通用户只能读取" in read_result + assert "普通用户只能写入" in write_result + assert "普通用户只能编辑" in edit_result + assert "普通用户只能列出" in list_result + assert "普通用户只能发送" in send_result + assert outside_path.read_text(encoding="utf-8") == "secret" + list_directory.assert_not_called() + + +def test_admin_file_tool_can_access_paths_outside_allowed_roots( + tmp_path, monkeypatch +): + """管理员上下文不受普通用户文件访问边界限制。""" + config_path = tmp_path / "config" + outside_path = tmp_path / "outside.txt" + monkeypatch.setattr(settings, "CONFIG_DIR", str(config_path)) + + tool = WriteFileTool(session_id="session-1", user_id="admin") + tool.set_agent_context({"is_admin": True}) + + result = asyncio.run(tool.run(str(outside_path), "admin write")) + + assert "成功写入文件" in result + assert outside_path.read_text(encoding="utf-8") == "admin write" + + +def test_query_downloaders_hides_sensitive_fields_for_non_admin_user(): + """普通用户查询下载器时只返回选择下载器所需的安全字段。""" + tool = QueryDownloadersTool(session_id="session-1", user_id="10001") + downloaders = [ + { + "name": "qb", + "type": "qbittorrent", + "enabled": True, + "host": "http://127.0.0.1", + "port": 8080, + "username": "admin", + "password": "secret", + "apikey": "downloader-api-key", + "token": "downloader-token", + } + ] + + with patch( + "app.agent.tools.impl.query_downloaders.SystemConfigOper" + ) as system_config_oper: + system_config_oper.return_value.get.return_value = downloaders + result = asyncio.run(tool.run()) + + payload = json.loads(result) + assert payload == [ + { + "name": "qb", + "type": "qbittorrent", + "enabled": True, + } + ] + assert "host" not in payload[0] + assert "username" not in payload[0] + assert "password" not in payload[0] + assert "apikey" not in payload[0] + assert "token" not in payload[0] + + +def test_query_downloaders_keeps_full_fields_for_admin_context(): + """管理员查询下载器时保留完整配置视图。""" + tool = QueryDownloadersTool(session_id="session-1", user_id="admin") + tool.set_agent_context({"is_admin": True}) + downloaders = [ + { + "name": "qb", + "type": "qbittorrent", + "enabled": True, + "host": "http://127.0.0.1", + "username": "admin", + "password": "secret", + "apikey": "downloader-api-key", + } + ] + + with patch( + "app.agent.tools.impl.query_downloaders.SystemConfigOper" + ) as system_config_oper: + system_config_oper.return_value.get.return_value = downloaders + result = asyncio.run(tool.run()) + + payload = json.loads(result) + assert payload[0]["host"] == "http://127.0.0.1" + assert payload[0]["username"] == "admin" + assert payload[0]["password"] == "secret" + assert payload[0]["apikey"] == "downloader-api-key" + + +def test_channel_agent_admin_user_id_does_not_bypass_user_lookup(): + """渠道用户 ID 恰好为 admin 时,不应绕过真实系统用户权限判断。""" + agent = MoviePilotAgent( + session_id="session-1", + user_id="admin", + channel=MessageChannel.Telegram.value, + source="telegram-main", + username="normal-user", + ) + + with patch("app.agent.UserOper") as user_oper: + user_oper.return_value.async_get_by_name.return_value = SimpleNamespace( + is_superuser=False + ) + context = asyncio.run( + agent._build_tool_context(should_dispatch_reply=True) + ) + + assert context["is_admin"] is False diff --git a/tests/test_agent_subagents.py b/tests/test_agent_subagents.py index 2a78810c..50df5f23 100644 --- a/tests/test_agent_subagents.py +++ b/tests/test_agent_subagents.py @@ -1,6 +1,5 @@ import asyncio import json -import unittest from pathlib import Path from types import SimpleNamespace from unittest.mock import patch @@ -19,134 +18,132 @@ from app.agent.middleware.subagents import ( from app.agent.tools.tags import ToolTag -class TestAgentSubagents(unittest.TestCase): - def test_create_subagent_middlewares_registers_task_tool(self): - """子代理中间件应向主 Agent 注册 task 委派工具。""" - model = FakeListChatModel(responses=["ok"]) +def test_create_subagent_middlewares_registers_task_tool(): + """子代理中间件应向主 Agent 注册 task 委派工具。""" + model = FakeListChatModel(responses=["ok"]) - middlewares, task_tools = create_subagent_middlewares( - model=model, - tools=[], - stream_handler=None, - ) + middlewares, task_tools = create_subagent_middlewares( + model=model, + tools=[], + stream_handler=None, + ) - self.assertEqual(len(middlewares), 3) - self.assertEqual( - [tool.name for tool in task_tools], - [SUBAGENT_TASK_TOOL_NAME, SUBAGENT_CONTROL_TOOL_NAME], - ) - self.assertIn("media-researcher", task_tools[0].description) - self.assertIn("moviepilot-explorer", task_tools[0].description) - self.assertIn("system-diagnostician", task_tools[0].description) - self.assertIn("action=start", task_tools[1].description) - self.assertIn("action=wait", task_tools[1].description) - - def test_subagent_tools_are_selected_by_tags(self): - """子代理应根据工具标签筛选工具,而不是依赖工具名名单。""" - model = FakeListChatModel(responses=["ok"]) - tools = [ - SimpleNamespace( - name="custom_media_lookup", - tags=[ToolTag.Read.value, ToolTag.Media.value], - ), - SimpleNamespace( - name="custom_media_writer", - tags=[ToolTag.Read.value, ToolTag.Write.value, ToolTag.Media.value], - ), - SimpleNamespace( - name="custom_site_lookup", - tags=[ToolTag.Read.value, ToolTag.Site.value], - ), - ] - captured = {} - - def _fake_create_agent(**kwargs): - captured.update(kwargs) - return kwargs - - middleware = MoviePilotSubAgentMiddleware( - model=model, - profiles=subagent_module._builtin_subagent_profiles(), - tools=tools, - ) - - with patch.object(subagent_module, "create_agent", side_effect=_fake_create_agent): - middleware._get_agent("media-researcher") - - self.assertEqual( - [tool.name for tool in captured["tools"]], - ["custom_media_lookup"], - ) - - def test_moviepilot_explorer_selects_code_and_settings_tools(self): - """MoviePilot 探索子代理应能读取代码、目录、设置和命令诊断工具。""" - model = FakeListChatModel(responses=["ok"]) - tools = [ - SimpleNamespace( - name="custom_code_reader", - tags=[ToolTag.Read.value, ToolTag.File.value], - ), - SimpleNamespace( - name="custom_directory_lister", - tags=[ToolTag.Read.value, ToolTag.Directory.value], - ), - SimpleNamespace( - name="custom_settings_reader", - tags=[ToolTag.Read.value, ToolTag.Settings.value], - ), - SimpleNamespace( - name="custom_command_runner", - tags=[ToolTag.Read.value, ToolTag.Command.value], - ), - SimpleNamespace( - name="custom_code_writer", - tags=[ToolTag.Read.value, ToolTag.Write.value, ToolTag.File.value], - ), - ] - captured = {} - - def _fake_create_agent(**kwargs): - captured.update(kwargs) - return kwargs - - middleware = MoviePilotSubAgentMiddleware( - model=model, - profiles=subagent_module._builtin_subagent_profiles(), - tools=tools, - ) - - with patch.object(subagent_module, "create_agent", side_effect=_fake_create_agent): - middleware._get_agent("moviepilot-explorer") - - self.assertEqual( - [tool.name for tool in captured["tools"]], - [ - "custom_code_reader", - "custom_directory_lister", - "custom_settings_reader", - "custom_command_runner", - ], - ) - - def test_builtin_tools_declare_tags_in_implementation(self): - """所有内置工具实现都应显式声明 tags。""" - impl_dir = Path(__file__).resolve().parents[1] / "app" / "agent" / "tools" / "impl" - missing_tools = [] - for path in sorted(impl_dir.glob("*.py")): - text = path.read_text() - for block in text.split("\nclass "): - if "(MoviePilotTool)" not in block: - continue - class_name = block.split("(", 1)[0].strip() - if "tags: list[str]" not in block: - missing_tools.append(f"{path.name}:{class_name}") - - self.assertEqual([], missing_tools) + assert len(middlewares) == 3 + assert [tool.name for tool in task_tools] == [ + SUBAGENT_TASK_TOOL_NAME, + SUBAGENT_CONTROL_TOOL_NAME, + ] + assert "media-researcher" in task_tools[0].description + assert "moviepilot-explorer" in task_tools[0].description + assert "system-diagnostician" in task_tools[0].description + assert "action=start" in task_tools[1].description + assert "action=wait" in task_tools[1].description + assert "action=pipeline" in task_tools[1].description -class TestSubAgentTaskControlMiddleware(unittest.IsolatedAsyncioTestCase): - async def test_call_summary_middleware_logs_subagent_tool_operations(self): - """子代理工具包装层应记录工具执行开始和完成日志。""" +def test_subagent_tools_are_selected_by_tags(): + """子代理应根据工具标签筛选工具,而不是依赖工具名名单。""" + model = FakeListChatModel(responses=["ok"]) + tools = [ + SimpleNamespace( + name="custom_media_lookup", + tags=[ToolTag.Read.value, ToolTag.Media.value], + ), + SimpleNamespace( + name="custom_media_writer", + tags=[ToolTag.Read.value, ToolTag.Write.value, ToolTag.Media.value], + ), + SimpleNamespace( + name="custom_site_lookup", + tags=[ToolTag.Read.value, ToolTag.Site.value], + ), + ] + captured = {} + + def _fake_create_agent(**kwargs): + captured.update(kwargs) + return kwargs + + middleware = MoviePilotSubAgentMiddleware( + model=model, + profiles=subagent_module._builtin_subagent_profiles(), + tools=tools, + ) + + with patch.object(subagent_module, "create_agent", side_effect=_fake_create_agent): + middleware._get_agent("media-researcher") + + assert [tool.name for tool in captured["tools"]] == ["custom_media_lookup"] + + +def test_moviepilot_explorer_selects_code_and_settings_tools(): + """MoviePilot 探索子代理应能读取代码、目录、设置和命令诊断工具。""" + model = FakeListChatModel(responses=["ok"]) + tools = [ + SimpleNamespace( + name="custom_code_reader", + tags=[ToolTag.Read.value, ToolTag.File.value], + ), + SimpleNamespace( + name="custom_directory_lister", + tags=[ToolTag.Read.value, ToolTag.Directory.value], + ), + SimpleNamespace( + name="custom_settings_reader", + tags=[ToolTag.Read.value, ToolTag.Settings.value], + ), + SimpleNamespace( + name="custom_command_runner", + tags=[ToolTag.Read.value, ToolTag.Command.value], + ), + SimpleNamespace( + name="custom_code_writer", + tags=[ToolTag.Read.value, ToolTag.Write.value, ToolTag.File.value], + ), + ] + captured = {} + + def _fake_create_agent(**kwargs): + captured.update(kwargs) + return kwargs + + middleware = MoviePilotSubAgentMiddleware( + model=model, + profiles=subagent_module._builtin_subagent_profiles(), + tools=tools, + ) + + with patch.object(subagent_module, "create_agent", side_effect=_fake_create_agent): + middleware._get_agent("moviepilot-explorer") + + assert [tool.name for tool in captured["tools"]] == [ + "custom_code_reader", + "custom_directory_lister", + "custom_settings_reader", + "custom_command_runner", + ] + + +def test_builtin_tools_declare_tags_in_implementation(): + """所有内置工具实现都应显式声明 tags。""" + impl_dir = Path(__file__).resolve().parents[1] / "app" / "agent" / "tools" / "impl" + missing_tools = [] + for path in sorted(impl_dir.glob("*.py")): + text = path.read_text() + for block in text.split("\nclass "): + if "(MoviePilotTool)" not in block: + continue + class_name = block.split("(", 1)[0].strip() + if "tags: list[str]" not in block: + missing_tools.append(f"{path.name}:{class_name}") + + assert missing_tools == [] + + +def test_call_summary_middleware_logs_subagent_tool_operations(): + """子代理工具包装层应记录工具执行开始和完成日志。""" + + async def _run_test(): middleware = SubAgentCallSummaryMiddleware() request = SimpleNamespace( tool=SimpleNamespace(name=SUBAGENT_CONTROL_TOOL_NAME), @@ -165,12 +162,17 @@ class TestSubAgentTaskControlMiddleware(unittest.IsolatedAsyncioTestCase): result = await middleware.awrap_tool_call(request, _fake_handler) messages = [call.args[0] for call in log_info.call_args_list] - self.assertEqual("ok", result) - self.assertTrue(any("开始执行子代理工具" in message for message in messages)) - self.assertTrue(any("子代理工具执行完成" in message for message in messages)) + assert result == "ok" + assert any("开始执行子代理工具" in message for message in messages) + assert any("子代理工具执行完成" in message for message in messages) - async def test_control_tool_starts_tasks_concurrently_and_waits(self): - """异步子代理管控工具应批量启动任务,并在 wait 时收集结果。""" + asyncio.run(_run_test()) + + +def test_control_tool_starts_tasks_concurrently_and_waits(): + """异步子代理管控工具应批量启动任务,并在 wait 时收集结果。""" + + async def _run_test(): model = FakeListChatModel(responses=["ok"]) middleware = SubAgentTaskControlMiddleware( model=model, @@ -221,21 +223,154 @@ class TestSubAgentTaskControlMiddleware(unittest.IsolatedAsyncioTestCase): ) ) - self.assertTrue(start_payload["success"]) - self.assertEqual(2, len(task_ids)) - self.assertEqual(["检查媒体库", "检查下载器"], running_descriptions) - self.assertEqual( - ["completed", "completed"], - [task["status"] for task in wait_payload["tasks"]], - ) - self.assertIn("media-researcher:检查媒体库", wait_payload["tasks"][0]["result"]) - self.assertIn( - "download-diagnostician:检查下载器", - wait_payload["tasks"][1]["result"], + assert start_payload["success"] + assert len(task_ids) == 2 + assert running_descriptions == ["检查媒体库", "检查下载器"] + assert [task["status"] for task in wait_payload["tasks"]] == [ + "completed", + "completed", + ] + assert "media-researcher:检查媒体库" in wait_payload["tasks"][0]["result"] + assert ( + "download-diagnostician:检查下载器" + in wait_payload["tasks"][1]["result"] ) - async def test_after_agent_cancels_unfinished_tasks(self): - """Agent 结束时应取消仍在运行的异步子代理任务。""" + asyncio.run(_run_test()) + + +def test_control_tool_pipeline_passes_previous_results_to_next_step(): + """管道模式应顺序执行子代理,并把上一步结果作为下一步私有上下文。""" + + async def _run_test(): + model = FakeListChatModel(responses=["ok"]) + middleware = SubAgentTaskControlMiddleware( + model=model, + profiles=subagent_module._builtin_subagent_profiles(), + tools=[], + ) + calls = [] + + async def _fake_run_task(self, *, description, subagent_type, task_id=None): + calls.append( + { + "description": description, + "subagent_type": subagent_type, + "task_id": task_id, + } + ) + return f"结果-{len(calls)}" + + with patch.object( + subagent_module._SubAgentAgentProvider, + "run_task", + new=_fake_run_task, + ): + payload = json.loads( + await middleware._control_task( + action="pipeline", + tasks=[ + { + "description": "识别媒体", + "subagent_type": "media-researcher", + }, + { + "description": "检查下载", + "subagent_type": "download-diagnostician", + }, + { + "description": "汇总结论", + "subagent_type": "general-purpose", + }, + ], + timeout_ms=1000, + ) + ) + + assert payload["success"] + assert [call["subagent_type"] for call in calls] == [ + "media-researcher", + "download-diagnostician", + "general-purpose", + ] + assert calls[0]["description"] == "识别媒体" + assert "结果-1" in calls[1]["description"] + assert "结果-1" in calls[2]["description"] + assert "结果-2" in calls[2]["description"] + assert [task["status"] for task in payload["tasks"]] == [ + "completed", + "completed", + "completed", + ] + assert [task["result"] for task in payload["tasks"]] == [ + "结果-1", + "结果-2", + "结果-3", + ] + + asyncio.run(_run_test()) + + +def test_control_tool_pipeline_stops_after_failed_step(): + """管道模式遇到失败步骤时应中断后续子代理。""" + + async def _run_test(): + model = FakeListChatModel(responses=["ok"]) + middleware = SubAgentTaskControlMiddleware( + model=model, + profiles=subagent_module._builtin_subagent_profiles(), + tools=[], + ) + calls = [] + + async def _fake_run_task(self, *, description, subagent_type, task_id=None): + calls.append(subagent_type) + if subagent_type == "download-diagnostician": + raise RuntimeError("下载器不可用") + return f"{subagent_type}:ok" + + with patch.object( + subagent_module._SubAgentAgentProvider, + "run_task", + new=_fake_run_task, + ): + payload = json.loads( + await middleware._control_task( + action="pipeline", + tasks=[ + { + "description": "识别媒体", + "subagent_type": "media-researcher", + }, + { + "description": "检查下载", + "subagent_type": "download-diagnostician", + }, + { + "description": "汇总结论", + "subagent_type": "general-purpose", + }, + ], + timeout_ms=1000, + ) + ) + + assert not payload["success"] + assert "第 2 个管道子代理任务执行失败" in payload["error"] + assert calls == ["media-researcher", "download-diagnostician"] + assert [task["status"] for task in payload["tasks"]] == [ + "completed", + "failed", + ] + assert "下载器不可用" in payload["tasks"][1]["error"] + + asyncio.run(_run_test()) + + +def test_after_agent_cancels_unfinished_tasks(): + """Agent 结束时应取消仍在运行的异步子代理任务。""" + + async def _run_test(): model = FakeListChatModel(responses=["ok"]) middleware = SubAgentTaskControlMiddleware( model=model, @@ -269,4 +404,6 @@ class TestSubAgentTaskControlMiddleware(unittest.IsolatedAsyncioTestCase): ) ) - self.assertEqual("cancelled", status_payload["tasks"][0]["status"]) + assert status_payload["tasks"][0]["status"] == "cancelled" + + asyncio.run(_run_test())