diff --git a/app/agent/middleware/activity_log.py b/app/agent/middleware/activity_log.py index 4a5f5fca..ceddef12 100644 --- a/app/agent/middleware/activity_log.py +++ b/app/agent/middleware/activity_log.py @@ -47,11 +47,6 @@ MAX_LOG_FILE_SIZE = 256 * 1024 # 提取本轮对话上下文的最大字符数(避免过长的对话消耗太多 token) MAX_CONTEXT_FOR_SUMMARY = 4000 -TRIVIAL_USER_TEXT_PATTERN = re.compile( - r"^\s*(你好|您好|hi|hello|hey|谢谢|谢了|多谢|ok|好的|收到|嗯|嗯嗯|是的|对|可以|行|好)\s*[。.!!]?\s*$", - re.IGNORECASE, -) - SUMMARY_SKIP_MARKER = "SKIP" # LLM 总结的提示词 @@ -343,15 +338,7 @@ def _should_skip_activity_summary(round_messages: list) -> bool: if has_tool_activity: return False - user_text_parts = [] - for msg in round_messages: - if not isinstance(msg, HumanMessage): - continue - content = msg.content if isinstance(msg.content, str) else str(msg.content) - if content: - user_text_parts.append(content) - user_text = " ".join(user_text_parts).strip() - return bool(user_text and TRIVIAL_USER_TEXT_PATTERN.match(user_text)) + return True async def _summarize_with_llm(conversation_text: str) -> Optional[str]: diff --git a/app/agent/middleware/tool_selection.py b/app/agent/middleware/tool_selection.py index 2ba3dc79..1a0082f3 100644 --- a/app/agent/middleware/tool_selection.py +++ b/app/agent/middleware/tool_selection.py @@ -1,5 +1,6 @@ """MoviePilot 自定义工具筛选中间件。""" +from dataclasses import replace import json from collections.abc import Awaitable, Callable from typing import Annotated, Any, NotRequired @@ -19,137 +20,39 @@ from langchain.agents.middleware.tool_selection import ( LLMToolSelectorMiddleware, ) from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import AIMessage, HumanMessage from langchain_core.runnables import RunnableConfig from langchain_core.tools import BaseTool from langgraph.runtime import Runtime from typing_extensions import TypedDict # noqa +from app.agent.tools.tags import ToolTag from app.log import logger MIN_SELECTED_TOOL_COUNT = 4 +RECENT_SELECTION_CONTEXT_MESSAGE_LIMIT = 6 +RECENT_SELECTION_CONTEXT_MAX_CHARS = 6000 +RECENT_SELECTION_CONTEXT_TRUNCATION_PREFIX = "..." +TOOL_GROUP_EXCLUDED_TAGS = frozenset( + { + ToolTag.AgentTool.value, + ToolTag.Read.value, + ToolTag.Write.value, + ToolTag.Admin.value, + ToolTag.Message.value, + ToolTag.UserInteraction.value, + ToolTag.TerminalResponse.value, + } +) MOVIEPILOT_TOOL_SELECTION_HINT = """ MoviePilot tool-chain hints: -- For media search and download tasks, keep related steps together when relevant: - search_media, search_torrents, get_search_results, add_download_tasks, query_download_tasks. -- For file organization and library transfer tasks, keep related steps together when relevant: - list_directory, query_directory_settings, recognize_media, query_library_exists, transfer_file, query_transfer_history, scrape_metadata. -- For subscription tasks, keep related steps together when relevant: - search_subscribe, add_subscribe, query_subscribes, update_subscribe, query_subscribe_history, query_popular_subscribes. -- For download management tasks, keep related steps together when relevant: - query_download_tasks, update_download_tasks, delete_download_tasks, query_downloaders. -- For site diagnostics or maintenance tasks, keep related steps together when relevant: - query_sites, query_site_userdata, test_site, update_site, update_site_cookie. -- For scheduler and workflow tasks, keep related steps together when relevant: - query_schedulers, run_scheduler, query_workflows, run_workflow, query_episode_schedule. -- For plugin tasks, keep related steps together when relevant: - query_installed_plugins, query_market_plugins, query_plugin_capabilities, query_plugin_config, update_plugin_config, query_plugin_data, install_plugin, uninstall_plugin, reload_plugin. -- For rule, identifier, or system setting tasks, keep related steps together when relevant: - query_rule_groups, query_builtin_filter_rules, query_custom_filter_rules, add_custom_filter_rule, update_custom_filter_rule, delete_custom_filter_rule, add_rule_group, update_rule_group, delete_rule_group, query_custom_identifiers, update_custom_identifiers, query_system_settings, update_system_settings. -- Prefer including the likely next-step tools in the same workflow instead of selecting only the first tool. +- Tools with the same capability tag belong to the same functional group. +- For multi-step MoviePilot tasks, keep same-tag tools together when relevant. +- Prefer selecting likely next-step tools in the same capability group instead of selecting only the first tool. """ -TOOL_CHAIN_GROUPS = ( - ( - "media_download", - ( - "search_media", - "search_torrents", - "get_search_results", - "add_download_tasks", - "query_download_tasks", - "query_downloaders", - ), - ), - ( - "library_transfer", - ( - "list_directory", - "query_directory_settings", - "recognize_media", - "query_library_exists", - "transfer_file", - "query_transfer_history", - "scrape_metadata", - ), - ), - ( - "subscription", - ( - "search_subscribe", - "add_subscribe", - "query_subscribes", - "update_subscribe", - "delete_subscribe", - "query_subscribe_history", - "query_popular_subscribes", - "query_subscribe_shares", - ), - ), - ( - "download_management", - ( - "query_download_tasks", - "update_download_tasks", - "delete_download_tasks", - "query_downloaders", - ), - ), - ( - "site_management", - ( - "query_sites", - "query_site_userdata", - "test_site", - "update_site", - "update_site_cookie", - ), - ), - ( - "workflow_scheduler", - ( - "query_schedulers", - "run_scheduler", - "query_workflows", - "run_workflow", - "query_episode_schedule", - ), - ), - ( - "plugin_management", - ( - "query_installed_plugins", - "query_market_plugins", - "query_plugin_capabilities", - "query_plugin_config", - "update_plugin_config", - "query_plugin_data", - "install_plugin", - "uninstall_plugin", - "reload_plugin", - ), - ), - ( - "rule_settings", - ( - "query_rule_groups", - "query_builtin_filter_rules", - "query_custom_filter_rules", - "add_custom_filter_rule", - "update_custom_filter_rule", - "delete_custom_filter_rule", - "add_rule_group", - "update_rule_group", - "delete_rule_group", - "query_custom_identifiers", - "update_custom_identifiers", - "query_system_settings", - "update_system_settings", - ), - ), -) - class ToolSelectionState(AgentState): """工具筛选中间件私有状态。""" @@ -203,6 +106,73 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware): ) self.selection_tools = selection_tools or [] + @classmethod + def _render_recent_conversation_context( + cls, + messages: list[Any], + ) -> tuple[str, int]: + """渲染最近对话上下文,供工具筛选模型理解多轮追问。""" + rendered_messages = [] + for message in messages: + if isinstance(message, HumanMessage): + role = "User" + elif isinstance(message, AIMessage): + role = "Assistant" + else: + continue + + content = cls._extract_text_content(message.content).strip() + if not content: + continue + rendered_messages.append(f"{role}: {content}") + + recent_messages = rendered_messages[-RECENT_SELECTION_CONTEXT_MESSAGE_LIMIT:] + context = "\n\n".join(recent_messages) + if len(context) > RECENT_SELECTION_CONTEXT_MAX_CHARS: + context = ( + f"{RECENT_SELECTION_CONTEXT_TRUNCATION_PREFIX}" + f"{context[-RECENT_SELECTION_CONTEXT_MAX_CHARS:]}" + ) + return context, len(recent_messages) + + @classmethod + def _build_contextual_user_message( + cls, + messages: list[Any], + last_user_message: HumanMessage, + ) -> HumanMessage: + """根据最近对话构造工具筛选专用用户消息。""" + context, message_count = cls._render_recent_conversation_context(messages) + if message_count <= 1: + return last_user_message + + return HumanMessage( + content=( + "Recent conversation context for tool selection:\n" + f"{context}\n\n" + "Select tools for the latest user instruction. Use prior assistant " + "messages and earlier user requests when the latest user message " + "depends on previous context." + ) + ) + + def _prepare_selection_request( + self, + request: ModelRequest[ContextT], + ) -> Any | None: + """准备带最近对话上下文的工具筛选请求。""" + selection_request = super()._prepare_selection_request(request) + if selection_request is None: + return None + + contextual_user_message = self._build_contextual_user_message( + messages=request.messages, + last_user_message=selection_request.last_user_message, + ) + if contextual_user_message is selection_request.last_user_message: + return selection_request + return replace(selection_request, last_user_message=contextual_user_message) + @staticmethod def _append_tool_selection_hint(system_prompt: str) -> str: """追加 MoviePilot 工具组选择提示,避免复杂链路只选中首个工具。""" @@ -216,34 +186,123 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware): return min(self.max_tools, len(valid_tool_names)) return len(valid_tool_names) + @staticmethod + def _normalize_tool_tags(tool: BaseTool) -> list[str]: + """读取工具的业务标签,过滤掉无法表达工具组的通用标签。""" + tags = getattr(tool, "tags", None) or [] + if isinstance(tags, str): + tags = [tags] + + normalized_tags = [] + for tag in tags: + tag_value = getattr(tag, "value", tag) + if not tag_value: + continue + tag_name = str(tag_value) + if tag_name in TOOL_GROUP_EXCLUDED_TAGS or tag_name in normalized_tags: + continue + normalized_tags.append(tag_name) + return normalized_tags + + @classmethod + def _build_tool_groups( + cls, + available_tools: list[BaseTool], + valid_tool_names: list[str], + ) -> list[tuple[str, list[str]]]: + """根据工具标签构造能力组,保留当前工具列表中的稳定顺序。""" + valid_tool_set = set(valid_tool_names) + tool_groups: dict[str, list[str]] = {} + for tool in available_tools: + tool_name = getattr(tool, "name", None) + if not tool_name or tool_name not in valid_tool_set: + continue + for tag in cls._normalize_tool_tags(tool): + group_tool_names = tool_groups.setdefault(tag, []) + if tool_name not in group_tool_names: + group_tool_names.append(tool_name) + + return [ + (tag, tool_names) + for tag, tool_names in tool_groups.items() + if len(tool_names) > 1 + ] + + @classmethod + def _get_matched_tool_groups( + cls, + selected_names: list[str], + available_tools: list[BaseTool], + valid_tool_names: list[str], + ) -> list[tuple[str, list[str]]]: + """返回已选工具命中的标签能力组。""" + groups_by_tag = { + tag: tool_names + for tag, tool_names in cls._build_tool_groups( + available_tools=available_tools, + valid_tool_names=valid_tool_names, + ) + } + tools_by_name = { + tool.name: tool + for tool in available_tools + if getattr(tool, "name", None) + } + matched_groups: list[tuple[str, list[str]]] = [] + seen_tags = set() + for tool_name in selected_names: + tool = tools_by_name.get(tool_name) + if not tool: + continue + for tag in cls._normalize_tool_tags(tool): + if tag in seen_tags or tag not in groups_by_tag: + continue + matched_groups.append((tag, groups_by_tag[tag])) + seen_tags.add(tag) + return matched_groups + def _complete_low_count_selection( self, selected_tool_names: list[str], valid_tool_names: list[str], + available_tools: list[BaseTool], ) -> list[str]: """ - 当模型只选出极少工具时,按 MoviePilot 常见工具链补齐相邻工具。 + 当模型只选出极少工具时,按工具标签补齐同组工具。 - 这只补齐已经命中的工具组,不会把所有工具组都展开,因此能降低 - “选了搜索工具但漏了结果/下载工具”这类链式任务失败概率。 + 工具标签是工具自身声明的能力归属。这里只补齐已经命中的标签组, + 不会把所有工具组都展开。 """ limit = self._get_tool_selection_limit(valid_tool_names) - target_count = min(MIN_SELECTED_TOOL_COUNT, limit) selected_names = [ tool_name for tool_name in selected_tool_names if tool_name in valid_tool_names ] - if len(selected_names) >= target_count: - return selected_names[:limit] - selected_set = set(selected_names) valid_tool_set = set(valid_tool_names) completed_names = list(selected_names) + matched_groups = self._get_matched_tool_groups( + selected_names=selected_names, + available_tools=available_tools, + valid_tool_names=valid_tool_names, + ) + if not matched_groups: + return completed_names[:limit] - for _, group_tool_names in TOOL_CHAIN_GROUPS: - if not selected_set.intersection(group_tool_names): - continue + matched_group_tool_names = { + tool_name + for _, group_tool_names in matched_groups + for tool_name in group_tool_names + } + target_count = min( + max(MIN_SELECTED_TOOL_COUNT, len(matched_group_tool_names)), + limit, + ) + if len(selected_names) >= target_count: + return selected_names[:limit] + + for _, group_tool_names in matched_groups: for tool_name in group_tool_names: if tool_name in selected_set or tool_name not in valid_tool_set: continue @@ -285,6 +344,7 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware): if isinstance(tool_name, str) ], valid_tool_names=valid_tool_names, + available_tools=available_tools, ) return super()._process_selection_response( response, @@ -382,12 +442,35 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware): raise ValueError("工具筛选 JSON 顶层必须是对象") return payload - @staticmethod - def _render_tool_list(available_tools: list[Any]) -> str: + @classmethod + def _render_tool_list(cls, available_tools: list[Any]) -> str: """把工具名和描述渲染成稳定的文本列表。""" - return "\n".join( - f"- {tool.name}: {tool.description}" for tool in available_tools + lines = [] + for tool in available_tools: + tags = cls._normalize_tool_tags(tool) + tag_text = f" [group tags: {', '.join(tags)}]" if tags else "" + lines.append(f"- {tool.name}{tag_text}: {tool.description}") + return "\n".join(lines) + + @classmethod + def _render_tool_groups(cls, available_tools: list[BaseTool]) -> str: + """把当前可用工具按标签渲染成能力组提示。""" + valid_tool_names = [ + tool.name + for tool in available_tools + if getattr(tool, "name", None) + ] + groups = cls._build_tool_groups( + available_tools=available_tools, + valid_tool_names=valid_tool_names, ) + if not groups: + return "" + rendered_groups = "\n".join( + f"- {tag}: {', '.join(tool_names)}" + for tag, tool_names in groups + ) + return f"Capability groups from tool tags:\n{rendered_groups}\n\n" def _build_deepseek_selection_prompt(self, selection_request: Any) -> str: """ @@ -408,8 +491,10 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware): "- The `tools` field must be a JSON array of strings.\n" "- Only use tool names from the allowed list below.\n" "- Order tools by relevance, with the most relevant first.\n" + "- Tools sharing the same capability tag are in the same group; include same-group tools together when relevant.\n" f"{limit_instruction}\n" "- Do not add explanations, markdown, or extra keys.\n\n" + f"{self._render_tool_groups(selection_request.available_tools)}" "Allowed tools:\n" f"{self._render_tool_list(selection_request.available_tools)}" ) diff --git a/app/agent/tools/impl/recognize_captcha.py b/app/agent/tools/impl/recognize_captcha.py index 5f7feaf1..f36c57ca 100644 --- a/app/agent/tools/impl/recognize_captcha.py +++ b/app/agent/tools/impl/recognize_captcha.py @@ -52,6 +52,7 @@ class RecognizeCaptchaTool(MoviePilotTool): tags: list[str] = [ ToolTag.Read, ToolTag.Web, + ToolTag.Site, ] description: str = ( "Recognize a graphic captcha image and return the captcha text. " @@ -70,6 +71,21 @@ class RecognizeCaptchaTool(MoviePilotTool): return "识别图形验证码: data image" return f"识别图形验证码: {image_url}" + @staticmethod + def _format_image_url_for_log(image_url: str) -> str: + """生成验证码图片地址的安全日志摘要,避免 data URL 图片刷屏。""" + clean_url = (image_url or "").strip() + if not clean_url: + return "" + if clean_url.lower().startswith("data:image/"): + metadata, separator, data = clean_url.partition(",") + if separator: + return f"{metadata}," + return f"data:image," + if len(clean_url) > 300: + return f"{clean_url[:300]}...(已截断,总长度: {len(clean_url)})" + return clean_url + @staticmethod def _recognize_captcha_sync( image_url: str, @@ -117,7 +133,10 @@ class RecognizeCaptchaTool(MoviePilotTool): :param allow_private_network: 是否允许访问本机或私网地址 :return: JSON 格式的识别结果 """ - logger.info(f"执行工具: {self.name}, 参数: image_url={image_url}") + logger.info( + f"执行工具: {self.name}, " + f"参数: image_url={self._format_image_url_for_log(image_url)}" + ) try: captcha_text = await self.run_blocking( diff --git a/app/helper/ocr.py b/app/helper/ocr.py index 93630ed8..b92be11f 100644 --- a/app/helper/ocr.py +++ b/app/helper/ocr.py @@ -31,7 +31,7 @@ class OcrHelper: if image_url: data_url_b64 = self._extract_data_url_base64(image_url) if data_url_b64: - image_b64 = data_url_b64 + image_b64 = self._normalize_image_base64(data_url_b64) else: ret = RequestUtils(ua=ua, cookies=cookie).get_res(image_url) @@ -54,7 +54,14 @@ class OcrHelper: """规范化外部传入的图片 base64 内容。""" if not image_b64: return "" - return OcrHelper._extract_data_url_base64(image_b64) or image_b64.strip() + clean_image_b64 = OcrHelper._extract_data_url_base64(image_b64) or image_b64 + clean_image_b64 = "".join(clean_image_b64.split()) + if not clean_image_b64: + return "" + padding_size = len(clean_image_b64) % 4 + if padding_size: + clean_image_b64 = f"{clean_image_b64}{'=' * (4 - padding_size)}" + return clean_image_b64 @staticmethod def _extract_data_url_base64(image_url: Optional[str]) -> str: diff --git a/tests/test_agent_recognize_captcha_tool.py b/tests/test_agent_recognize_captcha_tool.py index 96c9c741..59f079e5 100644 --- a/tests/test_agent_recognize_captcha_tool.py +++ b/tests/test_agent_recognize_captcha_tool.py @@ -82,6 +82,35 @@ def test_ocr_helper_extracts_data_url_base64_without_downloading_image(): } +def test_ocr_helper_normalizes_data_url_base64_padding(): + """data:image 地址缺少 padding 时应补齐后提交给 OCR 服务。""" + image_url = "data:image/jpeg;base64,YWJjZA" + + with patch("app.helper.ocr.RequestUtils") as request_utils: + request_utils.return_value.post_res.return_value = _FakeResponse( + payload={"result": "z9k2"} + ) + + result = OcrHelper().get_captcha_text(image_url=image_url) + + assert result == "z9k2" + request_utils.return_value.get_res.assert_not_called() + assert request_utils.return_value.post_res.call_args.kwargs["json"] == { + "base64_img": "YWJjZA==" + } + + +def test_recognize_captcha_tool_formats_data_url_for_log(): + """验证码工具日志应隐藏 data:image 的完整图片内容。""" + image_b64 = base64.b64encode(b"captcha-image").decode() + image_url = f"data:image/jpeg;base64,{image_b64}" + + result = RecognizeCaptchaTool._format_image_url_for_log(image_url) + + assert result == f"data:image/jpeg;base64," + assert image_b64 not in result + + def test_recognize_captcha_tool_returns_captcha_text_from_ocr_helper(): """验证码工具应返回结构化识别结果,便于 Agent 继续填写表单。""" tool = RecognizeCaptchaTool(session_id="captcha-session", user_id="10001") diff --git a/tests/test_agent_tool_selector_middleware.py b/tests/test_agent_tool_selector_middleware.py index 21a38e51..32ab6dff 100644 --- a/tests/test_agent_tool_selector_middleware.py +++ b/tests/test_agent_tool_selector_middleware.py @@ -3,9 +3,10 @@ import unittest from types import SimpleNamespace from unittest.mock import patch -from langchain_core.messages import HumanMessage +from langchain_core.messages import AIMessage, HumanMessage from app.agent.middleware import tool_selection as tool_selector_module +from app.agent.tools.tags import ToolTag class _FakeBoundModel: @@ -60,6 +61,11 @@ class _FakeRequest: return _FakeRequest(**data) +def _tool(name, description, tags=None): + """构造测试用工具对象。""" + return SimpleNamespace(name=name, description=description, tags=tags or []) + + class ToolSelectorMiddlewareTest(unittest.TestCase): def test_awrap_model_call_uses_json_mode_for_deepseek(self): tools = [ @@ -223,14 +229,115 @@ class ToolSelectorMiddlewareTest(unittest.TestCase): self.assertEqual(normalized, {"tools": ["search"]}) - def test_process_selection_response_completes_low_count_tool_chain(self): - """筛选结果过少时应按已命中的工具链补齐相邻工具。""" + def test_deepseek_selection_uses_recent_conversation_context(self): + """多轮追问时工具筛选应看到上一轮用户需求和助手回复。""" tools = [ - SimpleNamespace(name="search_media", description="Search media"), - SimpleNamespace(name="search_torrents", description="Search torrents"), - SimpleNamespace(name="get_search_results", description="Get results"), - SimpleNamespace(name="add_download_tasks", description="Add downloads"), - SimpleNamespace(name="query_download_tasks", description="Query downloads"), + _tool( + "query_plugin_config", + "Query plugin config", + [ToolTag.Read, ToolTag.Plugin, ToolTag.Settings], + ), + _tool( + "update_plugin_config", + "Update plugin config", + [ToolTag.Write, ToolTag.Plugin, ToolTag.Settings], + ), + _tool( + "reload_plugin", + "Reload plugin", + [ToolTag.Write, ToolTag.Plugin], + ), + ] + model = _FakeModel(content='{"tools": ["query_plugin_config"]}') + middleware = tool_selector_module.ToolSelectorMiddleware( + max_tools=3, + selection_tools=tools, + ) + middleware.model = model + request = _FakeRequest( + tools=tools, + messages=[ + HumanMessage(content="帮我检查插件 DemoPlugin 的配置"), + AIMessage(content="我建议先查询插件配置,然后根据结果决定是否重载插件。"), + HumanMessage(content="按你说的来"), + ], + model=model, + ) + + state_update = asyncio.run( + middleware.abefore_agent(request.state, runtime=None, config=None) + ) + + user_message = model.bound_model.messages[1] + self.assertEqual( + state_update, + {"selected_tool_names": ["query_plugin_config", "update_plugin_config", "reload_plugin"]}, + ) + self.assertIsInstance(user_message, HumanMessage) + self.assertIn( + "Recent conversation context for tool selection", + user_message.content, + ) + self.assertIn("帮我检查插件 DemoPlugin 的配置", user_message.content) + self.assertIn("我建议先查询插件配置", user_message.content) + self.assertIn("按你说的来", user_message.content) + + def test_single_turn_selection_keeps_original_user_message(self): + """单轮对话不应额外包裹上下文提示。""" + tools = [ + _tool("search", "Search for information", [ToolTag.Read, ToolTag.Web]), + _tool("calendar", "Manage events", [ToolTag.Write]), + ] + model = _FakeModel(content='{"tools": ["search"]}') + middleware = tool_selector_module.ToolSelectorMiddleware( + max_tools=2, + selection_tools=tools, + ) + middleware.model = model + original_message = HumanMessage(content="帮我查一下最近的更新") + request = _FakeRequest( + tools=tools, + messages=[original_message], + model=model, + ) + + asyncio.run(middleware.abefore_agent(request.state, runtime=None, config=None)) + + user_message = model.bound_model.messages[1] + self.assertIs(user_message, original_message) + self.assertNotIn( + "Recent conversation context for tool selection", + user_message.content, + ) + + def test_process_selection_response_completes_low_count_tool_group_by_tags(self): + """筛选结果过少时应按已命中的工具标签组补齐同组工具。""" + tools = [ + _tool( + "search_media", + "Search media", + [ToolTag.Read, ToolTag.Media], + ), + _tool( + "search_torrents", + "Search torrents", + [ToolTag.Read, ToolTag.Resource, ToolTag.Site, ToolTag.Media], + ), + _tool( + "get_search_results", + "Get results", + [ToolTag.Read, ToolTag.Resource], + ), + _tool( + "add_download_tasks", + "Add downloads", + [ToolTag.Write, ToolTag.Download, ToolTag.Resource], + ), + _tool( + "query_download_tasks", + "Query downloads", + [ToolTag.Read, ToolTag.Download], + ), ] middleware = tool_selector_module.ToolSelectorMiddleware( max_tools=4, @@ -243,20 +350,21 @@ class ToolSelectorMiddlewareTest(unittest.TestCase): ) result = middleware._process_selection_response( - {"tools": ["search_media"]}, + {"tools": ["search_torrents"]}, available_tools=tools, valid_tool_names=[tool.name for tool in tools], request=request, ) + self.assertEqual(len(result.tools), 4) self.assertEqual( - [tool.name for tool in result.tools], - [ + {tool.name for tool in result.tools}, + { "search_media", "search_torrents", "get_search_results", "add_download_tasks", - ], + }, ) def test_process_selection_response_keeps_high_count_selection(self): @@ -302,12 +410,28 @@ class ToolSelectorMiddlewareTest(unittest.TestCase): ) def test_process_selection_response_respects_max_tools_when_completing(self): - """工具链补齐不应突破 max_tools 上限。""" + """标签组补齐不应突破 max_tools 上限。""" tools = [ - SimpleNamespace(name="list_directory", description="List directory"), - SimpleNamespace(name="query_directory_settings", description="Query settings"), - SimpleNamespace(name="recognize_media", description="Recognize media"), - SimpleNamespace(name="transfer_file", description="Transfer file"), + _tool( + "list_directory", + "List directory", + [ToolTag.Read, ToolTag.Directory, ToolTag.File], + ), + _tool( + "query_directory_settings", + "Query settings", + [ToolTag.Read, ToolTag.Directory, ToolTag.Settings], + ), + _tool( + "recognize_media", + "Recognize media", + [ToolTag.Read, ToolTag.Media], + ), + _tool( + "transfer_file", + "Transfer file", + [ToolTag.Write, ToolTag.Transfer, ToolTag.Library, ToolTag.File], + ), ] middleware = tool_selector_module.ToolSelectorMiddleware( max_tools=2, @@ -331,3 +455,29 @@ class ToolSelectorMiddlewareTest(unittest.TestCase): {tool.name for tool in result.tools}, {"transfer_file", "list_directory"}, ) + + def test_process_selection_response_ignores_generic_tags_when_completing(self): + """通用权限标签不应被当作工具组使用。""" + tools = [ + _tool("read_one", "Read one", [ToolTag.Read]), + _tool("read_two", "Read two", [ToolTag.Read]), + _tool("write_one", "Write one", [ToolTag.Write, ToolTag.Admin]), + ] + middleware = tool_selector_module.ToolSelectorMiddleware( + max_tools=4, + selection_tools=tools, + ) + request = _FakeRequest( + tools=tools, + messages=[HumanMessage(content="查一下信息")], + model=_FakeModel(), + ) + + result = middleware._process_selection_response( + {"tools": ["read_one"]}, + available_tools=tools, + valid_tool_names=[tool.name for tool in tools], + request=request, + ) + + self.assertEqual([tool.name for tool in result.tools], ["read_one"])