mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-06 20:42:43 +08:00
feat: improve tool selection prompt with clearer instructions
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""MoviePilot 自定义工具筛选中间件。"""
|
||||
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
@@ -61,7 +62,8 @@ def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter[Any]:
|
||||
# Create a Union of Annotated Literal types for each tool name with description
|
||||
# For instance: Union[Annotated[Literal["tool1"], Field(description="...")], ...]
|
||||
literals = [
|
||||
Annotated[Literal[tool.name], Field(description=tool.description)] for tool in tools # noqa
|
||||
Annotated[Literal[tool.name], Field(description=tool.description)]
|
||||
for tool in tools # noqa
|
||||
]
|
||||
selected_tool_type = Union[tuple(literals)] # type: ignore[valid-type] # noqa: UP007
|
||||
|
||||
@@ -90,9 +92,7 @@ def _render_tool_list(tools: list[BaseTool]) -> str:
|
||||
class ToolSelectionState(AgentState):
|
||||
"""工具筛选中间件私有状态。"""
|
||||
|
||||
selected_tool_names: NotRequired[
|
||||
Annotated[list[str] | None, PrivateStateAttr]
|
||||
]
|
||||
selected_tool_names: NotRequired[Annotated[list[str] | None, PrivateStateAttr]]
|
||||
"""当前这条用户请求首轮筛选得到的工具名列表。"""
|
||||
|
||||
|
||||
@@ -102,7 +102,9 @@ class ToolSelectionStateUpdate(TypedDict):
|
||||
selected_tool_names: list[str] | None
|
||||
|
||||
|
||||
class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
class ToolSelectorMiddleware(
|
||||
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
|
||||
):
|
||||
"""
|
||||
为 DeepSeek 兼容端点提供更稳妥的工具筛选实现。
|
||||
|
||||
@@ -125,12 +127,14 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re
|
||||
|
||||
state_schema = ToolSelectionState
|
||||
|
||||
def __init__(self,
|
||||
model: BaseChatModel,
|
||||
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
|
||||
selection_tools: list[Any] | None = None,
|
||||
max_tools: int | None = None,
|
||||
always_include: list[str] | None = None, ) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
model: BaseChatModel,
|
||||
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
|
||||
selection_tools: list[Any] | None = None,
|
||||
max_tools: int | None = None,
|
||||
always_include: list[str] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.system_prompt = system_prompt
|
||||
@@ -175,7 +179,9 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re
|
||||
raise ValueError(msg)
|
||||
|
||||
# Separate tools that are always included from those available for selection
|
||||
available_tools = [tool for tool in base_tools if tool.name not in self.always_include]
|
||||
available_tools = [
|
||||
tool for tool in base_tools if tool.name not in self.always_include
|
||||
]
|
||||
|
||||
# If no tools available for selection, return None
|
||||
if not available_tools:
|
||||
@@ -239,9 +245,15 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re
|
||||
raise ValueError(msg)
|
||||
|
||||
# Filter tools based on selection and append always-included tools
|
||||
selected_tools: list[BaseTool] = [
|
||||
tool for tool in available_tools if tool.name in selected_tool_names
|
||||
]
|
||||
if selected_tool_names:
|
||||
selected_tools: list[BaseTool] = [
|
||||
tool for tool in available_tools if tool.name in selected_tool_names
|
||||
]
|
||||
else:
|
||||
# 如果模型筛选结果为空,则不对工具进行裁剪,使用所有可用工具
|
||||
logger.warning("工具筛选结果为空,将恢复使用所有工具。")
|
||||
selected_tools = available_tools
|
||||
|
||||
always_included_tools: list[BaseTool] = [
|
||||
tool
|
||||
for tool in request.tools
|
||||
@@ -264,12 +276,16 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re
|
||||
和 Base URL,避免只靠单一条件漏判。
|
||||
"""
|
||||
module_name = type(model).__module__.lower()
|
||||
model_name = str(
|
||||
getattr(model, "model_name", "") or getattr(model, "model", "")
|
||||
).strip().lower()
|
||||
base_url = str(
|
||||
getattr(model, "openai_api_base", "") or getattr(model, "api_base", "")
|
||||
).strip().lower()
|
||||
model_name = (
|
||||
str(getattr(model, "model_name", "") or getattr(model, "model", ""))
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
base_url = (
|
||||
str(getattr(model, "openai_api_base", "") or getattr(model, "api_base", ""))
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
|
||||
return (
|
||||
"deepseek" in module_name
|
||||
@@ -353,6 +369,10 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re
|
||||
DeepSeek 官方文档要求在 JSON 输出模式下,提示词中必须明确包含 JSON
|
||||
约束,否则兼容端点可能返回空内容或无意义输出。
|
||||
"""
|
||||
limit_instruction = ""
|
||||
if self.max_tools:
|
||||
limit_instruction = f"- Select up to {self.max_tools} tools. IF NO TOOLS ARE RELEVANT, DO NOT RETURN AN EMPTY ARRAY. SELECT THE MOST APPLICABLE ONES TO ENSURE THE REQUEST IS HANDLED."
|
||||
|
||||
return (
|
||||
f"{selection_request.system_message}\n\n"
|
||||
"Return the answer in JSON only.\n"
|
||||
@@ -361,8 +381,10 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re
|
||||
"- The `tools` field must be a JSON array of strings.\n"
|
||||
"- Only use tool names from the allowed list below.\n"
|
||||
"- Order tools by relevance, with the most relevant first.\n"
|
||||
f"{limit_instruction}\n"
|
||||
"- Do not add explanations, markdown, or extra keys.\n\n"
|
||||
f"Allowed tools:\n{self._render_tool_list(selection_request.available_tools)}"
|
||||
"Allowed tools:\n"
|
||||
f"{self._render_tool_list(selection_request.available_tools)}"
|
||||
)
|
||||
|
||||
def _normalize_selection_response(self, response: Any) -> dict[str, list[str]]:
|
||||
@@ -371,13 +393,17 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re
|
||||
"""
|
||||
content = getattr(response, "content", response)
|
||||
text = self._extract_text_content(content)
|
||||
logger.debug(f"工具筛选原始响应: {text}")
|
||||
payload = self._parse_json_object(text)
|
||||
|
||||
tools = payload.get("tools")
|
||||
if not isinstance(tools, list):
|
||||
raise ValueError(f"工具筛选 JSON 缺少 `tools` 数组: {payload}")
|
||||
|
||||
normalized_tools = [tool_name for tool_name in tools if isinstance(tool_name, str)]
|
||||
normalized_tools = [
|
||||
tool_name for tool_name in tools if isinstance(tool_name, str)
|
||||
]
|
||||
logger.debug(f"工具筛选标准化结果: {normalized_tools}")
|
||||
return {"tools": normalized_tools}
|
||||
|
||||
async def _aselect_tools_with_deepseek(
|
||||
@@ -394,9 +420,7 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": self._build_deepseek_selection_prompt(
|
||||
selection_request
|
||||
),
|
||||
"content": self._build_deepseek_selection_prompt(selection_request),
|
||||
},
|
||||
selection_request.last_user_message,
|
||||
]
|
||||
@@ -406,9 +430,7 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re
|
||||
@staticmethod
|
||||
def _extract_selected_tool_names(request: ModelRequest) -> list[str]:
|
||||
"""从已筛选后的请求中提取最终工具名,保留原有顺序。"""
|
||||
return [
|
||||
tool.name for tool in request.tools if not isinstance(tool, dict)
|
||||
]
|
||||
return [tool.name for tool in request.tools if not isinstance(tool, dict)]
|
||||
|
||||
@staticmethod
|
||||
def _apply_selected_tools(
|
||||
@@ -425,9 +447,7 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re
|
||||
return request
|
||||
|
||||
current_tools_by_name = {
|
||||
tool.name: tool
|
||||
for tool in request.tools
|
||||
if not isinstance(tool, dict)
|
||||
tool.name: tool for tool in request.tools if not isinstance(tool, dict)
|
||||
}
|
||||
selected_tools = [
|
||||
current_tools_by_name[tool_name]
|
||||
@@ -498,9 +518,7 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re
|
||||
)
|
||||
modified_request = await self._aselect_request_once(selection_request)
|
||||
selected_tool_names = self._extract_selected_tool_names(modified_request)
|
||||
return ToolSelectionStateUpdate(
|
||||
selected_tool_names=selected_tool_names or None
|
||||
)
|
||||
return ToolSelectionStateUpdate(selected_tool_names=selected_tool_names or None)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
@@ -516,7 +534,11 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re
|
||||
|
||||
# 正常路径下,`abefore_agent()` 已经提前写入状态;这里只保留一层兜底,
|
||||
# 兼容直接单测或未来某些绕过 before_agent 的调用场景。
|
||||
if selected_tool_names is None and self.selection_tools and self.model is not None:
|
||||
if (
|
||||
selected_tool_names is None
|
||||
and self.selection_tools
|
||||
and self.model is not None
|
||||
):
|
||||
request = await self._aselect_request_once(request)
|
||||
selected_tool_names = self._extract_selected_tool_names(request) or None
|
||||
request.state["selected_tool_names"] = selected_tool_names # noqa
|
||||
|
||||
Reference in New Issue
Block a user