From 10543eedd014bfce01105156c314cab4dd51f682 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Fri, 10 Apr 2026 16:50:23 +0800 Subject: [PATCH] =?UTF-8?q?feat(search):=20=E6=94=AF=E6=8C=81=E6=B8=90?= =?UTF-8?q?=E8=BF=9B=E5=BC=8F=EF=BC=88SSE=EF=BC=89=E6=90=9C=E7=B4=A2?= =?UTF-8?q?=E8=B5=84=E6=BA=90=E5=B9=B6=E5=AE=9E=E6=97=B6=E8=BF=94=E5=9B=9E?= =?UTF-8?q?=E6=90=9C=E7=B4=A2=E8=BF=9B=E5=BA=A6=E4=B8=8E=E7=BB=93=E6=9E=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 /media/{mediaid}/stream 和 /title/stream 接口,支持基于 SSE 的渐进式搜索 - SearchChain 增加 async_search_by_title_stream、async_search_by_id_stream、async_process_stream、__async_search_all_sites_stream 方法 - 搜索结果按站点完成顺序实时推送,支持进度、候选、过滤、完成等阶段事件 - 优化参数解析与异常处理,提升大规模搜索体验 --- app/api/endpoints/search.py | 195 +++++++++++++++++++++++- app/chain/search.py | 290 +++++++++++++++++++++++++++++++++++- 2 files changed, 480 insertions(+), 5 deletions(-) diff --git a/app/api/endpoints/search.py b/app/api/endpoints/search.py index d8a8e705..42239760 100644 --- a/app/api/endpoints/search.py +++ b/app/api/endpoints/search.py @@ -1,6 +1,8 @@ -from typing import List, Any, Optional +import json +from typing import List, Any, Optional, AsyncIterator -from fastapi import APIRouter, Depends, Body +from fastapi import APIRouter, Depends, Body, Request +from fastapi.responses import StreamingResponse from app import schemas from app.chain.media import MediaChain @@ -9,7 +11,7 @@ from app.chain.ai_recommend import AIRecommendChain from app.core.config import settings from app.core.event import eventmanager from app.core.metainfo import MetaInfo -from app.core.security import verify_token +from app.core.security import verify_resource_token, verify_token from app.log import logger from app.schemas import MediaRecognizeConvertEventData from app.schemas.types import MediaType, ChainEventType @@ -17,6 +19,38 @@ from app.schemas.types import MediaType, ChainEventType router = APIRouter() +def _parse_site_list(sites: Optional[str]) -> Optional[List[int]]: + """ + 解析站点ID列表 + """ + return [int(site) for site in sites.split(",") if site] if sites else None + + +def _sse_event(data: dict) -> str: + """ + 转换为SSE事件 + """ + return f"data: {json.dumps(data, ensure_ascii=False)}\n\n" + + +async def _stream_search_events(request: Request, event_source: AsyncIterator[dict]): + """ + 输出搜索SSE事件 + """ + try: + async for event in event_source: + if await request.is_disconnected(): + break + yield _sse_event(event) + except Exception as err: + logger.error(f"渐进式搜索出错:{err}", exc_info=True) + yield _sse_event({ + "type": "error", + "success": False, + "message": str(err) + }) + + @router.get("/last", summary="查询搜索结果", response_model=List[schemas.Context]) async def search_latest(_: schemas.TokenPayload = Depends(verify_token)) -> Any: """ @@ -26,6 +60,139 @@ async def search_latest(_: schemas.TokenPayload = Depends(verify_token)) -> Any: return [torrent.to_dict() for torrent in torrents] +@router.get("/media/{mediaid}/stream", summary="渐进式精确搜索资源") +async def search_by_id_stream(request: Request, + mediaid: str, + mtype: Optional[str] = None, + area: Optional[str] = "title", + title: Optional[str] = None, + year: Optional[str] = None, + season: Optional[str] = None, + sites: Optional[str] = None, + _: schemas.TokenPayload = Depends(verify_resource_token)) -> Any: + """ + 根据TMDBID/豆瓣ID渐进式搜索站点资源,返回格式为SSE + """ + AIRecommendChain().cancel_ai_recommend() + + media_type = MediaType(mtype) if mtype else None + media_season = int(season) if season else None + site_list = _parse_site_list(sites) + media_chain = MediaChain() + search_chain = SearchChain() + + async def event_source(): + nonlocal media_season + torrents = None + if mediaid.startswith("tmdb:"): + tmdbid = int(mediaid.replace("tmdb:", "")) + if settings.RECOGNIZE_SOURCE == "douban": + doubaninfo = await media_chain.async_get_doubaninfo_by_tmdbid(tmdbid=tmdbid, mtype=media_type) + if doubaninfo: + torrents = search_chain.async_search_by_id_stream(doubanid=doubaninfo.get("id"), + mtype=media_type, area=area, + season=media_season, sites=site_list, + cache_local=True) + else: + yield {"type": "error", "success": False, "message": "未识别到豆瓣媒体信息"} + return + else: + torrents = search_chain.async_search_by_id_stream(tmdbid=tmdbid, mtype=media_type, area=area, + season=media_season, sites=site_list, + cache_local=True) + elif mediaid.startswith("douban:"): + doubanid = mediaid.replace("douban:", "") + if settings.RECOGNIZE_SOURCE == "themoviedb": + tmdbinfo = await media_chain.async_get_tmdbinfo_by_doubanid(doubanid=doubanid, mtype=media_type) + if tmdbinfo: + if tmdbinfo.get('season') and not media_season: + media_season = tmdbinfo.get('season') + torrents = search_chain.async_search_by_id_stream(tmdbid=tmdbinfo.get("id"), + mtype=media_type, area=area, + season=media_season, sites=site_list, + cache_local=True) + else: + yield {"type": "error", "success": False, "message": "未识别到TMDB媒体信息"} + return + else: + torrents = search_chain.async_search_by_id_stream(doubanid=doubanid, mtype=media_type, area=area, + season=media_season, sites=site_list, + cache_local=True) + elif mediaid.startswith("bangumi:"): + bangumiid = int(mediaid.replace("bangumi:", "")) + if settings.RECOGNIZE_SOURCE == "themoviedb": + tmdbinfo = await media_chain.async_get_tmdbinfo_by_bangumiid(bangumiid=bangumiid) + if tmdbinfo: + torrents = search_chain.async_search_by_id_stream(tmdbid=tmdbinfo.get("id"), + mtype=media_type, area=area, + season=media_season, sites=site_list, + cache_local=True) + else: + yield {"type": "error", "success": False, "message": "未识别到TMDB媒体信息"} + return + else: + doubaninfo = await media_chain.async_get_doubaninfo_by_bangumiid(bangumiid=bangumiid) + if doubaninfo: + torrents = search_chain.async_search_by_id_stream(doubanid=doubaninfo.get("id"), + mtype=media_type, area=area, + season=media_season, sites=site_list, + cache_local=True) + else: + yield {"type": "error", "success": False, "message": "未识别到豆瓣媒体信息"} + return + else: + event_data = MediaRecognizeConvertEventData( + mediaid=mediaid, + convert_type=settings.RECOGNIZE_SOURCE + ) + event = await eventmanager.async_send_event(ChainEventType.MediaRecognizeConvert, event_data) + if event and event.event_data: + event_data = event.event_data + if event_data.media_dict: + search_id = event_data.media_dict.get("id") + if event_data.convert_type == "themoviedb": + torrents = search_chain.async_search_by_id_stream(tmdbid=search_id, mtype=media_type, + area=area, season=media_season, + sites=site_list, cache_local=True) + elif event_data.convert_type == "douban": + torrents = search_chain.async_search_by_id_stream(doubanid=search_id, mtype=media_type, + area=area, season=media_season, + sites=site_list, cache_local=True) + else: + if not title: + yield {"type": "error", "success": False, "message": "未知的媒体ID"} + return + meta = MetaInfo(title) + if year: + meta.year = year + if media_type: + meta.type = media_type + if media_season: + meta.type = MediaType.TV + meta.begin_season = media_season + mediainfo = await media_chain.async_recognize_media(meta=meta) + if mediainfo: + if settings.RECOGNIZE_SOURCE == "themoviedb": + torrents = search_chain.async_search_by_id_stream(tmdbid=mediainfo.tmdb_id, + mtype=media_type, area=area, + season=media_season, sites=site_list, + cache_local=True) + else: + torrents = search_chain.async_search_by_id_stream(doubanid=mediainfo.douban_id, + mtype=media_type, area=area, + season=media_season, sites=site_list, + cache_local=True) + + if not torrents: + yield {"type": "error", "success": False, "message": "未搜索到任何资源"} + return + + async for event in torrents: + yield event + + return StreamingResponse(_stream_search_events(request, event_source()), media_type="text/event-stream") + + @router.get("/media/{mediaid}", summary="精确搜索资源", response_model=schemas.Response) async def search_by_id(mediaid: str, mtype: Optional[str] = None, @@ -156,6 +323,26 @@ async def search_by_id(mediaid: str, return schemas.Response(success=True, data=[torrent.to_dict() for torrent in torrents]) +@router.get("/title/stream", summary="渐进式模糊搜索资源") +async def search_by_title_stream(request: Request, + keyword: Optional[str] = None, + page: Optional[int] = 0, + sites: Optional[str] = None, + _: schemas.TokenPayload = Depends(verify_resource_token)) -> Any: + """ + 根据名称渐进式模糊搜索站点资源,返回格式为SSE + """ + AIRecommendChain().cancel_ai_recommend() + + event_source = SearchChain().async_search_by_title_stream( + title=keyword, + page=page, + sites=_parse_site_list(sites), + cache_local=True + ) + return StreamingResponse(_stream_search_events(request, event_source), media_type="text/event-stream") + + @router.get("/title", summary="模糊搜索资源", response_model=schemas.Response) async def search_by_title(keyword: Optional[str] = None, page: Optional[int] = 0, @@ -169,7 +356,7 @@ async def search_by_title(keyword: Optional[str] = None, torrents = await SearchChain().async_search_by_title( title=keyword, page=page, - sites=[int(site) for site in sites.split(",") if site] if sites else None, + sites=_parse_site_list(sites), cache_local=True ) if not torrents: diff --git a/app/chain/search.py b/app/chain/search.py index c24d6cfc..fd254302 100644 --- a/app/chain/search.py +++ b/app/chain/search.py @@ -3,7 +3,7 @@ import random import time from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime -from typing import Dict, Tuple +from typing import AsyncIterator, Any, Dict, Tuple from typing import List, Optional from app.helper.sites import SitesHelper # noqa @@ -167,6 +167,85 @@ class SearchChain(ChainBase): await self.async_save_cache(contexts, self.__result_temp_file) return contexts + async def async_search_by_title_stream(self, title: str, page: Optional[int] = 0, + sites: List[int] = None, + cache_local: Optional[bool] = False) -> AsyncIterator[dict]: + """ + 根据标题渐进式搜索资源,不识别不过滤,按站点完成顺序返回结果 + """ + if title: + logger.info(f'开始渐进式搜索资源,关键词:{title} ...') + else: + logger.info(f'开始渐进式浏览资源,站点:{sites} ...') + + contexts: List[Context] = [] + async for event in self.__async_search_all_sites_stream(keyword=title, sites=sites, page=page): + result = event.pop("items", []) or [] + batch_contexts = [ + Context(meta_info=MetaInfo(title=torrent.title, subtitle=torrent.description), + torrent_info=torrent) + for torrent in result + ] + if batch_contexts: + contexts.extend(batch_contexts) + yield { + **event, + "type": "append", + "items": [context.to_dict() for context in batch_contexts], + "total_items": len(contexts) + } + + if cache_local: + await self.async_save_cache(contexts, self.__result_temp_file) + + if not contexts: + logger.warn(f'{title} 未搜索到资源') + yield { + "type": "done", + "text": f"搜索完成,共 {len(contexts)} 个资源", + "items": [context.to_dict() for context in contexts], + "total_items": len(contexts) + } + + async def async_search_by_id_stream(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None, + mtype: MediaType = None, area: Optional[str] = "title", + season: Optional[int] = None, sites: List[int] = None, + cache_local: bool = False) -> AsyncIterator[dict]: + """ + 根据TMDBID/豆瓣ID渐进式搜索资源,先返回站点原始候选,再返回过滤匹配后的最终结果 + """ + mediainfo = await self.async_recognize_media(tmdbid=tmdbid, doubanid=doubanid, mtype=mtype) + if not mediainfo: + logger.error(f'{tmdbid} 媒体信息识别失败!') + yield { + "type": "error", + "success": False, + "message": "媒体信息识别失败" + } + return + + no_exists = None + if season is not None: + no_exists = { + tmdbid or doubanid: { + season: NotExistMediaInfo(episodes=[]) + } + } + + contexts: List[Context] = [] + async for event in self.async_process_stream(mediainfo=mediainfo, sites=sites, area=area, no_exists=no_exists): + if event.get("type") == "done": + contexts = event.get("contexts") or [] + event = { + key: value + for key, value in event.items() + if key != "contexts" + } + yield event + + if cache_local: + await self.async_save_cache(contexts, self.__result_temp_file) + @staticmethod def __prepare_params(mediainfo: MediaInfo, keyword: Optional[str] = None, @@ -503,6 +582,115 @@ class SearchChain(ChainBase): filter_params=filter_params ) + async def async_process_stream(self, mediainfo: MediaInfo, + keyword: Optional[str] = None, + no_exists: Dict[int, Dict[int, NotExistMediaInfo]] = None, + sites: List[int] = None, + rule_groups: List[str] = None, + area: Optional[str] = "title", + custom_words: List[str] = None, + filter_params: Dict[str, str] = None) -> AsyncIterator[dict]: + """ + 根据媒体信息渐进式搜索种子资源,先返回站点候选,再返回过滤匹配后的最终结果 + """ + + # 豆瓣标题处理 + if not mediainfo.tmdb_id: + meta = MetaInfo(title=mediainfo.title) + mediainfo.title = meta.name + mediainfo.season = meta.begin_season + logger.info(f'开始渐进式搜索资源,关键词:{keyword or mediainfo.title} ...') + + # 补充媒体信息 + if not mediainfo.names: + mediainfo = await self.async_recognize_media(mtype=mediainfo.type, + tmdbid=mediainfo.tmdb_id, + doubanid=mediainfo.douban_id) + if not mediainfo: + logger.error(f'媒体信息识别失败!') + yield { + "type": "error", + "success": False, + "message": "媒体信息识别失败" + } + return + + # 准备搜索参数 + season_episodes, keywords = self.__prepare_params( + mediainfo=mediainfo, + keyword=keyword, + no_exists=no_exists + ) + + torrents: List[TorrentInfo] = [] + candidate_contexts: List[Context] = [] + search_count = 0 + + for search_word in keywords: + if search_count > 0: + logger.info(f"已搜索 {search_count} 次,强制休眠 1-10 秒 ...") + await asyncio.sleep(random.randint(1, 10)) + + async for event in self.__async_search_all_sites_stream( + mediainfo=mediainfo, + keyword=search_word, + sites=sites, + area=area): + result = event.pop("items", []) or [] + torrents.extend(result) + batch_contexts = [ + Context(meta_info=MetaInfo(title=torrent.title, subtitle=torrent.description), + media_info=mediainfo, + torrent_info=torrent) + for torrent in result + ] + candidate_contexts.extend(batch_contexts) + yield { + **event, + "type": "append", + "stage": "searching", + "items": [context.to_dict() for context in batch_contexts], + "total_items": len(candidate_contexts) + } + + search_count += 1 + if torrents: + logger.info(f"共搜索到 {len(torrents)} 个资源,停止搜索") + break + + yield { + "type": "progress", + "stage": "filtering", + "value": 98, + "text": f"正在过滤匹配 {len(torrents)} 个候选资源 ..." + } + + contexts = await run_in_threadpool(self.__parse_result, + torrents=torrents, + mediainfo=mediainfo, + keyword=keyword, + rule_groups=rule_groups, + season_episodes=season_episodes, + custom_words=custom_words, + filter_params=filter_params) + final_items = [context.to_dict() for context in contexts] + yield { + "type": "replace", + "stage": "filtered", + "value": 100, + "text": f"过滤匹配完成,共 {len(contexts)} 个资源", + "items": final_items, + "total_items": len(contexts) + } + yield { + "type": "done", + "stage": "done", + "text": f"搜索完成,共 {len(contexts)} 个资源", + "items": final_items, + "total_items": len(contexts), + "contexts": contexts + } + def __search_all_sites(self, keyword: str, mediainfo: Optional[MediaInfo] = None, sites: List[int] = None, @@ -670,6 +858,106 @@ class SearchChain(ChainBase): # 返回 return results + async def __async_search_all_sites_stream(self, keyword: str, + mediainfo: Optional[MediaInfo] = None, + sites: List[int] = None, + page: Optional[int] = 0, + area: Optional[str] = "title") -> AsyncIterator[Dict[str, Any]]: + """ + 异步搜索多个站点,按站点完成顺序渐进式返回结果 + :param mediainfo: 识别的媒体信息 + :param keyword: 搜索关键词 + :param sites: 指定站点ID列表,如有则只搜索指定站点,否则搜索所有站点 + :param page: 搜索页码 + :param area: 搜索区域 title or imdbid + """ + indexer_sites = [] + + if not sites: + sites = SystemConfigOper().get(SystemConfigKey.IndexerSites) or [] + + for indexer in await SitesHelper().async_get_indexers(): + if not sites or indexer.get("id") in sites: + indexer_sites.append(indexer) + if not indexer_sites: + logger.warn('未开启任何有效站点,无法搜索资源') + yield { + "type": "done", + "stage": "searching", + "value": 100, + "text": "未开启任何有效站点,无法搜索资源", + "items": [], + "finished": 0, + "total": 0 + } + return + + progress = ProgressHelper(ProgressKey.Search) + progress.start() + start_time = datetime.now() + total_num = len(indexer_sites) + finish_count = 0 + progress.update(value=0, + text=f"开始搜索,共 {total_num} 个站点 ...") + yield { + "type": "progress", + "stage": "searching", + "value": 0, + "text": f"开始搜索,共 {total_num} 个站点 ...", + "items": [], + "finished": 0, + "total": total_num + } + + async def search_site(site: dict) -> Tuple[dict, List[TorrentInfo]]: + if area == "imdbid": + result = await self.async_search_torrents(site=site, + keyword=mediainfo.imdb_id if mediainfo else None, + mtype=mediainfo.type if mediainfo else None, + page=page) + else: + result = await self.async_search_torrents(site=site, + keyword=keyword, + mtype=mediainfo.type if mediainfo else None, + page=page) + return site, result or [] + + tasks = [asyncio.create_task(search_site(site)) for site in indexer_sites] + results_count = 0 + try: + for future in asyncio.as_completed(tasks): + if global_vars.is_system_stopped: + break + finish_count += 1 + site, result = await future + results_count += len(result) + logger.info(f"站点搜索进度:{finish_count} / {total_num}") + progress_value = finish_count / total_num * 100 + progress_text = f"正在搜索{keyword or ''},已完成 {finish_count} / {total_num} 个站点 ..." + progress.update(value=progress_value, text=progress_text) + yield { + "type": "append", + "stage": "searching", + "value": progress_value, + "text": progress_text, + "items": result, + "site": site.get("name"), + "site_id": site.get("id"), + "finished": finish_count, + "total": total_num, + "total_items": results_count + } + finally: + for task in tasks: + if not task.done(): + task.cancel() + + end_time = datetime.now() + progress.update(value=100, + text=f"站点搜索完成,有效资源数:{results_count},总耗时 {(end_time - start_time).seconds} 秒") + logger.info(f"站点搜索完成,有效资源数:{results_count},总耗时 {(end_time - start_time).seconds} 秒") + progress.end() + @eventmanager.register(EventType.SiteDeleted) def remove_site(self, event: Event): """