Files
MoviePilot/tests/test_agent_summarization_streaming.py
2026-06-21 18:49:27 +08:00

212 lines
7.0 KiB
Python

import asyncio
from unittest.mock import patch
from langchain.agents.middleware import SummarizationMiddleware
import app.agent as agent_module
from app.agent.middleware.runtime_config import RuntimeConfigMiddleware
class _FakeLLM:
_llm_type = "openai-chat"
def __init__(self, model: str):
self.model = model
self.profile = {"max_input_tokens": 64000}
def test_streaming_agent_uses_non_streaming_llm_for_summary():
"""流式 Agent 的摘要中间件应使用非流式 LLM。"""
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
main_llm = _FakeLLM("main")
non_streaming_llm = _FakeLLM("non-streaming")
captured: dict = {}
def _fake_create_agent(**kwargs):
"""捕获 create_agent 参数。"""
captured.update(kwargs)
return object()
with (
patch.object(
agent, "_initialize_llm", side_effect=[main_llm, non_streaming_llm]
),
patch.object(agent, "_initialize_tools", return_value=[]),
patch.object(
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
),
patch.object(
agent_module, "create_subagent_middlewares", return_value=([], [])
),
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 0),
):
asyncio.run(agent._create_agent(streaming=True))
summary_middleware = next(
middleware
for middleware in captured["middleware"]
if isinstance(middleware, SummarizationMiddleware)
)
assert captured["model"] is main_llm
assert summary_middleware.model is non_streaming_llm
def test_streaming_agent_uses_non_streaming_llm_for_model_middlewares():
"""流式 Agent 的模型型中间件应使用非流式 LLM。"""
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
main_llm = _FakeLLM("main")
non_streaming_llm = _FakeLLM("non-streaming")
captured: dict = {}
class _FakeToolSelectorMiddleware:
"""记录工具选择中间件初始化参数。"""
def __init__(
self,
model,
max_tools,
always_include=None,
selection_tools=None,
):
"""保存测试断言需要的参数。"""
self.model = model
self.max_tools = max_tools
self.always_include = always_include or []
self.selection_tools = selection_tools or []
def _fake_create_agent(**kwargs):
"""捕获 create_agent 参数。"""
captured.update(kwargs)
return object()
class _FakeTool:
"""测试用工具占位对象。"""
def __init__(self, name: str):
"""保存工具名。"""
self.name = name
fake_tools = [
_FakeTool("list_directory"),
_FakeTool("write_file"),
_FakeTool("read_file"),
_FakeTool("edit_file"),
_FakeTool("execute_command"),
_FakeTool("search_media"),
]
with (
patch.object(
agent, "_initialize_llm", side_effect=[main_llm, non_streaming_llm]
),
patch.object(agent, "_initialize_tools", return_value=fake_tools),
patch.object(
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
),
patch.object(
agent_module, "create_subagent_middlewares", return_value=([], [])
),
patch.object(
agent_module,
"ToolSelectorMiddleware",
_FakeToolSelectorMiddleware,
),
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 3),
):
asyncio.run(agent._create_agent(streaming=True))
tool_selector_middleware = next(
middleware
for middleware in captured["middleware"]
if isinstance(middleware, _FakeToolSelectorMiddleware)
)
assert tool_selector_middleware.model is non_streaming_llm
assert tool_selector_middleware.max_tools == 3
assert tool_selector_middleware.always_include == [
"list_directory",
"write_file",
"read_file",
"edit_file",
"execute_command",
"skill",
]
assert tool_selector_middleware.selection_tools[: len(fake_tools)] == fake_tools
assert [
getattr(tool, "name", None)
for tool in tool_selector_middleware.selection_tools[len(fake_tools):]
] == ["skill"]
def test_non_streaming_agent_reuses_main_llm_for_summary():
"""非流式 Agent 的摘要中间件应复用主 LLM。"""
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
main_llm = _FakeLLM("main")
captured: dict = {}
def _fake_create_agent(**kwargs):
"""捕获 create_agent 参数。"""
captured.update(kwargs)
return object()
with (
patch.object(agent, "_initialize_llm", return_value=main_llm),
patch.object(agent, "_initialize_tools", return_value=[]),
patch.object(
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
),
patch.object(
agent_module, "create_subagent_middlewares", return_value=([], [])
),
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 0),
):
asyncio.run(agent._create_agent(streaming=False))
summary_middleware = next(
middleware
for middleware in captured["middleware"]
if isinstance(middleware, SummarizationMiddleware)
)
assert captured["model"] is main_llm
assert summary_middleware.model is main_llm
def test_agent_uses_runtime_config_middleware_instead_of_hooks():
"""Agent 应使用运行时配置中间件而不是旧 hooks。"""
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
main_llm = _FakeLLM("main")
captured: dict = {}
def _fake_create_agent(**kwargs):
"""捕获 create_agent 参数。"""
captured.update(kwargs)
return object()
with (
patch.object(agent, "_initialize_llm", return_value=main_llm),
patch.object(agent, "_initialize_tools", return_value=[]),
patch.object(
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
),
patch.object(
agent_module, "create_subagent_middlewares", return_value=([], [])
),
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 0),
):
asyncio.run(agent._create_agent(streaming=False))
assert any(
isinstance(middleware, RuntimeConfigMiddleware)
for middleware in captured["middleware"]
)
assert not any(
type(middleware).__name__ == "AgentHooksMiddleware"
for middleware in captured["middleware"]
)