mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-11 09:59:51 +08:00
202 lines
7.0 KiB
Python
202 lines
7.0 KiB
Python
import asyncio
|
|
import sys
|
|
import unittest
|
|
from types import SimpleNamespace
|
|
from types import ModuleType
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import app.chain as chain_module
|
|
|
|
|
|
def _stub_module(name: str, **attrs):
|
|
module = sys.modules.get(name)
|
|
if module is None:
|
|
module = ModuleType(name)
|
|
sys.modules[name] = module
|
|
for key, value in attrs.items():
|
|
setattr(module, key, value)
|
|
return module
|
|
|
|
|
|
_stub_module("qbittorrentapi", TorrentFilesList=list)
|
|
_stub_module("transmission_rpc", File=object)
|
|
|
|
from app.chain.search import SearchChain
|
|
from app.core.config import settings
|
|
from app.core.message_context import suppress_message_channel
|
|
from app.schemas import Notification
|
|
from app.schemas.types import NotificationType
|
|
|
|
|
|
def _make_result(title: str, size: int, seeders: int):
|
|
return SimpleNamespace(
|
|
torrent_info=SimpleNamespace(title=title, size=size, seeders=seeders)
|
|
)
|
|
|
|
|
|
class SearchChainAIRecommendTest(unittest.IsolatedAsyncioTestCase):
|
|
def setUp(self):
|
|
SearchChain._ai_recommend_running = False
|
|
SearchChain._ai_recommend_task = None
|
|
SearchChain._current_recommend_request_hash = None
|
|
SearchChain._ai_recommend_result = None
|
|
SearchChain._ai_recommend_error = None
|
|
|
|
async def asyncTearDown(self):
|
|
task = SearchChain._ai_recommend_task
|
|
if task and not task.done():
|
|
task.cancel()
|
|
try:
|
|
await task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
SearchChain._ai_recommend_running = False
|
|
SearchChain._ai_recommend_task = None
|
|
SearchChain._current_recommend_request_hash = None
|
|
SearchChain._ai_recommend_result = None
|
|
SearchChain._ai_recommend_error = None
|
|
|
|
@staticmethod
|
|
def _make_chain() -> SearchChain:
|
|
chain = object.__new__(SearchChain)
|
|
chain.load_cache = lambda _filename: None
|
|
chain.save_cache = lambda _cache, _filename: None
|
|
chain.remove_cache = lambda _filename: None
|
|
return chain
|
|
|
|
async def test_start_recommend_task_restores_original_indices(self):
|
|
chain = self._make_chain()
|
|
saved = []
|
|
chain.save_cache = lambda cache, filename: saved.append((filename, cache))
|
|
results = [_make_result(f"item-{index}", 1024 * (index + 1), index) for index in range(7)]
|
|
|
|
with (
|
|
patch.object(settings, "AI_AGENT_ENABLE", True, create=True),
|
|
patch.object(settings, "AI_RECOMMEND_ENABLED", True, create=True),
|
|
patch.object(settings, "AI_RECOMMEND_MAX_ITEMS", 50, create=True),
|
|
patch.object(
|
|
settings,
|
|
"AI_RECOMMEND_USER_PREFERENCE",
|
|
"Prefer high seeders",
|
|
create=True,
|
|
),
|
|
patch.object(
|
|
SearchChain,
|
|
"_invoke_recommend_llm",
|
|
new=AsyncMock(return_value='[1, 0, 1, "bad", 9]'),
|
|
),
|
|
):
|
|
chain.start_recommend_task(
|
|
filtered_indices=[2, 4, 6],
|
|
search_results_count=len(results),
|
|
results=results,
|
|
)
|
|
self.assertIsNotNone(SearchChain._ai_recommend_task)
|
|
await SearchChain._ai_recommend_task
|
|
|
|
self.assertEqual([4, 2], SearchChain._ai_recommend_result)
|
|
self.assertEqual(
|
|
[("__ai_recommend_indices__", [4, 2])],
|
|
saved,
|
|
)
|
|
self.assertFalse(SearchChain._ai_recommend_running)
|
|
self.assertIsNone(SearchChain._ai_recommend_task)
|
|
|
|
async def test_invoke_recommend_llm_disables_output_message_persistence(self):
|
|
chain = self._make_chain()
|
|
from app.agent import agent_manager
|
|
from app.agent.prompt import prompt_manager
|
|
|
|
captured = {}
|
|
|
|
async def _fake_run_background_prompt(**kwargs):
|
|
captured.update(kwargs)
|
|
kwargs["output_callback"]("[0, 2]")
|
|
|
|
with (
|
|
patch.object(
|
|
prompt_manager,
|
|
"render_system_task_message",
|
|
return_value="PROMPT",
|
|
),
|
|
patch.object(
|
|
agent_manager,
|
|
"run_background_prompt",
|
|
new=AsyncMock(side_effect=_fake_run_background_prompt),
|
|
),
|
|
):
|
|
result = await chain._invoke_recommend_llm("Candidates")
|
|
|
|
self.assertEqual("[0, 2]", result)
|
|
self.assertTrue(captured["suppress_user_reply"])
|
|
self.assertFalse(captured["persist_output_message"])
|
|
self.assertTrue(captured["suppress_message_channel_dispatch"])
|
|
|
|
def test_search_by_title_clears_previous_recommend_state_when_caching(self):
|
|
chain = self._make_chain()
|
|
removed = []
|
|
cached = []
|
|
chain.remove_cache = lambda filename: removed.append(filename)
|
|
chain.save_cache = lambda cache, filename: cached.append((filename, cache))
|
|
chain._SearchChain__search_all_sites = lambda keyword, sites, page: [
|
|
SimpleNamespace(title="Test Title", description="Test Desc")
|
|
]
|
|
|
|
SearchChain._current_recommend_request_hash = "stale-hash"
|
|
SearchChain._ai_recommend_result = [3, 1]
|
|
SearchChain._ai_recommend_error = "stale-error"
|
|
|
|
results = chain.search_by_title("keyword", cache_local=True)
|
|
|
|
self.assertEqual(1, len(results))
|
|
self.assertEqual(["__ai_recommend_indices__"], removed)
|
|
self.assertEqual("__search_result__", cached[0][0])
|
|
self.assertIsNone(SearchChain._current_recommend_request_hash)
|
|
self.assertIsNone(SearchChain._ai_recommend_result)
|
|
self.assertIsNone(SearchChain._ai_recommend_error)
|
|
|
|
def test_post_message_skips_channel_dispatch_when_suppressed(self):
|
|
chain = object.__new__(SearchChain)
|
|
queue_calls = []
|
|
event_calls = []
|
|
saved_messages = []
|
|
saved_records = []
|
|
chain.messagehelper = SimpleNamespace(
|
|
put=lambda *args, **kwargs: saved_messages.append((args, kwargs))
|
|
)
|
|
chain.messageoper = SimpleNamespace(
|
|
add=lambda **kwargs: saved_records.append(kwargs)
|
|
)
|
|
chain.messagequeue = SimpleNamespace(
|
|
send_message=lambda *args, **kwargs: queue_calls.append((args, kwargs))
|
|
)
|
|
chain.eventmanager = SimpleNamespace(
|
|
send_event=lambda *args, **kwargs: event_calls.append((args, kwargs))
|
|
)
|
|
|
|
notification = Notification(
|
|
mtype=NotificationType.Manual,
|
|
title="Title",
|
|
text="Body",
|
|
)
|
|
|
|
with (
|
|
patch.object(
|
|
chain_module.MessageTemplateHelper,
|
|
"render",
|
|
return_value=notification,
|
|
),
|
|
suppress_message_channel(),
|
|
):
|
|
chain.post_message(message=notification)
|
|
|
|
self.assertEqual(1, len(saved_messages))
|
|
self.assertEqual(1, len(saved_records))
|
|
self.assertEqual([], queue_calls)
|
|
self.assertEqual([], event_calls)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|