mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-12 04:59:39 +08:00
feat: optimize tool selection middleware to cache and reuse tool selection per agent run
- Refactor MoviePilotToolSelectorMiddleware to perform tool selection once per agent execution and cache the result in state, avoiding redundant LLM calls for each model round. - Add abefore_agent to select tools at the start of agent execution and store selected tool names in state. - Update awrap_model_call to reuse cached tool selection from state for subsequent model calls. - Enhance test coverage for tool selection caching and reuse logic. - Improve error logging in skill version extraction.
This commit is contained in:
@@ -56,10 +56,17 @@ class TestAgentSummarizationStreaming(unittest.TestCase):
|
||||
captured: dict = {}
|
||||
|
||||
class _FakeToolSelectorMiddleware:
|
||||
def __init__(self, model, max_tools, always_include=None):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
max_tools,
|
||||
always_include=None,
|
||||
selection_tools=None,
|
||||
):
|
||||
self.model = model
|
||||
self.max_tools = max_tools
|
||||
self.always_include = always_include or []
|
||||
self.selection_tools = selection_tools or []
|
||||
|
||||
def _fake_create_agent(**kwargs):
|
||||
captured.update(kwargs)
|
||||
@@ -88,7 +95,7 @@ class TestAgentSummarizationStreaming(unittest.TestCase):
|
||||
),
|
||||
patch.object(
|
||||
agent_module,
|
||||
"LLMToolSelectorMiddleware",
|
||||
"MoviePilotToolSelectorMiddleware",
|
||||
_FakeToolSelectorMiddleware,
|
||||
),
|
||||
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
|
||||
@@ -114,6 +121,7 @@ class TestAgentSummarizationStreaming(unittest.TestCase):
|
||||
"execute_command",
|
||||
],
|
||||
)
|
||||
self.assertEqual(tool_selector_middleware.selection_tools, fake_tools)
|
||||
|
||||
def test_non_streaming_agent_reuses_main_llm_for_summary(self):
|
||||
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
|
||||
|
||||
@@ -4,6 +4,7 @@ import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
@@ -22,6 +23,8 @@ sys.modules.pop("app.agent.middleware.tool_selection", None)
|
||||
_stub_module(
|
||||
"app.log",
|
||||
logger=SimpleNamespace(debug=lambda *args, **kwargs: None),
|
||||
log_settings=lambda *args, **kwargs: None,
|
||||
LogConfigModel=type("LogConfigModel", (), {}),
|
||||
)
|
||||
|
||||
module_path = (
|
||||
@@ -70,16 +73,20 @@ class _FakeModel:
|
||||
|
||||
|
||||
class _FakeRequest:
|
||||
def __init__(self, *, tools, messages, model):
|
||||
def __init__(self, *, tools, messages, model, state=None, runtime=None):
|
||||
self.tools = tools
|
||||
self.messages = messages
|
||||
self.model = model
|
||||
self.state = state if state is not None else {"messages": messages}
|
||||
self.runtime = runtime
|
||||
|
||||
def override(self, **kwargs):
|
||||
data = {
|
||||
"tools": self.tools,
|
||||
"messages": self.messages,
|
||||
"model": self.model,
|
||||
"state": self.state,
|
||||
"runtime": self.runtime,
|
||||
}
|
||||
data.update(kwargs)
|
||||
return _FakeRequest(**data)
|
||||
@@ -87,13 +94,17 @@ class _FakeRequest:
|
||||
|
||||
class ToolSelectorMiddlewareTest(unittest.TestCase):
|
||||
def test_awrap_model_call_uses_json_mode_for_deepseek(self):
|
||||
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware(max_tools=2)
|
||||
tools = [
|
||||
SimpleNamespace(name="search", description="Search for information"),
|
||||
SimpleNamespace(name="calendar", description="Manage events"),
|
||||
SimpleNamespace(name="translate", description="Translate text"),
|
||||
]
|
||||
model = _FakeModel()
|
||||
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
selection_tools=tools,
|
||||
)
|
||||
middleware.model = model
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
|
||||
@@ -105,6 +116,11 @@ class ToolSelectorMiddlewareTest(unittest.TestCase):
|
||||
handled_requests.append(updated_request)
|
||||
return updated_request
|
||||
|
||||
state_update = asyncio.run(
|
||||
middleware.abefore_agent(request.state, runtime=None, config=None)
|
||||
)
|
||||
if state_update:
|
||||
request.state.update(state_update)
|
||||
result = asyncio.run(middleware.awrap_model_call(request, handler))
|
||||
|
||||
self.assertEqual(
|
||||
@@ -121,6 +137,108 @@ class ToolSelectorMiddlewareTest(unittest.TestCase):
|
||||
self.assertIn('- calendar: Manage events', prompt)
|
||||
self.assertEqual(len(handled_requests), 1)
|
||||
|
||||
def test_awrap_model_call_reuses_first_selection_for_later_model_rounds(self):
|
||||
tools = [
|
||||
SimpleNamespace(name="search", description="Search for information"),
|
||||
SimpleNamespace(name="calendar", description="Manage events"),
|
||||
SimpleNamespace(name="translate", description="Translate text"),
|
||||
]
|
||||
model = _FakeModel(content='{"tools": ["calendar", "search"]}')
|
||||
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
selection_tools=tools,
|
||||
)
|
||||
middleware.model = model
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
|
||||
model=model,
|
||||
)
|
||||
handled_requests = []
|
||||
|
||||
async def handler(updated_request):
|
||||
handled_requests.append(updated_request)
|
||||
return updated_request
|
||||
|
||||
state_update = asyncio.run(
|
||||
middleware.abefore_agent(request.state, runtime=None, config=None)
|
||||
)
|
||||
if state_update:
|
||||
request.state.update(state_update)
|
||||
first_result = asyncio.run(middleware.awrap_model_call(request, handler))
|
||||
second_result = asyncio.run(middleware.awrap_model_call(request, handler))
|
||||
|
||||
self.assertEqual(
|
||||
model.bind_calls,
|
||||
[{"response_format": {"type": "json_object"}}],
|
||||
)
|
||||
self.assertEqual(
|
||||
[tool.name for tool in first_result.tools],
|
||||
["search", "calendar"],
|
||||
)
|
||||
self.assertEqual(
|
||||
[tool.name for tool in second_result.tools],
|
||||
["search", "calendar"],
|
||||
)
|
||||
self.assertEqual(len(handled_requests), 2)
|
||||
|
||||
def test_awrap_model_call_caches_non_deepseek_selection_too(self):
|
||||
tools = [
|
||||
SimpleNamespace(name="search", description="Search for information"),
|
||||
SimpleNamespace(name="calendar", description="Manage events"),
|
||||
SimpleNamespace(name="translate", description="Translate text"),
|
||||
]
|
||||
model = _FakeModel(
|
||||
model_name="gpt-4o-mini",
|
||||
base_url="https://api.openai.com/v1",
|
||||
)
|
||||
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
selection_tools=tools,
|
||||
)
|
||||
middleware.model = model
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
|
||||
model=model,
|
||||
)
|
||||
|
||||
async def handler(updated_request):
|
||||
return updated_request
|
||||
|
||||
parent_calls = 0
|
||||
|
||||
async def _fake_parent_awrap(self, request_arg, handler_arg):
|
||||
nonlocal parent_calls
|
||||
parent_calls += 1
|
||||
selected_request = request_arg.override(
|
||||
tools=[request_arg.tools[1], request_arg.tools[0]]
|
||||
)
|
||||
return await handler_arg(selected_request)
|
||||
|
||||
with patch.object(
|
||||
tool_selector_module.LLMToolSelectorMiddleware,
|
||||
"awrap_model_call",
|
||||
_fake_parent_awrap,
|
||||
):
|
||||
state_update = asyncio.run(
|
||||
middleware.abefore_agent(request.state, runtime=None, config=None)
|
||||
)
|
||||
if state_update:
|
||||
request.state.update(state_update)
|
||||
first_result = asyncio.run(middleware.awrap_model_call(request, handler))
|
||||
second_result = asyncio.run(middleware.awrap_model_call(request, handler))
|
||||
|
||||
self.assertEqual(parent_calls, 1)
|
||||
self.assertEqual(
|
||||
[tool.name for tool in first_result.tools],
|
||||
["calendar", "search"],
|
||||
)
|
||||
self.assertEqual(
|
||||
[tool.name for tool in second_result.tools],
|
||||
["calendar", "search"],
|
||||
)
|
||||
|
||||
def test_normalize_selection_response_accepts_code_fence_json(self):
|
||||
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware()
|
||||
response = SimpleNamespace(
|
||||
|
||||
Reference in New Issue
Block a user