feat: optimize tool selection middleware to cache and reuse tool selection per agent run

- Refactor MoviePilotToolSelectorMiddleware to perform tool selection once per agent execution and cache the result in state, avoiding redundant LLM calls for each model round.
- Add abefore_agent to select tools at the start of agent execution and store selected tool names in state.
- Update awrap_model_call to reuse cached tool selection from state for subsequent model calls.
- Enhance test coverage for tool selection caching and reuse logic.
- Improve error logging in skill version extraction.
This commit is contained in:
jxxghp
2026-04-30 18:29:54 +08:00
parent 2ea617655c
commit afcc071d07
5 changed files with 283 additions and 15 deletions

View File

@@ -56,10 +56,17 @@ class TestAgentSummarizationStreaming(unittest.TestCase):
captured: dict = {}
class _FakeToolSelectorMiddleware:
def __init__(self, model, max_tools, always_include=None):
def __init__(
self,
model,
max_tools,
always_include=None,
selection_tools=None,
):
self.model = model
self.max_tools = max_tools
self.always_include = always_include or []
self.selection_tools = selection_tools or []
def _fake_create_agent(**kwargs):
captured.update(kwargs)
@@ -88,7 +95,7 @@ class TestAgentSummarizationStreaming(unittest.TestCase):
),
patch.object(
agent_module,
"LLMToolSelectorMiddleware",
"MoviePilotToolSelectorMiddleware",
_FakeToolSelectorMiddleware,
),
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
@@ -114,6 +121,7 @@ class TestAgentSummarizationStreaming(unittest.TestCase):
"execute_command",
],
)
self.assertEqual(tool_selector_middleware.selection_tools, fake_tools)
def test_non_streaming_agent_reuses_main_llm_for_summary(self):
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")

View File

@@ -4,6 +4,7 @@ import sys
import unittest
from pathlib import Path
from types import ModuleType, SimpleNamespace
from unittest.mock import patch
from langchain_core.messages import HumanMessage
@@ -22,6 +23,8 @@ sys.modules.pop("app.agent.middleware.tool_selection", None)
_stub_module(
"app.log",
logger=SimpleNamespace(debug=lambda *args, **kwargs: None),
log_settings=lambda *args, **kwargs: None,
LogConfigModel=type("LogConfigModel", (), {}),
)
module_path = (
@@ -70,16 +73,20 @@ class _FakeModel:
class _FakeRequest:
def __init__(self, *, tools, messages, model):
def __init__(self, *, tools, messages, model, state=None, runtime=None):
self.tools = tools
self.messages = messages
self.model = model
self.state = state if state is not None else {"messages": messages}
self.runtime = runtime
def override(self, **kwargs):
data = {
"tools": self.tools,
"messages": self.messages,
"model": self.model,
"state": self.state,
"runtime": self.runtime,
}
data.update(kwargs)
return _FakeRequest(**data)
@@ -87,13 +94,17 @@ class _FakeRequest:
class ToolSelectorMiddlewareTest(unittest.TestCase):
def test_awrap_model_call_uses_json_mode_for_deepseek(self):
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware(max_tools=2)
tools = [
SimpleNamespace(name="search", description="Search for information"),
SimpleNamespace(name="calendar", description="Manage events"),
SimpleNamespace(name="translate", description="Translate text"),
]
model = _FakeModel()
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware(
max_tools=2,
selection_tools=tools,
)
middleware.model = model
request = _FakeRequest(
tools=tools,
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
@@ -105,6 +116,11 @@ class ToolSelectorMiddlewareTest(unittest.TestCase):
handled_requests.append(updated_request)
return updated_request
state_update = asyncio.run(
middleware.abefore_agent(request.state, runtime=None, config=None)
)
if state_update:
request.state.update(state_update)
result = asyncio.run(middleware.awrap_model_call(request, handler))
self.assertEqual(
@@ -121,6 +137,108 @@ class ToolSelectorMiddlewareTest(unittest.TestCase):
self.assertIn('- calendar: Manage events', prompt)
self.assertEqual(len(handled_requests), 1)
def test_awrap_model_call_reuses_first_selection_for_later_model_rounds(self):
tools = [
SimpleNamespace(name="search", description="Search for information"),
SimpleNamespace(name="calendar", description="Manage events"),
SimpleNamespace(name="translate", description="Translate text"),
]
model = _FakeModel(content='{"tools": ["calendar", "search"]}')
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware(
max_tools=2,
selection_tools=tools,
)
middleware.model = model
request = _FakeRequest(
tools=tools,
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
model=model,
)
handled_requests = []
async def handler(updated_request):
handled_requests.append(updated_request)
return updated_request
state_update = asyncio.run(
middleware.abefore_agent(request.state, runtime=None, config=None)
)
if state_update:
request.state.update(state_update)
first_result = asyncio.run(middleware.awrap_model_call(request, handler))
second_result = asyncio.run(middleware.awrap_model_call(request, handler))
self.assertEqual(
model.bind_calls,
[{"response_format": {"type": "json_object"}}],
)
self.assertEqual(
[tool.name for tool in first_result.tools],
["search", "calendar"],
)
self.assertEqual(
[tool.name for tool in second_result.tools],
["search", "calendar"],
)
self.assertEqual(len(handled_requests), 2)
def test_awrap_model_call_caches_non_deepseek_selection_too(self):
tools = [
SimpleNamespace(name="search", description="Search for information"),
SimpleNamespace(name="calendar", description="Manage events"),
SimpleNamespace(name="translate", description="Translate text"),
]
model = _FakeModel(
model_name="gpt-4o-mini",
base_url="https://api.openai.com/v1",
)
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware(
max_tools=2,
selection_tools=tools,
)
middleware.model = model
request = _FakeRequest(
tools=tools,
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
model=model,
)
async def handler(updated_request):
return updated_request
parent_calls = 0
async def _fake_parent_awrap(self, request_arg, handler_arg):
nonlocal parent_calls
parent_calls += 1
selected_request = request_arg.override(
tools=[request_arg.tools[1], request_arg.tools[0]]
)
return await handler_arg(selected_request)
with patch.object(
tool_selector_module.LLMToolSelectorMiddleware,
"awrap_model_call",
_fake_parent_awrap,
):
state_update = asyncio.run(
middleware.abefore_agent(request.state, runtime=None, config=None)
)
if state_update:
request.state.update(state_update)
first_result = asyncio.run(middleware.awrap_model_call(request, handler))
second_result = asyncio.run(middleware.awrap_model_call(request, handler))
self.assertEqual(parent_calls, 1)
self.assertEqual(
[tool.name for tool in first_result.tools],
["calendar", "search"],
)
self.assertEqual(
[tool.name for tool in second_result.tools],
["calendar", "search"],
)
def test_normalize_selection_response_accepts_code_fence_json(self):
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware()
response = SimpleNamespace(