mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-17 13:41:57 +08:00
feat: enhance user permissions handling for admin and non-admin contexts
This commit is contained in:
@@ -50,6 +50,7 @@ from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager
|
||||
from app.db.user_oper import UserOper
|
||||
from app.log import logger
|
||||
from app.schemas import AgentLLMProviderEventData, AgentTokensUsageEventData, Notification, NotificationType
|
||||
from app.schemas.message import ChannelCapabilityManager, ChannelCapability
|
||||
@@ -418,6 +419,38 @@ class MoviePilotAgent:
|
||||
"""
|
||||
return self.session_id.startswith(HEARTBEAT_SESSION_PREFIX)
|
||||
|
||||
async def _is_system_admin_context(self) -> bool:
|
||||
"""
|
||||
判断当前 Agent 会话是否应按系统管理员上下文运行工具。
|
||||
"""
|
||||
if self.is_background:
|
||||
return True
|
||||
if self.channel == MessageChannel.Web.value and self.source in {
|
||||
"openai",
|
||||
"openai.responses",
|
||||
"anthropic",
|
||||
}:
|
||||
return True
|
||||
if not self.username:
|
||||
return False
|
||||
try:
|
||||
user = await UserOper().async_get_by_name(self.username)
|
||||
except Exception as e:
|
||||
logger.error(f"检查 Agent 用户管理员身份失败: {e}")
|
||||
return False
|
||||
return bool(user and user.is_superuser)
|
||||
|
||||
async def _build_tool_context(self, should_dispatch_reply: bool) -> Dict[str, object]:
|
||||
"""
|
||||
构造本轮工具共享上下文。
|
||||
"""
|
||||
return {
|
||||
"user_reply_sent": False,
|
||||
"reply_mode": None,
|
||||
"should_dispatch_reply": should_dispatch_reply,
|
||||
"is_admin": await self._is_system_admin_context(),
|
||||
}
|
||||
|
||||
def _should_stream(self) -> bool:
|
||||
"""
|
||||
判断是否应启用流式输出:
|
||||
@@ -804,6 +837,7 @@ class MoviePilotAgent:
|
||||
"user_reply_sent": False,
|
||||
"reply_mode": None,
|
||||
"should_dispatch_reply": False,
|
||||
"is_admin": bool(self._tool_context.get("is_admin")),
|
||||
},
|
||||
allow_message_tools=False,
|
||||
)
|
||||
@@ -920,11 +954,9 @@ class MoviePilotAgent:
|
||||
f"images={len(images) if images else 0}, files={len(files) if files else 0}, "
|
||||
f"audio_input={has_audio_input}"
|
||||
)
|
||||
self._tool_context = {
|
||||
"user_reply_sent": False,
|
||||
"reply_mode": None,
|
||||
"should_dispatch_reply": self.should_dispatch_reply,
|
||||
}
|
||||
self._tool_context = await self._build_tool_context(
|
||||
should_dispatch_reply=self.should_dispatch_reply
|
||||
)
|
||||
self._streamed_output = ""
|
||||
|
||||
# 获取历史消息
|
||||
|
||||
@@ -39,6 +39,7 @@ SUBAGENT_MAX_ACTIVE_TASKS = 8
|
||||
SUBAGENT_MAX_CONCURRENT_TASKS = 4
|
||||
SUBAGENT_RESULT_MAX_CHARS = 12000
|
||||
SUBAGENT_DESCRIPTION_MAX_CHARS = 500
|
||||
SUBAGENT_PIPELINE_CONTEXT_MAX_CHARS = 12000
|
||||
|
||||
SUBAGENT_PARENT_PROMPT = """<subagents>
|
||||
You may use subagent tools to delegate independent research, retrieval,
|
||||
@@ -51,6 +52,9 @@ Delegation modes:
|
||||
`action=wait`, or `action=cancel` with the returned task IDs.
|
||||
- Use `subagent_task` with `action=run` when you want to launch a bounded
|
||||
batch and wait for the batch in one tool call.
|
||||
- Use `subagent_task` with `action=pipeline` when later subtasks must use
|
||||
previous subagent results. Pipeline steps run sequentially, and each step's
|
||||
result is passed as private context to the next step.
|
||||
|
||||
Rules:
|
||||
- Delegate when a task benefits from focused investigation, such as media identity checks, site/resource search, subscription analysis, download/transfer diagnosis, MoviePilot code/config exploration, or read-only system inspection.
|
||||
@@ -71,7 +75,9 @@ SUBAGENT_CONTROL_DESCRIPTION = (
|
||||
"Use action=start with tasks=[{description, subagent_type}] to launch a batch "
|
||||
"and get task IDs immediately. Use action=status to inspect tasks, action=wait "
|
||||
"to wait for all or any task result, action=cancel to stop running tasks, and "
|
||||
"action=run to launch a bounded batch and wait in one call."
|
||||
"action=run to launch a bounded batch and wait in one call. Use action=pipeline "
|
||||
"to run tasks sequentially while passing each result as private context to the "
|
||||
"next task."
|
||||
)
|
||||
|
||||
SUBAGENT_BASE_PROMPT = """You are a silent subagent working for the MoviePilot main agent.
|
||||
@@ -120,9 +126,9 @@ class _SubAgentTaskSpec(BaseModel):
|
||||
class _SubAgentControlInput(BaseModel):
|
||||
"""异步子代理管控工具输入。"""
|
||||
|
||||
action: Literal["start", "status", "wait", "cancel", "run"] = Field(
|
||||
action: Literal["start", "status", "wait", "cancel", "run", "pipeline"] = Field(
|
||||
default="start",
|
||||
description="Task action: start, status, wait, cancel, or run.",
|
||||
description="Task action: start, status, wait, cancel, run, or pipeline.",
|
||||
)
|
||||
description: Optional[str] = Field(
|
||||
default=None,
|
||||
@@ -150,7 +156,10 @@ class _SubAgentControlInput(BaseModel):
|
||||
)
|
||||
timeout_ms: Optional[int] = Field(
|
||||
default=SUBAGENT_DEFAULT_WAIT_TIMEOUT_MS,
|
||||
description="Maximum wait time in milliseconds for action=wait or action=run.",
|
||||
description=(
|
||||
"Maximum wait time in milliseconds for action=wait, action=run, "
|
||||
"or each action=pipeline step."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -742,7 +751,8 @@ class SubAgentTaskControlMiddleware(AgentMiddleware):
|
||||
f"pending={len(pending_tasks) - finished_count}"
|
||||
)
|
||||
|
||||
async def _cancel_records(self, records: list[_SubAgentRuntimeTask]) -> None:
|
||||
@staticmethod
|
||||
async def _cancel_records(records: list[_SubAgentRuntimeTask]) -> None:
|
||||
"""取消一组尚未完成的任务。"""
|
||||
cancellable_tasks = [
|
||||
record.task for record in records if not record.task.done()
|
||||
@@ -755,6 +765,156 @@ class SubAgentTaskControlMiddleware(AgentMiddleware):
|
||||
await asyncio.gather(*cancellable_tasks, return_exceptions=True)
|
||||
logger.info(f"子代理任务取消完成: tasks={len(cancellable_tasks)}")
|
||||
|
||||
@staticmethod
|
||||
def _pipeline_description(
|
||||
*,
|
||||
description: str,
|
||||
previous_results: list[tuple[_SubAgentRuntimeTask, str]],
|
||||
) -> str:
|
||||
"""追加上游子代理结果,生成当前管道步骤的任务描述。"""
|
||||
normalized_description = description.strip()
|
||||
if not previous_results:
|
||||
return normalized_description
|
||||
|
||||
context_parts = []
|
||||
for step_index, (record, result) in enumerate(previous_results, start=1):
|
||||
clipped_result, result_truncated = _clip_text(
|
||||
result,
|
||||
SUBAGENT_RESULT_MAX_CHARS,
|
||||
)
|
||||
truncated_note = "\n[Result truncated]" if result_truncated else ""
|
||||
context_parts.append(
|
||||
f"Step {step_index} ({record.subagent_type}) result:\n"
|
||||
f"{clipped_result}{truncated_note}"
|
||||
)
|
||||
context_text, context_truncated = _clip_text(
|
||||
"\n\n".join(context_parts),
|
||||
SUBAGENT_PIPELINE_CONTEXT_MAX_CHARS,
|
||||
)
|
||||
truncated_note = "\n[Pipeline context truncated]" if context_truncated else ""
|
||||
return (
|
||||
f"{normalized_description}\n\n"
|
||||
"<pipeline_context>\n"
|
||||
"Previous subagent results are private context for this delegated "
|
||||
"subtask. Use them to complete the current task, but do not expose "
|
||||
"the prior reports verbatim.\n\n"
|
||||
f"{context_text}{truncated_note}\n"
|
||||
"</pipeline_context>"
|
||||
)
|
||||
|
||||
async def _execute_pipeline_task(
|
||||
self,
|
||||
*,
|
||||
record: _SubAgentRuntimeTask,
|
||||
description: str,
|
||||
) -> str:
|
||||
"""执行单个管道步骤,保留原始步骤描述用于状态展示。"""
|
||||
async with self._semaphore:
|
||||
record.started_at = datetime.now()
|
||||
logger.info(
|
||||
f"管道子代理任务开始执行: task_id={record.task_id}, "
|
||||
f"subagent_type={record.subagent_type}"
|
||||
)
|
||||
try:
|
||||
result = await self._provider.run_task(
|
||||
description=description,
|
||||
subagent_type=record.subagent_type,
|
||||
task_id=record.task_id,
|
||||
)
|
||||
logger.info(
|
||||
f"管道子代理任务执行完成: task_id={record.task_id}, "
|
||||
f"subagent_type={record.subagent_type}, result_chars={len(result)}"
|
||||
)
|
||||
return result
|
||||
except asyncio.CancelledError:
|
||||
logger.info(
|
||||
f"管道子代理任务已取消: task_id={record.task_id}, "
|
||||
f"subagent_type={record.subagent_type}"
|
||||
)
|
||||
raise
|
||||
except Exception as err:
|
||||
logger.error(f"管道子代理任务执行失败: task_id={record.task_id}, error={err}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _create_pipeline_record(
|
||||
spec: _SubAgentTaskSpec,
|
||||
) -> _SubAgentRuntimeTask:
|
||||
"""创建一个管道步骤记录。"""
|
||||
task_id = f"subagent-{uuid.uuid4().hex[:12]}"
|
||||
return _SubAgentRuntimeTask(
|
||||
task_id=task_id,
|
||||
description=spec.description.strip(),
|
||||
subagent_type=spec.subagent_type or "general-purpose",
|
||||
task=None,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
def _track_pipeline_task(
|
||||
self,
|
||||
record: _SubAgentRuntimeTask,
|
||||
task: asyncio.Task,
|
||||
) -> None:
|
||||
"""登记管道步骤任务,复用统一的状态和异常收口逻辑。"""
|
||||
record.task = task
|
||||
task.add_done_callback(
|
||||
lambda finished_task, finished_task_id=record.task_id: self._mark_task_finished(
|
||||
finished_task_id,
|
||||
finished_task,
|
||||
)
|
||||
)
|
||||
self._tasks[record.task_id] = record
|
||||
|
||||
async def _run_pipeline(
|
||||
self,
|
||||
specs: list[_SubAgentTaskSpec],
|
||||
timeout_ms: Optional[int],
|
||||
) -> tuple[list[_SubAgentRuntimeTask], Optional[str]]:
|
||||
"""按顺序执行管道任务,并把每一步结果传给下一步。"""
|
||||
normalized_timeout_ms = self._normalize_timeout_ms(timeout_ms)
|
||||
if normalized_timeout_ms <= 0:
|
||||
return [], "管道任务需要大于 0 的等待时间。"
|
||||
|
||||
records: list[_SubAgentRuntimeTask] = []
|
||||
previous_results: list[tuple[_SubAgentRuntimeTask, str]] = []
|
||||
timeout = normalized_timeout_ms / 1000
|
||||
for step_index, spec in enumerate(specs, start=1):
|
||||
record = self._create_pipeline_record(spec)
|
||||
records.append(record)
|
||||
pipeline_description = self._pipeline_description(
|
||||
description=record.description,
|
||||
previous_results=previous_results,
|
||||
)
|
||||
task = asyncio.create_task(
|
||||
self._execute_pipeline_task(
|
||||
record=record,
|
||||
description=pipeline_description,
|
||||
),
|
||||
name=record.task_id,
|
||||
)
|
||||
self._track_pipeline_task(record, task)
|
||||
logger.info(
|
||||
f"已启动管道子代理任务: step={step_index}, task_id={record.task_id}, "
|
||||
f"subagent_type={record.subagent_type}"
|
||||
)
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(task, timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
error = f"第 {step_index} 个管道子代理任务等待超时。"
|
||||
logger.info(
|
||||
f"{error} task_id={record.task_id}, timeout_ms={normalized_timeout_ms}"
|
||||
)
|
||||
return records, error
|
||||
except Exception as err:
|
||||
error = f"第 {step_index} 个管道子代理任务执行失败: {err}"
|
||||
logger.info(f"{error} task_id={record.task_id}")
|
||||
return records, error
|
||||
|
||||
previous_results.append((record, result))
|
||||
|
||||
return records, None
|
||||
|
||||
async def _control_task(
|
||||
self,
|
||||
action: str = "start",
|
||||
@@ -768,7 +928,7 @@ class SubAgentTaskControlMiddleware(AgentMiddleware):
|
||||
) -> str:
|
||||
"""管理异步子代理任务。"""
|
||||
logger.info(f"收到子代理管控操作: action={action}")
|
||||
if action in {"start", "run"}:
|
||||
if action in {"start", "run", "pipeline"}:
|
||||
specs, error = self._normalize_specs(
|
||||
description=description,
|
||||
subagent_type=subagent_type,
|
||||
@@ -779,6 +939,20 @@ class SubAgentTaskControlMiddleware(AgentMiddleware):
|
||||
return self._json_response({"success": False, "error": error})
|
||||
|
||||
logger.info(f"准备启动子代理任务: action={action}, tasks={len(specs)}")
|
||||
if action == "pipeline":
|
||||
records, pipeline_error = await self._run_pipeline(
|
||||
specs=specs,
|
||||
timeout_ms=timeout_ms,
|
||||
)
|
||||
return self._json_response(
|
||||
{
|
||||
"success": pipeline_error is None,
|
||||
"action": action,
|
||||
"error": pipeline_error,
|
||||
"tasks": [self._task_output(record) for record in records],
|
||||
}
|
||||
)
|
||||
|
||||
records = self._start_tasks(specs)
|
||||
if action == "run":
|
||||
await self._wait_records(
|
||||
|
||||
@@ -4,6 +4,7 @@ import threading
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, ClassVar, Optional
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
@@ -373,6 +374,119 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
# 独立的新 dict,跨工具状态(例如质量门槛拒绝标记)无法传播。
|
||||
self._agent_context = {} if agent_context is None else agent_context
|
||||
|
||||
async def is_admin_user(self) -> bool:
|
||||
"""
|
||||
判断当前工具调用者是否拥有管理员级权限。
|
||||
|
||||
:return: 当前调用者是系统管理员、渠道管理员或显式管理员上下文时返回 True
|
||||
"""
|
||||
if bool(self._agent_context.get("is_admin")):
|
||||
return True
|
||||
|
||||
if not self._channel or not self._source:
|
||||
return False
|
||||
|
||||
return await self._has_channel_admin_permission()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_local_path(path: str) -> Path:
|
||||
"""
|
||||
解析本地路径并展开符号链接。
|
||||
|
||||
:param path: 用户传入的本地文件或目录路径
|
||||
:return: 规范化后的绝对路径
|
||||
"""
|
||||
return Path(path).expanduser().resolve(strict=False)
|
||||
|
||||
@staticmethod
|
||||
def _is_path_relative_to(path: Path, root: Path) -> bool:
|
||||
"""
|
||||
判断路径是否位于指定目录内。
|
||||
|
||||
:param path: 待检查路径
|
||||
:param root: 允许访问的根目录
|
||||
:return: 路径在根目录内或等于根目录时返回 True
|
||||
"""
|
||||
try:
|
||||
path.relative_to(root)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _get_non_admin_local_file_roots(cls) -> list[Path]:
|
||||
"""
|
||||
获取普通用户可访问的本地文件根目录。
|
||||
|
||||
:return: 普通用户允许读写的本地目录列表
|
||||
"""
|
||||
roots = [
|
||||
settings.CONFIG_PATH,
|
||||
settings.LOG_PATH,
|
||||
settings.CONFIG_PATH / "agent" / "memory",
|
||||
settings.CONFIG_PATH / "agent" / "activity",
|
||||
]
|
||||
resolved_roots = []
|
||||
for root in roots:
|
||||
resolved_root = cls._resolve_local_path(str(root))
|
||||
if resolved_root not in resolved_roots:
|
||||
resolved_roots.append(resolved_root)
|
||||
return resolved_roots
|
||||
|
||||
async def _check_local_file_access(
|
||||
self, path: str, operation: str = "访问"
|
||||
) -> tuple[Optional[Path], Optional[str]]:
|
||||
"""
|
||||
检查当前用户是否可访问指定本地路径。
|
||||
|
||||
:param path: 用户传入的本地文件或目录路径
|
||||
:param operation: 当前操作名称,用于生成拒绝提示
|
||||
:return: 解析后的路径和拒绝原因;拒绝原因为空表示允许访问
|
||||
"""
|
||||
if not path:
|
||||
return None, "错误:路径不能为空"
|
||||
|
||||
resolved_path = self._resolve_local_path(path)
|
||||
if await self.is_admin_user():
|
||||
return resolved_path, None
|
||||
|
||||
allowed_roots = self._get_non_admin_local_file_roots()
|
||||
if any(
|
||||
self._is_path_relative_to(resolved_path, root)
|
||||
for root in allowed_roots
|
||||
):
|
||||
return resolved_path, None
|
||||
|
||||
allowed_text = "、".join(str(root) for root in allowed_roots)
|
||||
return (
|
||||
resolved_path,
|
||||
f"抱歉,普通用户只能{operation}配置目录、Agent记忆目录和日志目录内的文件或目录:{allowed_text}",
|
||||
)
|
||||
|
||||
async def _check_local_storage_access(
|
||||
self,
|
||||
path: str,
|
||||
storage: Optional[str] = "local",
|
||||
operation: str = "访问",
|
||||
) -> tuple[Optional[Path], Optional[str]]:
|
||||
"""
|
||||
检查当前用户是否可访问指定存储路径。
|
||||
|
||||
:param path: 用户传入的文件或目录路径
|
||||
:param storage: 存储类型,普通用户只允许 local
|
||||
:param operation: 当前操作名称,用于生成拒绝提示
|
||||
:return: 本地存储时返回解析后的路径和拒绝原因;远程存储无本地路径
|
||||
"""
|
||||
if (storage or "local") != "local":
|
||||
if await self.is_admin_user():
|
||||
return None, None
|
||||
return (
|
||||
None,
|
||||
f"抱歉,普通用户只能{operation}本地配置目录、Agent记忆目录和日志目录,不能访问远程存储。",
|
||||
)
|
||||
|
||||
return await self._check_local_file_access(path=path, operation=operation)
|
||||
|
||||
async def _check_permission(self) -> Optional[str]:
|
||||
"""
|
||||
检查用户权限:
|
||||
@@ -385,9 +499,28 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
if not self._require_admin:
|
||||
return None
|
||||
|
||||
if await self.is_admin_user():
|
||||
return None
|
||||
|
||||
if not self._channel or not self._source:
|
||||
return None
|
||||
|
||||
return (
|
||||
"抱歉,您没有执行此工具的权限。"
|
||||
"只有渠道管理员或系统管理员才能执行工具操作。"
|
||||
"如需执行工具,请联系渠道管理员将您的用户ID添加到渠道管理员列表中,"
|
||||
"或联系系统管理员为您设置权限。"
|
||||
)
|
||||
|
||||
async def _has_channel_admin_permission(self) -> bool:
|
||||
"""
|
||||
检查当前消息渠道身份是否具备管理员权限。
|
||||
|
||||
:return: 当前渠道用户是渠道管理员、系统管理员或默认接收人时返回 True
|
||||
"""
|
||||
if not self._channel or not self._source:
|
||||
return False
|
||||
|
||||
# 渠道配置来自 SystemConfigOper 内存缓存,可以直接读取;
|
||||
# 只有用户信息需要走异步数据库查询。
|
||||
user_id_str = str(self._user_id) if self._user_id else None
|
||||
@@ -411,7 +544,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
break
|
||||
|
||||
if not channel_type:
|
||||
return None
|
||||
return False
|
||||
|
||||
admin_key_map = {
|
||||
"telegram": "TELEGRAM_ADMINS",
|
||||
@@ -451,7 +584,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
if aid.strip()
|
||||
]
|
||||
if user_id_str and user_id_str in admin_list:
|
||||
return None
|
||||
return True
|
||||
|
||||
user = (
|
||||
await UserOper().async_get_by_name(self._username)
|
||||
@@ -459,14 +592,9 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
else None
|
||||
)
|
||||
if user and user.is_superuser:
|
||||
return None
|
||||
return True
|
||||
|
||||
return (
|
||||
"抱歉,您没有执行此工具的权限。"
|
||||
"只有渠道管理员或系统管理员才能执行工具操作。"
|
||||
"如需执行工具,请联系渠道管理员将您的用户ID添加到渠道管理员列表中,"
|
||||
"或联系系统管理员为您设置权限。"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
user = (
|
||||
await UserOper().async_get_by_name(self._username)
|
||||
@@ -474,22 +602,18 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
else None
|
||||
)
|
||||
if user and user.is_superuser:
|
||||
return None
|
||||
return True
|
||||
|
||||
if user_id_key:
|
||||
config_user_id = config.config.get(user_id_key)
|
||||
if config_user_id and str(config_user_id) == user_id_str:
|
||||
return None
|
||||
return True
|
||||
|
||||
return (
|
||||
"抱歉,您没有执行此工具的权限。"
|
||||
"只有系统管理员才能执行工具操作。"
|
||||
"如需执行工具,请联系系统管理员为您设置权限。"
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"检查权限失败: {e}")
|
||||
|
||||
return None
|
||||
return False
|
||||
|
||||
async def send_tool_message(
|
||||
self, message: str, title: str = "", image: Optional[str] = None
|
||||
|
||||
@@ -83,7 +83,6 @@ class AskUserChoiceTool(MoviePilotTool):
|
||||
"back as the user's next message. Do not also send the same question as plain text."
|
||||
)
|
||||
args_schema: Type[BaseModel] = AskUserChoiceInput
|
||||
require_admin: bool = False
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
message = kwargs.get("message", "") or ""
|
||||
|
||||
@@ -24,11 +24,13 @@ class EditFileTool(MoviePilotTool):
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.File,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = "Edit a file by replacing specific old text with new text. Useful for modifying configuration files, code, or scripts."
|
||||
description: str = (
|
||||
"Edit a local text file by replacing specific old text with new text. "
|
||||
"Non-admin users can only edit files inside the MoviePilot config, "
|
||||
"Agent memory/activity, and log directories."
|
||||
)
|
||||
args_schema: Type[BaseModel] = EditFileInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据参数生成友好的提示消息"""
|
||||
@@ -40,21 +42,27 @@ class EditFileTool(MoviePilotTool):
|
||||
logger.info(f"执行工具: {self.name}, 参数: file_path={file_path}")
|
||||
|
||||
try:
|
||||
path = AsyncPath(file_path)
|
||||
resolved_path, access_error = await self._check_local_file_access(
|
||||
file_path, operation="编辑"
|
||||
)
|
||||
if access_error:
|
||||
return access_error
|
||||
|
||||
path = AsyncPath(resolved_path)
|
||||
# 校验逻辑:如果要替换特定文本,文件必须存在且包含该文本
|
||||
if not await path.exists():
|
||||
# 如果 old_text 为空,可能用户想直接创建文件,但通常 edit_file 需要匹配旧内容
|
||||
if old_text:
|
||||
return f"错误:文件 {file_path} 不存在,无法进行内容替换。"
|
||||
return f"错误:文件 {resolved_path} 不存在,无法进行内容替换。"
|
||||
|
||||
if await path.exists() and not await path.is_file():
|
||||
return f"错误:{file_path} 不是一个文件"
|
||||
return f"错误:{resolved_path} 不是一个文件"
|
||||
|
||||
if await path.exists():
|
||||
content = await path.read_text(encoding="utf-8")
|
||||
if old_text not in content:
|
||||
logger.warning(f"编辑文件 {file_path} 失败:未找到指定的旧文本块")
|
||||
return f"错误:在文件 {file_path} 中未找到指定的旧文本。请确保包含所有的空格、缩进 and 换行符。"
|
||||
logger.warning(f"编辑文件 {resolved_path} 失败:未找到指定的旧文本块")
|
||||
return f"错误:在文件 {resolved_path} 中未找到指定的旧文本。请确保包含所有的空格、缩进 and 换行符。"
|
||||
occurrences = content.count(old_text)
|
||||
new_content = content.replace(old_text, new_text)
|
||||
else:
|
||||
@@ -68,8 +76,8 @@ class EditFileTool(MoviePilotTool):
|
||||
# 写入文件
|
||||
await path.write_text(new_content, encoding="utf-8")
|
||||
|
||||
logger.info(f"成功编辑文件 {file_path},替换了 {occurrences} 处内容")
|
||||
return f"成功编辑文件 {file_path} (替换了 {occurrences} 处匹配内容)"
|
||||
logger.info(f"成功编辑文件 {resolved_path},替换了 {occurrences} 处内容")
|
||||
return f"成功编辑文件 {resolved_path} (替换了 {occurrences} 处匹配内容)"
|
||||
|
||||
except PermissionError:
|
||||
return f"错误:没有访问/修改 {file_path} 的权限"
|
||||
|
||||
@@ -116,6 +116,13 @@ class ListDirectoryTool(MoviePilotTool):
|
||||
logger.info(f"执行工具: {self.name}, 参数: path={path}, storage={storage}, sort_by={sort_by}")
|
||||
|
||||
try:
|
||||
resolved_path, access_error = await self._check_local_storage_access(
|
||||
path=path, storage=storage, operation="列出"
|
||||
)
|
||||
if access_error:
|
||||
return access_error
|
||||
if resolved_path:
|
||||
path = str(resolved_path)
|
||||
return await self.run_blocking(
|
||||
"storage", self._list_directory_sync, path, storage, sort_by
|
||||
)
|
||||
|
||||
@@ -22,10 +22,12 @@ class QueryDownloadersTool(MoviePilotTool):
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Download,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = "Query downloader configuration and list all available downloaders. Shows downloader status, connection details, and configuration settings."
|
||||
require_admin: bool = True
|
||||
description: str = (
|
||||
"Query downloader configuration and list available downloaders. Non-admin users receive "
|
||||
"a safe view with only the fields needed to choose a downloader, without host, account, "
|
||||
"password, token or API key values."
|
||||
)
|
||||
args_schema: Type[BaseModel] = QueryDownloadersInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
@@ -37,11 +39,35 @@ class QueryDownloadersTool(MoviePilotTool):
|
||||
"""从内存配置缓存中读取下载器配置。"""
|
||||
return SystemConfigOper().get(SystemConfigKey.Downloaders)
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_downloaders_config(downloaders_config: list) -> list:
|
||||
"""
|
||||
生成普通用户可见的下载器配置视图。
|
||||
|
||||
:param downloaders_config: 系统下载器完整配置列表
|
||||
:return: 仅包含名称、类型和启用状态的安全配置列表
|
||||
"""
|
||||
safe_fields = ("name", "type", "enabled", "default", "priority")
|
||||
safe_downloaders = []
|
||||
for downloader in downloaders_config:
|
||||
if not isinstance(downloader, dict):
|
||||
continue
|
||||
safe_downloaders.append({
|
||||
key: downloader.get(key)
|
||||
for key in safe_fields
|
||||
if key in downloader
|
||||
})
|
||||
return safe_downloaders
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
try:
|
||||
downloaders_config = self._load_downloaders_config()
|
||||
if downloaders_config:
|
||||
if not await self.is_admin_user():
|
||||
downloaders_config = self._sanitize_downloaders_config(
|
||||
downloaders_config
|
||||
)
|
||||
return json.dumps(downloaders_config, ensure_ascii=False, indent=2)
|
||||
return "未配置下载器。"
|
||||
except Exception as e:
|
||||
|
||||
@@ -30,10 +30,12 @@ class QuerySitesTool(MoviePilotTool):
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Site,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = "Query site status and list all configured sites. Shows site name, domain, status, priority, and basic configuration. Site priority (pri): smaller values have higher priority (e.g., pri=1 has higher priority than pri=10)."
|
||||
require_admin: bool = True
|
||||
description: str = (
|
||||
"Query site status and list configured sites. Non-admin users receive a safe view "
|
||||
"that omits sensitive fields: cookie, token, API key and RSS URL. "
|
||||
"Site priority (pri): smaller values have higher priority (e.g., pri=1 has higher priority than pri=10)."
|
||||
)
|
||||
args_schema: Type[BaseModel] = QuerySitesInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
@@ -57,6 +59,7 @@ class QuerySitesTool(MoviePilotTool):
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: status={status}, name={name}")
|
||||
try:
|
||||
is_admin = await self.is_admin_user()
|
||||
site_oper = SiteOper()
|
||||
# 获取所有站点(按优先级排序)
|
||||
sites = await site_oper.async_list()
|
||||
@@ -82,11 +85,25 @@ class QuerySitesTool(MoviePilotTool):
|
||||
"url": s.url,
|
||||
"pri": s.pri,
|
||||
"is_active": s.is_active,
|
||||
"cookie": s.cookie,
|
||||
"downloader": s.downloader,
|
||||
"ua": s.ua,
|
||||
"proxy": s.proxy,
|
||||
"filter": s.filter,
|
||||
"render": s.render,
|
||||
"public": s.public,
|
||||
"note": s.note,
|
||||
"limit_interval": s.limit_interval,
|
||||
"limit_count": s.limit_count,
|
||||
"limit_seconds": s.limit_seconds,
|
||||
"timeout": s.timeout,
|
||||
}
|
||||
if is_admin:
|
||||
simplified.update({
|
||||
"rss": s.rss,
|
||||
"cookie": s.cookie,
|
||||
"apikey": s.apikey,
|
||||
"token": s.token,
|
||||
})
|
||||
simplified_sites.append(simplified)
|
||||
result_json = json.dumps(simplified_sites, ensure_ascii=False, indent=2)
|
||||
return result_json
|
||||
|
||||
@@ -41,13 +41,19 @@ class ReadFileTool(MoviePilotTool):
|
||||
logger.info(f"执行工具: {self.name}, 参数: file_path={file_path}, start_line={start_line}, end_line={end_line}")
|
||||
|
||||
try:
|
||||
path = AsyncPath(file_path)
|
||||
resolved_path, access_error = await self._check_local_file_access(
|
||||
file_path, operation="读取"
|
||||
)
|
||||
if access_error:
|
||||
return access_error
|
||||
|
||||
path = AsyncPath(resolved_path)
|
||||
|
||||
if not await path.exists():
|
||||
return f"错误:文件 {file_path} 不存在"
|
||||
return f"错误:文件 {resolved_path} 不存在"
|
||||
|
||||
if not await path.is_file():
|
||||
return f"错误:{file_path} 不是一个文件"
|
||||
return f"错误:{resolved_path} 不是一个文件"
|
||||
|
||||
content = await path.read_text(encoding="utf-8")
|
||||
truncated = False
|
||||
|
||||
@@ -55,7 +55,7 @@ class SendLocalFileTool(MoviePilotTool):
|
||||
"Use this when you have generated or identified a local file the user should download."
|
||||
)
|
||||
args_schema: Type[BaseModel] = SendLocalFileInput
|
||||
require_admin: bool = False
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
file_path = kwargs.get("file_path", "")
|
||||
|
||||
@@ -44,7 +44,6 @@ class SendVoiceMessageTool(MoviePilotTool):
|
||||
"or call `send_message` with the same content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = SendVoiceMessageInput
|
||||
require_admin: bool = False
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成语音回复工具的执行提示。"""
|
||||
|
||||
@@ -23,11 +23,12 @@ class WriteFileTool(MoviePilotTool):
|
||||
tags: list[str] = [
|
||||
ToolTag.Write,
|
||||
ToolTag.File,
|
||||
ToolTag.Admin,
|
||||
]
|
||||
description: str = "Write full content to a file. If the file already exists, it will be overwritten. Automatically creates parent directories if they don't exist."
|
||||
description: str = (
|
||||
"Write full content to a local text file. Non-admin users can only write "
|
||||
"inside the MoviePilot config, Agent memory/activity, and log directories."
|
||||
)
|
||||
args_schema: Type[BaseModel] = WriteFileInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据参数生成友好的提示消息"""
|
||||
@@ -39,10 +40,16 @@ class WriteFileTool(MoviePilotTool):
|
||||
logger.info(f"执行工具: {self.name}, 参数: file_path={file_path}")
|
||||
|
||||
try:
|
||||
path = AsyncPath(file_path)
|
||||
resolved_path, access_error = await self._check_local_file_access(
|
||||
file_path, operation="写入"
|
||||
)
|
||||
if access_error:
|
||||
return access_error
|
||||
|
||||
path = AsyncPath(resolved_path)
|
||||
|
||||
if await path.exists() and not await path.is_file():
|
||||
return f"错误:{file_path} 路径已存在但不是一个文件"
|
||||
return f"错误:{resolved_path} 路径已存在但不是一个文件"
|
||||
|
||||
# 自动创建父目录
|
||||
await path.parent.mkdir(parents=True, exist_ok=True)
|
||||
@@ -50,8 +57,8 @@ class WriteFileTool(MoviePilotTool):
|
||||
# 写入文件
|
||||
await path.write_text(content, encoding="utf-8")
|
||||
|
||||
logger.info(f"成功写入文件 {file_path}")
|
||||
return f"成功写入文件 {file_path}"
|
||||
logger.info(f"成功写入文件 {resolved_path}")
|
||||
return f"成功写入文件 {resolved_path}"
|
||||
|
||||
except PermissionError:
|
||||
return f"错误:没有权限写入 {file_path}"
|
||||
|
||||
@@ -55,6 +55,7 @@ class MoviePilotToolsManager:
|
||||
source="api",
|
||||
username="API Client",
|
||||
stream_handler=None,
|
||||
agent_context={"is_admin": self.is_admin},
|
||||
)
|
||||
logger.info(f"成功加载 {len(self.tools)} 个工具")
|
||||
except Exception as e:
|
||||
|
||||
@@ -108,8 +108,6 @@ MoviePilot 也提供普通 REST API 给前端和自动化客户端使用。所
|
||||
| GET | `/api/v1/download/paths` | 查询可用于下载接口 `save_path` 参数的下载路径 |
|
||||
| DELETE | `/api/v1/download/{hashString}` | 删除下载任务,参数:`name` |
|
||||
|
||||
MCP 工具 `query_download_tasks` 支持 `status=all|downloading|completed|paused`;其中 `completed` 表示下载器任务既不是下载中,也不是暂停状态。默认仅查询带 MoviePilot 内置标签的任务;如需诊断下载器中未打内置标签的任务,可传 `include_all_tags=true`。
|
||||
|
||||
#### 系统
|
||||
|
||||
| 方法 | 路径 | 说明 |
|
||||
|
||||
@@ -49,6 +49,8 @@ dedicated tool can complete the task more directly and safely.
|
||||
`google`, `brave`, etc.) and `site_url` for limiting results to a specified
|
||||
domain or URL path. It uses the configured system proxy by default.
|
||||
- `query_sites` - Get MoviePilot site IDs before site-specific operations.
|
||||
Non-admin callers receive a safe view without Cookie, RSS, Token, or API Key
|
||||
fields.
|
||||
- `update_site_cookie` - Update a configured site's Cookie and User-Agent using
|
||||
username, password, and optional two-step code.
|
||||
- `test_site` - Verify configured site connectivity and login status.
|
||||
|
||||
317
tests/test_agent_resource_flow_permissions.py
Normal file
317
tests/test_agent_resource_flow_permissions.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""Agent 资源流程工具权限测试。"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from app.agent.tools.impl.edit_file import EditFileTool
|
||||
from app.agent.tools.impl.list_directory import ListDirectoryTool
|
||||
from app.agent.tools.impl.query_downloaders import QueryDownloadersTool
|
||||
from app.agent.tools.impl.query_sites import QuerySitesTool
|
||||
from app.agent.tools.impl.read_file import ReadFileTool
|
||||
from app.agent.tools.impl.send_local_file import SendLocalFileTool
|
||||
from app.agent.tools.impl.write_file import WriteFileTool
|
||||
from app.agent.tools.manager import MoviePilotToolsManager
|
||||
from app.agent import MoviePilotAgent
|
||||
from app.core.config import settings
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
def test_non_admin_manager_exposes_resource_flow_helper_tools():
|
||||
"""普通用户应能看到搜索、订阅、下载流程所需的辅助工具。"""
|
||||
site_tool = QuerySitesTool(session_id="session-1", user_id="10001")
|
||||
downloader_tool = QueryDownloadersTool(session_id="session-1", user_id="10001")
|
||||
|
||||
with patch(
|
||||
"app.agent.tools.manager.MoviePilotToolFactory.create_tools",
|
||||
return_value=[site_tool, downloader_tool],
|
||||
):
|
||||
manager = MoviePilotToolsManager(is_admin=False)
|
||||
|
||||
tool_names = {tool.name for tool in manager.list_tools()}
|
||||
assert "query_sites" in tool_names
|
||||
assert "query_downloaders" in tool_names
|
||||
|
||||
|
||||
def test_non_admin_manager_exposes_restricted_file_tools():
|
||||
"""普通用户应能看到受目录边界限制的文件读写工具。"""
|
||||
tools = [
|
||||
ReadFileTool(session_id="session-1", user_id="10001"),
|
||||
WriteFileTool(session_id="session-1", user_id="10001"),
|
||||
EditFileTool(session_id="session-1", user_id="10001"),
|
||||
ListDirectoryTool(session_id="session-1", user_id="10001"),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"app.agent.tools.manager.MoviePilotToolFactory.create_tools",
|
||||
return_value=tools,
|
||||
):
|
||||
manager = MoviePilotToolsManager(is_admin=False)
|
||||
|
||||
tool_names = {tool.name for tool in manager.list_tools()}
|
||||
assert {"read_file", "write_file", "edit_file", "list_directory"} <= tool_names
|
||||
|
||||
|
||||
def test_query_sites_hides_only_sensitive_fields_for_non_admin_user():
|
||||
"""普通用户查询站点时只隐藏 Cookie、API Key、Token 和 RSS。"""
|
||||
tool = QuerySitesTool(session_id="session-1", user_id="10001")
|
||||
site = SimpleNamespace(
|
||||
id=1,
|
||||
name="TestSite",
|
||||
domain="secret.example",
|
||||
url="https://secret.example/",
|
||||
pri=1,
|
||||
rss="https://secret.example/rss",
|
||||
cookie="uid=1; passkey=secret",
|
||||
ua="SecretUA",
|
||||
apikey="site-api-key",
|
||||
token="site-token",
|
||||
proxy=1,
|
||||
filter="",
|
||||
render=0,
|
||||
public=0,
|
||||
note={"secret": True},
|
||||
limit_interval=0,
|
||||
limit_count=0,
|
||||
limit_seconds=0,
|
||||
timeout=15,
|
||||
is_active=True,
|
||||
downloader="qb",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.agent.tools.impl.query_sites.SiteOper"
|
||||
) as site_oper:
|
||||
site_oper.return_value.async_list = AsyncMock(return_value=[site])
|
||||
result = asyncio.run(tool.run())
|
||||
|
||||
payload = json.loads(result)
|
||||
assert payload == [
|
||||
{
|
||||
"id": 1,
|
||||
"name": "TestSite",
|
||||
"domain": "secret.example",
|
||||
"url": "https://secret.example/",
|
||||
"pri": 1,
|
||||
"is_active": True,
|
||||
"downloader": "qb",
|
||||
"ua": "SecretUA",
|
||||
"proxy": 1,
|
||||
"filter": "",
|
||||
"render": 0,
|
||||
"public": 0,
|
||||
"note": {"secret": True},
|
||||
"limit_interval": 0,
|
||||
"limit_count": 0,
|
||||
"limit_seconds": 0,
|
||||
"timeout": 15,
|
||||
}
|
||||
]
|
||||
assert "cookie" not in payload[0]
|
||||
assert "rss" not in payload[0]
|
||||
assert "token" not in payload[0]
|
||||
assert "apikey" not in payload[0]
|
||||
|
||||
|
||||
def test_query_sites_keeps_full_fields_for_admin_context():
|
||||
"""管理员查询站点时保留完整配置视图。"""
|
||||
tool = QuerySitesTool(session_id="session-1", user_id="admin")
|
||||
tool.set_agent_context({"is_admin": True})
|
||||
site = SimpleNamespace(
|
||||
id=1,
|
||||
name="TestSite",
|
||||
domain="secret.example",
|
||||
url="https://secret.example/",
|
||||
pri=1,
|
||||
rss="https://secret.example/rss",
|
||||
cookie="uid=1; passkey=secret",
|
||||
ua="SecretUA",
|
||||
apikey="site-api-key",
|
||||
token="site-token",
|
||||
proxy=1,
|
||||
filter="",
|
||||
render=0,
|
||||
public=0,
|
||||
note={"secret": True},
|
||||
limit_interval=0,
|
||||
limit_count=0,
|
||||
limit_seconds=0,
|
||||
timeout=15,
|
||||
is_active=True,
|
||||
downloader="qb",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.agent.tools.impl.query_sites.SiteOper"
|
||||
) as site_oper:
|
||||
site_oper.return_value.async_list = AsyncMock(return_value=[site])
|
||||
result = asyncio.run(tool.run())
|
||||
|
||||
payload = json.loads(result)
|
||||
assert payload[0]["cookie"] == "uid=1; passkey=secret"
|
||||
assert payload[0]["token"] == "site-token"
|
||||
assert payload[0]["apikey"] == "site-api-key"
|
||||
assert payload[0]["url"] == "https://secret.example/"
|
||||
|
||||
|
||||
def test_non_admin_file_tools_can_access_config_directory(tmp_path, monkeypatch):
|
||||
"""普通用户可在配置目录内读写和编辑文件。"""
|
||||
config_path = tmp_path / "config"
|
||||
monkeypatch.setattr(settings, "CONFIG_DIR", str(config_path))
|
||||
memory_path = settings.CONFIG_PATH / "agent" / "memory" / "MEMORY.md"
|
||||
|
||||
write_tool = WriteFileTool(session_id="session-1", user_id="10001")
|
||||
read_tool = ReadFileTool(session_id="session-1", user_id="10001")
|
||||
edit_tool = EditFileTool(session_id="session-1", user_id="10001")
|
||||
|
||||
write_result = asyncio.run(write_tool.run(str(memory_path), "hello"))
|
||||
read_result = asyncio.run(read_tool.run(str(memory_path)))
|
||||
edit_result = asyncio.run(edit_tool.run(str(memory_path), "hello", "hello mp"))
|
||||
edited_content = memory_path.read_text(encoding="utf-8")
|
||||
|
||||
assert "成功写入文件" in write_result
|
||||
assert read_result == "hello"
|
||||
assert "成功编辑文件" in edit_result
|
||||
assert edited_content == "hello mp"
|
||||
|
||||
|
||||
def test_non_admin_file_tools_block_paths_outside_allowed_roots(
|
||||
tmp_path, monkeypatch
|
||||
):
|
||||
"""普通用户不能通过文件工具访问配置、记忆和日志目录外的路径。"""
|
||||
config_path = tmp_path / "config"
|
||||
outside_path = tmp_path / "outside.txt"
|
||||
outside_path.write_text("secret", encoding="utf-8")
|
||||
monkeypatch.setattr(settings, "CONFIG_DIR", str(config_path))
|
||||
|
||||
read_tool = ReadFileTool(session_id="session-1", user_id="10001")
|
||||
write_tool = WriteFileTool(session_id="session-1", user_id="10001")
|
||||
edit_tool = EditFileTool(session_id="session-1", user_id="10001")
|
||||
list_tool = ListDirectoryTool(session_id="session-1", user_id="10001")
|
||||
send_tool = SendLocalFileTool(session_id="session-1", user_id="10001")
|
||||
send_tool.set_message_attr(
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-main",
|
||||
username="normal-user",
|
||||
)
|
||||
|
||||
read_result = asyncio.run(read_tool.run(str(outside_path)))
|
||||
write_result = asyncio.run(write_tool.run(str(outside_path), "changed"))
|
||||
edit_result = asyncio.run(edit_tool.run(str(outside_path), "secret", "changed"))
|
||||
with patch.object(ListDirectoryTool, "_list_directory_sync") as list_directory:
|
||||
list_result = asyncio.run(list_tool.run(str(tmp_path)))
|
||||
send_result = asyncio.run(send_tool.run(str(outside_path)))
|
||||
|
||||
assert "普通用户只能读取" in read_result
|
||||
assert "普通用户只能写入" in write_result
|
||||
assert "普通用户只能编辑" in edit_result
|
||||
assert "普通用户只能列出" in list_result
|
||||
assert "普通用户只能发送" in send_result
|
||||
assert outside_path.read_text(encoding="utf-8") == "secret"
|
||||
list_directory.assert_not_called()
|
||||
|
||||
|
||||
def test_admin_file_tool_can_access_paths_outside_allowed_roots(
|
||||
tmp_path, monkeypatch
|
||||
):
|
||||
"""管理员上下文不受普通用户文件访问边界限制。"""
|
||||
config_path = tmp_path / "config"
|
||||
outside_path = tmp_path / "outside.txt"
|
||||
monkeypatch.setattr(settings, "CONFIG_DIR", str(config_path))
|
||||
|
||||
tool = WriteFileTool(session_id="session-1", user_id="admin")
|
||||
tool.set_agent_context({"is_admin": True})
|
||||
|
||||
result = asyncio.run(tool.run(str(outside_path), "admin write"))
|
||||
|
||||
assert "成功写入文件" in result
|
||||
assert outside_path.read_text(encoding="utf-8") == "admin write"
|
||||
|
||||
|
||||
def test_query_downloaders_hides_sensitive_fields_for_non_admin_user():
|
||||
"""普通用户查询下载器时只返回选择下载器所需的安全字段。"""
|
||||
tool = QueryDownloadersTool(session_id="session-1", user_id="10001")
|
||||
downloaders = [
|
||||
{
|
||||
"name": "qb",
|
||||
"type": "qbittorrent",
|
||||
"enabled": True,
|
||||
"host": "http://127.0.0.1",
|
||||
"port": 8080,
|
||||
"username": "admin",
|
||||
"password": "secret",
|
||||
"apikey": "downloader-api-key",
|
||||
"token": "downloader-token",
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"app.agent.tools.impl.query_downloaders.SystemConfigOper"
|
||||
) as system_config_oper:
|
||||
system_config_oper.return_value.get.return_value = downloaders
|
||||
result = asyncio.run(tool.run())
|
||||
|
||||
payload = json.loads(result)
|
||||
assert payload == [
|
||||
{
|
||||
"name": "qb",
|
||||
"type": "qbittorrent",
|
||||
"enabled": True,
|
||||
}
|
||||
]
|
||||
assert "host" not in payload[0]
|
||||
assert "username" not in payload[0]
|
||||
assert "password" not in payload[0]
|
||||
assert "apikey" not in payload[0]
|
||||
assert "token" not in payload[0]
|
||||
|
||||
|
||||
def test_query_downloaders_keeps_full_fields_for_admin_context():
|
||||
"""管理员查询下载器时保留完整配置视图。"""
|
||||
tool = QueryDownloadersTool(session_id="session-1", user_id="admin")
|
||||
tool.set_agent_context({"is_admin": True})
|
||||
downloaders = [
|
||||
{
|
||||
"name": "qb",
|
||||
"type": "qbittorrent",
|
||||
"enabled": True,
|
||||
"host": "http://127.0.0.1",
|
||||
"username": "admin",
|
||||
"password": "secret",
|
||||
"apikey": "downloader-api-key",
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"app.agent.tools.impl.query_downloaders.SystemConfigOper"
|
||||
) as system_config_oper:
|
||||
system_config_oper.return_value.get.return_value = downloaders
|
||||
result = asyncio.run(tool.run())
|
||||
|
||||
payload = json.loads(result)
|
||||
assert payload[0]["host"] == "http://127.0.0.1"
|
||||
assert payload[0]["username"] == "admin"
|
||||
assert payload[0]["password"] == "secret"
|
||||
assert payload[0]["apikey"] == "downloader-api-key"
|
||||
|
||||
|
||||
def test_channel_agent_admin_user_id_does_not_bypass_user_lookup():
|
||||
"""渠道用户 ID 恰好为 admin 时,不应绕过真实系统用户权限判断。"""
|
||||
agent = MoviePilotAgent(
|
||||
session_id="session-1",
|
||||
user_id="admin",
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-main",
|
||||
username="normal-user",
|
||||
)
|
||||
|
||||
with patch("app.agent.UserOper") as user_oper:
|
||||
user_oper.return_value.async_get_by_name.return_value = SimpleNamespace(
|
||||
is_superuser=False
|
||||
)
|
||||
context = asyncio.run(
|
||||
agent._build_tool_context(should_dispatch_reply=True)
|
||||
)
|
||||
|
||||
assert context["is_admin"] is False
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
@@ -19,134 +18,132 @@ from app.agent.middleware.subagents import (
|
||||
from app.agent.tools.tags import ToolTag
|
||||
|
||||
|
||||
class TestAgentSubagents(unittest.TestCase):
|
||||
def test_create_subagent_middlewares_registers_task_tool(self):
|
||||
"""子代理中间件应向主 Agent 注册 task 委派工具。"""
|
||||
model = FakeListChatModel(responses=["ok"])
|
||||
def test_create_subagent_middlewares_registers_task_tool():
|
||||
"""子代理中间件应向主 Agent 注册 task 委派工具。"""
|
||||
model = FakeListChatModel(responses=["ok"])
|
||||
|
||||
middlewares, task_tools = create_subagent_middlewares(
|
||||
model=model,
|
||||
tools=[],
|
||||
stream_handler=None,
|
||||
)
|
||||
middlewares, task_tools = create_subagent_middlewares(
|
||||
model=model,
|
||||
tools=[],
|
||||
stream_handler=None,
|
||||
)
|
||||
|
||||
self.assertEqual(len(middlewares), 3)
|
||||
self.assertEqual(
|
||||
[tool.name for tool in task_tools],
|
||||
[SUBAGENT_TASK_TOOL_NAME, SUBAGENT_CONTROL_TOOL_NAME],
|
||||
)
|
||||
self.assertIn("media-researcher", task_tools[0].description)
|
||||
self.assertIn("moviepilot-explorer", task_tools[0].description)
|
||||
self.assertIn("system-diagnostician", task_tools[0].description)
|
||||
self.assertIn("action=start", task_tools[1].description)
|
||||
self.assertIn("action=wait", task_tools[1].description)
|
||||
|
||||
def test_subagent_tools_are_selected_by_tags(self):
|
||||
"""子代理应根据工具标签筛选工具,而不是依赖工具名名单。"""
|
||||
model = FakeListChatModel(responses=["ok"])
|
||||
tools = [
|
||||
SimpleNamespace(
|
||||
name="custom_media_lookup",
|
||||
tags=[ToolTag.Read.value, ToolTag.Media.value],
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="custom_media_writer",
|
||||
tags=[ToolTag.Read.value, ToolTag.Write.value, ToolTag.Media.value],
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="custom_site_lookup",
|
||||
tags=[ToolTag.Read.value, ToolTag.Site.value],
|
||||
),
|
||||
]
|
||||
captured = {}
|
||||
|
||||
def _fake_create_agent(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return kwargs
|
||||
|
||||
middleware = MoviePilotSubAgentMiddleware(
|
||||
model=model,
|
||||
profiles=subagent_module._builtin_subagent_profiles(),
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
with patch.object(subagent_module, "create_agent", side_effect=_fake_create_agent):
|
||||
middleware._get_agent("media-researcher")
|
||||
|
||||
self.assertEqual(
|
||||
[tool.name for tool in captured["tools"]],
|
||||
["custom_media_lookup"],
|
||||
)
|
||||
|
||||
def test_moviepilot_explorer_selects_code_and_settings_tools(self):
|
||||
"""MoviePilot 探索子代理应能读取代码、目录、设置和命令诊断工具。"""
|
||||
model = FakeListChatModel(responses=["ok"])
|
||||
tools = [
|
||||
SimpleNamespace(
|
||||
name="custom_code_reader",
|
||||
tags=[ToolTag.Read.value, ToolTag.File.value],
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="custom_directory_lister",
|
||||
tags=[ToolTag.Read.value, ToolTag.Directory.value],
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="custom_settings_reader",
|
||||
tags=[ToolTag.Read.value, ToolTag.Settings.value],
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="custom_command_runner",
|
||||
tags=[ToolTag.Read.value, ToolTag.Command.value],
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="custom_code_writer",
|
||||
tags=[ToolTag.Read.value, ToolTag.Write.value, ToolTag.File.value],
|
||||
),
|
||||
]
|
||||
captured = {}
|
||||
|
||||
def _fake_create_agent(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return kwargs
|
||||
|
||||
middleware = MoviePilotSubAgentMiddleware(
|
||||
model=model,
|
||||
profiles=subagent_module._builtin_subagent_profiles(),
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
with patch.object(subagent_module, "create_agent", side_effect=_fake_create_agent):
|
||||
middleware._get_agent("moviepilot-explorer")
|
||||
|
||||
self.assertEqual(
|
||||
[tool.name for tool in captured["tools"]],
|
||||
[
|
||||
"custom_code_reader",
|
||||
"custom_directory_lister",
|
||||
"custom_settings_reader",
|
||||
"custom_command_runner",
|
||||
],
|
||||
)
|
||||
|
||||
def test_builtin_tools_declare_tags_in_implementation(self):
|
||||
"""所有内置工具实现都应显式声明 tags。"""
|
||||
impl_dir = Path(__file__).resolve().parents[1] / "app" / "agent" / "tools" / "impl"
|
||||
missing_tools = []
|
||||
for path in sorted(impl_dir.glob("*.py")):
|
||||
text = path.read_text()
|
||||
for block in text.split("\nclass "):
|
||||
if "(MoviePilotTool)" not in block:
|
||||
continue
|
||||
class_name = block.split("(", 1)[0].strip()
|
||||
if "tags: list[str]" not in block:
|
||||
missing_tools.append(f"{path.name}:{class_name}")
|
||||
|
||||
self.assertEqual([], missing_tools)
|
||||
assert len(middlewares) == 3
|
||||
assert [tool.name for tool in task_tools] == [
|
||||
SUBAGENT_TASK_TOOL_NAME,
|
||||
SUBAGENT_CONTROL_TOOL_NAME,
|
||||
]
|
||||
assert "media-researcher" in task_tools[0].description
|
||||
assert "moviepilot-explorer" in task_tools[0].description
|
||||
assert "system-diagnostician" in task_tools[0].description
|
||||
assert "action=start" in task_tools[1].description
|
||||
assert "action=wait" in task_tools[1].description
|
||||
assert "action=pipeline" in task_tools[1].description
|
||||
|
||||
|
||||
class TestSubAgentTaskControlMiddleware(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_call_summary_middleware_logs_subagent_tool_operations(self):
|
||||
"""子代理工具包装层应记录工具执行开始和完成日志。"""
|
||||
def test_subagent_tools_are_selected_by_tags():
|
||||
"""子代理应根据工具标签筛选工具,而不是依赖工具名名单。"""
|
||||
model = FakeListChatModel(responses=["ok"])
|
||||
tools = [
|
||||
SimpleNamespace(
|
||||
name="custom_media_lookup",
|
||||
tags=[ToolTag.Read.value, ToolTag.Media.value],
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="custom_media_writer",
|
||||
tags=[ToolTag.Read.value, ToolTag.Write.value, ToolTag.Media.value],
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="custom_site_lookup",
|
||||
tags=[ToolTag.Read.value, ToolTag.Site.value],
|
||||
),
|
||||
]
|
||||
captured = {}
|
||||
|
||||
def _fake_create_agent(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return kwargs
|
||||
|
||||
middleware = MoviePilotSubAgentMiddleware(
|
||||
model=model,
|
||||
profiles=subagent_module._builtin_subagent_profiles(),
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
with patch.object(subagent_module, "create_agent", side_effect=_fake_create_agent):
|
||||
middleware._get_agent("media-researcher")
|
||||
|
||||
assert [tool.name for tool in captured["tools"]] == ["custom_media_lookup"]
|
||||
|
||||
|
||||
def test_moviepilot_explorer_selects_code_and_settings_tools():
|
||||
"""MoviePilot 探索子代理应能读取代码、目录、设置和命令诊断工具。"""
|
||||
model = FakeListChatModel(responses=["ok"])
|
||||
tools = [
|
||||
SimpleNamespace(
|
||||
name="custom_code_reader",
|
||||
tags=[ToolTag.Read.value, ToolTag.File.value],
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="custom_directory_lister",
|
||||
tags=[ToolTag.Read.value, ToolTag.Directory.value],
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="custom_settings_reader",
|
||||
tags=[ToolTag.Read.value, ToolTag.Settings.value],
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="custom_command_runner",
|
||||
tags=[ToolTag.Read.value, ToolTag.Command.value],
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="custom_code_writer",
|
||||
tags=[ToolTag.Read.value, ToolTag.Write.value, ToolTag.File.value],
|
||||
),
|
||||
]
|
||||
captured = {}
|
||||
|
||||
def _fake_create_agent(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return kwargs
|
||||
|
||||
middleware = MoviePilotSubAgentMiddleware(
|
||||
model=model,
|
||||
profiles=subagent_module._builtin_subagent_profiles(),
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
with patch.object(subagent_module, "create_agent", side_effect=_fake_create_agent):
|
||||
middleware._get_agent("moviepilot-explorer")
|
||||
|
||||
assert [tool.name for tool in captured["tools"]] == [
|
||||
"custom_code_reader",
|
||||
"custom_directory_lister",
|
||||
"custom_settings_reader",
|
||||
"custom_command_runner",
|
||||
]
|
||||
|
||||
|
||||
def test_builtin_tools_declare_tags_in_implementation():
|
||||
"""所有内置工具实现都应显式声明 tags。"""
|
||||
impl_dir = Path(__file__).resolve().parents[1] / "app" / "agent" / "tools" / "impl"
|
||||
missing_tools = []
|
||||
for path in sorted(impl_dir.glob("*.py")):
|
||||
text = path.read_text()
|
||||
for block in text.split("\nclass "):
|
||||
if "(MoviePilotTool)" not in block:
|
||||
continue
|
||||
class_name = block.split("(", 1)[0].strip()
|
||||
if "tags: list[str]" not in block:
|
||||
missing_tools.append(f"{path.name}:{class_name}")
|
||||
|
||||
assert missing_tools == []
|
||||
|
||||
|
||||
def test_call_summary_middleware_logs_subagent_tool_operations():
|
||||
"""子代理工具包装层应记录工具执行开始和完成日志。"""
|
||||
|
||||
async def _run_test():
|
||||
middleware = SubAgentCallSummaryMiddleware()
|
||||
request = SimpleNamespace(
|
||||
tool=SimpleNamespace(name=SUBAGENT_CONTROL_TOOL_NAME),
|
||||
@@ -165,12 +162,17 @@ class TestSubAgentTaskControlMiddleware(unittest.IsolatedAsyncioTestCase):
|
||||
result = await middleware.awrap_tool_call(request, _fake_handler)
|
||||
|
||||
messages = [call.args[0] for call in log_info.call_args_list]
|
||||
self.assertEqual("ok", result)
|
||||
self.assertTrue(any("开始执行子代理工具" in message for message in messages))
|
||||
self.assertTrue(any("子代理工具执行完成" in message for message in messages))
|
||||
assert result == "ok"
|
||||
assert any("开始执行子代理工具" in message for message in messages)
|
||||
assert any("子代理工具执行完成" in message for message in messages)
|
||||
|
||||
async def test_control_tool_starts_tasks_concurrently_and_waits(self):
|
||||
"""异步子代理管控工具应批量启动任务,并在 wait 时收集结果。"""
|
||||
asyncio.run(_run_test())
|
||||
|
||||
|
||||
def test_control_tool_starts_tasks_concurrently_and_waits():
|
||||
"""异步子代理管控工具应批量启动任务,并在 wait 时收集结果。"""
|
||||
|
||||
async def _run_test():
|
||||
model = FakeListChatModel(responses=["ok"])
|
||||
middleware = SubAgentTaskControlMiddleware(
|
||||
model=model,
|
||||
@@ -221,21 +223,154 @@ class TestSubAgentTaskControlMiddleware(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
self.assertTrue(start_payload["success"])
|
||||
self.assertEqual(2, len(task_ids))
|
||||
self.assertEqual(["检查媒体库", "检查下载器"], running_descriptions)
|
||||
self.assertEqual(
|
||||
["completed", "completed"],
|
||||
[task["status"] for task in wait_payload["tasks"]],
|
||||
)
|
||||
self.assertIn("media-researcher:检查媒体库", wait_payload["tasks"][0]["result"])
|
||||
self.assertIn(
|
||||
"download-diagnostician:检查下载器",
|
||||
wait_payload["tasks"][1]["result"],
|
||||
assert start_payload["success"]
|
||||
assert len(task_ids) == 2
|
||||
assert running_descriptions == ["检查媒体库", "检查下载器"]
|
||||
assert [task["status"] for task in wait_payload["tasks"]] == [
|
||||
"completed",
|
||||
"completed",
|
||||
]
|
||||
assert "media-researcher:检查媒体库" in wait_payload["tasks"][0]["result"]
|
||||
assert (
|
||||
"download-diagnostician:检查下载器"
|
||||
in wait_payload["tasks"][1]["result"]
|
||||
)
|
||||
|
||||
async def test_after_agent_cancels_unfinished_tasks(self):
|
||||
"""Agent 结束时应取消仍在运行的异步子代理任务。"""
|
||||
asyncio.run(_run_test())
|
||||
|
||||
|
||||
def test_control_tool_pipeline_passes_previous_results_to_next_step():
|
||||
"""管道模式应顺序执行子代理,并把上一步结果作为下一步私有上下文。"""
|
||||
|
||||
async def _run_test():
|
||||
model = FakeListChatModel(responses=["ok"])
|
||||
middleware = SubAgentTaskControlMiddleware(
|
||||
model=model,
|
||||
profiles=subagent_module._builtin_subagent_profiles(),
|
||||
tools=[],
|
||||
)
|
||||
calls = []
|
||||
|
||||
async def _fake_run_task(self, *, description, subagent_type, task_id=None):
|
||||
calls.append(
|
||||
{
|
||||
"description": description,
|
||||
"subagent_type": subagent_type,
|
||||
"task_id": task_id,
|
||||
}
|
||||
)
|
||||
return f"结果-{len(calls)}"
|
||||
|
||||
with patch.object(
|
||||
subagent_module._SubAgentAgentProvider,
|
||||
"run_task",
|
||||
new=_fake_run_task,
|
||||
):
|
||||
payload = json.loads(
|
||||
await middleware._control_task(
|
||||
action="pipeline",
|
||||
tasks=[
|
||||
{
|
||||
"description": "识别媒体",
|
||||
"subagent_type": "media-researcher",
|
||||
},
|
||||
{
|
||||
"description": "检查下载",
|
||||
"subagent_type": "download-diagnostician",
|
||||
},
|
||||
{
|
||||
"description": "汇总结论",
|
||||
"subagent_type": "general-purpose",
|
||||
},
|
||||
],
|
||||
timeout_ms=1000,
|
||||
)
|
||||
)
|
||||
|
||||
assert payload["success"]
|
||||
assert [call["subagent_type"] for call in calls] == [
|
||||
"media-researcher",
|
||||
"download-diagnostician",
|
||||
"general-purpose",
|
||||
]
|
||||
assert calls[0]["description"] == "识别媒体"
|
||||
assert "结果-1" in calls[1]["description"]
|
||||
assert "结果-1" in calls[2]["description"]
|
||||
assert "结果-2" in calls[2]["description"]
|
||||
assert [task["status"] for task in payload["tasks"]] == [
|
||||
"completed",
|
||||
"completed",
|
||||
"completed",
|
||||
]
|
||||
assert [task["result"] for task in payload["tasks"]] == [
|
||||
"结果-1",
|
||||
"结果-2",
|
||||
"结果-3",
|
||||
]
|
||||
|
||||
asyncio.run(_run_test())
|
||||
|
||||
|
||||
def test_control_tool_pipeline_stops_after_failed_step():
|
||||
"""管道模式遇到失败步骤时应中断后续子代理。"""
|
||||
|
||||
async def _run_test():
|
||||
model = FakeListChatModel(responses=["ok"])
|
||||
middleware = SubAgentTaskControlMiddleware(
|
||||
model=model,
|
||||
profiles=subagent_module._builtin_subagent_profiles(),
|
||||
tools=[],
|
||||
)
|
||||
calls = []
|
||||
|
||||
async def _fake_run_task(self, *, description, subagent_type, task_id=None):
|
||||
calls.append(subagent_type)
|
||||
if subagent_type == "download-diagnostician":
|
||||
raise RuntimeError("下载器不可用")
|
||||
return f"{subagent_type}:ok"
|
||||
|
||||
with patch.object(
|
||||
subagent_module._SubAgentAgentProvider,
|
||||
"run_task",
|
||||
new=_fake_run_task,
|
||||
):
|
||||
payload = json.loads(
|
||||
await middleware._control_task(
|
||||
action="pipeline",
|
||||
tasks=[
|
||||
{
|
||||
"description": "识别媒体",
|
||||
"subagent_type": "media-researcher",
|
||||
},
|
||||
{
|
||||
"description": "检查下载",
|
||||
"subagent_type": "download-diagnostician",
|
||||
},
|
||||
{
|
||||
"description": "汇总结论",
|
||||
"subagent_type": "general-purpose",
|
||||
},
|
||||
],
|
||||
timeout_ms=1000,
|
||||
)
|
||||
)
|
||||
|
||||
assert not payload["success"]
|
||||
assert "第 2 个管道子代理任务执行失败" in payload["error"]
|
||||
assert calls == ["media-researcher", "download-diagnostician"]
|
||||
assert [task["status"] for task in payload["tasks"]] == [
|
||||
"completed",
|
||||
"failed",
|
||||
]
|
||||
assert "下载器不可用" in payload["tasks"][1]["error"]
|
||||
|
||||
asyncio.run(_run_test())
|
||||
|
||||
|
||||
def test_after_agent_cancels_unfinished_tasks():
|
||||
"""Agent 结束时应取消仍在运行的异步子代理任务。"""
|
||||
|
||||
async def _run_test():
|
||||
model = FakeListChatModel(responses=["ok"])
|
||||
middleware = SubAgentTaskControlMiddleware(
|
||||
model=model,
|
||||
@@ -269,4 +404,6 @@ class TestSubAgentTaskControlMiddleware(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual("cancelled", status_payload["tasks"][0]["status"])
|
||||
assert status_payload["tasks"][0]["status"] == "cancelled"
|
||||
|
||||
asyncio.run(_run_test())
|
||||
|
||||
Reference in New Issue
Block a user