mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-06 20:42:43 +08:00
528 lines
20 KiB
Python
528 lines
20 KiB
Python
"""MoviePilot 自定义工具筛选中间件。"""
|
||
import json
|
||
from collections.abc import Awaitable, Callable
|
||
from dataclasses import dataclass
|
||
from typing import Annotated, Any, Literal, Union, NotRequired
|
||
|
||
from langchain.agents.middleware.types import (
|
||
AgentMiddleware,
|
||
AgentState,
|
||
ContextT,
|
||
ModelRequest,
|
||
ModelResponse,
|
||
ResponseT,
|
||
)
|
||
from langchain.agents.middleware.types import (
|
||
PrivateStateAttr, # noqa
|
||
)
|
||
from langchain_core.language_models.chat_models import BaseChatModel
|
||
from langchain_core.messages import HumanMessage
|
||
from langchain_core.runnables import RunnableConfig
|
||
from langchain_core.tools import BaseTool
|
||
from langgraph.runtime import Runtime
|
||
from pydantic import Field, TypeAdapter
|
||
from typing_extensions import TypedDict # noqa
|
||
|
||
from app.log import logger
|
||
|
||
DEFAULT_SYSTEM_PROMPT = (
|
||
"Your goal is to select the most relevant tools for answering the user's query."
|
||
)
|
||
|
||
|
||
@dataclass
|
||
class _SelectionRequest:
|
||
"""Prepared inputs for tool selection."""
|
||
|
||
available_tools: list[BaseTool]
|
||
system_message: str
|
||
last_user_message: HumanMessage
|
||
model: BaseChatModel
|
||
valid_tool_names: list[str]
|
||
|
||
|
||
def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter[Any]:
|
||
"""Create a structured output schema for tool selection.
|
||
|
||
Args:
|
||
tools: Available tools to include in the schema.
|
||
|
||
Returns:
|
||
`TypeAdapter` for a schema where each tool name is a `Literal` with its
|
||
description.
|
||
|
||
Raises:
|
||
AssertionError: If `tools` is empty.
|
||
"""
|
||
if not tools:
|
||
msg = "Invalid usage: tools must be non-empty"
|
||
raise AssertionError(msg)
|
||
|
||
# Create a Union of Annotated Literal types for each tool name with description
|
||
# For instance: Union[Annotated[Literal["tool1"], Field(description="...")], ...]
|
||
literals = [
|
||
Annotated[Literal[tool.name], Field(description=tool.description)] for tool in tools # noqa
|
||
]
|
||
selected_tool_type = Union[tuple(literals)] # type: ignore[valid-type] # noqa: UP007
|
||
|
||
description = "Tools to use. Place the most relevant tools first."
|
||
|
||
class ToolSelectionResponse(TypedDict):
|
||
"""Use to select relevant tools."""
|
||
|
||
tools: Annotated[list[selected_tool_type], Field(description=description)] # type: ignore[valid-type]
|
||
|
||
return TypeAdapter(ToolSelectionResponse)
|
||
|
||
|
||
def _render_tool_list(tools: list[BaseTool]) -> str:
|
||
"""Format tools as markdown list.
|
||
|
||
Args:
|
||
tools: Tools to format.
|
||
|
||
Returns:
|
||
Markdown string with each tool on a new line.
|
||
"""
|
||
return "\n".join(f"- {tool.name}: {tool.description}" for tool in tools)
|
||
|
||
|
||
class ToolSelectionState(AgentState):
|
||
"""工具筛选中间件私有状态。"""
|
||
|
||
selected_tool_names: NotRequired[
|
||
Annotated[list[str] | None, PrivateStateAttr]
|
||
]
|
||
"""当前这条用户请求首轮筛选得到的工具名列表。"""
|
||
|
||
|
||
class ToolSelectionStateUpdate(TypedDict):
|
||
"""工具筛选中间件状态更新项。"""
|
||
|
||
selected_tool_names: list[str] | None
|
||
|
||
|
||
class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||
"""
|
||
为 DeepSeek 兼容端点提供更稳妥的工具筛选实现。
|
||
|
||
LangChain 默认会通过 `with_structured_output()` 走 OpenAI 的
|
||
`response_format=json_schema` 路径,但 DeepSeek 官方 OpenAI 兼容端点公开文档
|
||
仅保证 `json_object` 模式可用。对于 `deepseek-reasoner`,这会在工具筛选阶段
|
||
提前触发 400,导致 Agent 还没真正开始执行工具就失败。
|
||
|
||
因此这里仅在识别到 DeepSeek 模型/端点时,退回到显式 JSON 输出模式:
|
||
1. 使用 `response_format={"type": "json_object"}`;
|
||
2. 在提示词中明确约束返回 JSON 结构;
|
||
3. 手动解析 `{"tools": [...]}`,其余模型继续沿用 LangChain 默认实现。
|
||
|
||
另外,LangChain 原生工具筛选挂在 `wrap_model_call` 上,会在同一条用户请求
|
||
的每次“模型回合”前都重新筛选一次工具。对于会多轮调用工具的复杂任务,
|
||
这会重复消耗一次额外的 LLM 调用。这里改成:
|
||
- `abefore_agent()`:在本轮 Agent 执行开始时筛选一次;
|
||
- `awrap_model_call()`:从 `request.state` 读取首轮筛选结果并复用。
|
||
"""
|
||
|
||
state_schema = ToolSelectionState
|
||
|
||
def __init__(self,
|
||
model: BaseChatModel,
|
||
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
|
||
selection_tools: list[Any] | None = None,
|
||
max_tools: int | None = None,
|
||
always_include: list[str] | None = None, ) -> None:
|
||
super().__init__()
|
||
self.model = model
|
||
self.system_prompt = system_prompt
|
||
self.max_tools = max_tools
|
||
self.always_include = always_include or []
|
||
self.selection_tools = selection_tools or []
|
||
|
||
def _prepare_selection_request(
|
||
self, request: ModelRequest[ContextT]
|
||
) -> _SelectionRequest | None:
|
||
"""Prepare inputs for tool selection.
|
||
|
||
Args:
|
||
request: the model request.
|
||
|
||
Returns:
|
||
`SelectionRequest` with prepared inputs, or `None` if no selection is
|
||
needed.
|
||
|
||
Raises:
|
||
ValueError: If tools in `always_include` are not found in the request.
|
||
AssertionError: If no user message is found in the request messages.
|
||
"""
|
||
# If no tools available, return None
|
||
if not request.tools or len(request.tools) == 0:
|
||
return None
|
||
|
||
# Filter to only BaseTool instances (exclude provider-specific tool dicts)
|
||
base_tools = [tool for tool in request.tools if not isinstance(tool, dict)]
|
||
|
||
# Validate that always_include tools exist
|
||
if self.always_include:
|
||
available_tool_names = {tool.name for tool in base_tools}
|
||
missing_tools = [
|
||
name for name in self.always_include if name not in available_tool_names
|
||
]
|
||
if missing_tools:
|
||
msg = (
|
||
f"Tools in always_include not found in request: {missing_tools}. "
|
||
f"Available tools: {sorted(available_tool_names)}"
|
||
)
|
||
raise ValueError(msg)
|
||
|
||
# Separate tools that are always included from those available for selection
|
||
available_tools = [tool for tool in base_tools if tool.name not in self.always_include]
|
||
|
||
# If no tools available for selection, return None
|
||
if not available_tools:
|
||
return None
|
||
|
||
system_message = self.system_prompt
|
||
# If there's a max_tools limit, append instructions to the system prompt
|
||
if self.max_tools is not None:
|
||
system_message += (
|
||
f"\nIMPORTANT: List the tool names in order of relevance, "
|
||
f"with the most relevant first. "
|
||
f"If you exceed the maximum number of tools, "
|
||
f"only the first {self.max_tools} will be used."
|
||
)
|
||
|
||
# Get the last user message from the conversation history
|
||
last_user_message: HumanMessage
|
||
for message in reversed(request.messages):
|
||
if isinstance(message, HumanMessage):
|
||
last_user_message = message
|
||
break
|
||
else:
|
||
msg = "No user message found in request messages"
|
||
raise AssertionError(msg)
|
||
|
||
model = self.model or request.model
|
||
valid_tool_names = [tool.name for tool in available_tools]
|
||
|
||
return _SelectionRequest(
|
||
available_tools=available_tools,
|
||
system_message=system_message,
|
||
last_user_message=last_user_message,
|
||
model=model,
|
||
valid_tool_names=valid_tool_names,
|
||
)
|
||
|
||
def _process_selection_response(
|
||
self,
|
||
response: dict[str, Any],
|
||
available_tools: list[BaseTool],
|
||
valid_tool_names: list[str],
|
||
request: ModelRequest[ContextT],
|
||
) -> ModelRequest[ContextT]:
|
||
"""Process the selection response and return filtered `ModelRequest`."""
|
||
selected_tool_names: list[str] = []
|
||
invalid_tool_selections = []
|
||
|
||
for tool_name in response["tools"]:
|
||
if tool_name not in valid_tool_names:
|
||
invalid_tool_selections.append(tool_name)
|
||
continue
|
||
|
||
# Only add if not already selected and within max_tools limit
|
||
if tool_name not in selected_tool_names and (
|
||
self.max_tools is None or len(selected_tool_names) < self.max_tools
|
||
):
|
||
selected_tool_names.append(tool_name)
|
||
|
||
if invalid_tool_selections:
|
||
msg = f"Model selected invalid tools: {invalid_tool_selections}"
|
||
raise ValueError(msg)
|
||
|
||
# Filter tools based on selection and append always-included tools
|
||
selected_tools: list[BaseTool] = [
|
||
tool for tool in available_tools if tool.name in selected_tool_names
|
||
]
|
||
always_included_tools: list[BaseTool] = [
|
||
tool
|
||
for tool in request.tools
|
||
if not isinstance(tool, dict) and tool.name in self.always_include
|
||
]
|
||
selected_tools.extend(always_included_tools)
|
||
|
||
# Also preserve any provider-specific tool dicts from the original request
|
||
provider_tools = [tool for tool in request.tools if isinstance(tool, dict)]
|
||
|
||
return request.override(tools=[*selected_tools, *provider_tools])
|
||
|
||
@staticmethod
|
||
def _is_deepseek_compatible_model(model: BaseChatModel) -> bool:
|
||
"""
|
||
判断当前模型是否应当走 DeepSeek JSON 兼容分支。
|
||
|
||
除了官方 `langchain_deepseek`,用户也可能通过 OpenAI-compatible
|
||
配置把 DeepSeek 端点接到 `ChatOpenAI`。因此这里同时检查模块名、模型名
|
||
和 Base URL,避免只靠单一条件漏判。
|
||
"""
|
||
module_name = type(model).__module__.lower()
|
||
model_name = str(
|
||
getattr(model, "model_name", "") or getattr(model, "model", "")
|
||
).strip().lower()
|
||
base_url = str(
|
||
getattr(model, "openai_api_base", "") or getattr(model, "api_base", "")
|
||
).strip().lower()
|
||
|
||
return (
|
||
"deepseek" in module_name
|
||
or model_name.startswith("deepseek-")
|
||
or "api.deepseek.com" in base_url
|
||
)
|
||
|
||
@staticmethod
|
||
def _extract_text_content(content: Any) -> str:
|
||
"""
|
||
从模型响应中提取纯文本。
|
||
|
||
这里不依赖上层 LLMHelper,避免中间件与 LLM 构造逻辑互相耦合。
|
||
"""
|
||
if content is None:
|
||
return ""
|
||
if isinstance(content, str):
|
||
return content
|
||
if isinstance(content, list):
|
||
text_parts: list[str] = []
|
||
for block in content:
|
||
if isinstance(block, str):
|
||
text_parts.append(block)
|
||
continue
|
||
if isinstance(block, dict):
|
||
if block.get("type") == "text" and isinstance(
|
||
block.get("text"), str
|
||
):
|
||
text_parts.append(block["text"])
|
||
continue
|
||
if not block.get("type") and isinstance(block.get("text"), str):
|
||
text_parts.append(block["text"])
|
||
return "".join(text_parts)
|
||
if isinstance(content, dict):
|
||
if content.get("type") == "text" and isinstance(content.get("text"), str):
|
||
return content["text"]
|
||
if not content.get("type") and isinstance(content.get("text"), str):
|
||
return content["text"]
|
||
return ""
|
||
|
||
@staticmethod
|
||
def _parse_json_object(text: str) -> dict[str, Any]:
|
||
"""
|
||
解析模型返回的 JSON。
|
||
|
||
DeepSeek 在 JSON 模式下通常会返回纯 JSON,但这里仍做一层兜底,
|
||
兼容模型偶发输出围栏或前后说明文本的情况。
|
||
"""
|
||
stripped_text = text.strip()
|
||
if not stripped_text:
|
||
raise ValueError("工具筛选返回了空响应")
|
||
|
||
try:
|
||
payload = json.loads(stripped_text)
|
||
if isinstance(payload, dict):
|
||
return payload
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
start = stripped_text.find("{")
|
||
end = stripped_text.rfind("}")
|
||
if start == -1 or end == -1 or end <= start:
|
||
raise ValueError(f"工具筛选返回的内容不是合法 JSON: {stripped_text}")
|
||
|
||
payload = json.loads(stripped_text[start: end + 1])
|
||
if not isinstance(payload, dict):
|
||
raise ValueError("工具筛选 JSON 顶层必须是对象")
|
||
return payload
|
||
|
||
@staticmethod
|
||
def _render_tool_list(available_tools: list[Any]) -> str:
|
||
"""把工具名和描述渲染成稳定的文本列表。"""
|
||
return "\n".join(
|
||
f"- {tool.name}: {tool.description}" for tool in available_tools
|
||
)
|
||
|
||
def _build_deepseek_selection_prompt(self, selection_request: Any) -> str:
|
||
"""
|
||
为 DeepSeek 生成显式 JSON 输出提示。
|
||
|
||
DeepSeek 官方文档要求在 JSON 输出模式下,提示词中必须明确包含 JSON
|
||
约束,否则兼容端点可能返回空内容或无意义输出。
|
||
"""
|
||
return (
|
||
f"{selection_request.system_message}\n\n"
|
||
"Return the answer in JSON only.\n"
|
||
'Use exactly this shape: {"tools": ["tool_name_1", "tool_name_2"]}\n'
|
||
"Rules:\n"
|
||
"- The `tools` field must be a JSON array of strings.\n"
|
||
"- Only use tool names from the allowed list below.\n"
|
||
"- Order tools by relevance, with the most relevant first.\n"
|
||
"- Do not add explanations, markdown, or extra keys.\n\n"
|
||
f"Allowed tools:\n{self._render_tool_list(selection_request.available_tools)}"
|
||
)
|
||
|
||
def _normalize_selection_response(self, response: Any) -> dict[str, list[str]]:
|
||
"""
|
||
解析并标准化 DeepSeek JSON 模式的工具筛选结果。
|
||
"""
|
||
content = getattr(response, "content", response)
|
||
text = self._extract_text_content(content)
|
||
payload = self._parse_json_object(text)
|
||
|
||
tools = payload.get("tools")
|
||
if not isinstance(tools, list):
|
||
raise ValueError(f"工具筛选 JSON 缺少 `tools` 数组: {payload}")
|
||
|
||
normalized_tools = [tool_name for tool_name in tools if isinstance(tool_name, str)]
|
||
return {"tools": normalized_tools}
|
||
|
||
async def _aselect_tools_with_deepseek(
|
||
self, selection_request: Any
|
||
) -> dict[str, list[str]]:
|
||
"""
|
||
使用 DeepSeek 兼容的 JSON 输出模式执行异步工具筛选。
|
||
"""
|
||
logger.debug("工具筛选走 DeepSeek JSON 兼容分支")
|
||
structured_model = selection_request.model.bind(
|
||
response_format={"type": "json_object"}
|
||
)
|
||
response = await structured_model.ainvoke(
|
||
[
|
||
{
|
||
"role": "system",
|
||
"content": self._build_deepseek_selection_prompt(
|
||
selection_request
|
||
),
|
||
},
|
||
selection_request.last_user_message,
|
||
]
|
||
)
|
||
return self._normalize_selection_response(response)
|
||
|
||
@staticmethod
|
||
def _extract_selected_tool_names(request: ModelRequest) -> list[str]:
|
||
"""从已筛选后的请求中提取最终工具名,保留原有顺序。"""
|
||
return [
|
||
tool.name for tool in request.tools if not isinstance(tool, dict)
|
||
]
|
||
|
||
@staticmethod
|
||
def _apply_selected_tools(
|
||
request: ModelRequest[ContextT],
|
||
selected_tool_names: list[str],
|
||
) -> ModelRequest[ContextT]:
|
||
"""
|
||
将已筛选出的工具集应用到当前模型请求。
|
||
|
||
这里只复用首次筛选出的客户端工具名;provider-specific 的 dict 工具仍然
|
||
原样保留,避免破坏 LangChain/provider 自身的工具绑定约定。
|
||
"""
|
||
if not selected_tool_names:
|
||
return request
|
||
|
||
current_tools_by_name = {
|
||
tool.name: tool
|
||
for tool in request.tools
|
||
if not isinstance(tool, dict)
|
||
}
|
||
selected_tools = [
|
||
current_tools_by_name[tool_name]
|
||
for tool_name in selected_tool_names
|
||
if tool_name in current_tools_by_name
|
||
]
|
||
provider_tools = [tool for tool in request.tools if isinstance(tool, dict)]
|
||
return request.override(tools=[*selected_tools, *provider_tools])
|
||
|
||
async def _aselect_request_once(
|
||
self, request: ModelRequest[ContextT]
|
||
) -> ModelRequest[ContextT]:
|
||
"""
|
||
执行一次真实工具筛选,并返回筛选后的请求对象。
|
||
|
||
这里单独抽成 helper,便于首次筛选后缓存结果,也便于测试覆盖
|
||
“首轮筛选,后续复用”的行为。
|
||
"""
|
||
selection_request = self._prepare_selection_request(request)
|
||
if selection_request is None:
|
||
return request
|
||
|
||
if not self._is_deepseek_compatible_model(selection_request.model):
|
||
captured_request: ModelRequest[ContextT] = request
|
||
|
||
async def _capture_handler(
|
||
updated_request: ModelRequest[ContextT],
|
||
) -> ModelRequest[ContextT]:
|
||
nonlocal captured_request
|
||
captured_request = updated_request
|
||
return updated_request
|
||
|
||
await super().awrap_model_call(request, _capture_handler)
|
||
return captured_request
|
||
|
||
response = await self._aselect_tools_with_deepseek(selection_request)
|
||
return self._process_selection_response(
|
||
response,
|
||
selection_request.available_tools,
|
||
selection_request.valid_tool_names,
|
||
request,
|
||
)
|
||
|
||
async def abefore_agent( # noqa
|
||
self,
|
||
state: ToolSelectionState,
|
||
runtime: Runtime, # noqa
|
||
config: RunnableConfig,
|
||
) -> ToolSelectionStateUpdate | None: # ty: ignore[invalid-method-override]
|
||
"""
|
||
在本轮 Agent 执行开始前完成一次真实工具筛选。
|
||
|
||
这样后续多轮 `model -> tools -> model` 循环都只复用这一次结果,
|
||
不会为每次模型回合重复追加一笔 selector LLM 开销。
|
||
"""
|
||
if "selected_tool_names" in state:
|
||
return None
|
||
|
||
if not self.selection_tools or self.model is None:
|
||
return ToolSelectionStateUpdate(selected_tool_names=None)
|
||
|
||
selection_request = ModelRequest(
|
||
model=self.model,
|
||
tools=list(self.selection_tools),
|
||
messages=state["messages"],
|
||
state=state,
|
||
runtime=runtime,
|
||
)
|
||
modified_request = await self._aselect_request_once(selection_request)
|
||
selected_tool_names = self._extract_selected_tool_names(modified_request)
|
||
return ToolSelectionStateUpdate(
|
||
selected_tool_names=selected_tool_names or None
|
||
)
|
||
|
||
async def awrap_model_call(
|
||
self,
|
||
request: ModelRequest[ContextT],
|
||
handler: Callable[
|
||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||
],
|
||
) -> ModelResponse[ResponseT]:
|
||
"""
|
||
从 state 中读取首次筛选结果,并应用到每次模型回合。
|
||
"""
|
||
selected_tool_names = request.state.get("selected_tool_names") # noqa
|
||
|
||
# 正常路径下,`abefore_agent()` 已经提前写入状态;这里只保留一层兜底,
|
||
# 兼容直接单测或未来某些绕过 before_agent 的调用场景。
|
||
if selected_tool_names is None and self.selection_tools and self.model is not None:
|
||
request = await self._aselect_request_once(request)
|
||
selected_tool_names = self._extract_selected_tool_names(request) or None
|
||
request.state["selected_tool_names"] = selected_tool_names # noqa
|
||
|
||
if selected_tool_names:
|
||
request = self._apply_selected_tools(request, selected_tool_names)
|
||
|
||
return await handler(request)
|