mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-01 13:40:54 +08:00
482 lines
18 KiB
Python
482 lines
18 KiB
Python
import asyncio
|
|
import importlib.machinery
|
|
import sys
|
|
import unittest
|
|
from types import SimpleNamespace
|
|
from types import ModuleType
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
|
|
def _stub_module(name: str, **attrs):
|
|
module = sys.modules.get(name)
|
|
if module is None:
|
|
module = ModuleType(name)
|
|
sys.modules[name] = module
|
|
for key, value in attrs.items():
|
|
setattr(module, key, value)
|
|
return module
|
|
|
|
|
|
_stub_module("qbittorrentapi", TorrentFilesList=list)
|
|
_stub_module("transmission_rpc", File=object)
|
|
_stub_module(
|
|
"psutil",
|
|
__spec__=importlib.machinery.ModuleSpec("psutil", loader=None),
|
|
)
|
|
|
|
from app.agent.tools.factory import MoviePilotToolFactory
|
|
from app.agent import ReplyMode
|
|
from app.chain.search import SearchChain
|
|
from app.core.config import settings
|
|
from app.modules.indexer import IndexerModule
|
|
from app.schemas.types import MediaType
|
|
|
|
|
|
def _make_result(title: str, size: int, seeders: int):
|
|
return SimpleNamespace(
|
|
torrent_info=SimpleNamespace(title=title, size=size, seeders=seeders)
|
|
)
|
|
|
|
|
|
class SearchChainAIRecommendTest(unittest.IsolatedAsyncioTestCase):
|
|
def setUp(self):
|
|
SearchChain._ai_recommend_running = False
|
|
SearchChain._ai_recommend_task = None
|
|
SearchChain._current_recommend_request_hash = None
|
|
SearchChain._ai_recommend_result = None
|
|
SearchChain._ai_recommend_error = None
|
|
|
|
async def asyncTearDown(self):
|
|
task = SearchChain._ai_recommend_task
|
|
if task and not task.done():
|
|
task.cancel()
|
|
try:
|
|
await task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
SearchChain._ai_recommend_running = False
|
|
SearchChain._ai_recommend_task = None
|
|
SearchChain._current_recommend_request_hash = None
|
|
SearchChain._ai_recommend_result = None
|
|
SearchChain._ai_recommend_error = None
|
|
|
|
@staticmethod
|
|
def _make_chain() -> SearchChain:
|
|
chain = object.__new__(SearchChain)
|
|
chain.load_cache = lambda _filename: None
|
|
chain.save_cache = lambda _cache, _filename: None
|
|
chain.remove_cache = lambda _filename: None
|
|
chain.get_search_page_size = IndexerModule.get_search_page_size
|
|
return chain
|
|
|
|
async def test_start_recommend_task_restores_original_indices(self):
|
|
chain = self._make_chain()
|
|
saved = []
|
|
chain.save_cache = lambda cache, filename: saved.append((filename, cache))
|
|
results = [_make_result(f"item-{index}", 1024 * (index + 1), index) for index in range(7)]
|
|
|
|
with (
|
|
patch.object(settings, "AI_AGENT_ENABLE", True, create=True),
|
|
patch.object(settings, "AI_RECOMMEND_ENABLED", True, create=True),
|
|
patch.object(settings, "AI_RECOMMEND_MAX_ITEMS", 50, create=True),
|
|
patch.object(
|
|
settings,
|
|
"AI_RECOMMEND_USER_PREFERENCE",
|
|
"Prefer high seeders",
|
|
create=True,
|
|
),
|
|
patch.object(
|
|
SearchChain,
|
|
"_invoke_recommend_llm",
|
|
new=AsyncMock(return_value='[1, 0, 1, "bad", 9]'),
|
|
),
|
|
):
|
|
chain.start_recommend_task(
|
|
filtered_indices=[2, 4, 6],
|
|
search_results_count=len(results),
|
|
results=results,
|
|
)
|
|
self.assertIsNotNone(SearchChain._ai_recommend_task)
|
|
await SearchChain._ai_recommend_task
|
|
|
|
self.assertEqual([4, 2], SearchChain._ai_recommend_result)
|
|
self.assertEqual(
|
|
[("__ai_recommend_indices__", [4, 2])],
|
|
saved,
|
|
)
|
|
self.assertFalse(SearchChain._ai_recommend_running)
|
|
self.assertIsNone(SearchChain._ai_recommend_task)
|
|
|
|
async def test_invoke_recommend_llm_disables_output_message_persistence(self):
|
|
chain = self._make_chain()
|
|
from app.agent import agent_manager
|
|
from app.agent.prompt import prompt_manager
|
|
|
|
captured = {}
|
|
|
|
async def _fake_run_background_prompt(**kwargs):
|
|
captured.update(kwargs)
|
|
kwargs["output_callback"]("[0, 2]")
|
|
|
|
with (
|
|
patch.object(
|
|
prompt_manager,
|
|
"render_system_task_message",
|
|
return_value="PROMPT",
|
|
),
|
|
patch.object(
|
|
agent_manager,
|
|
"run_background_prompt",
|
|
new=AsyncMock(side_effect=_fake_run_background_prompt),
|
|
),
|
|
):
|
|
result = await chain._invoke_recommend_llm("Candidates")
|
|
|
|
self.assertEqual("[0, 2]", result)
|
|
self.assertEqual(ReplyMode.CAPTURE_ONLY, captured["reply_mode"])
|
|
self.assertFalse(captured["persist_output_message"])
|
|
self.assertFalse(captured["allow_message_tools"])
|
|
|
|
def test_search_by_title_clears_previous_recommend_state_when_caching(self):
|
|
chain = self._make_chain()
|
|
removed = []
|
|
cached = []
|
|
chain.remove_cache = lambda filename: removed.append(filename)
|
|
chain.save_cache = lambda cache, filename: cached.append((filename, cache))
|
|
chain._SearchChain__search_all_sites = lambda keyword, sites, page: [
|
|
SimpleNamespace(title="Test Title", description="Test Desc")
|
|
]
|
|
|
|
SearchChain._current_recommend_request_hash = "stale-hash"
|
|
SearchChain._ai_recommend_result = [3, 1]
|
|
SearchChain._ai_recommend_error = "stale-error"
|
|
|
|
results = chain.search_by_title("keyword", cache_local=True)
|
|
|
|
self.assertEqual(1, len(results))
|
|
self.assertEqual(["__ai_recommend_indices__"], removed)
|
|
self.assertTrue(any(filename == "__search_result__" for filename, _ in cached))
|
|
self.assertTrue(any(filename == "__search_params__" for filename, _ in cached))
|
|
self.assertIsNone(SearchChain._current_recommend_request_hash)
|
|
self.assertIsNone(SearchChain._ai_recommend_result)
|
|
self.assertIsNone(SearchChain._ai_recommend_error)
|
|
|
|
def test_build_search_pages_uses_search_resource_pages_setting(self):
|
|
with patch.object(settings, "SEARCH_RESOURCE_PAGES", 3, create=True):
|
|
self.assertEqual([2, 3, 4], SearchChain._build_search_pages(page=2))
|
|
|
|
def test_build_search_pages_falls_back_to_one_page_for_invalid_setting(self):
|
|
with patch.object(settings, "SEARCH_RESOURCE_PAGES", 0, create=True):
|
|
self.assertEqual([0], SearchChain._build_search_pages(page=0))
|
|
with patch.object(settings, "SEARCH_RESOURCE_PAGES", "bad", create=True):
|
|
self.assertEqual([0], SearchChain._build_search_pages(page="bad"))
|
|
|
|
def test_search_all_sites_stops_after_short_page(self):
|
|
"""
|
|
验证普通站点默认按 100 条判断是否继续翻页。
|
|
"""
|
|
chain = self._make_chain()
|
|
requested_pages = []
|
|
|
|
def search_torrents(**kwargs):
|
|
"""
|
|
模拟前两页满页、第三页不足 100 条,验证不会继续请求第四页。
|
|
"""
|
|
page = kwargs["page"]
|
|
requested_pages.append(page)
|
|
count = 100 if page in (0, 1) else 1
|
|
return [
|
|
SimpleNamespace(title=f"Result Page {page}-{index}", description="")
|
|
for index in range(count)
|
|
]
|
|
|
|
chain.search_torrents = search_torrents
|
|
|
|
with (
|
|
patch.object(settings, "SEARCH_RESOURCE_PAGES", 4, create=True),
|
|
patch("app.chain.search.SystemConfigOper") as system_config_oper,
|
|
patch("app.chain.search.SitesHelper") as sites_helper,
|
|
patch("app.chain.search.ProgressHelper") as progress_helper,
|
|
):
|
|
system_config_oper.return_value.get.return_value = [1]
|
|
sites_helper.return_value.get_indexers.return_value = [
|
|
{"id": 1, "name": "测试站点"}
|
|
]
|
|
progress_helper.return_value = SimpleNamespace(
|
|
start=lambda: None,
|
|
update=lambda **_kwargs: None,
|
|
end=lambda: None,
|
|
)
|
|
|
|
results = chain._SearchChain__search_all_sites(
|
|
keyword="keyword",
|
|
sites=None,
|
|
page=0,
|
|
)
|
|
|
|
self.assertEqual([0, 1, 2], sorted(requested_pages))
|
|
self.assertEqual(201, len(results))
|
|
|
|
def test_search_all_sites_uses_configured_result_num_for_common_site(self):
|
|
"""
|
|
验证普通配置站点按 result_num 判断是否继续翻页。
|
|
"""
|
|
chain = self._make_chain()
|
|
requested_pages = []
|
|
|
|
def search_torrents(**kwargs):
|
|
"""
|
|
模拟配置站点每页 50 条,第二页不足 50 条后停止。
|
|
"""
|
|
page = kwargs["page"]
|
|
requested_pages.append(page)
|
|
count = 50 if page == 0 else 49
|
|
return [
|
|
SimpleNamespace(title=f"Result Page {page}-{index}", description="")
|
|
for index in range(count)
|
|
]
|
|
|
|
chain.search_torrents = search_torrents
|
|
|
|
with (
|
|
patch.object(settings, "SEARCH_RESOURCE_PAGES", 3, create=True),
|
|
patch("app.chain.search.SystemConfigOper") as system_config_oper,
|
|
patch("app.chain.search.SitesHelper") as sites_helper,
|
|
patch("app.chain.search.ProgressHelper") as progress_helper,
|
|
):
|
|
system_config_oper.return_value.get.return_value = [1]
|
|
sites_helper.return_value.get_indexers.return_value = [
|
|
{"id": 1, "name": "测试站点", "result_num": 50}
|
|
]
|
|
progress_helper.return_value = SimpleNamespace(
|
|
start=lambda: None,
|
|
update=lambda **_kwargs: None,
|
|
end=lambda: None,
|
|
)
|
|
|
|
results = chain._SearchChain__search_all_sites(
|
|
keyword="keyword",
|
|
sites=None,
|
|
page=0,
|
|
)
|
|
|
|
self.assertEqual([0, 1], requested_pages)
|
|
self.assertEqual(99, len(results))
|
|
|
|
def test_search_all_sites_uses_parser_page_size_for_yema(self):
|
|
"""
|
|
验证专用解析器按自身页容量判断,避免 Yema 的 40 条分页被误停。
|
|
"""
|
|
chain = self._make_chain()
|
|
requested_pages = []
|
|
|
|
def search_torrents(**kwargs):
|
|
"""
|
|
模拟 Yema 第一页满 40 条,第二页不足 40 条后停止。
|
|
"""
|
|
page = kwargs["page"]
|
|
requested_pages.append(page)
|
|
count = 40 if page == 0 else 39
|
|
return [
|
|
SimpleNamespace(title=f"Result Page {page}-{index}", description="")
|
|
for index in range(count)
|
|
]
|
|
|
|
chain.search_torrents = search_torrents
|
|
|
|
with (
|
|
patch.object(settings, "SEARCH_RESOURCE_PAGES", 3, create=True),
|
|
patch("app.chain.search.SystemConfigOper") as system_config_oper,
|
|
patch("app.chain.search.SitesHelper") as sites_helper,
|
|
patch("app.chain.search.ProgressHelper") as progress_helper,
|
|
):
|
|
system_config_oper.return_value.get.return_value = [1]
|
|
sites_helper.return_value.get_indexers.return_value = [
|
|
{"id": 1, "name": "测试站点", "parser": "Yema"}
|
|
]
|
|
progress_helper.return_value = SimpleNamespace(
|
|
start=lambda: None,
|
|
update=lambda **_kwargs: None,
|
|
end=lambda: None,
|
|
)
|
|
|
|
results = chain._SearchChain__search_all_sites(
|
|
keyword="keyword",
|
|
sites=None,
|
|
page=0,
|
|
)
|
|
|
|
self.assertEqual([0, 1], requested_pages)
|
|
self.assertEqual(79, len(results))
|
|
|
|
def test_indexer_module_search_page_size_uses_spider_metadata(self):
|
|
"""
|
|
验证站点单页容量由索引器模块统一读取,避免搜索链写死 parser 容量。
|
|
"""
|
|
self.assertEqual(
|
|
40,
|
|
IndexerModule.get_search_page_size({"parser": "Yema"}, keyword="keyword")
|
|
)
|
|
self.assertEqual(
|
|
50,
|
|
IndexerModule.get_search_page_size({"result_num": 50}, keyword="keyword")
|
|
)
|
|
self.assertIsNone(
|
|
IndexerModule.get_search_page_size({"parser": "Haidan"}, keyword="keyword")
|
|
)
|
|
self.assertIsNone(
|
|
IndexerModule.get_search_page_size({"parser": "TorrentLeech"}, keyword="keyword")
|
|
)
|
|
|
|
async def test_async_search_all_sites_stops_after_empty_page(self):
|
|
"""
|
|
验证异步搜索遇到空页后停止后续翻页。
|
|
"""
|
|
chain = self._make_chain()
|
|
requested_pages = []
|
|
|
|
async def async_search_torrents(**kwargs):
|
|
"""
|
|
模拟第二页为空,验证异步搜索不会继续请求后续页。
|
|
"""
|
|
page = kwargs["page"]
|
|
requested_pages.append(page)
|
|
count = 100 if page == 0 else 0
|
|
return [
|
|
SimpleNamespace(title=f"Result Page {page}-{index}", description="")
|
|
for index in range(count)
|
|
]
|
|
|
|
chain.async_search_torrents = async_search_torrents
|
|
|
|
with (
|
|
patch.object(settings, "SEARCH_RESOURCE_PAGES", 4, create=True),
|
|
patch("app.chain.search.SystemConfigOper") as system_config_oper,
|
|
patch("app.chain.search.SitesHelper") as sites_helper,
|
|
patch("app.chain.search.ProgressHelper") as progress_helper,
|
|
):
|
|
system_config_oper.return_value.get.return_value = [1]
|
|
sites_helper.return_value.async_get_indexers = AsyncMock(
|
|
return_value=[{"id": 1, "name": "测试站点"}]
|
|
)
|
|
progress_helper.return_value = SimpleNamespace(
|
|
start=lambda: None,
|
|
update=lambda **_kwargs: None,
|
|
end=lambda: None,
|
|
)
|
|
|
|
results = await chain._SearchChain__async_search_all_sites(
|
|
keyword="keyword",
|
|
sites=None,
|
|
page=0,
|
|
)
|
|
|
|
self.assertEqual([0, 1], requested_pages)
|
|
self.assertEqual(100, len(results))
|
|
|
|
async def test_async_search_all_sites_stream_stops_after_short_page(self):
|
|
"""
|
|
验证渐进式搜索遇到非满页后停止后续翻页。
|
|
"""
|
|
chain = self._make_chain()
|
|
requested_pages = []
|
|
|
|
async def async_search_torrents(**kwargs):
|
|
"""
|
|
模拟渐进式搜索第二页不足 100 条,验证不会继续请求第三页。
|
|
"""
|
|
page = kwargs["page"]
|
|
requested_pages.append(page)
|
|
count = 100 if page == 0 else 99
|
|
return [
|
|
SimpleNamespace(title=f"Result Page {page}-{index}", description="")
|
|
for index in range(count)
|
|
]
|
|
|
|
chain.async_search_torrents = async_search_torrents
|
|
|
|
with (
|
|
patch.object(settings, "SEARCH_RESOURCE_PAGES", 3, create=True),
|
|
patch("app.chain.search.SystemConfigOper") as system_config_oper,
|
|
patch("app.chain.search.SitesHelper") as sites_helper,
|
|
patch("app.chain.search.ProgressHelper") as progress_helper,
|
|
):
|
|
system_config_oper.return_value.get.return_value = [1]
|
|
sites_helper.return_value.async_get_indexers = AsyncMock(
|
|
return_value=[{"id": 1, "name": "测试站点"}]
|
|
)
|
|
progress_helper.return_value = SimpleNamespace(
|
|
start=lambda: None,
|
|
update=lambda **_kwargs: None,
|
|
end=lambda: None,
|
|
)
|
|
|
|
events = [
|
|
event
|
|
async for event in chain._SearchChain__async_search_all_sites_stream(
|
|
keyword="keyword",
|
|
sites=None,
|
|
page=0,
|
|
)
|
|
]
|
|
|
|
append_events = [event for event in events if event.get("type") == "append"]
|
|
self.assertEqual([0, 1], requested_pages)
|
|
self.assertEqual([0, 1], [event["page"] for event in append_events])
|
|
self.assertEqual(199, append_events[-1]["total_items"])
|
|
|
|
def test_search_by_id_caches_replayable_search_params_when_caching(self):
|
|
chain = self._make_chain()
|
|
cached = []
|
|
chain.save_cache = lambda cache, filename: cached.append((filename, cache))
|
|
chain.recognize_media = lambda **_kwargs: SimpleNamespace(title="Test")
|
|
chain.process = lambda **_kwargs: [SimpleNamespace(title="Result")]
|
|
|
|
chain.search_by_id(
|
|
tmdbid=123,
|
|
mtype=MediaType.MOVIE,
|
|
area="title",
|
|
season=2,
|
|
sites=[1, 3],
|
|
cache_local=True,
|
|
)
|
|
|
|
self.assertIn(
|
|
(
|
|
"__search_params__",
|
|
{
|
|
"keyword": "tmdb:123",
|
|
"type": "电影",
|
|
"area": "title",
|
|
"title": "",
|
|
"year": "",
|
|
"season": "2",
|
|
"sites": "1,3",
|
|
},
|
|
),
|
|
cached,
|
|
)
|
|
self.assertTrue(any(filename == "__search_result__" for filename, _ in cached))
|
|
|
|
def test_tool_factory_excludes_message_tools_when_disabled(self):
|
|
with patch(
|
|
"app.agent.tools.factory.PluginManager.get_plugin_agent_tools",
|
|
return_value=[],
|
|
):
|
|
tools = MoviePilotToolFactory.create_tools(
|
|
session_id="test-session",
|
|
user_id="test-user",
|
|
allow_message_tools=False,
|
|
)
|
|
|
|
tool_names = {tool.name for tool in tools}
|
|
self.assertNotIn("send_message", tool_names)
|
|
self.assertNotIn("ask_user_choice", tool_names)
|
|
self.assertNotIn("send_local_file", tool_names)
|
|
self.assertNotIn("send_voice_message", tool_names)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|