diff --git a/app/agent/tools/factory.py b/app/agent/tools/factory.py index b57c7e7a..2597169c 100644 --- a/app/agent/tools/factory.py +++ b/app/agent/tools/factory.py @@ -53,7 +53,7 @@ from app.agent.tools.impl.update_site_cookie import UpdateSiteCookieTool from app.agent.tools.impl.delete_download import DeleteDownloadTool from app.agent.tools.impl.delete_download_history import DeleteDownloadHistoryTool from app.agent.tools.impl.delete_transfer_history import DeleteTransferHistoryTool -from app.agent.tools.impl.modify_download import ModifyDownloadTool +from app.agent.tools.impl.update_download_tasks import UpdateDownloadTasksTool from app.agent.tools.impl.query_directory_settings import QueryDirectorySettingsTool from app.agent.tools.impl.list_directory import ListDirectoryTool from app.agent.tools.impl.query_transfer_history import QueryTransferHistoryTool @@ -186,7 +186,7 @@ class MoviePilotToolFactory: DeleteDownloadTool, DeleteDownloadHistoryTool, DeleteTransferHistoryTool, - ModifyDownloadTool, + UpdateDownloadTasksTool, QueryDownloadersTool, QuerySitesTool, UpdateSiteTool, diff --git a/app/agent/tools/impl/modify_download.py b/app/agent/tools/impl/modify_download.py deleted file mode 100644 index 3fdcc836..00000000 --- a/app/agent/tools/impl/modify_download.py +++ /dev/null @@ -1,143 +0,0 @@ -"""修改下载任务工具""" - -from typing import Optional, Type, List - -from pydantic import BaseModel, Field - -from app.agent.tools.base import MoviePilotTool -from app.agent.tools.tags import ToolTag -from app.chain.download import DownloadChain -from app.log import logger - - -class ModifyDownloadInput(BaseModel): - """修改下载任务工具的输入参数模型""" - - explanation: Optional[str] = Field(None, - description="Clear explanation of why this tool is being used in the current context",) - hash: str = Field( - ..., description="Task hash (can be obtained from query_download_tasks tool)" - ) - action: Optional[str] = Field( - None, - description="Action to perform on the task: 'start' to resume downloading, 'stop' to pause downloading. " - "If not provided, no start/stop action will be performed.", - ) - tags: Optional[List[str]] = Field( - None, - description="List of tags to set on the download task. If provided, these tags will be added to the task. " - "Example: ['movie', 'hd']", - ) - downloader: Optional[str] = Field( - None, - description="Name of specific downloader (optional, if not provided will search all downloaders)", - ) - - -class ModifyDownloadTool(MoviePilotTool): - """修改下载任务工具""" - - name: str = "modify_download" - tags: list[str] = [ - ToolTag.Write, - ToolTag.Download, - ToolTag.Admin, - ] - description: str = ( - "Modify a download task in the downloader by task hash. " - "Supports: 1) Setting tags on a download task, " - "2) Starting (resuming) a paused download task, " - "3) Stopping (pausing) a downloading task. " - "Multiple operations can be performed in a single call." - ) - args_schema: Type[BaseModel] = ModifyDownloadInput - require_admin: bool = True - - def get_tool_message(self, **kwargs) -> Optional[str]: - hash_value = kwargs.get("hash", "") - action = kwargs.get("action") - tags = kwargs.get("tags") - downloader = kwargs.get("downloader") - - parts = [f"修改下载任务: {hash_value}"] - if action == "start": - parts.append("操作: 开始下载") - elif action == "stop": - parts.append("操作: 暂停下载") - if tags: - parts.append(f"标签: {', '.join(tags)}") - if downloader: - parts.append(f"下载器: {downloader}") - return " | ".join(parts) - - @staticmethod - def _modify_download_sync( - hash_value: str, - action: Optional[str] = None, - tags: Optional[List[str]] = None, - downloader: Optional[str] = None, - ) -> List[str]: - """同步修改下载任务状态和标签,避免下载器 SDK 阻塞事件循环。""" - download_chain = DownloadChain() - results = [] - - if tags: - tag_result = download_chain.set_torrents_tag( - hashs=[hash_value], tags=tags, downloader=downloader - ) - if tag_result: - results.append(f"成功设置标签:{', '.join(tags)}") - else: - results.append("设置标签失败,请检查任务是否存在或下载器是否可用") - - if action: - action_result = download_chain.set_downloading( - hash_str=hash_value, oper=action, name=downloader - ) - action_desc = "开始" if action == "start" else "暂停" - if action_result: - results.append(f"成功{action_desc}下载任务") - else: - results.append(f"{action_desc}下载任务失败,请检查任务是否存在或下载器是否可用") - - return results - - async def run( - self, - hash: str, - action: Optional[str] = None, - tags: Optional[List[str]] = None, - downloader: Optional[str] = None, - **kwargs, - ) -> str: - logger.info( - f"执行工具: {self.name}, 参数: hash={hash}, action={action}, tags={tags}, downloader={downloader}" - ) - - try: - # 校验 hash 格式 - if len(hash) != 40 or not all(c in "0123456789abcdefABCDEF" for c in hash): - return "参数错误:hash 格式无效,请先使用 query_download_tasks 工具获取正确的 hash。" - - # 校验参数:至少需要一个操作 - if not action and not tags: - return "参数错误:至少需要指定 action(start/stop)或 tags 中的一个。" - - # 校验 action 参数 - if action and action not in ("start", "stop"): - return f"参数错误:action 只支持 'start'(开始下载)或 'stop'(暂停下载),收到: '{action}'。" - - results = await self.run_blocking( - "downloader", - self._modify_download_sync, - hash, - action, - tags, - downloader, - ) - - return f"下载任务 {hash}:" + ";".join(results) - - except Exception as e: - logger.error(f"修改下载任务失败: {e}", exc_info=True) - return f"修改下载任务时发生错误: {str(e)}" diff --git a/app/agent/tools/impl/query_download_tasks.py b/app/agent/tools/impl/query_download_tasks.py index 9bdb547c..c3d1230a 100644 --- a/app/agent/tools/impl/query_download_tasks.py +++ b/app/agent/tools/impl/query_download_tasks.py @@ -25,6 +25,10 @@ class QueryDownloadTasksInput(BaseModel): False, description="Include tasks without the MoviePilot built-in tag. Default false keeps the normal MoviePilot task scope.", ) + include_trackers: Optional[bool] = Field( + False, + description="Include tracker URLs when supported. Hash queries always include trackers.", + ) hash: Optional[str] = Field(None, description="Query specific download task by hash (optional, if provided will search for this specific task regardless of status)") title: Optional[str] = Field(None, description="Query download tasks by title/name (optional, supports partial match, searches all tasks if provided)") tag: Optional[str] = Field(None, description="Filter download tasks by tag (optional, supports partial match, e.g. 'movie' will match tasks with tag 'movie' or 'movie_2024')") @@ -131,6 +135,7 @@ class QueryDownloadTasksTool(MoviePilotTool): title: Optional[str] = None, tag: Optional[str] = None, include_all_tags: bool = False, + include_trackers: bool = False, ) -> Dict[str, Any]: """ 同步查询下载器和下载历史,整个链路放在线程池中执行。 @@ -214,6 +219,16 @@ class QueryDownloadTasksTool(MoviePilotTool): if not filtered_downloads: return {"message": "未找到相关下载任务"} + if hash_value or include_trackers: + for torrent in filtered_downloads: + if not getattr(torrent, "hash", None): + continue + tracker_map = download_chain.get_torrent_trackers( + hash_string=torrent.hash, + downloader=getattr(torrent, "downloader", None) or downloader, + ) or {} + torrent.trackers = tracker_map.get(getattr(torrent, "downloader", None)) or [] + return {"downloads": filtered_downloads} def get_tool_message(self, **kwargs) -> Optional[str]: @@ -245,6 +260,8 @@ class QueryDownloadTasksTool(MoviePilotTool): parts.append(f"标签: {tag}") if include_all_tags: parts.append("范围: 全部标签") + if kwargs.get("include_trackers"): + parts.append("包含Tracker") return " | ".join(parts) if len(parts) > 1 else parts[0] @@ -254,10 +271,12 @@ class QueryDownloadTasksTool(MoviePilotTool): title: Optional[str] = None, tag: Optional[str] = None, include_all_tags: Optional[bool] = False, + include_trackers: Optional[bool] = False, **kwargs) -> str: logger.info( f"执行工具: {self.name}, 参数: downloader={downloader}, status={status}, " - f"hash={hash}, title={title}, tag={tag}, include_all_tags={include_all_tags}" + f"hash={hash}, title={title}, tag={tag}, include_all_tags={include_all_tags}, " + f"include_trackers={include_trackers}" ) try: payload = await self.run_blocking( @@ -269,6 +288,7 @@ class QueryDownloadTasksTool(MoviePilotTool): title, tag, self._normalize_include_all_tags(include_all_tags), + self._normalize_include_all_tags(include_trackers), ) if payload.get("message"): return payload["message"] @@ -294,6 +314,16 @@ class QueryDownloadTasksTool(MoviePilotTool): "upspeed": getattr(d, "upspeed", None), "dlspeed": getattr(d, "dlspeed", None), "tags": d.tags, + "save_path": getattr(d, "save_path", None), + "content_path": getattr(d, "content_path", None) or ( + d.path.as_posix() if getattr(d, "path", None) else None + ), + "category": getattr(d, "category", None), + "download_limit": getattr(d, "download_limit", None), + "upload_limit": getattr(d, "upload_limit", None), + "ratio_limit": getattr(d, "ratio_limit", None), + "seeding_time_limit": getattr(d, "seeding_time_limit", None), + "trackers": getattr(d, "trackers", None) or [], "left_time": getattr(d, "left_time", None) } # 精简 media 字段 diff --git a/app/agent/tools/impl/update_download_tasks.py b/app/agent/tools/impl/update_download_tasks.py new file mode 100644 index 00000000..49aa0bcc --- /dev/null +++ b/app/agent/tools/impl/update_download_tasks.py @@ -0,0 +1,310 @@ +"""更新下载任务工具""" + +import json +from typing import Any, Dict, List, Optional, Type + +from pydantic import BaseModel, Field + +from app.agent.tools.base import MoviePilotTool +from app.agent.tools.tags import ToolTag +from app.chain.download import DownloadChain +from app.log import logger + + +class UpdateDownloadTasksInput(BaseModel): + """更新下载任务工具的输入参数模型""" + + explanation: Optional[str] = Field( + None, + description="Clear explanation of why this tool is being used in the current context", + ) + hash: str = Field( + ..., description="Task hash (can be obtained from query_download_tasks tool)" + ) + action: Optional[str] = Field( + None, + description="Action to perform on the task: 'start' to resume downloading, 'stop' to pause downloading.", + ) + tags: Optional[List[str]] = Field( + None, + description="List of tags to add to the download task. Example: ['movie', 'hd']", + ) + downloader: Optional[str] = Field( + None, + description="Name of specific downloader. If omitted, the tool resolves it from the task hash.", + ) + download_limit: Optional[float] = Field( + None, + description="Per-task download speed limit in KB/s. Use 0 to disable the limit when supported.", + ) + upload_limit: Optional[float] = Field( + None, + description="Per-task upload speed limit in KB/s. Use 0 to disable the limit when supported.", + ) + trackers: Optional[List[str]] = Field( + None, + description="Tracker URL list to add or set, depending on downloader support.", + ) + save_path: Optional[str] = Field( + None, + description="New save/download directory for the task, when supported.", + ) + category: Optional[str] = Field( + None, + description="Downloader category to set, when supported.", + ) + ratio_limit: Optional[float] = Field( + None, + description="Per-task share ratio limit, when supported.", + ) + seeding_time_limit: Optional[int] = Field( + None, + description="Per-task seeding time limit in minutes, when supported.", + ) + + +class UpdateDownloadTasksTool(MoviePilotTool): + """更新下载任务工具""" + + name: str = "update_download_tasks" + tags: list[str] = [ + ToolTag.Write, + ToolTag.Download, + ToolTag.Admin, + ] + description: str = ( + "Update a download task by hash. Supports start/stop, adding tags, per-task " + "upload/download speed limits, trackers, save directory, category, share ratio, " + "and seeding time where the configured downloader supports them. " + "Use query_download_tasks first to get the hash and current downloader." + ) + args_schema: Type[BaseModel] = UpdateDownloadTasksInput + require_admin: bool = True + + @staticmethod + def _is_valid_hash(hash_value: str) -> bool: + """校验下载任务Hash格式。""" + return len(hash_value) == 40 and all(c in "0123456789abcdefABCDEF" for c in hash_value) + + @staticmethod + def _normalize_non_empty_list(values: Optional[List[str]]) -> Optional[List[str]]: + """清理字符串列表中的空值。""" + if values is None: + return None + return [str(value).strip() for value in values if str(value).strip()] + + @staticmethod + def _has_update_params(**kwargs) -> bool: + """判断是否传入至少一个修改参数。""" + return any(value is not None and value != [] for value in kwargs.values()) + + @staticmethod + def _build_result(operation: str, success: bool, message: str) -> Dict[str, Any]: + """构造单项操作结果。""" + return { + "operation": operation, + "success": success, + "message": message, + } + + @classmethod + def _resolve_downloader( + cls, + download_chain: DownloadChain, + hash_value: str, + downloader: Optional[str], + ) -> Optional[str]: + """根据Hash解析下载任务所在下载器。""" + if downloader: + return downloader + torrents = download_chain.list_torrents( + hashs=[hash_value], + include_all_tags=True, + ) or [] + return getattr(torrents[0], "downloader", None) if torrents else None + + @classmethod + def _update_download_sync( + cls, + hash_value: str, + action: Optional[str] = None, + tags: Optional[List[str]] = None, + downloader: Optional[str] = None, + download_limit: Optional[float] = None, + upload_limit: Optional[float] = None, + trackers: Optional[List[str]] = None, + save_path: Optional[str] = None, + category: Optional[str] = None, + ratio_limit: Optional[float] = None, + seeding_time_limit: Optional[int] = None, + ) -> Dict[str, Any]: + """同步更新下载任务,避免下载器 SDK 阻塞事件循环。""" + download_chain = DownloadChain() + resolved_downloader = cls._resolve_downloader( + download_chain=download_chain, + hash_value=hash_value, + downloader=downloader, + ) + if not resolved_downloader: + return { + "hash": hash_value, + "downloader": downloader, + "results": [ + cls._build_result("resolve_downloader", False, "未找到下载任务或下载器不可用") + ], + } + + results = [] + if tags: + tag_result = download_chain.set_torrents_tag( + hashs=[hash_value], tags=tags, downloader=resolved_downloader + ) + results.append( + cls._build_result( + "tags", + bool(tag_result), + f"成功设置标签:{', '.join(tags)}" if tag_result else "设置标签失败", + ) + ) + + if action: + action_result = download_chain.set_downloading( + hash_str=hash_value, oper=action, name=resolved_downloader + ) + action_desc = "开始" if action == "start" else "暂停" + results.append( + cls._build_result( + action, + bool(action_result), + f"成功{action_desc}下载任务" if action_result else f"{action_desc}下载任务失败", + ) + ) + + update_result = {} + if cls._has_update_params( + download_limit=download_limit, + upload_limit=upload_limit, + trackers=trackers, + save_path=save_path, + category=category, + ratio_limit=ratio_limit, + seeding_time_limit=seeding_time_limit, + ): + update_result = download_chain.update_torrent( + hash_string=hash_value, + downloader=resolved_downloader, + download_limit=download_limit, + upload_limit=upload_limit, + tracker_list=trackers, + save_path=save_path, + category=category, + ratio_limit=ratio_limit, + seeding_time_limit=seeding_time_limit, + ) + operation_messages = { + "limits": "限速/做种策略", + "trackers": "Tracker", + "save_path": "保存目录", + "category": "分类", + } + for operation, success in (update_result or {}).items(): + label = operation_messages.get(operation, operation) + results.append( + cls._build_result( + operation, + bool(success), + f"{label}修改成功" if success else f"{label}修改失败或下载器不支持", + ) + ) + + return { + "hash": hash_value, + "downloader": resolved_downloader, + "results": results, + } + + def get_tool_message(self, **kwargs) -> Optional[str]: + """根据更新参数生成友好的提示消息。""" + hash_value = kwargs.get("hash", "") + parts = [f"更新下载任务: {hash_value}"] + action = kwargs.get("action") + if action == "start": + parts.append("操作: 开始下载") + elif action == "stop": + parts.append("操作: 暂停下载") + if kwargs.get("tags"): + parts.append(f"标签: {', '.join(kwargs.get('tags'))}") + if kwargs.get("download_limit") is not None or kwargs.get("upload_limit") is not None: + parts.append("限速") + if kwargs.get("trackers") is not None: + parts.append("Tracker") + if kwargs.get("save_path"): + parts.append("保存目录") + if kwargs.get("category") is not None: + parts.append("分类") + if kwargs.get("downloader"): + parts.append(f"下载器: {kwargs.get('downloader')}") + return " | ".join(parts) + + async def run( + self, + hash: str, + action: Optional[str] = None, + tags: Optional[List[str]] = None, + downloader: Optional[str] = None, + download_limit: Optional[float] = None, + upload_limit: Optional[float] = None, + trackers: Optional[List[str]] = None, + save_path: Optional[str] = None, + category: Optional[str] = None, + ratio_limit: Optional[float] = None, + seeding_time_limit: Optional[int] = None, + **kwargs, + ) -> str: + """执行下载任务更新。""" + logger.info( + f"执行工具: {self.name}, 参数: hash={hash}, action={action}, tags={tags}, " + f"downloader={downloader}, download_limit={download_limit}, upload_limit={upload_limit}, " + f"trackers={trackers}, save_path={save_path}, category={category}, " + f"ratio_limit={ratio_limit}, seeding_time_limit={seeding_time_limit}" + ) + try: + if not self._is_valid_hash(hash): + return "参数错误:hash 格式无效,请先使用 query_download_tasks 工具获取正确的 hash。" + + tags = self._normalize_non_empty_list(tags) + trackers = self._normalize_non_empty_list(trackers) + if action and action not in ("start", "stop"): + return f"参数错误:action 只支持 'start'(开始下载)或 'stop'(暂停下载),收到: '{action}'。" + if not self._has_update_params( + action=action, + tags=tags, + download_limit=download_limit, + upload_limit=upload_limit, + trackers=trackers, + save_path=save_path, + category=category, + ratio_limit=ratio_limit, + seeding_time_limit=seeding_time_limit, + ): + return "参数错误:至少需要指定一个要更新的字段。" + + result = await self.run_blocking( + "downloader", + self._update_download_sync, + hash, + action, + tags, + downloader, + download_limit, + upload_limit, + trackers, + save_path, + category, + ratio_limit, + seeding_time_limit, + ) + return json.dumps(result, ensure_ascii=False, indent=2) + except Exception as e: + logger.error(f"更新下载任务失败: {e}", exc_info=True) + return f"更新下载任务时发生错误: {str(e)}" diff --git a/app/chain/__init__.py b/app/chain/__init__.py index 576a9c93..a00ea96d 100644 --- a/app/chain/__init__.py +++ b/app/chain/__init__.py @@ -1353,6 +1353,61 @@ class ChainBase(metaclass=ABCMeta): """ return self.run_module("set_torrents_tag", hashs=hashs, tags=tags, downloader=downloader) + def update_torrent( + self, + hash_string: str, + downloader: Optional[str] = None, + download_limit: Optional[float] = None, + upload_limit: Optional[float] = None, + tracker_list: Optional[list] = None, + save_path: Optional[str] = None, + category: Optional[str] = None, + ratio_limit: Optional[float] = None, + seeding_time_limit: Optional[int] = None, + ) -> Optional[Dict[str, bool]]: + """ + 修改下载任务属性。 + :param hash_string: 种子Hash + :param downloader: 下载器 + :param download_limit: 下载限速,单位 KB/s + :param upload_limit: 上传限速,单位 KB/s + :param tracker_list: Tracker URL列表 + :param save_path: 保存目录 + :param category: 分类 + :param ratio_limit: 分享率限制 + :param seeding_time_limit: 做种时间限制,单位分钟 + :return: 各项修改结果 + """ + return self.run_module( + "update_torrent", + hash_string=hash_string, + downloader=downloader, + download_limit=download_limit, + upload_limit=upload_limit, + tracker_list=tracker_list, + save_path=save_path, + category=category, + ratio_limit=ratio_limit, + seeding_time_limit=seeding_time_limit, + ) + + def get_torrent_trackers( + self, + hash_string: str, + downloader: Optional[str] = None, + ) -> Optional[Dict[str, List[str]]]: + """ + 查询下载任务Tracker列表。 + :param hash_string: 种子Hash + :param downloader: 下载器 + :return: 下载器名称到Tracker列表的映射 + """ + return self.run_module( + "get_torrent_trackers", + hash_string=hash_string, + downloader=downloader, + ) + def torrent_files( self, tid: str, downloader: Optional[str] = None ) -> Optional[Union[TorrentFilesList, List[File]]]: diff --git a/app/modules/qbittorrent/__init__.py b/app/modules/qbittorrent/__init__.py index 70902f66..324e9558 100644 --- a/app/modules/qbittorrent/__init__.py +++ b/app/modules/qbittorrent/__init__.py @@ -307,9 +307,18 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase[Qbittorrent]): year=meta.year, season_episode=meta.season_episode, path=Path(self.normalize_return_path(torrent_path, downloader_name)), + save_path=self.normalize_return_path( + Path(torrent_data.get('save_path') or ""), downloader_name + ) if torrent_data.get('save_path') else None, + content_path=self.normalize_return_path(torrent_path, downloader_name), hash=torrent_data.get('hash'), size=total_size, tags=torrent_data.get('tags'), + category=torrent_data.get('category'), + download_limit=(torrent_data.get('dl_limit') or 0) / 1024, + upload_limit=(torrent_data.get('up_limit') or 0) / 1024, + ratio_limit=torrent_data.get('ratio_limit'), + seeding_time_limit=torrent_data.get('seeding_time_limit'), progress=(torrent_data.get('progress') or 0) * 100, state=self.__normalize_torrent_state(torrent_data.get('state')), dlspeed=StringUtils.str_filesize(dlspeed), @@ -464,6 +473,86 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase[Qbittorrent]): server.set_torrents_tag(ids=hashs, tags=tags) return True + def update_torrent( + self, + hash_string: str, + downloader: Optional[str] = None, + download_limit: Optional[float] = None, + upload_limit: Optional[float] = None, + tracker_list: Optional[list] = None, + save_path: Optional[str] = None, + category: Optional[str] = None, + ratio_limit: Optional[float] = None, + seeding_time_limit: Optional[int] = None, + ) -> Optional[Dict[str, bool]]: + """ + 修改下载任务属性。 + :param hash_string: 种子Hash + :param downloader: 下载器 + :param download_limit: 下载限速,单位 KB/s + :param upload_limit: 上传限速,单位 KB/s + :param tracker_list: Tracker URL列表 + :param save_path: 保存目录 + :param category: 分类 + :param ratio_limit: 分享率限制 + :param seeding_time_limit: 做种时间限制,单位分钟 + :return: 各项修改结果 + """ + server: Qbittorrent = self.get_instance(downloader) + if not server: + return None + results = {} + if any( + value is not None + for value in (download_limit, upload_limit, ratio_limit, seeding_time_limit) + ): + results["limits"] = server.change_torrent( + hash_string=hash_string, + download_limit=download_limit, + upload_limit=upload_limit, + ratio_limit=ratio_limit, + seeding_time_limit=seeding_time_limit, + ) + if tracker_list is not None: + results["trackers"] = server.update_tracker( + hash_string=hash_string, tracker_list=tracker_list + ) + if save_path is not None: + results["save_path"] = server.set_torrent_location( + hash_string=hash_string, + location=self.normalize_path(Path(save_path), downloader), + ) + if category is not None: + results["category"] = server.set_torrent_category( + hash_string=hash_string, category=category + ) + return results + + def get_torrent_trackers( + self, + hash_string: str, + downloader: Optional[str] = None, + ) -> Optional[Dict[str, List[str]]]: + """ + 查询下载任务Tracker列表。 + :param hash_string: 种子Hash + :param downloader: 下载器 + :return: 下载器名称到Tracker列表的映射 + """ + if downloader: + server: Qbittorrent = self.get_instance(downloader) + if not server: + return None + servers = {downloader: server} + else: + servers: Dict[str, Qbittorrent] = self.get_instances() + ret_trackers = {} + for name, server in servers.items(): + trackers = server.get_trackers(hash_string) + if trackers is not None: + ret_trackers[name] = trackers + return ret_trackers + def start_torrents(self, hashs: Union[list, str], downloader: Optional[str] = None) -> Optional[bool]: """ diff --git a/app/modules/qbittorrent/qbittorrent.py b/app/modules/qbittorrent/qbittorrent.py index 11f02a2e..19766c16 100644 --- a/app/modules/qbittorrent/qbittorrent.py +++ b/app/modules/qbittorrent/qbittorrent.py @@ -537,8 +537,8 @@ class Qbittorrent: """ if not self.qbc: return False - download_limit = download_limit * 1024 - upload_limit = upload_limit * 1024 + download_limit = (download_limit or 0) * 1024 + upload_limit = (upload_limit or 0) * 1024 try: self.qbc.transfer.upload_limit = int(upload_limit) self.qbc.transfer.download_limit = int(download_limit) @@ -578,6 +578,87 @@ class Qbittorrent: logger.error(f"重新校验种子出错:{str(err)}") return False + def change_torrent( + self, + hash_string: str, + upload_limit: Optional[float] = None, + download_limit: Optional[float] = None, + ratio_limit: Optional[float] = None, + seeding_time_limit: Optional[int] = None, + ) -> bool: + """ + 修改单个种子的限速和做种策略。 + :param hash_string: 种子Hash + :param upload_limit: 上传限速,单位 KB/s,0 表示不限速 + :param download_limit: 下载限速,单位 KB/s,0 表示不限速 + :param ratio_limit: 分享率限制 + :param seeding_time_limit: 做种时间限制,单位分钟 + :return: 是否修改成功 + """ + if not self.qbc or not hash_string: + return False + try: + if upload_limit is not None: + self.qbc.torrents_set_upload_limit( + limit=int(float(upload_limit) * 1024), + torrent_hashes=hash_string, + ) + if download_limit is not None: + self.qbc.torrents_set_download_limit( + limit=int(float(download_limit) * 1024), + torrent_hashes=hash_string, + ) + if ratio_limit is not None: + self.qbc.torrents_set_share_limits( + ratio_limit=round(float(ratio_limit), 2), + seeding_time_limit=int(seeding_time_limit or -1), + inactive_seeding_time_limit=-1, + torrent_hashes=hash_string, + ) + elif seeding_time_limit is not None: + self.qbc.torrents_set_share_limits( + ratio_limit=-2, + seeding_time_limit=int(seeding_time_limit), + inactive_seeding_time_limit=-1, + torrent_hashes=hash_string, + ) + return True + except Exception as err: + logger.error(f"设置种子属性出错:{str(err)}") + return False + + def set_torrent_location(self, hash_string: str, location: str) -> bool: + """ + 修改种子保存目录。 + :param hash_string: 种子Hash + :param location: 新保存目录 + :return: 是否修改成功 + """ + if not self.qbc or not hash_string or not location: + return False + try: + self.qbc.torrents_set_location(location=location, torrent_hashes=hash_string) + return True + except Exception as err: + logger.error(f"设置种子保存目录出错:{str(err)}") + return False + + def set_torrent_category(self, hash_string: str, category: str) -> bool: + """ + 修改种子分类。 + :param hash_string: 种子Hash + :param category: 分类名称 + :return: 是否修改成功 + """ + if not self.qbc or not hash_string: + return False + try: + self.qbc.torrents_set_category(category=category or "", torrent_hashes=hash_string) + return True + except Exception as err: + logger.error(f"设置种子分类出错:{str(err)}") + return False + def update_tracker(self, hash_string: str, tracker_list: list) -> bool: """ 添加tracker @@ -591,6 +672,25 @@ class Qbittorrent: logger.error(f"修改tracker出错:{str(err)}") return False + def get_trackers(self, hash_string: str) -> Optional[List[str]]: + """ + 获取种子Tracker列表。 + :param hash_string: 种子Hash + :return: Tracker URL列表 + """ + if not self.qbc or not hash_string: + return None + try: + trackers = self.qbc.torrents_trackers(torrent_hash=hash_string) or [] + return [ + tracker.get("url") + for tracker in trackers + if tracker.get("url") + ] + except Exception as err: + logger.error(f"获取tracker出错:{str(err)}") + return None + def get_content_layout(self) -> Optional[str]: """ 获取内容布局 diff --git a/app/modules/rtorrent/__init__.py b/app/modules/rtorrent/__init__.py index fb078e3c..f8719046 100644 --- a/app/modules/rtorrent/__init__.py +++ b/app/modules/rtorrent/__init__.py @@ -338,6 +338,10 @@ class RtorrentModule(_ModuleBase, _DownloaderBase[Rtorrent]): year=meta.year, season_episode=meta.season_episode, path=Path(self.normalize_return_path(torrent_path, downloader_name)), + save_path=self.normalize_return_path( + Path(torrent_data.get("save_path") or ""), downloader_name + ) if torrent_data.get("save_path") else None, + content_path=self.normalize_return_path(torrent_path, downloader_name), progress=torrent_data.get("progress", 0), size=total_size, state=self.__normalize_torrent_state( @@ -522,6 +526,54 @@ class RtorrentModule(_ModuleBase, _DownloaderBase[Rtorrent]): return None return server.set_torrents_tag(ids=hashs, tags=tags) + def update_torrent( + self, + hash_string: str, + downloader: Optional[str] = None, + download_limit: Optional[float] = None, + upload_limit: Optional[float] = None, + tracker_list: Optional[list] = None, + save_path: Optional[str] = None, + category: Optional[str] = None, + ratio_limit: Optional[float] = None, + seeding_time_limit: Optional[int] = None, + ) -> Optional[Dict[str, bool]]: + """ + 修改下载任务属性。 + :param hash_string: 种子Hash + :param downloader: 下载器 + :param download_limit: 下载限速,单位 KB/s + :param upload_limit: 上传限速,单位 KB/s + :param tracker_list: Tracker URL列表,rTorrent 当前封装不支持 + :param save_path: 保存目录 + :param category: 分类,rTorrent 当前封装不支持 + :param ratio_limit: 分享率限制,rTorrent 当前封装不支持 + :param seeding_time_limit: 做种时间限制,rTorrent 当前封装不支持 + :return: 各项修改结果 + """ + server: Rtorrent = self.get_instance(downloader) + if not server: + return None + results = {} + if download_limit is not None or upload_limit is not None: + results["limits"] = server.change_torrent( + hash_string=hash_string, + download_limit=download_limit, + upload_limit=upload_limit, + ) + if ratio_limit is not None or seeding_time_limit is not None: + results["seeding_limits"] = False + if tracker_list is not None: + results["trackers"] = False + if save_path is not None: + results["save_path"] = server.set_torrent_location( + hash_string=hash_string, + location=self.normalize_path(Path(save_path), downloader), + ) + if category is not None: + results["category"] = False + return results + def start_torrents( self, hashs: Union[list, str], downloader: Optional[str] = None ) -> Optional[bool]: diff --git a/app/modules/rtorrent/rtorrent.py b/app/modules/rtorrent/rtorrent.py index 672c4525..09fc978d 100644 --- a/app/modules/rtorrent/rtorrent.py +++ b/app/modules/rtorrent/rtorrent.py @@ -530,6 +530,62 @@ class Rtorrent: break return torrent_id + @staticmethod + def __build_throttle_name(torrent_hash: str) -> str: + """ + 生成单任务限速组名称。 + """ + return f"mp_{torrent_hash.lower()[:16]}" + + def change_torrent( + self, + hash_string: str, + upload_limit: Optional[float] = None, + download_limit: Optional[float] = None, + ) -> bool: + """ + 修改单个种子的上传和下载限速。 + :param hash_string: 种子Hash + :param upload_limit: 上传限速,单位 KB/s,0 表示不限速 + :param download_limit: 下载限速,单位 KB/s,0 表示不限速 + :return: 是否修改成功 + """ + if not self._proxy or not hash_string: + return False + try: + throttle_name = self.__build_throttle_name(hash_string) + if download_limit is not None: + self._proxy.throttle.down.max.set( + throttle_name, + int(float(download_limit) * 1024), + ) + if upload_limit is not None: + self._proxy.throttle.up.max.set( + throttle_name, + int(float(upload_limit) * 1024), + ) + self._proxy.d.throttle_name.set(hash_string, throttle_name) + return True + except Exception as err: + logger.error(f"设置种子限速出错:{str(err)}") + return False + + def set_torrent_location(self, hash_string: str, location: str) -> bool: + """ + 修改种子保存目录。 + :param hash_string: 种子Hash + :param location: 新保存目录 + :return: 是否修改成功 + """ + if not self._proxy or not hash_string or not location: + return False + try: + self._proxy.d.directory.set(hash_string, location) + return True + except Exception as err: + logger.error(f"设置种子保存目录出错:{str(err)}") + return False + def transfer_info(self) -> Optional[Dict]: """ 获取传输信息 diff --git a/app/modules/transmission/__init__.py b/app/modules/transmission/__init__.py index bbae4358..82c5c113 100644 --- a/app/modules/transmission/__init__.py +++ b/app/modules/transmission/__init__.py @@ -310,6 +310,8 @@ class TransmissionModule(_ModuleBase, _DownloaderBase[Transmission]): torrent_data, "left_until_done", "leftUntilDone" ) or 0 torrent_path = __get_torrent_path(torrent_data) + ratio_limit = __get_torrent_attr(torrent_data, "seed_ratio_limit", "seedRatioLimit") + seeding_time_limit = __get_torrent_attr(torrent_data, "seed_idle_limit", "seedIdleLimit") return DownloaderTorrent( downloader=downloader_name, hash=torrent_data.hashString, @@ -318,12 +320,20 @@ class TransmissionModule(_ModuleBase, _DownloaderBase[Transmission]): year=meta.year, season_episode=meta.season_episode, path=Path(self.normalize_return_path(torrent_path, downloader_name)), + save_path=self.normalize_return_path( + Path(torrent_data.download_dir), downloader_name + ) if getattr(torrent_data, "download_dir", None) else None, + content_path=self.normalize_return_path(torrent_path, downloader_name), progress=__get_torrent_progress(torrent_data), size=__get_torrent_size(torrent_data), state=self.__normalize_torrent_state(torrent_data.status), dlspeed=StringUtils.str_filesize(dlspeed), upspeed=StringUtils.str_filesize(upspeed), tags=__get_torrent_labels(torrent_data), + download_limit=__get_torrent_attr(torrent_data, "download_limit", "downloadLimit"), + upload_limit=__get_torrent_attr(torrent_data, "upload_limit", "uploadLimit"), + ratio_limit=ratio_limit, + seeding_time_limit=seeding_time_limit, left_time=StringUtils.str_secends( left_until_done / dlspeed ) if dlspeed > 0 else '' @@ -491,6 +501,85 @@ class TransmissionModule(_ModuleBase, _DownloaderBase[Transmission]): org_tags = server.get_torrent_tags(ids=hashs) return server.set_torrent_tag(ids=hashs, tags=tags, org_tags=org_tags) + def update_torrent( + self, + hash_string: str, + downloader: Optional[str] = None, + download_limit: Optional[float] = None, + upload_limit: Optional[float] = None, + tracker_list: Optional[list] = None, + save_path: Optional[str] = None, + category: Optional[str] = None, + ratio_limit: Optional[float] = None, + seeding_time_limit: Optional[int] = None, + ) -> Optional[Dict[str, bool]]: + """ + 修改下载任务属性。 + :param hash_string: 种子Hash + :param downloader: 下载器 + :param download_limit: 下载限速,单位 KB/s + :param upload_limit: 上传限速,单位 KB/s + :param tracker_list: Tracker URL列表 + :param save_path: 保存目录 + :param category: 分类,Transmission 不支持 + :param ratio_limit: 分享率限制 + :param seeding_time_limit: 做种时间限制,单位分钟 + :return: 各项修改结果 + """ + server: Transmission = self.get_instance(downloader) + if not server: + return None + results = {} + if any( + value is not None + for value in (download_limit, upload_limit, ratio_limit, seeding_time_limit) + ): + change_result = server.change_torrent( + hash_string=hash_string, + download_limit=download_limit, + upload_limit=upload_limit, + ratio_limit=ratio_limit, + seeding_time_limit=seeding_time_limit, + ) + results["limits"] = change_result + if save_path is not None: + results["save_path"] = server.set_torrent_location( + hash_string=hash_string, + location=self.normalize_path(Path(save_path), downloader), + ) + if tracker_list is not None: + results["trackers"] = server.update_tracker( + hash_string=hash_string, tracker_list=tracker_list + ) + if category is not None: + results["category"] = False + return results + + def get_torrent_trackers( + self, + hash_string: str, + downloader: Optional[str] = None, + ) -> Optional[Dict[str, List[str]]]: + """ + 查询下载任务Tracker列表。 + :param hash_string: 种子Hash + :param downloader: 下载器 + :return: 下载器名称到Tracker列表的映射 + """ + if downloader: + server: Transmission = self.get_instance(downloader) + if not server: + return None + servers = {downloader: server} + else: + servers: Dict[str, Transmission] = self.get_instances() + ret_trackers = {} + for name, server in servers.items(): + trackers = server.get_trackers(hash_string) + if trackers is not None: + ret_trackers[name] = trackers + return ret_trackers + def start_torrents(self, hashs: Union[list, str], downloader: Optional[str] = None) -> Optional[bool]: """ diff --git a/app/modules/transmission/transmission.py b/app/modules/transmission/transmission.py index e91fb3fc..72293ba7 100755 --- a/app/modules/transmission/transmission.py +++ b/app/modules/transmission/transmission.py @@ -402,45 +402,46 @@ class Transmission: """ if not hash_string: return False - if upload_limit: - uploadLimited = True - uploadLimit = int(upload_limit) - else: - uploadLimited = False - uploadLimit = 0 - if download_limit: - downloadLimited = True - downloadLimit = int(download_limit) - else: - downloadLimited = False - downloadLimit = 0 - if ratio_limit: - seedRatioMode = 1 - seedRatioLimit = round(float(ratio_limit), 2) - else: - seedRatioMode = 2 - seedRatioLimit = 0 - if seeding_time_limit: - seedIdleMode = 1 - seedIdleLimit = int(seeding_time_limit) - else: - seedIdleMode = 2 - seedIdleLimit = 0 + change_kwargs = {"ids": hash_string} + if upload_limit is not None: + change_kwargs["uploadLimited"] = bool(upload_limit) + change_kwargs["uploadLimit"] = int(upload_limit) + if download_limit is not None: + change_kwargs["downloadLimited"] = bool(download_limit) + change_kwargs["downloadLimit"] = int(download_limit) + if ratio_limit is not None: + change_kwargs["seedRatioMode"] = 1 if ratio_limit else 2 + change_kwargs["seedRatioLimit"] = round(float(ratio_limit), 2) if ratio_limit else 0 + if seeding_time_limit is not None: + change_kwargs["seedIdleMode"] = 1 if seeding_time_limit else 2 + change_kwargs["seedIdleLimit"] = int(seeding_time_limit) if seeding_time_limit else 0 try: - self.trc.change_torrent(ids=hash_string, - uploadLimited=uploadLimited, - uploadLimit=uploadLimit, - downloadLimited=downloadLimited, - downloadLimit=downloadLimit, - seedRatioMode=seedRatioMode, - seedRatioLimit=seedRatioLimit, - seedIdleMode=seedIdleMode, - seedIdleLimit=seedIdleLimit) + self.trc.change_torrent(**change_kwargs) return True except Exception as err: logger.error(f"设置种子出错:{str(err)}") return False + def set_torrent_location(self, hash_string: str, location: str) -> bool: + """ + 修改种子保存目录。 + :param hash_string: 种子Hash + :param location: 新保存目录 + :return: 是否修改成功 + """ + if not self.trc or not hash_string or not location: + return False + try: + move_torrent_data = getattr(self.trc, "move_torrent_data", None) + if callable(move_torrent_data): + move_torrent_data(ids=hash_string, location=location) + else: + self.trc.change_torrent(ids=hash_string, download_dir=location) + return True + except Exception as err: + logger.error(f"设置种子保存目录出错:{str(err)}") + return False + def update_tracker(self, hash_string: str, tracker_list: list = None) -> bool: """ tr4.0及以上弃用直接设置tracker 共用change方法 @@ -456,6 +457,34 @@ class Transmission: logger.error(f"修改tracker出错:{str(err)}") return False + def get_trackers(self, hash_string: str) -> Optional[List[str]]: + """ + 获取种子Tracker列表。 + :param hash_string: 种子Hash + :return: Tracker URL列表 + """ + if not self.trc or not hash_string: + return None + try: + torrents = self.trc.get_torrents(ids=hash_string, arguments=self._trarg) + if not torrents: + return [] + torrent = torrents[0] + tracker_list = getattr(torrent, "tracker_list", None) \ + or getattr(torrent, "trackerList", None) \ + or [] + if tracker_list: + return list(tracker_list) + trackers = getattr(torrent, "trackers", None) or [] + return [ + tracker.get("announce") + for tracker in trackers + if isinstance(tracker, dict) and tracker.get("announce") + ] + except Exception as err: + logger.error(f"获取tracker出错:{str(err)}") + return None + def get_session(self) -> Optional[Session]: """ 获取Transmission当前的会话信息和配置设置 diff --git a/app/schemas/transfer.py b/app/schemas/transfer.py index e9f343e1..46cfeb65 100644 --- a/app/schemas/transfer.py +++ b/app/schemas/transfer.py @@ -27,6 +27,14 @@ class DownloaderTorrent(BaseModel): upspeed: Optional[str] = None dlspeed: Optional[str] = None tags: Optional[str] = None + save_path: Optional[str] = None + content_path: Optional[str] = None + category: Optional[str] = None + download_limit: Optional[float] = None + upload_limit: Optional[float] = None + ratio_limit: Optional[float] = None + seeding_time_limit: Optional[int] = None + trackers: Optional[List[str]] = Field(default_factory=list) media: Optional[dict] = Field(default_factory=dict) userid: Optional[str] = None username: Optional[str] = None diff --git a/docs/mcp-api.md b/docs/mcp-api.md index 0c0654ec..04102ba5 100644 --- a/docs/mcp-api.md +++ b/docs/mcp-api.md @@ -236,6 +236,11 @@ MoviePilot 也提供普通 REST API 给前端和自动化客户端使用。所 `recognize_captcha` 用于浏览器自动化登录时识别普通图形验证码。智能体可先通过 `browse_webpage` 的 `evaluate` 动作从页面元素中提取 `img.src`,再把图片地址传给该工具;支持 `http/https` 图片地址和 `data:image/...;base64,...`。当验证码图片依赖当前浏览器会话时,可传入 Cookie 与 User-Agent。出于安全考虑,默认拒绝访问 localhost、环回地址、私网地址和链路本地地址;确需访问可信内网或本机验证码图片时,可显式传入 `allow_private_network: true`。 +**下载任务工具说明**: + +- `query_download_tasks` 用于查询下载任务,支持按下载器、状态、Hash、标题、标签过滤;返回保存目录、内容路径、上传/下载速度、上传/下载限速、分类、分享率、做种时间等下载器可提供的字段。按 `hash` 查询或传入 `include_trackers=true` 时,会尽量返回 Tracker URL 列表。 +- `update_download_tasks` 用于修改下载任务,统一支持 `start`/`stop`、标签、上传/下载限速、Tracker、保存目录、分类、分享率、做种时间等字段;具体字段是否成功取决于下载器能力,返回结果会按操作项逐条标记成功或失败。 + ### 3. 获取工具详情 **GET** `/api/v1/mcp/tools/{tool_name}` diff --git a/skills/moviepilot-cli/SKILL.md b/skills/moviepilot-cli/SKILL.md index 74ebc89d..ebd6b01c 100644 --- a/skills/moviepilot-cli/SKILL.md +++ b/skills/moviepilot-cli/SKILL.md @@ -24,7 +24,7 @@ Always run `show ` before calling a command — parameter names are not |---|---| | Media Search | search_media, recognize_media, query_media_detail, get_recommendations, search_person, search_person_credits | | Torrent | search_torrents, get_search_results | -| Download | add_download, query_download_tasks, delete_download, query_downloaders | +| Download | add_download, query_download_tasks, update_download_tasks, delete_download, query_downloaders | | Subscription | add_subscribe, query_subscribes, update_subscribe, delete_subscribe, search_subscribe, query_subscribe_history, query_popular_subscribes, query_subscribe_shares | | Library | query_library_exists, query_library_latest, transfer_file, scrape_metadata, query_transfer_history | | Files | list_directory, query_directory_settings | @@ -126,7 +126,13 @@ Subscribe starting from a specific episode: List download tasks and get hash for further operations: `node scripts/mp-cli.js query_download_tasks status=downloading` -Use `status=completed` for tasks that are neither downloading nor paused in the downloader; use `status=all` to include every MoviePilot-tagged downloader task. Add `include_all_tags=true` when diagnosing tasks that do not have the MoviePilot built-in tag. +Use `status=completed` for tasks that are neither downloading nor paused in the downloader; use `status=all` to include every MoviePilot-tagged downloader task. Add `include_all_tags=true` when diagnosing tasks that do not have the MoviePilot built-in tag. Add `include_trackers=true` or query by `hash` when tracker URLs are needed. + +Update a download task (supports start/stop, tags, speed limits, trackers, save path, category, ratio, and seeding time where the downloader supports them): +`node scripts/mp-cli.js update_download_tasks hash= action=stop upload_limit=512 download_limit=2048` + +Add trackers to a download task: +`node scripts/mp-cli.js update_download_tasks hash= trackers='https://tracker.example/announce,udp://tracker.example:80/announce'` Delete a download task (confirm with user first — irreversible): `node scripts/mp-cli.js delete_download hash=` diff --git a/tests/test_agent_query_download_tasks_tool.py b/tests/test_agent_query_download_tasks_tool.py index 9cfbe85e..c3be862d 100644 --- a/tests/test_agent_query_download_tasks_tool.py +++ b/tests/test_agent_query_download_tasks_tool.py @@ -64,6 +64,14 @@ def test_run_completed_status_formats_completed_download_tasks(): progress=100, state="completed", tags="moviepilot", + save_path="/downloads", + content_path="/downloads/QB Done", + category="电影", + download_limit=1024, + upload_limit=512, + ratio_limit=2.0, + seeding_time_limit=1440, + trackers=["https://tracker.example/announce"], ) ] @@ -81,6 +89,50 @@ def test_run_completed_status_formats_completed_download_tasks(): payload = json.loads(result) assert payload[0]["hash"] == "hash-qb" assert payload[0]["state"] == "completed" + assert payload[0]["save_path"] == "/downloads" + assert payload[0]["content_path"] == "/downloads/QB Done" + assert payload[0]["category"] == "电影" + assert payload[0]["download_limit"] == 1024 + assert payload[0]["upload_limit"] == 512 + assert payload[0]["ratio_limit"] == 2.0 + assert payload[0]["seeding_time_limit"] == 1440 + assert payload[0]["trackers"] == ["https://tracker.example/announce"] + + +def test_hash_query_loads_trackers_for_matching_task(): + """ + 按 Hash 查询详情时应额外加载下载器支持的 Tracker 列表。 + """ + torrent = DownloaderTorrent( + downloader="qb", + hash="a" * 40, + title="Task With Trackers", + size=1024, + progress=10, + state="downloading", + tags="moviepilot", + ) + download_chain = MagicMock() + download_chain.list_torrents.return_value = [torrent] + download_chain.get_torrent_trackers.return_value = { + "qb": ["https://tracker.example/announce"] + } + + with patch( + "app.agent.tools.impl.query_download_tasks.DownloadChain", + return_value=download_chain, + ), patch.object( + QueryDownloadTasksTool, + "_load_history_map", + return_value={}, + ): + result = QueryDownloadTasksTool._query_downloads_sync(hash_value="a" * 40) + + assert result["downloads"][0].trackers == ["https://tracker.example/announce"] + download_chain.get_torrent_trackers.assert_called_once_with( + hash_string="a" * 40, + downloader="qb", + ) def test_include_all_tags_passes_scope_to_downloader_query(): diff --git a/tests/test_agent_update_download_tasks_tool.py b/tests/test_agent_update_download_tasks_tool.py new file mode 100644 index 00000000..4ab3d221 --- /dev/null +++ b/tests/test_agent_update_download_tasks_tool.py @@ -0,0 +1,192 @@ +import asyncio +import json +from unittest.mock import MagicMock, patch + +from app.agent.tools.factory import MoviePilotToolFactory +from app.agent.tools.impl.update_download_tasks import UpdateDownloadTasksTool +from app.schemas import DownloaderTorrent + + +def test_update_download_tasks_resolves_downloader_and_updates_all_supported_fields(): + """ + 未显式传下载器时,应先按 Hash 解析任务所属下载器,再一次性执行多项修改。 + """ + hash_value = "a" * 40 + download_chain = MagicMock() + download_chain.list_torrents.return_value = [ + DownloaderTorrent(downloader="qb", hash=hash_value, title="Demo") + ] + download_chain.set_torrents_tag.return_value = True + download_chain.set_downloading.return_value = True + download_chain.update_torrent.return_value = { + "limits": True, + "trackers": True, + "save_path": True, + "category": True, + } + + with patch( + "app.agent.tools.impl.update_download_tasks.DownloadChain", + return_value=download_chain, + ): + result = UpdateDownloadTasksTool._update_download_sync( + hash_value=hash_value, + action="stop", + tags=["movie", "hd"], + download_limit=1024, + upload_limit=512, + trackers=["https://tracker.example/announce"], + save_path="/downloads/new", + category="电影", + ratio_limit=2.5, + seeding_time_limit=1440, + ) + + assert result["downloader"] == "qb" + assert {item["operation"] for item in result["results"]} == { + "tags", + "stop", + "limits", + "trackers", + "save_path", + "category", + } + assert all(item["success"] for item in result["results"]) + download_chain.list_torrents.assert_called_once_with( + hashs=[hash_value], + include_all_tags=True, + ) + download_chain.set_torrents_tag.assert_called_once_with( + hashs=[hash_value], + tags=["movie", "hd"], + downloader="qb", + ) + download_chain.set_downloading.assert_called_once_with( + hash_str=hash_value, + oper="stop", + name="qb", + ) + download_chain.update_torrent.assert_called_once_with( + hash_string=hash_value, + downloader="qb", + download_limit=1024, + upload_limit=512, + tracker_list=["https://tracker.example/announce"], + save_path="/downloads/new", + category="电影", + ratio_limit=2.5, + seeding_time_limit=1440, + ) + + +def test_update_download_tasks_skips_property_update_when_only_action_is_requested(): + """ + 仅开始或暂停任务时,不应额外调用属性修改接口。 + """ + hash_value = "e" * 40 + download_chain = MagicMock() + download_chain.list_torrents.return_value = [ + DownloaderTorrent(downloader="tr", hash=hash_value, title="Demo") + ] + download_chain.set_downloading.return_value = True + + with patch( + "app.agent.tools.impl.update_download_tasks.DownloadChain", + return_value=download_chain, + ): + result = UpdateDownloadTasksTool._update_download_sync( + hash_value=hash_value, + action="start", + ) + + assert result["results"] == [ + {"operation": "start", "success": True, "message": "成功开始下载任务"} + ] + download_chain.update_torrent.assert_not_called() + + +def test_update_download_tasks_reports_missing_task_when_downloader_cannot_be_resolved(): + """ + 找不到任务时,应返回明确的解析失败结果。 + """ + hash_value = "b" * 40 + download_chain = MagicMock() + download_chain.list_torrents.return_value = [] + + with patch( + "app.agent.tools.impl.update_download_tasks.DownloadChain", + return_value=download_chain, + ): + result = UpdateDownloadTasksTool._update_download_sync( + hash_value=hash_value, + action="start", + ) + + assert result["results"] == [ + { + "operation": "resolve_downloader", + "success": False, + "message": "未找到下载任务或下载器不可用", + } + ] + download_chain.set_downloading.assert_not_called() + download_chain.update_torrent.assert_not_called() + + +def test_update_download_tasks_run_rejects_empty_update(): + """ + 没有任何修改字段时,应拒绝调用下载器。 + """ + result = asyncio.run( + UpdateDownloadTasksTool(session_id="session-1", user_id="10001").run( + hash="c" * 40 + ) + ) + + assert "至少需要指定一个要更新的字段" in result + + +def test_update_download_tasks_run_outputs_structured_result(): + """ + 工具运行结果应是结构化 JSON,方便 Agent 判断每项修改是否成功。 + """ + with patch.object( + UpdateDownloadTasksTool, + "_update_download_sync", + return_value={ + "hash": "d" * 40, + "downloader": "tr", + "results": [ + {"operation": "limits", "success": True, "message": "限速/做种策略修改成功"} + ], + }, + ): + result = asyncio.run( + UpdateDownloadTasksTool(session_id="session-1", user_id="10001").run( + hash="d" * 40, + download_limit=100, + ) + ) + + payload = json.loads(result) + assert payload["downloader"] == "tr" + assert payload["results"][0]["operation"] == "limits" + + +def test_factory_registers_update_download_tasks_without_old_modify_name(): + """ + 工具工厂应只暴露统一后的下载任务更新工具名。 + """ + with patch( + "app.agent.tools.factory.PluginManager.get_plugin_agent_tools", + return_value=[], + ): + tools = MoviePilotToolFactory.create_tools( + session_id="download-task-session", + user_id="10001", + ) + + tool_names = {tool.name for tool in tools} + assert "query_download_tasks" in tool_names + assert "update_download_tasks" in tool_names + assert "modify_download" not in tool_names diff --git a/tests/test_qbittorrent_compat.py b/tests/test_qbittorrent_compat.py index 7fea30ec..8a7b490f 100644 --- a/tests/test_qbittorrent_compat.py +++ b/tests/test_qbittorrent_compat.py @@ -492,3 +492,17 @@ def test_download_falls_back_to_tag_lookup_when_added_ids_missing(): assert result == ("qb", "def456", "Original", "添加下载成功") fake_server.delete_torrents_tag.assert_not_called() fake_server.get_torrent_id_by_tag.assert_called_once_with(tags="tmp-tag-01") + + +def test_set_speed_limit_allows_single_direction_limit(): + """ + 设置全局限速时允许只传一个方向,未传方向按不限速处理。 + """ + fake_client = MagicMock() + + with patch.object(Qbittorrent, "_Qbittorrent__login_qbittorrent", return_value=fake_client): + downloader = Qbittorrent(host="http://127.0.0.1", port=8080, username="admin", password="adminadmin") + + assert downloader.set_speed_limit(download_limit=1024) + assert fake_client.transfer.download_limit == 1024 * 1024 + assert fake_client.transfer.upload_limit == 0 diff --git a/tests/test_rtorrent_compat.py b/tests/test_rtorrent_compat.py new file mode 100644 index 00000000..d812e121 --- /dev/null +++ b/tests/test_rtorrent_compat.py @@ -0,0 +1,135 @@ +import importlib.util +import sys +import types +from pathlib import Path +from unittest.mock import MagicMock, patch + + +def _load_rtorrent_client_module(): + """ + 使用轻量桩加载 rTorrent 客户端封装,避免测试依赖完整应用启动。 + """ + repo_root = Path(__file__).resolve().parents[1] + + app_module = types.ModuleType("app") + app_module.__path__ = [] + log_module = types.ModuleType("app.log") + + class _Logger: + """ + 测试日志桩,仅提供被客户端封装调用的方法。 + """ + + def info(self, *_args, **_kwargs): + """ + 忽略信息日志。 + """ + pass + + def warning(self, *_args, **_kwargs): + """ + 忽略警告日志。 + """ + pass + + def error(self, *_args, **_kwargs): + """ + 忽略错误日志。 + """ + pass + + log_module.logger = _Logger() + app_module.log = log_module + + stub_modules = { + "app": app_module, + "app.log": log_module, + } + + rtorrent_path = repo_root / "app" / "modules" / "rtorrent" / "rtorrent.py" + rtorrent_spec = importlib.util.spec_from_file_location( + "app.modules.rtorrent.rtorrent", + rtorrent_path, + ) + rtorrent_module = importlib.util.module_from_spec(rtorrent_spec) + assert rtorrent_spec and rtorrent_spec.loader + + with patch.dict(sys.modules, stub_modules): + rtorrent_spec.loader.exec_module(rtorrent_module) + + return rtorrent_module + + +rtorrent_module = _load_rtorrent_client_module() +Rtorrent = rtorrent_module.Rtorrent + + +def test_change_torrent_sets_per_task_speed_limits(): + """ + rTorrent 单任务限速应创建限速组并绑定到任务。 + """ + downloader = Rtorrent.__new__(Rtorrent) + fake_proxy = MagicMock() + downloader._proxy = fake_proxy + + assert downloader.change_torrent( + hash_string="ABCDEF1234567890ABCDEF1234567890ABCDEF12", + download_limit=1024, + upload_limit=512, + ) + + fake_proxy.throttle.down.max.set.assert_called_once_with( + "mp_abcdef1234567890", + 1024 * 1024, + ) + fake_proxy.throttle.up.max.set.assert_called_once_with( + "mp_abcdef1234567890", + 512 * 1024, + ) + fake_proxy.d.throttle_name.set.assert_called_once_with( + "ABCDEF1234567890ABCDEF1234567890ABCDEF12", + "mp_abcdef1234567890", + ) + + +def test_change_torrent_allows_zero_limit_to_disable_limit(): + """ + rTorrent 单任务限速传 0 时应写入 0,表示关闭对应限速。 + """ + downloader = Rtorrent.__new__(Rtorrent) + fake_proxy = MagicMock() + downloader._proxy = fake_proxy + + assert downloader.change_torrent( + hash_string="ABCDEF1234567890ABCDEF1234567890ABCDEF12", + download_limit=0, + ) + + fake_proxy.throttle.down.max.set.assert_called_once_with( + "mp_abcdef1234567890", + 0, + ) + fake_proxy.throttle.up.max.set.assert_not_called() + fake_proxy.d.throttle_name.set.assert_called_once_with( + "ABCDEF1234567890ABCDEF1234567890ABCDEF12", + "mp_abcdef1234567890", + ) + + +def test_set_torrent_location_updates_directory(): + """ + rTorrent 保存目录修改应调用 d.directory.set。 + """ + downloader = Rtorrent.__new__(Rtorrent) + fake_proxy = MagicMock() + downloader._proxy = fake_proxy + + assert downloader.set_torrent_location( + hash_string="ABCDEF1234567890ABCDEF1234567890ABCDEF12", + location="/downloads/new", + ) + + fake_proxy.d.directory.set.assert_called_once_with( + "ABCDEF1234567890ABCDEF1234567890ABCDEF12", + "/downloads/new", + ) diff --git a/tests/test_transmission_compat.py b/tests/test_transmission_compat.py index 5b11d939..d5c14fc6 100644 --- a/tests/test_transmission_compat.py +++ b/tests/test_transmission_compat.py @@ -164,3 +164,73 @@ def test_get_files_falls_back_to_legacy_files_method(): assert downloader.get_files("1") == torrent_files fake_client.get_torrent.assert_called_once_with("1") torrent.files.assert_called_once_with() + + +def test_change_torrent_only_sends_explicit_fields(): + """ + 修改单个任务时只能写入显式传入的策略字段。 + """ + downloader = Transmission.__new__(Transmission) + fake_client = MagicMock() + downloader.trc = fake_client + + assert downloader.change_torrent("hash", ratio_limit=2.5) + + fake_client.change_torrent.assert_called_once_with( + ids="hash", + seedRatioMode=1, + seedRatioLimit=2.5, + ) + + +def test_change_torrent_disables_speed_limit_with_zero_value(): + """ + 单任务限速传 0 时应显式关闭对应限速。 + """ + downloader = Transmission.__new__(Transmission) + fake_client = MagicMock() + downloader.trc = fake_client + + assert downloader.change_torrent("hash", download_limit=0, upload_limit=512) + + fake_client.change_torrent.assert_called_once_with( + ids="hash", + uploadLimited=True, + uploadLimit=512, + downloadLimited=False, + downloadLimit=0, + ) + + +def test_set_torrent_location_prefers_move_torrent_data(): + """ + Transmission 修改保存目录应优先使用移动数据接口。 + """ + downloader = Transmission.__new__(Transmission) + fake_client = MagicMock() + downloader.trc = fake_client + + assert downloader.set_torrent_location("hash", "/downloads/new") + + fake_client.move_torrent_data.assert_called_once_with( + ids="hash", + location="/downloads/new", + ) + fake_client.change_torrent.assert_not_called() + + +def test_set_torrent_location_falls_back_to_change_torrent(): + """ + 旧版 transmission-rpc 没有移动数据接口时回退到 change_torrent。 + """ + downloader = Transmission.__new__(Transmission) + fake_client = MagicMock() + fake_client.move_torrent_data = None + downloader.trc = fake_client + + assert downloader.set_torrent_location("hash", "/downloads/new") + + fake_client.change_torrent.assert_called_once_with( + ids="hash", + download_dir="/downloads/new", + )