mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-11 18:50:59 +08:00
feat: add async subagent task control
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
@@ -8,7 +10,9 @@ from langchain_core.language_models.fake_chat_models import FakeListChatModel
|
||||
import app.agent.middleware.subagents as subagent_module
|
||||
from app.agent.middleware.subagents import (
|
||||
MoviePilotSubAgentMiddleware,
|
||||
SUBAGENT_CONTROL_TOOL_NAME,
|
||||
SUBAGENT_TASK_TOOL_NAME,
|
||||
SubAgentTaskControlMiddleware,
|
||||
create_subagent_middlewares,
|
||||
)
|
||||
from app.agent.tools.tags import ToolTag
|
||||
@@ -25,10 +29,15 @@ class TestAgentSubagents(unittest.TestCase):
|
||||
stream_handler=None,
|
||||
)
|
||||
|
||||
self.assertEqual(len(middlewares), 2)
|
||||
self.assertEqual([tool.name for tool in task_tools], [SUBAGENT_TASK_TOOL_NAME])
|
||||
self.assertEqual(len(middlewares), 3)
|
||||
self.assertEqual(
|
||||
[tool.name for tool in task_tools],
|
||||
[SUBAGENT_TASK_TOOL_NAME, SUBAGENT_CONTROL_TOOL_NAME],
|
||||
)
|
||||
self.assertIn("media-researcher", task_tools[0].description)
|
||||
self.assertIn("system-diagnostician", task_tools[0].description)
|
||||
self.assertIn("action=start", task_tools[1].description)
|
||||
self.assertIn("action=wait", task_tools[1].description)
|
||||
|
||||
def test_subagent_tools_are_selected_by_tags(self):
|
||||
"""子代理应根据工具标签筛选工具,而不是依赖工具名名单。"""
|
||||
@@ -83,5 +92,109 @@ class TestAgentSubagents(unittest.TestCase):
|
||||
self.assertEqual([], missing_tools)
|
||||
|
||||
|
||||
class TestSubAgentTaskControlMiddleware(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_control_tool_starts_tasks_concurrently_and_waits(self):
|
||||
"""异步子代理管控工具应批量启动任务,并在 wait 时收集结果。"""
|
||||
model = FakeListChatModel(responses=["ok"])
|
||||
middleware = SubAgentTaskControlMiddleware(
|
||||
model=model,
|
||||
profiles=subagent_module._builtin_subagent_profiles(),
|
||||
tools=[],
|
||||
)
|
||||
running_descriptions = []
|
||||
both_started = asyncio.Event()
|
||||
allow_finish = asyncio.Event()
|
||||
|
||||
async def _fake_run_task(self, *, description, subagent_type, task_id=None):
|
||||
running_descriptions.append(description)
|
||||
if len(running_descriptions) == 2:
|
||||
both_started.set()
|
||||
await allow_finish.wait()
|
||||
return f"{subagent_type}:{description}:{task_id}"
|
||||
|
||||
with patch.object(
|
||||
subagent_module._SubAgentAgentProvider,
|
||||
"run_task",
|
||||
new=_fake_run_task,
|
||||
):
|
||||
start_payload = json.loads(
|
||||
await middleware._control_task(
|
||||
action="start",
|
||||
tasks=[
|
||||
{
|
||||
"description": "检查媒体库",
|
||||
"subagent_type": "media-researcher",
|
||||
},
|
||||
{
|
||||
"description": "检查下载器",
|
||||
"subagent_type": "download-diagnostician",
|
||||
},
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.wait_for(both_started.wait(), timeout=1)
|
||||
allow_finish.set()
|
||||
task_ids = [task["task_id"] for task in start_payload["tasks"]]
|
||||
wait_payload = json.loads(
|
||||
await middleware._control_task(
|
||||
action="wait",
|
||||
task_ids=task_ids,
|
||||
wait_mode="all",
|
||||
timeout_ms=1000,
|
||||
)
|
||||
)
|
||||
|
||||
self.assertTrue(start_payload["success"])
|
||||
self.assertEqual(2, len(task_ids))
|
||||
self.assertEqual(["检查媒体库", "检查下载器"], running_descriptions)
|
||||
self.assertEqual(
|
||||
["completed", "completed"],
|
||||
[task["status"] for task in wait_payload["tasks"]],
|
||||
)
|
||||
self.assertIn("media-researcher:检查媒体库", wait_payload["tasks"][0]["result"])
|
||||
self.assertIn(
|
||||
"download-diagnostician:检查下载器",
|
||||
wait_payload["tasks"][1]["result"],
|
||||
)
|
||||
|
||||
async def test_after_agent_cancels_unfinished_tasks(self):
|
||||
"""Agent 结束时应取消仍在运行的异步子代理任务。"""
|
||||
model = FakeListChatModel(responses=["ok"])
|
||||
middleware = SubAgentTaskControlMiddleware(
|
||||
model=model,
|
||||
profiles=subagent_module._builtin_subagent_profiles(),
|
||||
tools=[],
|
||||
)
|
||||
task_started = asyncio.Event()
|
||||
|
||||
async def _fake_run_task(self, *, description, subagent_type, task_id=None):
|
||||
task_started.set()
|
||||
await asyncio.Event().wait()
|
||||
|
||||
with patch.object(
|
||||
subagent_module._SubAgentAgentProvider,
|
||||
"run_task",
|
||||
new=_fake_run_task,
|
||||
):
|
||||
start_payload = json.loads(
|
||||
await middleware._control_task(
|
||||
action="start",
|
||||
description="长时间诊断",
|
||||
subagent_type="system-diagnostician",
|
||||
)
|
||||
)
|
||||
await asyncio.wait_for(task_started.wait(), timeout=1)
|
||||
await middleware.aafter_agent({}, None)
|
||||
status_payload = json.loads(
|
||||
await middleware._control_task(
|
||||
action="status",
|
||||
task_ids=[start_payload["tasks"][0]["task_id"]],
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual("cancelled", status_payload["tasks"][0]["status"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user