From d2103f91b82308ade34f8428f364cc2fb2443401 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Sat, 20 Jun 2026 12:59:37 +0800 Subject: [PATCH] fix(message): keep auto directory matching for interactive downloads --- app/chain/message.py | 76 ++++++++++++++++++++++-- tests/test_media_interaction.py | 101 +++++++++++++++++++++++++++++--- 2 files changed, 164 insertions(+), 13 deletions(-) diff --git a/app/chain/message.py b/app/chain/message.py index 13ed616c..884f2944 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -33,6 +33,7 @@ from app.helper.torrent import TorrentHelper from app.log import logger from app.schemas import CommingMessage, DownloadDirectory, FileURI, NotExistMediaInfo, Notification from app.schemas.message import ChannelCapabilityManager, ChannelCapability +from app.schemas.system import TransferDirectoryConf from app.schemas.types import EventType, MessageChannel, MediaType from app.utils.http import RequestUtils from app.utils.string import StringUtils @@ -1908,6 +1909,7 @@ class MediaInteractionChain(ChainBase): _button_page_size = 8 _text_page_size = 8 + _auto_download_dir_name = "自动匹配目录" @staticmethod def has_pending_interaction(user_id: Union[str, int]) -> bool: @@ -2646,7 +2648,8 @@ class MediaInteractionChain(ChainBase): """ 在下载前进入目录选择阶段;没有配置下载目录时保持原下载流程。 """ - download_dirs = self._get_download_dirs() + media_info = context.media_info if context else request.current_media + download_dirs = self._get_download_dirs(media_info) if not download_dirs: return False @@ -2704,6 +2707,17 @@ class MediaInteractionChain(ChainBase): return download_dir = page_items[page_index - 1] + if self._is_auto_download_dir(download_dir): + self._execute_pending_download( + request=request, + channel=channel, + source=source, + userid=userid, + username=username, + save_path=None, + ) + return + save_path = download_dir.save_path or download_dir.download_path if not save_path: self._post_invalid_input( @@ -2730,7 +2744,7 @@ class MediaInteractionChain(ChainBase): source: str, userid: Union[str, int], username: str, - save_path: str, + save_path: Optional[str], ) -> None: """ 使用用户确认的下载目录执行单资源下载或自动择优下载。 @@ -3250,12 +3264,12 @@ class MediaInteractionChain(ChainBase): end = start + page_size return items[start:end], page, total_pages - @staticmethod - def _get_download_dirs() -> List[DownloadDirectory]: + @classmethod + def _get_download_dirs(cls, media_info: Optional[MediaInfo] = None) -> List[DownloadDirectory]: """ 获取可供消息交互选择的下载目录。 """ - return [ + download_dirs = [ DownloadDirectory( name=dir_info.name, storage=dir_info.storage or "local", @@ -3269,8 +3283,58 @@ class MediaInteractionChain(ChainBase): media_category=dir_info.media_category, ) for dir_info in DirectoryHelper().get_download_dirs() - if dir_info.download_path + if dir_info.download_path and cls._match_download_dir_media(dir_info, media_info) ] + if not download_dirs: + return [] + return [cls._build_auto_download_dir(), *download_dirs] + + @classmethod + def _build_auto_download_dir(cls) -> DownloadDirectory: + """ + 构造自动匹配下载目录选项。 + """ + return DownloadDirectory( + name=cls._auto_download_dir_name, + storage="local", + priority=-1, + ) + + @classmethod + def _is_auto_download_dir(cls, download_dir: DownloadDirectory) -> bool: + """ + 判断是否为自动匹配下载目录选项。 + """ + return ( + download_dir.name == cls._auto_download_dir_name + and not download_dir.download_path + and not download_dir.save_path + ) + + @staticmethod + def _match_download_dir_media( + dir_info: TransferDirectoryConf, + media_info: Optional[MediaInfo], + ) -> bool: + """ + 判断下载目录是否适用于当前媒体。 + """ + if not media_info or not media_info.type: + return True + + if dir_info.media_type: + media_type_values = ( + {media_info.type.value, media_info.type.to_agent()} + if isinstance(media_info.type, MediaType) + else {str(media_info.type)} + ) + if dir_info.media_type not in media_type_values: + return False + + if dir_info.media_category and dir_info.media_category != media_info.category: + return False + + return True @staticmethod def _format_download_dir_label(download_dir: DownloadDirectory) -> str: diff --git a/tests/test_media_interaction.py b/tests/test_media_interaction.py index 88c983de..b98eea3f 100644 --- a/tests/test_media_interaction.py +++ b/tests/test_media_interaction.py @@ -44,6 +44,26 @@ def _build_context(title: str = "星际穿越") -> Context: ) +def _build_tv_context(title: str = "葬送的芙莉莲") -> Context: + """构造可用于媒体交互下载测试的电视剧上下文。""" + return Context( + meta_info=_build_meta(title), + media_info=MediaInfo( + type=MediaType.TV, + title=title, + year="2023", + tmdb_id=2, + category="动漫", + ), + torrent_info=TorrentInfo( + title=f"{title}.S01.1080p", + site_name="TestSite", + enclosure="https://example.com/demo-tv.torrent", + seeders=10, + ), + ) + + def _build_download_dirs() -> list[TransferDirectoryConf]: """构造消息交互可选择的下载目录配置。""" return [ @@ -52,12 +72,15 @@ def _build_download_dirs() -> list[TransferDirectoryConf]: storage="local", download_path="/downloads/movies", priority=1, + media_type=MediaType.MOVIE.value, ), TransferDirectoryConf( name="动画下载", storage="rclone", download_path="/media/anime", priority=2, + media_type=MediaType.TV.value, + media_category="动漫", ), ] @@ -309,7 +332,9 @@ def test_torrent_selection_prompts_download_dir_buttons_before_download(): notification = post_message.call_args.args[0] assert notification.save_history is False assert "请选择下载目录" in notification.title - assert "电影下载 (/downloads/movies)" in notification.text + assert "1. 自动匹配目录" in notification.text + assert "2. 电影下载 (/downloads/movies)" in notification.text + assert "动画下载" not in notification.text assert notification.buttons[0][0]["callback_data"] == f"media:{request.request_id}:download-dir:1" @@ -347,7 +372,51 @@ def test_torrent_selection_prompts_text_download_dir_for_plain_channel(): assert notification.save_history is False assert "请回复对应数字" in notification.title assert notification.buttons is None - assert "2. 动画下载 (rclone:/media/anime)" in notification.text + assert "1. 自动匹配目录" in notification.text + assert "2. 电影下载 (/downloads/movies)" in notification.text + assert "动画下载" not in notification.text + + +def test_download_dir_callback_runs_pending_single_download_without_save_path_for_auto(): + """下载目录选择自动匹配时,应不传 save_path 继续执行挂起的单资源下载。""" + chain = MediaInteractionChain() + context = _build_context() + request = media_interaction_manager.create_or_replace( + user_id="10001", + channel=MessageChannel.Telegram, + source="telegram-test", + username="tester", + action="Search", + keyword="星际穿越", + title="星际穿越", + meta=_build_meta("星际穿越"), + items=[context], + ) + request.phase = "download-dir" + request.pending_download_mode = "single" + request.pending_download_context = context + + with patch( + "app.chain.message.DirectoryHelper.get_download_dirs", + return_value=_build_download_dirs(), + ), patch( + "app.chain.message.DownloadChain.download_single", + return_value="hash", + ) as download_single: + request.download_dirs = chain._get_download_dirs(context.media_info) + handled = chain.handle_callback_interaction( + callback_data=f"media:{request.request_id}:download-dir:1", + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + username="tester", + ) + + assert handled + assert request.phase == "torrent" + download_single.assert_called_once() + assert download_single.call_args.args[0] is context + assert download_single.call_args.kwargs["save_path"] is None def test_download_dir_callback_runs_pending_single_download_with_save_path(): @@ -376,7 +445,7 @@ def test_download_dir_callback_runs_pending_single_download_with_save_path(): "app.chain.message.DownloadChain.download_single", return_value="hash", ) as download_single: - request.download_dirs = chain._get_download_dirs() + request.download_dirs = chain._get_download_dirs(context.media_info) handled = chain.handle_callback_interaction( callback_data=f"media:{request.request_id}:download-dir:2", channel=MessageChannel.Telegram, @@ -389,11 +458,11 @@ def test_download_dir_callback_runs_pending_single_download_with_save_path(): assert request.phase == "torrent" download_single.assert_called_once() assert download_single.call_args.args[0] is context - assert download_single.call_args.kwargs["save_path"] == "rclone:/media/anime" + assert download_single.call_args.kwargs["save_path"] == "/downloads/movies" -def test_download_dir_text_reply_runs_pending_single_download_with_save_path(): - """下载目录文本回复应使用所选 save_path 继续执行挂起的单资源下载。""" +def test_download_dir_text_reply_runs_pending_single_download_without_save_path(): + """下载目录文本回复选择自动匹配时应不传 save_path。""" chain = MediaInteractionChain() context = _build_context() request = media_interaction_manager.create_or_replace( @@ -431,4 +500,22 @@ def test_download_dir_text_reply_runs_pending_single_download_with_save_path(): assert request.phase == "torrent" download_single.assert_called_once() assert download_single.call_args.args[0] is context - assert download_single.call_args.kwargs["save_path"] == "/downloads/movies" + assert download_single.call_args.kwargs["save_path"] is None + + +def test_get_download_dirs_keeps_matching_tv_category_dir(): + """目录列表应保留匹配当前电视剧类别的下载目录。""" + chain = MediaInteractionChain() + context = _build_tv_context() + + with patch( + "app.chain.message.DirectoryHelper.get_download_dirs", + return_value=_build_download_dirs(), + ): + download_dirs = chain._get_download_dirs(context.media_info) + + assert [download_dir.name for download_dir in download_dirs] == [ + "自动匹配目录", + "动画下载", + ] + assert download_dirs[1].save_path == "rclone:/media/anime"