refactor: rename MoviePilotToolSelectorMiddleware to ToolSelectorMiddleware and enhance tool selection logic

This commit is contained in:
jxxghp
2026-04-30 19:05:49 +08:00
parent afcc071d07
commit 53bf81aede
4 changed files with 211 additions and 20 deletions

View File

@@ -27,7 +27,7 @@ from app.agent.middleware.memory import MemoryMiddleware
from app.agent.middleware.patch_tool_calls import PatchToolCallsMiddleware
from app.agent.middleware.runtime_config import RuntimeConfigMiddleware
from app.agent.middleware.skills import SkillsMiddleware
from app.agent.middleware.tool_selection import MoviePilotToolSelectorMiddleware
from app.agent.middleware.tool_selection import ToolSelectorMiddleware
from app.agent.middleware.usage import UsageMiddleware
from app.agent.prompt import prompt_manager
from app.agent.runtime import agent_runtime_manager
@@ -447,11 +447,11 @@ class MoviePilotAgent:
# 工具选择
if max_tools > 0:
middlewares.append(
MoviePilotToolSelectorMiddleware(
ToolSelectorMiddleware(
model=non_streaming_model,
selection_tools=tools,
max_tools=max_tools,
always_include=always_include_tools,
selection_tools=tools,
)
)

View File

@@ -1,23 +1,91 @@
"""MoviePilot 自定义工具筛选中间件。"""
import json
from collections.abc import Awaitable, Callable
from typing import Annotated, Any, NotRequired, TypedDict
from dataclasses import dataclass
from typing import Annotated, Any, Literal, Union, NotRequired
from langchain.agents.middleware import LLMToolSelectorMiddleware
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
ModelRequest,
ModelResponse,
PrivateStateAttr, # noqa
ResponseT,
)
from langchain.agents.middleware.types import (
PrivateStateAttr, # noqa
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import BaseTool
from langgraph.runtime import Runtime
from pydantic import Field, TypeAdapter
from typing_extensions import TypedDict # noqa
from app.log import logger
DEFAULT_SYSTEM_PROMPT = (
"Your goal is to select the most relevant tools for answering the user's query."
)
@dataclass
class _SelectionRequest:
"""Prepared inputs for tool selection."""
available_tools: list[BaseTool]
system_message: str
last_user_message: HumanMessage
model: BaseChatModel
valid_tool_names: list[str]
def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter[Any]:
"""Create a structured output schema for tool selection.
Args:
tools: Available tools to include in the schema.
Returns:
`TypeAdapter` for a schema where each tool name is a `Literal` with its
description.
Raises:
AssertionError: If `tools` is empty.
"""
if not tools:
msg = "Invalid usage: tools must be non-empty"
raise AssertionError(msg)
# Create a Union of Annotated Literal types for each tool name with description
# For instance: Union[Annotated[Literal["tool1"], Field(description="...")], ...]
literals = [
Annotated[Literal[tool.name], Field(description=tool.description)] for tool in tools # noqa
]
selected_tool_type = Union[tuple(literals)] # type: ignore[valid-type] # noqa: UP007
description = "Tools to use. Place the most relevant tools first."
class ToolSelectionResponse(TypedDict):
"""Use to select relevant tools."""
tools: Annotated[list[selected_tool_type], Field(description=description)] # type: ignore[valid-type]
return TypeAdapter(ToolSelectionResponse)
def _render_tool_list(tools: list[BaseTool]) -> str:
"""Format tools as markdown list.
Args:
tools: Tools to format.
Returns:
Markdown string with each tool on a new line.
"""
return "\n".join(f"- {tool.name}: {tool.description}" for tool in tools)
class ToolSelectionState(AgentState):
"""工具筛选中间件私有状态。"""
@@ -34,7 +102,7 @@ class ToolSelectionStateUpdate(TypedDict):
selected_tool_names: list[str] | None
class MoviePilotToolSelectorMiddleware(LLMToolSelectorMiddleware):
class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
"""
为 DeepSeek 兼容端点提供更稳妥的工具筛选实现。
@@ -57,11 +125,134 @@ class MoviePilotToolSelectorMiddleware(LLMToolSelectorMiddleware):
state_schema = ToolSelectionState
def __init__(self, *args, selection_tools: list[Any] | None = None, **kwargs) -> None:
super().__init__(*args, **kwargs)
# `abefore_agent()` 无法直接拿到 ModelRequest因此把首次可见的工具集
# 通过初始化参数传入,后续在进入模型循环前完成一次真实筛选。
self._selection_tools = selection_tools or []
def __init__(self,
model: BaseChatModel,
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
selection_tools: list[Any] | None = None,
max_tools: int | None = None,
always_include: list[str] | None = None, ) -> None:
super().__init__()
self.model = model
self.system_prompt = system_prompt
self.max_tools = max_tools
self.always_include = always_include or []
self.selection_tools = selection_tools or []
def _prepare_selection_request(
self, request: ModelRequest[ContextT]
) -> _SelectionRequest | None:
"""Prepare inputs for tool selection.
Args:
request: the model request.
Returns:
`SelectionRequest` with prepared inputs, or `None` if no selection is
needed.
Raises:
ValueError: If tools in `always_include` are not found in the request.
AssertionError: If no user message is found in the request messages.
"""
# If no tools available, return None
if not request.tools or len(request.tools) == 0:
return None
# Filter to only BaseTool instances (exclude provider-specific tool dicts)
base_tools = [tool for tool in request.tools if not isinstance(tool, dict)]
# Validate that always_include tools exist
if self.always_include:
available_tool_names = {tool.name for tool in base_tools}
missing_tools = [
name for name in self.always_include if name not in available_tool_names
]
if missing_tools:
msg = (
f"Tools in always_include not found in request: {missing_tools}. "
f"Available tools: {sorted(available_tool_names)}"
)
raise ValueError(msg)
# Separate tools that are always included from those available for selection
available_tools = [tool for tool in base_tools if tool.name not in self.always_include]
# If no tools available for selection, return None
if not available_tools:
return None
system_message = self.system_prompt
# If there's a max_tools limit, append instructions to the system prompt
if self.max_tools is not None:
system_message += (
f"\nIMPORTANT: List the tool names in order of relevance, "
f"with the most relevant first. "
f"If you exceed the maximum number of tools, "
f"only the first {self.max_tools} will be used."
)
# Get the last user message from the conversation history
last_user_message: HumanMessage
for message in reversed(request.messages):
if isinstance(message, HumanMessage):
last_user_message = message
break
else:
msg = "No user message found in request messages"
raise AssertionError(msg)
model = self.model or request.model
valid_tool_names = [tool.name for tool in available_tools]
return _SelectionRequest(
available_tools=available_tools,
system_message=system_message,
last_user_message=last_user_message,
model=model,
valid_tool_names=valid_tool_names,
)
def _process_selection_response(
self,
response: dict[str, Any],
available_tools: list[BaseTool],
valid_tool_names: list[str],
request: ModelRequest[ContextT],
) -> ModelRequest[ContextT]:
"""Process the selection response and return filtered `ModelRequest`."""
selected_tool_names: list[str] = []
invalid_tool_selections = []
for tool_name in response["tools"]:
if tool_name not in valid_tool_names:
invalid_tool_selections.append(tool_name)
continue
# Only add if not already selected and within max_tools limit
if tool_name not in selected_tool_names and (
self.max_tools is None or len(selected_tool_names) < self.max_tools
):
selected_tool_names.append(tool_name)
if invalid_tool_selections:
msg = f"Model selected invalid tools: {invalid_tool_selections}"
raise ValueError(msg)
# Filter tools based on selection and append always-included tools
selected_tools: list[BaseTool] = [
tool for tool in available_tools if tool.name in selected_tool_names
]
always_included_tools: list[BaseTool] = [
tool
for tool in request.tools
if not isinstance(tool, dict) and tool.name in self.always_include
]
selected_tools.extend(always_included_tools)
# Also preserve any provider-specific tool dicts from the original request
provider_tools = [tool for tool in request.tools if isinstance(tool, dict)]
return request.override(tools=[*selected_tools, *provider_tools])
@staticmethod
def _is_deepseek_compatible_model(model: BaseChatModel) -> bool:
@@ -295,12 +486,12 @@ class MoviePilotToolSelectorMiddleware(LLMToolSelectorMiddleware):
if "selected_tool_names" in state:
return None
if not self._selection_tools or self.model is None:
if not self.selection_tools or self.model is None:
return ToolSelectionStateUpdate(selected_tool_names=None)
selection_request = ModelRequest(
model=self.model,
tools=list(self._selection_tools),
tools=list(self.selection_tools),
messages=state["messages"],
state=state,
runtime=runtime,
@@ -325,7 +516,7 @@ class MoviePilotToolSelectorMiddleware(LLMToolSelectorMiddleware):
# 正常路径下,`abefore_agent()` 已经提前写入状态;这里只保留一层兜底,
# 兼容直接单测或未来某些绕过 before_agent 的调用场景。
if selected_tool_names is None and self._selection_tools and self.model is not None:
if selected_tool_names is None and self.selection_tools and self.model is not None:
request = await self._aselect_request_once(request)
selected_tool_names = self._extract_selected_tool_names(request) or None
request.state["selected_tool_names"] = selected_tool_names # noqa

View File

@@ -95,7 +95,7 @@ class TestAgentSummarizationStreaming(unittest.TestCase):
),
patch.object(
agent_module,
"MoviePilotToolSelectorMiddleware",
"ToolSelectorMiddleware",
_FakeToolSelectorMiddleware,
),
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),

View File

@@ -100,7 +100,7 @@ class ToolSelectorMiddlewareTest(unittest.TestCase):
SimpleNamespace(name="translate", description="Translate text"),
]
model = _FakeModel()
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware(
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=2,
selection_tools=tools,
)
@@ -144,7 +144,7 @@ class ToolSelectorMiddlewareTest(unittest.TestCase):
SimpleNamespace(name="translate", description="Translate text"),
]
model = _FakeModel(content='{"tools": ["calendar", "search"]}')
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware(
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=2,
selection_tools=tools,
)
@@ -192,7 +192,7 @@ class ToolSelectorMiddlewareTest(unittest.TestCase):
model_name="gpt-4o-mini",
base_url="https://api.openai.com/v1",
)
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware(
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=2,
selection_tools=tools,
)
@@ -240,7 +240,7 @@ class ToolSelectorMiddlewareTest(unittest.TestCase):
)
def test_normalize_selection_response_accepts_code_fence_json(self):
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware()
middleware = tool_selector_module.ToolSelectorMiddleware()
response = SimpleNamespace(
content=[
{