agent工具支持翻页及取消数量限制

This commit is contained in:
jxxghp
2026-04-08 07:41:34 +08:00
parent 6b01901a4a
commit 5acfd683b9
6 changed files with 420 additions and 198 deletions

View File

@@ -13,40 +13,48 @@ from app.schemas.types import MediaType, media_type_to_agent
class GetRecommendationsInput(BaseModel):
"""获取推荐工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
source: Optional[str] = Field("tmdb_trending",
description="Recommendation source: "
"'tmdb_trending' for TMDB trending content, "
"'tmdb_movies' for TMDB popular movies, "
"'tmdb_tvs' for TMDB popular TV shows, "
"'douban_hot' for Douban popular content, "
"'douban_movie_hot' for Douban hot movies, "
"'douban_tv_hot' for Douban hot TV shows, "
"'douban_movie_showing' for Douban movies currently showing, "
"'douban_movies' for Douban latest movies, "
"'douban_tvs' for Douban latest TV shows, "
"'douban_movie_top250' for Douban movie TOP250, "
"'douban_tv_weekly_chinese' for Douban Chinese TV weekly chart, "
"'douban_tv_weekly_global' for Douban global TV weekly chart, "
"'douban_tv_animation' for Douban popular animation, "
"'bangumi_calendar' for Bangumi anime calendar")
media_type: Optional[str] = Field("all",
description="Allowed values: movie, tv, all")
limit: Optional[int] = Field(20,
description="Maximum number of recommendations to return (default: 20, maximum: 100)")
explanation: str = Field(
...,
description="Clear explanation of why this tool is being used in the current context",
)
source: Optional[str] = Field(
"tmdb_trending",
description="Recommendation source: "
"'tmdb_trending' for TMDB trending content, "
"'tmdb_movies' for TMDB popular movies, "
"'tmdb_tvs' for TMDB popular TV shows, "
"'douban_hot' for Douban popular content, "
"'douban_movie_hot' for Douban hot movies, "
"'douban_tv_hot' for Douban hot TV shows, "
"'douban_movie_showing' for Douban movies currently showing, "
"'douban_movies' for Douban latest movies, "
"'douban_tvs' for Douban latest TV shows, "
"'douban_movie_top250' for Douban movie TOP250, "
"'douban_tv_weekly_chinese' for Douban Chinese TV weekly chart, "
"'douban_tv_weekly_global' for Douban global TV weekly chart, "
"'douban_tv_animation' for Douban popular animation, "
"'bangumi_calendar' for Bangumi anime calendar",
)
media_type: Optional[str] = Field(
"all", description="Allowed values: movie, tv, all"
)
page: Optional[int] = Field(
1, description="Page number for pagination (default: 1, 20 items per page)"
)
class GetRecommendationsTool(MoviePilotTool):
name: str = "get_recommendations"
description: str = "Get trending and popular media recommendations from various sources. Returns curated lists of popular movies, TV shows, and anime based on different criteria like trending, ratings, or calendar schedules."
description: str = "Get trending and popular media recommendations from various sources. Returns curated lists of popular movies, TV shows, and anime based on different criteria like trending, ratings, or calendar schedules. Supports pagination with 20 items per page."
args_schema: Type[BaseModel] = GetRecommendationsInput
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据推荐参数生成友好的提示消息"""
source = kwargs.get("source", "tmdb_trending")
media_type = kwargs.get("media_type", "all")
limit = kwargs.get("limit", 20)
page = kwargs.get("page", 1)
source_map = {
"tmdb_trending": "TMDB流行趋势",
"tmdb_movies": "TMDB热门电影",
@@ -61,20 +69,29 @@ class GetRecommendationsTool(MoviePilotTool):
"douban_tv_weekly_chinese": "豆瓣国产剧集榜",
"douban_tv_weekly_global": "豆瓣全球剧集榜",
"douban_tv_animation": "豆瓣热门动漫",
"bangumi_calendar": "番组计划"
"bangumi_calendar": "番组计划",
}
source_desc = source_map.get(source, source)
message = f"正在获取推荐: {source_desc}"
if media_type != "all":
message += f" [{media_type}]"
message += f" (限制: {limit})"
message += f" ({page})"
return message
async def run(self, source: Optional[str] = "tmdb_trending",
media_type: Optional[str] = "all", limit: Optional[int] = 20, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: source={source}, media_type={media_type}, limit={limit}")
async def run(
self,
source: Optional[str] = "tmdb_trending",
media_type: Optional[str] = "all",
page: Optional[int] = 1,
**kwargs,
) -> str:
page = max(1, page or 1)
page_size = 20
logger.info(
f"执行工具: {self.name}, 参数: source={source}, media_type={media_type}, page={page}"
)
try:
if media_type != "all":
media_type_enum = MediaType.from_agent(media_type)
@@ -85,73 +102,103 @@ class GetRecommendationsTool(MoviePilotTool):
recommend_chain = RecommendChain()
results = []
if source == "tmdb_trending":
# async_tmdb_trending 只接受 page 参数,返回固定数量的结果
# 如果需要限制数量,需要在返回后截取
results = await recommend_chain.async_tmdb_trending(page=1)
if limit and limit > 0:
results = results[:limit]
results = await recommend_chain.async_tmdb_trending(page=page)
elif source == "tmdb_movies":
# async_tmdb_movies 接受 page 参数,返回固定数量的结果
results = await recommend_chain.async_tmdb_movies(page=1)
if limit and limit > 0:
results = results[:limit]
results = await recommend_chain.async_tmdb_movies(page=page)
elif source == "tmdb_tvs":
# async_tmdb_tvs 接受 page 参数,返回固定数量的结果
results = await recommend_chain.async_tmdb_tvs(page=1)
if limit and limit > 0:
results = results[:limit]
results = await recommend_chain.async_tmdb_tvs(page=page)
elif source == "douban_hot":
if media_type == "movie":
results = await recommend_chain.async_douban_movie_hot(page=1, count=limit)
results = await recommend_chain.async_douban_movie_hot(
page=page, count=page_size
)
elif media_type == "tv":
results = await recommend_chain.async_douban_tv_hot(page=1, count=limit)
results = await recommend_chain.async_douban_tv_hot(
page=page, count=page_size
)
else: # all
results.extend(await recommend_chain.async_douban_movie_hot(page=1, count=limit))
results.extend(await recommend_chain.async_douban_tv_hot(page=1, count=limit))
results.extend(
await recommend_chain.async_douban_movie_hot(
page=page, count=page_size
)
)
results.extend(
await recommend_chain.async_douban_tv_hot(
page=page, count=page_size
)
)
elif source == "douban_movie_hot":
results = await recommend_chain.async_douban_movie_hot(page=1, count=limit)
results = await recommend_chain.async_douban_movie_hot(
page=page, count=page_size
)
elif source == "douban_tv_hot":
results = await recommend_chain.async_douban_tv_hot(page=1, count=limit)
results = await recommend_chain.async_douban_tv_hot(
page=page, count=page_size
)
elif source == "douban_movie_showing":
results = await recommend_chain.async_douban_movie_showing(page=1, count=limit)
results = await recommend_chain.async_douban_movie_showing(
page=page, count=page_size
)
elif source == "douban_movies":
results = await recommend_chain.async_douban_movies(page=1, count=limit)
results = await recommend_chain.async_douban_movies(
page=page, count=page_size
)
elif source == "douban_tvs":
results = await recommend_chain.async_douban_tvs(page=1, count=limit)
results = await recommend_chain.async_douban_tvs(
page=page, count=page_size
)
elif source == "douban_movie_top250":
results = await recommend_chain.async_douban_movie_top250(page=1, count=limit)
results = await recommend_chain.async_douban_movie_top250(
page=page, count=page_size
)
elif source == "douban_tv_weekly_chinese":
results = await recommend_chain.async_douban_tv_weekly_chinese(page=1, count=limit)
results = await recommend_chain.async_douban_tv_weekly_chinese(
page=page, count=page_size
)
elif source == "douban_tv_weekly_global":
results = await recommend_chain.async_douban_tv_weekly_global(page=1, count=limit)
results = await recommend_chain.async_douban_tv_weekly_global(
page=page, count=page_size
)
elif source == "douban_tv_animation":
results = await recommend_chain.async_douban_tv_animation(page=1, count=limit)
results = await recommend_chain.async_douban_tv_animation(
page=page, count=page_size
)
elif source == "bangumi_calendar":
results = await recommend_chain.async_bangumi_calendar(page=1, count=limit)
results = await recommend_chain.async_bangumi_calendar(
page=page, count=page_size
)
else:
# 不支持的推荐来源
supported_sources = [
"tmdb_trending", "tmdb_movies", "tmdb_tvs",
"douban_hot", "douban_movie_hot", "douban_tv_hot",
"douban_movie_showing", "douban_movies", "douban_tvs",
"douban_movie_top250", "douban_tv_weekly_chinese",
"douban_tv_weekly_global", "douban_tv_animation",
"bangumi_calendar"
"tmdb_trending",
"tmdb_movies",
"tmdb_tvs",
"douban_hot",
"douban_movie_hot",
"douban_tv_hot",
"douban_movie_showing",
"douban_movies",
"douban_tvs",
"douban_movie_top250",
"douban_tv_weekly_chinese",
"douban_tv_weekly_global",
"douban_tv_animation",
"bangumi_calendar",
]
return f"不支持的推荐来源: {source}。支持的来源包括: {', '.join(supported_sources)}"
if results:
# 限制最多20条结果
# 对于TMDB来源API自身按页返回取前page_size条
total_count = len(results)
limited_results = results[:20]
page_results = results[:page_size]
# 精简字段,只保留关键信息
simplified_results = []
for r in limited_results:
for r in page_results:
# r 应该是字典格式to_dict的结果但为了安全起见进行检查
if not isinstance(r, dict):
logger.warning(f"推荐结果格式异常,跳过: {type(r)}")
continue
simplified = {
"title": r.get("title"),
"en_title": r.get("en_title"),
@@ -163,14 +210,19 @@ class GetRecommendationsTool(MoviePilotTool):
"douban_id": r.get("douban_id"),
"vote_average": r.get("vote_average"),
"poster_path": r.get("poster_path"),
"detail_link": r.get("detail_link")
"detail_link": r.get("detail_link"),
}
simplified_results.append(simplified)
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
# 如果结果被裁剪,添加提示信息
if total_count > 20:
return f"注意:推荐结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}"
return result_json
result_json = json.dumps(
simplified_results, ensure_ascii=False, indent=2
)
has_more = total_count > page_size
payload_msg = f"{page} 页,当前页 {len(simplified_results)} 条结果。"
if has_more:
payload_msg += (
f" 可能有更多数据,可使用 page={page + 1} 获取下一页。"
)
return f"{payload_msg}\n\n{result_json}"
return "未找到推荐内容。"
except Exception as e:
logger.error(f"获取推荐失败: {e}", exc_info=True)

View File

@@ -19,33 +19,60 @@ from ._torrent_search_utils import (
class GetSearchResultsInput(BaseModel):
"""获取搜索结果工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
explanation: str = Field(
...,
description="Clear explanation of why this tool is being used in the current context",
)
site: Optional[List[str]] = Field(None, description="Site name filters")
season: Optional[List[str]] = Field(None, description="Season or episode filters")
free_state: Optional[List[str]] = Field(None, description="Promotion state filters")
video_code: Optional[List[str]] = Field(None, description="Video codec filters")
edition: Optional[List[str]] = Field(None, description="Edition filters")
resolution: Optional[List[str]] = Field(None, description="Resolution filters")
release_group: Optional[List[str]] = Field(None, description="Release group filters")
title_pattern: Optional[str] = Field(None, description="Regular expression pattern to filter torrent titles (e.g., '4K|2160p|UHD', '1080p.*BluRay')")
show_filter_options: Optional[bool] = Field(False, description="Whether to return only optional filter options for re-checking available conditions")
release_group: Optional[List[str]] = Field(
None, description="Release group filters"
)
title_pattern: Optional[str] = Field(
None,
description="Regular expression pattern to filter torrent titles (e.g., '4K|2160p|UHD', '1080p.*BluRay')",
)
show_filter_options: Optional[bool] = Field(
False,
description="Whether to return only optional filter options for re-checking available conditions",
)
page: Optional[int] = Field(
1,
description="Page number for pagination (default: 1, each page returns up to 50 results)",
)
class GetSearchResultsTool(MoviePilotTool):
name: str = "get_search_results"
description: str = "Get cached torrent search results from search_torrents with optional filters. Returns at most the first 50 matches."
description: str = "Get cached torrent search results from search_torrents with optional filters. Supports pagination with up to 50 results per page."
args_schema: Type[BaseModel] = GetSearchResultsInput
def get_tool_message(self, **kwargs) -> Optional[str]:
return "正在获取搜索结果"
async def run(self, site: Optional[List[str]] = None, season: Optional[List[str]] = None,
free_state: Optional[List[str]] = None, video_code: Optional[List[str]] = None,
edition: Optional[List[str]] = None, resolution: Optional[List[str]] = None,
release_group: Optional[List[str]] = None, title_pattern: Optional[str] = None,
show_filter_options: bool = False,
**kwargs) -> str:
async def run(
self,
site: Optional[List[str]] = None,
season: Optional[List[str]] = None,
free_state: Optional[List[str]] = None,
video_code: Optional[List[str]] = None,
edition: Optional[List[str]] = None,
resolution: Optional[List[str]] = None,
release_group: Optional[List[str]] = None,
title_pattern: Optional[str] = None,
show_filter_options: bool = False,
page: Optional[int] = 1,
**kwargs,
) -> str:
page = max(1, page or 1)
logger.info(
f"执行工具: {self.name}, 参数: site={site}, season={season}, free_state={free_state}, video_code={video_code}, edition={edition}, resolution={resolution}, release_group={release_group}, title_pattern={title_pattern}, show_filter_options={show_filter_options}")
f"执行工具: {self.name}, 参数: site={site}, season={season}, free_state={free_state}, video_code={video_code}, edition={edition}, resolution={resolution}, release_group={release_group}, title_pattern={title_pattern}, show_filter_options={show_filter_options}, page={page}"
)
try:
items = await SearchChain().async_last_search_results() or []
@@ -79,8 +106,10 @@ class GetSearchResultsTool(MoviePilotTool):
)
if regex_pattern:
filtered_items = [
item for item in filtered_items
if item.torrent_info and item.torrent_info.title
item
for item in filtered_items
if item.torrent_info
and item.torrent_info.title
and regex_pattern.search(item.torrent_info.title)
]
if not filtered_items:
@@ -88,19 +117,37 @@ class GetSearchResultsTool(MoviePilotTool):
total_count = len(filtered_items)
filtered_ids = {id(item) for item in filtered_items}
matched_indices = [index for index, item in enumerate(items, start=1) if id(item) in filtered_ids]
limited_items = filtered_items[:TORRENT_RESULT_LIMIT]
limited_indices = matched_indices[:TORRENT_RESULT_LIMIT]
matched_indices = [
index
for index, item in enumerate(items, start=1)
if id(item) in filtered_ids
]
# 分页
page_size = TORRENT_RESULT_LIMIT
start = (page - 1) * page_size
end = start + page_size
page_items = filtered_items[start:end]
page_indices = matched_indices[start:end]
if not page_items:
return f"{page} 页没有数据,共 {total_count} 条结果,共 {(total_count + page_size - 1) // page_size} 页。"
results = [
simplify_search_result(item, index)
for item, index in zip(limited_items, limited_indices)
for item, index in zip(page_items, page_indices)
]
total_pages = (total_count + page_size - 1) // page_size
payload = {
"total_count": total_count,
"page": page,
"total_pages": total_pages,
"results": results,
}
if total_count > TORRENT_RESULT_LIMIT:
payload["message"] = f"搜索结果共找到 {total_count} 条,仅显示前 {TORRENT_RESULT_LIMIT} 条结果。"
if page < total_pages:
payload["message"] = (
f"搜索结果共 {total_count} 条,当前第 {page}/{total_pages} 页,可使用 page={page + 1} 获取下一页。"
)
return json.dumps(payload, ensure_ascii=False, indent=2)
except Exception as e:
error_message = f"获取搜索结果失败: {str(e)}"

View File

@@ -58,14 +58,7 @@ class QueryInstalledPluginsTool(MoviePilotTool):
}
)
total_count = len(plugins_list)
result_json = json.dumps(plugins_list, ensure_ascii=False, indent=2)
if total_count > 100:
limited_plugins = plugins_list[:100]
limited_json = json.dumps(limited_plugins, ensure_ascii=False, indent=2)
return f"注意:共找到 {total_count} 个已安装插件,为节省上下文空间,仅显示前 100 个。\n\n{limited_json}"
return result_json
except Exception as e:
logger.error(f"查询已安装插件失败: {e}", exc_info=True)

View File

@@ -10,52 +10,70 @@ from app.chain.mediaserver import MediaServerChain
from app.helper.service import ServiceConfigHelper
from app.log import logger
PAGE_SIZE = 20
class QueryLibraryLatestInput(BaseModel):
"""查询媒体服务器最近入库影片工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
server: Optional[str] = Field(None, description="Media server name (optional, if not specified queries all enabled media servers)")
count: Optional[int] = Field(20, description="Number of items to return (default: 20)")
explanation: str = Field(
...,
description="Clear explanation of why this tool is being used in the current context",
)
server: Optional[str] = Field(
None,
description="Media server name (optional, if not specified queries all enabled media servers)",
)
page: Optional[int] = Field(
1, description="Page number for pagination (default: 1, 20 items per page)"
)
class QueryLibraryLatestTool(MoviePilotTool):
name: str = "query_library_latest"
description: str = "Query the latest media items added to the media server (Plex, Emby, Jellyfin). Returns recently added movies and TV series with their titles, images, links, and other metadata."
description: str = "Query the latest media items added to the media server (Plex, Emby, Jellyfin). Returns recently added movies and TV series with their titles, images, links, and other metadata. Supports pagination with 20 items per page."
args_schema: Type[BaseModel] = QueryLibraryLatestInput
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据查询参数生成友好的提示消息"""
server = kwargs.get("server")
count = kwargs.get("count", 20)
page = kwargs.get("page", 1)
parts = ["正在查询媒体服务器最近入库影片"]
if server:
parts.append(f"服务器: {server}")
else:
parts.append("所有服务器")
parts.append(f"数量: {count}")
parts.append(f"{page}")
return " | ".join(parts)
async def run(self, server: Optional[str] = None, count: Optional[int] = 20, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: server={server}, count={count}")
async def run(
self, server: Optional[str] = None, page: Optional[int] = 1, **kwargs
) -> str:
page = max(1, page or 1)
# 为了支持分页,需要获取足够多的数据再切片
fetch_count = page * PAGE_SIZE
logger.info(f"执行工具: {self.name}, 参数: server={server}, page={page}")
try:
media_chain = MediaServerChain()
results = []
# 如果没有指定服务器,获取所有启用的媒体服务器
if not server:
mediaservers = ServiceConfigHelper.get_mediaserver_configs()
enabled_servers = [ms.name for ms in mediaservers if ms.enabled]
if not enabled_servers:
return "未找到启用的媒体服务器"
# 遍历所有启用的服务器
for server_name in enabled_servers:
latest_items = media_chain.latest(server=server_name, count=count, username=self._username)
latest_items = media_chain.latest(
server=server_name, count=fetch_count, username=self._username
)
if latest_items:
for item in latest_items:
item_dict = item.model_dump(exclude_none=True)
@@ -63,24 +81,37 @@ class QueryLibraryLatestTool(MoviePilotTool):
results.append(item_dict)
else:
# 查询指定服务器
latest_items = media_chain.latest(server=server, count=count, username=self._username)
latest_items = media_chain.latest(
server=server, count=fetch_count, username=self._username
)
if latest_items:
for item in latest_items:
item_dict = item.model_dump(exclude_none=True)
item_dict["server"] = server
results.append(item_dict)
if not results:
server_info = f"服务器 {server}" if server else "所有服务器"
return f"未找到 {server_info} 的最近入库影片"
# 限制返回数量,避免结果过多
if len(results) > count:
results = results[:count]
return json.dumps(results, ensure_ascii=False, indent=2)
# 分页
total_count = len(results)
start = (page - 1) * PAGE_SIZE
end = start + PAGE_SIZE
page_results = results[start:end]
if not page_results:
total_pages = (total_count + PAGE_SIZE - 1) // PAGE_SIZE
return f"{page} 页没有数据,共 {total_count} 条结果,共 {total_pages} 页。"
total_pages = (total_count + PAGE_SIZE - 1) // PAGE_SIZE
payload_msg = f"{page}/{total_pages} 页,当前页 {len(page_results)} 条结果,共 {total_count} 条。"
if page < total_pages:
payload_msg += f" 可使用 page={page + 1} 获取下一页。"
result_json = json.dumps(page_results, ensure_ascii=False, indent=2)
return f"{payload_msg}\n\n{result_json}"
except Exception as e:
logger.error(f"查询媒体服务器最近入库影片失败: {e}", exc_info=True)
return f"查询媒体服务器最近入库影片时发生错误: {str(e)}"

View File

@@ -11,36 +11,59 @@ from app.db.models.subscribehistory import SubscribeHistory
from app.log import logger
from app.schemas.types import media_type_to_agent
PAGE_SIZE = 20
class QuerySubscribeHistoryInput(BaseModel):
"""查询订阅历史工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
media_type: Optional[str] = Field("all", description="Allowed values: movie, tv, all")
name: Optional[str] = Field(None, description="Filter by media name (partial match, optional)")
explanation: str = Field(
...,
description="Clear explanation of why this tool is being used in the current context",
)
media_type: Optional[str] = Field(
"all", description="Allowed values: movie, tv, all"
)
name: Optional[str] = Field(
None, description="Filter by media name (partial match, optional)"
)
page: Optional[int] = Field(
1, description="Page number for pagination (default: 1, 20 items per page)"
)
class QuerySubscribeHistoryTool(MoviePilotTool):
name: str = "query_subscribe_history"
description: str = "Query subscription history records. Shows completed subscriptions with their details including name, type, rating, completion date, and other subscription information. Supports filtering by media type and name. Returns up to 30 records."
description: str = "Query subscription history records. Shows completed subscriptions with their details including name, type, rating, completion date, and other subscription information. Supports filtering by media type and name. Supports pagination with 20 records per page."
args_schema: Type[BaseModel] = QuerySubscribeHistoryInput
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据查询参数生成友好的提示消息"""
media_type = kwargs.get("media_type", "all")
name = kwargs.get("name")
page = kwargs.get("page", 1)
parts = ["正在查询订阅历史"]
if media_type != "all":
parts.append(f"类型: {media_type}")
if name:
parts.append(f"名称: {name}")
return " | ".join(parts) if len(parts) > 1 else parts[0]
parts.append(f"{page}")
async def run(self, media_type: Optional[str] = "all",
name: Optional[str] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, name={name}")
return " | ".join(parts)
async def run(
self,
media_type: Optional[str] = "all",
name: Optional[str] = None,
page: Optional[int] = 1,
**kwargs,
) -> str:
page = max(1, page or 1)
logger.info(
f"执行工具: {self.name}, 参数: media_type={media_type}, name={name}, page={page}"
)
try:
if media_type not in ["all", "movie", "tv"]:
@@ -48,38 +71,66 @@ class QuerySubscribeHistoryTool(MoviePilotTool):
# 获取数据库会话
async with AsyncSessionFactory() as db:
# 根据类型查询
if media_type == "all":
# 查询所有类型,需要分别查询电影和电视剧
movie_history = await SubscribeHistory.async_list_by_type(db, mtype="movie", page=1, count=100)
tv_history = await SubscribeHistory.async_list_by_type(db, mtype="tv", page=1, count=100)
all_history = list(movie_history) + list(tv_history)
# 按日期排序
all_history.sort(key=lambda x: x.date or "", reverse=True)
else:
# 查询指定类型
all_history = await SubscribeHistory.async_list_by_type(db, mtype=media_type, page=1, count=100)
# 按名称过滤
filtered_history = []
if name:
# 有名称过滤时,需要获取较多记录在内存中过滤
fetch_count = page * PAGE_SIZE * 5 # 获取足够多的数据用于过滤后分页
if media_type == "all":
movie_history = await SubscribeHistory.async_list_by_type(
db, mtype="movie", page=1, count=fetch_count
)
tv_history = await SubscribeHistory.async_list_by_type(
db, mtype="tv", page=1, count=fetch_count
)
all_history = list(movie_history) + list(tv_history)
all_history.sort(key=lambda x: x.date or "", reverse=True)
else:
all_history = list(
await SubscribeHistory.async_list_by_type(
db, mtype=media_type, page=1, count=fetch_count
)
)
# 按名称过滤
name_lower = name.lower()
for record in all_history:
if record.name and name_lower in record.name.lower():
filtered_history.append(record)
filtered_history = [
record
for record in all_history
if record.name and name_lower in record.name.lower()
]
else:
filtered_history = all_history
# 无名称过滤时,直接利用数据库分页
if media_type == "all":
movie_history = await SubscribeHistory.async_list_by_type(
db, mtype="movie", page=1, count=page * PAGE_SIZE
)
tv_history = await SubscribeHistory.async_list_by_type(
db, mtype="tv", page=1, count=page * PAGE_SIZE
)
all_history = list(movie_history) + list(tv_history)
all_history.sort(key=lambda x: x.date or "", reverse=True)
filtered_history = all_history
else:
filtered_history = list(
await SubscribeHistory.async_list_by_type(
db, mtype=media_type, page=1, count=page * PAGE_SIZE
)
)
if not filtered_history:
return "未找到相关订阅历史记录"
# 限制最多30条
# 分页切片
total_count = len(filtered_history)
limited_history = filtered_history[:30]
start = (page - 1) * PAGE_SIZE
end = start + PAGE_SIZE
page_records = filtered_history[start:end]
if not page_records:
return f"{page} 页没有数据。"
# 转换为字典格式,只保留关键信息
simplified_records = []
for record in limited_history:
for record in page_records:
simplified = {
"id": record.id,
"name": record.name,
@@ -93,7 +144,7 @@ class QuerySubscribeHistoryTool(MoviePilotTool):
"vote": record.vote,
"total_episode": record.total_episode,
"date": record.date,
"username": record.username
"username": record.username,
}
# 添加过滤规则信息(如果有)
if record.filter:
@@ -103,14 +154,19 @@ class QuerySubscribeHistoryTool(MoviePilotTool):
if record.resolution:
simplified["resolution"] = record.resolution
simplified_records.append(simplified)
result_json = json.dumps(simplified_records, ensure_ascii=False, indent=2)
# 如果结果被裁剪,添加提示信息
if total_count > 100:
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 100 条结果。\n\n{result_json}"
return result_json
result_json = json.dumps(
simplified_records, ensure_ascii=False, indent=2
)
has_more = total_count > end
payload_msg = f"{page} 页,当前页 {len(simplified_records)} 条结果。"
if has_more:
payload_msg += (
f" 可能有更多数据,可使用 page={page + 1} 获取下一页。"
)
return f"{payload_msg}\n\n{result_json}"
except Exception as e:
logger.error(f"查询订阅历史失败: {e}", exc_info=True)
return f"查询订阅历史时发生错误: {str(e)}"

View File

@@ -11,6 +11,8 @@ from app.log import logger
from app.schemas.subscribe import Subscribe as SubscribeSchema
from app.schemas.types import MediaType
PAGE_SIZE = 100
QUERY_SUBSCRIBE_OUTPUT_FIELDS = [
"id",
"name",
@@ -35,47 +37,76 @@ QUERY_SUBSCRIBE_OUTPUT_FIELDS = [
"custom_words",
"media_category",
"filter_groups",
"episode_group"
"episode_group",
]
class QuerySubscribesInput(BaseModel):
"""查询订阅工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
status: Optional[str] = Field("all",
description="Filter subscriptions by status: 'R' for enabled subscriptions, 'S' for paused ones, 'all' for all subscriptions")
media_type: Optional[str] = Field("all",
description="Allowed values: movie, tv, all")
tmdb_id: Optional[int] = Field(None, description="Filter by TMDB ID to check if a specific media is already subscribed")
douban_id: Optional[str] = Field(None, description="Filter by Douban ID to check if a specific media is already subscribed")
explanation: str = Field(
...,
description="Clear explanation of why this tool is being used in the current context",
)
status: Optional[str] = Field(
"all",
description="Filter subscriptions by status: 'R' for enabled subscriptions, 'S' for paused ones, 'all' for all subscriptions",
)
media_type: Optional[str] = Field(
"all", description="Allowed values: movie, tv, all"
)
tmdb_id: Optional[int] = Field(
None,
description="Filter by TMDB ID to check if a specific media is already subscribed",
)
douban_id: Optional[str] = Field(
None,
description="Filter by Douban ID to check if a specific media is already subscribed",
)
page: Optional[int] = Field(
1, description="Page number for pagination (default: 1, 100 items per page)"
)
class QuerySubscribesTool(MoviePilotTool):
name: str = "query_subscribes"
description: str = "Query subscription status and list user subscriptions. Returns full subscription parameters for each matched subscription."
description: str = "Query subscription status and list user subscriptions. Returns full subscription parameters for each matched subscription. Supports pagination with 100 items per page."
args_schema: Type[BaseModel] = QuerySubscribesInput
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据查询参数生成友好的提示消息"""
status = kwargs.get("status", "all")
media_type = kwargs.get("media_type", "all")
page = kwargs.get("page", 1)
parts = ["正在查询订阅"]
# 根据状态过滤条件生成提示
if status != "all":
status_map = {"R": "已启用", "S": "已暂停"}
parts.append(f"状态: {status_map.get(status, status)}")
# 根据媒体类型过滤条件生成提示
if media_type != "all":
parts.append(f"类型: {media_type}")
return " | ".join(parts) if len(parts) > 1 else parts[0]
async def run(self, status: Optional[str] = "all", media_type: Optional[str] = "all",
tmdb_id: Optional[int] = None, douban_id: Optional[str] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}, tmdb_id={tmdb_id}, douban_id={douban_id}")
parts.append(f"{page}")
return " | ".join(parts)
async def run(
self,
status: Optional[str] = "all",
media_type: Optional[str] = "all",
tmdb_id: Optional[int] = None,
douban_id: Optional[str] = None,
page: Optional[int] = 1,
**kwargs,
) -> str:
page = max(1, page or 1)
logger.info(
f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}, tmdb_id={tmdb_id}, douban_id={douban_id}, page={page}"
)
try:
if media_type != "all" and not MediaType.from_agent(media_type):
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv', 'all'"
@@ -86,7 +117,10 @@ class QuerySubscribesTool(MoviePilotTool):
for sub in subscribes:
if status != "all" and sub.state != status:
continue
if media_type != "all" and sub.type != MediaType.from_agent(media_type).value:
if (
media_type != "all"
and sub.type != MediaType.from_agent(media_type).value
):
continue
if tmdb_id is not None and sub.tmdbid != tmdb_id:
continue
@@ -94,21 +128,30 @@ class QuerySubscribesTool(MoviePilotTool):
continue
filtered_subscribes.append(sub)
if filtered_subscribes:
# 限制最多50条结果
total_count = len(filtered_subscribes)
limited_subscribes = filtered_subscribes[:50]
# 分页
start = (page - 1) * PAGE_SIZE
end = start + PAGE_SIZE
page_subscribes = filtered_subscribes[start:end]
if not page_subscribes:
total_pages = (total_count + PAGE_SIZE - 1) // PAGE_SIZE
return f"{page} 页没有数据,共 {total_count} 条结果,共 {total_pages} 页。"
full_subscribes = [
SubscribeSchema.model_validate(s, from_attributes=True).model_dump(
include=set(QUERY_SUBSCRIBE_OUTPUT_FIELDS),
exclude_none=True
include=set(QUERY_SUBSCRIBE_OUTPUT_FIELDS), exclude_none=True
)
for s in limited_subscribes
for s in page_subscribes
]
result_json = json.dumps(full_subscribes, ensure_ascii=False, indent=2)
# 如果结果被裁剪,添加提示信息
if total_count > 200:
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 200 条结果。\n\n{result_json}"
return result_json
total_pages = (total_count + PAGE_SIZE - 1) // PAGE_SIZE
payload_msg = f"{page}/{total_pages} 页,当前页 {len(page_subscribes)} 条结果,共 {total_count} 条。"
if page < total_pages:
payload_msg += f" 可使用 page={page + 1} 获取下一页。"
return f"{payload_msg}\n\n{result_json}"
return "未找到相关订阅"
except Exception as e:
logger.error(f"查询订阅失败: {e}", exc_info=True)