mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-06 20:42:43 +08:00
refactor search AI recommendation flow
This commit is contained in:
@@ -5,7 +5,6 @@ from fastapi import APIRouter, Depends, Body, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app import schemas
|
||||
from app.chain.agent import AIRecommendChain
|
||||
from app.chain.media import MediaChain
|
||||
from app.chain.search import SearchChain
|
||||
from app.core.config import settings
|
||||
@@ -204,8 +203,6 @@ async def search_by_id(mediaid: str,
|
||||
"""
|
||||
根据TMDBID/豆瓣ID精确搜索站点资源 tmdb:/douban:/bangumi:
|
||||
"""
|
||||
# 取消正在运行的AI推荐(会清除数据库缓存)
|
||||
|
||||
if mtype:
|
||||
media_type = MediaType(mtype)
|
||||
else:
|
||||
@@ -348,8 +345,6 @@ async def search_by_title(keyword: Optional[str] = None,
|
||||
"""
|
||||
根据名称模糊搜索站点资源,支持分页,关键词为空是返回首页资源
|
||||
"""
|
||||
# 取消正在运行的AI推荐并清除数据库缓存
|
||||
|
||||
torrents = await SearchChain().async_search_by_title(
|
||||
title=keyword, page=page,
|
||||
sites=_parse_site_list(sites),
|
||||
@@ -393,12 +388,12 @@ async def recommend_search_results(
|
||||
"status": "error"
|
||||
})
|
||||
|
||||
recommend_chain = AIRecommendChain()
|
||||
recommend_chain = SearchChain()
|
||||
|
||||
# 如果是强制模式,先取消并清除旧结果,然后直接启动新任务
|
||||
if force:
|
||||
# 检查功能是否启用
|
||||
if not settings.AI_AGENT_ENABLE or not settings.AI_RECOMMEND_ENABLED:
|
||||
if not recommend_chain.is_ai_recommend_enabled:
|
||||
return schemas.Response(success=True, data={
|
||||
"status": "disabled"
|
||||
})
|
||||
@@ -413,7 +408,7 @@ async def recommend_search_results(
|
||||
# 如果是仅检查模式,不传递 filtered_indices(避免触发请求变化检测)
|
||||
if check_only:
|
||||
# 返回当前运行状态,不做任何任务启动或取消操作
|
||||
current_status = recommend_chain.get_current_status_only()
|
||||
current_status = recommend_chain.get_current_recommend_status_only()
|
||||
# 如果有错误,将错误信息放到message中
|
||||
if current_status.get("status") == "error":
|
||||
error_msg = current_status.pop("error", "未知错误")
|
||||
@@ -421,7 +416,7 @@ async def recommend_search_results(
|
||||
return schemas.Response(success=True, data=current_status)
|
||||
|
||||
# 获取当前状态(会检测请求是否变化)
|
||||
status_data = recommend_chain.get_status(filtered_indices, len(results))
|
||||
status_data = recommend_chain.get_recommend_status(filtered_indices, len(results))
|
||||
|
||||
# 如果功能未启用,直接返回禁用状态
|
||||
if status_data.get("status") == "disabled":
|
||||
|
||||
@@ -1,265 +0,0 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.agent import agent_manager, prompt_manager
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
from app.utils.singleton import Singleton
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
class AIRecommendChain(ChainBase, metaclass=Singleton):
|
||||
"""
|
||||
AI推荐处理链,单例运行
|
||||
用于基于搜索结果的AI智能推荐
|
||||
使用 agent_manager.run_background_prompt 统一后台任务机制
|
||||
"""
|
||||
|
||||
__ai_indices_cache_file = "__ai_recommend_indices__"
|
||||
|
||||
_ai_recommend_running = False
|
||||
_ai_recommend_task: Optional[asyncio.Task] = None
|
||||
_current_request_hash: Optional[str] = None
|
||||
_ai_recommend_result: Optional[List[int]] = None
|
||||
_ai_recommend_error: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def _calculate_request_hash(
|
||||
filtered_indices: Optional[List[int]], search_results_count: int
|
||||
) -> str:
|
||||
"""
|
||||
计算请求的哈希值,用于判断请求是否变化
|
||||
"""
|
||||
request_data = {
|
||||
"filtered_indices": filtered_indices or [],
|
||||
"search_results_count": search_results_count,
|
||||
}
|
||||
return hashlib.md5(
|
||||
json.dumps(request_data, sort_keys=True).encode()
|
||||
).hexdigest()
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
"""
|
||||
检查AI推荐功能是否已启用。
|
||||
"""
|
||||
return settings.AI_AGENT_ENABLE and settings.AI_RECOMMEND_ENABLED
|
||||
|
||||
def _build_status(self) -> Dict[str, Any]:
|
||||
"""
|
||||
构建AI推荐状态字典
|
||||
"""
|
||||
if not self.is_enabled:
|
||||
return {"status": "disabled"}
|
||||
|
||||
if self._ai_recommend_running:
|
||||
return {"status": "running"}
|
||||
|
||||
if self._ai_recommend_result is None:
|
||||
cached_indices = self.load_cache(self.__ai_indices_cache_file)
|
||||
if cached_indices is not None:
|
||||
self._ai_recommend_result = cached_indices
|
||||
|
||||
if self._ai_recommend_result is not None:
|
||||
return {"status": "completed", "results": self._ai_recommend_result}
|
||||
|
||||
if self._ai_recommend_error is not None:
|
||||
return {"status": "error", "error": self._ai_recommend_error}
|
||||
|
||||
return {"status": "idle"}
|
||||
|
||||
def get_current_status_only(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取当前状态(不校验hash,用于check_only模式)
|
||||
"""
|
||||
return self._build_status()
|
||||
|
||||
def get_status(
|
||||
self, filtered_indices: Optional[List[int]], search_results_count: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取AI推荐状态并检查请求是否变化(用于首次请求或force模式)
|
||||
如果请求变化(筛选条件变化),返回idle状态
|
||||
"""
|
||||
request_hash = self._calculate_request_hash(
|
||||
filtered_indices, search_results_count
|
||||
)
|
||||
is_same_request = request_hash == self._current_request_hash
|
||||
|
||||
if not is_same_request:
|
||||
return {"status": "idle"} if self.is_enabled else {"status": "disabled"}
|
||||
|
||||
return self._build_status()
|
||||
|
||||
def is_ai_recommend_running(self) -> bool:
|
||||
"""
|
||||
检查AI推荐是否正在运行
|
||||
"""
|
||||
return self._ai_recommend_running
|
||||
|
||||
def cancel_ai_recommend(self):
|
||||
"""
|
||||
取消正在运行的AI推荐任务
|
||||
"""
|
||||
if self._ai_recommend_task and not self._ai_recommend_task.done():
|
||||
self._ai_recommend_task.cancel()
|
||||
self._ai_recommend_running = False
|
||||
self._ai_recommend_task = None
|
||||
self._current_request_hash = None
|
||||
self._ai_recommend_result = None
|
||||
self._ai_recommend_error = None
|
||||
self.remove_cache(self.__ai_indices_cache_file)
|
||||
|
||||
def start_recommend_task(
|
||||
self,
|
||||
filtered_indices: Optional[List[int]],
|
||||
search_results_count: int,
|
||||
results: List[Any],
|
||||
) -> None:
|
||||
"""
|
||||
启动AI推荐任务
|
||||
使用 agent_manager.run_background_prompt 后台Agent机制执行推荐
|
||||
:param filtered_indices: 筛选后的索引列表
|
||||
:param search_results_count: 搜索结果总数
|
||||
:param results: 搜索结果列表
|
||||
"""
|
||||
if not self.is_enabled:
|
||||
logger.warning("AI推荐功能未启用,跳过任务执行")
|
||||
return
|
||||
|
||||
new_request_hash = self._calculate_request_hash(
|
||||
filtered_indices, search_results_count
|
||||
)
|
||||
|
||||
if new_request_hash != self._current_request_hash:
|
||||
self.cancel_ai_recommend()
|
||||
self._current_request_hash = new_request_hash
|
||||
self._ai_recommend_result = None
|
||||
self._ai_recommend_error = None
|
||||
|
||||
async def run_recommend():
|
||||
current_task = asyncio.current_task()
|
||||
try:
|
||||
self._ai_recommend_running = True
|
||||
|
||||
items = []
|
||||
valid_indices = []
|
||||
max_items = settings.AI_RECOMMEND_MAX_ITEMS or 50
|
||||
|
||||
if filtered_indices is not None and len(filtered_indices) > 0:
|
||||
results_to_process = [
|
||||
results[i]
|
||||
for i in filtered_indices
|
||||
if 0 <= i < len(results)
|
||||
]
|
||||
else:
|
||||
results_to_process = results
|
||||
|
||||
for i, torrent in enumerate(results_to_process):
|
||||
if len(items) >= max_items:
|
||||
break
|
||||
|
||||
if not torrent.torrent_info:
|
||||
continue
|
||||
|
||||
valid_indices.append(i)
|
||||
|
||||
item_info = {
|
||||
"index": i,
|
||||
"title": torrent.torrent_info.title or "未知",
|
||||
"size": (
|
||||
StringUtils.format_size(torrent.torrent_info.size)
|
||||
if torrent.torrent_info.size
|
||||
else "0 B"
|
||||
),
|
||||
"seeders": torrent.torrent_info.seeders or 0,
|
||||
}
|
||||
|
||||
items.append(json.dumps(item_info, ensure_ascii=False))
|
||||
|
||||
if not items:
|
||||
self._ai_recommend_error = "没有可用于AI推荐的资源"
|
||||
return
|
||||
|
||||
user_preference = (
|
||||
settings.AI_RECOMMEND_USER_PREFERENCE
|
||||
or "Prefer high-quality resources with more seeders"
|
||||
)
|
||||
|
||||
search_results_text = "User Preference: {preference}\n\nCandidate Resources:\n{items}".format(
|
||||
preference=user_preference, items="\n".join(items)
|
||||
)
|
||||
|
||||
prompt = prompt_manager.render_system_task_message(
|
||||
"search_recommend",
|
||||
template_context={"search_results": search_results_text},
|
||||
)
|
||||
|
||||
full_output = [""]
|
||||
|
||||
def on_output(text: str):
|
||||
full_output[0] = text
|
||||
|
||||
await agent_manager.run_background_prompt(
|
||||
message=prompt,
|
||||
session_prefix="__agent_search_recommend",
|
||||
output_callback=on_output,
|
||||
suppress_user_reply=True,
|
||||
)
|
||||
|
||||
ai_response = full_output[0]
|
||||
if not ai_response:
|
||||
self._ai_recommend_error = "AI推荐未返回结果"
|
||||
return
|
||||
|
||||
try:
|
||||
json_match = re.search(r"\[.*?]", ai_response, re.DOTALL)
|
||||
if not json_match:
|
||||
raise ValueError(f"无法从响应中提取JSON数组: {ai_response}")
|
||||
|
||||
ai_indices = json.loads(json_match.group())
|
||||
if not isinstance(ai_indices, list):
|
||||
raise ValueError(f"AI返回格式错误: {ai_response}")
|
||||
|
||||
if filtered_indices:
|
||||
original_indices = [
|
||||
filtered_indices[valid_indices[i]]
|
||||
for i in ai_indices
|
||||
if i < len(valid_indices)
|
||||
and 0
|
||||
<= filtered_indices[valid_indices[i]]
|
||||
< len(results)
|
||||
]
|
||||
else:
|
||||
original_indices = [
|
||||
valid_indices[i]
|
||||
for i in ai_indices
|
||||
if i < len(valid_indices)
|
||||
and 0 <= valid_indices[i] < len(results)
|
||||
]
|
||||
|
||||
self._ai_recommend_result = original_indices
|
||||
self.save_cache(original_indices, self.__ai_indices_cache_file)
|
||||
logger.info(f"AI推荐完成: {len(original_indices)}项")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"解析AI返回结果失败: {e}, 原始响应: {ai_response}"
|
||||
)
|
||||
self._ai_recommend_error = str(e)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("AI推荐任务被取消")
|
||||
except Exception as e:
|
||||
logger.error(f"AI推荐任务失败: {e}")
|
||||
self._ai_recommend_error = str(e)
|
||||
finally:
|
||||
if self._ai_recommend_task == current_task:
|
||||
self._ai_recommend_running = False
|
||||
self._ai_recommend_task = None
|
||||
|
||||
self._ai_recommend_task = asyncio.create_task(run_recommend())
|
||||
@@ -1,5 +1,8 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from datetime import datetime
|
||||
@@ -21,6 +24,7 @@ from app.helper.torrent import TorrentHelper
|
||||
from app.log import logger
|
||||
from app.schemas import NotExistMediaInfo
|
||||
from app.schemas.types import MediaType, ProgressKey, SystemConfigKey, EventType
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
class SearchChain(ChainBase):
|
||||
@@ -29,7 +33,291 @@ class SearchChain(ChainBase):
|
||||
"""
|
||||
|
||||
__result_temp_file = "__search_result__"
|
||||
__ai_result_temp_file = "__ai_search_result__"
|
||||
__ai_indices_cache_file = "__ai_recommend_indices__"
|
||||
|
||||
_ai_recommend_running = False
|
||||
_ai_recommend_task: Optional[asyncio.Task] = None
|
||||
_current_recommend_request_hash: Optional[str] = None
|
||||
_ai_recommend_result: Optional[List[int]] = None
|
||||
_ai_recommend_error: Optional[str] = None
|
||||
|
||||
@property
|
||||
def is_ai_recommend_enabled(self) -> bool:
|
||||
"""
|
||||
检查AI推荐功能是否已启用。
|
||||
"""
|
||||
return settings.AI_AGENT_ENABLE and settings.AI_RECOMMEND_ENABLED
|
||||
|
||||
@staticmethod
|
||||
def _calculate_recommend_request_hash(
|
||||
filtered_indices: Optional[List[int]], search_results_count: int
|
||||
) -> str:
|
||||
"""
|
||||
计算当前推荐请求哈希,用于识别筛选条件是否变化。
|
||||
"""
|
||||
request_data = {
|
||||
"filtered_indices": filtered_indices or [],
|
||||
"search_results_count": search_results_count,
|
||||
}
|
||||
return hashlib.md5(
|
||||
json.dumps(request_data, sort_keys=True).encode()
|
||||
).hexdigest()
|
||||
|
||||
def _build_ai_recommend_status(self) -> Dict[str, Any]:
|
||||
"""
|
||||
构建AI推荐状态字典。
|
||||
"""
|
||||
state = type(self)
|
||||
if not self.is_ai_recommend_enabled:
|
||||
return {"status": "disabled"}
|
||||
|
||||
if state._ai_recommend_running:
|
||||
return {"status": "running"}
|
||||
|
||||
if state._ai_recommend_result is None:
|
||||
cached_indices = self.load_cache(self.__ai_indices_cache_file)
|
||||
if cached_indices is not None:
|
||||
state._ai_recommend_result = cached_indices
|
||||
|
||||
if state._ai_recommend_result is not None:
|
||||
return {"status": "completed", "results": state._ai_recommend_result}
|
||||
|
||||
if state._ai_recommend_error is not None:
|
||||
return {"status": "error", "error": state._ai_recommend_error}
|
||||
|
||||
return {"status": "idle"}
|
||||
|
||||
def get_current_recommend_status_only(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取当前推荐状态,不校验请求是否变化。
|
||||
"""
|
||||
return self._build_ai_recommend_status()
|
||||
|
||||
def get_recommend_status(
|
||||
self, filtered_indices: Optional[List[int]], search_results_count: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取AI推荐状态,并在筛选条件变化时返回 idle。
|
||||
"""
|
||||
state = type(self)
|
||||
request_hash = self._calculate_recommend_request_hash(
|
||||
filtered_indices, search_results_count
|
||||
)
|
||||
if request_hash != state._current_recommend_request_hash:
|
||||
return {"status": "idle"} if self.is_ai_recommend_enabled else {"status": "disabled"}
|
||||
return self._build_ai_recommend_status()
|
||||
|
||||
def cancel_ai_recommend(self):
|
||||
"""
|
||||
取消当前AI推荐任务并清空缓存状态。
|
||||
"""
|
||||
state = type(self)
|
||||
if state._ai_recommend_task and not state._ai_recommend_task.done():
|
||||
state._ai_recommend_task.cancel()
|
||||
state._ai_recommend_running = False
|
||||
state._ai_recommend_task = None
|
||||
state._current_recommend_request_hash = None
|
||||
state._ai_recommend_result = None
|
||||
state._ai_recommend_error = None
|
||||
self.remove_cache(self.__ai_indices_cache_file)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_ai_indices(ai_indices: List[Any]) -> List[int]:
|
||||
"""
|
||||
过滤模型返回的非法或重复索引,保留原顺序。
|
||||
"""
|
||||
normalized = []
|
||||
seen = set()
|
||||
for index in ai_indices:
|
||||
try:
|
||||
value = int(index)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if value in seen:
|
||||
continue
|
||||
seen.add(value)
|
||||
normalized.append(value)
|
||||
return normalized
|
||||
|
||||
@staticmethod
|
||||
def _extract_recommend_items(
|
||||
filtered_indices: Optional[List[int]], results: List[Any]
|
||||
) -> tuple[List[str], List[int]]:
|
||||
"""
|
||||
构建发送给模型的候选列表和索引映射。
|
||||
"""
|
||||
items: List[str] = []
|
||||
valid_indices: List[int] = []
|
||||
max_items = settings.AI_RECOMMEND_MAX_ITEMS or 50
|
||||
|
||||
if filtered_indices:
|
||||
results_to_process = [
|
||||
results[index] for index in filtered_indices if 0 <= index < len(results)
|
||||
]
|
||||
else:
|
||||
results_to_process = results
|
||||
|
||||
for index, torrent in enumerate(results_to_process):
|
||||
if len(items) >= max_items:
|
||||
break
|
||||
if not torrent.torrent_info:
|
||||
continue
|
||||
|
||||
valid_indices.append(index)
|
||||
item_info = {
|
||||
"index": index,
|
||||
"title": torrent.torrent_info.title or "未知",
|
||||
"size": (
|
||||
StringUtils.format_size(torrent.torrent_info.size)
|
||||
if torrent.torrent_info.size
|
||||
else "0 B"
|
||||
),
|
||||
"seeders": torrent.torrent_info.seeders or 0,
|
||||
}
|
||||
items.append(json.dumps(item_info, ensure_ascii=False))
|
||||
|
||||
return items, valid_indices
|
||||
|
||||
@staticmethod
|
||||
def _restore_original_indices(
|
||||
ai_indices: List[int],
|
||||
filtered_indices: Optional[List[int]],
|
||||
valid_indices: List[int],
|
||||
results_count: int,
|
||||
) -> List[int]:
|
||||
"""
|
||||
将模型输出的局部索引映射回原始搜索结果索引。
|
||||
"""
|
||||
original_indices = []
|
||||
seen = set()
|
||||
|
||||
for index in ai_indices:
|
||||
if not 0 <= index < len(valid_indices):
|
||||
continue
|
||||
original_index = (
|
||||
filtered_indices[valid_indices[index]]
|
||||
if filtered_indices
|
||||
else valid_indices[index]
|
||||
)
|
||||
if not 0 <= original_index < results_count or original_index in seen:
|
||||
continue
|
||||
seen.add(original_index)
|
||||
original_indices.append(original_index)
|
||||
|
||||
return original_indices
|
||||
|
||||
async def _invoke_recommend_llm(self, search_results_text: str) -> str:
|
||||
"""
|
||||
通过统一后台提示词机制执行资源推荐。
|
||||
"""
|
||||
from app.agent import agent_manager
|
||||
from app.agent.prompt import prompt_manager
|
||||
|
||||
prompt = prompt_manager.render_system_task_message(
|
||||
"search_recommend",
|
||||
template_context={"search_results": search_results_text},
|
||||
)
|
||||
full_output = [""]
|
||||
|
||||
def on_output(text: str):
|
||||
full_output[0] = text
|
||||
|
||||
await agent_manager.run_background_prompt(
|
||||
message=prompt,
|
||||
session_prefix="__agent_search_recommend",
|
||||
output_callback=on_output,
|
||||
suppress_user_reply=True,
|
||||
)
|
||||
return full_output[0].strip()
|
||||
|
||||
def start_recommend_task(
|
||||
self,
|
||||
filtered_indices: Optional[List[int]],
|
||||
search_results_count: int,
|
||||
results: List[Any],
|
||||
) -> None:
|
||||
"""
|
||||
启动AI推荐任务。
|
||||
"""
|
||||
if not self.is_ai_recommend_enabled:
|
||||
logger.warning("AI推荐功能未启用,跳过任务执行")
|
||||
return
|
||||
|
||||
state = type(self)
|
||||
request_hash = self._calculate_recommend_request_hash(
|
||||
filtered_indices, search_results_count
|
||||
)
|
||||
if request_hash == state._current_recommend_request_hash:
|
||||
return
|
||||
|
||||
self.cancel_ai_recommend()
|
||||
state._current_recommend_request_hash = request_hash
|
||||
|
||||
async def run_recommend():
|
||||
current_task = asyncio.current_task()
|
||||
|
||||
def is_current_request() -> bool:
|
||||
return state._current_recommend_request_hash == request_hash
|
||||
|
||||
try:
|
||||
state._ai_recommend_running = True
|
||||
|
||||
items, valid_indices = self._extract_recommend_items(
|
||||
filtered_indices=filtered_indices,
|
||||
results=results,
|
||||
)
|
||||
if not items:
|
||||
if is_current_request():
|
||||
state._ai_recommend_error = "没有可用于AI推荐的资源"
|
||||
return
|
||||
|
||||
user_preference = (
|
||||
settings.AI_RECOMMEND_USER_PREFERENCE
|
||||
or "Prefer high-quality resources with more seeders"
|
||||
)
|
||||
search_results_text = (
|
||||
f"User Preference: {user_preference}\n\n"
|
||||
f"Candidate Resources:\n{chr(10).join(items)}"
|
||||
)
|
||||
ai_response = await self._invoke_recommend_llm(search_results_text)
|
||||
if not ai_response:
|
||||
if is_current_request():
|
||||
state._ai_recommend_error = "AI推荐未返回结果"
|
||||
return
|
||||
|
||||
json_match = re.search(r"\[.*?]", ai_response, re.DOTALL)
|
||||
if not json_match:
|
||||
raise ValueError(f"无法从响应中提取JSON数组: {ai_response}")
|
||||
|
||||
ai_indices = json.loads(json_match.group())
|
||||
if not isinstance(ai_indices, list):
|
||||
raise ValueError(f"AI返回格式错误: {ai_response}")
|
||||
|
||||
original_indices = self._restore_original_indices(
|
||||
ai_indices=self._normalize_ai_indices(ai_indices),
|
||||
filtered_indices=filtered_indices,
|
||||
valid_indices=valid_indices,
|
||||
results_count=len(results),
|
||||
)
|
||||
if not is_current_request():
|
||||
logger.info("AI推荐结果已过期,丢弃旧结果")
|
||||
return
|
||||
|
||||
state._ai_recommend_result = original_indices
|
||||
self.save_cache(original_indices, self.__ai_indices_cache_file)
|
||||
logger.info(f"AI推荐完成: {len(original_indices)}项")
|
||||
except asyncio.CancelledError:
|
||||
logger.info("AI推荐任务被取消")
|
||||
except Exception as err:
|
||||
logger.error(f"AI推荐任务失败: {err}")
|
||||
if is_current_request():
|
||||
state._ai_recommend_error = str(err)
|
||||
finally:
|
||||
if state._ai_recommend_task == current_task:
|
||||
state._ai_recommend_running = False
|
||||
state._ai_recommend_task = None
|
||||
|
||||
state._ai_recommend_task = asyncio.create_task(run_recommend())
|
||||
|
||||
def search_by_id(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
mtype: MediaType = None, area: Optional[str] = "title", season: Optional[int] = None,
|
||||
@@ -44,6 +332,8 @@ class SearchChain(ChainBase):
|
||||
:param sites: 站点ID列表
|
||||
:param cache_local: 是否缓存到本地
|
||||
"""
|
||||
if cache_local:
|
||||
self.cancel_ai_recommend()
|
||||
mediainfo = self.recognize_media(tmdbid=tmdbid, doubanid=doubanid, mtype=mtype)
|
||||
if not mediainfo:
|
||||
logger.error(f'{tmdbid} 媒体信息识别失败!')
|
||||
@@ -70,6 +360,8 @@ class SearchChain(ChainBase):
|
||||
:param sites: 站点ID列表
|
||||
:param cache_local: 是否缓存到本地
|
||||
"""
|
||||
if cache_local:
|
||||
self.cancel_ai_recommend()
|
||||
if title:
|
||||
logger.info(f'开始搜索资源,关键词:{title} ...')
|
||||
else:
|
||||
@@ -99,18 +391,6 @@ class SearchChain(ChainBase):
|
||||
"""
|
||||
return await self.async_load_cache(self.__result_temp_file)
|
||||
|
||||
async def async_last_ai_results(self) -> Optional[List[Context]]:
|
||||
"""
|
||||
异步获取上次AI推荐结果
|
||||
"""
|
||||
return await self.async_load_cache(self.__ai_result_temp_file)
|
||||
|
||||
async def async_save_ai_results(self, results: List[Context]):
|
||||
"""
|
||||
异步保存AI推荐结果
|
||||
"""
|
||||
await self.async_save_cache(results, self.__ai_result_temp_file)
|
||||
|
||||
async def async_search_by_id(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
mtype: MediaType = None, area: Optional[str] = "title", season: Optional[int] = None,
|
||||
sites: List[int] = None, cache_local: bool = False) -> List[Context]:
|
||||
@@ -124,6 +404,8 @@ class SearchChain(ChainBase):
|
||||
:param sites: 站点ID列表
|
||||
:param cache_local: 是否缓存到本地
|
||||
"""
|
||||
if cache_local:
|
||||
self.cancel_ai_recommend()
|
||||
mediainfo = await self.async_recognize_media(tmdbid=tmdbid, doubanid=doubanid, mtype=mtype)
|
||||
if not mediainfo:
|
||||
logger.error(f'{tmdbid} 媒体信息识别失败!')
|
||||
@@ -150,6 +432,8 @@ class SearchChain(ChainBase):
|
||||
:param sites: 站点ID列表
|
||||
:param cache_local: 是否缓存到本地
|
||||
"""
|
||||
if cache_local:
|
||||
self.cancel_ai_recommend()
|
||||
if title:
|
||||
logger.info(f'开始搜索资源,关键词:{title} ...')
|
||||
else:
|
||||
@@ -173,6 +457,8 @@ class SearchChain(ChainBase):
|
||||
"""
|
||||
根据标题渐进式搜索资源,不识别不过滤,按站点完成顺序返回结果
|
||||
"""
|
||||
if cache_local:
|
||||
self.cancel_ai_recommend()
|
||||
if title:
|
||||
logger.info(f'开始渐进式搜索资源,关键词:{title} ...')
|
||||
else:
|
||||
@@ -214,6 +500,8 @@ class SearchChain(ChainBase):
|
||||
"""
|
||||
根据TMDBID/豆瓣ID渐进式搜索资源,先返回站点原始候选,再返回过滤匹配后的最终结果
|
||||
"""
|
||||
if cache_local:
|
||||
self.cancel_ai_recommend()
|
||||
mediainfo = await self.async_recognize_media(tmdbid=tmdbid, doubanid=doubanid, mtype=mtype)
|
||||
if not mediainfo:
|
||||
logger.error(f'{tmdbid} 媒体信息识别失败!')
|
||||
|
||||
126
tests/test_search_ai_recommend.py
Normal file
126
tests/test_search_ai_recommend.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from types import ModuleType
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
|
||||
def _stub_module(name: str, **attrs):
|
||||
module = sys.modules.get(name)
|
||||
if module is None:
|
||||
module = ModuleType(name)
|
||||
sys.modules[name] = module
|
||||
for key, value in attrs.items():
|
||||
setattr(module, key, value)
|
||||
return module
|
||||
|
||||
|
||||
_stub_module("qbittorrentapi", TorrentFilesList=list)
|
||||
_stub_module("transmission_rpc", File=object)
|
||||
|
||||
from app.chain.search import SearchChain
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
def _make_result(title: str, size: int, seeders: int):
|
||||
return SimpleNamespace(
|
||||
torrent_info=SimpleNamespace(title=title, size=size, seeders=seeders)
|
||||
)
|
||||
|
||||
|
||||
class SearchChainAIRecommendTest(unittest.IsolatedAsyncioTestCase):
|
||||
def setUp(self):
|
||||
SearchChain._ai_recommend_running = False
|
||||
SearchChain._ai_recommend_task = None
|
||||
SearchChain._current_recommend_request_hash = None
|
||||
SearchChain._ai_recommend_result = None
|
||||
SearchChain._ai_recommend_error = None
|
||||
|
||||
async def asyncTearDown(self):
|
||||
task = SearchChain._ai_recommend_task
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
SearchChain._ai_recommend_running = False
|
||||
SearchChain._ai_recommend_task = None
|
||||
SearchChain._current_recommend_request_hash = None
|
||||
SearchChain._ai_recommend_result = None
|
||||
SearchChain._ai_recommend_error = None
|
||||
|
||||
@staticmethod
|
||||
def _make_chain() -> SearchChain:
|
||||
chain = object.__new__(SearchChain)
|
||||
chain.load_cache = lambda _filename: None
|
||||
chain.save_cache = lambda _cache, _filename: None
|
||||
chain.remove_cache = lambda _filename: None
|
||||
return chain
|
||||
|
||||
async def test_start_recommend_task_restores_original_indices(self):
|
||||
chain = self._make_chain()
|
||||
saved = []
|
||||
chain.save_cache = lambda cache, filename: saved.append((filename, cache))
|
||||
results = [_make_result(f"item-{index}", 1024 * (index + 1), index) for index in range(7)]
|
||||
|
||||
with (
|
||||
patch.object(settings, "AI_AGENT_ENABLE", True, create=True),
|
||||
patch.object(settings, "AI_RECOMMEND_ENABLED", True, create=True),
|
||||
patch.object(settings, "AI_RECOMMEND_MAX_ITEMS", 50, create=True),
|
||||
patch.object(
|
||||
settings,
|
||||
"AI_RECOMMEND_USER_PREFERENCE",
|
||||
"Prefer high seeders",
|
||||
create=True,
|
||||
),
|
||||
patch.object(
|
||||
SearchChain,
|
||||
"_invoke_recommend_llm",
|
||||
new=AsyncMock(return_value='[1, 0, 1, "bad", 9]'),
|
||||
),
|
||||
):
|
||||
chain.start_recommend_task(
|
||||
filtered_indices=[2, 4, 6],
|
||||
search_results_count=len(results),
|
||||
results=results,
|
||||
)
|
||||
self.assertIsNotNone(SearchChain._ai_recommend_task)
|
||||
await SearchChain._ai_recommend_task
|
||||
|
||||
self.assertEqual([4, 2], SearchChain._ai_recommend_result)
|
||||
self.assertEqual(
|
||||
[("__ai_recommend_indices__", [4, 2])],
|
||||
saved,
|
||||
)
|
||||
self.assertFalse(SearchChain._ai_recommend_running)
|
||||
self.assertIsNone(SearchChain._ai_recommend_task)
|
||||
|
||||
def test_search_by_title_clears_previous_recommend_state_when_caching(self):
|
||||
chain = self._make_chain()
|
||||
removed = []
|
||||
cached = []
|
||||
chain.remove_cache = lambda filename: removed.append(filename)
|
||||
chain.save_cache = lambda cache, filename: cached.append((filename, cache))
|
||||
chain._SearchChain__search_all_sites = lambda keyword, sites, page: [
|
||||
SimpleNamespace(title="Test Title", description="Test Desc")
|
||||
]
|
||||
|
||||
SearchChain._current_recommend_request_hash = "stale-hash"
|
||||
SearchChain._ai_recommend_result = [3, 1]
|
||||
SearchChain._ai_recommend_error = "stale-error"
|
||||
|
||||
results = chain.search_by_title("keyword", cache_local=True)
|
||||
|
||||
self.assertEqual(1, len(results))
|
||||
self.assertEqual(["__ai_recommend_indices__"], removed)
|
||||
self.assertEqual("__search_result__", cached[0][0])
|
||||
self.assertIsNone(SearchChain._current_recommend_request_hash)
|
||||
self.assertIsNone(SearchChain._ai_recommend_result)
|
||||
self.assertIsNone(SearchChain._ai_recommend_error)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user