diff --git a/app/agent/middleware/tool_selection.py b/app/agent/middleware/tool_selection.py index ae43ceac..9217c212 100644 --- a/app/agent/middleware/tool_selection.py +++ b/app/agent/middleware/tool_selection.py @@ -1,4 +1,5 @@ """MoviePilot 自定义工具筛选中间件。""" + import json from collections.abc import Awaitable, Callable from dataclasses import dataclass @@ -61,7 +62,8 @@ def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter[Any]: # Create a Union of Annotated Literal types for each tool name with description # For instance: Union[Annotated[Literal["tool1"], Field(description="...")], ...] literals = [ - Annotated[Literal[tool.name], Field(description=tool.description)] for tool in tools # noqa + Annotated[Literal[tool.name], Field(description=tool.description)] + for tool in tools # noqa ] selected_tool_type = Union[tuple(literals)] # type: ignore[valid-type] # noqa: UP007 @@ -90,9 +92,7 @@ def _render_tool_list(tools: list[BaseTool]) -> str: class ToolSelectionState(AgentState): """工具筛选中间件私有状态。""" - selected_tool_names: NotRequired[ - Annotated[list[str] | None, PrivateStateAttr] - ] + selected_tool_names: NotRequired[Annotated[list[str] | None, PrivateStateAttr]] """当前这条用户请求首轮筛选得到的工具名列表。""" @@ -102,7 +102,9 @@ class ToolSelectionStateUpdate(TypedDict): selected_tool_names: list[str] | None -class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): +class ToolSelectorMiddleware( + AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT] +): """ 为 DeepSeek 兼容端点提供更稳妥的工具筛选实现。 @@ -125,12 +127,14 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re state_schema = ToolSelectionState - def __init__(self, - model: BaseChatModel, - system_prompt: str = DEFAULT_SYSTEM_PROMPT, - selection_tools: list[Any] | None = None, - max_tools: int | None = None, - always_include: list[str] | None = None, ) -> None: + def __init__( + self, + model: BaseChatModel, + system_prompt: str = DEFAULT_SYSTEM_PROMPT, + selection_tools: list[Any] | None = None, + max_tools: int | None = None, + always_include: list[str] | None = None, + ) -> None: super().__init__() self.model = model self.system_prompt = system_prompt @@ -175,7 +179,9 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re raise ValueError(msg) # Separate tools that are always included from those available for selection - available_tools = [tool for tool in base_tools if tool.name not in self.always_include] + available_tools = [ + tool for tool in base_tools if tool.name not in self.always_include + ] # If no tools available for selection, return None if not available_tools: @@ -239,9 +245,15 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re raise ValueError(msg) # Filter tools based on selection and append always-included tools - selected_tools: list[BaseTool] = [ - tool for tool in available_tools if tool.name in selected_tool_names - ] + if selected_tool_names: + selected_tools: list[BaseTool] = [ + tool for tool in available_tools if tool.name in selected_tool_names + ] + else: + # 如果模型筛选结果为空,则不对工具进行裁剪,使用所有可用工具 + logger.warning("工具筛选结果为空,将恢复使用所有工具。") + selected_tools = available_tools + always_included_tools: list[BaseTool] = [ tool for tool in request.tools @@ -264,12 +276,16 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re 和 Base URL,避免只靠单一条件漏判。 """ module_name = type(model).__module__.lower() - model_name = str( - getattr(model, "model_name", "") or getattr(model, "model", "") - ).strip().lower() - base_url = str( - getattr(model, "openai_api_base", "") or getattr(model, "api_base", "") - ).strip().lower() + model_name = ( + str(getattr(model, "model_name", "") or getattr(model, "model", "")) + .strip() + .lower() + ) + base_url = ( + str(getattr(model, "openai_api_base", "") or getattr(model, "api_base", "")) + .strip() + .lower() + ) return ( "deepseek" in module_name @@ -353,6 +369,10 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re DeepSeek 官方文档要求在 JSON 输出模式下,提示词中必须明确包含 JSON 约束,否则兼容端点可能返回空内容或无意义输出。 """ + limit_instruction = "" + if self.max_tools: + limit_instruction = f"- Select up to {self.max_tools} tools. IF NO TOOLS ARE RELEVANT, DO NOT RETURN AN EMPTY ARRAY. SELECT THE MOST APPLICABLE ONES TO ENSURE THE REQUEST IS HANDLED." + return ( f"{selection_request.system_message}\n\n" "Return the answer in JSON only.\n" @@ -361,8 +381,10 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re "- The `tools` field must be a JSON array of strings.\n" "- Only use tool names from the allowed list below.\n" "- Order tools by relevance, with the most relevant first.\n" + f"{limit_instruction}\n" "- Do not add explanations, markdown, or extra keys.\n\n" - f"Allowed tools:\n{self._render_tool_list(selection_request.available_tools)}" + "Allowed tools:\n" + f"{self._render_tool_list(selection_request.available_tools)}" ) def _normalize_selection_response(self, response: Any) -> dict[str, list[str]]: @@ -371,13 +393,17 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re """ content = getattr(response, "content", response) text = self._extract_text_content(content) + logger.debug(f"工具筛选原始响应: {text}") payload = self._parse_json_object(text) tools = payload.get("tools") if not isinstance(tools, list): raise ValueError(f"工具筛选 JSON 缺少 `tools` 数组: {payload}") - normalized_tools = [tool_name for tool_name in tools if isinstance(tool_name, str)] + normalized_tools = [ + tool_name for tool_name in tools if isinstance(tool_name, str) + ] + logger.debug(f"工具筛选标准化结果: {normalized_tools}") return {"tools": normalized_tools} async def _aselect_tools_with_deepseek( @@ -394,9 +420,7 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re [ { "role": "system", - "content": self._build_deepseek_selection_prompt( - selection_request - ), + "content": self._build_deepseek_selection_prompt(selection_request), }, selection_request.last_user_message, ] @@ -406,9 +430,7 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re @staticmethod def _extract_selected_tool_names(request: ModelRequest) -> list[str]: """从已筛选后的请求中提取最终工具名,保留原有顺序。""" - return [ - tool.name for tool in request.tools if not isinstance(tool, dict) - ] + return [tool.name for tool in request.tools if not isinstance(tool, dict)] @staticmethod def _apply_selected_tools( @@ -425,9 +447,7 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re return request current_tools_by_name = { - tool.name: tool - for tool in request.tools - if not isinstance(tool, dict) + tool.name: tool for tool in request.tools if not isinstance(tool, dict) } selected_tools = [ current_tools_by_name[tool_name] @@ -498,9 +518,7 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re ) modified_request = await self._aselect_request_once(selection_request) selected_tool_names = self._extract_selected_tool_names(modified_request) - return ToolSelectionStateUpdate( - selected_tool_names=selected_tool_names or None - ) + return ToolSelectionStateUpdate(selected_tool_names=selected_tool_names or None) async def awrap_model_call( self, @@ -516,7 +534,11 @@ class ToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Re # 正常路径下,`abefore_agent()` 已经提前写入状态;这里只保留一层兜底, # 兼容直接单测或未来某些绕过 before_agent 的调用场景。 - if selected_tool_names is None and self.selection_tools and self.model is not None: + if ( + selected_tool_names is None + and self.selection_tools + and self.model is not None + ): request = await self._aselect_request_once(request) selected_tool_names = self._extract_selected_tool_names(request) or None request.state["selected_tool_names"] = selected_tool_names # noqa