diff --git a/app/agent/prompt/Agent Prompt.txt b/app/agent/prompt/Agent Prompt.txt index 4af9e6ef..b215baa7 100644 --- a/app/agent/prompt/Agent Prompt.txt +++ b/app/agent/prompt/Agent Prompt.txt @@ -23,6 +23,7 @@ Core Capabilities: - Do not stop for approval on read-only operations. Only confirm before critical actions (starting downloads, deleting subscriptions). - If the current channel supports image sending and an image would materially help, you may use the `send_message` tool with `image_url` to send it. - If the current channel supports file sending and you need to return a local image/file for the user to download, use `send_local_file`. +{button_choice_spec} - Voice replies: {voice_reply_spec} - NOT a coding assistant. Do not offer code snippets. - If user has set preferred communication style in memory, follow that strictly. diff --git a/app/agent/prompt/__init__.py b/app/agent/prompt/__init__.py index 4c34ceaa..b283eb08 100644 --- a/app/agent/prompt/__init__.py +++ b/app/agent/prompt/__init__.py @@ -76,6 +76,7 @@ class PromptManager: caps = ChannelCapabilityManager.get_capabilities(msg_channel) if caps: markdown_spec = self._generate_formatting_instructions(caps) + button_choice_spec = self._generate_button_choice_instructions(msg_channel) # 啰嗦模式 verbose_spec = "" @@ -100,6 +101,7 @@ class PromptManager: verbose_spec=verbose_spec, moviepilot_info=moviepilot_info, voice_reply_spec=voice_reply_spec, + button_choice_spec=button_choice_spec, ) return base_prompt @@ -187,6 +189,23 @@ class PromptManager: "- Do not repeat the same full reply again after calling `send_voice_message`." ) + @staticmethod + def _generate_button_choice_instructions( + channel: MessageChannel = None, + ) -> str: + if channel and ChannelCapabilityManager.supports_buttons( + channel + ) and ChannelCapabilityManager.supports_callbacks(channel): + return ( + "- User questions: If you need the user to choose from a few clear options, " + "call `ask_user_choice` to send button options. After the user clicks a button, " + "the selected value will come back as the user's next message. After calling this tool, " + "wait for the user's selection instead of repeating the question in plain text." + ) + return ( + "- User questions: When you truly need user input, ask briefly in plain text." + ) + def clear_cache(self): """ 清空缓存 diff --git a/app/agent/tools/factory.py b/app/agent/tools/factory.py index a5e3ad3f..a98f39ca 100644 --- a/app/agent/tools/factory.py +++ b/app/agent/tools/factory.py @@ -30,6 +30,7 @@ from app.agent.tools.impl.search_torrents import SearchTorrentsTool from app.agent.tools.impl.get_search_results import GetSearchResultsTool from app.agent.tools.impl.search_web import SearchWebTool from app.agent.tools.impl.send_message import SendMessageTool +from app.agent.tools.impl.ask_user_choice import AskUserChoiceTool from app.agent.tools.impl.send_local_file import SendLocalFileTool from app.agent.tools.impl.send_voice_message import SendVoiceMessageTool from app.agent.tools.impl.query_schedulers import QuerySchedulersTool @@ -58,6 +59,8 @@ from app.agent.tools.impl.query_custom_identifiers import QueryCustomIdentifiers from app.agent.tools.impl.update_custom_identifiers import UpdateCustomIdentifiersTool from app.core.plugin import PluginManager from app.log import logger +from app.schemas.message import ChannelCapabilityManager +from app.schemas.types import MessageChannel from .base import MoviePilotTool @@ -66,6 +69,18 @@ class MoviePilotToolFactory: MoviePilot工具工厂 """ + @staticmethod + def _should_enable_choice_tool(channel: str = None) -> bool: + if not channel: + return False + try: + message_channel = MessageChannel(channel) + except ValueError: + return False + return ChannelCapabilityManager.supports_buttons( + message_channel + ) and ChannelCapabilityManager.supports_callbacks(message_channel) + @staticmethod def create_tools( session_id: str, @@ -120,8 +135,6 @@ class MoviePilotToolFactory: QueryTransferHistoryTool, TransferFileTool, SendMessageTool, - SendLocalFileTool, - SendVoiceMessageTool, QuerySchedulersTool, RunSchedulerTool, QueryWorkflowsTool, @@ -138,6 +151,14 @@ class MoviePilotToolFactory: QueryCustomIdentifiersTool, UpdateCustomIdentifiersTool, ] + if MoviePilotToolFactory._should_enable_choice_tool(channel): + tool_definitions.append(AskUserChoiceTool) + tool_definitions.extend( + [ + SendLocalFileTool, + SendVoiceMessageTool, + ] + ) # 创建内置工具 for ToolClass in tool_definitions: tool = ToolClass(session_id=session_id, user_id=user_id) diff --git a/app/agent/tools/impl/ask_user_choice.py b/app/agent/tools/impl/ask_user_choice.py new file mode 100644 index 00000000..5f1eed21 --- /dev/null +++ b/app/agent/tools/impl/ask_user_choice.py @@ -0,0 +1,165 @@ +"""让用户通过按钮进行选择的工具。""" + +from typing import List, Optional, Type + +from pydantic import BaseModel, Field, model_validator + +from app.agent.tools.base import MoviePilotTool, ToolChain +from app.agent.user_choice import AgentChoiceOption, agent_user_choice_manager +from app.log import logger +from app.schemas import Notification, NotificationType +from app.schemas.message import ChannelCapabilityManager +from app.schemas.types import MessageChannel + + +class UserChoiceOptionInput(BaseModel): + """单个按钮选项。""" + + label: str = Field(..., description="Text shown on the button") + value: str = Field( + ..., + description="The exact content that will be sent back to the agent after the user clicks this button", + ) + + @model_validator(mode="after") + def validate_option(self): + if not self.label.strip(): + raise ValueError("label 不能为空") + if not self.value.strip(): + raise ValueError("value 不能为空") + return self + + +class AskUserChoiceInput(BaseModel): + """按钮选择工具输入。""" + + explanation: str = Field( + ..., + description="Clear explanation of why the agent needs the user to choose from buttons", + ) + message: str = Field( + ..., + description="Question or prompt shown to the user together with the buttons", + ) + title: Optional[str] = Field( + None, + description="Optional short title displayed above the question", + ) + options: List[UserChoiceOptionInput] = Field( + ..., + description="Button options to show to the user", + ) + + @model_validator(mode="after") + def validate_payload(self): + if not self.message.strip(): + raise ValueError("message 不能为空") + if not self.options: + raise ValueError("options 至少需要提供一个") + return self + + +class AskUserChoiceTool(MoviePilotTool): + name: str = "ask_user_choice" + description: str = ( + "Ask the user to choose from button options on channels that support interactive buttons. " + "After the user clicks a button, the selected value will come back as the user's next message." + ) + args_schema: Type[BaseModel] = AskUserChoiceInput + require_admin: bool = False + + def get_tool_message(self, **kwargs) -> Optional[str]: + message = kwargs.get("message", "") or "" + if len(message) > 40: + message = message[:40] + "..." + return f"正在发送按钮选择: {message}" + + @staticmethod + def _truncate_button_text(text: str, max_length: int) -> str: + if max_length <= 0 or len(text) <= max_length: + return text + if max_length <= 3: + return text[:max_length] + return text[: max_length - 3] + "..." + + async def run( + self, + message: str, + options: List[UserChoiceOptionInput], + title: Optional[str] = None, + **kwargs, + ) -> str: + if not self._channel or not self._source: + return "当前不在可回传消息的会话中,无法发起按钮选择" + + try: + channel = MessageChannel(self._channel) + except ValueError: + return f"不支持的消息渠道: {self._channel}" + + if not ( + ChannelCapabilityManager.supports_buttons(channel) + and ChannelCapabilityManager.supports_callbacks(channel) + ): + return f"当前渠道 {channel.value} 不支持按钮选择" + + max_per_row = ChannelCapabilityManager.get_max_buttons_per_row(channel) + max_rows = ChannelCapabilityManager.get_max_button_rows(channel) + max_text_length = ChannelCapabilityManager.get_max_button_text_length(channel) + max_options = max_per_row * max_rows + if len(options) > max_options: + return f"当前渠道最多支持 {max_options} 个按钮选项" + + choice_options = [ + AgentChoiceOption(label=option.label.strip(), value=option.value.strip()) + for option in options + ] + request = agent_user_choice_manager.create_request( + session_id=self._session_id, + user_id=str(self._user_id), + channel=channel.value, + source=self._source, + username=self._username, + prompt=message.strip(), + options=choice_options, + ) + + buttons = [] + current_row = [] + for index, option in enumerate(choice_options, start=1): + current_row.append( + { + "text": self._truncate_button_text(option.label, max_text_length), + "callback_data": f"agent_choice:{request.request_id}:{index}", + } + ) + if len(current_row) >= max_per_row: + buttons.append(current_row) + current_row = [] + if current_row: + buttons.append(current_row) + + logger.info( + "执行工具: %s, channel=%s, session_id=%s, options=%s", + self.name, + channel.value, + self._session_id, + len(choice_options), + ) + + await ToolChain().async_post_message( + Notification( + channel=channel, + source=self._source, + mtype=NotificationType.Agent, + userid=self._user_id, + username=self._username, + title=title, + text=message.strip(), + buttons=buttons, + ) + ) + + self._agent_context["user_reply_sent"] = True + self._agent_context["reply_mode"] = "button_choice" + return f"已发送 {len(choice_options)} 个按钮选项,等待用户选择" diff --git a/app/agent/user_choice.py b/app/agent/user_choice.py new file mode 100644 index 00000000..5c1e9154 --- /dev/null +++ b/app/agent/user_choice.py @@ -0,0 +1,104 @@ +"""Agent 用户按钮选择请求管理。""" + +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from threading import Lock +from typing import Dict, List, Optional +import uuid + + +@dataclass(frozen=True) +class AgentChoiceOption: + """按钮选项。""" + + label: str + value: str + + +@dataclass +class PendingAgentChoice: + """待处理的按钮选择请求。""" + + request_id: str + session_id: str + user_id: str + channel: Optional[str] + source: Optional[str] + username: Optional[str] + prompt: str + options: List[AgentChoiceOption] + created_at: datetime = field(default_factory=datetime.now) + + +class AgentUserChoiceManager: + """管理 Agent 发起的按钮选择请求。""" + + _ttl = timedelta(hours=24) + + def __init__(self): + self._pending_choices: Dict[str, PendingAgentChoice] = {} + self._lock = Lock() + + def _cleanup_locked(self): + expire_before = datetime.now() - self._ttl + expired_ids = [ + request_id + for request_id, request in self._pending_choices.items() + if request.created_at < expire_before + ] + for request_id in expired_ids: + self._pending_choices.pop(request_id, None) + + def create_request( + self, + session_id: str, + user_id: str, + channel: Optional[str], + source: Optional[str], + username: Optional[str], + prompt: str, + options: List[AgentChoiceOption], + ) -> PendingAgentChoice: + with self._lock: + self._cleanup_locked() + request_id = uuid.uuid4().hex[:12] + while request_id in self._pending_choices: + request_id = uuid.uuid4().hex[:12] + request = PendingAgentChoice( + request_id=request_id, + session_id=session_id, + user_id=str(user_id), + channel=channel, + source=source, + username=username, + prompt=prompt, + options=options, + ) + self._pending_choices[request_id] = request + return request + + def resolve( + self, + request_id: str, + option_index: int, + user_id: Optional[str] = None, + ) -> Optional[tuple[PendingAgentChoice, AgentChoiceOption]]: + with self._lock: + self._cleanup_locked() + request = self._pending_choices.get(request_id) + if not request: + return None + if user_id is not None and str(request.user_id) != str(user_id): + return None + if option_index < 1 or option_index > len(request.options): + return None + option = request.options[option_index - 1] + self._pending_choices.pop(request_id, None) + return request, option + + def clear(self): + with self._lock: + self._pending_choices.clear() + + +agent_user_choice_manager = AgentUserChoiceManager() diff --git a/app/chain/message.py b/app/chain/message.py index 695d0cd0..2e6e9073 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -11,6 +11,7 @@ import uuid import base64 from app.agent import agent_manager +from app.agent.user_choice import agent_user_choice_manager from app.chain import ChainBase from app.chain.download import DownloadChain from app.chain.media import MediaChain @@ -215,22 +216,12 @@ class MessageChain(ChainBase): # 保存消息 if not text.startswith("CALLBACK:"): - self.messagehelper.put( - CommingMessage( - userid=userid, - username=username, - channel=channel, - source=source, - text=text, - ), - role="user", - ) - self.messageoper.add( + self._record_user_message( channel=channel, source=source, - userid=username or userid, + userid=userid, + username=username, text=text, - action=0, ) # 处理消息 if text.startswith("CALLBACK:"): @@ -794,6 +785,15 @@ class MessageChain(ChainBase): ): return + if self._handle_agent_choice_callback( + callback_data=callback_data, + channel=channel, + source=source, + userid=userid, + username=username, + ): + return + # 插件消息的事件回调 [PLUGIN]插件ID|内容 if callback_data.startswith("[PLUGIN]"): # 提取插件ID和内容 @@ -888,6 +888,76 @@ class MessageChain(ChainBase): ) return True + @staticmethod + def _parse_agent_choice_callback( + callback_data: str, + ) -> Optional[tuple[str, int]]: + """ + 解析 Agent 按钮选择回调。 + """ + if not callback_data.startswith("agent_choice:"): + return None + try: + _, request_id, option_index = callback_data.split(":", 2) + except ValueError: + return None + if not request_id or not option_index.isdigit(): + return None + return request_id, int(option_index) + + def _handle_agent_choice_callback( + self, + callback_data: str, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + ) -> bool: + """ + 将 Agent 按钮选择回传为同一会话中的下一条用户消息。 + """ + callback = self._parse_agent_choice_callback(callback_data) + if not callback: + return False + + request_id, option_index = callback + resolved = agent_user_choice_manager.resolve( + request_id=request_id, + option_index=option_index, + user_id=str(userid), + ) + if not resolved: + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title="该选择已失效,请重新发起选择", + ) + ) + return True + + request, option = resolved + selected_text = option.value + self._bind_session_id(userid, request.session_id) + self._record_user_message( + channel=channel, + source=source, + userid=userid, + username=username, + text=selected_text, + ) + self._handle_ai_message( + text=selected_text, + channel=channel, + source=source, + userid=userid, + username=username, + session_id=request.session_id, + ) + return True + def _retry_transfer_history( self, history_id: int, @@ -1308,6 +1378,41 @@ class MessageChain(ChainBase): logger.info(f"创建新会话ID: {new_session_id}, 用户: {userid}") return new_session_id + def _bind_session_id(self, userid: Union[str, int], session_id: str) -> None: + """ + 将用户会话绑定到指定的 session_id,并刷新最后活动时间。 + """ + self._user_sessions[userid] = (session_id, datetime.now()) + + def _record_user_message( + self, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + text: str, + ) -> None: + """ + 保存一条用户消息到消息历史与数据库。 + """ + self.messagehelper.put( + CommingMessage( + userid=userid, + username=username, + channel=channel, + source=source, + text=text, + ), + role="user", + ) + self.messageoper.add( + channel=channel, + source=source, + userid=username or userid, + text=text, + action=0, + ) + def clear_user_session(self, userid: Union[str, int]) -> bool: """ 清除指定用户的会话信息 @@ -1427,6 +1532,7 @@ class MessageChain(ChainBase): images: Optional[List[CommingMessage.MessageImage]] = None, files: Optional[List[CommingMessage.MessageAttachment]] = None, reply_with_voice: bool = False, + session_id: Optional[str] = None, ) -> None: """ 处理AI智能体消息 @@ -1466,7 +1572,8 @@ class MessageChain(ChainBase): return # 生成或复用会话ID - session_id = self._get_or_create_session_id(userid) + session_id = session_id or self._get_or_create_session_id(userid) + self._bind_session_id(userid, session_id) # 下载图片并转为base64 original_images = images diff --git a/tests/test_agent_user_choice.py b/tests/test_agent_user_choice.py new file mode 100644 index 00000000..b798c7cd --- /dev/null +++ b/tests/test_agent_user_choice.py @@ -0,0 +1,126 @@ +import asyncio +import unittest +from unittest.mock import AsyncMock, patch + +from app.agent.prompt import prompt_manager +from app.agent.tools.factory import MoviePilotToolFactory +from app.agent.tools.impl.ask_user_choice import ( + AskUserChoiceTool, + UserChoiceOptionInput, +) +from app.agent.user_choice import AgentChoiceOption, agent_user_choice_manager +from app.chain.message import MessageChain +from app.schemas.types import MessageChannel + + +class TestAgentUserChoice(unittest.TestCase): + def tearDown(self): + agent_user_choice_manager.clear() + + def test_prompt_injects_choice_tool_hint_only_for_button_channels(self): + telegram_prompt = prompt_manager.get_agent_prompt( + channel=MessageChannel.Telegram.value + ) + wechat_prompt = prompt_manager.get_agent_prompt( + channel=MessageChannel.Wechat.value + ) + + self.assertIn("ask_user_choice", telegram_prompt) + self.assertNotIn("ask_user_choice", wechat_prompt) + + def test_factory_injects_choice_tool_only_for_button_channels(self): + with patch( + "app.agent.tools.factory.PluginManager.get_plugin_agent_tools", + return_value=[], + ): + telegram_tools = MoviePilotToolFactory.create_tools( + session_id="session-1", + user_id="10001", + channel=MessageChannel.Telegram.value, + source="telegram-test", + username="tester", + ) + wechat_tools = MoviePilotToolFactory.create_tools( + session_id="session-2", + user_id="10001", + channel=MessageChannel.Wechat.value, + source="wechat-test", + username="tester", + ) + + self.assertIn("ask_user_choice", [tool.name for tool in telegram_tools]) + self.assertNotIn("ask_user_choice", [tool.name for tool in wechat_tools]) + + def test_choice_tool_sends_buttons_and_registers_pending_request(self): + tool = AskUserChoiceTool(session_id="session-1", user_id="10001") + tool.set_message_attr( + channel=MessageChannel.Telegram.value, + source="telegram-test", + username="tester", + ) + tool.set_agent_context(agent_context={}) + + with patch( + "app.agent.tools.impl.ask_user_choice.ToolChain.async_post_message", + new=AsyncMock(), + ) as async_post_message: + result = asyncio.run( + tool.run( + message="请选择要执行的操作", + options=[ + UserChoiceOptionInput(label="继续下载", value="继续下载"), + UserChoiceOptionInput(label="先看详情", value="先看详情"), + ], + title="需要你的选择", + ) + ) + + self.assertIn("等待用户选择", result) + self.assertTrue(tool._agent_context.get("user_reply_sent")) + notification = async_post_message.await_args.args[0] + self.assertEqual(notification.text, "请选择要执行的操作") + self.assertEqual(len(notification.buttons[0]), 2) + + callback_data = notification.buttons[0][0]["callback_data"] + _, request_id, option_index = callback_data.split(":") + resolved = agent_user_choice_manager.resolve(request_id, int(option_index), "10001") + self.assertIsNotNone(resolved) + _, option = resolved + self.assertEqual(option.value, "继续下载") + + def test_agent_choice_callback_routes_selected_value_back_to_agent(self): + chain = MessageChain() + request = agent_user_choice_manager.create_request( + session_id="session-choice", + user_id="10001", + channel=MessageChannel.Telegram.value, + source="telegram-test", + username="tester", + prompt="请选择", + options=[ + AgentChoiceOption(label="电影", value="我选择电影"), + AgentChoiceOption(label="电视剧", value="我选择电视剧"), + ], + ) + + with patch.object(chain, "_handle_ai_message") as handle_ai_message, patch.object( + chain.messagehelper, "put" + ) as message_put, patch.object(chain.messageoper, "add") as message_add: + chain._handle_callback( + text=f"CALLBACK:agent_choice:{request.request_id}:1", + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + username="tester", + ) + + handle_ai_message.assert_called_once() + kwargs = handle_ai_message.call_args.kwargs + self.assertEqual(kwargs["text"], "我选择电影") + self.assertEqual(kwargs["session_id"], "session-choice") + message_put.assert_called_once() + message_add.assert_called_once() + + +if __name__ == "__main__": + unittest.main()