mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-06 20:42:43 +08:00
refactor: rename MoviePilotToolSelectorMiddleware to ToolSelectorMiddleware and enhance tool selection logic
This commit is contained in:
@@ -27,7 +27,7 @@ from app.agent.middleware.memory import MemoryMiddleware
|
||||
from app.agent.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||
from app.agent.middleware.runtime_config import RuntimeConfigMiddleware
|
||||
from app.agent.middleware.skills import SkillsMiddleware
|
||||
from app.agent.middleware.tool_selection import MoviePilotToolSelectorMiddleware
|
||||
from app.agent.middleware.tool_selection import ToolSelectorMiddleware
|
||||
from app.agent.middleware.usage import UsageMiddleware
|
||||
from app.agent.prompt import prompt_manager
|
||||
from app.agent.runtime import agent_runtime_manager
|
||||
@@ -447,11 +447,11 @@ class MoviePilotAgent:
|
||||
# 工具选择
|
||||
if max_tools > 0:
|
||||
middlewares.append(
|
||||
MoviePilotToolSelectorMiddleware(
|
||||
ToolSelectorMiddleware(
|
||||
model=non_streaming_model,
|
||||
selection_tools=tools,
|
||||
max_tools=max_tools,
|
||||
always_include=always_include_tools,
|
||||
selection_tools=tools,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -1,23 +1,91 @@
|
||||
"""MoviePilot 自定义工具筛选中间件。"""
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Annotated, Any, NotRequired, TypedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Any, Literal, Union, NotRequired
|
||||
|
||||
from langchain.agents.middleware import LLMToolSelectorMiddleware
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
PrivateStateAttr, # noqa
|
||||
ResponseT,
|
||||
)
|
||||
from langchain.agents.middleware.types import (
|
||||
PrivateStateAttr, # noqa
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.runtime import Runtime
|
||||
from pydantic import Field, TypeAdapter
|
||||
from typing_extensions import TypedDict # noqa
|
||||
|
||||
from app.log import logger
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = (
|
||||
"Your goal is to select the most relevant tools for answering the user's query."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SelectionRequest:
|
||||
"""Prepared inputs for tool selection."""
|
||||
|
||||
available_tools: list[BaseTool]
|
||||
system_message: str
|
||||
last_user_message: HumanMessage
|
||||
model: BaseChatModel
|
||||
valid_tool_names: list[str]
|
||||
|
||||
|
||||
def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter[Any]:
|
||||
"""Create a structured output schema for tool selection.
|
||||
|
||||
Args:
|
||||
tools: Available tools to include in the schema.
|
||||
|
||||
Returns:
|
||||
`TypeAdapter` for a schema where each tool name is a `Literal` with its
|
||||
description.
|
||||
|
||||
Raises:
|
||||
AssertionError: If `tools` is empty.
|
||||
"""
|
||||
if not tools:
|
||||
msg = "Invalid usage: tools must be non-empty"
|
||||
raise AssertionError(msg)
|
||||
|
||||
# Create a Union of Annotated Literal types for each tool name with description
|
||||
# For instance: Union[Annotated[Literal["tool1"], Field(description="...")], ...]
|
||||
literals = [
|
||||
Annotated[Literal[tool.name], Field(description=tool.description)] for tool in tools # noqa
|
||||
]
|
||||
selected_tool_type = Union[tuple(literals)] # type: ignore[valid-type] # noqa: UP007
|
||||
|
||||
description = "Tools to use. Place the most relevant tools first."
|
||||
|
||||
class ToolSelectionResponse(TypedDict):
|
||||
"""Use to select relevant tools."""
|
||||
|
||||
tools: Annotated[list[selected_tool_type], Field(description=description)] # type: ignore[valid-type]
|
||||
|
||||
return TypeAdapter(ToolSelectionResponse)
|
||||
|
||||
|
||||
def _render_tool_list(tools: list[BaseTool]) -> str:
|
||||
"""Format tools as markdown list.
|
||||
|
||||
Args:
|
||||
tools: Tools to format.
|
||||
|
||||
Returns:
|
||||
Markdown string with each tool on a new line.
|
||||
"""
|
||||
return "\n".join(f"- {tool.name}: {tool.description}" for tool in tools)
|
||||
|
||||
|
||||
class ToolSelectionState(AgentState):
|
||||
"""工具筛选中间件私有状态。"""
|
||||
@@ -34,7 +102,7 @@ class ToolSelectionStateUpdate(TypedDict):
|
||||
selected_tool_names: list[str] | None
|
||||
|
||||
|
||||
class MoviePilotToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
"""
|
||||
为 DeepSeek 兼容端点提供更稳妥的工具筛选实现。
|
||||
|
||||
@@ -57,11 +125,134 @@ class MoviePilotToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
|
||||
state_schema = ToolSelectionState
|
||||
|
||||
def __init__(self, *args, selection_tools: list[Any] | None = None, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
# `abefore_agent()` 无法直接拿到 ModelRequest,因此把首次可见的工具集
|
||||
# 通过初始化参数传入,后续在进入模型循环前完成一次真实筛选。
|
||||
self._selection_tools = selection_tools or []
|
||||
def __init__(self,
|
||||
model: BaseChatModel,
|
||||
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
|
||||
selection_tools: list[Any] | None = None,
|
||||
max_tools: int | None = None,
|
||||
always_include: list[str] | None = None, ) -> None:
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.system_prompt = system_prompt
|
||||
self.max_tools = max_tools
|
||||
self.always_include = always_include or []
|
||||
self.selection_tools = selection_tools or []
|
||||
|
||||
def _prepare_selection_request(
|
||||
self, request: ModelRequest[ContextT]
|
||||
) -> _SelectionRequest | None:
|
||||
"""Prepare inputs for tool selection.
|
||||
|
||||
Args:
|
||||
request: the model request.
|
||||
|
||||
Returns:
|
||||
`SelectionRequest` with prepared inputs, or `None` if no selection is
|
||||
needed.
|
||||
|
||||
Raises:
|
||||
ValueError: If tools in `always_include` are not found in the request.
|
||||
AssertionError: If no user message is found in the request messages.
|
||||
"""
|
||||
# If no tools available, return None
|
||||
if not request.tools or len(request.tools) == 0:
|
||||
return None
|
||||
|
||||
# Filter to only BaseTool instances (exclude provider-specific tool dicts)
|
||||
base_tools = [tool for tool in request.tools if not isinstance(tool, dict)]
|
||||
|
||||
# Validate that always_include tools exist
|
||||
if self.always_include:
|
||||
available_tool_names = {tool.name for tool in base_tools}
|
||||
missing_tools = [
|
||||
name for name in self.always_include if name not in available_tool_names
|
||||
]
|
||||
if missing_tools:
|
||||
msg = (
|
||||
f"Tools in always_include not found in request: {missing_tools}. "
|
||||
f"Available tools: {sorted(available_tool_names)}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
# Separate tools that are always included from those available for selection
|
||||
available_tools = [tool for tool in base_tools if tool.name not in self.always_include]
|
||||
|
||||
# If no tools available for selection, return None
|
||||
if not available_tools:
|
||||
return None
|
||||
|
||||
system_message = self.system_prompt
|
||||
# If there's a max_tools limit, append instructions to the system prompt
|
||||
if self.max_tools is not None:
|
||||
system_message += (
|
||||
f"\nIMPORTANT: List the tool names in order of relevance, "
|
||||
f"with the most relevant first. "
|
||||
f"If you exceed the maximum number of tools, "
|
||||
f"only the first {self.max_tools} will be used."
|
||||
)
|
||||
|
||||
# Get the last user message from the conversation history
|
||||
last_user_message: HumanMessage
|
||||
for message in reversed(request.messages):
|
||||
if isinstance(message, HumanMessage):
|
||||
last_user_message = message
|
||||
break
|
||||
else:
|
||||
msg = "No user message found in request messages"
|
||||
raise AssertionError(msg)
|
||||
|
||||
model = self.model or request.model
|
||||
valid_tool_names = [tool.name for tool in available_tools]
|
||||
|
||||
return _SelectionRequest(
|
||||
available_tools=available_tools,
|
||||
system_message=system_message,
|
||||
last_user_message=last_user_message,
|
||||
model=model,
|
||||
valid_tool_names=valid_tool_names,
|
||||
)
|
||||
|
||||
def _process_selection_response(
|
||||
self,
|
||||
response: dict[str, Any],
|
||||
available_tools: list[BaseTool],
|
||||
valid_tool_names: list[str],
|
||||
request: ModelRequest[ContextT],
|
||||
) -> ModelRequest[ContextT]:
|
||||
"""Process the selection response and return filtered `ModelRequest`."""
|
||||
selected_tool_names: list[str] = []
|
||||
invalid_tool_selections = []
|
||||
|
||||
for tool_name in response["tools"]:
|
||||
if tool_name not in valid_tool_names:
|
||||
invalid_tool_selections.append(tool_name)
|
||||
continue
|
||||
|
||||
# Only add if not already selected and within max_tools limit
|
||||
if tool_name not in selected_tool_names and (
|
||||
self.max_tools is None or len(selected_tool_names) < self.max_tools
|
||||
):
|
||||
selected_tool_names.append(tool_name)
|
||||
|
||||
if invalid_tool_selections:
|
||||
msg = f"Model selected invalid tools: {invalid_tool_selections}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Filter tools based on selection and append always-included tools
|
||||
selected_tools: list[BaseTool] = [
|
||||
tool for tool in available_tools if tool.name in selected_tool_names
|
||||
]
|
||||
always_included_tools: list[BaseTool] = [
|
||||
tool
|
||||
for tool in request.tools
|
||||
if not isinstance(tool, dict) and tool.name in self.always_include
|
||||
]
|
||||
selected_tools.extend(always_included_tools)
|
||||
|
||||
# Also preserve any provider-specific tool dicts from the original request
|
||||
provider_tools = [tool for tool in request.tools if isinstance(tool, dict)]
|
||||
|
||||
return request.override(tools=[*selected_tools, *provider_tools])
|
||||
|
||||
@staticmethod
|
||||
def _is_deepseek_compatible_model(model: BaseChatModel) -> bool:
|
||||
@@ -295,12 +486,12 @@ class MoviePilotToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
if "selected_tool_names" in state:
|
||||
return None
|
||||
|
||||
if not self._selection_tools or self.model is None:
|
||||
if not self.selection_tools or self.model is None:
|
||||
return ToolSelectionStateUpdate(selected_tool_names=None)
|
||||
|
||||
selection_request = ModelRequest(
|
||||
model=self.model,
|
||||
tools=list(self._selection_tools),
|
||||
tools=list(self.selection_tools),
|
||||
messages=state["messages"],
|
||||
state=state,
|
||||
runtime=runtime,
|
||||
@@ -325,7 +516,7 @@ class MoviePilotToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
|
||||
# 正常路径下,`abefore_agent()` 已经提前写入状态;这里只保留一层兜底,
|
||||
# 兼容直接单测或未来某些绕过 before_agent 的调用场景。
|
||||
if selected_tool_names is None and self._selection_tools and self.model is not None:
|
||||
if selected_tool_names is None and self.selection_tools and self.model is not None:
|
||||
request = await self._aselect_request_once(request)
|
||||
selected_tool_names = self._extract_selected_tool_names(request) or None
|
||||
request.state["selected_tool_names"] = selected_tool_names # noqa
|
||||
|
||||
@@ -95,7 +95,7 @@ class TestAgentSummarizationStreaming(unittest.TestCase):
|
||||
),
|
||||
patch.object(
|
||||
agent_module,
|
||||
"MoviePilotToolSelectorMiddleware",
|
||||
"ToolSelectorMiddleware",
|
||||
_FakeToolSelectorMiddleware,
|
||||
),
|
||||
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
|
||||
|
||||
@@ -100,7 +100,7 @@ class ToolSelectorMiddlewareTest(unittest.TestCase):
|
||||
SimpleNamespace(name="translate", description="Translate text"),
|
||||
]
|
||||
model = _FakeModel()
|
||||
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware(
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
selection_tools=tools,
|
||||
)
|
||||
@@ -144,7 +144,7 @@ class ToolSelectorMiddlewareTest(unittest.TestCase):
|
||||
SimpleNamespace(name="translate", description="Translate text"),
|
||||
]
|
||||
model = _FakeModel(content='{"tools": ["calendar", "search"]}')
|
||||
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware(
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
selection_tools=tools,
|
||||
)
|
||||
@@ -192,7 +192,7 @@ class ToolSelectorMiddlewareTest(unittest.TestCase):
|
||||
model_name="gpt-4o-mini",
|
||||
base_url="https://api.openai.com/v1",
|
||||
)
|
||||
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware(
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
selection_tools=tools,
|
||||
)
|
||||
@@ -240,7 +240,7 @@ class ToolSelectorMiddlewareTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_normalize_selection_response_accepts_code_fence_json(self):
|
||||
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware()
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware()
|
||||
response = SimpleNamespace(
|
||||
content=[
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user