mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-30 04:01:54 +08:00
feat(activity_log): simplify user text matching and enhance logging for captcha tool
This commit is contained in:
@@ -82,6 +82,35 @@ def test_ocr_helper_extracts_data_url_base64_without_downloading_image():
|
||||
}
|
||||
|
||||
|
||||
def test_ocr_helper_normalizes_data_url_base64_padding():
|
||||
"""data:image 地址缺少 padding 时应补齐后提交给 OCR 服务。"""
|
||||
image_url = "data:image/jpeg;base64,YWJjZA"
|
||||
|
||||
with patch("app.helper.ocr.RequestUtils") as request_utils:
|
||||
request_utils.return_value.post_res.return_value = _FakeResponse(
|
||||
payload={"result": "z9k2"}
|
||||
)
|
||||
|
||||
result = OcrHelper().get_captcha_text(image_url=image_url)
|
||||
|
||||
assert result == "z9k2"
|
||||
request_utils.return_value.get_res.assert_not_called()
|
||||
assert request_utils.return_value.post_res.call_args.kwargs["json"] == {
|
||||
"base64_img": "YWJjZA=="
|
||||
}
|
||||
|
||||
|
||||
def test_recognize_captcha_tool_formats_data_url_for_log():
|
||||
"""验证码工具日志应隐藏 data:image 的完整图片内容。"""
|
||||
image_b64 = base64.b64encode(b"captcha-image").decode()
|
||||
image_url = f"data:image/jpeg;base64,{image_b64}"
|
||||
|
||||
result = RecognizeCaptchaTool._format_image_url_for_log(image_url)
|
||||
|
||||
assert result == f"data:image/jpeg;base64,<base64:{len(image_b64)} chars>"
|
||||
assert image_b64 not in result
|
||||
|
||||
|
||||
def test_recognize_captcha_tool_returns_captcha_text_from_ocr_helper():
|
||||
"""验证码工具应返回结构化识别结果,便于 Agent 继续填写表单。"""
|
||||
tool = RecognizeCaptchaTool(session_id="captcha-session", user_id="10001")
|
||||
|
||||
@@ -3,9 +3,10 @@ import unittest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from app.agent.middleware import tool_selection as tool_selector_module
|
||||
from app.agent.tools.tags import ToolTag
|
||||
|
||||
|
||||
class _FakeBoundModel:
|
||||
@@ -60,6 +61,11 @@ class _FakeRequest:
|
||||
return _FakeRequest(**data)
|
||||
|
||||
|
||||
def _tool(name, description, tags=None):
|
||||
"""构造测试用工具对象。"""
|
||||
return SimpleNamespace(name=name, description=description, tags=tags or [])
|
||||
|
||||
|
||||
class ToolSelectorMiddlewareTest(unittest.TestCase):
|
||||
def test_awrap_model_call_uses_json_mode_for_deepseek(self):
|
||||
tools = [
|
||||
@@ -223,14 +229,115 @@ class ToolSelectorMiddlewareTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(normalized, {"tools": ["search"]})
|
||||
|
||||
def test_process_selection_response_completes_low_count_tool_chain(self):
|
||||
"""筛选结果过少时应按已命中的工具链补齐相邻工具。"""
|
||||
def test_deepseek_selection_uses_recent_conversation_context(self):
|
||||
"""多轮追问时工具筛选应看到上一轮用户需求和助手回复。"""
|
||||
tools = [
|
||||
SimpleNamespace(name="search_media", description="Search media"),
|
||||
SimpleNamespace(name="search_torrents", description="Search torrents"),
|
||||
SimpleNamespace(name="get_search_results", description="Get results"),
|
||||
SimpleNamespace(name="add_download_tasks", description="Add downloads"),
|
||||
SimpleNamespace(name="query_download_tasks", description="Query downloads"),
|
||||
_tool(
|
||||
"query_plugin_config",
|
||||
"Query plugin config",
|
||||
[ToolTag.Read, ToolTag.Plugin, ToolTag.Settings],
|
||||
),
|
||||
_tool(
|
||||
"update_plugin_config",
|
||||
"Update plugin config",
|
||||
[ToolTag.Write, ToolTag.Plugin, ToolTag.Settings],
|
||||
),
|
||||
_tool(
|
||||
"reload_plugin",
|
||||
"Reload plugin",
|
||||
[ToolTag.Write, ToolTag.Plugin],
|
||||
),
|
||||
]
|
||||
model = _FakeModel(content='{"tools": ["query_plugin_config"]}')
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=3,
|
||||
selection_tools=tools,
|
||||
)
|
||||
middleware.model = model
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[
|
||||
HumanMessage(content="帮我检查插件 DemoPlugin 的配置"),
|
||||
AIMessage(content="我建议先查询插件配置,然后根据结果决定是否重载插件。"),
|
||||
HumanMessage(content="按你说的来"),
|
||||
],
|
||||
model=model,
|
||||
)
|
||||
|
||||
state_update = asyncio.run(
|
||||
middleware.abefore_agent(request.state, runtime=None, config=None)
|
||||
)
|
||||
|
||||
user_message = model.bound_model.messages[1]
|
||||
self.assertEqual(
|
||||
state_update,
|
||||
{"selected_tool_names": ["query_plugin_config", "update_plugin_config", "reload_plugin"]},
|
||||
)
|
||||
self.assertIsInstance(user_message, HumanMessage)
|
||||
self.assertIn(
|
||||
"Recent conversation context for tool selection",
|
||||
user_message.content,
|
||||
)
|
||||
self.assertIn("帮我检查插件 DemoPlugin 的配置", user_message.content)
|
||||
self.assertIn("我建议先查询插件配置", user_message.content)
|
||||
self.assertIn("按你说的来", user_message.content)
|
||||
|
||||
def test_single_turn_selection_keeps_original_user_message(self):
|
||||
"""单轮对话不应额外包裹上下文提示。"""
|
||||
tools = [
|
||||
_tool("search", "Search for information", [ToolTag.Read, ToolTag.Web]),
|
||||
_tool("calendar", "Manage events", [ToolTag.Write]),
|
||||
]
|
||||
model = _FakeModel(content='{"tools": ["search"]}')
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
selection_tools=tools,
|
||||
)
|
||||
middleware.model = model
|
||||
original_message = HumanMessage(content="帮我查一下最近的更新")
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[original_message],
|
||||
model=model,
|
||||
)
|
||||
|
||||
asyncio.run(middleware.abefore_agent(request.state, runtime=None, config=None))
|
||||
|
||||
user_message = model.bound_model.messages[1]
|
||||
self.assertIs(user_message, original_message)
|
||||
self.assertNotIn(
|
||||
"Recent conversation context for tool selection",
|
||||
user_message.content,
|
||||
)
|
||||
|
||||
def test_process_selection_response_completes_low_count_tool_group_by_tags(self):
|
||||
"""筛选结果过少时应按已命中的工具标签组补齐同组工具。"""
|
||||
tools = [
|
||||
_tool(
|
||||
"search_media",
|
||||
"Search media",
|
||||
[ToolTag.Read, ToolTag.Media],
|
||||
),
|
||||
_tool(
|
||||
"search_torrents",
|
||||
"Search torrents",
|
||||
[ToolTag.Read, ToolTag.Resource, ToolTag.Site, ToolTag.Media],
|
||||
),
|
||||
_tool(
|
||||
"get_search_results",
|
||||
"Get results",
|
||||
[ToolTag.Read, ToolTag.Resource],
|
||||
),
|
||||
_tool(
|
||||
"add_download_tasks",
|
||||
"Add downloads",
|
||||
[ToolTag.Write, ToolTag.Download, ToolTag.Resource],
|
||||
),
|
||||
_tool(
|
||||
"query_download_tasks",
|
||||
"Query downloads",
|
||||
[ToolTag.Read, ToolTag.Download],
|
||||
),
|
||||
]
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=4,
|
||||
@@ -243,20 +350,21 @@ class ToolSelectorMiddlewareTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
result = middleware._process_selection_response(
|
||||
{"tools": ["search_media"]},
|
||||
{"tools": ["search_torrents"]},
|
||||
available_tools=tools,
|
||||
valid_tool_names=[tool.name for tool in tools],
|
||||
request=request,
|
||||
)
|
||||
|
||||
self.assertEqual(len(result.tools), 4)
|
||||
self.assertEqual(
|
||||
[tool.name for tool in result.tools],
|
||||
[
|
||||
{tool.name for tool in result.tools},
|
||||
{
|
||||
"search_media",
|
||||
"search_torrents",
|
||||
"get_search_results",
|
||||
"add_download_tasks",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
def test_process_selection_response_keeps_high_count_selection(self):
|
||||
@@ -302,12 +410,28 @@ class ToolSelectorMiddlewareTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_process_selection_response_respects_max_tools_when_completing(self):
|
||||
"""工具链补齐不应突破 max_tools 上限。"""
|
||||
"""标签组补齐不应突破 max_tools 上限。"""
|
||||
tools = [
|
||||
SimpleNamespace(name="list_directory", description="List directory"),
|
||||
SimpleNamespace(name="query_directory_settings", description="Query settings"),
|
||||
SimpleNamespace(name="recognize_media", description="Recognize media"),
|
||||
SimpleNamespace(name="transfer_file", description="Transfer file"),
|
||||
_tool(
|
||||
"list_directory",
|
||||
"List directory",
|
||||
[ToolTag.Read, ToolTag.Directory, ToolTag.File],
|
||||
),
|
||||
_tool(
|
||||
"query_directory_settings",
|
||||
"Query settings",
|
||||
[ToolTag.Read, ToolTag.Directory, ToolTag.Settings],
|
||||
),
|
||||
_tool(
|
||||
"recognize_media",
|
||||
"Recognize media",
|
||||
[ToolTag.Read, ToolTag.Media],
|
||||
),
|
||||
_tool(
|
||||
"transfer_file",
|
||||
"Transfer file",
|
||||
[ToolTag.Write, ToolTag.Transfer, ToolTag.Library, ToolTag.File],
|
||||
),
|
||||
]
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
@@ -331,3 +455,29 @@ class ToolSelectorMiddlewareTest(unittest.TestCase):
|
||||
{tool.name for tool in result.tools},
|
||||
{"transfer_file", "list_directory"},
|
||||
)
|
||||
|
||||
def test_process_selection_response_ignores_generic_tags_when_completing(self):
|
||||
"""通用权限标签不应被当作工具组使用。"""
|
||||
tools = [
|
||||
_tool("read_one", "Read one", [ToolTag.Read]),
|
||||
_tool("read_two", "Read two", [ToolTag.Read]),
|
||||
_tool("write_one", "Write one", [ToolTag.Write, ToolTag.Admin]),
|
||||
]
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=4,
|
||||
selection_tools=tools,
|
||||
)
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="查一下信息")],
|
||||
model=_FakeModel(),
|
||||
)
|
||||
|
||||
result = middleware._process_selection_response(
|
||||
{"tools": ["read_one"]},
|
||||
available_tools=tools,
|
||||
valid_tool_names=[tool.name for tool in tools],
|
||||
request=request,
|
||||
)
|
||||
|
||||
self.assertEqual([tool.name for tool in result.tools], ["read_one"])
|
||||
|
||||
Reference in New Issue
Block a user