fix(agent): stabilize tool selector logging

This commit is contained in:
jxxghp
2026-06-22 19:06:14 +08:00
parent a6afa0fbc0
commit ad73434e2c
2 changed files with 133 additions and 29 deletions

View File

@@ -1,6 +1,6 @@
"""MoviePilot 自定义工具筛选中间件。"""
from dataclasses import replace
from dataclasses import dataclass, replace
import json
from collections.abc import Awaitable, Callable
from typing import Annotated, Any, NotRequired
@@ -68,6 +68,16 @@ class ToolSelectionStateUpdate(TypedDict):
selected_tool_names: list[str] | None
@dataclass(frozen=True)
class _ToolSelectionAttempt:
"""工具筛选尝试结果,用于统一记录最终日志。"""
request: ModelRequest
selected_tool_names: list[str]
status: str
detail: str = ""
class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
"""
使用 provider-neutral JSON 提示执行工具筛选。
@@ -321,8 +331,6 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
处理工具筛选响应,并保留空结果回退所有工具的 MoviePilot 策略。
"""
if response.get("tools") == []:
logger.info("工具筛选结果为空,将恢复使用所有工具。")
always_included_tools: list[BaseTool] = [
tool
for tool in request.tools
@@ -349,8 +357,6 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
valid_tool_names,
request,
)
selected_tool_names = self._extract_selected_tool_names(modified_request)
logger.info(f"工具筛选结果: {', '.join(selected_tool_names) or '无有效工具'}")
return modified_request
@staticmethod
@@ -483,6 +489,34 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
"""从已筛选后的请求中提取最终工具名,保留原有顺序。"""
return [tool.name for tool in request.tools if not isinstance(tool, dict)]
@staticmethod
def _count_request_tools(request: ModelRequest) -> int:
"""统计当前请求中的 LangChain 工具数量,不包含 provider 原生工具字典。"""
return len([tool for tool in request.tools if not isinstance(tool, dict)])
@classmethod
def _log_selection_attempt(cls, attempt: _ToolSelectionAttempt) -> None:
"""按工具筛选最终状态记录稳定日志。"""
tool_count = cls._count_request_tools(attempt.request)
if attempt.status == "selected":
selected_text = ", ".join(attempt.selected_tool_names) or "无有效工具"
logger.info(f"工具筛选结果: {selected_text}")
return
if attempt.status == "empty_fallback":
logger.info(f"工具筛选结果为空,将恢复使用所有工具(共 {tool_count} 个)。")
return
if attempt.status == "failed_fallback":
logger.warning(
f"工具筛选失败,将恢复使用所有工具(共 {tool_count} 个): {attempt.detail}"
)
return
if attempt.status == "skipped":
logger.info(f"工具筛选跳过: {attempt.detail}")
return
if attempt.status == "reused":
selected_text = ", ".join(attempt.selected_tool_names) or "无有效工具"
logger.info(f"工具筛选复用已有结果: {selected_text}")
@staticmethod
def _apply_selected_tools(
request: ModelRequest[ContextT],
@@ -517,21 +551,48 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
这里单独抽成 helper便于首次筛选后缓存结果也便于测试覆盖
“首轮筛选,后续复用”的行为。
"""
return (await self._aselect_request_once_with_status(request)).request
async def _aselect_request_once_with_status(
self, request: ModelRequest[ContextT]
) -> _ToolSelectionAttempt:
"""
执行一次真实工具筛选,并携带最终状态供调用方统一记录日志。
"""
selection_request = self._prepare_selection_request(request)
if selection_request is None:
return request
return _ToolSelectionAttempt(
request=request,
selected_tool_names=self._extract_selected_tool_names(request),
status="skipped",
detail="没有需要筛选的工具",
)
try:
response = await self._aselect_tools_with_json_prompt(selection_request)
return self._process_selection_response(
modified_request = self._process_selection_response(
response,
selection_request.available_tools,
selection_request.valid_tool_names,
request,
)
status = (
"empty_fallback"
if response.get("tools") == []
else "selected"
)
return _ToolSelectionAttempt(
request=modified_request,
selected_tool_names=self._extract_selected_tool_names(modified_request),
status=status,
)
except Exception as err:
logger.warning(f"工具筛选失败,将恢复使用所有工具: {str(err)}")
return request
return _ToolSelectionAttempt(
request=request,
selected_tool_names=self._extract_selected_tool_names(request),
status="failed_fallback",
detail=str(err),
)
async def abefore_agent( # noqa
self,
@@ -546,9 +607,37 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
不会为每次模型回合重复追加一笔 selector LLM 开销。
"""
if "selected_tool_names" in state:
self._log_selection_attempt(
_ToolSelectionAttempt(
request=ModelRequest(
model=self.model,
tools=list(self.selection_tools),
messages=state["messages"],
state=state,
runtime=runtime,
),
selected_tool_names=state.get("selected_tool_names") or [],
status="reused",
)
)
return None
if not self.selection_tools or self.model is None:
detail = "没有可筛选工具" if not self.selection_tools else "未配置筛选模型"
self._log_selection_attempt(
_ToolSelectionAttempt(
request=ModelRequest(
model=self.model,
tools=list(self.selection_tools),
messages=state["messages"],
state=state,
runtime=runtime,
),
selected_tool_names=[],
status="skipped",
detail=detail,
)
)
return ToolSelectionStateUpdate(selected_tool_names=None)
selection_request = ModelRequest(
@@ -558,8 +647,9 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
state=state,
runtime=runtime,
)
modified_request = await self._aselect_request_once(selection_request)
selected_tool_names = self._extract_selected_tool_names(modified_request)
attempt = await self._aselect_request_once_with_status(selection_request)
self._log_selection_attempt(attempt)
selected_tool_names = attempt.selected_tool_names
return ToolSelectionStateUpdate(selected_tool_names=selected_tool_names or None)
async def awrap_model_call(
@@ -581,8 +671,10 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
and self.selection_tools
and self.model is not None
):
request = await self._aselect_request_once(request)
selected_tool_names = self._extract_selected_tool_names(request) or None
attempt = await self._aselect_request_once_with_status(request)
self._log_selection_attempt(attempt)
request = attempt.request
selected_tool_names = attempt.selected_tool_names or None
request.state["selected_tool_names"] = selected_tool_names # noqa
if selected_tool_names:

View File

@@ -298,58 +298,70 @@ def test_empty_tool_selection_logs_info_not_warning():
SimpleNamespace(name="search", description="Search for information"),
SimpleNamespace(name="calendar", description="Manage events"),
]
model = _FakeModel(content='{"tools": []}')
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=2,
selection_tools=tools,
)
middleware.model = model
request = _FakeRequest(
tools=tools,
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
model=_FakeModel(),
model=model,
)
with patch.object(tool_selector_module.logger, "info") as logger_info, \
patch.object(tool_selector_module.logger, "warning") as logger_warning:
result = middleware._process_selection_response(
{"tools": []},
available_tools=tools,
valid_tool_names=[tool.name for tool in tools],
request=request,
state_update = asyncio.run(
middleware.abefore_agent(request.state, runtime=None, config=None)
)
assert [tool.name for tool in result.tools] == ["search", "calendar"]
logger_info.assert_called_once_with("工具筛选结果为空,将恢复使用所有工具。")
assert state_update == {"selected_tool_names": ["search", "calendar"]}
logger_info.assert_called_once_with("工具筛选结果为空,将恢复使用所有工具(共 2 个)")
logger_warning.assert_not_called()
def test_process_selection_response_logs_selected_tools():
def test_abefore_agent_logs_selected_tools():
"""工具筛选返回有效工具时应记录最终生效的工具名。"""
tools = [
SimpleNamespace(name="search", description="Search for information"),
SimpleNamespace(name="calendar", description="Manage events"),
]
model = _FakeModel(content='{"tools": ["calendar"]}')
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=2,
selection_tools=tools,
)
middleware.model = model
request = _FakeRequest(
tools=tools,
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
model=_FakeModel(),
model=model,
)
with patch.object(tool_selector_module.logger, "info") as logger_info:
result = middleware._process_selection_response(
{"tools": ["calendar"]},
available_tools=tools,
valid_tool_names=[tool.name for tool in tools],
request=request,
state_update = asyncio.run(
middleware.abefore_agent(request.state, runtime=None, config=None)
)
assert [tool.name for tool in result.tools] == ["calendar"]
assert state_update == {"selected_tool_names": ["calendar"]}
logger_info.assert_called_once_with("工具筛选结果: calendar")
def test_abefore_agent_logs_skipped_selection():
"""工具筛选未启用时也应记录跳过原因。"""
middleware = tool_selector_module.ToolSelectorMiddleware(selection_tools=[])
request_state = {"messages": [HumanMessage(content="帮我安排明天的行程")]}
with patch.object(tool_selector_module.logger, "info") as logger_info:
state_update = asyncio.run(
middleware.abefore_agent(request_state, runtime=None, config=None)
)
assert state_update == {"selected_tool_names": None}
logger_info.assert_called_once_with("工具筛选跳过: 没有可筛选工具。")
def test_normalize_selection_response_accepts_code_fence_json():
"""工具筛选响应应兼容 Markdown 代码围栏包裹的 JSON。"""
middleware = tool_selector_module.ToolSelectorMiddleware()