diff --git a/app/chain/message.py b/app/chain/message.py index f924f101..fb05955b 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -27,10 +27,11 @@ from app.core.meta import MetaBase from app.db.models import TransferHistory from app.db.transferhistory_oper import TransferHistoryOper from app.db.user_oper import UserOper +from app.helper.directory import DirectoryHelper from app.helper.interaction import agent_interaction_manager, media_interaction_manager, PendingMediaInteraction from app.helper.torrent import TorrentHelper from app.log import logger -from app.schemas import Notification, CommingMessage, NotExistMediaInfo +from app.schemas import CommingMessage, DownloadDirectory, FileURI, NotExistMediaInfo, Notification from app.schemas.message import ChannelCapabilityManager, ChannelCapability from app.schemas.types import EventType, MessageChannel, MediaType from app.utils.http import RequestUtils @@ -2075,6 +2076,17 @@ class MediaInteractionChain(ChainBase): ) return True + if action == "download-dir": + self._handle_download_dir_selection( + request=request, + page_index=index, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + return False def handle_text_interaction( @@ -2120,7 +2132,16 @@ class MediaInteractionChain(ChainBase): request.source = source request.username = username index = int(normalized) - if request.phase == "torrent": + if request.phase == "download-dir": + self._handle_download_dir_selection( + request=request, + page_index=index, + channel=channel, + source=source, + userid=userid, + username=username, + ) + elif request.phase == "torrent": self._handle_torrent_selection( request=request, page_index=index, @@ -2414,6 +2435,22 @@ class MediaInteractionChain(ChainBase): contexts = TorrentHelper().sort_torrents(contexts) if self._should_auto_download(userid): logger.info("用户 %s 在自动下载用户中,开始自动择优下载 ...", userid) + request.phase = "torrent" + request.page = 0 + request.title = mediainfo.title + request.items = list(contexts) + if self._prompt_download_dir_selection( + request=request, + download_mode="auto", + channel=channel, + source=source, + userid=userid, + username=username, + no_exists=no_exists, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ): + return self._auto_download( request=request, cache_list=contexts, @@ -2508,6 +2545,15 @@ class MediaInteractionChain(ChainBase): return if page_index == 0: + if self._prompt_download_dir_selection( + request=request, + download_mode="auto", + channel=channel, + source=source, + userid=userid, + username=username, + ): + return self._auto_download( request=request, cache_list=request.items, @@ -2534,6 +2580,16 @@ class MediaInteractionChain(ChainBase): return context: Context = page_items[page_index - 1] + if self._prompt_download_dir_selection( + request=request, + download_mode="single", + channel=channel, + source=source, + userid=userid, + username=username, + context=context, + ): + return DownloadChain().download_single( context, channel=channel, @@ -2542,6 +2598,163 @@ class MediaInteractionChain(ChainBase): username=username, ) + def _prompt_download_dir_selection( + self, + request: PendingMediaInteraction, + download_mode: str, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + context: Optional[Context] = None, + no_exists: Optional[Dict[Union[int, str], Dict[int, NotExistMediaInfo]]] = None, + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[str] = None, + ) -> bool: + """ + 在下载前进入目录选择阶段;没有配置下载目录时保持原下载流程。 + """ + download_dirs = self._get_download_dirs() + if not download_dirs: + return False + + request.pending_torrent_page = request.page + request.phase = "download-dir" + request.page = 0 + request.download_dirs = download_dirs + request.pending_download_mode = download_mode + request.pending_download_context = context + request.pending_no_exists = no_exists + self._post_download_dirs_message( + request=request, + channel=channel, + source=source, + userid=userid, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ) + return True + + def _handle_download_dir_selection( + self, + request: PendingMediaInteraction, + page_index: Optional[int], + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + ) -> None: + """ + 处理下载目录阶段的序号输入,并继续执行挂起的下载动作。 + """ + if request.phase != "download-dir": + self._post_invalid_input( + channel=channel, + source=source, + userid=userid, + username=username, + ) + return + + page_items, page, _ = self._page_items( + items=request.download_dirs, + page=request.page, + page_size=self._page_size(request.channel), + ) + request.page = page + if not page_index or page_index < 1 or page_index > len(page_items): + self._post_invalid_input( + channel=channel, + source=source, + userid=userid, + username=username, + ) + return + + download_dir = page_items[page_index - 1] + save_path = download_dir.save_path or download_dir.download_path + if not save_path: + self._post_invalid_input( + channel=channel, + source=source, + userid=userid, + username=username, + title="下载目录配置无效!", + ) + return + self._execute_pending_download( + request=request, + channel=channel, + source=source, + userid=userid, + username=username, + save_path=save_path, + ) + + def _execute_pending_download( + self, + request: PendingMediaInteraction, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + save_path: str, + ) -> None: + """ + 使用用户确认的下载目录执行单资源下载或自动择优下载。 + """ + download_mode = request.pending_download_mode + if download_mode == "single" and request.pending_download_context: + context = request.pending_download_context + self._restore_torrent_phase(request) + DownloadChain().download_single( + context, + channel=channel, + source=source, + userid=userid, + username=username, + save_path=save_path, + ) + return + + if download_mode == "auto": + cache_list = list(request.items or []) + no_exists = request.pending_no_exists + self._restore_torrent_phase(request) + self._auto_download( + request=request, + cache_list=cache_list, + channel=channel, + source=source, + userid=userid, + username=username, + no_exists=no_exists, + save_path=save_path, + ) + return + + self._restore_torrent_phase(request) + self._post_invalid_input( + channel=channel, + source=source, + userid=userid, + username=username, + title="下载操作已失效,请重新选择资源", + ) + + @staticmethod + def _restore_torrent_phase(request: PendingMediaInteraction) -> None: + """ + 下载动作完成或失效后恢复到资源列表阶段,便于用户继续选择其它资源。 + """ + request.phase = "torrent" + request.page = request.pending_torrent_page + request.download_dirs = [] + request.pending_download_mode = None + request.pending_download_context = None + request.pending_no_exists = None + request.pending_torrent_page = 0 + def _auto_download( self, request: PendingMediaInteraction, @@ -2551,6 +2764,7 @@ class MediaInteractionChain(ChainBase): userid: Union[str, int], username: str, no_exists: Optional[Dict[Union[int, str], Dict[int, NotExistMediaInfo]]] = None, + save_path: Optional[str] = None, ) -> None: """ 自动择优下载当前资源列表,并在未完成时补建订阅。 @@ -2567,6 +2781,7 @@ class MediaInteractionChain(ChainBase): downloads, lefts = downloadchain.batch_download( contexts=cache_list, no_exists=no_exists, + save_path=save_path, channel=channel, source=source, userid=userid, @@ -2617,7 +2832,16 @@ class MediaInteractionChain(ChainBase): """ 按当前阶段渲染媒体列表或资源列表。 """ - if request.phase == "torrent": + if request.phase == "download-dir": + self._post_download_dirs_message( + request=request, + channel=channel, + source=source, + userid=userid, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ) + elif request.phase == "torrent": self._post_torrents_message( request=request, channel=channel, @@ -2733,6 +2957,58 @@ class MediaInteractionChain(ChainBase): torrents=page_items, ) + def _post_download_dirs_message( + self, + request: PendingMediaInteraction, + channel: MessageChannel, + source: str, + userid: Union[str, int], + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[str] = None, + ) -> None: + """ + 发送或更新下载目录选择列表。 + """ + page_items, page, total_pages = self._page_items( + items=request.download_dirs, + page=request.page, + page_size=self._page_size(channel), + ) + request.page = page + total = len(request.download_dirs) + if self._supports_interactive_buttons(channel): + title = f"【{request.title}】请选择下载目录" + buttons = self._create_download_dir_buttons( + channel=channel, + request=request, + items=page_items, + total=total, + total_pages=total_pages, + ) + else: + if total > self._page_size(channel): + title = f"【{request.title}】请选择下载目录,请回复对应数字(p: 上一页 n: 下一页)" + else: + title = f"【{request.title}】请选择下载目录,请回复对应数字" + buttons = None + + text = "\n".join( + f"{index}. {self._format_download_dir_label(download_dir)}" + for index, download_dir in enumerate(page_items, start=1) + ) + self.post_message( + Notification( + channel=channel, + source=source, + title=title, + text=text, + userid=userid, + buttons=buttons, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ) + ) + def _create_media_buttons( self, channel: MessageChannel, @@ -2831,17 +3107,71 @@ class MediaInteractionChain(ChainBase): buttons.extend(self._navigation_buttons(request, total_pages)) return buttons + def _create_download_dir_buttons( + self, + channel: MessageChannel, + request: PendingMediaInteraction, + items: List[DownloadDirectory], + total: int, + total_pages: int, + ) -> List[List[Dict[str, str]]]: + """ + 为下载目录列表生成选择和翻页按钮。 + """ + buttons: List[List[Dict[str, str]]] = [] + max_text_length = ChannelCapabilityManager.get_max_button_text_length(channel) + max_per_row = ChannelCapabilityManager.get_max_buttons_per_row(channel) + + current_row: List[Dict[str, str]] = [] + for index, download_dir in enumerate(items, start=1): + if max_per_row == 1: + button_text = f"{index}. {self._format_download_dir_label(download_dir)}" + if len(button_text) > max_text_length: + button_text = button_text[: max_text_length - 3] + "..." + buttons.append( + [ + { + "text": button_text, + "callback_data": f"media:{request.request_id}:download-dir:{index}", + } + ] + ) + continue + + current_row.append( + { + "text": f"{index}", + "callback_data": f"media:{request.request_id}:download-dir:{index}", + } + ) + if len(current_row) == max_per_row or index == len(items): + buttons.append(current_row) + current_row = [] + + if total > self._page_size(channel): + buttons.extend(self._navigation_buttons(request, total_pages)) + return buttons + def _has_next_page(self, request: PendingMediaInteraction) -> bool: """ 判断当前视图是否还有下一页。 """ _, page, total_pages = self._page_items( - items=request.items, + items=self._get_current_phase_items(request), page=request.page, page_size=self._page_size(request.channel), ) return page < total_pages - 1 + @staticmethod + def _get_current_phase_items(request: PendingMediaInteraction) -> List[Any]: + """ + 获取当前阶段用于分页的数据列表。 + """ + if request.phase == "download-dir": + return request.download_dirs + return request.items + @staticmethod def _navigation_buttons( request: PendingMediaInteraction, @@ -2885,6 +3215,39 @@ class MediaInteractionChain(ChainBase): end = start + page_size return items[start:end], page, total_pages + @staticmethod + def _get_download_dirs() -> List[DownloadDirectory]: + """ + 获取可供消息交互选择的下载目录。 + """ + return [ + DownloadDirectory( + name=dir_info.name, + storage=dir_info.storage or "local", + download_path=dir_info.download_path, + save_path=FileURI( + storage=dir_info.storage or "local", + path=dir_info.download_path, + ).uri, + priority=dir_info.priority, + media_type=dir_info.media_type, + media_category=dir_info.media_category, + ) + for dir_info in DirectoryHelper().get_download_dirs() + if dir_info.download_path + ] + + @staticmethod + def _format_download_dir_label(download_dir: DownloadDirectory) -> str: + """ + 格式化下载目录展示名称,优先显示用户配置的目录名称。 + """ + save_path = download_dir.save_path or download_dir.download_path or "" + name = download_dir.name or save_path or "下载目录" + if save_path and name != save_path: + return f"{name} ({save_path})" + return name + def _page_size(self, channel: Optional[MessageChannel]) -> int: """ 按渠道交互能力选择分页大小。 diff --git a/app/helper/interaction.py b/app/helper/interaction.py index 8cc69d65..d78e2fdf 100644 --- a/app/helper/interaction.py +++ b/app/helper/interaction.py @@ -271,6 +271,11 @@ class PendingMediaInteraction: meta: Optional[MetaBase] = None current_media: Optional[MediaInfo] = None items: List[Any] = field(default_factory=list) + download_dirs: List[Any] = field(default_factory=list) + pending_download_mode: Optional[str] = None + pending_download_context: Optional[Any] = None + pending_no_exists: Optional[Dict[Any, Any]] = None + pending_torrent_page: int = 0 created_at: datetime = field(default_factory=datetime.now) diff --git a/tests/test_media_interaction.py b/tests/test_media_interaction.py index a5fc0630..4771ded1 100644 --- a/tests/test_media_interaction.py +++ b/tests/test_media_interaction.py @@ -3,10 +3,11 @@ from unittest.mock import patch import pytest from app.chain.message import MediaInteractionChain, MessageChain -from app.core.context import MediaInfo +from app.core.context import Context, MediaInfo, TorrentInfo from app.core.meta import MetaBase from app.helper.interaction import media_interaction_manager -from app.schemas.types import MessageChannel +from app.schemas import TransferDirectoryConf +from app.schemas.types import MediaType, MessageChannel @pytest.fixture(autouse=True) @@ -24,6 +25,43 @@ def _build_meta(name: str) -> MetaBase: return meta +def _build_context(title: str = "星际穿越") -> Context: + """构造可用于媒体交互下载测试的资源上下文。""" + return Context( + meta_info=_build_meta(title), + media_info=MediaInfo( + type=MediaType.MOVIE, + title=title, + year="2014", + tmdb_id=1, + ), + torrent_info=TorrentInfo( + title=f"{title}.2014.1080p", + site_name="TestSite", + enclosure="https://example.com/demo.torrent", + seeders=10, + ), + ) + + +def _build_download_dirs() -> list[TransferDirectoryConf]: + """构造消息交互可选择的下载目录配置。""" + return [ + TransferDirectoryConf( + name="电影下载", + storage="local", + download_path="/downloads/movies", + priority=1, + ), + TransferDirectoryConf( + name="动画下载", + storage="rclone", + download_path="/media/anime", + priority=2, + ), + ] + + def test_message_routes_text_reply_to_media_interaction_before_ai(): """已有传统媒体交互时,用户回复应优先交给传统交互处理。""" chain = MessageChain() @@ -230,3 +268,164 @@ def test_media_interaction_legacy_page_callback_updates_existing_request(): notification = post_medias_message.call_args.args[0] assert notification.original_message_id == 123 assert notification.original_chat_id == "456" + + +def test_torrent_selection_prompts_download_dir_buttons_before_download(): + """支持按钮的渠道选择资源后,应先发送下载目录按钮而不是立即下载。""" + chain = MediaInteractionChain() + context = _build_context() + request = media_interaction_manager.create_or_replace( + user_id="10001", + channel=MessageChannel.Telegram, + source="telegram-test", + username="tester", + action="Search", + keyword="星际穿越", + title="星际穿越", + meta=_build_meta("星际穿越"), + items=[context], + ) + request.phase = "torrent" + + with patch( + "app.chain.message.DirectoryHelper.get_download_dirs", + return_value=_build_download_dirs(), + ), patch.object(chain, "post_message") as post_message, patch( + "app.chain.message.DownloadChain.download_single" + ) as download_single: + handled = chain.handle_text_interaction( + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + username="tester", + text="1", + ) + + assert handled + download_single.assert_not_called() + assert request.phase == "download-dir" + post_message.assert_called_once() + notification = post_message.call_args.args[0] + assert "请选择下载目录" in notification.title + assert "电影下载 (/downloads/movies)" in notification.text + assert notification.buttons[0][0]["callback_data"] == f"media:{request.request_id}:download-dir:1" + + +def test_torrent_selection_prompts_text_download_dir_for_plain_channel(): + """不支持按钮的渠道选择资源后,应提示用户回复数字选择下载目录。""" + chain = MediaInteractionChain() + context = _build_context() + request = media_interaction_manager.create_or_replace( + user_id="wechat-user", + channel=MessageChannel.Wechat, + source="wechat-test", + username="tester", + action="Search", + keyword="星际穿越", + title="星际穿越", + meta=_build_meta("星际穿越"), + items=[context], + ) + request.phase = "torrent" + + with patch( + "app.chain.message.DirectoryHelper.get_download_dirs", + return_value=_build_download_dirs(), + ), patch.object(chain, "post_message") as post_message: + handled = chain.handle_text_interaction( + channel=MessageChannel.Wechat, + source="wechat-test", + userid="wechat-user", + username="tester", + text="1", + ) + + assert handled + notification = post_message.call_args.args[0] + assert "请回复对应数字" in notification.title + assert notification.buttons is None + assert "2. 动画下载 (rclone:/media/anime)" in notification.text + + +def test_download_dir_callback_runs_pending_single_download_with_save_path(): + """下载目录按钮回调应使用所选 save_path 继续执行挂起的单资源下载。""" + chain = MediaInteractionChain() + context = _build_context() + request = media_interaction_manager.create_or_replace( + user_id="10001", + channel=MessageChannel.Telegram, + source="telegram-test", + username="tester", + action="Search", + keyword="星际穿越", + title="星际穿越", + meta=_build_meta("星际穿越"), + items=[context], + ) + request.phase = "download-dir" + request.pending_download_mode = "single" + request.pending_download_context = context + + with patch( + "app.chain.message.DirectoryHelper.get_download_dirs", + return_value=_build_download_dirs(), + ), patch( + "app.chain.message.DownloadChain.download_single", + return_value="hash", + ) as download_single: + request.download_dirs = chain._get_download_dirs() + handled = chain.handle_callback_interaction( + callback_data=f"media:{request.request_id}:download-dir:2", + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + username="tester", + ) + + assert handled + assert request.phase == "torrent" + download_single.assert_called_once() + assert download_single.call_args.args[0] is context + assert download_single.call_args.kwargs["save_path"] == "rclone:/media/anime" + + +def test_download_dir_text_reply_runs_pending_single_download_with_save_path(): + """下载目录文本回复应使用所选 save_path 继续执行挂起的单资源下载。""" + chain = MediaInteractionChain() + context = _build_context() + request = media_interaction_manager.create_or_replace( + user_id="wechat-user", + channel=MessageChannel.Wechat, + source="wechat-test", + username="tester", + action="Search", + keyword="星际穿越", + title="星际穿越", + meta=_build_meta("星际穿越"), + items=[context], + ) + request.phase = "download-dir" + request.pending_download_mode = "single" + request.pending_download_context = context + + with patch( + "app.chain.message.DirectoryHelper.get_download_dirs", + return_value=_build_download_dirs(), + ), patch( + "app.chain.message.DownloadChain.download_single", + return_value="hash", + ) as download_single: + request.download_dirs = chain._get_download_dirs() + handled = chain.handle_text_interaction( + channel=MessageChannel.Wechat, + source="wechat-test", + userid="wechat-user", + username="tester", + text="1", + ) + + assert handled + assert request.phase == "torrent" + download_single.assert_called_once() + assert download_single.call_args.args[0] is context + assert download_single.call_args.kwargs["save_path"] == "/downloads/movies"