fix(agent): stabilize tool selector logging

This commit is contained in:
jxxghp
2026-06-22 19:06:14 +08:00
parent a6afa0fbc0
commit ad73434e2c
2 changed files with 133 additions and 29 deletions

View File

@@ -1,6 +1,6 @@
"""MoviePilot 自定义工具筛选中间件。"""
from dataclasses import replace
from dataclasses import dataclass, replace
import json
from collections.abc import Awaitable, Callable
from typing import Annotated, Any, NotRequired
@@ -68,6 +68,16 @@ class ToolSelectionStateUpdate(TypedDict):
selected_tool_names: list[str] | None
@dataclass(frozen=True)
class _ToolSelectionAttempt:
"""工具筛选尝试结果,用于统一记录最终日志。"""
request: ModelRequest
selected_tool_names: list[str]
status: str
detail: str = ""
class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
"""
使用 provider-neutral JSON 提示执行工具筛选。
@@ -321,8 +331,6 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
处理工具筛选响应,并保留空结果回退所有工具的 MoviePilot 策略。
"""
if response.get("tools") == []:
logger.info("工具筛选结果为空,将恢复使用所有工具。")
always_included_tools: list[BaseTool] = [
tool
for tool in request.tools
@@ -349,8 +357,6 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
valid_tool_names,
request,
)
selected_tool_names = self._extract_selected_tool_names(modified_request)
logger.info(f"工具筛选结果: {', '.join(selected_tool_names) or '无有效工具'}")
return modified_request
@staticmethod
@@ -483,6 +489,34 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
"""从已筛选后的请求中提取最终工具名,保留原有顺序。"""
return [tool.name for tool in request.tools if not isinstance(tool, dict)]
@staticmethod
def _count_request_tools(request: ModelRequest) -> int:
"""统计当前请求中的 LangChain 工具数量,不包含 provider 原生工具字典。"""
return len([tool for tool in request.tools if not isinstance(tool, dict)])
@classmethod
def _log_selection_attempt(cls, attempt: _ToolSelectionAttempt) -> None:
"""按工具筛选最终状态记录稳定日志。"""
tool_count = cls._count_request_tools(attempt.request)
if attempt.status == "selected":
selected_text = ", ".join(attempt.selected_tool_names) or "无有效工具"
logger.info(f"工具筛选结果: {selected_text}")
return
if attempt.status == "empty_fallback":
logger.info(f"工具筛选结果为空,将恢复使用所有工具(共 {tool_count} 个)。")
return
if attempt.status == "failed_fallback":
logger.warning(
f"工具筛选失败,将恢复使用所有工具(共 {tool_count} 个): {attempt.detail}"
)
return
if attempt.status == "skipped":
logger.info(f"工具筛选跳过: {attempt.detail}")
return
if attempt.status == "reused":
selected_text = ", ".join(attempt.selected_tool_names) or "无有效工具"
logger.info(f"工具筛选复用已有结果: {selected_text}")
@staticmethod
def _apply_selected_tools(
request: ModelRequest[ContextT],
@@ -517,21 +551,48 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
这里单独抽成 helper便于首次筛选后缓存结果也便于测试覆盖
“首轮筛选,后续复用”的行为。
"""
return (await self._aselect_request_once_with_status(request)).request
async def _aselect_request_once_with_status(
self, request: ModelRequest[ContextT]
) -> _ToolSelectionAttempt:
"""
执行一次真实工具筛选,并携带最终状态供调用方统一记录日志。
"""
selection_request = self._prepare_selection_request(request)
if selection_request is None:
return request
return _ToolSelectionAttempt(
request=request,
selected_tool_names=self._extract_selected_tool_names(request),
status="skipped",
detail="没有需要筛选的工具",
)
try:
response = await self._aselect_tools_with_json_prompt(selection_request)
return self._process_selection_response(
modified_request = self._process_selection_response(
response,
selection_request.available_tools,
selection_request.valid_tool_names,
request,
)
status = (
"empty_fallback"
if response.get("tools") == []
else "selected"
)
return _ToolSelectionAttempt(
request=modified_request,
selected_tool_names=self._extract_selected_tool_names(modified_request),
status=status,
)
except Exception as err:
logger.warning(f"工具筛选失败,将恢复使用所有工具: {str(err)}")
return request
return _ToolSelectionAttempt(
request=request,
selected_tool_names=self._extract_selected_tool_names(request),
status="failed_fallback",
detail=str(err),
)
async def abefore_agent( # noqa
self,
@@ -546,9 +607,37 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
不会为每次模型回合重复追加一笔 selector LLM 开销。
"""
if "selected_tool_names" in state:
self._log_selection_attempt(
_ToolSelectionAttempt(
request=ModelRequest(
model=self.model,
tools=list(self.selection_tools),
messages=state["messages"],
state=state,
runtime=runtime,
),
selected_tool_names=state.get("selected_tool_names") or [],
status="reused",
)
)
return None
if not self.selection_tools or self.model is None:
detail = "没有可筛选工具" if not self.selection_tools else "未配置筛选模型"
self._log_selection_attempt(
_ToolSelectionAttempt(
request=ModelRequest(
model=self.model,
tools=list(self.selection_tools),
messages=state["messages"],
state=state,
runtime=runtime,
),
selected_tool_names=[],
status="skipped",
detail=detail,
)
)
return ToolSelectionStateUpdate(selected_tool_names=None)
selection_request = ModelRequest(
@@ -558,8 +647,9 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
state=state,
runtime=runtime,
)
modified_request = await self._aselect_request_once(selection_request)
selected_tool_names = self._extract_selected_tool_names(modified_request)
attempt = await self._aselect_request_once_with_status(selection_request)
self._log_selection_attempt(attempt)
selected_tool_names = attempt.selected_tool_names
return ToolSelectionStateUpdate(selected_tool_names=selected_tool_names or None)
async def awrap_model_call(
@@ -581,8 +671,10 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
and self.selection_tools
and self.model is not None
):
request = await self._aselect_request_once(request)
selected_tool_names = self._extract_selected_tool_names(request) or None
attempt = await self._aselect_request_once_with_status(request)
self._log_selection_attempt(attempt)
request = attempt.request
selected_tool_names = attempt.selected_tool_names or None
request.state["selected_tool_names"] = selected_tool_names # noqa
if selected_tool_names: