feat: ensure essential tools are always included in LLM tool selection and update tests

- Add mechanism to always include core tools (e.g., file operations, command execution) in LLMToolSelectorMiddleware
- Update MoviePilotToolFactory to provide filtered always-include tool names based on loaded tools
- Set default LLM_MAX_TOOLS to 30 in config
- Refactor agent initialization to support always_include parameter
- Enhance tests to cover always_include logic and async agent creation
This commit is contained in:
jxxghp
2026-04-30 13:04:52 +08:00
parent 28a2386f2f
commit 53f6897d62
4 changed files with 74 additions and 9 deletions

View File

@@ -407,6 +407,10 @@ class MoviePilotAgent:
# 工具列表
tools = self._initialize_tools()
max_tools = settings.LLM_MAX_TOOLS
always_include_tools = (
MoviePilotToolFactory.get_tool_selector_always_include_names(tools)
)
# 中间件
middlewares = [
@@ -438,11 +442,12 @@ class MoviePilotAgent:
]
# 工具选择
if settings.LLM_MAX_TOOLS > 0:
if max_tools > 0:
middlewares.append(
LLMToolSelectorMiddleware(
model=non_streaming_llm,
max_tools=settings.LLM_MAX_TOOLS,
max_tools=max_tools,
always_include=always_include_tools,
)
)

View File

@@ -87,6 +87,18 @@ class MoviePilotToolFactory:
MoviePilot工具工厂
"""
# 这些通用工具需要始终保留,避免大工具集裁剪后让 Agent 丢失基础的
# 文件系统、命令执行或交互确认能力。AskUserChoiceTool 仅在支持按钮
# 的渠道中才会实际注入,因此后续会再按已加载工具做一次求交集。
TOOL_SELECTOR_ALWAYS_INCLUDE_NAMES = (
"list_directory",
"write_file",
"read_file",
"edit_file",
"execute_command",
"ask_user_choice",
)
@staticmethod
def _should_enable_choice_tool(channel: str = None) -> bool:
if not channel:
@@ -99,6 +111,25 @@ class MoviePilotToolFactory:
message_channel
) and ChannelCapabilityManager.supports_callbacks(message_channel)
@classmethod
def get_tool_selector_always_include_names(
cls, tools: List[MoviePilotTool]
) -> List[str]:
"""
返回当前实际已加载且需要绕过工具筛选的工具名。
`LLMToolSelectorMiddleware` 会校验 `always_include` 中的工具名是否
存在于当前请求里,因此这里必须根据运行时工具列表做交集过滤。
"""
available_tool_names = {
tool.name for tool in tools if getattr(tool, "name", None)
}
return [
tool_name
for tool_name in cls.TOOL_SELECTOR_ALWAYS_INCLUDE_NAMES
if tool_name in available_tool_names
]
@staticmethod
def create_tools(
session_id: str,

View File

@@ -547,7 +547,7 @@ class ConfigModel(BaseModel):
# AI推荐条目数量限制
AI_RECOMMEND_MAX_ITEMS: int = 50
# LLM工具选择中间件最大工具数量0为不启用工具选择中间件
LLM_MAX_TOOLS: int = 0
LLM_MAX_TOOLS: int = 30
# AI智能体定时任务检查间隔小时0为不启用默认24小时
AI_AGENT_JOB_INTERVAL: int = 0
# AI智能体啰嗦模式开启后会回复工具调用过程

View File

@@ -1,3 +1,4 @@
import asyncio
import unittest
from unittest.mock import patch
@@ -35,8 +36,9 @@ class TestAgentSummarizationStreaming(unittest.TestCase):
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
),
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 0),
):
agent._create_agent(streaming=True)
asyncio.run(agent._create_agent(streaming=True))
summary_middleware = next(
middleware
@@ -54,19 +56,33 @@ class TestAgentSummarizationStreaming(unittest.TestCase):
captured: dict = {}
class _FakeToolSelectorMiddleware:
def __init__(self, model, max_tools):
def __init__(self, model, max_tools, always_include=None):
self.model = model
self.max_tools = max_tools
self.always_include = always_include or []
def _fake_create_agent(**kwargs):
captured.update(kwargs)
return object()
class _FakeTool:
def __init__(self, name: str):
self.name = name
fake_tools = [
_FakeTool("list_directory"),
_FakeTool("write_file"),
_FakeTool("read_file"),
_FakeTool("edit_file"),
_FakeTool("execute_command"),
_FakeTool("search_media"),
]
with (
patch.object(
agent, "_initialize_llm", side_effect=[main_llm, non_streaming_llm]
),
patch.object(agent, "_initialize_tools", return_value=[]),
patch.object(agent, "_initialize_tools", return_value=fake_tools),
patch.object(
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
),
@@ -78,7 +94,7 @@ class TestAgentSummarizationStreaming(unittest.TestCase):
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 3),
):
agent._create_agent(streaming=True)
asyncio.run(agent._create_agent(streaming=True))
tool_selector_middleware = next(
middleware
@@ -87,6 +103,17 @@ class TestAgentSummarizationStreaming(unittest.TestCase):
)
self.assertIs(tool_selector_middleware.model, non_streaming_llm)
self.assertEqual(tool_selector_middleware.max_tools, 3)
self.assertEqual(
tool_selector_middleware.always_include,
[
"list_directory",
"write_file",
"read_file",
"edit_file",
"execute_command",
],
)
def test_non_streaming_agent_reuses_main_llm_for_summary(self):
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
@@ -104,8 +131,9 @@ class TestAgentSummarizationStreaming(unittest.TestCase):
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
),
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 0),
):
agent._create_agent(streaming=False)
asyncio.run(agent._create_agent(streaming=False))
summary_middleware = next(
middleware
@@ -132,8 +160,9 @@ class TestAgentSummarizationStreaming(unittest.TestCase):
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
),
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 0),
):
agent._create_agent(streaming=False)
asyncio.run(agent._create_agent(streaming=False))
self.assertTrue(
any(