refactor search AI recommendation flow

This commit is contained in:
jxxghp
2026-04-29 22:55:27 +08:00
parent 460d716512
commit 0b7505a604
4 changed files with 431 additions and 287 deletions

View File

@@ -5,7 +5,6 @@ from fastapi import APIRouter, Depends, Body, Request
from fastapi.responses import StreamingResponse
from app import schemas
from app.chain.agent import AIRecommendChain
from app.chain.media import MediaChain
from app.chain.search import SearchChain
from app.core.config import settings
@@ -204,8 +203,6 @@ async def search_by_id(mediaid: str,
"""
根据TMDBID/豆瓣ID精确搜索站点资源 tmdb:/douban:/bangumi:
"""
# 取消正在运行的AI推荐会清除数据库缓存
if mtype:
media_type = MediaType(mtype)
else:
@@ -348,8 +345,6 @@ async def search_by_title(keyword: Optional[str] = None,
"""
根据名称模糊搜索站点资源,支持分页,关键词为空是返回首页资源
"""
# 取消正在运行的AI推荐并清除数据库缓存
torrents = await SearchChain().async_search_by_title(
title=keyword, page=page,
sites=_parse_site_list(sites),
@@ -393,12 +388,12 @@ async def recommend_search_results(
"status": "error"
})
recommend_chain = AIRecommendChain()
recommend_chain = SearchChain()
# 如果是强制模式,先取消并清除旧结果,然后直接启动新任务
if force:
# 检查功能是否启用
if not settings.AI_AGENT_ENABLE or not settings.AI_RECOMMEND_ENABLED:
if not recommend_chain.is_ai_recommend_enabled:
return schemas.Response(success=True, data={
"status": "disabled"
})
@@ -413,7 +408,7 @@ async def recommend_search_results(
# 如果是仅检查模式,不传递 filtered_indices避免触发请求变化检测
if check_only:
# 返回当前运行状态,不做任何任务启动或取消操作
current_status = recommend_chain.get_current_status_only()
current_status = recommend_chain.get_current_recommend_status_only()
# 如果有错误将错误信息放到message中
if current_status.get("status") == "error":
error_msg = current_status.pop("error", "未知错误")
@@ -421,7 +416,7 @@ async def recommend_search_results(
return schemas.Response(success=True, data=current_status)
# 获取当前状态(会检测请求是否变化)
status_data = recommend_chain.get_status(filtered_indices, len(results))
status_data = recommend_chain.get_recommend_status(filtered_indices, len(results))
# 如果功能未启用,直接返回禁用状态
if status_data.get("status") == "disabled":

View File

@@ -1,265 +0,0 @@
import asyncio
import hashlib
import json
import re
from typing import Any, Dict, List, Optional
from app.agent import agent_manager, prompt_manager
from app.chain import ChainBase
from app.core.config import settings
from app.log import logger
from app.utils.singleton import Singleton
from app.utils.string import StringUtils
class AIRecommendChain(ChainBase, metaclass=Singleton):
"""
AI推荐处理链单例运行
用于基于搜索结果的AI智能推荐
使用 agent_manager.run_background_prompt 统一后台任务机制
"""
__ai_indices_cache_file = "__ai_recommend_indices__"
_ai_recommend_running = False
_ai_recommend_task: Optional[asyncio.Task] = None
_current_request_hash: Optional[str] = None
_ai_recommend_result: Optional[List[int]] = None
_ai_recommend_error: Optional[str] = None
@staticmethod
def _calculate_request_hash(
filtered_indices: Optional[List[int]], search_results_count: int
) -> str:
"""
计算请求的哈希值,用于判断请求是否变化
"""
request_data = {
"filtered_indices": filtered_indices or [],
"search_results_count": search_results_count,
}
return hashlib.md5(
json.dumps(request_data, sort_keys=True).encode()
).hexdigest()
@property
def is_enabled(self) -> bool:
"""
检查AI推荐功能是否已启用。
"""
return settings.AI_AGENT_ENABLE and settings.AI_RECOMMEND_ENABLED
def _build_status(self) -> Dict[str, Any]:
"""
构建AI推荐状态字典
"""
if not self.is_enabled:
return {"status": "disabled"}
if self._ai_recommend_running:
return {"status": "running"}
if self._ai_recommend_result is None:
cached_indices = self.load_cache(self.__ai_indices_cache_file)
if cached_indices is not None:
self._ai_recommend_result = cached_indices
if self._ai_recommend_result is not None:
return {"status": "completed", "results": self._ai_recommend_result}
if self._ai_recommend_error is not None:
return {"status": "error", "error": self._ai_recommend_error}
return {"status": "idle"}
def get_current_status_only(self) -> Dict[str, Any]:
"""
获取当前状态不校验hash用于check_only模式
"""
return self._build_status()
def get_status(
self, filtered_indices: Optional[List[int]], search_results_count: int
) -> Dict[str, Any]:
"""
获取AI推荐状态并检查请求是否变化用于首次请求或force模式
如果请求变化筛选条件变化返回idle状态
"""
request_hash = self._calculate_request_hash(
filtered_indices, search_results_count
)
is_same_request = request_hash == self._current_request_hash
if not is_same_request:
return {"status": "idle"} if self.is_enabled else {"status": "disabled"}
return self._build_status()
def is_ai_recommend_running(self) -> bool:
"""
检查AI推荐是否正在运行
"""
return self._ai_recommend_running
def cancel_ai_recommend(self):
"""
取消正在运行的AI推荐任务
"""
if self._ai_recommend_task and not self._ai_recommend_task.done():
self._ai_recommend_task.cancel()
self._ai_recommend_running = False
self._ai_recommend_task = None
self._current_request_hash = None
self._ai_recommend_result = None
self._ai_recommend_error = None
self.remove_cache(self.__ai_indices_cache_file)
def start_recommend_task(
self,
filtered_indices: Optional[List[int]],
search_results_count: int,
results: List[Any],
) -> None:
"""
启动AI推荐任务
使用 agent_manager.run_background_prompt 后台Agent机制执行推荐
:param filtered_indices: 筛选后的索引列表
:param search_results_count: 搜索结果总数
:param results: 搜索结果列表
"""
if not self.is_enabled:
logger.warning("AI推荐功能未启用跳过任务执行")
return
new_request_hash = self._calculate_request_hash(
filtered_indices, search_results_count
)
if new_request_hash != self._current_request_hash:
self.cancel_ai_recommend()
self._current_request_hash = new_request_hash
self._ai_recommend_result = None
self._ai_recommend_error = None
async def run_recommend():
current_task = asyncio.current_task()
try:
self._ai_recommend_running = True
items = []
valid_indices = []
max_items = settings.AI_RECOMMEND_MAX_ITEMS or 50
if filtered_indices is not None and len(filtered_indices) > 0:
results_to_process = [
results[i]
for i in filtered_indices
if 0 <= i < len(results)
]
else:
results_to_process = results
for i, torrent in enumerate(results_to_process):
if len(items) >= max_items:
break
if not torrent.torrent_info:
continue
valid_indices.append(i)
item_info = {
"index": i,
"title": torrent.torrent_info.title or "未知",
"size": (
StringUtils.format_size(torrent.torrent_info.size)
if torrent.torrent_info.size
else "0 B"
),
"seeders": torrent.torrent_info.seeders or 0,
}
items.append(json.dumps(item_info, ensure_ascii=False))
if not items:
self._ai_recommend_error = "没有可用于AI推荐的资源"
return
user_preference = (
settings.AI_RECOMMEND_USER_PREFERENCE
or "Prefer high-quality resources with more seeders"
)
search_results_text = "User Preference: {preference}\n\nCandidate Resources:\n{items}".format(
preference=user_preference, items="\n".join(items)
)
prompt = prompt_manager.render_system_task_message(
"search_recommend",
template_context={"search_results": search_results_text},
)
full_output = [""]
def on_output(text: str):
full_output[0] = text
await agent_manager.run_background_prompt(
message=prompt,
session_prefix="__agent_search_recommend",
output_callback=on_output,
suppress_user_reply=True,
)
ai_response = full_output[0]
if not ai_response:
self._ai_recommend_error = "AI推荐未返回结果"
return
try:
json_match = re.search(r"\[.*?]", ai_response, re.DOTALL)
if not json_match:
raise ValueError(f"无法从响应中提取JSON数组: {ai_response}")
ai_indices = json.loads(json_match.group())
if not isinstance(ai_indices, list):
raise ValueError(f"AI返回格式错误: {ai_response}")
if filtered_indices:
original_indices = [
filtered_indices[valid_indices[i]]
for i in ai_indices
if i < len(valid_indices)
and 0
<= filtered_indices[valid_indices[i]]
< len(results)
]
else:
original_indices = [
valid_indices[i]
for i in ai_indices
if i < len(valid_indices)
and 0 <= valid_indices[i] < len(results)
]
self._ai_recommend_result = original_indices
self.save_cache(original_indices, self.__ai_indices_cache_file)
logger.info(f"AI推荐完成: {len(original_indices)}")
except Exception as e:
logger.error(
f"解析AI返回结果失败: {e}, 原始响应: {ai_response}"
)
self._ai_recommend_error = str(e)
except asyncio.CancelledError:
logger.info("AI推荐任务被取消")
except Exception as e:
logger.error(f"AI推荐任务失败: {e}")
self._ai_recommend_error = str(e)
finally:
if self._ai_recommend_task == current_task:
self._ai_recommend_running = False
self._ai_recommend_task = None
self._ai_recommend_task = asyncio.create_task(run_recommend())

View File

@@ -1,5 +1,8 @@
import asyncio
import hashlib
import json
import random
import re
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
@@ -21,6 +24,7 @@ from app.helper.torrent import TorrentHelper
from app.log import logger
from app.schemas import NotExistMediaInfo
from app.schemas.types import MediaType, ProgressKey, SystemConfigKey, EventType
from app.utils.string import StringUtils
class SearchChain(ChainBase):
@@ -29,7 +33,291 @@ class SearchChain(ChainBase):
"""
__result_temp_file = "__search_result__"
__ai_result_temp_file = "__ai_search_result__"
__ai_indices_cache_file = "__ai_recommend_indices__"
_ai_recommend_running = False
_ai_recommend_task: Optional[asyncio.Task] = None
_current_recommend_request_hash: Optional[str] = None
_ai_recommend_result: Optional[List[int]] = None
_ai_recommend_error: Optional[str] = None
@property
def is_ai_recommend_enabled(self) -> bool:
"""
检查AI推荐功能是否已启用。
"""
return settings.AI_AGENT_ENABLE and settings.AI_RECOMMEND_ENABLED
@staticmethod
def _calculate_recommend_request_hash(
filtered_indices: Optional[List[int]], search_results_count: int
) -> str:
"""
计算当前推荐请求哈希,用于识别筛选条件是否变化。
"""
request_data = {
"filtered_indices": filtered_indices or [],
"search_results_count": search_results_count,
}
return hashlib.md5(
json.dumps(request_data, sort_keys=True).encode()
).hexdigest()
def _build_ai_recommend_status(self) -> Dict[str, Any]:
"""
构建AI推荐状态字典。
"""
state = type(self)
if not self.is_ai_recommend_enabled:
return {"status": "disabled"}
if state._ai_recommend_running:
return {"status": "running"}
if state._ai_recommend_result is None:
cached_indices = self.load_cache(self.__ai_indices_cache_file)
if cached_indices is not None:
state._ai_recommend_result = cached_indices
if state._ai_recommend_result is not None:
return {"status": "completed", "results": state._ai_recommend_result}
if state._ai_recommend_error is not None:
return {"status": "error", "error": state._ai_recommend_error}
return {"status": "idle"}
def get_current_recommend_status_only(self) -> Dict[str, Any]:
"""
获取当前推荐状态,不校验请求是否变化。
"""
return self._build_ai_recommend_status()
def get_recommend_status(
self, filtered_indices: Optional[List[int]], search_results_count: int
) -> Dict[str, Any]:
"""
获取AI推荐状态并在筛选条件变化时返回 idle。
"""
state = type(self)
request_hash = self._calculate_recommend_request_hash(
filtered_indices, search_results_count
)
if request_hash != state._current_recommend_request_hash:
return {"status": "idle"} if self.is_ai_recommend_enabled else {"status": "disabled"}
return self._build_ai_recommend_status()
def cancel_ai_recommend(self):
"""
取消当前AI推荐任务并清空缓存状态。
"""
state = type(self)
if state._ai_recommend_task and not state._ai_recommend_task.done():
state._ai_recommend_task.cancel()
state._ai_recommend_running = False
state._ai_recommend_task = None
state._current_recommend_request_hash = None
state._ai_recommend_result = None
state._ai_recommend_error = None
self.remove_cache(self.__ai_indices_cache_file)
@staticmethod
def _normalize_ai_indices(ai_indices: List[Any]) -> List[int]:
"""
过滤模型返回的非法或重复索引,保留原顺序。
"""
normalized = []
seen = set()
for index in ai_indices:
try:
value = int(index)
except (TypeError, ValueError):
continue
if value in seen:
continue
seen.add(value)
normalized.append(value)
return normalized
@staticmethod
def _extract_recommend_items(
filtered_indices: Optional[List[int]], results: List[Any]
) -> tuple[List[str], List[int]]:
"""
构建发送给模型的候选列表和索引映射。
"""
items: List[str] = []
valid_indices: List[int] = []
max_items = settings.AI_RECOMMEND_MAX_ITEMS or 50
if filtered_indices:
results_to_process = [
results[index] for index in filtered_indices if 0 <= index < len(results)
]
else:
results_to_process = results
for index, torrent in enumerate(results_to_process):
if len(items) >= max_items:
break
if not torrent.torrent_info:
continue
valid_indices.append(index)
item_info = {
"index": index,
"title": torrent.torrent_info.title or "未知",
"size": (
StringUtils.format_size(torrent.torrent_info.size)
if torrent.torrent_info.size
else "0 B"
),
"seeders": torrent.torrent_info.seeders or 0,
}
items.append(json.dumps(item_info, ensure_ascii=False))
return items, valid_indices
@staticmethod
def _restore_original_indices(
ai_indices: List[int],
filtered_indices: Optional[List[int]],
valid_indices: List[int],
results_count: int,
) -> List[int]:
"""
将模型输出的局部索引映射回原始搜索结果索引。
"""
original_indices = []
seen = set()
for index in ai_indices:
if not 0 <= index < len(valid_indices):
continue
original_index = (
filtered_indices[valid_indices[index]]
if filtered_indices
else valid_indices[index]
)
if not 0 <= original_index < results_count or original_index in seen:
continue
seen.add(original_index)
original_indices.append(original_index)
return original_indices
async def _invoke_recommend_llm(self, search_results_text: str) -> str:
"""
通过统一后台提示词机制执行资源推荐。
"""
from app.agent import agent_manager
from app.agent.prompt import prompt_manager
prompt = prompt_manager.render_system_task_message(
"search_recommend",
template_context={"search_results": search_results_text},
)
full_output = [""]
def on_output(text: str):
full_output[0] = text
await agent_manager.run_background_prompt(
message=prompt,
session_prefix="__agent_search_recommend",
output_callback=on_output,
suppress_user_reply=True,
)
return full_output[0].strip()
def start_recommend_task(
self,
filtered_indices: Optional[List[int]],
search_results_count: int,
results: List[Any],
) -> None:
"""
启动AI推荐任务。
"""
if not self.is_ai_recommend_enabled:
logger.warning("AI推荐功能未启用跳过任务执行")
return
state = type(self)
request_hash = self._calculate_recommend_request_hash(
filtered_indices, search_results_count
)
if request_hash == state._current_recommend_request_hash:
return
self.cancel_ai_recommend()
state._current_recommend_request_hash = request_hash
async def run_recommend():
current_task = asyncio.current_task()
def is_current_request() -> bool:
return state._current_recommend_request_hash == request_hash
try:
state._ai_recommend_running = True
items, valid_indices = self._extract_recommend_items(
filtered_indices=filtered_indices,
results=results,
)
if not items:
if is_current_request():
state._ai_recommend_error = "没有可用于AI推荐的资源"
return
user_preference = (
settings.AI_RECOMMEND_USER_PREFERENCE
or "Prefer high-quality resources with more seeders"
)
search_results_text = (
f"User Preference: {user_preference}\n\n"
f"Candidate Resources:\n{chr(10).join(items)}"
)
ai_response = await self._invoke_recommend_llm(search_results_text)
if not ai_response:
if is_current_request():
state._ai_recommend_error = "AI推荐未返回结果"
return
json_match = re.search(r"\[.*?]", ai_response, re.DOTALL)
if not json_match:
raise ValueError(f"无法从响应中提取JSON数组: {ai_response}")
ai_indices = json.loads(json_match.group())
if not isinstance(ai_indices, list):
raise ValueError(f"AI返回格式错误: {ai_response}")
original_indices = self._restore_original_indices(
ai_indices=self._normalize_ai_indices(ai_indices),
filtered_indices=filtered_indices,
valid_indices=valid_indices,
results_count=len(results),
)
if not is_current_request():
logger.info("AI推荐结果已过期丢弃旧结果")
return
state._ai_recommend_result = original_indices
self.save_cache(original_indices, self.__ai_indices_cache_file)
logger.info(f"AI推荐完成: {len(original_indices)}")
except asyncio.CancelledError:
logger.info("AI推荐任务被取消")
except Exception as err:
logger.error(f"AI推荐任务失败: {err}")
if is_current_request():
state._ai_recommend_error = str(err)
finally:
if state._ai_recommend_task == current_task:
state._ai_recommend_running = False
state._ai_recommend_task = None
state._ai_recommend_task = asyncio.create_task(run_recommend())
def search_by_id(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
mtype: MediaType = None, area: Optional[str] = "title", season: Optional[int] = None,
@@ -44,6 +332,8 @@ class SearchChain(ChainBase):
:param sites: 站点ID列表
:param cache_local: 是否缓存到本地
"""
if cache_local:
self.cancel_ai_recommend()
mediainfo = self.recognize_media(tmdbid=tmdbid, doubanid=doubanid, mtype=mtype)
if not mediainfo:
logger.error(f'{tmdbid} 媒体信息识别失败!')
@@ -70,6 +360,8 @@ class SearchChain(ChainBase):
:param sites: 站点ID列表
:param cache_local: 是否缓存到本地
"""
if cache_local:
self.cancel_ai_recommend()
if title:
logger.info(f'开始搜索资源,关键词:{title} ...')
else:
@@ -99,18 +391,6 @@ class SearchChain(ChainBase):
"""
return await self.async_load_cache(self.__result_temp_file)
async def async_last_ai_results(self) -> Optional[List[Context]]:
"""
异步获取上次AI推荐结果
"""
return await self.async_load_cache(self.__ai_result_temp_file)
async def async_save_ai_results(self, results: List[Context]):
"""
异步保存AI推荐结果
"""
await self.async_save_cache(results, self.__ai_result_temp_file)
async def async_search_by_id(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
mtype: MediaType = None, area: Optional[str] = "title", season: Optional[int] = None,
sites: List[int] = None, cache_local: bool = False) -> List[Context]:
@@ -124,6 +404,8 @@ class SearchChain(ChainBase):
:param sites: 站点ID列表
:param cache_local: 是否缓存到本地
"""
if cache_local:
self.cancel_ai_recommend()
mediainfo = await self.async_recognize_media(tmdbid=tmdbid, doubanid=doubanid, mtype=mtype)
if not mediainfo:
logger.error(f'{tmdbid} 媒体信息识别失败!')
@@ -150,6 +432,8 @@ class SearchChain(ChainBase):
:param sites: 站点ID列表
:param cache_local: 是否缓存到本地
"""
if cache_local:
self.cancel_ai_recommend()
if title:
logger.info(f'开始搜索资源,关键词:{title} ...')
else:
@@ -173,6 +457,8 @@ class SearchChain(ChainBase):
"""
根据标题渐进式搜索资源,不识别不过滤,按站点完成顺序返回结果
"""
if cache_local:
self.cancel_ai_recommend()
if title:
logger.info(f'开始渐进式搜索资源,关键词:{title} ...')
else:
@@ -214,6 +500,8 @@ class SearchChain(ChainBase):
"""
根据TMDBID/豆瓣ID渐进式搜索资源先返回站点原始候选再返回过滤匹配后的最终结果
"""
if cache_local:
self.cancel_ai_recommend()
mediainfo = await self.async_recognize_media(tmdbid=tmdbid, doubanid=doubanid, mtype=mtype)
if not mediainfo:
logger.error(f'{tmdbid} 媒体信息识别失败!')

View File

@@ -0,0 +1,126 @@
import asyncio
import sys
import unittest
from types import SimpleNamespace
from types import ModuleType
from unittest.mock import AsyncMock, patch
def _stub_module(name: str, **attrs):
module = sys.modules.get(name)
if module is None:
module = ModuleType(name)
sys.modules[name] = module
for key, value in attrs.items():
setattr(module, key, value)
return module
_stub_module("qbittorrentapi", TorrentFilesList=list)
_stub_module("transmission_rpc", File=object)
from app.chain.search import SearchChain
from app.core.config import settings
def _make_result(title: str, size: int, seeders: int):
return SimpleNamespace(
torrent_info=SimpleNamespace(title=title, size=size, seeders=seeders)
)
class SearchChainAIRecommendTest(unittest.IsolatedAsyncioTestCase):
def setUp(self):
SearchChain._ai_recommend_running = False
SearchChain._ai_recommend_task = None
SearchChain._current_recommend_request_hash = None
SearchChain._ai_recommend_result = None
SearchChain._ai_recommend_error = None
async def asyncTearDown(self):
task = SearchChain._ai_recommend_task
if task and not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
SearchChain._ai_recommend_running = False
SearchChain._ai_recommend_task = None
SearchChain._current_recommend_request_hash = None
SearchChain._ai_recommend_result = None
SearchChain._ai_recommend_error = None
@staticmethod
def _make_chain() -> SearchChain:
chain = object.__new__(SearchChain)
chain.load_cache = lambda _filename: None
chain.save_cache = lambda _cache, _filename: None
chain.remove_cache = lambda _filename: None
return chain
async def test_start_recommend_task_restores_original_indices(self):
chain = self._make_chain()
saved = []
chain.save_cache = lambda cache, filename: saved.append((filename, cache))
results = [_make_result(f"item-{index}", 1024 * (index + 1), index) for index in range(7)]
with (
patch.object(settings, "AI_AGENT_ENABLE", True, create=True),
patch.object(settings, "AI_RECOMMEND_ENABLED", True, create=True),
patch.object(settings, "AI_RECOMMEND_MAX_ITEMS", 50, create=True),
patch.object(
settings,
"AI_RECOMMEND_USER_PREFERENCE",
"Prefer high seeders",
create=True,
),
patch.object(
SearchChain,
"_invoke_recommend_llm",
new=AsyncMock(return_value='[1, 0, 1, "bad", 9]'),
),
):
chain.start_recommend_task(
filtered_indices=[2, 4, 6],
search_results_count=len(results),
results=results,
)
self.assertIsNotNone(SearchChain._ai_recommend_task)
await SearchChain._ai_recommend_task
self.assertEqual([4, 2], SearchChain._ai_recommend_result)
self.assertEqual(
[("__ai_recommend_indices__", [4, 2])],
saved,
)
self.assertFalse(SearchChain._ai_recommend_running)
self.assertIsNone(SearchChain._ai_recommend_task)
def test_search_by_title_clears_previous_recommend_state_when_caching(self):
chain = self._make_chain()
removed = []
cached = []
chain.remove_cache = lambda filename: removed.append(filename)
chain.save_cache = lambda cache, filename: cached.append((filename, cache))
chain._SearchChain__search_all_sites = lambda keyword, sites, page: [
SimpleNamespace(title="Test Title", description="Test Desc")
]
SearchChain._current_recommend_request_hash = "stale-hash"
SearchChain._ai_recommend_result = [3, 1]
SearchChain._ai_recommend_error = "stale-error"
results = chain.search_by_title("keyword", cache_local=True)
self.assertEqual(1, len(results))
self.assertEqual(["__ai_recommend_indices__"], removed)
self.assertEqual("__search_result__", cached[0][0])
self.assertIsNone(SearchChain._current_recommend_request_hash)
self.assertIsNone(SearchChain._ai_recommend_result)
self.assertIsNone(SearchChain._ai_recommend_error)
if __name__ == "__main__":
unittest.main()