mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-07-02 05:11:31 +08:00
feat: ensure essential tools are always included in LLM tool selection and update tests
- Add mechanism to always include core tools (e.g., file operations, command execution) in LLMToolSelectorMiddleware - Update MoviePilotToolFactory to provide filtered always-include tool names based on loaded tools - Set default LLM_MAX_TOOLS to 30 in config - Refactor agent initialization to support always_include parameter - Enhance tests to cover always_include logic and async agent creation
This commit is contained in:
@@ -407,6 +407,10 @@ class MoviePilotAgent:
|
||||
|
||||
# 工具列表
|
||||
tools = self._initialize_tools()
|
||||
max_tools = settings.LLM_MAX_TOOLS
|
||||
always_include_tools = (
|
||||
MoviePilotToolFactory.get_tool_selector_always_include_names(tools)
|
||||
)
|
||||
|
||||
# 中间件
|
||||
middlewares = [
|
||||
@@ -438,11 +442,12 @@ class MoviePilotAgent:
|
||||
]
|
||||
|
||||
# 工具选择
|
||||
if settings.LLM_MAX_TOOLS > 0:
|
||||
if max_tools > 0:
|
||||
middlewares.append(
|
||||
LLMToolSelectorMiddleware(
|
||||
model=non_streaming_llm,
|
||||
max_tools=settings.LLM_MAX_TOOLS,
|
||||
max_tools=max_tools,
|
||||
always_include=always_include_tools,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -87,6 +87,18 @@ class MoviePilotToolFactory:
|
||||
MoviePilot工具工厂
|
||||
"""
|
||||
|
||||
# 这些通用工具需要始终保留,避免大工具集裁剪后让 Agent 丢失基础的
|
||||
# 文件系统、命令执行或交互确认能力。AskUserChoiceTool 仅在支持按钮
|
||||
# 的渠道中才会实际注入,因此后续会再按已加载工具做一次求交集。
|
||||
TOOL_SELECTOR_ALWAYS_INCLUDE_NAMES = (
|
||||
"list_directory",
|
||||
"write_file",
|
||||
"read_file",
|
||||
"edit_file",
|
||||
"execute_command",
|
||||
"ask_user_choice",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _should_enable_choice_tool(channel: str = None) -> bool:
|
||||
if not channel:
|
||||
@@ -99,6 +111,25 @@ class MoviePilotToolFactory:
|
||||
message_channel
|
||||
) and ChannelCapabilityManager.supports_callbacks(message_channel)
|
||||
|
||||
@classmethod
|
||||
def get_tool_selector_always_include_names(
|
||||
cls, tools: List[MoviePilotTool]
|
||||
) -> List[str]:
|
||||
"""
|
||||
返回当前实际已加载且需要绕过工具筛选的工具名。
|
||||
|
||||
`LLMToolSelectorMiddleware` 会校验 `always_include` 中的工具名是否
|
||||
存在于当前请求里,因此这里必须根据运行时工具列表做交集过滤。
|
||||
"""
|
||||
available_tool_names = {
|
||||
tool.name for tool in tools if getattr(tool, "name", None)
|
||||
}
|
||||
return [
|
||||
tool_name
|
||||
for tool_name in cls.TOOL_SELECTOR_ALWAYS_INCLUDE_NAMES
|
||||
if tool_name in available_tool_names
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def create_tools(
|
||||
session_id: str,
|
||||
|
||||
@@ -547,7 +547,7 @@ class ConfigModel(BaseModel):
|
||||
# AI推荐条目数量限制
|
||||
AI_RECOMMEND_MAX_ITEMS: int = 50
|
||||
# LLM工具选择中间件最大工具数量,0为不启用工具选择中间件
|
||||
LLM_MAX_TOOLS: int = 0
|
||||
LLM_MAX_TOOLS: int = 30
|
||||
# AI智能体定时任务检查间隔(小时),0为不启用,默认24小时
|
||||
AI_AGENT_JOB_INTERVAL: int = 0
|
||||
# AI智能体啰嗦模式,开启后会回复工具调用过程
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -35,8 +36,9 @@ class TestAgentSummarizationStreaming(unittest.TestCase):
|
||||
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
|
||||
),
|
||||
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
|
||||
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 0),
|
||||
):
|
||||
agent._create_agent(streaming=True)
|
||||
asyncio.run(agent._create_agent(streaming=True))
|
||||
|
||||
summary_middleware = next(
|
||||
middleware
|
||||
@@ -54,19 +56,33 @@ class TestAgentSummarizationStreaming(unittest.TestCase):
|
||||
captured: dict = {}
|
||||
|
||||
class _FakeToolSelectorMiddleware:
|
||||
def __init__(self, model, max_tools):
|
||||
def __init__(self, model, max_tools, always_include=None):
|
||||
self.model = model
|
||||
self.max_tools = max_tools
|
||||
self.always_include = always_include or []
|
||||
|
||||
def _fake_create_agent(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return object()
|
||||
|
||||
class _FakeTool:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
fake_tools = [
|
||||
_FakeTool("list_directory"),
|
||||
_FakeTool("write_file"),
|
||||
_FakeTool("read_file"),
|
||||
_FakeTool("edit_file"),
|
||||
_FakeTool("execute_command"),
|
||||
_FakeTool("search_media"),
|
||||
]
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
agent, "_initialize_llm", side_effect=[main_llm, non_streaming_llm]
|
||||
),
|
||||
patch.object(agent, "_initialize_tools", return_value=[]),
|
||||
patch.object(agent, "_initialize_tools", return_value=fake_tools),
|
||||
patch.object(
|
||||
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
|
||||
),
|
||||
@@ -78,7 +94,7 @@ class TestAgentSummarizationStreaming(unittest.TestCase):
|
||||
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
|
||||
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 3),
|
||||
):
|
||||
agent._create_agent(streaming=True)
|
||||
asyncio.run(agent._create_agent(streaming=True))
|
||||
|
||||
tool_selector_middleware = next(
|
||||
middleware
|
||||
@@ -87,6 +103,17 @@ class TestAgentSummarizationStreaming(unittest.TestCase):
|
||||
)
|
||||
|
||||
self.assertIs(tool_selector_middleware.model, non_streaming_llm)
|
||||
self.assertEqual(tool_selector_middleware.max_tools, 3)
|
||||
self.assertEqual(
|
||||
tool_selector_middleware.always_include,
|
||||
[
|
||||
"list_directory",
|
||||
"write_file",
|
||||
"read_file",
|
||||
"edit_file",
|
||||
"execute_command",
|
||||
],
|
||||
)
|
||||
|
||||
def test_non_streaming_agent_reuses_main_llm_for_summary(self):
|
||||
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
|
||||
@@ -104,8 +131,9 @@ class TestAgentSummarizationStreaming(unittest.TestCase):
|
||||
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
|
||||
),
|
||||
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
|
||||
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 0),
|
||||
):
|
||||
agent._create_agent(streaming=False)
|
||||
asyncio.run(agent._create_agent(streaming=False))
|
||||
|
||||
summary_middleware = next(
|
||||
middleware
|
||||
@@ -132,8 +160,9 @@ class TestAgentSummarizationStreaming(unittest.TestCase):
|
||||
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
|
||||
),
|
||||
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
|
||||
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 0),
|
||||
):
|
||||
agent._create_agent(streaming=False)
|
||||
asyncio.run(agent._create_agent(streaming=False))
|
||||
|
||||
self.assertTrue(
|
||||
any(
|
||||
|
||||
Reference in New Issue
Block a user