mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-22 07:54:06 +08:00
feat(activity_log): simplify user text matching and enhance logging for captcha tool
This commit is contained in:
@@ -47,11 +47,6 @@ MAX_LOG_FILE_SIZE = 256 * 1024
|
||||
# 提取本轮对话上下文的最大字符数(避免过长的对话消耗太多 token)
|
||||
MAX_CONTEXT_FOR_SUMMARY = 4000
|
||||
|
||||
TRIVIAL_USER_TEXT_PATTERN = re.compile(
|
||||
r"^\s*(你好|您好|hi|hello|hey|谢谢|谢了|多谢|ok|好的|收到|嗯|嗯嗯|是的|对|可以|行|好)\s*[。.!!]?\s*$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
SUMMARY_SKIP_MARKER = "SKIP"
|
||||
|
||||
# LLM 总结的提示词
|
||||
@@ -343,15 +338,7 @@ def _should_skip_activity_summary(round_messages: list) -> bool:
|
||||
if has_tool_activity:
|
||||
return False
|
||||
|
||||
user_text_parts = []
|
||||
for msg in round_messages:
|
||||
if not isinstance(msg, HumanMessage):
|
||||
continue
|
||||
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
if content:
|
||||
user_text_parts.append(content)
|
||||
user_text = " ".join(user_text_parts).strip()
|
||||
return bool(user_text and TRIVIAL_USER_TEXT_PATTERN.match(user_text))
|
||||
return True
|
||||
|
||||
|
||||
async def _summarize_with_llm(conversation_text: str) -> Optional[str]:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""MoviePilot 自定义工具筛选中间件。"""
|
||||
|
||||
from dataclasses import replace
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Annotated, Any, NotRequired
|
||||
@@ -19,137 +20,39 @@ from langchain.agents.middleware.tool_selection import (
|
||||
LLMToolSelectorMiddleware,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.runtime import Runtime
|
||||
from typing_extensions import TypedDict # noqa
|
||||
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.log import logger
|
||||
|
||||
MIN_SELECTED_TOOL_COUNT = 4
|
||||
RECENT_SELECTION_CONTEXT_MESSAGE_LIMIT = 6
|
||||
RECENT_SELECTION_CONTEXT_MAX_CHARS = 6000
|
||||
RECENT_SELECTION_CONTEXT_TRUNCATION_PREFIX = "..."
|
||||
TOOL_GROUP_EXCLUDED_TAGS = frozenset(
|
||||
{
|
||||
ToolTag.AgentTool.value,
|
||||
ToolTag.Read.value,
|
||||
ToolTag.Write.value,
|
||||
ToolTag.Admin.value,
|
||||
ToolTag.Message.value,
|
||||
ToolTag.UserInteraction.value,
|
||||
ToolTag.TerminalResponse.value,
|
||||
}
|
||||
)
|
||||
|
||||
MOVIEPILOT_TOOL_SELECTION_HINT = """
|
||||
|
||||
MoviePilot tool-chain hints:
|
||||
- For media search and download tasks, keep related steps together when relevant:
|
||||
search_media, search_torrents, get_search_results, add_download_tasks, query_download_tasks.
|
||||
- For file organization and library transfer tasks, keep related steps together when relevant:
|
||||
list_directory, query_directory_settings, recognize_media, query_library_exists, transfer_file, query_transfer_history, scrape_metadata.
|
||||
- For subscription tasks, keep related steps together when relevant:
|
||||
search_subscribe, add_subscribe, query_subscribes, update_subscribe, query_subscribe_history, query_popular_subscribes.
|
||||
- For download management tasks, keep related steps together when relevant:
|
||||
query_download_tasks, update_download_tasks, delete_download_tasks, query_downloaders.
|
||||
- For site diagnostics or maintenance tasks, keep related steps together when relevant:
|
||||
query_sites, query_site_userdata, test_site, update_site, update_site_cookie.
|
||||
- For scheduler and workflow tasks, keep related steps together when relevant:
|
||||
query_schedulers, run_scheduler, query_workflows, run_workflow, query_episode_schedule.
|
||||
- For plugin tasks, keep related steps together when relevant:
|
||||
query_installed_plugins, query_market_plugins, query_plugin_capabilities, query_plugin_config, update_plugin_config, query_plugin_data, install_plugin, uninstall_plugin, reload_plugin.
|
||||
- For rule, identifier, or system setting tasks, keep related steps together when relevant:
|
||||
query_rule_groups, query_builtin_filter_rules, query_custom_filter_rules, add_custom_filter_rule, update_custom_filter_rule, delete_custom_filter_rule, add_rule_group, update_rule_group, delete_rule_group, query_custom_identifiers, update_custom_identifiers, query_system_settings, update_system_settings.
|
||||
- Prefer including the likely next-step tools in the same workflow instead of selecting only the first tool.
|
||||
- Tools with the same capability tag belong to the same functional group.
|
||||
- For multi-step MoviePilot tasks, keep same-tag tools together when relevant.
|
||||
- Prefer selecting likely next-step tools in the same capability group instead of selecting only the first tool.
|
||||
"""
|
||||
|
||||
TOOL_CHAIN_GROUPS = (
|
||||
(
|
||||
"media_download",
|
||||
(
|
||||
"search_media",
|
||||
"search_torrents",
|
||||
"get_search_results",
|
||||
"add_download_tasks",
|
||||
"query_download_tasks",
|
||||
"query_downloaders",
|
||||
),
|
||||
),
|
||||
(
|
||||
"library_transfer",
|
||||
(
|
||||
"list_directory",
|
||||
"query_directory_settings",
|
||||
"recognize_media",
|
||||
"query_library_exists",
|
||||
"transfer_file",
|
||||
"query_transfer_history",
|
||||
"scrape_metadata",
|
||||
),
|
||||
),
|
||||
(
|
||||
"subscription",
|
||||
(
|
||||
"search_subscribe",
|
||||
"add_subscribe",
|
||||
"query_subscribes",
|
||||
"update_subscribe",
|
||||
"delete_subscribe",
|
||||
"query_subscribe_history",
|
||||
"query_popular_subscribes",
|
||||
"query_subscribe_shares",
|
||||
),
|
||||
),
|
||||
(
|
||||
"download_management",
|
||||
(
|
||||
"query_download_tasks",
|
||||
"update_download_tasks",
|
||||
"delete_download_tasks",
|
||||
"query_downloaders",
|
||||
),
|
||||
),
|
||||
(
|
||||
"site_management",
|
||||
(
|
||||
"query_sites",
|
||||
"query_site_userdata",
|
||||
"test_site",
|
||||
"update_site",
|
||||
"update_site_cookie",
|
||||
),
|
||||
),
|
||||
(
|
||||
"workflow_scheduler",
|
||||
(
|
||||
"query_schedulers",
|
||||
"run_scheduler",
|
||||
"query_workflows",
|
||||
"run_workflow",
|
||||
"query_episode_schedule",
|
||||
),
|
||||
),
|
||||
(
|
||||
"plugin_management",
|
||||
(
|
||||
"query_installed_plugins",
|
||||
"query_market_plugins",
|
||||
"query_plugin_capabilities",
|
||||
"query_plugin_config",
|
||||
"update_plugin_config",
|
||||
"query_plugin_data",
|
||||
"install_plugin",
|
||||
"uninstall_plugin",
|
||||
"reload_plugin",
|
||||
),
|
||||
),
|
||||
(
|
||||
"rule_settings",
|
||||
(
|
||||
"query_rule_groups",
|
||||
"query_builtin_filter_rules",
|
||||
"query_custom_filter_rules",
|
||||
"add_custom_filter_rule",
|
||||
"update_custom_filter_rule",
|
||||
"delete_custom_filter_rule",
|
||||
"add_rule_group",
|
||||
"update_rule_group",
|
||||
"delete_rule_group",
|
||||
"query_custom_identifiers",
|
||||
"update_custom_identifiers",
|
||||
"query_system_settings",
|
||||
"update_system_settings",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ToolSelectionState(AgentState):
|
||||
"""工具筛选中间件私有状态。"""
|
||||
@@ -203,6 +106,73 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
)
|
||||
self.selection_tools = selection_tools or []
|
||||
|
||||
@classmethod
|
||||
def _render_recent_conversation_context(
|
||||
cls,
|
||||
messages: list[Any],
|
||||
) -> tuple[str, int]:
|
||||
"""渲染最近对话上下文,供工具筛选模型理解多轮追问。"""
|
||||
rendered_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message, HumanMessage):
|
||||
role = "User"
|
||||
elif isinstance(message, AIMessage):
|
||||
role = "Assistant"
|
||||
else:
|
||||
continue
|
||||
|
||||
content = cls._extract_text_content(message.content).strip()
|
||||
if not content:
|
||||
continue
|
||||
rendered_messages.append(f"{role}: {content}")
|
||||
|
||||
recent_messages = rendered_messages[-RECENT_SELECTION_CONTEXT_MESSAGE_LIMIT:]
|
||||
context = "\n\n".join(recent_messages)
|
||||
if len(context) > RECENT_SELECTION_CONTEXT_MAX_CHARS:
|
||||
context = (
|
||||
f"{RECENT_SELECTION_CONTEXT_TRUNCATION_PREFIX}"
|
||||
f"{context[-RECENT_SELECTION_CONTEXT_MAX_CHARS:]}"
|
||||
)
|
||||
return context, len(recent_messages)
|
||||
|
||||
@classmethod
|
||||
def _build_contextual_user_message(
|
||||
cls,
|
||||
messages: list[Any],
|
||||
last_user_message: HumanMessage,
|
||||
) -> HumanMessage:
|
||||
"""根据最近对话构造工具筛选专用用户消息。"""
|
||||
context, message_count = cls._render_recent_conversation_context(messages)
|
||||
if message_count <= 1:
|
||||
return last_user_message
|
||||
|
||||
return HumanMessage(
|
||||
content=(
|
||||
"Recent conversation context for tool selection:\n"
|
||||
f"{context}\n\n"
|
||||
"Select tools for the latest user instruction. Use prior assistant "
|
||||
"messages and earlier user requests when the latest user message "
|
||||
"depends on previous context."
|
||||
)
|
||||
)
|
||||
|
||||
def _prepare_selection_request(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
) -> Any | None:
|
||||
"""准备带最近对话上下文的工具筛选请求。"""
|
||||
selection_request = super()._prepare_selection_request(request)
|
||||
if selection_request is None:
|
||||
return None
|
||||
|
||||
contextual_user_message = self._build_contextual_user_message(
|
||||
messages=request.messages,
|
||||
last_user_message=selection_request.last_user_message,
|
||||
)
|
||||
if contextual_user_message is selection_request.last_user_message:
|
||||
return selection_request
|
||||
return replace(selection_request, last_user_message=contextual_user_message)
|
||||
|
||||
@staticmethod
|
||||
def _append_tool_selection_hint(system_prompt: str) -> str:
|
||||
"""追加 MoviePilot 工具组选择提示,避免复杂链路只选中首个工具。"""
|
||||
@@ -216,34 +186,123 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
return min(self.max_tools, len(valid_tool_names))
|
||||
return len(valid_tool_names)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_tool_tags(tool: BaseTool) -> list[str]:
|
||||
"""读取工具的业务标签,过滤掉无法表达工具组的通用标签。"""
|
||||
tags = getattr(tool, "tags", None) or []
|
||||
if isinstance(tags, str):
|
||||
tags = [tags]
|
||||
|
||||
normalized_tags = []
|
||||
for tag in tags:
|
||||
tag_value = getattr(tag, "value", tag)
|
||||
if not tag_value:
|
||||
continue
|
||||
tag_name = str(tag_value)
|
||||
if tag_name in TOOL_GROUP_EXCLUDED_TAGS or tag_name in normalized_tags:
|
||||
continue
|
||||
normalized_tags.append(tag_name)
|
||||
return normalized_tags
|
||||
|
||||
@classmethod
|
||||
def _build_tool_groups(
|
||||
cls,
|
||||
available_tools: list[BaseTool],
|
||||
valid_tool_names: list[str],
|
||||
) -> list[tuple[str, list[str]]]:
|
||||
"""根据工具标签构造能力组,保留当前工具列表中的稳定顺序。"""
|
||||
valid_tool_set = set(valid_tool_names)
|
||||
tool_groups: dict[str, list[str]] = {}
|
||||
for tool in available_tools:
|
||||
tool_name = getattr(tool, "name", None)
|
||||
if not tool_name or tool_name not in valid_tool_set:
|
||||
continue
|
||||
for tag in cls._normalize_tool_tags(tool):
|
||||
group_tool_names = tool_groups.setdefault(tag, [])
|
||||
if tool_name not in group_tool_names:
|
||||
group_tool_names.append(tool_name)
|
||||
|
||||
return [
|
||||
(tag, tool_names)
|
||||
for tag, tool_names in tool_groups.items()
|
||||
if len(tool_names) > 1
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _get_matched_tool_groups(
|
||||
cls,
|
||||
selected_names: list[str],
|
||||
available_tools: list[BaseTool],
|
||||
valid_tool_names: list[str],
|
||||
) -> list[tuple[str, list[str]]]:
|
||||
"""返回已选工具命中的标签能力组。"""
|
||||
groups_by_tag = {
|
||||
tag: tool_names
|
||||
for tag, tool_names in cls._build_tool_groups(
|
||||
available_tools=available_tools,
|
||||
valid_tool_names=valid_tool_names,
|
||||
)
|
||||
}
|
||||
tools_by_name = {
|
||||
tool.name: tool
|
||||
for tool in available_tools
|
||||
if getattr(tool, "name", None)
|
||||
}
|
||||
matched_groups: list[tuple[str, list[str]]] = []
|
||||
seen_tags = set()
|
||||
for tool_name in selected_names:
|
||||
tool = tools_by_name.get(tool_name)
|
||||
if not tool:
|
||||
continue
|
||||
for tag in cls._normalize_tool_tags(tool):
|
||||
if tag in seen_tags or tag not in groups_by_tag:
|
||||
continue
|
||||
matched_groups.append((tag, groups_by_tag[tag]))
|
||||
seen_tags.add(tag)
|
||||
return matched_groups
|
||||
|
||||
def _complete_low_count_selection(
|
||||
self,
|
||||
selected_tool_names: list[str],
|
||||
valid_tool_names: list[str],
|
||||
available_tools: list[BaseTool],
|
||||
) -> list[str]:
|
||||
"""
|
||||
当模型只选出极少工具时,按 MoviePilot 常见工具链补齐相邻工具。
|
||||
当模型只选出极少工具时,按工具标签补齐同组工具。
|
||||
|
||||
这只补齐已经命中的工具组,不会把所有工具组都展开,因此能降低
|
||||
“选了搜索工具但漏了结果/下载工具”这类链式任务失败概率。
|
||||
工具标签是工具自身声明的能力归属。这里只补齐已经命中的标签组,
|
||||
不会把所有工具组都展开。
|
||||
"""
|
||||
limit = self._get_tool_selection_limit(valid_tool_names)
|
||||
target_count = min(MIN_SELECTED_TOOL_COUNT, limit)
|
||||
selected_names = [
|
||||
tool_name
|
||||
for tool_name in selected_tool_names
|
||||
if tool_name in valid_tool_names
|
||||
]
|
||||
if len(selected_names) >= target_count:
|
||||
return selected_names[:limit]
|
||||
|
||||
selected_set = set(selected_names)
|
||||
valid_tool_set = set(valid_tool_names)
|
||||
completed_names = list(selected_names)
|
||||
matched_groups = self._get_matched_tool_groups(
|
||||
selected_names=selected_names,
|
||||
available_tools=available_tools,
|
||||
valid_tool_names=valid_tool_names,
|
||||
)
|
||||
if not matched_groups:
|
||||
return completed_names[:limit]
|
||||
|
||||
for _, group_tool_names in TOOL_CHAIN_GROUPS:
|
||||
if not selected_set.intersection(group_tool_names):
|
||||
continue
|
||||
matched_group_tool_names = {
|
||||
tool_name
|
||||
for _, group_tool_names in matched_groups
|
||||
for tool_name in group_tool_names
|
||||
}
|
||||
target_count = min(
|
||||
max(MIN_SELECTED_TOOL_COUNT, len(matched_group_tool_names)),
|
||||
limit,
|
||||
)
|
||||
if len(selected_names) >= target_count:
|
||||
return selected_names[:limit]
|
||||
|
||||
for _, group_tool_names in matched_groups:
|
||||
for tool_name in group_tool_names:
|
||||
if tool_name in selected_set or tool_name not in valid_tool_set:
|
||||
continue
|
||||
@@ -285,6 +344,7 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
if isinstance(tool_name, str)
|
||||
],
|
||||
valid_tool_names=valid_tool_names,
|
||||
available_tools=available_tools,
|
||||
)
|
||||
return super()._process_selection_response(
|
||||
response,
|
||||
@@ -382,12 +442,35 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
raise ValueError("工具筛选 JSON 顶层必须是对象")
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def _render_tool_list(available_tools: list[Any]) -> str:
|
||||
@classmethod
|
||||
def _render_tool_list(cls, available_tools: list[Any]) -> str:
|
||||
"""把工具名和描述渲染成稳定的文本列表。"""
|
||||
return "\n".join(
|
||||
f"- {tool.name}: {tool.description}" for tool in available_tools
|
||||
lines = []
|
||||
for tool in available_tools:
|
||||
tags = cls._normalize_tool_tags(tool)
|
||||
tag_text = f" [group tags: {', '.join(tags)}]" if tags else ""
|
||||
lines.append(f"- {tool.name}{tag_text}: {tool.description}")
|
||||
return "\n".join(lines)
|
||||
|
||||
@classmethod
|
||||
def _render_tool_groups(cls, available_tools: list[BaseTool]) -> str:
|
||||
"""把当前可用工具按标签渲染成能力组提示。"""
|
||||
valid_tool_names = [
|
||||
tool.name
|
||||
for tool in available_tools
|
||||
if getattr(tool, "name", None)
|
||||
]
|
||||
groups = cls._build_tool_groups(
|
||||
available_tools=available_tools,
|
||||
valid_tool_names=valid_tool_names,
|
||||
)
|
||||
if not groups:
|
||||
return ""
|
||||
rendered_groups = "\n".join(
|
||||
f"- {tag}: {', '.join(tool_names)}"
|
||||
for tag, tool_names in groups
|
||||
)
|
||||
return f"Capability groups from tool tags:\n{rendered_groups}\n\n"
|
||||
|
||||
def _build_deepseek_selection_prompt(self, selection_request: Any) -> str:
|
||||
"""
|
||||
@@ -408,8 +491,10 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
"- The `tools` field must be a JSON array of strings.\n"
|
||||
"- Only use tool names from the allowed list below.\n"
|
||||
"- Order tools by relevance, with the most relevant first.\n"
|
||||
"- Tools sharing the same capability tag are in the same group; include same-group tools together when relevant.\n"
|
||||
f"{limit_instruction}\n"
|
||||
"- Do not add explanations, markdown, or extra keys.\n\n"
|
||||
f"{self._render_tool_groups(selection_request.available_tools)}"
|
||||
"Allowed tools:\n"
|
||||
f"{self._render_tool_list(selection_request.available_tools)}"
|
||||
)
|
||||
|
||||
@@ -52,6 +52,7 @@ class RecognizeCaptchaTool(MoviePilotTool):
|
||||
tags: list[str] = [
|
||||
ToolTag.Read,
|
||||
ToolTag.Web,
|
||||
ToolTag.Site,
|
||||
]
|
||||
description: str = (
|
||||
"Recognize a graphic captcha image and return the captcha text. "
|
||||
@@ -70,6 +71,21 @@ class RecognizeCaptchaTool(MoviePilotTool):
|
||||
return "识别图形验证码: data image"
|
||||
return f"识别图形验证码: {image_url}"
|
||||
|
||||
@staticmethod
|
||||
def _format_image_url_for_log(image_url: str) -> str:
|
||||
"""生成验证码图片地址的安全日志摘要,避免 data URL 图片刷屏。"""
|
||||
clean_url = (image_url or "").strip()
|
||||
if not clean_url:
|
||||
return ""
|
||||
if clean_url.lower().startswith("data:image/"):
|
||||
metadata, separator, data = clean_url.partition(",")
|
||||
if separator:
|
||||
return f"{metadata},<base64:{len(data)} chars>"
|
||||
return f"data:image,<invalid:{len(clean_url)} chars>"
|
||||
if len(clean_url) > 300:
|
||||
return f"{clean_url[:300]}...(已截断,总长度: {len(clean_url)})"
|
||||
return clean_url
|
||||
|
||||
@staticmethod
|
||||
def _recognize_captcha_sync(
|
||||
image_url: str,
|
||||
@@ -117,7 +133,10 @@ class RecognizeCaptchaTool(MoviePilotTool):
|
||||
:param allow_private_network: 是否允许访问本机或私网地址
|
||||
:return: JSON 格式的识别结果
|
||||
"""
|
||||
logger.info(f"执行工具: {self.name}, 参数: image_url={image_url}")
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, "
|
||||
f"参数: image_url={self._format_image_url_for_log(image_url)}"
|
||||
)
|
||||
|
||||
try:
|
||||
captcha_text = await self.run_blocking(
|
||||
|
||||
@@ -31,7 +31,7 @@ class OcrHelper:
|
||||
if image_url:
|
||||
data_url_b64 = self._extract_data_url_base64(image_url)
|
||||
if data_url_b64:
|
||||
image_b64 = data_url_b64
|
||||
image_b64 = self._normalize_image_base64(data_url_b64)
|
||||
else:
|
||||
ret = RequestUtils(ua=ua,
|
||||
cookies=cookie).get_res(image_url)
|
||||
@@ -54,7 +54,14 @@ class OcrHelper:
|
||||
"""规范化外部传入的图片 base64 内容。"""
|
||||
if not image_b64:
|
||||
return ""
|
||||
return OcrHelper._extract_data_url_base64(image_b64) or image_b64.strip()
|
||||
clean_image_b64 = OcrHelper._extract_data_url_base64(image_b64) or image_b64
|
||||
clean_image_b64 = "".join(clean_image_b64.split())
|
||||
if not clean_image_b64:
|
||||
return ""
|
||||
padding_size = len(clean_image_b64) % 4
|
||||
if padding_size:
|
||||
clean_image_b64 = f"{clean_image_b64}{'=' * (4 - padding_size)}"
|
||||
return clean_image_b64
|
||||
|
||||
@staticmethod
|
||||
def _extract_data_url_base64(image_url: Optional[str]) -> str:
|
||||
|
||||
@@ -82,6 +82,35 @@ def test_ocr_helper_extracts_data_url_base64_without_downloading_image():
|
||||
}
|
||||
|
||||
|
||||
def test_ocr_helper_normalizes_data_url_base64_padding():
|
||||
"""data:image 地址缺少 padding 时应补齐后提交给 OCR 服务。"""
|
||||
image_url = "data:image/jpeg;base64,YWJjZA"
|
||||
|
||||
with patch("app.helper.ocr.RequestUtils") as request_utils:
|
||||
request_utils.return_value.post_res.return_value = _FakeResponse(
|
||||
payload={"result": "z9k2"}
|
||||
)
|
||||
|
||||
result = OcrHelper().get_captcha_text(image_url=image_url)
|
||||
|
||||
assert result == "z9k2"
|
||||
request_utils.return_value.get_res.assert_not_called()
|
||||
assert request_utils.return_value.post_res.call_args.kwargs["json"] == {
|
||||
"base64_img": "YWJjZA=="
|
||||
}
|
||||
|
||||
|
||||
def test_recognize_captcha_tool_formats_data_url_for_log():
|
||||
"""验证码工具日志应隐藏 data:image 的完整图片内容。"""
|
||||
image_b64 = base64.b64encode(b"captcha-image").decode()
|
||||
image_url = f"data:image/jpeg;base64,{image_b64}"
|
||||
|
||||
result = RecognizeCaptchaTool._format_image_url_for_log(image_url)
|
||||
|
||||
assert result == f"data:image/jpeg;base64,<base64:{len(image_b64)} chars>"
|
||||
assert image_b64 not in result
|
||||
|
||||
|
||||
def test_recognize_captcha_tool_returns_captcha_text_from_ocr_helper():
|
||||
"""验证码工具应返回结构化识别结果,便于 Agent 继续填写表单。"""
|
||||
tool = RecognizeCaptchaTool(session_id="captcha-session", user_id="10001")
|
||||
|
||||
@@ -3,9 +3,10 @@ import unittest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from app.agent.middleware import tool_selection as tool_selector_module
|
||||
from app.agent.tools.tags import ToolTag
|
||||
|
||||
|
||||
class _FakeBoundModel:
|
||||
@@ -60,6 +61,11 @@ class _FakeRequest:
|
||||
return _FakeRequest(**data)
|
||||
|
||||
|
||||
def _tool(name, description, tags=None):
|
||||
"""构造测试用工具对象。"""
|
||||
return SimpleNamespace(name=name, description=description, tags=tags or [])
|
||||
|
||||
|
||||
class ToolSelectorMiddlewareTest(unittest.TestCase):
|
||||
def test_awrap_model_call_uses_json_mode_for_deepseek(self):
|
||||
tools = [
|
||||
@@ -223,14 +229,115 @@ class ToolSelectorMiddlewareTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(normalized, {"tools": ["search"]})
|
||||
|
||||
def test_process_selection_response_completes_low_count_tool_chain(self):
|
||||
"""筛选结果过少时应按已命中的工具链补齐相邻工具。"""
|
||||
def test_deepseek_selection_uses_recent_conversation_context(self):
|
||||
"""多轮追问时工具筛选应看到上一轮用户需求和助手回复。"""
|
||||
tools = [
|
||||
SimpleNamespace(name="search_media", description="Search media"),
|
||||
SimpleNamespace(name="search_torrents", description="Search torrents"),
|
||||
SimpleNamespace(name="get_search_results", description="Get results"),
|
||||
SimpleNamespace(name="add_download_tasks", description="Add downloads"),
|
||||
SimpleNamespace(name="query_download_tasks", description="Query downloads"),
|
||||
_tool(
|
||||
"query_plugin_config",
|
||||
"Query plugin config",
|
||||
[ToolTag.Read, ToolTag.Plugin, ToolTag.Settings],
|
||||
),
|
||||
_tool(
|
||||
"update_plugin_config",
|
||||
"Update plugin config",
|
||||
[ToolTag.Write, ToolTag.Plugin, ToolTag.Settings],
|
||||
),
|
||||
_tool(
|
||||
"reload_plugin",
|
||||
"Reload plugin",
|
||||
[ToolTag.Write, ToolTag.Plugin],
|
||||
),
|
||||
]
|
||||
model = _FakeModel(content='{"tools": ["query_plugin_config"]}')
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=3,
|
||||
selection_tools=tools,
|
||||
)
|
||||
middleware.model = model
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[
|
||||
HumanMessage(content="帮我检查插件 DemoPlugin 的配置"),
|
||||
AIMessage(content="我建议先查询插件配置,然后根据结果决定是否重载插件。"),
|
||||
HumanMessage(content="按你说的来"),
|
||||
],
|
||||
model=model,
|
||||
)
|
||||
|
||||
state_update = asyncio.run(
|
||||
middleware.abefore_agent(request.state, runtime=None, config=None)
|
||||
)
|
||||
|
||||
user_message = model.bound_model.messages[1]
|
||||
self.assertEqual(
|
||||
state_update,
|
||||
{"selected_tool_names": ["query_plugin_config", "update_plugin_config", "reload_plugin"]},
|
||||
)
|
||||
self.assertIsInstance(user_message, HumanMessage)
|
||||
self.assertIn(
|
||||
"Recent conversation context for tool selection",
|
||||
user_message.content,
|
||||
)
|
||||
self.assertIn("帮我检查插件 DemoPlugin 的配置", user_message.content)
|
||||
self.assertIn("我建议先查询插件配置", user_message.content)
|
||||
self.assertIn("按你说的来", user_message.content)
|
||||
|
||||
def test_single_turn_selection_keeps_original_user_message(self):
|
||||
"""单轮对话不应额外包裹上下文提示。"""
|
||||
tools = [
|
||||
_tool("search", "Search for information", [ToolTag.Read, ToolTag.Web]),
|
||||
_tool("calendar", "Manage events", [ToolTag.Write]),
|
||||
]
|
||||
model = _FakeModel(content='{"tools": ["search"]}')
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
selection_tools=tools,
|
||||
)
|
||||
middleware.model = model
|
||||
original_message = HumanMessage(content="帮我查一下最近的更新")
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[original_message],
|
||||
model=model,
|
||||
)
|
||||
|
||||
asyncio.run(middleware.abefore_agent(request.state, runtime=None, config=None))
|
||||
|
||||
user_message = model.bound_model.messages[1]
|
||||
self.assertIs(user_message, original_message)
|
||||
self.assertNotIn(
|
||||
"Recent conversation context for tool selection",
|
||||
user_message.content,
|
||||
)
|
||||
|
||||
def test_process_selection_response_completes_low_count_tool_group_by_tags(self):
|
||||
"""筛选结果过少时应按已命中的工具标签组补齐同组工具。"""
|
||||
tools = [
|
||||
_tool(
|
||||
"search_media",
|
||||
"Search media",
|
||||
[ToolTag.Read, ToolTag.Media],
|
||||
),
|
||||
_tool(
|
||||
"search_torrents",
|
||||
"Search torrents",
|
||||
[ToolTag.Read, ToolTag.Resource, ToolTag.Site, ToolTag.Media],
|
||||
),
|
||||
_tool(
|
||||
"get_search_results",
|
||||
"Get results",
|
||||
[ToolTag.Read, ToolTag.Resource],
|
||||
),
|
||||
_tool(
|
||||
"add_download_tasks",
|
||||
"Add downloads",
|
||||
[ToolTag.Write, ToolTag.Download, ToolTag.Resource],
|
||||
),
|
||||
_tool(
|
||||
"query_download_tasks",
|
||||
"Query downloads",
|
||||
[ToolTag.Read, ToolTag.Download],
|
||||
),
|
||||
]
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=4,
|
||||
@@ -243,20 +350,21 @@ class ToolSelectorMiddlewareTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
result = middleware._process_selection_response(
|
||||
{"tools": ["search_media"]},
|
||||
{"tools": ["search_torrents"]},
|
||||
available_tools=tools,
|
||||
valid_tool_names=[tool.name for tool in tools],
|
||||
request=request,
|
||||
)
|
||||
|
||||
self.assertEqual(len(result.tools), 4)
|
||||
self.assertEqual(
|
||||
[tool.name for tool in result.tools],
|
||||
[
|
||||
{tool.name for tool in result.tools},
|
||||
{
|
||||
"search_media",
|
||||
"search_torrents",
|
||||
"get_search_results",
|
||||
"add_download_tasks",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
def test_process_selection_response_keeps_high_count_selection(self):
|
||||
@@ -302,12 +410,28 @@ class ToolSelectorMiddlewareTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_process_selection_response_respects_max_tools_when_completing(self):
|
||||
"""工具链补齐不应突破 max_tools 上限。"""
|
||||
"""标签组补齐不应突破 max_tools 上限。"""
|
||||
tools = [
|
||||
SimpleNamespace(name="list_directory", description="List directory"),
|
||||
SimpleNamespace(name="query_directory_settings", description="Query settings"),
|
||||
SimpleNamespace(name="recognize_media", description="Recognize media"),
|
||||
SimpleNamespace(name="transfer_file", description="Transfer file"),
|
||||
_tool(
|
||||
"list_directory",
|
||||
"List directory",
|
||||
[ToolTag.Read, ToolTag.Directory, ToolTag.File],
|
||||
),
|
||||
_tool(
|
||||
"query_directory_settings",
|
||||
"Query settings",
|
||||
[ToolTag.Read, ToolTag.Directory, ToolTag.Settings],
|
||||
),
|
||||
_tool(
|
||||
"recognize_media",
|
||||
"Recognize media",
|
||||
[ToolTag.Read, ToolTag.Media],
|
||||
),
|
||||
_tool(
|
||||
"transfer_file",
|
||||
"Transfer file",
|
||||
[ToolTag.Write, ToolTag.Transfer, ToolTag.Library, ToolTag.File],
|
||||
),
|
||||
]
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
@@ -331,3 +455,29 @@ class ToolSelectorMiddlewareTest(unittest.TestCase):
|
||||
{tool.name for tool in result.tools},
|
||||
{"transfer_file", "list_directory"},
|
||||
)
|
||||
|
||||
def test_process_selection_response_ignores_generic_tags_when_completing(self):
|
||||
"""通用权限标签不应被当作工具组使用。"""
|
||||
tools = [
|
||||
_tool("read_one", "Read one", [ToolTag.Read]),
|
||||
_tool("read_two", "Read two", [ToolTag.Read]),
|
||||
_tool("write_one", "Write one", [ToolTag.Write, ToolTag.Admin]),
|
||||
]
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=4,
|
||||
selection_tools=tools,
|
||||
)
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="查一下信息")],
|
||||
model=_FakeModel(),
|
||||
)
|
||||
|
||||
result = middleware._process_selection_response(
|
||||
{"tools": ["read_one"]},
|
||||
available_tools=tools,
|
||||
valid_tool_names=[tool.name for tool in tools],
|
||||
request=request,
|
||||
)
|
||||
|
||||
self.assertEqual([tool.name for tool in result.tools], ["read_one"])
|
||||
|
||||
Reference in New Issue
Block a user