mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-07-05 14:51:28 +08:00
feat: enhance user permissions handling for admin and non-admin contexts
This commit is contained in:
317
tests/test_agent_resource_flow_permissions.py
Normal file
317
tests/test_agent_resource_flow_permissions.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""Agent 资源流程工具权限测试。"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from app.agent.tools.impl.edit_file import EditFileTool
|
||||
from app.agent.tools.impl.list_directory import ListDirectoryTool
|
||||
from app.agent.tools.impl.query_downloaders import QueryDownloadersTool
|
||||
from app.agent.tools.impl.query_sites import QuerySitesTool
|
||||
from app.agent.tools.impl.read_file import ReadFileTool
|
||||
from app.agent.tools.impl.send_local_file import SendLocalFileTool
|
||||
from app.agent.tools.impl.write_file import WriteFileTool
|
||||
from app.agent.tools.manager import MoviePilotToolsManager
|
||||
from app.agent import MoviePilotAgent
|
||||
from app.core.config import settings
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
def test_non_admin_manager_exposes_resource_flow_helper_tools():
|
||||
"""普通用户应能看到搜索、订阅、下载流程所需的辅助工具。"""
|
||||
site_tool = QuerySitesTool(session_id="session-1", user_id="10001")
|
||||
downloader_tool = QueryDownloadersTool(session_id="session-1", user_id="10001")
|
||||
|
||||
with patch(
|
||||
"app.agent.tools.manager.MoviePilotToolFactory.create_tools",
|
||||
return_value=[site_tool, downloader_tool],
|
||||
):
|
||||
manager = MoviePilotToolsManager(is_admin=False)
|
||||
|
||||
tool_names = {tool.name for tool in manager.list_tools()}
|
||||
assert "query_sites" in tool_names
|
||||
assert "query_downloaders" in tool_names
|
||||
|
||||
|
||||
def test_non_admin_manager_exposes_restricted_file_tools():
|
||||
"""普通用户应能看到受目录边界限制的文件读写工具。"""
|
||||
tools = [
|
||||
ReadFileTool(session_id="session-1", user_id="10001"),
|
||||
WriteFileTool(session_id="session-1", user_id="10001"),
|
||||
EditFileTool(session_id="session-1", user_id="10001"),
|
||||
ListDirectoryTool(session_id="session-1", user_id="10001"),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"app.agent.tools.manager.MoviePilotToolFactory.create_tools",
|
||||
return_value=tools,
|
||||
):
|
||||
manager = MoviePilotToolsManager(is_admin=False)
|
||||
|
||||
tool_names = {tool.name for tool in manager.list_tools()}
|
||||
assert {"read_file", "write_file", "edit_file", "list_directory"} <= tool_names
|
||||
|
||||
|
||||
def test_query_sites_hides_only_sensitive_fields_for_non_admin_user():
|
||||
"""普通用户查询站点时只隐藏 Cookie、API Key、Token 和 RSS。"""
|
||||
tool = QuerySitesTool(session_id="session-1", user_id="10001")
|
||||
site = SimpleNamespace(
|
||||
id=1,
|
||||
name="TestSite",
|
||||
domain="secret.example",
|
||||
url="https://secret.example/",
|
||||
pri=1,
|
||||
rss="https://secret.example/rss",
|
||||
cookie="uid=1; passkey=secret",
|
||||
ua="SecretUA",
|
||||
apikey="site-api-key",
|
||||
token="site-token",
|
||||
proxy=1,
|
||||
filter="",
|
||||
render=0,
|
||||
public=0,
|
||||
note={"secret": True},
|
||||
limit_interval=0,
|
||||
limit_count=0,
|
||||
limit_seconds=0,
|
||||
timeout=15,
|
||||
is_active=True,
|
||||
downloader="qb",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.agent.tools.impl.query_sites.SiteOper"
|
||||
) as site_oper:
|
||||
site_oper.return_value.async_list = AsyncMock(return_value=[site])
|
||||
result = asyncio.run(tool.run())
|
||||
|
||||
payload = json.loads(result)
|
||||
assert payload == [
|
||||
{
|
||||
"id": 1,
|
||||
"name": "TestSite",
|
||||
"domain": "secret.example",
|
||||
"url": "https://secret.example/",
|
||||
"pri": 1,
|
||||
"is_active": True,
|
||||
"downloader": "qb",
|
||||
"ua": "SecretUA",
|
||||
"proxy": 1,
|
||||
"filter": "",
|
||||
"render": 0,
|
||||
"public": 0,
|
||||
"note": {"secret": True},
|
||||
"limit_interval": 0,
|
||||
"limit_count": 0,
|
||||
"limit_seconds": 0,
|
||||
"timeout": 15,
|
||||
}
|
||||
]
|
||||
assert "cookie" not in payload[0]
|
||||
assert "rss" not in payload[0]
|
||||
assert "token" not in payload[0]
|
||||
assert "apikey" not in payload[0]
|
||||
|
||||
|
||||
def test_query_sites_keeps_full_fields_for_admin_context():
|
||||
"""管理员查询站点时保留完整配置视图。"""
|
||||
tool = QuerySitesTool(session_id="session-1", user_id="admin")
|
||||
tool.set_agent_context({"is_admin": True})
|
||||
site = SimpleNamespace(
|
||||
id=1,
|
||||
name="TestSite",
|
||||
domain="secret.example",
|
||||
url="https://secret.example/",
|
||||
pri=1,
|
||||
rss="https://secret.example/rss",
|
||||
cookie="uid=1; passkey=secret",
|
||||
ua="SecretUA",
|
||||
apikey="site-api-key",
|
||||
token="site-token",
|
||||
proxy=1,
|
||||
filter="",
|
||||
render=0,
|
||||
public=0,
|
||||
note={"secret": True},
|
||||
limit_interval=0,
|
||||
limit_count=0,
|
||||
limit_seconds=0,
|
||||
timeout=15,
|
||||
is_active=True,
|
||||
downloader="qb",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.agent.tools.impl.query_sites.SiteOper"
|
||||
) as site_oper:
|
||||
site_oper.return_value.async_list = AsyncMock(return_value=[site])
|
||||
result = asyncio.run(tool.run())
|
||||
|
||||
payload = json.loads(result)
|
||||
assert payload[0]["cookie"] == "uid=1; passkey=secret"
|
||||
assert payload[0]["token"] == "site-token"
|
||||
assert payload[0]["apikey"] == "site-api-key"
|
||||
assert payload[0]["url"] == "https://secret.example/"
|
||||
|
||||
|
||||
def test_non_admin_file_tools_can_access_config_directory(tmp_path, monkeypatch):
|
||||
"""普通用户可在配置目录内读写和编辑文件。"""
|
||||
config_path = tmp_path / "config"
|
||||
monkeypatch.setattr(settings, "CONFIG_DIR", str(config_path))
|
||||
memory_path = settings.CONFIG_PATH / "agent" / "memory" / "MEMORY.md"
|
||||
|
||||
write_tool = WriteFileTool(session_id="session-1", user_id="10001")
|
||||
read_tool = ReadFileTool(session_id="session-1", user_id="10001")
|
||||
edit_tool = EditFileTool(session_id="session-1", user_id="10001")
|
||||
|
||||
write_result = asyncio.run(write_tool.run(str(memory_path), "hello"))
|
||||
read_result = asyncio.run(read_tool.run(str(memory_path)))
|
||||
edit_result = asyncio.run(edit_tool.run(str(memory_path), "hello", "hello mp"))
|
||||
edited_content = memory_path.read_text(encoding="utf-8")
|
||||
|
||||
assert "成功写入文件" in write_result
|
||||
assert read_result == "hello"
|
||||
assert "成功编辑文件" in edit_result
|
||||
assert edited_content == "hello mp"
|
||||
|
||||
|
||||
def test_non_admin_file_tools_block_paths_outside_allowed_roots(
|
||||
tmp_path, monkeypatch
|
||||
):
|
||||
"""普通用户不能通过文件工具访问配置、记忆和日志目录外的路径。"""
|
||||
config_path = tmp_path / "config"
|
||||
outside_path = tmp_path / "outside.txt"
|
||||
outside_path.write_text("secret", encoding="utf-8")
|
||||
monkeypatch.setattr(settings, "CONFIG_DIR", str(config_path))
|
||||
|
||||
read_tool = ReadFileTool(session_id="session-1", user_id="10001")
|
||||
write_tool = WriteFileTool(session_id="session-1", user_id="10001")
|
||||
edit_tool = EditFileTool(session_id="session-1", user_id="10001")
|
||||
list_tool = ListDirectoryTool(session_id="session-1", user_id="10001")
|
||||
send_tool = SendLocalFileTool(session_id="session-1", user_id="10001")
|
||||
send_tool.set_message_attr(
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-main",
|
||||
username="normal-user",
|
||||
)
|
||||
|
||||
read_result = asyncio.run(read_tool.run(str(outside_path)))
|
||||
write_result = asyncio.run(write_tool.run(str(outside_path), "changed"))
|
||||
edit_result = asyncio.run(edit_tool.run(str(outside_path), "secret", "changed"))
|
||||
with patch.object(ListDirectoryTool, "_list_directory_sync") as list_directory:
|
||||
list_result = asyncio.run(list_tool.run(str(tmp_path)))
|
||||
send_result = asyncio.run(send_tool.run(str(outside_path)))
|
||||
|
||||
assert "普通用户只能读取" in read_result
|
||||
assert "普通用户只能写入" in write_result
|
||||
assert "普通用户只能编辑" in edit_result
|
||||
assert "普通用户只能列出" in list_result
|
||||
assert "普通用户只能发送" in send_result
|
||||
assert outside_path.read_text(encoding="utf-8") == "secret"
|
||||
list_directory.assert_not_called()
|
||||
|
||||
|
||||
def test_admin_file_tool_can_access_paths_outside_allowed_roots(
|
||||
tmp_path, monkeypatch
|
||||
):
|
||||
"""管理员上下文不受普通用户文件访问边界限制。"""
|
||||
config_path = tmp_path / "config"
|
||||
outside_path = tmp_path / "outside.txt"
|
||||
monkeypatch.setattr(settings, "CONFIG_DIR", str(config_path))
|
||||
|
||||
tool = WriteFileTool(session_id="session-1", user_id="admin")
|
||||
tool.set_agent_context({"is_admin": True})
|
||||
|
||||
result = asyncio.run(tool.run(str(outside_path), "admin write"))
|
||||
|
||||
assert "成功写入文件" in result
|
||||
assert outside_path.read_text(encoding="utf-8") == "admin write"
|
||||
|
||||
|
||||
def test_query_downloaders_hides_sensitive_fields_for_non_admin_user():
|
||||
"""普通用户查询下载器时只返回选择下载器所需的安全字段。"""
|
||||
tool = QueryDownloadersTool(session_id="session-1", user_id="10001")
|
||||
downloaders = [
|
||||
{
|
||||
"name": "qb",
|
||||
"type": "qbittorrent",
|
||||
"enabled": True,
|
||||
"host": "http://127.0.0.1",
|
||||
"port": 8080,
|
||||
"username": "admin",
|
||||
"password": "secret",
|
||||
"apikey": "downloader-api-key",
|
||||
"token": "downloader-token",
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"app.agent.tools.impl.query_downloaders.SystemConfigOper"
|
||||
) as system_config_oper:
|
||||
system_config_oper.return_value.get.return_value = downloaders
|
||||
result = asyncio.run(tool.run())
|
||||
|
||||
payload = json.loads(result)
|
||||
assert payload == [
|
||||
{
|
||||
"name": "qb",
|
||||
"type": "qbittorrent",
|
||||
"enabled": True,
|
||||
}
|
||||
]
|
||||
assert "host" not in payload[0]
|
||||
assert "username" not in payload[0]
|
||||
assert "password" not in payload[0]
|
||||
assert "apikey" not in payload[0]
|
||||
assert "token" not in payload[0]
|
||||
|
||||
|
||||
def test_query_downloaders_keeps_full_fields_for_admin_context():
|
||||
"""管理员查询下载器时保留完整配置视图。"""
|
||||
tool = QueryDownloadersTool(session_id="session-1", user_id="admin")
|
||||
tool.set_agent_context({"is_admin": True})
|
||||
downloaders = [
|
||||
{
|
||||
"name": "qb",
|
||||
"type": "qbittorrent",
|
||||
"enabled": True,
|
||||
"host": "http://127.0.0.1",
|
||||
"username": "admin",
|
||||
"password": "secret",
|
||||
"apikey": "downloader-api-key",
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"app.agent.tools.impl.query_downloaders.SystemConfigOper"
|
||||
) as system_config_oper:
|
||||
system_config_oper.return_value.get.return_value = downloaders
|
||||
result = asyncio.run(tool.run())
|
||||
|
||||
payload = json.loads(result)
|
||||
assert payload[0]["host"] == "http://127.0.0.1"
|
||||
assert payload[0]["username"] == "admin"
|
||||
assert payload[0]["password"] == "secret"
|
||||
assert payload[0]["apikey"] == "downloader-api-key"
|
||||
|
||||
|
||||
def test_channel_agent_admin_user_id_does_not_bypass_user_lookup():
|
||||
"""渠道用户 ID 恰好为 admin 时,不应绕过真实系统用户权限判断。"""
|
||||
agent = MoviePilotAgent(
|
||||
session_id="session-1",
|
||||
user_id="admin",
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-main",
|
||||
username="normal-user",
|
||||
)
|
||||
|
||||
with patch("app.agent.UserOper") as user_oper:
|
||||
user_oper.return_value.async_get_by_name.return_value = SimpleNamespace(
|
||||
is_superuser=False
|
||||
)
|
||||
context = asyncio.run(
|
||||
agent._build_tool_context(should_dispatch_reply=True)
|
||||
)
|
||||
|
||||
assert context["is_admin"] is False
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
@@ -19,134 +18,132 @@ from app.agent.middleware.subagents import (
|
||||
from app.agent.tools.tags import ToolTag
|
||||
|
||||
|
||||
class TestAgentSubagents(unittest.TestCase):
|
||||
def test_create_subagent_middlewares_registers_task_tool(self):
|
||||
"""子代理中间件应向主 Agent 注册 task 委派工具。"""
|
||||
model = FakeListChatModel(responses=["ok"])
|
||||
def test_create_subagent_middlewares_registers_task_tool():
|
||||
"""子代理中间件应向主 Agent 注册 task 委派工具。"""
|
||||
model = FakeListChatModel(responses=["ok"])
|
||||
|
||||
middlewares, task_tools = create_subagent_middlewares(
|
||||
model=model,
|
||||
tools=[],
|
||||
stream_handler=None,
|
||||
)
|
||||
middlewares, task_tools = create_subagent_middlewares(
|
||||
model=model,
|
||||
tools=[],
|
||||
stream_handler=None,
|
||||
)
|
||||
|
||||
self.assertEqual(len(middlewares), 3)
|
||||
self.assertEqual(
|
||||
[tool.name for tool in task_tools],
|
||||
[SUBAGENT_TASK_TOOL_NAME, SUBAGENT_CONTROL_TOOL_NAME],
|
||||
)
|
||||
self.assertIn("media-researcher", task_tools[0].description)
|
||||
self.assertIn("moviepilot-explorer", task_tools[0].description)
|
||||
self.assertIn("system-diagnostician", task_tools[0].description)
|
||||
self.assertIn("action=start", task_tools[1].description)
|
||||
self.assertIn("action=wait", task_tools[1].description)
|
||||
|
||||
def test_subagent_tools_are_selected_by_tags(self):
|
||||
"""子代理应根据工具标签筛选工具,而不是依赖工具名名单。"""
|
||||
model = FakeListChatModel(responses=["ok"])
|
||||
tools = [
|
||||
SimpleNamespace(
|
||||
name="custom_media_lookup",
|
||||
tags=[ToolTag.Read.value, ToolTag.Media.value],
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="custom_media_writer",
|
||||
tags=[ToolTag.Read.value, ToolTag.Write.value, ToolTag.Media.value],
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="custom_site_lookup",
|
||||
tags=[ToolTag.Read.value, ToolTag.Site.value],
|
||||
),
|
||||
]
|
||||
captured = {}
|
||||
|
||||
def _fake_create_agent(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return kwargs
|
||||
|
||||
middleware = MoviePilotSubAgentMiddleware(
|
||||
model=model,
|
||||
profiles=subagent_module._builtin_subagent_profiles(),
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
with patch.object(subagent_module, "create_agent", side_effect=_fake_create_agent):
|
||||
middleware._get_agent("media-researcher")
|
||||
|
||||
self.assertEqual(
|
||||
[tool.name for tool in captured["tools"]],
|
||||
["custom_media_lookup"],
|
||||
)
|
||||
|
||||
def test_moviepilot_explorer_selects_code_and_settings_tools(self):
|
||||
"""MoviePilot 探索子代理应能读取代码、目录、设置和命令诊断工具。"""
|
||||
model = FakeListChatModel(responses=["ok"])
|
||||
tools = [
|
||||
SimpleNamespace(
|
||||
name="custom_code_reader",
|
||||
tags=[ToolTag.Read.value, ToolTag.File.value],
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="custom_directory_lister",
|
||||
tags=[ToolTag.Read.value, ToolTag.Directory.value],
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="custom_settings_reader",
|
||||
tags=[ToolTag.Read.value, ToolTag.Settings.value],
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="custom_command_runner",
|
||||
tags=[ToolTag.Read.value, ToolTag.Command.value],
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="custom_code_writer",
|
||||
tags=[ToolTag.Read.value, ToolTag.Write.value, ToolTag.File.value],
|
||||
),
|
||||
]
|
||||
captured = {}
|
||||
|
||||
def _fake_create_agent(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return kwargs
|
||||
|
||||
middleware = MoviePilotSubAgentMiddleware(
|
||||
model=model,
|
||||
profiles=subagent_module._builtin_subagent_profiles(),
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
with patch.object(subagent_module, "create_agent", side_effect=_fake_create_agent):
|
||||
middleware._get_agent("moviepilot-explorer")
|
||||
|
||||
self.assertEqual(
|
||||
[tool.name for tool in captured["tools"]],
|
||||
[
|
||||
"custom_code_reader",
|
||||
"custom_directory_lister",
|
||||
"custom_settings_reader",
|
||||
"custom_command_runner",
|
||||
],
|
||||
)
|
||||
|
||||
def test_builtin_tools_declare_tags_in_implementation(self):
|
||||
"""所有内置工具实现都应显式声明 tags。"""
|
||||
impl_dir = Path(__file__).resolve().parents[1] / "app" / "agent" / "tools" / "impl"
|
||||
missing_tools = []
|
||||
for path in sorted(impl_dir.glob("*.py")):
|
||||
text = path.read_text()
|
||||
for block in text.split("\nclass "):
|
||||
if "(MoviePilotTool)" not in block:
|
||||
continue
|
||||
class_name = block.split("(", 1)[0].strip()
|
||||
if "tags: list[str]" not in block:
|
||||
missing_tools.append(f"{path.name}:{class_name}")
|
||||
|
||||
self.assertEqual([], missing_tools)
|
||||
assert len(middlewares) == 3
|
||||
assert [tool.name for tool in task_tools] == [
|
||||
SUBAGENT_TASK_TOOL_NAME,
|
||||
SUBAGENT_CONTROL_TOOL_NAME,
|
||||
]
|
||||
assert "media-researcher" in task_tools[0].description
|
||||
assert "moviepilot-explorer" in task_tools[0].description
|
||||
assert "system-diagnostician" in task_tools[0].description
|
||||
assert "action=start" in task_tools[1].description
|
||||
assert "action=wait" in task_tools[1].description
|
||||
assert "action=pipeline" in task_tools[1].description
|
||||
|
||||
|
||||
class TestSubAgentTaskControlMiddleware(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_call_summary_middleware_logs_subagent_tool_operations(self):
|
||||
"""子代理工具包装层应记录工具执行开始和完成日志。"""
|
||||
def test_subagent_tools_are_selected_by_tags():
|
||||
"""子代理应根据工具标签筛选工具,而不是依赖工具名名单。"""
|
||||
model = FakeListChatModel(responses=["ok"])
|
||||
tools = [
|
||||
SimpleNamespace(
|
||||
name="custom_media_lookup",
|
||||
tags=[ToolTag.Read.value, ToolTag.Media.value],
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="custom_media_writer",
|
||||
tags=[ToolTag.Read.value, ToolTag.Write.value, ToolTag.Media.value],
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="custom_site_lookup",
|
||||
tags=[ToolTag.Read.value, ToolTag.Site.value],
|
||||
),
|
||||
]
|
||||
captured = {}
|
||||
|
||||
def _fake_create_agent(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return kwargs
|
||||
|
||||
middleware = MoviePilotSubAgentMiddleware(
|
||||
model=model,
|
||||
profiles=subagent_module._builtin_subagent_profiles(),
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
with patch.object(subagent_module, "create_agent", side_effect=_fake_create_agent):
|
||||
middleware._get_agent("media-researcher")
|
||||
|
||||
assert [tool.name for tool in captured["tools"]] == ["custom_media_lookup"]
|
||||
|
||||
|
||||
def test_moviepilot_explorer_selects_code_and_settings_tools():
|
||||
"""MoviePilot 探索子代理应能读取代码、目录、设置和命令诊断工具。"""
|
||||
model = FakeListChatModel(responses=["ok"])
|
||||
tools = [
|
||||
SimpleNamespace(
|
||||
name="custom_code_reader",
|
||||
tags=[ToolTag.Read.value, ToolTag.File.value],
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="custom_directory_lister",
|
||||
tags=[ToolTag.Read.value, ToolTag.Directory.value],
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="custom_settings_reader",
|
||||
tags=[ToolTag.Read.value, ToolTag.Settings.value],
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="custom_command_runner",
|
||||
tags=[ToolTag.Read.value, ToolTag.Command.value],
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="custom_code_writer",
|
||||
tags=[ToolTag.Read.value, ToolTag.Write.value, ToolTag.File.value],
|
||||
),
|
||||
]
|
||||
captured = {}
|
||||
|
||||
def _fake_create_agent(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return kwargs
|
||||
|
||||
middleware = MoviePilotSubAgentMiddleware(
|
||||
model=model,
|
||||
profiles=subagent_module._builtin_subagent_profiles(),
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
with patch.object(subagent_module, "create_agent", side_effect=_fake_create_agent):
|
||||
middleware._get_agent("moviepilot-explorer")
|
||||
|
||||
assert [tool.name for tool in captured["tools"]] == [
|
||||
"custom_code_reader",
|
||||
"custom_directory_lister",
|
||||
"custom_settings_reader",
|
||||
"custom_command_runner",
|
||||
]
|
||||
|
||||
|
||||
def test_builtin_tools_declare_tags_in_implementation():
|
||||
"""所有内置工具实现都应显式声明 tags。"""
|
||||
impl_dir = Path(__file__).resolve().parents[1] / "app" / "agent" / "tools" / "impl"
|
||||
missing_tools = []
|
||||
for path in sorted(impl_dir.glob("*.py")):
|
||||
text = path.read_text()
|
||||
for block in text.split("\nclass "):
|
||||
if "(MoviePilotTool)" not in block:
|
||||
continue
|
||||
class_name = block.split("(", 1)[0].strip()
|
||||
if "tags: list[str]" not in block:
|
||||
missing_tools.append(f"{path.name}:{class_name}")
|
||||
|
||||
assert missing_tools == []
|
||||
|
||||
|
||||
def test_call_summary_middleware_logs_subagent_tool_operations():
|
||||
"""子代理工具包装层应记录工具执行开始和完成日志。"""
|
||||
|
||||
async def _run_test():
|
||||
middleware = SubAgentCallSummaryMiddleware()
|
||||
request = SimpleNamespace(
|
||||
tool=SimpleNamespace(name=SUBAGENT_CONTROL_TOOL_NAME),
|
||||
@@ -165,12 +162,17 @@ class TestSubAgentTaskControlMiddleware(unittest.IsolatedAsyncioTestCase):
|
||||
result = await middleware.awrap_tool_call(request, _fake_handler)
|
||||
|
||||
messages = [call.args[0] for call in log_info.call_args_list]
|
||||
self.assertEqual("ok", result)
|
||||
self.assertTrue(any("开始执行子代理工具" in message for message in messages))
|
||||
self.assertTrue(any("子代理工具执行完成" in message for message in messages))
|
||||
assert result == "ok"
|
||||
assert any("开始执行子代理工具" in message for message in messages)
|
||||
assert any("子代理工具执行完成" in message for message in messages)
|
||||
|
||||
async def test_control_tool_starts_tasks_concurrently_and_waits(self):
|
||||
"""异步子代理管控工具应批量启动任务,并在 wait 时收集结果。"""
|
||||
asyncio.run(_run_test())
|
||||
|
||||
|
||||
def test_control_tool_starts_tasks_concurrently_and_waits():
|
||||
"""异步子代理管控工具应批量启动任务,并在 wait 时收集结果。"""
|
||||
|
||||
async def _run_test():
|
||||
model = FakeListChatModel(responses=["ok"])
|
||||
middleware = SubAgentTaskControlMiddleware(
|
||||
model=model,
|
||||
@@ -221,21 +223,154 @@ class TestSubAgentTaskControlMiddleware(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
self.assertTrue(start_payload["success"])
|
||||
self.assertEqual(2, len(task_ids))
|
||||
self.assertEqual(["检查媒体库", "检查下载器"], running_descriptions)
|
||||
self.assertEqual(
|
||||
["completed", "completed"],
|
||||
[task["status"] for task in wait_payload["tasks"]],
|
||||
)
|
||||
self.assertIn("media-researcher:检查媒体库", wait_payload["tasks"][0]["result"])
|
||||
self.assertIn(
|
||||
"download-diagnostician:检查下载器",
|
||||
wait_payload["tasks"][1]["result"],
|
||||
assert start_payload["success"]
|
||||
assert len(task_ids) == 2
|
||||
assert running_descriptions == ["检查媒体库", "检查下载器"]
|
||||
assert [task["status"] for task in wait_payload["tasks"]] == [
|
||||
"completed",
|
||||
"completed",
|
||||
]
|
||||
assert "media-researcher:检查媒体库" in wait_payload["tasks"][0]["result"]
|
||||
assert (
|
||||
"download-diagnostician:检查下载器"
|
||||
in wait_payload["tasks"][1]["result"]
|
||||
)
|
||||
|
||||
async def test_after_agent_cancels_unfinished_tasks(self):
|
||||
"""Agent 结束时应取消仍在运行的异步子代理任务。"""
|
||||
asyncio.run(_run_test())
|
||||
|
||||
|
||||
def test_control_tool_pipeline_passes_previous_results_to_next_step():
|
||||
"""管道模式应顺序执行子代理,并把上一步结果作为下一步私有上下文。"""
|
||||
|
||||
async def _run_test():
|
||||
model = FakeListChatModel(responses=["ok"])
|
||||
middleware = SubAgentTaskControlMiddleware(
|
||||
model=model,
|
||||
profiles=subagent_module._builtin_subagent_profiles(),
|
||||
tools=[],
|
||||
)
|
||||
calls = []
|
||||
|
||||
async def _fake_run_task(self, *, description, subagent_type, task_id=None):
|
||||
calls.append(
|
||||
{
|
||||
"description": description,
|
||||
"subagent_type": subagent_type,
|
||||
"task_id": task_id,
|
||||
}
|
||||
)
|
||||
return f"结果-{len(calls)}"
|
||||
|
||||
with patch.object(
|
||||
subagent_module._SubAgentAgentProvider,
|
||||
"run_task",
|
||||
new=_fake_run_task,
|
||||
):
|
||||
payload = json.loads(
|
||||
await middleware._control_task(
|
||||
action="pipeline",
|
||||
tasks=[
|
||||
{
|
||||
"description": "识别媒体",
|
||||
"subagent_type": "media-researcher",
|
||||
},
|
||||
{
|
||||
"description": "检查下载",
|
||||
"subagent_type": "download-diagnostician",
|
||||
},
|
||||
{
|
||||
"description": "汇总结论",
|
||||
"subagent_type": "general-purpose",
|
||||
},
|
||||
],
|
||||
timeout_ms=1000,
|
||||
)
|
||||
)
|
||||
|
||||
assert payload["success"]
|
||||
assert [call["subagent_type"] for call in calls] == [
|
||||
"media-researcher",
|
||||
"download-diagnostician",
|
||||
"general-purpose",
|
||||
]
|
||||
assert calls[0]["description"] == "识别媒体"
|
||||
assert "结果-1" in calls[1]["description"]
|
||||
assert "结果-1" in calls[2]["description"]
|
||||
assert "结果-2" in calls[2]["description"]
|
||||
assert [task["status"] for task in payload["tasks"]] == [
|
||||
"completed",
|
||||
"completed",
|
||||
"completed",
|
||||
]
|
||||
assert [task["result"] for task in payload["tasks"]] == [
|
||||
"结果-1",
|
||||
"结果-2",
|
||||
"结果-3",
|
||||
]
|
||||
|
||||
asyncio.run(_run_test())
|
||||
|
||||
|
||||
def test_control_tool_pipeline_stops_after_failed_step():
|
||||
"""管道模式遇到失败步骤时应中断后续子代理。"""
|
||||
|
||||
async def _run_test():
|
||||
model = FakeListChatModel(responses=["ok"])
|
||||
middleware = SubAgentTaskControlMiddleware(
|
||||
model=model,
|
||||
profiles=subagent_module._builtin_subagent_profiles(),
|
||||
tools=[],
|
||||
)
|
||||
calls = []
|
||||
|
||||
async def _fake_run_task(self, *, description, subagent_type, task_id=None):
|
||||
calls.append(subagent_type)
|
||||
if subagent_type == "download-diagnostician":
|
||||
raise RuntimeError("下载器不可用")
|
||||
return f"{subagent_type}:ok"
|
||||
|
||||
with patch.object(
|
||||
subagent_module._SubAgentAgentProvider,
|
||||
"run_task",
|
||||
new=_fake_run_task,
|
||||
):
|
||||
payload = json.loads(
|
||||
await middleware._control_task(
|
||||
action="pipeline",
|
||||
tasks=[
|
||||
{
|
||||
"description": "识别媒体",
|
||||
"subagent_type": "media-researcher",
|
||||
},
|
||||
{
|
||||
"description": "检查下载",
|
||||
"subagent_type": "download-diagnostician",
|
||||
},
|
||||
{
|
||||
"description": "汇总结论",
|
||||
"subagent_type": "general-purpose",
|
||||
},
|
||||
],
|
||||
timeout_ms=1000,
|
||||
)
|
||||
)
|
||||
|
||||
assert not payload["success"]
|
||||
assert "第 2 个管道子代理任务执行失败" in payload["error"]
|
||||
assert calls == ["media-researcher", "download-diagnostician"]
|
||||
assert [task["status"] for task in payload["tasks"]] == [
|
||||
"completed",
|
||||
"failed",
|
||||
]
|
||||
assert "下载器不可用" in payload["tasks"][1]["error"]
|
||||
|
||||
asyncio.run(_run_test())
|
||||
|
||||
|
||||
def test_after_agent_cancels_unfinished_tasks():
|
||||
"""Agent 结束时应取消仍在运行的异步子代理任务。"""
|
||||
|
||||
async def _run_test():
|
||||
model = FakeListChatModel(responses=["ok"])
|
||||
middleware = SubAgentTaskControlMiddleware(
|
||||
model=model,
|
||||
@@ -269,4 +404,6 @@ class TestSubAgentTaskControlMiddleware(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual("cancelled", status_payload["tasks"][0]["status"])
|
||||
assert status_payload["tasks"][0]["status"] == "cancelled"
|
||||
|
||||
asyncio.run(_run_test())
|
||||
|
||||
Reference in New Issue
Block a user