diff --git a/app/agent/user_choice.py b/app/agent/interaction.py similarity index 63% rename from app/agent/user_choice.py rename to app/agent/interaction.py index 5c1e9154..f831cc6e 100644 --- a/app/agent/user_choice.py +++ b/app/agent/interaction.py @@ -1,4 +1,4 @@ -"""Agent 用户按钮选择请求管理。""" +"""Agent 客户端交互请求管理。""" from dataclasses import dataclass, field from datetime import datetime, timedelta @@ -8,16 +8,16 @@ import uuid @dataclass(frozen=True) -class AgentChoiceOption: - """按钮选项。""" +class AgentInteractionOption: + """交互选项。""" label: str value: str @dataclass -class PendingAgentChoice: - """待处理的按钮选择请求。""" +class PendingAgentInteraction: + """待处理的 Agent 客户端交互请求。""" request_id: str session_id: str @@ -25,29 +25,30 @@ class PendingAgentChoice: channel: Optional[str] source: Optional[str] username: Optional[str] + title: Optional[str] prompt: str - options: List[AgentChoiceOption] + options: List[AgentInteractionOption] created_at: datetime = field(default_factory=datetime.now) -class AgentUserChoiceManager: - """管理 Agent 发起的按钮选择请求。""" +class AgentInteractionManager: + """管理 Agent 发起的客户端交互请求。""" _ttl = timedelta(hours=24) def __init__(self): - self._pending_choices: Dict[str, PendingAgentChoice] = {} + self._pending_interactions: Dict[str, PendingAgentInteraction] = {} self._lock = Lock() def _cleanup_locked(self): expire_before = datetime.now() - self._ttl expired_ids = [ request_id - for request_id, request in self._pending_choices.items() + for request_id, request in self._pending_interactions.items() if request.created_at < expire_before ] for request_id in expired_ids: - self._pending_choices.pop(request_id, None) + self._pending_interactions.pop(request_id, None) def create_request( self, @@ -56,25 +57,27 @@ class AgentUserChoiceManager: channel: Optional[str], source: Optional[str], username: Optional[str], + title: Optional[str], prompt: str, - options: List[AgentChoiceOption], - ) -> PendingAgentChoice: + options: List[AgentInteractionOption], + ) -> PendingAgentInteraction: with self._lock: self._cleanup_locked() request_id = uuid.uuid4().hex[:12] - while request_id in self._pending_choices: + while request_id in self._pending_interactions: request_id = uuid.uuid4().hex[:12] - request = PendingAgentChoice( + request = PendingAgentInteraction( request_id=request_id, session_id=session_id, user_id=str(user_id), channel=channel, source=source, username=username, + title=title, prompt=prompt, options=options, ) - self._pending_choices[request_id] = request + self._pending_interactions[request_id] = request return request def resolve( @@ -82,10 +85,10 @@ class AgentUserChoiceManager: request_id: str, option_index: int, user_id: Optional[str] = None, - ) -> Optional[tuple[PendingAgentChoice, AgentChoiceOption]]: + ) -> Optional[tuple[PendingAgentInteraction, AgentInteractionOption]]: with self._lock: self._cleanup_locked() - request = self._pending_choices.get(request_id) + request = self._pending_interactions.get(request_id) if not request: return None if user_id is not None and str(request.user_id) != str(user_id): @@ -93,12 +96,12 @@ class AgentUserChoiceManager: if option_index < 1 or option_index > len(request.options): return None option = request.options[option_index - 1] - self._pending_choices.pop(request_id, None) + self._pending_interactions.pop(request_id, None) return request, option def clear(self): with self._lock: - self._pending_choices.clear() + self._pending_interactions.clear() -agent_user_choice_manager = AgentUserChoiceManager() +agent_interaction_manager = AgentInteractionManager() diff --git a/app/agent/tools/impl/ask_user_choice.py b/app/agent/tools/impl/ask_user_choice.py index 5f1eed21..f44e8bac 100644 --- a/app/agent/tools/impl/ask_user_choice.py +++ b/app/agent/tools/impl/ask_user_choice.py @@ -5,7 +5,10 @@ from typing import List, Optional, Type from pydantic import BaseModel, Field, model_validator from app.agent.tools.base import MoviePilotTool, ToolChain -from app.agent.user_choice import AgentChoiceOption, agent_user_choice_manager +from app.agent.interaction import ( + AgentInteractionOption, + agent_interaction_manager, +) from app.log import logger from app.schemas import Notification, NotificationType from app.schemas.message import ChannelCapabilityManager @@ -111,15 +114,18 @@ class AskUserChoiceTool(MoviePilotTool): return f"当前渠道最多支持 {max_options} 个按钮选项" choice_options = [ - AgentChoiceOption(label=option.label.strip(), value=option.value.strip()) + AgentInteractionOption( + label=option.label.strip(), value=option.value.strip() + ) for option in options ] - request = agent_user_choice_manager.create_request( + request = agent_interaction_manager.create_request( session_id=self._session_id, user_id=str(self._user_id), channel=channel.value, source=self._source, username=self._username, + title=title, prompt=message.strip(), options=choice_options, ) @@ -130,7 +136,9 @@ class AskUserChoiceTool(MoviePilotTool): current_row.append( { "text": self._truncate_button_text(option.label, max_text_length), - "callback_data": f"agent_choice:{request.request_id}:{index}", + "callback_data": ( + f"agent_interaction:choice:{request.request_id}:{index}" + ), } ) if len(current_row) >= max_per_row: diff --git a/app/chain/message.py b/app/chain/message.py index 2e6e9073..da83a479 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -11,7 +11,7 @@ import uuid import base64 from app.agent import agent_manager -from app.agent.user_choice import agent_user_choice_manager +from app.agent.interaction import agent_interaction_manager from app.chain import ChainBase from app.chain.download import DownloadChain from app.chain.media import MediaChain @@ -791,6 +791,8 @@ class MessageChain(ChainBase): source=source, userid=userid, username=username, + original_message_id=original_message_id, + original_chat_id=original_chat_id, ): return @@ -895,11 +897,18 @@ class MessageChain(ChainBase): """ 解析 Agent 按钮选择回调。 """ - if not callback_data.startswith("agent_choice:"): - return None - try: - _, request_id, option_index = callback_data.split(":", 2) - except ValueError: + if callback_data.startswith("agent_interaction:choice:"): + try: + _, _, request_id, option_index = callback_data.split(":", 3) + except ValueError: + return None + elif callback_data.startswith("agent_choice:"): + # 兼容旧格式,避免已发送的按钮失效 + try: + _, request_id, option_index = callback_data.split(":", 2) + except ValueError: + return None + else: return None if not request_id or not option_index.isdigit(): return None @@ -912,6 +921,8 @@ class MessageChain(ChainBase): source: str, userid: Union[str, int], username: str, + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[str] = None, ) -> bool: """ 将 Agent 按钮选择回传为同一会话中的下一条用户消息。 @@ -921,7 +932,7 @@ class MessageChain(ChainBase): return False request_id, option_index = callback - resolved = agent_user_choice_manager.resolve( + resolved = agent_interaction_manager.resolve( request_id=request_id, option_index=option_index, user_id=str(userid), @@ -940,6 +951,15 @@ class MessageChain(ChainBase): request, option = resolved selected_text = option.value + self._update_interaction_message_feedback( + channel=channel, + source=source, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + title=request.title, + prompt=request.prompt, + selected_label=option.label, + ) self._bind_session_id(userid, request.session_id) self._record_user_message( channel=channel, @@ -958,6 +978,35 @@ class MessageChain(ChainBase): ) return True + def _update_interaction_message_feedback( + self, + channel: MessageChannel, + source: str, + original_message_id: Optional[Union[str, int]], + original_chat_id: Optional[str], + prompt: str, + selected_label: str, + title: Optional[str] = None, + ) -> None: + """ + 在用户点击交互按钮后,立即更新原消息,明确显示已选择的内容。 + """ + if not original_message_id or not original_chat_id: + return + + lines = [prompt.strip()] + if selected_label: + lines.append(f"已选择:{selected_label}") + feedback_text = "\n\n".join(line for line in lines if line) + self.edit_message( + channel=channel, + source=source, + message_id=original_message_id, + chat_id=original_chat_id, + title=title, + text=feedback_text, + ) + def _retry_transfer_history( self, history_id: int, diff --git a/tests/test_agent_user_choice.py b/tests/test_agent_interaction.py similarity index 66% rename from tests/test_agent_user_choice.py rename to tests/test_agent_interaction.py index b798c7cd..ca1992a6 100644 --- a/tests/test_agent_user_choice.py +++ b/tests/test_agent_interaction.py @@ -8,14 +8,17 @@ from app.agent.tools.impl.ask_user_choice import ( AskUserChoiceTool, UserChoiceOptionInput, ) -from app.agent.user_choice import AgentChoiceOption, agent_user_choice_manager +from app.agent.interaction import ( + AgentInteractionOption, + agent_interaction_manager, +) from app.chain.message import MessageChain from app.schemas.types import MessageChannel -class TestAgentUserChoice(unittest.TestCase): +class TestAgentInteraction(unittest.TestCase): def tearDown(self): - agent_user_choice_manager.clear() + agent_interaction_manager.clear() def test_prompt_injects_choice_tool_hint_only_for_button_channels(self): telegram_prompt = prompt_manager.get_agent_prompt( @@ -82,30 +85,76 @@ class TestAgentUserChoice(unittest.TestCase): self.assertEqual(len(notification.buttons[0]), 2) callback_data = notification.buttons[0][0]["callback_data"] - _, request_id, option_index = callback_data.split(":") - resolved = agent_user_choice_manager.resolve(request_id, int(option_index), "10001") + _, _, request_id, option_index = callback_data.split(":") + resolved = agent_interaction_manager.resolve( + request_id, int(option_index), "10001" + ) self.assertIsNotNone(resolved) _, option = resolved self.assertEqual(option.value, "继续下载") - def test_agent_choice_callback_routes_selected_value_back_to_agent(self): + def test_agent_interaction_callback_routes_selected_value_back_to_agent(self): chain = MessageChain() - request = agent_user_choice_manager.create_request( + request = agent_interaction_manager.create_request( session_id="session-choice", user_id="10001", channel=MessageChannel.Telegram.value, source="telegram-test", username="tester", + title="需要你的选择", prompt="请选择", options=[ - AgentChoiceOption(label="电影", value="我选择电影"), - AgentChoiceOption(label="电视剧", value="我选择电视剧"), + AgentInteractionOption(label="电影", value="我选择电影"), + AgentInteractionOption(label="电视剧", value="我选择电视剧"), ], ) with patch.object(chain, "_handle_ai_message") as handle_ai_message, patch.object( chain.messagehelper, "put" - ) as message_put, patch.object(chain.messageoper, "add") as message_add: + ) as message_put, patch.object(chain.messageoper, "add") as message_add, patch.object( + chain, "edit_message", return_value=True + ) as edit_message: + chain._handle_callback( + text=f"CALLBACK:agent_interaction:choice:{request.request_id}:1", + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + username="tester", + original_message_id=123, + original_chat_id="456", + ) + + handle_ai_message.assert_called_once() + edit_message.assert_called_once_with( + channel=MessageChannel.Telegram, + source="telegram-test", + message_id=123, + chat_id="456", + title="需要你的选择", + text="请选择\n\n已选择:电影", + ) + kwargs = handle_ai_message.call_args.kwargs + self.assertEqual(kwargs["text"], "我选择电影") + self.assertEqual(kwargs["session_id"], "session-choice") + message_put.assert_called_once() + message_add.assert_called_once() + + def test_legacy_agent_choice_callback_still_supported(self): + chain = MessageChain() + request = agent_interaction_manager.create_request( + session_id="session-choice", + user_id="10001", + channel=MessageChannel.Telegram.value, + source="telegram-test", + username="tester", + title=None, + prompt="请选择", + options=[AgentInteractionOption(label="电影", value="我选择电影")], + ) + + with patch.object(chain, "_handle_ai_message") as handle_ai_message, patch.object( + chain.messagehelper, "put" + ), patch.object(chain.messageoper, "add"): chain._handle_callback( text=f"CALLBACK:agent_choice:{request.request_id}:1", channel=MessageChannel.Telegram, @@ -115,11 +164,6 @@ class TestAgentUserChoice(unittest.TestCase): ) handle_ai_message.assert_called_once() - kwargs = handle_ai_message.call_args.kwargs - self.assertEqual(kwargs["text"], "我选择电影") - self.assertEqual(kwargs["session_id"], "session-choice") - message_put.assert_called_once() - message_add.assert_called_once() if __name__ == "__main__":