From 0b7505a604ce2772c888a4d4252fceb01582c3bf Mon Sep 17 00:00:00 2001 From: jxxghp Date: Wed, 29 Apr 2026 22:55:27 +0800 Subject: [PATCH] refactor search AI recommendation flow --- app/api/endpoints/search.py | 13 +- app/chain/agent.py | 265 ------------------------- app/chain/search.py | 314 ++++++++++++++++++++++++++++-- tests/test_search_ai_recommend.py | 126 ++++++++++++ 4 files changed, 431 insertions(+), 287 deletions(-) delete mode 100644 app/chain/agent.py create mode 100644 tests/test_search_ai_recommend.py diff --git a/app/api/endpoints/search.py b/app/api/endpoints/search.py index 2c33106b..f6ea2491 100644 --- a/app/api/endpoints/search.py +++ b/app/api/endpoints/search.py @@ -5,7 +5,6 @@ from fastapi import APIRouter, Depends, Body, Request from fastapi.responses import StreamingResponse from app import schemas -from app.chain.agent import AIRecommendChain from app.chain.media import MediaChain from app.chain.search import SearchChain from app.core.config import settings @@ -204,8 +203,6 @@ async def search_by_id(mediaid: str, """ 根据TMDBID/豆瓣ID精确搜索站点资源 tmdb:/douban:/bangumi: """ - # 取消正在运行的AI推荐(会清除数据库缓存) - if mtype: media_type = MediaType(mtype) else: @@ -348,8 +345,6 @@ async def search_by_title(keyword: Optional[str] = None, """ 根据名称模糊搜索站点资源,支持分页,关键词为空是返回首页资源 """ - # 取消正在运行的AI推荐并清除数据库缓存 - torrents = await SearchChain().async_search_by_title( title=keyword, page=page, sites=_parse_site_list(sites), @@ -393,12 +388,12 @@ async def recommend_search_results( "status": "error" }) - recommend_chain = AIRecommendChain() + recommend_chain = SearchChain() # 如果是强制模式,先取消并清除旧结果,然后直接启动新任务 if force: # 检查功能是否启用 - if not settings.AI_AGENT_ENABLE or not settings.AI_RECOMMEND_ENABLED: + if not recommend_chain.is_ai_recommend_enabled: return schemas.Response(success=True, data={ "status": "disabled" }) @@ -413,7 +408,7 @@ async def recommend_search_results( # 如果是仅检查模式,不传递 filtered_indices(避免触发请求变化检测) if check_only: # 返回当前运行状态,不做任何任务启动或取消操作 - current_status = recommend_chain.get_current_status_only() + current_status = recommend_chain.get_current_recommend_status_only() # 如果有错误,将错误信息放到message中 if current_status.get("status") == "error": error_msg = current_status.pop("error", "未知错误") @@ -421,7 +416,7 @@ async def recommend_search_results( return schemas.Response(success=True, data=current_status) # 获取当前状态(会检测请求是否变化) - status_data = recommend_chain.get_status(filtered_indices, len(results)) + status_data = recommend_chain.get_recommend_status(filtered_indices, len(results)) # 如果功能未启用,直接返回禁用状态 if status_data.get("status") == "disabled": diff --git a/app/chain/agent.py b/app/chain/agent.py deleted file mode 100644 index 3188c6d1..00000000 --- a/app/chain/agent.py +++ /dev/null @@ -1,265 +0,0 @@ -import asyncio -import hashlib -import json -import re -from typing import Any, Dict, List, Optional - -from app.agent import agent_manager, prompt_manager -from app.chain import ChainBase -from app.core.config import settings -from app.log import logger -from app.utils.singleton import Singleton -from app.utils.string import StringUtils - - -class AIRecommendChain(ChainBase, metaclass=Singleton): - """ - AI推荐处理链,单例运行 - 用于基于搜索结果的AI智能推荐 - 使用 agent_manager.run_background_prompt 统一后台任务机制 - """ - - __ai_indices_cache_file = "__ai_recommend_indices__" - - _ai_recommend_running = False - _ai_recommend_task: Optional[asyncio.Task] = None - _current_request_hash: Optional[str] = None - _ai_recommend_result: Optional[List[int]] = None - _ai_recommend_error: Optional[str] = None - - @staticmethod - def _calculate_request_hash( - filtered_indices: Optional[List[int]], search_results_count: int - ) -> str: - """ - 计算请求的哈希值,用于判断请求是否变化 - """ - request_data = { - "filtered_indices": filtered_indices or [], - "search_results_count": search_results_count, - } - return hashlib.md5( - json.dumps(request_data, sort_keys=True).encode() - ).hexdigest() - - @property - def is_enabled(self) -> bool: - """ - 检查AI推荐功能是否已启用。 - """ - return settings.AI_AGENT_ENABLE and settings.AI_RECOMMEND_ENABLED - - def _build_status(self) -> Dict[str, Any]: - """ - 构建AI推荐状态字典 - """ - if not self.is_enabled: - return {"status": "disabled"} - - if self._ai_recommend_running: - return {"status": "running"} - - if self._ai_recommend_result is None: - cached_indices = self.load_cache(self.__ai_indices_cache_file) - if cached_indices is not None: - self._ai_recommend_result = cached_indices - - if self._ai_recommend_result is not None: - return {"status": "completed", "results": self._ai_recommend_result} - - if self._ai_recommend_error is not None: - return {"status": "error", "error": self._ai_recommend_error} - - return {"status": "idle"} - - def get_current_status_only(self) -> Dict[str, Any]: - """ - 获取当前状态(不校验hash,用于check_only模式) - """ - return self._build_status() - - def get_status( - self, filtered_indices: Optional[List[int]], search_results_count: int - ) -> Dict[str, Any]: - """ - 获取AI推荐状态并检查请求是否变化(用于首次请求或force模式) - 如果请求变化(筛选条件变化),返回idle状态 - """ - request_hash = self._calculate_request_hash( - filtered_indices, search_results_count - ) - is_same_request = request_hash == self._current_request_hash - - if not is_same_request: - return {"status": "idle"} if self.is_enabled else {"status": "disabled"} - - return self._build_status() - - def is_ai_recommend_running(self) -> bool: - """ - 检查AI推荐是否正在运行 - """ - return self._ai_recommend_running - - def cancel_ai_recommend(self): - """ - 取消正在运行的AI推荐任务 - """ - if self._ai_recommend_task and not self._ai_recommend_task.done(): - self._ai_recommend_task.cancel() - self._ai_recommend_running = False - self._ai_recommend_task = None - self._current_request_hash = None - self._ai_recommend_result = None - self._ai_recommend_error = None - self.remove_cache(self.__ai_indices_cache_file) - - def start_recommend_task( - self, - filtered_indices: Optional[List[int]], - search_results_count: int, - results: List[Any], - ) -> None: - """ - 启动AI推荐任务 - 使用 agent_manager.run_background_prompt 后台Agent机制执行推荐 - :param filtered_indices: 筛选后的索引列表 - :param search_results_count: 搜索结果总数 - :param results: 搜索结果列表 - """ - if not self.is_enabled: - logger.warning("AI推荐功能未启用,跳过任务执行") - return - - new_request_hash = self._calculate_request_hash( - filtered_indices, search_results_count - ) - - if new_request_hash != self._current_request_hash: - self.cancel_ai_recommend() - self._current_request_hash = new_request_hash - self._ai_recommend_result = None - self._ai_recommend_error = None - - async def run_recommend(): - current_task = asyncio.current_task() - try: - self._ai_recommend_running = True - - items = [] - valid_indices = [] - max_items = settings.AI_RECOMMEND_MAX_ITEMS or 50 - - if filtered_indices is not None and len(filtered_indices) > 0: - results_to_process = [ - results[i] - for i in filtered_indices - if 0 <= i < len(results) - ] - else: - results_to_process = results - - for i, torrent in enumerate(results_to_process): - if len(items) >= max_items: - break - - if not torrent.torrent_info: - continue - - valid_indices.append(i) - - item_info = { - "index": i, - "title": torrent.torrent_info.title or "未知", - "size": ( - StringUtils.format_size(torrent.torrent_info.size) - if torrent.torrent_info.size - else "0 B" - ), - "seeders": torrent.torrent_info.seeders or 0, - } - - items.append(json.dumps(item_info, ensure_ascii=False)) - - if not items: - self._ai_recommend_error = "没有可用于AI推荐的资源" - return - - user_preference = ( - settings.AI_RECOMMEND_USER_PREFERENCE - or "Prefer high-quality resources with more seeders" - ) - - search_results_text = "User Preference: {preference}\n\nCandidate Resources:\n{items}".format( - preference=user_preference, items="\n".join(items) - ) - - prompt = prompt_manager.render_system_task_message( - "search_recommend", - template_context={"search_results": search_results_text}, - ) - - full_output = [""] - - def on_output(text: str): - full_output[0] = text - - await agent_manager.run_background_prompt( - message=prompt, - session_prefix="__agent_search_recommend", - output_callback=on_output, - suppress_user_reply=True, - ) - - ai_response = full_output[0] - if not ai_response: - self._ai_recommend_error = "AI推荐未返回结果" - return - - try: - json_match = re.search(r"\[.*?]", ai_response, re.DOTALL) - if not json_match: - raise ValueError(f"无法从响应中提取JSON数组: {ai_response}") - - ai_indices = json.loads(json_match.group()) - if not isinstance(ai_indices, list): - raise ValueError(f"AI返回格式错误: {ai_response}") - - if filtered_indices: - original_indices = [ - filtered_indices[valid_indices[i]] - for i in ai_indices - if i < len(valid_indices) - and 0 - <= filtered_indices[valid_indices[i]] - < len(results) - ] - else: - original_indices = [ - valid_indices[i] - for i in ai_indices - if i < len(valid_indices) - and 0 <= valid_indices[i] < len(results) - ] - - self._ai_recommend_result = original_indices - self.save_cache(original_indices, self.__ai_indices_cache_file) - logger.info(f"AI推荐完成: {len(original_indices)}项") - - except Exception as e: - logger.error( - f"解析AI返回结果失败: {e}, 原始响应: {ai_response}" - ) - self._ai_recommend_error = str(e) - - except asyncio.CancelledError: - logger.info("AI推荐任务被取消") - except Exception as e: - logger.error(f"AI推荐任务失败: {e}") - self._ai_recommend_error = str(e) - finally: - if self._ai_recommend_task == current_task: - self._ai_recommend_running = False - self._ai_recommend_task = None - - self._ai_recommend_task = asyncio.create_task(run_recommend()) diff --git a/app/chain/search.py b/app/chain/search.py index 4e56b794..10cb3eed 100644 --- a/app/chain/search.py +++ b/app/chain/search.py @@ -1,5 +1,8 @@ import asyncio +import hashlib +import json import random +import re import time from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime @@ -21,6 +24,7 @@ from app.helper.torrent import TorrentHelper from app.log import logger from app.schemas import NotExistMediaInfo from app.schemas.types import MediaType, ProgressKey, SystemConfigKey, EventType +from app.utils.string import StringUtils class SearchChain(ChainBase): @@ -29,7 +33,291 @@ class SearchChain(ChainBase): """ __result_temp_file = "__search_result__" - __ai_result_temp_file = "__ai_search_result__" + __ai_indices_cache_file = "__ai_recommend_indices__" + + _ai_recommend_running = False + _ai_recommend_task: Optional[asyncio.Task] = None + _current_recommend_request_hash: Optional[str] = None + _ai_recommend_result: Optional[List[int]] = None + _ai_recommend_error: Optional[str] = None + + @property + def is_ai_recommend_enabled(self) -> bool: + """ + 检查AI推荐功能是否已启用。 + """ + return settings.AI_AGENT_ENABLE and settings.AI_RECOMMEND_ENABLED + + @staticmethod + def _calculate_recommend_request_hash( + filtered_indices: Optional[List[int]], search_results_count: int + ) -> str: + """ + 计算当前推荐请求哈希,用于识别筛选条件是否变化。 + """ + request_data = { + "filtered_indices": filtered_indices or [], + "search_results_count": search_results_count, + } + return hashlib.md5( + json.dumps(request_data, sort_keys=True).encode() + ).hexdigest() + + def _build_ai_recommend_status(self) -> Dict[str, Any]: + """ + 构建AI推荐状态字典。 + """ + state = type(self) + if not self.is_ai_recommend_enabled: + return {"status": "disabled"} + + if state._ai_recommend_running: + return {"status": "running"} + + if state._ai_recommend_result is None: + cached_indices = self.load_cache(self.__ai_indices_cache_file) + if cached_indices is not None: + state._ai_recommend_result = cached_indices + + if state._ai_recommend_result is not None: + return {"status": "completed", "results": state._ai_recommend_result} + + if state._ai_recommend_error is not None: + return {"status": "error", "error": state._ai_recommend_error} + + return {"status": "idle"} + + def get_current_recommend_status_only(self) -> Dict[str, Any]: + """ + 获取当前推荐状态,不校验请求是否变化。 + """ + return self._build_ai_recommend_status() + + def get_recommend_status( + self, filtered_indices: Optional[List[int]], search_results_count: int + ) -> Dict[str, Any]: + """ + 获取AI推荐状态,并在筛选条件变化时返回 idle。 + """ + state = type(self) + request_hash = self._calculate_recommend_request_hash( + filtered_indices, search_results_count + ) + if request_hash != state._current_recommend_request_hash: + return {"status": "idle"} if self.is_ai_recommend_enabled else {"status": "disabled"} + return self._build_ai_recommend_status() + + def cancel_ai_recommend(self): + """ + 取消当前AI推荐任务并清空缓存状态。 + """ + state = type(self) + if state._ai_recommend_task and not state._ai_recommend_task.done(): + state._ai_recommend_task.cancel() + state._ai_recommend_running = False + state._ai_recommend_task = None + state._current_recommend_request_hash = None + state._ai_recommend_result = None + state._ai_recommend_error = None + self.remove_cache(self.__ai_indices_cache_file) + + @staticmethod + def _normalize_ai_indices(ai_indices: List[Any]) -> List[int]: + """ + 过滤模型返回的非法或重复索引,保留原顺序。 + """ + normalized = [] + seen = set() + for index in ai_indices: + try: + value = int(index) + except (TypeError, ValueError): + continue + if value in seen: + continue + seen.add(value) + normalized.append(value) + return normalized + + @staticmethod + def _extract_recommend_items( + filtered_indices: Optional[List[int]], results: List[Any] + ) -> tuple[List[str], List[int]]: + """ + 构建发送给模型的候选列表和索引映射。 + """ + items: List[str] = [] + valid_indices: List[int] = [] + max_items = settings.AI_RECOMMEND_MAX_ITEMS or 50 + + if filtered_indices: + results_to_process = [ + results[index] for index in filtered_indices if 0 <= index < len(results) + ] + else: + results_to_process = results + + for index, torrent in enumerate(results_to_process): + if len(items) >= max_items: + break + if not torrent.torrent_info: + continue + + valid_indices.append(index) + item_info = { + "index": index, + "title": torrent.torrent_info.title or "未知", + "size": ( + StringUtils.format_size(torrent.torrent_info.size) + if torrent.torrent_info.size + else "0 B" + ), + "seeders": torrent.torrent_info.seeders or 0, + } + items.append(json.dumps(item_info, ensure_ascii=False)) + + return items, valid_indices + + @staticmethod + def _restore_original_indices( + ai_indices: List[int], + filtered_indices: Optional[List[int]], + valid_indices: List[int], + results_count: int, + ) -> List[int]: + """ + 将模型输出的局部索引映射回原始搜索结果索引。 + """ + original_indices = [] + seen = set() + + for index in ai_indices: + if not 0 <= index < len(valid_indices): + continue + original_index = ( + filtered_indices[valid_indices[index]] + if filtered_indices + else valid_indices[index] + ) + if not 0 <= original_index < results_count or original_index in seen: + continue + seen.add(original_index) + original_indices.append(original_index) + + return original_indices + + async def _invoke_recommend_llm(self, search_results_text: str) -> str: + """ + 通过统一后台提示词机制执行资源推荐。 + """ + from app.agent import agent_manager + from app.agent.prompt import prompt_manager + + prompt = prompt_manager.render_system_task_message( + "search_recommend", + template_context={"search_results": search_results_text}, + ) + full_output = [""] + + def on_output(text: str): + full_output[0] = text + + await agent_manager.run_background_prompt( + message=prompt, + session_prefix="__agent_search_recommend", + output_callback=on_output, + suppress_user_reply=True, + ) + return full_output[0].strip() + + def start_recommend_task( + self, + filtered_indices: Optional[List[int]], + search_results_count: int, + results: List[Any], + ) -> None: + """ + 启动AI推荐任务。 + """ + if not self.is_ai_recommend_enabled: + logger.warning("AI推荐功能未启用,跳过任务执行") + return + + state = type(self) + request_hash = self._calculate_recommend_request_hash( + filtered_indices, search_results_count + ) + if request_hash == state._current_recommend_request_hash: + return + + self.cancel_ai_recommend() + state._current_recommend_request_hash = request_hash + + async def run_recommend(): + current_task = asyncio.current_task() + + def is_current_request() -> bool: + return state._current_recommend_request_hash == request_hash + + try: + state._ai_recommend_running = True + + items, valid_indices = self._extract_recommend_items( + filtered_indices=filtered_indices, + results=results, + ) + if not items: + if is_current_request(): + state._ai_recommend_error = "没有可用于AI推荐的资源" + return + + user_preference = ( + settings.AI_RECOMMEND_USER_PREFERENCE + or "Prefer high-quality resources with more seeders" + ) + search_results_text = ( + f"User Preference: {user_preference}\n\n" + f"Candidate Resources:\n{chr(10).join(items)}" + ) + ai_response = await self._invoke_recommend_llm(search_results_text) + if not ai_response: + if is_current_request(): + state._ai_recommend_error = "AI推荐未返回结果" + return + + json_match = re.search(r"\[.*?]", ai_response, re.DOTALL) + if not json_match: + raise ValueError(f"无法从响应中提取JSON数组: {ai_response}") + + ai_indices = json.loads(json_match.group()) + if not isinstance(ai_indices, list): + raise ValueError(f"AI返回格式错误: {ai_response}") + + original_indices = self._restore_original_indices( + ai_indices=self._normalize_ai_indices(ai_indices), + filtered_indices=filtered_indices, + valid_indices=valid_indices, + results_count=len(results), + ) + if not is_current_request(): + logger.info("AI推荐结果已过期,丢弃旧结果") + return + + state._ai_recommend_result = original_indices + self.save_cache(original_indices, self.__ai_indices_cache_file) + logger.info(f"AI推荐完成: {len(original_indices)}项") + except asyncio.CancelledError: + logger.info("AI推荐任务被取消") + except Exception as err: + logger.error(f"AI推荐任务失败: {err}") + if is_current_request(): + state._ai_recommend_error = str(err) + finally: + if state._ai_recommend_task == current_task: + state._ai_recommend_running = False + state._ai_recommend_task = None + + state._ai_recommend_task = asyncio.create_task(run_recommend()) def search_by_id(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None, mtype: MediaType = None, area: Optional[str] = "title", season: Optional[int] = None, @@ -44,6 +332,8 @@ class SearchChain(ChainBase): :param sites: 站点ID列表 :param cache_local: 是否缓存到本地 """ + if cache_local: + self.cancel_ai_recommend() mediainfo = self.recognize_media(tmdbid=tmdbid, doubanid=doubanid, mtype=mtype) if not mediainfo: logger.error(f'{tmdbid} 媒体信息识别失败!') @@ -70,6 +360,8 @@ class SearchChain(ChainBase): :param sites: 站点ID列表 :param cache_local: 是否缓存到本地 """ + if cache_local: + self.cancel_ai_recommend() if title: logger.info(f'开始搜索资源,关键词:{title} ...') else: @@ -99,18 +391,6 @@ class SearchChain(ChainBase): """ return await self.async_load_cache(self.__result_temp_file) - async def async_last_ai_results(self) -> Optional[List[Context]]: - """ - 异步获取上次AI推荐结果 - """ - return await self.async_load_cache(self.__ai_result_temp_file) - - async def async_save_ai_results(self, results: List[Context]): - """ - 异步保存AI推荐结果 - """ - await self.async_save_cache(results, self.__ai_result_temp_file) - async def async_search_by_id(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None, mtype: MediaType = None, area: Optional[str] = "title", season: Optional[int] = None, sites: List[int] = None, cache_local: bool = False) -> List[Context]: @@ -124,6 +404,8 @@ class SearchChain(ChainBase): :param sites: 站点ID列表 :param cache_local: 是否缓存到本地 """ + if cache_local: + self.cancel_ai_recommend() mediainfo = await self.async_recognize_media(tmdbid=tmdbid, doubanid=doubanid, mtype=mtype) if not mediainfo: logger.error(f'{tmdbid} 媒体信息识别失败!') @@ -150,6 +432,8 @@ class SearchChain(ChainBase): :param sites: 站点ID列表 :param cache_local: 是否缓存到本地 """ + if cache_local: + self.cancel_ai_recommend() if title: logger.info(f'开始搜索资源,关键词:{title} ...') else: @@ -173,6 +457,8 @@ class SearchChain(ChainBase): """ 根据标题渐进式搜索资源,不识别不过滤,按站点完成顺序返回结果 """ + if cache_local: + self.cancel_ai_recommend() if title: logger.info(f'开始渐进式搜索资源,关键词:{title} ...') else: @@ -214,6 +500,8 @@ class SearchChain(ChainBase): """ 根据TMDBID/豆瓣ID渐进式搜索资源,先返回站点原始候选,再返回过滤匹配后的最终结果 """ + if cache_local: + self.cancel_ai_recommend() mediainfo = await self.async_recognize_media(tmdbid=tmdbid, doubanid=doubanid, mtype=mtype) if not mediainfo: logger.error(f'{tmdbid} 媒体信息识别失败!') diff --git a/tests/test_search_ai_recommend.py b/tests/test_search_ai_recommend.py new file mode 100644 index 00000000..aeb9dc70 --- /dev/null +++ b/tests/test_search_ai_recommend.py @@ -0,0 +1,126 @@ +import asyncio +import sys +import unittest +from types import SimpleNamespace +from types import ModuleType +from unittest.mock import AsyncMock, patch + + +def _stub_module(name: str, **attrs): + module = sys.modules.get(name) + if module is None: + module = ModuleType(name) + sys.modules[name] = module + for key, value in attrs.items(): + setattr(module, key, value) + return module + + +_stub_module("qbittorrentapi", TorrentFilesList=list) +_stub_module("transmission_rpc", File=object) + +from app.chain.search import SearchChain +from app.core.config import settings + + +def _make_result(title: str, size: int, seeders: int): + return SimpleNamespace( + torrent_info=SimpleNamespace(title=title, size=size, seeders=seeders) + ) + + +class SearchChainAIRecommendTest(unittest.IsolatedAsyncioTestCase): + def setUp(self): + SearchChain._ai_recommend_running = False + SearchChain._ai_recommend_task = None + SearchChain._current_recommend_request_hash = None + SearchChain._ai_recommend_result = None + SearchChain._ai_recommend_error = None + + async def asyncTearDown(self): + task = SearchChain._ai_recommend_task + if task and not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + SearchChain._ai_recommend_running = False + SearchChain._ai_recommend_task = None + SearchChain._current_recommend_request_hash = None + SearchChain._ai_recommend_result = None + SearchChain._ai_recommend_error = None + + @staticmethod + def _make_chain() -> SearchChain: + chain = object.__new__(SearchChain) + chain.load_cache = lambda _filename: None + chain.save_cache = lambda _cache, _filename: None + chain.remove_cache = lambda _filename: None + return chain + + async def test_start_recommend_task_restores_original_indices(self): + chain = self._make_chain() + saved = [] + chain.save_cache = lambda cache, filename: saved.append((filename, cache)) + results = [_make_result(f"item-{index}", 1024 * (index + 1), index) for index in range(7)] + + with ( + patch.object(settings, "AI_AGENT_ENABLE", True, create=True), + patch.object(settings, "AI_RECOMMEND_ENABLED", True, create=True), + patch.object(settings, "AI_RECOMMEND_MAX_ITEMS", 50, create=True), + patch.object( + settings, + "AI_RECOMMEND_USER_PREFERENCE", + "Prefer high seeders", + create=True, + ), + patch.object( + SearchChain, + "_invoke_recommend_llm", + new=AsyncMock(return_value='[1, 0, 1, "bad", 9]'), + ), + ): + chain.start_recommend_task( + filtered_indices=[2, 4, 6], + search_results_count=len(results), + results=results, + ) + self.assertIsNotNone(SearchChain._ai_recommend_task) + await SearchChain._ai_recommend_task + + self.assertEqual([4, 2], SearchChain._ai_recommend_result) + self.assertEqual( + [("__ai_recommend_indices__", [4, 2])], + saved, + ) + self.assertFalse(SearchChain._ai_recommend_running) + self.assertIsNone(SearchChain._ai_recommend_task) + + def test_search_by_title_clears_previous_recommend_state_when_caching(self): + chain = self._make_chain() + removed = [] + cached = [] + chain.remove_cache = lambda filename: removed.append(filename) + chain.save_cache = lambda cache, filename: cached.append((filename, cache)) + chain._SearchChain__search_all_sites = lambda keyword, sites, page: [ + SimpleNamespace(title="Test Title", description="Test Desc") + ] + + SearchChain._current_recommend_request_hash = "stale-hash" + SearchChain._ai_recommend_result = [3, 1] + SearchChain._ai_recommend_error = "stale-error" + + results = chain.search_by_title("keyword", cache_local=True) + + self.assertEqual(1, len(results)) + self.assertEqual(["__ai_recommend_indices__"], removed) + self.assertEqual("__search_result__", cached[0][0]) + self.assertIsNone(SearchChain._current_recommend_request_hash) + self.assertIsNone(SearchChain._ai_recommend_result) + self.assertIsNone(SearchChain._ai_recommend_error) + + +if __name__ == "__main__": + unittest.main()