From 53bf81aede7c9f58db8137904c010f9a13b4f9d8 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Thu, 30 Apr 2026 19:05:49 +0800 Subject: [PATCH] refactor: rename MoviePilotToolSelectorMiddleware to ToolSelectorMiddleware and enhance tool selection logic --- app/agent/__init__.py | 6 +- app/agent/middleware/tool_selection.py | 215 +++++++++++++++++-- tests/test_agent_summarization_streaming.py | 2 +- tests/test_agent_tool_selector_middleware.py | 8 +- 4 files changed, 211 insertions(+), 20 deletions(-) diff --git a/app/agent/__init__.py b/app/agent/__init__.py index afc70238..69233555 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -27,7 +27,7 @@ from app.agent.middleware.memory import MemoryMiddleware from app.agent.middleware.patch_tool_calls import PatchToolCallsMiddleware from app.agent.middleware.runtime_config import RuntimeConfigMiddleware from app.agent.middleware.skills import SkillsMiddleware -from app.agent.middleware.tool_selection import MoviePilotToolSelectorMiddleware +from app.agent.middleware.tool_selection import ToolSelectorMiddleware from app.agent.middleware.usage import UsageMiddleware from app.agent.prompt import prompt_manager from app.agent.runtime import agent_runtime_manager @@ -447,11 +447,11 @@ class MoviePilotAgent: # 工具选择 if max_tools > 0: middlewares.append( - MoviePilotToolSelectorMiddleware( + ToolSelectorMiddleware( model=non_streaming_model, + selection_tools=tools, max_tools=max_tools, always_include=always_include_tools, - selection_tools=tools, ) ) diff --git a/app/agent/middleware/tool_selection.py b/app/agent/middleware/tool_selection.py index dc52f14d..ae43ceac 100644 --- a/app/agent/middleware/tool_selection.py +++ b/app/agent/middleware/tool_selection.py @@ -1,23 +1,91 @@ """MoviePilot 自定义工具筛选中间件。""" import json from collections.abc import Awaitable, Callable -from typing import Annotated, Any, NotRequired, TypedDict +from dataclasses import dataclass +from typing import Annotated, Any, Literal, Union, NotRequired -from langchain.agents.middleware import LLMToolSelectorMiddleware from langchain.agents.middleware.types import ( + AgentMiddleware, AgentState, ContextT, ModelRequest, ModelResponse, - PrivateStateAttr, # noqa ResponseT, ) +from langchain.agents.middleware.types import ( + PrivateStateAttr, # noqa +) from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import HumanMessage from langchain_core.runnables import RunnableConfig +from langchain_core.tools import BaseTool from langgraph.runtime import Runtime +from pydantic import Field, TypeAdapter +from typing_extensions import TypedDict # noqa from app.log import logger +DEFAULT_SYSTEM_PROMPT = ( + "Your goal is to select the most relevant tools for answering the user's query." +) + + +@dataclass +class _SelectionRequest: + """Prepared inputs for tool selection.""" + + available_tools: list[BaseTool] + system_message: str + last_user_message: HumanMessage + model: BaseChatModel + valid_tool_names: list[str] + + +def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter[Any]: + """Create a structured output schema for tool selection. + + Args: + tools: Available tools to include in the schema. + + Returns: + `TypeAdapter` for a schema where each tool name is a `Literal` with its + description. + + Raises: + AssertionError: If `tools` is empty. + """ + if not tools: + msg = "Invalid usage: tools must be non-empty" + raise AssertionError(msg) + + # Create a Union of Annotated Literal types for each tool name with description + # For instance: Union[Annotated[Literal["tool1"], Field(description="...")], ...] + literals = [ + Annotated[Literal[tool.name], Field(description=tool.description)] for tool in tools # noqa + ] + selected_tool_type = Union[tuple(literals)] # type: ignore[valid-type] # noqa: UP007 + + description = "Tools to use. Place the most relevant tools first." + + class ToolSelectionResponse(TypedDict): + """Use to select relevant tools.""" + + tools: Annotated[list[selected_tool_type], Field(description=description)] # type: ignore[valid-type] + + return TypeAdapter(ToolSelectionResponse) + + +def _render_tool_list(tools: list[BaseTool]) -> str: + """Format tools as markdown list. + + Args: + tools: Tools to format. + + Returns: + Markdown string with each tool on a new line. + """ + return "\n".join(f"- {tool.name}: {tool.description}" for tool in tools) + class ToolSelectionState(AgentState): """工具筛选中间件私有状态。""" @@ -34,7 +102,7 @@ class ToolSelectionStateUpdate(TypedDict): selected_tool_names: list[str] | None -class MoviePilotToolSelectorMiddleware(LLMToolSelectorMiddleware): +class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): """ 为 DeepSeek 兼容端点提供更稳妥的工具筛选实现。 @@ -57,11 +125,134 @@ class MoviePilotToolSelectorMiddleware(LLMToolSelectorMiddleware): state_schema = ToolSelectionState - def __init__(self, *args, selection_tools: list[Any] | None = None, **kwargs) -> None: - super().__init__(*args, **kwargs) - # `abefore_agent()` 无法直接拿到 ModelRequest,因此把首次可见的工具集 - # 通过初始化参数传入,后续在进入模型循环前完成一次真实筛选。 - self._selection_tools = selection_tools or [] + def __init__(self, + model: BaseChatModel, + system_prompt: str = DEFAULT_SYSTEM_PROMPT, + selection_tools: list[Any] | None = None, + max_tools: int | None = None, + always_include: list[str] | None = None, ) -> None: + super().__init__() + self.model = model + self.system_prompt = system_prompt + self.max_tools = max_tools + self.always_include = always_include or [] + self.selection_tools = selection_tools or [] + + def _prepare_selection_request( + self, request: ModelRequest[ContextT] + ) -> _SelectionRequest | None: + """Prepare inputs for tool selection. + + Args: + request: the model request. + + Returns: + `SelectionRequest` with prepared inputs, or `None` if no selection is + needed. + + Raises: + ValueError: If tools in `always_include` are not found in the request. + AssertionError: If no user message is found in the request messages. + """ + # If no tools available, return None + if not request.tools or len(request.tools) == 0: + return None + + # Filter to only BaseTool instances (exclude provider-specific tool dicts) + base_tools = [tool for tool in request.tools if not isinstance(tool, dict)] + + # Validate that always_include tools exist + if self.always_include: + available_tool_names = {tool.name for tool in base_tools} + missing_tools = [ + name for name in self.always_include if name not in available_tool_names + ] + if missing_tools: + msg = ( + f"Tools in always_include not found in request: {missing_tools}. " + f"Available tools: {sorted(available_tool_names)}" + ) + raise ValueError(msg) + + # Separate tools that are always included from those available for selection + available_tools = [tool for tool in base_tools if tool.name not in self.always_include] + + # If no tools available for selection, return None + if not available_tools: + return None + + system_message = self.system_prompt + # If there's a max_tools limit, append instructions to the system prompt + if self.max_tools is not None: + system_message += ( + f"\nIMPORTANT: List the tool names in order of relevance, " + f"with the most relevant first. " + f"If you exceed the maximum number of tools, " + f"only the first {self.max_tools} will be used." + ) + + # Get the last user message from the conversation history + last_user_message: HumanMessage + for message in reversed(request.messages): + if isinstance(message, HumanMessage): + last_user_message = message + break + else: + msg = "No user message found in request messages" + raise AssertionError(msg) + + model = self.model or request.model + valid_tool_names = [tool.name for tool in available_tools] + + return _SelectionRequest( + available_tools=available_tools, + system_message=system_message, + last_user_message=last_user_message, + model=model, + valid_tool_names=valid_tool_names, + ) + + def _process_selection_response( + self, + response: dict[str, Any], + available_tools: list[BaseTool], + valid_tool_names: list[str], + request: ModelRequest[ContextT], + ) -> ModelRequest[ContextT]: + """Process the selection response and return filtered `ModelRequest`.""" + selected_tool_names: list[str] = [] + invalid_tool_selections = [] + + for tool_name in response["tools"]: + if tool_name not in valid_tool_names: + invalid_tool_selections.append(tool_name) + continue + + # Only add if not already selected and within max_tools limit + if tool_name not in selected_tool_names and ( + self.max_tools is None or len(selected_tool_names) < self.max_tools + ): + selected_tool_names.append(tool_name) + + if invalid_tool_selections: + msg = f"Model selected invalid tools: {invalid_tool_selections}" + raise ValueError(msg) + + # Filter tools based on selection and append always-included tools + selected_tools: list[BaseTool] = [ + tool for tool in available_tools if tool.name in selected_tool_names + ] + always_included_tools: list[BaseTool] = [ + tool + for tool in request.tools + if not isinstance(tool, dict) and tool.name in self.always_include + ] + selected_tools.extend(always_included_tools) + + # Also preserve any provider-specific tool dicts from the original request + provider_tools = [tool for tool in request.tools if isinstance(tool, dict)] + + return request.override(tools=[*selected_tools, *provider_tools]) @staticmethod def _is_deepseek_compatible_model(model: BaseChatModel) -> bool: @@ -295,12 +486,12 @@ class MoviePilotToolSelectorMiddleware(LLMToolSelectorMiddleware): if "selected_tool_names" in state: return None - if not self._selection_tools or self.model is None: + if not self.selection_tools or self.model is None: return ToolSelectionStateUpdate(selected_tool_names=None) selection_request = ModelRequest( model=self.model, - tools=list(self._selection_tools), + tools=list(self.selection_tools), messages=state["messages"], state=state, runtime=runtime, @@ -325,7 +516,7 @@ class MoviePilotToolSelectorMiddleware(LLMToolSelectorMiddleware): # 正常路径下,`abefore_agent()` 已经提前写入状态;这里只保留一层兜底, # 兼容直接单测或未来某些绕过 before_agent 的调用场景。 - if selected_tool_names is None and self._selection_tools and self.model is not None: + if selected_tool_names is None and self.selection_tools and self.model is not None: request = await self._aselect_request_once(request) selected_tool_names = self._extract_selected_tool_names(request) or None request.state["selected_tool_names"] = selected_tool_names # noqa diff --git a/tests/test_agent_summarization_streaming.py b/tests/test_agent_summarization_streaming.py index ab4f2312..95d762cf 100644 --- a/tests/test_agent_summarization_streaming.py +++ b/tests/test_agent_summarization_streaming.py @@ -95,7 +95,7 @@ class TestAgentSummarizationStreaming(unittest.TestCase): ), patch.object( agent_module, - "MoviePilotToolSelectorMiddleware", + "ToolSelectorMiddleware", _FakeToolSelectorMiddleware, ), patch.object(agent_module, "create_agent", side_effect=_fake_create_agent), diff --git a/tests/test_agent_tool_selector_middleware.py b/tests/test_agent_tool_selector_middleware.py index eb02d475..bdb73aa0 100644 --- a/tests/test_agent_tool_selector_middleware.py +++ b/tests/test_agent_tool_selector_middleware.py @@ -100,7 +100,7 @@ class ToolSelectorMiddlewareTest(unittest.TestCase): SimpleNamespace(name="translate", description="Translate text"), ] model = _FakeModel() - middleware = tool_selector_module.MoviePilotToolSelectorMiddleware( + middleware = tool_selector_module.ToolSelectorMiddleware( max_tools=2, selection_tools=tools, ) @@ -144,7 +144,7 @@ class ToolSelectorMiddlewareTest(unittest.TestCase): SimpleNamespace(name="translate", description="Translate text"), ] model = _FakeModel(content='{"tools": ["calendar", "search"]}') - middleware = tool_selector_module.MoviePilotToolSelectorMiddleware( + middleware = tool_selector_module.ToolSelectorMiddleware( max_tools=2, selection_tools=tools, ) @@ -192,7 +192,7 @@ class ToolSelectorMiddlewareTest(unittest.TestCase): model_name="gpt-4o-mini", base_url="https://api.openai.com/v1", ) - middleware = tool_selector_module.MoviePilotToolSelectorMiddleware( + middleware = tool_selector_module.ToolSelectorMiddleware( max_tools=2, selection_tools=tools, ) @@ -240,7 +240,7 @@ class ToolSelectorMiddlewareTest(unittest.TestCase): ) def test_normalize_selection_response_accepts_code_fence_json(self): - middleware = tool_selector_module.MoviePilotToolSelectorMiddleware() + middleware = tool_selector_module.ToolSelectorMiddleware() response = SimpleNamespace( content=[ {