feat(activity_log): simplify user text matching and enhance logging for captcha tool

This commit is contained in:
jxxghp
2026-06-19 21:25:38 +08:00
parent 7f1cb40421
commit 66feacb48d
6 changed files with 442 additions and 165 deletions

View File

@@ -47,11 +47,6 @@ MAX_LOG_FILE_SIZE = 256 * 1024
# 提取本轮对话上下文的最大字符数(避免过长的对话消耗太多 token
MAX_CONTEXT_FOR_SUMMARY = 4000
TRIVIAL_USER_TEXT_PATTERN = re.compile(
r"^\s*(你好|您好|hi|hello|hey|谢谢|谢了|多谢|ok|好的|收到|嗯|嗯嗯|是的|对|可以|行|好)\s*[。.!]?\s*$",
re.IGNORECASE,
)
SUMMARY_SKIP_MARKER = "SKIP"
# LLM 总结的提示词
@@ -343,15 +338,7 @@ def _should_skip_activity_summary(round_messages: list) -> bool:
if has_tool_activity:
return False
user_text_parts = []
for msg in round_messages:
if not isinstance(msg, HumanMessage):
continue
content = msg.content if isinstance(msg.content, str) else str(msg.content)
if content:
user_text_parts.append(content)
user_text = " ".join(user_text_parts).strip()
return bool(user_text and TRIVIAL_USER_TEXT_PATTERN.match(user_text))
return True
async def _summarize_with_llm(conversation_text: str) -> Optional[str]:

View File

@@ -1,5 +1,6 @@
"""MoviePilot 自定义工具筛选中间件。"""
from dataclasses import replace
import json
from collections.abc import Awaitable, Callable
from typing import Annotated, Any, NotRequired
@@ -19,137 +20,39 @@ from langchain.agents.middleware.tool_selection import (
LLMToolSelectorMiddleware,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import BaseTool
from langgraph.runtime import Runtime
from typing_extensions import TypedDict # noqa
from app.agent.tools.tags import ToolTag
from app.log import logger
MIN_SELECTED_TOOL_COUNT = 4
RECENT_SELECTION_CONTEXT_MESSAGE_LIMIT = 6
RECENT_SELECTION_CONTEXT_MAX_CHARS = 6000
RECENT_SELECTION_CONTEXT_TRUNCATION_PREFIX = "..."
TOOL_GROUP_EXCLUDED_TAGS = frozenset(
{
ToolTag.AgentTool.value,
ToolTag.Read.value,
ToolTag.Write.value,
ToolTag.Admin.value,
ToolTag.Message.value,
ToolTag.UserInteraction.value,
ToolTag.TerminalResponse.value,
}
)
MOVIEPILOT_TOOL_SELECTION_HINT = """
MoviePilot tool-chain hints:
- For media search and download tasks, keep related steps together when relevant:
search_media, search_torrents, get_search_results, add_download_tasks, query_download_tasks.
- For file organization and library transfer tasks, keep related steps together when relevant:
list_directory, query_directory_settings, recognize_media, query_library_exists, transfer_file, query_transfer_history, scrape_metadata.
- For subscription tasks, keep related steps together when relevant:
search_subscribe, add_subscribe, query_subscribes, update_subscribe, query_subscribe_history, query_popular_subscribes.
- For download management tasks, keep related steps together when relevant:
query_download_tasks, update_download_tasks, delete_download_tasks, query_downloaders.
- For site diagnostics or maintenance tasks, keep related steps together when relevant:
query_sites, query_site_userdata, test_site, update_site, update_site_cookie.
- For scheduler and workflow tasks, keep related steps together when relevant:
query_schedulers, run_scheduler, query_workflows, run_workflow, query_episode_schedule.
- For plugin tasks, keep related steps together when relevant:
query_installed_plugins, query_market_plugins, query_plugin_capabilities, query_plugin_config, update_plugin_config, query_plugin_data, install_plugin, uninstall_plugin, reload_plugin.
- For rule, identifier, or system setting tasks, keep related steps together when relevant:
query_rule_groups, query_builtin_filter_rules, query_custom_filter_rules, add_custom_filter_rule, update_custom_filter_rule, delete_custom_filter_rule, add_rule_group, update_rule_group, delete_rule_group, query_custom_identifiers, update_custom_identifiers, query_system_settings, update_system_settings.
- Prefer including the likely next-step tools in the same workflow instead of selecting only the first tool.
- Tools with the same capability tag belong to the same functional group.
- For multi-step MoviePilot tasks, keep same-tag tools together when relevant.
- Prefer selecting likely next-step tools in the same capability group instead of selecting only the first tool.
"""
TOOL_CHAIN_GROUPS = (
(
"media_download",
(
"search_media",
"search_torrents",
"get_search_results",
"add_download_tasks",
"query_download_tasks",
"query_downloaders",
),
),
(
"library_transfer",
(
"list_directory",
"query_directory_settings",
"recognize_media",
"query_library_exists",
"transfer_file",
"query_transfer_history",
"scrape_metadata",
),
),
(
"subscription",
(
"search_subscribe",
"add_subscribe",
"query_subscribes",
"update_subscribe",
"delete_subscribe",
"query_subscribe_history",
"query_popular_subscribes",
"query_subscribe_shares",
),
),
(
"download_management",
(
"query_download_tasks",
"update_download_tasks",
"delete_download_tasks",
"query_downloaders",
),
),
(
"site_management",
(
"query_sites",
"query_site_userdata",
"test_site",
"update_site",
"update_site_cookie",
),
),
(
"workflow_scheduler",
(
"query_schedulers",
"run_scheduler",
"query_workflows",
"run_workflow",
"query_episode_schedule",
),
),
(
"plugin_management",
(
"query_installed_plugins",
"query_market_plugins",
"query_plugin_capabilities",
"query_plugin_config",
"update_plugin_config",
"query_plugin_data",
"install_plugin",
"uninstall_plugin",
"reload_plugin",
),
),
(
"rule_settings",
(
"query_rule_groups",
"query_builtin_filter_rules",
"query_custom_filter_rules",
"add_custom_filter_rule",
"update_custom_filter_rule",
"delete_custom_filter_rule",
"add_rule_group",
"update_rule_group",
"delete_rule_group",
"query_custom_identifiers",
"update_custom_identifiers",
"query_system_settings",
"update_system_settings",
),
),
)
class ToolSelectionState(AgentState):
"""工具筛选中间件私有状态。"""
@@ -203,6 +106,73 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
)
self.selection_tools = selection_tools or []
@classmethod
def _render_recent_conversation_context(
cls,
messages: list[Any],
) -> tuple[str, int]:
"""渲染最近对话上下文,供工具筛选模型理解多轮追问。"""
rendered_messages = []
for message in messages:
if isinstance(message, HumanMessage):
role = "User"
elif isinstance(message, AIMessage):
role = "Assistant"
else:
continue
content = cls._extract_text_content(message.content).strip()
if not content:
continue
rendered_messages.append(f"{role}: {content}")
recent_messages = rendered_messages[-RECENT_SELECTION_CONTEXT_MESSAGE_LIMIT:]
context = "\n\n".join(recent_messages)
if len(context) > RECENT_SELECTION_CONTEXT_MAX_CHARS:
context = (
f"{RECENT_SELECTION_CONTEXT_TRUNCATION_PREFIX}"
f"{context[-RECENT_SELECTION_CONTEXT_MAX_CHARS:]}"
)
return context, len(recent_messages)
@classmethod
def _build_contextual_user_message(
cls,
messages: list[Any],
last_user_message: HumanMessage,
) -> HumanMessage:
"""根据最近对话构造工具筛选专用用户消息。"""
context, message_count = cls._render_recent_conversation_context(messages)
if message_count <= 1:
return last_user_message
return HumanMessage(
content=(
"Recent conversation context for tool selection:\n"
f"{context}\n\n"
"Select tools for the latest user instruction. Use prior assistant "
"messages and earlier user requests when the latest user message "
"depends on previous context."
)
)
def _prepare_selection_request(
self,
request: ModelRequest[ContextT],
) -> Any | None:
"""准备带最近对话上下文的工具筛选请求。"""
selection_request = super()._prepare_selection_request(request)
if selection_request is None:
return None
contextual_user_message = self._build_contextual_user_message(
messages=request.messages,
last_user_message=selection_request.last_user_message,
)
if contextual_user_message is selection_request.last_user_message:
return selection_request
return replace(selection_request, last_user_message=contextual_user_message)
@staticmethod
def _append_tool_selection_hint(system_prompt: str) -> str:
"""追加 MoviePilot 工具组选择提示,避免复杂链路只选中首个工具。"""
@@ -216,34 +186,123 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
return min(self.max_tools, len(valid_tool_names))
return len(valid_tool_names)
@staticmethod
def _normalize_tool_tags(tool: BaseTool) -> list[str]:
"""读取工具的业务标签,过滤掉无法表达工具组的通用标签。"""
tags = getattr(tool, "tags", None) or []
if isinstance(tags, str):
tags = [tags]
normalized_tags = []
for tag in tags:
tag_value = getattr(tag, "value", tag)
if not tag_value:
continue
tag_name = str(tag_value)
if tag_name in TOOL_GROUP_EXCLUDED_TAGS or tag_name in normalized_tags:
continue
normalized_tags.append(tag_name)
return normalized_tags
@classmethod
def _build_tool_groups(
cls,
available_tools: list[BaseTool],
valid_tool_names: list[str],
) -> list[tuple[str, list[str]]]:
"""根据工具标签构造能力组,保留当前工具列表中的稳定顺序。"""
valid_tool_set = set(valid_tool_names)
tool_groups: dict[str, list[str]] = {}
for tool in available_tools:
tool_name = getattr(tool, "name", None)
if not tool_name or tool_name not in valid_tool_set:
continue
for tag in cls._normalize_tool_tags(tool):
group_tool_names = tool_groups.setdefault(tag, [])
if tool_name not in group_tool_names:
group_tool_names.append(tool_name)
return [
(tag, tool_names)
for tag, tool_names in tool_groups.items()
if len(tool_names) > 1
]
@classmethod
def _get_matched_tool_groups(
cls,
selected_names: list[str],
available_tools: list[BaseTool],
valid_tool_names: list[str],
) -> list[tuple[str, list[str]]]:
"""返回已选工具命中的标签能力组。"""
groups_by_tag = {
tag: tool_names
for tag, tool_names in cls._build_tool_groups(
available_tools=available_tools,
valid_tool_names=valid_tool_names,
)
}
tools_by_name = {
tool.name: tool
for tool in available_tools
if getattr(tool, "name", None)
}
matched_groups: list[tuple[str, list[str]]] = []
seen_tags = set()
for tool_name in selected_names:
tool = tools_by_name.get(tool_name)
if not tool:
continue
for tag in cls._normalize_tool_tags(tool):
if tag in seen_tags or tag not in groups_by_tag:
continue
matched_groups.append((tag, groups_by_tag[tag]))
seen_tags.add(tag)
return matched_groups
def _complete_low_count_selection(
self,
selected_tool_names: list[str],
valid_tool_names: list[str],
available_tools: list[BaseTool],
) -> list[str]:
"""
当模型只选出极少工具时,按 MoviePilot 常见工具链补齐相邻工具。
当模型只选出极少工具时,按工具标签补齐同组工具。
这只补齐已经命中的工具组,不会把所有工具组都展开,因此能降低
“选了搜索工具但漏了结果/下载工具”这类链式任务失败概率
工具标签是工具自身声明的能力归属。这里只补齐已经命中的标签组,
不会把所有工具组都展开
"""
limit = self._get_tool_selection_limit(valid_tool_names)
target_count = min(MIN_SELECTED_TOOL_COUNT, limit)
selected_names = [
tool_name
for tool_name in selected_tool_names
if tool_name in valid_tool_names
]
if len(selected_names) >= target_count:
return selected_names[:limit]
selected_set = set(selected_names)
valid_tool_set = set(valid_tool_names)
completed_names = list(selected_names)
matched_groups = self._get_matched_tool_groups(
selected_names=selected_names,
available_tools=available_tools,
valid_tool_names=valid_tool_names,
)
if not matched_groups:
return completed_names[:limit]
for _, group_tool_names in TOOL_CHAIN_GROUPS:
if not selected_set.intersection(group_tool_names):
continue
matched_group_tool_names = {
tool_name
for _, group_tool_names in matched_groups
for tool_name in group_tool_names
}
target_count = min(
max(MIN_SELECTED_TOOL_COUNT, len(matched_group_tool_names)),
limit,
)
if len(selected_names) >= target_count:
return selected_names[:limit]
for _, group_tool_names in matched_groups:
for tool_name in group_tool_names:
if tool_name in selected_set or tool_name not in valid_tool_set:
continue
@@ -285,6 +344,7 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
if isinstance(tool_name, str)
],
valid_tool_names=valid_tool_names,
available_tools=available_tools,
)
return super()._process_selection_response(
response,
@@ -382,12 +442,35 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
raise ValueError("工具筛选 JSON 顶层必须是对象")
return payload
@staticmethod
def _render_tool_list(available_tools: list[Any]) -> str:
@classmethod
def _render_tool_list(cls, available_tools: list[Any]) -> str:
"""把工具名和描述渲染成稳定的文本列表。"""
return "\n".join(
f"- {tool.name}: {tool.description}" for tool in available_tools
lines = []
for tool in available_tools:
tags = cls._normalize_tool_tags(tool)
tag_text = f" [group tags: {', '.join(tags)}]" if tags else ""
lines.append(f"- {tool.name}{tag_text}: {tool.description}")
return "\n".join(lines)
@classmethod
def _render_tool_groups(cls, available_tools: list[BaseTool]) -> str:
"""把当前可用工具按标签渲染成能力组提示。"""
valid_tool_names = [
tool.name
for tool in available_tools
if getattr(tool, "name", None)
]
groups = cls._build_tool_groups(
available_tools=available_tools,
valid_tool_names=valid_tool_names,
)
if not groups:
return ""
rendered_groups = "\n".join(
f"- {tag}: {', '.join(tool_names)}"
for tag, tool_names in groups
)
return f"Capability groups from tool tags:\n{rendered_groups}\n\n"
def _build_deepseek_selection_prompt(self, selection_request: Any) -> str:
"""
@@ -408,8 +491,10 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
"- The `tools` field must be a JSON array of strings.\n"
"- Only use tool names from the allowed list below.\n"
"- Order tools by relevance, with the most relevant first.\n"
"- Tools sharing the same capability tag are in the same group; include same-group tools together when relevant.\n"
f"{limit_instruction}\n"
"- Do not add explanations, markdown, or extra keys.\n\n"
f"{self._render_tool_groups(selection_request.available_tools)}"
"Allowed tools:\n"
f"{self._render_tool_list(selection_request.available_tools)}"
)

View File

@@ -52,6 +52,7 @@ class RecognizeCaptchaTool(MoviePilotTool):
tags: list[str] = [
ToolTag.Read,
ToolTag.Web,
ToolTag.Site,
]
description: str = (
"Recognize a graphic captcha image and return the captcha text. "
@@ -70,6 +71,21 @@ class RecognizeCaptchaTool(MoviePilotTool):
return "识别图形验证码: data image"
return f"识别图形验证码: {image_url}"
@staticmethod
def _format_image_url_for_log(image_url: str) -> str:
"""生成验证码图片地址的安全日志摘要,避免 data URL 图片刷屏。"""
clean_url = (image_url or "").strip()
if not clean_url:
return ""
if clean_url.lower().startswith("data:image/"):
metadata, separator, data = clean_url.partition(",")
if separator:
return f"{metadata},<base64:{len(data)} chars>"
return f"data:image,<invalid:{len(clean_url)} chars>"
if len(clean_url) > 300:
return f"{clean_url[:300]}...(已截断,总长度: {len(clean_url)})"
return clean_url
@staticmethod
def _recognize_captcha_sync(
image_url: str,
@@ -117,7 +133,10 @@ class RecognizeCaptchaTool(MoviePilotTool):
:param allow_private_network: 是否允许访问本机或私网地址
:return: JSON 格式的识别结果
"""
logger.info(f"执行工具: {self.name}, 参数: image_url={image_url}")
logger.info(
f"执行工具: {self.name}, "
f"参数: image_url={self._format_image_url_for_log(image_url)}"
)
try:
captcha_text = await self.run_blocking(

View File

@@ -31,7 +31,7 @@ class OcrHelper:
if image_url:
data_url_b64 = self._extract_data_url_base64(image_url)
if data_url_b64:
image_b64 = data_url_b64
image_b64 = self._normalize_image_base64(data_url_b64)
else:
ret = RequestUtils(ua=ua,
cookies=cookie).get_res(image_url)
@@ -54,7 +54,14 @@ class OcrHelper:
"""规范化外部传入的图片 base64 内容。"""
if not image_b64:
return ""
return OcrHelper._extract_data_url_base64(image_b64) or image_b64.strip()
clean_image_b64 = OcrHelper._extract_data_url_base64(image_b64) or image_b64
clean_image_b64 = "".join(clean_image_b64.split())
if not clean_image_b64:
return ""
padding_size = len(clean_image_b64) % 4
if padding_size:
clean_image_b64 = f"{clean_image_b64}{'=' * (4 - padding_size)}"
return clean_image_b64
@staticmethod
def _extract_data_url_base64(image_url: Optional[str]) -> str:

View File

@@ -82,6 +82,35 @@ def test_ocr_helper_extracts_data_url_base64_without_downloading_image():
}
def test_ocr_helper_normalizes_data_url_base64_padding():
"""data:image 地址缺少 padding 时应补齐后提交给 OCR 服务。"""
image_url = "data:image/jpeg;base64,YWJjZA"
with patch("app.helper.ocr.RequestUtils") as request_utils:
request_utils.return_value.post_res.return_value = _FakeResponse(
payload={"result": "z9k2"}
)
result = OcrHelper().get_captcha_text(image_url=image_url)
assert result == "z9k2"
request_utils.return_value.get_res.assert_not_called()
assert request_utils.return_value.post_res.call_args.kwargs["json"] == {
"base64_img": "YWJjZA=="
}
def test_recognize_captcha_tool_formats_data_url_for_log():
"""验证码工具日志应隐藏 data:image 的完整图片内容。"""
image_b64 = base64.b64encode(b"captcha-image").decode()
image_url = f"data:image/jpeg;base64,{image_b64}"
result = RecognizeCaptchaTool._format_image_url_for_log(image_url)
assert result == f"data:image/jpeg;base64,<base64:{len(image_b64)} chars>"
assert image_b64 not in result
def test_recognize_captcha_tool_returns_captcha_text_from_ocr_helper():
"""验证码工具应返回结构化识别结果,便于 Agent 继续填写表单。"""
tool = RecognizeCaptchaTool(session_id="captcha-session", user_id="10001")

View File

@@ -3,9 +3,10 @@ import unittest
from types import SimpleNamespace
from unittest.mock import patch
from langchain_core.messages import HumanMessage
from langchain_core.messages import AIMessage, HumanMessage
from app.agent.middleware import tool_selection as tool_selector_module
from app.agent.tools.tags import ToolTag
class _FakeBoundModel:
@@ -60,6 +61,11 @@ class _FakeRequest:
return _FakeRequest(**data)
def _tool(name, description, tags=None):
"""构造测试用工具对象。"""
return SimpleNamespace(name=name, description=description, tags=tags or [])
class ToolSelectorMiddlewareTest(unittest.TestCase):
def test_awrap_model_call_uses_json_mode_for_deepseek(self):
tools = [
@@ -223,14 +229,115 @@ class ToolSelectorMiddlewareTest(unittest.TestCase):
self.assertEqual(normalized, {"tools": ["search"]})
def test_process_selection_response_completes_low_count_tool_chain(self):
"""筛选结果过少时应按已命中的工具链补齐相邻工具"""
def test_deepseek_selection_uses_recent_conversation_context(self):
"""多轮追问时工具筛选应看到上一轮用户需求和助手回复"""
tools = [
SimpleNamespace(name="search_media", description="Search media"),
SimpleNamespace(name="search_torrents", description="Search torrents"),
SimpleNamespace(name="get_search_results", description="Get results"),
SimpleNamespace(name="add_download_tasks", description="Add downloads"),
SimpleNamespace(name="query_download_tasks", description="Query downloads"),
_tool(
"query_plugin_config",
"Query plugin config",
[ToolTag.Read, ToolTag.Plugin, ToolTag.Settings],
),
_tool(
"update_plugin_config",
"Update plugin config",
[ToolTag.Write, ToolTag.Plugin, ToolTag.Settings],
),
_tool(
"reload_plugin",
"Reload plugin",
[ToolTag.Write, ToolTag.Plugin],
),
]
model = _FakeModel(content='{"tools": ["query_plugin_config"]}')
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=3,
selection_tools=tools,
)
middleware.model = model
request = _FakeRequest(
tools=tools,
messages=[
HumanMessage(content="帮我检查插件 DemoPlugin 的配置"),
AIMessage(content="我建议先查询插件配置,然后根据结果决定是否重载插件。"),
HumanMessage(content="按你说的来"),
],
model=model,
)
state_update = asyncio.run(
middleware.abefore_agent(request.state, runtime=None, config=None)
)
user_message = model.bound_model.messages[1]
self.assertEqual(
state_update,
{"selected_tool_names": ["query_plugin_config", "update_plugin_config", "reload_plugin"]},
)
self.assertIsInstance(user_message, HumanMessage)
self.assertIn(
"Recent conversation context for tool selection",
user_message.content,
)
self.assertIn("帮我检查插件 DemoPlugin 的配置", user_message.content)
self.assertIn("我建议先查询插件配置", user_message.content)
self.assertIn("按你说的来", user_message.content)
def test_single_turn_selection_keeps_original_user_message(self):
"""单轮对话不应额外包裹上下文提示。"""
tools = [
_tool("search", "Search for information", [ToolTag.Read, ToolTag.Web]),
_tool("calendar", "Manage events", [ToolTag.Write]),
]
model = _FakeModel(content='{"tools": ["search"]}')
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=2,
selection_tools=tools,
)
middleware.model = model
original_message = HumanMessage(content="帮我查一下最近的更新")
request = _FakeRequest(
tools=tools,
messages=[original_message],
model=model,
)
asyncio.run(middleware.abefore_agent(request.state, runtime=None, config=None))
user_message = model.bound_model.messages[1]
self.assertIs(user_message, original_message)
self.assertNotIn(
"Recent conversation context for tool selection",
user_message.content,
)
def test_process_selection_response_completes_low_count_tool_group_by_tags(self):
"""筛选结果过少时应按已命中的工具标签组补齐同组工具。"""
tools = [
_tool(
"search_media",
"Search media",
[ToolTag.Read, ToolTag.Media],
),
_tool(
"search_torrents",
"Search torrents",
[ToolTag.Read, ToolTag.Resource, ToolTag.Site, ToolTag.Media],
),
_tool(
"get_search_results",
"Get results",
[ToolTag.Read, ToolTag.Resource],
),
_tool(
"add_download_tasks",
"Add downloads",
[ToolTag.Write, ToolTag.Download, ToolTag.Resource],
),
_tool(
"query_download_tasks",
"Query downloads",
[ToolTag.Read, ToolTag.Download],
),
]
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=4,
@@ -243,20 +350,21 @@ class ToolSelectorMiddlewareTest(unittest.TestCase):
)
result = middleware._process_selection_response(
{"tools": ["search_media"]},
{"tools": ["search_torrents"]},
available_tools=tools,
valid_tool_names=[tool.name for tool in tools],
request=request,
)
self.assertEqual(len(result.tools), 4)
self.assertEqual(
[tool.name for tool in result.tools],
[
{tool.name for tool in result.tools},
{
"search_media",
"search_torrents",
"get_search_results",
"add_download_tasks",
],
},
)
def test_process_selection_response_keeps_high_count_selection(self):
@@ -302,12 +410,28 @@ class ToolSelectorMiddlewareTest(unittest.TestCase):
)
def test_process_selection_response_respects_max_tools_when_completing(self):
"""工具链补齐不应突破 max_tools 上限。"""
"""标签组补齐不应突破 max_tools 上限。"""
tools = [
SimpleNamespace(name="list_directory", description="List directory"),
SimpleNamespace(name="query_directory_settings", description="Query settings"),
SimpleNamespace(name="recognize_media", description="Recognize media"),
SimpleNamespace(name="transfer_file", description="Transfer file"),
_tool(
"list_directory",
"List directory",
[ToolTag.Read, ToolTag.Directory, ToolTag.File],
),
_tool(
"query_directory_settings",
"Query settings",
[ToolTag.Read, ToolTag.Directory, ToolTag.Settings],
),
_tool(
"recognize_media",
"Recognize media",
[ToolTag.Read, ToolTag.Media],
),
_tool(
"transfer_file",
"Transfer file",
[ToolTag.Write, ToolTag.Transfer, ToolTag.Library, ToolTag.File],
),
]
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=2,
@@ -331,3 +455,29 @@ class ToolSelectorMiddlewareTest(unittest.TestCase):
{tool.name for tool in result.tools},
{"transfer_file", "list_directory"},
)
def test_process_selection_response_ignores_generic_tags_when_completing(self):
"""通用权限标签不应被当作工具组使用。"""
tools = [
_tool("read_one", "Read one", [ToolTag.Read]),
_tool("read_two", "Read two", [ToolTag.Read]),
_tool("write_one", "Write one", [ToolTag.Write, ToolTag.Admin]),
]
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=4,
selection_tools=tools,
)
request = _FakeRequest(
tools=tools,
messages=[HumanMessage(content="查一下信息")],
model=_FakeModel(),
)
result = middleware._process_selection_response(
{"tools": ["read_one"]},
available_tools=tools,
valid_tool_names=[tool.name for tool in tools],
request=request,
)
self.assertEqual([tool.name for tool in result.tools], ["read_one"])