feat: improve tool selection prompt with clearer instructions

This commit is contained in:
jxxghp
2026-04-30 20:33:46 +08:00
parent b129508304
commit 100eaec38f

View File

@@ -1,4 +1,5 @@
"""MoviePilot 自定义工具筛选中间件。"""
import json
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
@@ -61,7 +62,8 @@ def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter[Any]:
# Create a Union of Annotated Literal types for each tool name with description
# For instance: Union[Annotated[Literal["tool1"], Field(description="...")], ...]
literals = [
Annotated[Literal[tool.name], Field(description=tool.description)] for tool in tools # noqa
Annotated[Literal[tool.name], Field(description=tool.description)]
for tool in tools # noqa
]
selected_tool_type = Union[tuple(literals)] # type: ignore[valid-type] # noqa: UP007
@@ -90,9 +92,7 @@ def _render_tool_list(tools: list[BaseTool]) -> str:
class ToolSelectionState(AgentState):
"""工具筛选中间件私有状态。"""
selected_tool_names: NotRequired[
Annotated[list[str] | None, PrivateStateAttr]
]
selected_tool_names: NotRequired[Annotated[list[str] | None, PrivateStateAttr]]
"""当前这条用户请求首轮筛选得到的工具名列表。"""
@@ -102,7 +102,9 @@ class ToolSelectionStateUpdate(TypedDict):
selected_tool_names: list[str] | None
class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
class ToolSelectorMiddleware(
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
):
"""
为 DeepSeek 兼容端点提供更稳妥的工具筛选实现。
@@ -125,12 +127,14 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re
state_schema = ToolSelectionState
def __init__(self,
model: BaseChatModel,
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
selection_tools: list[Any] | None = None,
max_tools: int | None = None,
always_include: list[str] | None = None, ) -> None:
def __init__(
self,
model: BaseChatModel,
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
selection_tools: list[Any] | None = None,
max_tools: int | None = None,
always_include: list[str] | None = None,
) -> None:
super().__init__()
self.model = model
self.system_prompt = system_prompt
@@ -175,7 +179,9 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re
raise ValueError(msg)
# Separate tools that are always included from those available for selection
available_tools = [tool for tool in base_tools if tool.name not in self.always_include]
available_tools = [
tool for tool in base_tools if tool.name not in self.always_include
]
# If no tools available for selection, return None
if not available_tools:
@@ -239,9 +245,15 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re
raise ValueError(msg)
# Filter tools based on selection and append always-included tools
selected_tools: list[BaseTool] = [
tool for tool in available_tools if tool.name in selected_tool_names
]
if selected_tool_names:
selected_tools: list[BaseTool] = [
tool for tool in available_tools if tool.name in selected_tool_names
]
else:
# 如果模型筛选结果为空,则不对工具进行裁剪,使用所有可用工具
logger.warning("工具筛选结果为空,将恢复使用所有工具。")
selected_tools = available_tools
always_included_tools: list[BaseTool] = [
tool
for tool in request.tools
@@ -264,12 +276,16 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re
和 Base URL避免只靠单一条件漏判。
"""
module_name = type(model).__module__.lower()
model_name = str(
getattr(model, "model_name", "") or getattr(model, "model", "")
).strip().lower()
base_url = str(
getattr(model, "openai_api_base", "") or getattr(model, "api_base", "")
).strip().lower()
model_name = (
str(getattr(model, "model_name", "") or getattr(model, "model", ""))
.strip()
.lower()
)
base_url = (
str(getattr(model, "openai_api_base", "") or getattr(model, "api_base", ""))
.strip()
.lower()
)
return (
"deepseek" in module_name
@@ -353,6 +369,10 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re
DeepSeek 官方文档要求在 JSON 输出模式下,提示词中必须明确包含 JSON
约束,否则兼容端点可能返回空内容或无意义输出。
"""
limit_instruction = ""
if self.max_tools:
limit_instruction = f"- Select up to {self.max_tools} tools. IF NO TOOLS ARE RELEVANT, DO NOT RETURN AN EMPTY ARRAY. SELECT THE MOST APPLICABLE ONES TO ENSURE THE REQUEST IS HANDLED."
return (
f"{selection_request.system_message}\n\n"
"Return the answer in JSON only.\n"
@@ -361,8 +381,10 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re
"- The `tools` field must be a JSON array of strings.\n"
"- Only use tool names from the allowed list below.\n"
"- Order tools by relevance, with the most relevant first.\n"
f"{limit_instruction}\n"
"- Do not add explanations, markdown, or extra keys.\n\n"
f"Allowed tools:\n{self._render_tool_list(selection_request.available_tools)}"
"Allowed tools:\n"
f"{self._render_tool_list(selection_request.available_tools)}"
)
def _normalize_selection_response(self, response: Any) -> dict[str, list[str]]:
@@ -371,13 +393,17 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re
"""
content = getattr(response, "content", response)
text = self._extract_text_content(content)
logger.debug(f"工具筛选原始响应: {text}")
payload = self._parse_json_object(text)
tools = payload.get("tools")
if not isinstance(tools, list):
raise ValueError(f"工具筛选 JSON 缺少 `tools` 数组: {payload}")
normalized_tools = [tool_name for tool_name in tools if isinstance(tool_name, str)]
normalized_tools = [
tool_name for tool_name in tools if isinstance(tool_name, str)
]
logger.debug(f"工具筛选标准化结果: {normalized_tools}")
return {"tools": normalized_tools}
async def _aselect_tools_with_deepseek(
@@ -394,9 +420,7 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re
[
{
"role": "system",
"content": self._build_deepseek_selection_prompt(
selection_request
),
"content": self._build_deepseek_selection_prompt(selection_request),
},
selection_request.last_user_message,
]
@@ -406,9 +430,7 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re
@staticmethod
def _extract_selected_tool_names(request: ModelRequest) -> list[str]:
"""从已筛选后的请求中提取最终工具名,保留原有顺序。"""
return [
tool.name for tool in request.tools if not isinstance(tool, dict)
]
return [tool.name for tool in request.tools if not isinstance(tool, dict)]
@staticmethod
def _apply_selected_tools(
@@ -425,9 +447,7 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re
return request
current_tools_by_name = {
tool.name: tool
for tool in request.tools
if not isinstance(tool, dict)
tool.name: tool for tool in request.tools if not isinstance(tool, dict)
}
selected_tools = [
current_tools_by_name[tool_name]
@@ -498,9 +518,7 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re
)
modified_request = await self._aselect_request_once(selection_request)
selected_tool_names = self._extract_selected_tool_names(modified_request)
return ToolSelectionStateUpdate(
selected_tool_names=selected_tool_names or None
)
return ToolSelectionStateUpdate(selected_tool_names=selected_tool_names or None)
async def awrap_model_call(
self,
@@ -516,7 +534,11 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re
# 正常路径下,`abefore_agent()` 已经提前写入状态;这里只保留一层兜底,
# 兼容直接单测或未来某些绕过 before_agent 的调用场景。
if selected_tool_names is None and self.selection_tools and self.model is not None:
if (
selected_tool_names is None
and self.selection_tools
and self.model is not None
):
request = await self._aselect_request_once(request)
selected_tool_names = self._extract_selected_tool_names(request) or None
request.state["selected_tool_names"] = selected_tool_names # noqa