mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-25 09:34:19 +08:00
fix(agent): stabilize tool selector logging
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
"""MoviePilot 自定义工具筛选中间件。"""
|
||||
|
||||
from dataclasses import replace
|
||||
from dataclasses import dataclass, replace
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Annotated, Any, NotRequired
|
||||
@@ -68,6 +68,16 @@ class ToolSelectionStateUpdate(TypedDict):
|
||||
selected_tool_names: list[str] | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ToolSelectionAttempt:
|
||||
"""工具筛选尝试结果,用于统一记录最终日志。"""
|
||||
|
||||
request: ModelRequest
|
||||
selected_tool_names: list[str]
|
||||
status: str
|
||||
detail: str = ""
|
||||
|
||||
|
||||
class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
"""
|
||||
使用 provider-neutral JSON 提示执行工具筛选。
|
||||
@@ -321,8 +331,6 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
处理工具筛选响应,并保留空结果回退所有工具的 MoviePilot 策略。
|
||||
"""
|
||||
if response.get("tools") == []:
|
||||
logger.info("工具筛选结果为空,将恢复使用所有工具。")
|
||||
|
||||
always_included_tools: list[BaseTool] = [
|
||||
tool
|
||||
for tool in request.tools
|
||||
@@ -349,8 +357,6 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
valid_tool_names,
|
||||
request,
|
||||
)
|
||||
selected_tool_names = self._extract_selected_tool_names(modified_request)
|
||||
logger.info(f"工具筛选结果: {', '.join(selected_tool_names) or '无有效工具'}")
|
||||
return modified_request
|
||||
|
||||
@staticmethod
|
||||
@@ -483,6 +489,34 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
"""从已筛选后的请求中提取最终工具名,保留原有顺序。"""
|
||||
return [tool.name for tool in request.tools if not isinstance(tool, dict)]
|
||||
|
||||
@staticmethod
|
||||
def _count_request_tools(request: ModelRequest) -> int:
|
||||
"""统计当前请求中的 LangChain 工具数量,不包含 provider 原生工具字典。"""
|
||||
return len([tool for tool in request.tools if not isinstance(tool, dict)])
|
||||
|
||||
@classmethod
|
||||
def _log_selection_attempt(cls, attempt: _ToolSelectionAttempt) -> None:
|
||||
"""按工具筛选最终状态记录稳定日志。"""
|
||||
tool_count = cls._count_request_tools(attempt.request)
|
||||
if attempt.status == "selected":
|
||||
selected_text = ", ".join(attempt.selected_tool_names) or "无有效工具"
|
||||
logger.info(f"工具筛选结果: {selected_text}")
|
||||
return
|
||||
if attempt.status == "empty_fallback":
|
||||
logger.info(f"工具筛选结果为空,将恢复使用所有工具(共 {tool_count} 个)。")
|
||||
return
|
||||
if attempt.status == "failed_fallback":
|
||||
logger.warning(
|
||||
f"工具筛选失败,将恢复使用所有工具(共 {tool_count} 个): {attempt.detail}"
|
||||
)
|
||||
return
|
||||
if attempt.status == "skipped":
|
||||
logger.info(f"工具筛选跳过: {attempt.detail}。")
|
||||
return
|
||||
if attempt.status == "reused":
|
||||
selected_text = ", ".join(attempt.selected_tool_names) or "无有效工具"
|
||||
logger.info(f"工具筛选复用已有结果: {selected_text}")
|
||||
|
||||
@staticmethod
|
||||
def _apply_selected_tools(
|
||||
request: ModelRequest[ContextT],
|
||||
@@ -517,21 +551,48 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
这里单独抽成 helper,便于首次筛选后缓存结果,也便于测试覆盖
|
||||
“首轮筛选,后续复用”的行为。
|
||||
"""
|
||||
return (await self._aselect_request_once_with_status(request)).request
|
||||
|
||||
async def _aselect_request_once_with_status(
|
||||
self, request: ModelRequest[ContextT]
|
||||
) -> _ToolSelectionAttempt:
|
||||
"""
|
||||
执行一次真实工具筛选,并携带最终状态供调用方统一记录日志。
|
||||
"""
|
||||
selection_request = self._prepare_selection_request(request)
|
||||
if selection_request is None:
|
||||
return request
|
||||
return _ToolSelectionAttempt(
|
||||
request=request,
|
||||
selected_tool_names=self._extract_selected_tool_names(request),
|
||||
status="skipped",
|
||||
detail="没有需要筛选的工具",
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self._aselect_tools_with_json_prompt(selection_request)
|
||||
return self._process_selection_response(
|
||||
modified_request = self._process_selection_response(
|
||||
response,
|
||||
selection_request.available_tools,
|
||||
selection_request.valid_tool_names,
|
||||
request,
|
||||
)
|
||||
status = (
|
||||
"empty_fallback"
|
||||
if response.get("tools") == []
|
||||
else "selected"
|
||||
)
|
||||
return _ToolSelectionAttempt(
|
||||
request=modified_request,
|
||||
selected_tool_names=self._extract_selected_tool_names(modified_request),
|
||||
status=status,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.warning(f"工具筛选失败,将恢复使用所有工具: {str(err)}")
|
||||
return request
|
||||
return _ToolSelectionAttempt(
|
||||
request=request,
|
||||
selected_tool_names=self._extract_selected_tool_names(request),
|
||||
status="failed_fallback",
|
||||
detail=str(err),
|
||||
)
|
||||
|
||||
async def abefore_agent( # noqa
|
||||
self,
|
||||
@@ -546,9 +607,37 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
不会为每次模型回合重复追加一笔 selector LLM 开销。
|
||||
"""
|
||||
if "selected_tool_names" in state:
|
||||
self._log_selection_attempt(
|
||||
_ToolSelectionAttempt(
|
||||
request=ModelRequest(
|
||||
model=self.model,
|
||||
tools=list(self.selection_tools),
|
||||
messages=state["messages"],
|
||||
state=state,
|
||||
runtime=runtime,
|
||||
),
|
||||
selected_tool_names=state.get("selected_tool_names") or [],
|
||||
status="reused",
|
||||
)
|
||||
)
|
||||
return None
|
||||
|
||||
if not self.selection_tools or self.model is None:
|
||||
detail = "没有可筛选工具" if not self.selection_tools else "未配置筛选模型"
|
||||
self._log_selection_attempt(
|
||||
_ToolSelectionAttempt(
|
||||
request=ModelRequest(
|
||||
model=self.model,
|
||||
tools=list(self.selection_tools),
|
||||
messages=state["messages"],
|
||||
state=state,
|
||||
runtime=runtime,
|
||||
),
|
||||
selected_tool_names=[],
|
||||
status="skipped",
|
||||
detail=detail,
|
||||
)
|
||||
)
|
||||
return ToolSelectionStateUpdate(selected_tool_names=None)
|
||||
|
||||
selection_request = ModelRequest(
|
||||
@@ -558,8 +647,9 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
state=state,
|
||||
runtime=runtime,
|
||||
)
|
||||
modified_request = await self._aselect_request_once(selection_request)
|
||||
selected_tool_names = self._extract_selected_tool_names(modified_request)
|
||||
attempt = await self._aselect_request_once_with_status(selection_request)
|
||||
self._log_selection_attempt(attempt)
|
||||
selected_tool_names = attempt.selected_tool_names
|
||||
return ToolSelectionStateUpdate(selected_tool_names=selected_tool_names or None)
|
||||
|
||||
async def awrap_model_call(
|
||||
@@ -581,8 +671,10 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
and self.selection_tools
|
||||
and self.model is not None
|
||||
):
|
||||
request = await self._aselect_request_once(request)
|
||||
selected_tool_names = self._extract_selected_tool_names(request) or None
|
||||
attempt = await self._aselect_request_once_with_status(request)
|
||||
self._log_selection_attempt(attempt)
|
||||
request = attempt.request
|
||||
selected_tool_names = attempt.selected_tool_names or None
|
||||
request.state["selected_tool_names"] = selected_tool_names # noqa
|
||||
|
||||
if selected_tool_names:
|
||||
|
||||
@@ -298,58 +298,70 @@ def test_empty_tool_selection_logs_info_not_warning():
|
||||
SimpleNamespace(name="search", description="Search for information"),
|
||||
SimpleNamespace(name="calendar", description="Manage events"),
|
||||
]
|
||||
model = _FakeModel(content='{"tools": []}')
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
selection_tools=tools,
|
||||
)
|
||||
middleware.model = model
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
|
||||
model=_FakeModel(),
|
||||
model=model,
|
||||
)
|
||||
|
||||
with patch.object(tool_selector_module.logger, "info") as logger_info, \
|
||||
patch.object(tool_selector_module.logger, "warning") as logger_warning:
|
||||
result = middleware._process_selection_response(
|
||||
{"tools": []},
|
||||
available_tools=tools,
|
||||
valid_tool_names=[tool.name for tool in tools],
|
||||
request=request,
|
||||
state_update = asyncio.run(
|
||||
middleware.abefore_agent(request.state, runtime=None, config=None)
|
||||
)
|
||||
|
||||
assert [tool.name for tool in result.tools] == ["search", "calendar"]
|
||||
logger_info.assert_called_once_with("工具筛选结果为空,将恢复使用所有工具。")
|
||||
assert state_update == {"selected_tool_names": ["search", "calendar"]}
|
||||
logger_info.assert_called_once_with("工具筛选结果为空,将恢复使用所有工具(共 2 个)。")
|
||||
logger_warning.assert_not_called()
|
||||
|
||||
|
||||
def test_process_selection_response_logs_selected_tools():
|
||||
def test_abefore_agent_logs_selected_tools():
|
||||
"""工具筛选返回有效工具时应记录最终生效的工具名。"""
|
||||
tools = [
|
||||
SimpleNamespace(name="search", description="Search for information"),
|
||||
SimpleNamespace(name="calendar", description="Manage events"),
|
||||
]
|
||||
model = _FakeModel(content='{"tools": ["calendar"]}')
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
selection_tools=tools,
|
||||
)
|
||||
middleware.model = model
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
|
||||
model=_FakeModel(),
|
||||
model=model,
|
||||
)
|
||||
|
||||
with patch.object(tool_selector_module.logger, "info") as logger_info:
|
||||
result = middleware._process_selection_response(
|
||||
{"tools": ["calendar"]},
|
||||
available_tools=tools,
|
||||
valid_tool_names=[tool.name for tool in tools],
|
||||
request=request,
|
||||
state_update = asyncio.run(
|
||||
middleware.abefore_agent(request.state, runtime=None, config=None)
|
||||
)
|
||||
|
||||
assert [tool.name for tool in result.tools] == ["calendar"]
|
||||
assert state_update == {"selected_tool_names": ["calendar"]}
|
||||
logger_info.assert_called_once_with("工具筛选结果: calendar")
|
||||
|
||||
|
||||
def test_abefore_agent_logs_skipped_selection():
|
||||
"""工具筛选未启用时也应记录跳过原因。"""
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(selection_tools=[])
|
||||
request_state = {"messages": [HumanMessage(content="帮我安排明天的行程")]}
|
||||
|
||||
with patch.object(tool_selector_module.logger, "info") as logger_info:
|
||||
state_update = asyncio.run(
|
||||
middleware.abefore_agent(request_state, runtime=None, config=None)
|
||||
)
|
||||
|
||||
assert state_update == {"selected_tool_names": None}
|
||||
logger_info.assert_called_once_with("工具筛选跳过: 没有可筛选工具。")
|
||||
|
||||
|
||||
def test_normalize_selection_response_accepts_code_fence_json():
|
||||
"""工具筛选响应应兼容 Markdown 代码围栏包裹的 JSON。"""
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware()
|
||||
|
||||
Reference in New Issue
Block a user