feat(search): cache and expose last search parameters for replay and context retrieval

- Add methods to save and retrieve last search parameters in SearchChain
- Persist search params alongside results for replayable search context
- Add /last/context endpoint to fetch last search results and parameters
- Update tests to cover search param caching logic
- Allow images.tmdb.org in SECURITY_IMAGE_DOMAINS
This commit is contained in:
jxxghp
2026-05-15 22:43:40 +08:00
parent 1f49f9b454
commit 9b23265c3b
4 changed files with 200 additions and 1 deletions

View File

@@ -144,6 +144,23 @@ async def search_latest(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
return [torrent.to_dict() for torrent in torrents]
@router.get("/last/context", summary="查询上次搜索上下文", response_model=schemas.Response)
async def search_latest_context(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询上次搜索结果及其对应的搜索参数。
"""
search_chain = SearchChain()
torrents = await search_chain.async_last_search_results() or []
params = await search_chain.async_last_search_params() or {}
return schemas.Response(
success=True,
data={
"params": params,
"results": [torrent.to_dict() for torrent in torrents],
},
)
@router.get("/media/{mediaid}/stream", summary="渐进式精确搜索资源")
async def search_by_id_stream(
request: Request,

View File

@@ -33,6 +33,7 @@ class SearchChain(ChainBase):
"""
__result_temp_file = "__search_result__"
__search_params_temp_file = "__search_params__"
__ai_indices_cache_file = "__ai_recommend_indices__"
_ai_recommend_running = False
@@ -121,6 +122,115 @@ class SearchChain(ChainBase):
state._ai_recommend_error = None
self.remove_cache(self.__ai_indices_cache_file)
@staticmethod
def _build_search_keyword(
tmdbid: Optional[int] = None, doubanid: Optional[str] = None
) -> str:
"""
根据媒体ID生成可重放的搜索关键字。
"""
if tmdbid is not None:
return f"tmdb:{tmdbid}"
if doubanid:
return f"douban:{doubanid}"
return ""
@staticmethod
def _stringify_sites(sites: Optional[List[int]]) -> str:
"""
将站点ID列表转换为前端可直接复用的查询字符串。
"""
return ",".join(str(site) for site in sites) if sites else ""
@staticmethod
def _normalize_search_params(params: Optional[Dict[str, Any]]) -> Optional[Dict[str, str]]:
"""
规范化上次搜索参数,供前端结果页重新搜索使用。
"""
if not isinstance(params, dict):
return None
normalized = {
"keyword": str(params.get("keyword") or ""),
"type": str(params.get("type") or ""),
"area": str(params.get("area") or ""),
"title": str(params.get("title") or ""),
"year": str(params.get("year") or ""),
"season": str(params.get("season") or ""),
"sites": str(params.get("sites") or ""),
}
return normalized if normalized["keyword"] else None
def save_last_search_params(
self,
*,
keyword: Optional[str],
mtype: Optional[MediaType] = None,
area: Optional[str] = "title",
title: Optional[str] = None,
year: Optional[str] = None,
season: Optional[int] = None,
sites: Optional[List[int]] = None,
) -> None:
"""
保存最后一次资源搜索参数。
"""
params = self._normalize_search_params(
{
"keyword": keyword,
"type": mtype.value if isinstance(mtype, MediaType) else mtype,
"area": area,
"title": title,
"year": year,
"season": season,
"sites": self._stringify_sites(sites),
}
)
if params:
self.save_cache(params, self.__search_params_temp_file)
async def async_save_last_search_params(
self,
*,
keyword: Optional[str],
mtype: Optional[MediaType] = None,
area: Optional[str] = "title",
title: Optional[str] = None,
year: Optional[str] = None,
season: Optional[int] = None,
sites: Optional[List[int]] = None,
) -> None:
"""
异步保存最后一次资源搜索参数。
"""
params = self._normalize_search_params(
{
"keyword": keyword,
"type": mtype.value if isinstance(mtype, MediaType) else mtype,
"area": area,
"title": title,
"year": year,
"season": season,
"sites": self._stringify_sites(sites),
}
)
if params:
await self.async_save_cache(params, self.__search_params_temp_file)
def last_search_params(self) -> Optional[Dict[str, str]]:
"""
获取上次搜索使用的参数。
"""
return self._normalize_search_params(self.load_cache(self.__search_params_temp_file))
async def async_last_search_params(self) -> Optional[Dict[str, str]]:
"""
异步获取上次搜索使用的参数。
"""
return self._normalize_search_params(
await self.async_load_cache(self.__search_params_temp_file)
)
@staticmethod
def _normalize_ai_indices(ai_indices: List[Any]) -> List[int]:
"""
@@ -337,6 +447,13 @@ class SearchChain(ChainBase):
"""
if cache_local:
self.cancel_ai_recommend()
self.save_last_search_params(
keyword=self._build_search_keyword(tmdbid=tmdbid, doubanid=doubanid),
mtype=mtype,
area=area,
season=season,
sites=sites,
)
mediainfo = self.recognize_media(tmdbid=tmdbid, doubanid=doubanid, mtype=mtype)
if not mediainfo:
logger.error(f'{tmdbid} 媒体信息识别失败!')
@@ -365,6 +482,11 @@ class SearchChain(ChainBase):
"""
if cache_local:
self.cancel_ai_recommend()
self.save_last_search_params(
keyword=title,
area="title",
sites=sites,
)
if title:
logger.info(f'开始搜索资源,关键词:{title} ...')
else:
@@ -414,6 +536,13 @@ class SearchChain(ChainBase):
"""
if cache_local:
self.cancel_ai_recommend()
await self.async_save_last_search_params(
keyword=self._build_search_keyword(tmdbid=tmdbid, doubanid=doubanid),
mtype=mtype,
area=area,
season=season,
sites=sites,
)
mediainfo = await self.async_recognize_media(tmdbid=tmdbid, doubanid=doubanid, mtype=mtype)
if not mediainfo:
logger.error(f'{tmdbid} 媒体信息识别失败!')
@@ -442,6 +571,11 @@ class SearchChain(ChainBase):
"""
if cache_local:
self.cancel_ai_recommend()
await self.async_save_last_search_params(
keyword=title,
area="title",
sites=sites,
)
if title:
logger.info(f'开始搜索资源,关键词:{title} ...')
else:
@@ -472,6 +606,11 @@ class SearchChain(ChainBase):
"""
if cache_local:
self.cancel_ai_recommend()
await self.async_save_last_search_params(
keyword=title,
area="title",
sites=sites,
)
if title:
logger.info(f'开始渐进式搜索资源,关键词:{title} ...')
else:
@@ -518,6 +657,13 @@ class SearchChain(ChainBase):
"""
if cache_local:
self.cancel_ai_recommend()
await self.async_save_last_search_params(
keyword=self._build_search_keyword(tmdbid=tmdbid, doubanid=doubanid),
mtype=mtype,
area=area,
season=season,
sites=sites,
)
mediainfo = await self.async_recognize_media(tmdbid=tmdbid, doubanid=doubanid, mtype=mtype)
if not mediainfo:
logger.error(f'{tmdbid} 媒体信息识别失败!')

View File

@@ -486,6 +486,7 @@ class ConfigModel(BaseModel):
SECURITY_IMAGE_DOMAINS: list = Field(
default=[
"image.tmdb.org",
"images.tmdb.org",
"static-mdb.v.geilijiasu.com",
"bing.com",
"doubanio.com",