feat: add async subagent task control

This commit is contained in:
jxxghp
2026-05-31 21:55:25 +08:00
parent 1922cce499
commit 40d0b60aa2
4 changed files with 679 additions and 45 deletions

View File

@@ -13,6 +13,10 @@ from app.agent import (
_MessageTask,
)
from app.agent.memory import memory_manager
from app.agent.middleware.subagents import (
SUBAGENT_CONTROL_TOOL_NAME,
SUBAGENT_TASK_TOOL_NAME,
)
from app.agent.tools.factory import MoviePilotToolFactory
from app.core.config import settings
from app.utils.identity import SYSTEM_INTERNAL_USER_ID
@@ -355,6 +359,52 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
self.assertIn("send_message", always_include)
async def test_create_agent_always_includes_subagent_tools(self):
"""工具筛选开启时应保留同步和异步子代理入口。"""
captured = {}
agent = MoviePilotAgent(session_id="normal-session", user_id="system")
agent._initialize_tools = lambda: []
agent._initialize_subagent_tools = lambda: []
def _tool_selector(**kwargs):
captured["always_include"] = kwargs["always_include"]
return "selector"
with (
patch.object(settings, "LLM_MAX_TOOLS", 5),
patch.object(agent, "_initialize_llm", new=AsyncMock(return_value=object())),
patch("app.agent.prompt_manager.get_agent_prompt", return_value="PROMPT"),
patch(
"app.agent.create_subagent_middlewares",
return_value=(
["subagent"],
[
SimpleNamespace(name=SUBAGENT_TASK_TOOL_NAME),
SimpleNamespace(name=SUBAGENT_CONTROL_TOOL_NAME),
],
),
),
patch(
"app.agent.MoviePilotToolFactory.get_tool_selector_always_include_names",
return_value=[],
),
patch("app.agent.SkillsMiddleware", side_effect=lambda *args, **kwargs: "skills"),
patch("app.agent.JobsMiddleware", side_effect=lambda *args, **kwargs: "jobs"),
patch("app.agent.RuntimeConfigMiddleware", side_effect=lambda *args, **kwargs: "runtime"),
patch("app.agent.MemoryMiddleware", side_effect=lambda *args, **kwargs: "memory"),
patch("app.agent.ActivityLogMiddleware", side_effect=lambda *args, **kwargs: "activity"),
patch("app.agent.SummarizationMiddleware", side_effect=lambda *args, **kwargs: "summary"),
patch("app.agent.PatchToolCallsMiddleware", side_effect=lambda *args, **kwargs: "patch"),
patch("app.agent.UsageMiddleware", side_effect=lambda *args, **kwargs: "usage"),
patch("app.agent.ToolSelectorMiddleware", side_effect=_tool_selector),
patch("app.agent.InMemorySaver", return_value="checkpointer"),
patch("app.agent.create_agent", side_effect=lambda **kwargs: kwargs),
):
await agent._create_agent(streaming=False)
self.assertIn(SUBAGENT_TASK_TOOL_NAME, captured["always_include"])
self.assertIn(SUBAGENT_CONTROL_TOOL_NAME, captured["always_include"])
async def test_create_agent_keeps_activity_log_for_normal_session(self):
agent = MoviePilotAgent(session_id="normal-session", user_id="system")
agent._initialize_tools = lambda: []

View File

@@ -1,3 +1,5 @@
import asyncio
import json
import unittest
from pathlib import Path
from types import SimpleNamespace
@@ -8,7 +10,9 @@ from langchain_core.language_models.fake_chat_models import FakeListChatModel
import app.agent.middleware.subagents as subagent_module
from app.agent.middleware.subagents import (
MoviePilotSubAgentMiddleware,
SUBAGENT_CONTROL_TOOL_NAME,
SUBAGENT_TASK_TOOL_NAME,
SubAgentTaskControlMiddleware,
create_subagent_middlewares,
)
from app.agent.tools.tags import ToolTag
@@ -25,10 +29,15 @@ class TestAgentSubagents(unittest.TestCase):
stream_handler=None,
)
self.assertEqual(len(middlewares), 2)
self.assertEqual([tool.name for tool in task_tools], [SUBAGENT_TASK_TOOL_NAME])
self.assertEqual(len(middlewares), 3)
self.assertEqual(
[tool.name for tool in task_tools],
[SUBAGENT_TASK_TOOL_NAME, SUBAGENT_CONTROL_TOOL_NAME],
)
self.assertIn("media-researcher", task_tools[0].description)
self.assertIn("system-diagnostician", task_tools[0].description)
self.assertIn("action=start", task_tools[1].description)
self.assertIn("action=wait", task_tools[1].description)
def test_subagent_tools_are_selected_by_tags(self):
"""子代理应根据工具标签筛选工具,而不是依赖工具名名单。"""
@@ -83,5 +92,109 @@ class TestAgentSubagents(unittest.TestCase):
self.assertEqual([], missing_tools)
class TestSubAgentTaskControlMiddleware(unittest.IsolatedAsyncioTestCase):
async def test_control_tool_starts_tasks_concurrently_and_waits(self):
"""异步子代理管控工具应批量启动任务,并在 wait 时收集结果。"""
model = FakeListChatModel(responses=["ok"])
middleware = SubAgentTaskControlMiddleware(
model=model,
profiles=subagent_module._builtin_subagent_profiles(),
tools=[],
)
running_descriptions = []
both_started = asyncio.Event()
allow_finish = asyncio.Event()
async def _fake_run_task(self, *, description, subagent_type, task_id=None):
running_descriptions.append(description)
if len(running_descriptions) == 2:
both_started.set()
await allow_finish.wait()
return f"{subagent_type}:{description}:{task_id}"
with patch.object(
subagent_module._SubAgentAgentProvider,
"run_task",
new=_fake_run_task,
):
start_payload = json.loads(
await middleware._control_task(
action="start",
tasks=[
{
"description": "检查媒体库",
"subagent_type": "media-researcher",
},
{
"description": "检查下载器",
"subagent_type": "download-diagnostician",
},
],
)
)
await asyncio.wait_for(both_started.wait(), timeout=1)
allow_finish.set()
task_ids = [task["task_id"] for task in start_payload["tasks"]]
wait_payload = json.loads(
await middleware._control_task(
action="wait",
task_ids=task_ids,
wait_mode="all",
timeout_ms=1000,
)
)
self.assertTrue(start_payload["success"])
self.assertEqual(2, len(task_ids))
self.assertEqual(["检查媒体库", "检查下载器"], running_descriptions)
self.assertEqual(
["completed", "completed"],
[task["status"] for task in wait_payload["tasks"]],
)
self.assertIn("media-researcher:检查媒体库", wait_payload["tasks"][0]["result"])
self.assertIn(
"download-diagnostician:检查下载器",
wait_payload["tasks"][1]["result"],
)
async def test_after_agent_cancels_unfinished_tasks(self):
"""Agent 结束时应取消仍在运行的异步子代理任务。"""
model = FakeListChatModel(responses=["ok"])
middleware = SubAgentTaskControlMiddleware(
model=model,
profiles=subagent_module._builtin_subagent_profiles(),
tools=[],
)
task_started = asyncio.Event()
async def _fake_run_task(self, *, description, subagent_type, task_id=None):
task_started.set()
await asyncio.Event().wait()
with patch.object(
subagent_module._SubAgentAgentProvider,
"run_task",
new=_fake_run_task,
):
start_payload = json.loads(
await middleware._control_task(
action="start",
description="长时间诊断",
subagent_type="system-diagnostician",
)
)
await asyncio.wait_for(task_started.wait(), timeout=1)
await middleware.aafter_agent({}, None)
status_payload = json.loads(
await middleware._control_task(
action="status",
task_ids=[start_payload["tasks"][0]["task_id"]],
)
)
self.assertEqual("cancelled", status_payload["tasks"][0]["status"])
if __name__ == "__main__":
unittest.main()