Files
MoviePilot/tests/test_search_ai_recommend.py

202 lines
7.0 KiB
Python

import asyncio
import sys
import unittest
from types import SimpleNamespace
from types import ModuleType
from unittest.mock import AsyncMock, patch
import app.chain as chain_module
def _stub_module(name: str, **attrs):
module = sys.modules.get(name)
if module is None:
module = ModuleType(name)
sys.modules[name] = module
for key, value in attrs.items():
setattr(module, key, value)
return module
_stub_module("qbittorrentapi", TorrentFilesList=list)
_stub_module("transmission_rpc", File=object)
from app.chain.search import SearchChain
from app.core.config import settings
from app.core.message_context import suppress_message_channel
from app.schemas import Notification
from app.schemas.types import NotificationType
def _make_result(title: str, size: int, seeders: int):
return SimpleNamespace(
torrent_info=SimpleNamespace(title=title, size=size, seeders=seeders)
)
class SearchChainAIRecommendTest(unittest.IsolatedAsyncioTestCase):
def setUp(self):
SearchChain._ai_recommend_running = False
SearchChain._ai_recommend_task = None
SearchChain._current_recommend_request_hash = None
SearchChain._ai_recommend_result = None
SearchChain._ai_recommend_error = None
async def asyncTearDown(self):
task = SearchChain._ai_recommend_task
if task and not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
SearchChain._ai_recommend_running = False
SearchChain._ai_recommend_task = None
SearchChain._current_recommend_request_hash = None
SearchChain._ai_recommend_result = None
SearchChain._ai_recommend_error = None
@staticmethod
def _make_chain() -> SearchChain:
chain = object.__new__(SearchChain)
chain.load_cache = lambda _filename: None
chain.save_cache = lambda _cache, _filename: None
chain.remove_cache = lambda _filename: None
return chain
async def test_start_recommend_task_restores_original_indices(self):
chain = self._make_chain()
saved = []
chain.save_cache = lambda cache, filename: saved.append((filename, cache))
results = [_make_result(f"item-{index}", 1024 * (index + 1), index) for index in range(7)]
with (
patch.object(settings, "AI_AGENT_ENABLE", True, create=True),
patch.object(settings, "AI_RECOMMEND_ENABLED", True, create=True),
patch.object(settings, "AI_RECOMMEND_MAX_ITEMS", 50, create=True),
patch.object(
settings,
"AI_RECOMMEND_USER_PREFERENCE",
"Prefer high seeders",
create=True,
),
patch.object(
SearchChain,
"_invoke_recommend_llm",
new=AsyncMock(return_value='[1, 0, 1, "bad", 9]'),
),
):
chain.start_recommend_task(
filtered_indices=[2, 4, 6],
search_results_count=len(results),
results=results,
)
self.assertIsNotNone(SearchChain._ai_recommend_task)
await SearchChain._ai_recommend_task
self.assertEqual([4, 2], SearchChain._ai_recommend_result)
self.assertEqual(
[("__ai_recommend_indices__", [4, 2])],
saved,
)
self.assertFalse(SearchChain._ai_recommend_running)
self.assertIsNone(SearchChain._ai_recommend_task)
async def test_invoke_recommend_llm_disables_output_message_persistence(self):
chain = self._make_chain()
from app.agent import agent_manager
from app.agent.prompt import prompt_manager
captured = {}
async def _fake_run_background_prompt(**kwargs):
captured.update(kwargs)
kwargs["output_callback"]("[0, 2]")
with (
patch.object(
prompt_manager,
"render_system_task_message",
return_value="PROMPT",
),
patch.object(
agent_manager,
"run_background_prompt",
new=AsyncMock(side_effect=_fake_run_background_prompt),
),
):
result = await chain._invoke_recommend_llm("Candidates")
self.assertEqual("[0, 2]", result)
self.assertTrue(captured["suppress_user_reply"])
self.assertFalse(captured["persist_output_message"])
self.assertTrue(captured["suppress_message_channel_dispatch"])
def test_search_by_title_clears_previous_recommend_state_when_caching(self):
chain = self._make_chain()
removed = []
cached = []
chain.remove_cache = lambda filename: removed.append(filename)
chain.save_cache = lambda cache, filename: cached.append((filename, cache))
chain._SearchChain__search_all_sites = lambda keyword, sites, page: [
SimpleNamespace(title="Test Title", description="Test Desc")
]
SearchChain._current_recommend_request_hash = "stale-hash"
SearchChain._ai_recommend_result = [3, 1]
SearchChain._ai_recommend_error = "stale-error"
results = chain.search_by_title("keyword", cache_local=True)
self.assertEqual(1, len(results))
self.assertEqual(["__ai_recommend_indices__"], removed)
self.assertEqual("__search_result__", cached[0][0])
self.assertIsNone(SearchChain._current_recommend_request_hash)
self.assertIsNone(SearchChain._ai_recommend_result)
self.assertIsNone(SearchChain._ai_recommend_error)
def test_post_message_skips_channel_dispatch_when_suppressed(self):
chain = object.__new__(SearchChain)
queue_calls = []
event_calls = []
saved_messages = []
saved_records = []
chain.messagehelper = SimpleNamespace(
put=lambda *args, **kwargs: saved_messages.append((args, kwargs))
)
chain.messageoper = SimpleNamespace(
add=lambda **kwargs: saved_records.append(kwargs)
)
chain.messagequeue = SimpleNamespace(
send_message=lambda *args, **kwargs: queue_calls.append((args, kwargs))
)
chain.eventmanager = SimpleNamespace(
send_event=lambda *args, **kwargs: event_calls.append((args, kwargs))
)
notification = Notification(
mtype=NotificationType.Manual,
title="Title",
text="Body",
)
with (
patch.object(
chain_module.MessageTemplateHelper,
"render",
return_value=notification,
),
suppress_message_channel(),
):
chain.post_message(message=notification)
self.assertEqual(1, len(saved_messages))
self.assertEqual(1, len(saved_records))
self.assertEqual([], queue_calls)
self.assertEqual([], event_calls)
if __name__ == "__main__":
unittest.main()