mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-06 20:42:43 +08:00
feat: add agent button choice workflow
This commit is contained in:
@@ -23,6 +23,7 @@ Core Capabilities:
|
||||
- Do not stop for approval on read-only operations. Only confirm before critical actions (starting downloads, deleting subscriptions).
|
||||
- If the current channel supports image sending and an image would materially help, you may use the `send_message` tool with `image_url` to send it.
|
||||
- If the current channel supports file sending and you need to return a local image/file for the user to download, use `send_local_file`.
|
||||
{button_choice_spec}
|
||||
- Voice replies: {voice_reply_spec}
|
||||
- NOT a coding assistant. Do not offer code snippets.
|
||||
- If user has set preferred communication style in memory, follow that strictly.
|
||||
|
||||
@@ -76,6 +76,7 @@ class PromptManager:
|
||||
caps = ChannelCapabilityManager.get_capabilities(msg_channel)
|
||||
if caps:
|
||||
markdown_spec = self._generate_formatting_instructions(caps)
|
||||
button_choice_spec = self._generate_button_choice_instructions(msg_channel)
|
||||
|
||||
# 啰嗦模式
|
||||
verbose_spec = ""
|
||||
@@ -100,6 +101,7 @@ class PromptManager:
|
||||
verbose_spec=verbose_spec,
|
||||
moviepilot_info=moviepilot_info,
|
||||
voice_reply_spec=voice_reply_spec,
|
||||
button_choice_spec=button_choice_spec,
|
||||
)
|
||||
|
||||
return base_prompt
|
||||
@@ -187,6 +189,23 @@ class PromptManager:
|
||||
"- Do not repeat the same full reply again after calling `send_voice_message`."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _generate_button_choice_instructions(
|
||||
channel: MessageChannel = None,
|
||||
) -> str:
|
||||
if channel and ChannelCapabilityManager.supports_buttons(
|
||||
channel
|
||||
) and ChannelCapabilityManager.supports_callbacks(channel):
|
||||
return (
|
||||
"- User questions: If you need the user to choose from a few clear options, "
|
||||
"call `ask_user_choice` to send button options. After the user clicks a button, "
|
||||
"the selected value will come back as the user's next message. After calling this tool, "
|
||||
"wait for the user's selection instead of repeating the question in plain text."
|
||||
)
|
||||
return (
|
||||
"- User questions: When you truly need user input, ask briefly in plain text."
|
||||
)
|
||||
|
||||
def clear_cache(self):
|
||||
"""
|
||||
清空缓存
|
||||
|
||||
@@ -30,6 +30,7 @@ from app.agent.tools.impl.search_torrents import SearchTorrentsTool
|
||||
from app.agent.tools.impl.get_search_results import GetSearchResultsTool
|
||||
from app.agent.tools.impl.search_web import SearchWebTool
|
||||
from app.agent.tools.impl.send_message import SendMessageTool
|
||||
from app.agent.tools.impl.ask_user_choice import AskUserChoiceTool
|
||||
from app.agent.tools.impl.send_local_file import SendLocalFileTool
|
||||
from app.agent.tools.impl.send_voice_message import SendVoiceMessageTool
|
||||
from app.agent.tools.impl.query_schedulers import QuerySchedulersTool
|
||||
@@ -58,6 +59,8 @@ from app.agent.tools.impl.query_custom_identifiers import QueryCustomIdentifiers
|
||||
from app.agent.tools.impl.update_custom_identifiers import UpdateCustomIdentifiersTool
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
from app.schemas.message import ChannelCapabilityManager
|
||||
from app.schemas.types import MessageChannel
|
||||
from .base import MoviePilotTool
|
||||
|
||||
|
||||
@@ -66,6 +69,18 @@ class MoviePilotToolFactory:
|
||||
MoviePilot工具工厂
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _should_enable_choice_tool(channel: str = None) -> bool:
|
||||
if not channel:
|
||||
return False
|
||||
try:
|
||||
message_channel = MessageChannel(channel)
|
||||
except ValueError:
|
||||
return False
|
||||
return ChannelCapabilityManager.supports_buttons(
|
||||
message_channel
|
||||
) and ChannelCapabilityManager.supports_callbacks(message_channel)
|
||||
|
||||
@staticmethod
|
||||
def create_tools(
|
||||
session_id: str,
|
||||
@@ -120,8 +135,6 @@ class MoviePilotToolFactory:
|
||||
QueryTransferHistoryTool,
|
||||
TransferFileTool,
|
||||
SendMessageTool,
|
||||
SendLocalFileTool,
|
||||
SendVoiceMessageTool,
|
||||
QuerySchedulersTool,
|
||||
RunSchedulerTool,
|
||||
QueryWorkflowsTool,
|
||||
@@ -138,6 +151,14 @@ class MoviePilotToolFactory:
|
||||
QueryCustomIdentifiersTool,
|
||||
UpdateCustomIdentifiersTool,
|
||||
]
|
||||
if MoviePilotToolFactory._should_enable_choice_tool(channel):
|
||||
tool_definitions.append(AskUserChoiceTool)
|
||||
tool_definitions.extend(
|
||||
[
|
||||
SendLocalFileTool,
|
||||
SendVoiceMessageTool,
|
||||
]
|
||||
)
|
||||
# 创建内置工具
|
||||
for ToolClass in tool_definitions:
|
||||
tool = ToolClass(session_id=session_id, user_id=user_id)
|
||||
|
||||
165
app/agent/tools/impl/ask_user_choice.py
Normal file
165
app/agent/tools/impl/ask_user_choice.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""让用户通过按钮进行选择的工具。"""
|
||||
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.agent.user_choice import AgentChoiceOption, agent_user_choice_manager
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotificationType
|
||||
from app.schemas.message import ChannelCapabilityManager
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class UserChoiceOptionInput(BaseModel):
|
||||
"""单个按钮选项。"""
|
||||
|
||||
label: str = Field(..., description="Text shown on the button")
|
||||
value: str = Field(
|
||||
...,
|
||||
description="The exact content that will be sent back to the agent after the user clicks this button",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_option(self):
|
||||
if not self.label.strip():
|
||||
raise ValueError("label 不能为空")
|
||||
if not self.value.strip():
|
||||
raise ValueError("value 不能为空")
|
||||
return self
|
||||
|
||||
|
||||
class AskUserChoiceInput(BaseModel):
|
||||
"""按钮选择工具输入。"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why the agent needs the user to choose from buttons",
|
||||
)
|
||||
message: str = Field(
|
||||
...,
|
||||
description="Question or prompt shown to the user together with the buttons",
|
||||
)
|
||||
title: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional short title displayed above the question",
|
||||
)
|
||||
options: List[UserChoiceOptionInput] = Field(
|
||||
...,
|
||||
description="Button options to show to the user",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_payload(self):
|
||||
if not self.message.strip():
|
||||
raise ValueError("message 不能为空")
|
||||
if not self.options:
|
||||
raise ValueError("options 至少需要提供一个")
|
||||
return self
|
||||
|
||||
|
||||
class AskUserChoiceTool(MoviePilotTool):
|
||||
name: str = "ask_user_choice"
|
||||
description: str = (
|
||||
"Ask the user to choose from button options on channels that support interactive buttons. "
|
||||
"After the user clicks a button, the selected value will come back as the user's next message."
|
||||
)
|
||||
args_schema: Type[BaseModel] = AskUserChoiceInput
|
||||
require_admin: bool = False
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
message = kwargs.get("message", "") or ""
|
||||
if len(message) > 40:
|
||||
message = message[:40] + "..."
|
||||
return f"正在发送按钮选择: {message}"
|
||||
|
||||
@staticmethod
|
||||
def _truncate_button_text(text: str, max_length: int) -> str:
|
||||
if max_length <= 0 or len(text) <= max_length:
|
||||
return text
|
||||
if max_length <= 3:
|
||||
return text[:max_length]
|
||||
return text[: max_length - 3] + "..."
|
||||
|
||||
async def run(
|
||||
self,
|
||||
message: str,
|
||||
options: List[UserChoiceOptionInput],
|
||||
title: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
if not self._channel or not self._source:
|
||||
return "当前不在可回传消息的会话中,无法发起按钮选择"
|
||||
|
||||
try:
|
||||
channel = MessageChannel(self._channel)
|
||||
except ValueError:
|
||||
return f"不支持的消息渠道: {self._channel}"
|
||||
|
||||
if not (
|
||||
ChannelCapabilityManager.supports_buttons(channel)
|
||||
and ChannelCapabilityManager.supports_callbacks(channel)
|
||||
):
|
||||
return f"当前渠道 {channel.value} 不支持按钮选择"
|
||||
|
||||
max_per_row = ChannelCapabilityManager.get_max_buttons_per_row(channel)
|
||||
max_rows = ChannelCapabilityManager.get_max_button_rows(channel)
|
||||
max_text_length = ChannelCapabilityManager.get_max_button_text_length(channel)
|
||||
max_options = max_per_row * max_rows
|
||||
if len(options) > max_options:
|
||||
return f"当前渠道最多支持 {max_options} 个按钮选项"
|
||||
|
||||
choice_options = [
|
||||
AgentChoiceOption(label=option.label.strip(), value=option.value.strip())
|
||||
for option in options
|
||||
]
|
||||
request = agent_user_choice_manager.create_request(
|
||||
session_id=self._session_id,
|
||||
user_id=str(self._user_id),
|
||||
channel=channel.value,
|
||||
source=self._source,
|
||||
username=self._username,
|
||||
prompt=message.strip(),
|
||||
options=choice_options,
|
||||
)
|
||||
|
||||
buttons = []
|
||||
current_row = []
|
||||
for index, option in enumerate(choice_options, start=1):
|
||||
current_row.append(
|
||||
{
|
||||
"text": self._truncate_button_text(option.label, max_text_length),
|
||||
"callback_data": f"agent_choice:{request.request_id}:{index}",
|
||||
}
|
||||
)
|
||||
if len(current_row) >= max_per_row:
|
||||
buttons.append(current_row)
|
||||
current_row = []
|
||||
if current_row:
|
||||
buttons.append(current_row)
|
||||
|
||||
logger.info(
|
||||
"执行工具: %s, channel=%s, session_id=%s, options=%s",
|
||||
self.name,
|
||||
channel.value,
|
||||
self._session_id,
|
||||
len(choice_options),
|
||||
)
|
||||
|
||||
await ToolChain().async_post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=self._source,
|
||||
mtype=NotificationType.Agent,
|
||||
userid=self._user_id,
|
||||
username=self._username,
|
||||
title=title,
|
||||
text=message.strip(),
|
||||
buttons=buttons,
|
||||
)
|
||||
)
|
||||
|
||||
self._agent_context["user_reply_sent"] = True
|
||||
self._agent_context["reply_mode"] = "button_choice"
|
||||
return f"已发送 {len(choice_options)} 个按钮选项,等待用户选择"
|
||||
104
app/agent/user_choice.py
Normal file
104
app/agent/user_choice.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Agent 用户按钮选择请求管理。"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Lock
|
||||
from typing import Dict, List, Optional
|
||||
import uuid
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentChoiceOption:
|
||||
"""按钮选项。"""
|
||||
|
||||
label: str
|
||||
value: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingAgentChoice:
|
||||
"""待处理的按钮选择请求。"""
|
||||
|
||||
request_id: str
|
||||
session_id: str
|
||||
user_id: str
|
||||
channel: Optional[str]
|
||||
source: Optional[str]
|
||||
username: Optional[str]
|
||||
prompt: str
|
||||
options: List[AgentChoiceOption]
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
class AgentUserChoiceManager:
|
||||
"""管理 Agent 发起的按钮选择请求。"""
|
||||
|
||||
_ttl = timedelta(hours=24)
|
||||
|
||||
def __init__(self):
|
||||
self._pending_choices: Dict[str, PendingAgentChoice] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def _cleanup_locked(self):
|
||||
expire_before = datetime.now() - self._ttl
|
||||
expired_ids = [
|
||||
request_id
|
||||
for request_id, request in self._pending_choices.items()
|
||||
if request.created_at < expire_before
|
||||
]
|
||||
for request_id in expired_ids:
|
||||
self._pending_choices.pop(request_id, None)
|
||||
|
||||
def create_request(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
channel: Optional[str],
|
||||
source: Optional[str],
|
||||
username: Optional[str],
|
||||
prompt: str,
|
||||
options: List[AgentChoiceOption],
|
||||
) -> PendingAgentChoice:
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request_id = uuid.uuid4().hex[:12]
|
||||
while request_id in self._pending_choices:
|
||||
request_id = uuid.uuid4().hex[:12]
|
||||
request = PendingAgentChoice(
|
||||
request_id=request_id,
|
||||
session_id=session_id,
|
||||
user_id=str(user_id),
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username,
|
||||
prompt=prompt,
|
||||
options=options,
|
||||
)
|
||||
self._pending_choices[request_id] = request
|
||||
return request
|
||||
|
||||
def resolve(
|
||||
self,
|
||||
request_id: str,
|
||||
option_index: int,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Optional[tuple[PendingAgentChoice, AgentChoiceOption]]:
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request = self._pending_choices.get(request_id)
|
||||
if not request:
|
||||
return None
|
||||
if user_id is not None and str(request.user_id) != str(user_id):
|
||||
return None
|
||||
if option_index < 1 or option_index > len(request.options):
|
||||
return None
|
||||
option = request.options[option_index - 1]
|
||||
self._pending_choices.pop(request_id, None)
|
||||
return request, option
|
||||
|
||||
def clear(self):
|
||||
with self._lock:
|
||||
self._pending_choices.clear()
|
||||
|
||||
|
||||
agent_user_choice_manager = AgentUserChoiceManager()
|
||||
@@ -11,6 +11,7 @@ import uuid
|
||||
import base64
|
||||
|
||||
from app.agent import agent_manager
|
||||
from app.agent.user_choice import agent_user_choice_manager
|
||||
from app.chain import ChainBase
|
||||
from app.chain.download import DownloadChain
|
||||
from app.chain.media import MediaChain
|
||||
@@ -215,22 +216,12 @@ class MessageChain(ChainBase):
|
||||
|
||||
# 保存消息
|
||||
if not text.startswith("CALLBACK:"):
|
||||
self.messagehelper.put(
|
||||
CommingMessage(
|
||||
userid=userid,
|
||||
username=username,
|
||||
channel=channel,
|
||||
source=source,
|
||||
text=text,
|
||||
),
|
||||
role="user",
|
||||
)
|
||||
self.messageoper.add(
|
||||
self._record_user_message(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=username or userid,
|
||||
userid=userid,
|
||||
username=username,
|
||||
text=text,
|
||||
action=0,
|
||||
)
|
||||
# 处理消息
|
||||
if text.startswith("CALLBACK:"):
|
||||
@@ -794,6 +785,15 @@ class MessageChain(ChainBase):
|
||||
):
|
||||
return
|
||||
|
||||
if self._handle_agent_choice_callback(
|
||||
callback_data=callback_data,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
):
|
||||
return
|
||||
|
||||
# 插件消息的事件回调 [PLUGIN]插件ID|内容
|
||||
if callback_data.startswith("[PLUGIN]"):
|
||||
# 提取插件ID和内容
|
||||
@@ -888,6 +888,76 @@ class MessageChain(ChainBase):
|
||||
)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _parse_agent_choice_callback(
|
||||
callback_data: str,
|
||||
) -> Optional[tuple[str, int]]:
|
||||
"""
|
||||
解析 Agent 按钮选择回调。
|
||||
"""
|
||||
if not callback_data.startswith("agent_choice:"):
|
||||
return None
|
||||
try:
|
||||
_, request_id, option_index = callback_data.split(":", 2)
|
||||
except ValueError:
|
||||
return None
|
||||
if not request_id or not option_index.isdigit():
|
||||
return None
|
||||
return request_id, int(option_index)
|
||||
|
||||
def _handle_agent_choice_callback(
|
||||
self,
|
||||
callback_data: str,
|
||||
channel: MessageChannel,
|
||||
source: str,
|
||||
userid: Union[str, int],
|
||||
username: str,
|
||||
) -> bool:
|
||||
"""
|
||||
将 Agent 按钮选择回传为同一会话中的下一条用户消息。
|
||||
"""
|
||||
callback = self._parse_agent_choice_callback(callback_data)
|
||||
if not callback:
|
||||
return False
|
||||
|
||||
request_id, option_index = callback
|
||||
resolved = agent_user_choice_manager.resolve(
|
||||
request_id=request_id,
|
||||
option_index=option_index,
|
||||
user_id=str(userid),
|
||||
)
|
||||
if not resolved:
|
||||
self.post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="该选择已失效,请重新发起选择",
|
||||
)
|
||||
)
|
||||
return True
|
||||
|
||||
request, option = resolved
|
||||
selected_text = option.value
|
||||
self._bind_session_id(userid, request.session_id)
|
||||
self._record_user_message(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
text=selected_text,
|
||||
)
|
||||
self._handle_ai_message(
|
||||
text=selected_text,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
session_id=request.session_id,
|
||||
)
|
||||
return True
|
||||
|
||||
def _retry_transfer_history(
|
||||
self,
|
||||
history_id: int,
|
||||
@@ -1308,6 +1378,41 @@ class MessageChain(ChainBase):
|
||||
logger.info(f"创建新会话ID: {new_session_id}, 用户: {userid}")
|
||||
return new_session_id
|
||||
|
||||
def _bind_session_id(self, userid: Union[str, int], session_id: str) -> None:
|
||||
"""
|
||||
将用户会话绑定到指定的 session_id,并刷新最后活动时间。
|
||||
"""
|
||||
self._user_sessions[userid] = (session_id, datetime.now())
|
||||
|
||||
def _record_user_message(
|
||||
self,
|
||||
channel: MessageChannel,
|
||||
source: str,
|
||||
userid: Union[str, int],
|
||||
username: str,
|
||||
text: str,
|
||||
) -> None:
|
||||
"""
|
||||
保存一条用户消息到消息历史与数据库。
|
||||
"""
|
||||
self.messagehelper.put(
|
||||
CommingMessage(
|
||||
userid=userid,
|
||||
username=username,
|
||||
channel=channel,
|
||||
source=source,
|
||||
text=text,
|
||||
),
|
||||
role="user",
|
||||
)
|
||||
self.messageoper.add(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=username or userid,
|
||||
text=text,
|
||||
action=0,
|
||||
)
|
||||
|
||||
def clear_user_session(self, userid: Union[str, int]) -> bool:
|
||||
"""
|
||||
清除指定用户的会话信息
|
||||
@@ -1427,6 +1532,7 @@ class MessageChain(ChainBase):
|
||||
images: Optional[List[CommingMessage.MessageImage]] = None,
|
||||
files: Optional[List[CommingMessage.MessageAttachment]] = None,
|
||||
reply_with_voice: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
处理AI智能体消息
|
||||
@@ -1466,7 +1572,8 @@ class MessageChain(ChainBase):
|
||||
return
|
||||
|
||||
# 生成或复用会话ID
|
||||
session_id = self._get_or_create_session_id(userid)
|
||||
session_id = session_id or self._get_or_create_session_id(userid)
|
||||
self._bind_session_id(userid, session_id)
|
||||
|
||||
# 下载图片并转为base64
|
||||
original_images = images
|
||||
|
||||
126
tests/test_agent_user_choice.py
Normal file
126
tests/test_agent_user_choice.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from app.agent.prompt import prompt_manager
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.agent.tools.impl.ask_user_choice import (
|
||||
AskUserChoiceTool,
|
||||
UserChoiceOptionInput,
|
||||
)
|
||||
from app.agent.user_choice import AgentChoiceOption, agent_user_choice_manager
|
||||
from app.chain.message import MessageChain
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class TestAgentUserChoice(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
agent_user_choice_manager.clear()
|
||||
|
||||
def test_prompt_injects_choice_tool_hint_only_for_button_channels(self):
|
||||
telegram_prompt = prompt_manager.get_agent_prompt(
|
||||
channel=MessageChannel.Telegram.value
|
||||
)
|
||||
wechat_prompt = prompt_manager.get_agent_prompt(
|
||||
channel=MessageChannel.Wechat.value
|
||||
)
|
||||
|
||||
self.assertIn("ask_user_choice", telegram_prompt)
|
||||
self.assertNotIn("ask_user_choice", wechat_prompt)
|
||||
|
||||
def test_factory_injects_choice_tool_only_for_button_channels(self):
|
||||
with patch(
|
||||
"app.agent.tools.factory.PluginManager.get_plugin_agent_tools",
|
||||
return_value=[],
|
||||
):
|
||||
telegram_tools = MoviePilotToolFactory.create_tools(
|
||||
session_id="session-1",
|
||||
user_id="10001",
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-test",
|
||||
username="tester",
|
||||
)
|
||||
wechat_tools = MoviePilotToolFactory.create_tools(
|
||||
session_id="session-2",
|
||||
user_id="10001",
|
||||
channel=MessageChannel.Wechat.value,
|
||||
source="wechat-test",
|
||||
username="tester",
|
||||
)
|
||||
|
||||
self.assertIn("ask_user_choice", [tool.name for tool in telegram_tools])
|
||||
self.assertNotIn("ask_user_choice", [tool.name for tool in wechat_tools])
|
||||
|
||||
def test_choice_tool_sends_buttons_and_registers_pending_request(self):
|
||||
tool = AskUserChoiceTool(session_id="session-1", user_id="10001")
|
||||
tool.set_message_attr(
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-test",
|
||||
username="tester",
|
||||
)
|
||||
tool.set_agent_context(agent_context={})
|
||||
|
||||
with patch(
|
||||
"app.agent.tools.impl.ask_user_choice.ToolChain.async_post_message",
|
||||
new=AsyncMock(),
|
||||
) as async_post_message:
|
||||
result = asyncio.run(
|
||||
tool.run(
|
||||
message="请选择要执行的操作",
|
||||
options=[
|
||||
UserChoiceOptionInput(label="继续下载", value="继续下载"),
|
||||
UserChoiceOptionInput(label="先看详情", value="先看详情"),
|
||||
],
|
||||
title="需要你的选择",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertIn("等待用户选择", result)
|
||||
self.assertTrue(tool._agent_context.get("user_reply_sent"))
|
||||
notification = async_post_message.await_args.args[0]
|
||||
self.assertEqual(notification.text, "请选择要执行的操作")
|
||||
self.assertEqual(len(notification.buttons[0]), 2)
|
||||
|
||||
callback_data = notification.buttons[0][0]["callback_data"]
|
||||
_, request_id, option_index = callback_data.split(":")
|
||||
resolved = agent_user_choice_manager.resolve(request_id, int(option_index), "10001")
|
||||
self.assertIsNotNone(resolved)
|
||||
_, option = resolved
|
||||
self.assertEqual(option.value, "继续下载")
|
||||
|
||||
def test_agent_choice_callback_routes_selected_value_back_to_agent(self):
|
||||
chain = MessageChain()
|
||||
request = agent_user_choice_manager.create_request(
|
||||
session_id="session-choice",
|
||||
user_id="10001",
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-test",
|
||||
username="tester",
|
||||
prompt="请选择",
|
||||
options=[
|
||||
AgentChoiceOption(label="电影", value="我选择电影"),
|
||||
AgentChoiceOption(label="电视剧", value="我选择电视剧"),
|
||||
],
|
||||
)
|
||||
|
||||
with patch.object(chain, "_handle_ai_message") as handle_ai_message, patch.object(
|
||||
chain.messagehelper, "put"
|
||||
) as message_put, patch.object(chain.messageoper, "add") as message_add:
|
||||
chain._handle_callback(
|
||||
text=f"CALLBACK:agent_choice:{request.request_id}:1",
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
)
|
||||
|
||||
handle_ai_message.assert_called_once()
|
||||
kwargs = handle_ai_message.call_args.kwargs
|
||||
self.assertEqual(kwargs["text"], "我选择电影")
|
||||
self.assertEqual(kwargs["session_id"], "session-choice")
|
||||
message_put.assert_called_once()
|
||||
message_add.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user