"""获取搜索结果工具""" import json import re from typing import List, Optional, Type from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool from app.chain.search import SearchChain from app.log import logger from ._torrent_search_utils import ( TORRENT_RESULT_LIMIT, build_filter_options, filter_contexts, simplify_search_result, ) class GetSearchResultsInput(BaseModel): """获取搜索结果工具的输入参数模型""" explanation: str = Field( ..., description="Clear explanation of why this tool is being used in the current context", ) site: Optional[List[str]] = Field(None, description="Site name filters") season: Optional[List[str]] = Field(None, description="Season or episode filters") free_state: Optional[List[str]] = Field(None, description="Promotion state filters") video_code: Optional[List[str]] = Field(None, description="Video codec filters") edition: Optional[List[str]] = Field(None, description="Edition filters") resolution: Optional[List[str]] = Field(None, description="Resolution filters") release_group: Optional[List[str]] = Field( None, description="Release group filters" ) title_pattern: Optional[str] = Field( None, description="Regular expression pattern to filter torrent titles (e.g., '4K|2160p|UHD', '1080p.*BluRay')", ) show_filter_options: Optional[bool] = Field( False, description="Whether to return only optional filter options for re-checking available conditions", ) page: Optional[int] = Field( 1, description="Page number for pagination (default: 1, each page returns up to 50 results)", ) class GetSearchResultsTool(MoviePilotTool): name: str = "get_search_results" description: str = "Get cached torrent search results from search_torrents with optional filters. Supports pagination with up to 50 results per page." args_schema: Type[BaseModel] = GetSearchResultsInput def get_tool_message(self, **kwargs) -> Optional[str]: return "获取搜索结果" async def run( self, site: Optional[List[str]] = None, season: Optional[List[str]] = None, free_state: Optional[List[str]] = None, video_code: Optional[List[str]] = None, edition: Optional[List[str]] = None, resolution: Optional[List[str]] = None, release_group: Optional[List[str]] = None, title_pattern: Optional[str] = None, show_filter_options: bool = False, page: Optional[int] = 1, **kwargs, ) -> str: page = max(1, page or 1) logger.info( f"执行工具: {self.name}, 参数: site={site}, season={season}, free_state={free_state}, video_code={video_code}, edition={edition}, resolution={resolution}, release_group={release_group}, title_pattern={title_pattern}, show_filter_options={show_filter_options}, page={page}" ) try: items = await SearchChain().async_last_search_results() or [] if not items: return "没有可用的搜索结果,请先使用 search_torrents 搜索" if show_filter_options: payload = { "total_count": len(items), "filter_options": build_filter_options(items), } return json.dumps(payload, ensure_ascii=False, indent=2) regex_pattern = None if title_pattern: try: regex_pattern = re.compile(title_pattern, re.IGNORECASE) except re.error as e: logger.warning(f"正则表达式编译失败: {title_pattern}, 错误: {e}") return f"正则表达式格式错误: {str(e)}" filtered_items = filter_contexts( items=items, site=site, season=season, free_state=free_state, video_code=video_code, edition=edition, resolution=resolution, release_group=release_group, ) if regex_pattern: filtered_items = [ item for item in filtered_items if item.torrent_info and item.torrent_info.title and regex_pattern.search(item.torrent_info.title) ] if not filtered_items: return "没有符合筛选条件的搜索结果,请调整筛选条件" total_count = len(filtered_items) filtered_ids = {id(item) for item in filtered_items} matched_indices = [ index for index, item in enumerate(items, start=1) if id(item) in filtered_ids ] # 分页 page_size = TORRENT_RESULT_LIMIT start = (page - 1) * page_size end = start + page_size page_items = filtered_items[start:end] page_indices = matched_indices[start:end] if not page_items: return f"第 {page} 页没有数据,共 {total_count} 条结果,共 {(total_count + page_size - 1) // page_size} 页。" results = [ simplify_search_result(item, index) for item, index in zip(page_items, page_indices) ] total_pages = (total_count + page_size - 1) // page_size payload = { "total_count": total_count, "page": page, "total_pages": total_pages, "results": results, } if page < total_pages: payload["message"] = ( f"搜索结果共 {total_count} 条,当前第 {page}/{total_pages} 页,可使用 page={page + 1} 获取下一页。" ) return json.dumps(payload, ensure_ascii=False, indent=2) except Exception as e: error_message = f"获取搜索结果失败: {str(e)}" logger.error(f"获取搜索结果失败: {e}", exc_info=True) return error_message