refactor: generalize agent interaction requests

This commit is contained in:
jxxghp
2026-04-16 22:51:51 +08:00
parent cc31c66b93
commit e0e21e39a2
4 changed files with 151 additions and 47 deletions

View File

@@ -1,4 +1,4 @@
"""Agent 用户按钮选择请求管理。"""
"""Agent 客户端交互请求管理。"""
from dataclasses import dataclass, field
from datetime import datetime, timedelta
@@ -8,16 +8,16 @@ import uuid
@dataclass(frozen=True)
class AgentChoiceOption:
"""按钮选项。"""
class AgentInteractionOption:
"""交互选项。"""
label: str
value: str
@dataclass
class PendingAgentChoice:
"""待处理的按钮选择请求。"""
class PendingAgentInteraction:
"""待处理的 Agent 客户端交互请求。"""
request_id: str
session_id: str
@@ -25,29 +25,30 @@ class PendingAgentChoice:
channel: Optional[str]
source: Optional[str]
username: Optional[str]
title: Optional[str]
prompt: str
options: List[AgentChoiceOption]
options: List[AgentInteractionOption]
created_at: datetime = field(default_factory=datetime.now)
class AgentUserChoiceManager:
"""管理 Agent 发起的按钮选择请求。"""
class AgentInteractionManager:
"""管理 Agent 发起的客户端交互请求。"""
_ttl = timedelta(hours=24)
def __init__(self):
self._pending_choices: Dict[str, PendingAgentChoice] = {}
self._pending_interactions: Dict[str, PendingAgentInteraction] = {}
self._lock = Lock()
def _cleanup_locked(self):
expire_before = datetime.now() - self._ttl
expired_ids = [
request_id
for request_id, request in self._pending_choices.items()
for request_id, request in self._pending_interactions.items()
if request.created_at < expire_before
]
for request_id in expired_ids:
self._pending_choices.pop(request_id, None)
self._pending_interactions.pop(request_id, None)
def create_request(
self,
@@ -56,25 +57,27 @@ class AgentUserChoiceManager:
channel: Optional[str],
source: Optional[str],
username: Optional[str],
title: Optional[str],
prompt: str,
options: List[AgentChoiceOption],
) -> PendingAgentChoice:
options: List[AgentInteractionOption],
) -> PendingAgentInteraction:
with self._lock:
self._cleanup_locked()
request_id = uuid.uuid4().hex[:12]
while request_id in self._pending_choices:
while request_id in self._pending_interactions:
request_id = uuid.uuid4().hex[:12]
request = PendingAgentChoice(
request = PendingAgentInteraction(
request_id=request_id,
session_id=session_id,
user_id=str(user_id),
channel=channel,
source=source,
username=username,
title=title,
prompt=prompt,
options=options,
)
self._pending_choices[request_id] = request
self._pending_interactions[request_id] = request
return request
def resolve(
@@ -82,10 +85,10 @@ class AgentUserChoiceManager:
request_id: str,
option_index: int,
user_id: Optional[str] = None,
) -> Optional[tuple[PendingAgentChoice, AgentChoiceOption]]:
) -> Optional[tuple[PendingAgentInteraction, AgentInteractionOption]]:
with self._lock:
self._cleanup_locked()
request = self._pending_choices.get(request_id)
request = self._pending_interactions.get(request_id)
if not request:
return None
if user_id is not None and str(request.user_id) != str(user_id):
@@ -93,12 +96,12 @@ class AgentUserChoiceManager:
if option_index < 1 or option_index > len(request.options):
return None
option = request.options[option_index - 1]
self._pending_choices.pop(request_id, None)
self._pending_interactions.pop(request_id, None)
return request, option
def clear(self):
with self._lock:
self._pending_choices.clear()
self._pending_interactions.clear()
agent_user_choice_manager = AgentUserChoiceManager()
agent_interaction_manager = AgentInteractionManager()

View File

@@ -5,7 +5,10 @@ from typing import List, Optional, Type
from pydantic import BaseModel, Field, model_validator
from app.agent.tools.base import MoviePilotTool, ToolChain
from app.agent.user_choice import AgentChoiceOption, agent_user_choice_manager
from app.agent.interaction import (
AgentInteractionOption,
agent_interaction_manager,
)
from app.log import logger
from app.schemas import Notification, NotificationType
from app.schemas.message import ChannelCapabilityManager
@@ -111,15 +114,18 @@ class AskUserChoiceTool(MoviePilotTool):
return f"当前渠道最多支持 {max_options} 个按钮选项"
choice_options = [
AgentChoiceOption(label=option.label.strip(), value=option.value.strip())
AgentInteractionOption(
label=option.label.strip(), value=option.value.strip()
)
for option in options
]
request = agent_user_choice_manager.create_request(
request = agent_interaction_manager.create_request(
session_id=self._session_id,
user_id=str(self._user_id),
channel=channel.value,
source=self._source,
username=self._username,
title=title,
prompt=message.strip(),
options=choice_options,
)
@@ -130,7 +136,9 @@ class AskUserChoiceTool(MoviePilotTool):
current_row.append(
{
"text": self._truncate_button_text(option.label, max_text_length),
"callback_data": f"agent_choice:{request.request_id}:{index}",
"callback_data": (
f"agent_interaction:choice:{request.request_id}:{index}"
),
}
)
if len(current_row) >= max_per_row:

View File

@@ -11,7 +11,7 @@ import uuid
import base64
from app.agent import agent_manager
from app.agent.user_choice import agent_user_choice_manager
from app.agent.interaction import agent_interaction_manager
from app.chain import ChainBase
from app.chain.download import DownloadChain
from app.chain.media import MediaChain
@@ -791,6 +791,8 @@ class MessageChain(ChainBase):
source=source,
userid=userid,
username=username,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
):
return
@@ -895,11 +897,18 @@ class MessageChain(ChainBase):
"""
解析 Agent 按钮选择回调。
"""
if not callback_data.startswith("agent_choice:"):
return None
try:
_, request_id, option_index = callback_data.split(":", 2)
except ValueError:
if callback_data.startswith("agent_interaction:choice:"):
try:
_, _, request_id, option_index = callback_data.split(":", 3)
except ValueError:
return None
elif callback_data.startswith("agent_choice:"):
# 兼容旧格式,避免已发送的按钮失效
try:
_, request_id, option_index = callback_data.split(":", 2)
except ValueError:
return None
else:
return None
if not request_id or not option_index.isdigit():
return None
@@ -912,6 +921,8 @@ class MessageChain(ChainBase):
source: str,
userid: Union[str, int],
username: str,
original_message_id: Optional[Union[str, int]] = None,
original_chat_id: Optional[str] = None,
) -> bool:
"""
将 Agent 按钮选择回传为同一会话中的下一条用户消息。
@@ -921,7 +932,7 @@ class MessageChain(ChainBase):
return False
request_id, option_index = callback
resolved = agent_user_choice_manager.resolve(
resolved = agent_interaction_manager.resolve(
request_id=request_id,
option_index=option_index,
user_id=str(userid),
@@ -940,6 +951,15 @@ class MessageChain(ChainBase):
request, option = resolved
selected_text = option.value
self._update_interaction_message_feedback(
channel=channel,
source=source,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
title=request.title,
prompt=request.prompt,
selected_label=option.label,
)
self._bind_session_id(userid, request.session_id)
self._record_user_message(
channel=channel,
@@ -958,6 +978,35 @@ class MessageChain(ChainBase):
)
return True
def _update_interaction_message_feedback(
self,
channel: MessageChannel,
source: str,
original_message_id: Optional[Union[str, int]],
original_chat_id: Optional[str],
prompt: str,
selected_label: str,
title: Optional[str] = None,
) -> None:
"""
在用户点击交互按钮后,立即更新原消息,明确显示已选择的内容。
"""
if not original_message_id or not original_chat_id:
return
lines = [prompt.strip()]
if selected_label:
lines.append(f"已选择:{selected_label}")
feedback_text = "\n\n".join(line for line in lines if line)
self.edit_message(
channel=channel,
source=source,
message_id=original_message_id,
chat_id=original_chat_id,
title=title,
text=feedback_text,
)
def _retry_transfer_history(
self,
history_id: int,

View File

@@ -8,14 +8,17 @@ from app.agent.tools.impl.ask_user_choice import (
AskUserChoiceTool,
UserChoiceOptionInput,
)
from app.agent.user_choice import AgentChoiceOption, agent_user_choice_manager
from app.agent.interaction import (
AgentInteractionOption,
agent_interaction_manager,
)
from app.chain.message import MessageChain
from app.schemas.types import MessageChannel
class TestAgentUserChoice(unittest.TestCase):
class TestAgentInteraction(unittest.TestCase):
def tearDown(self):
agent_user_choice_manager.clear()
agent_interaction_manager.clear()
def test_prompt_injects_choice_tool_hint_only_for_button_channels(self):
telegram_prompt = prompt_manager.get_agent_prompt(
@@ -82,30 +85,76 @@ class TestAgentUserChoice(unittest.TestCase):
self.assertEqual(len(notification.buttons[0]), 2)
callback_data = notification.buttons[0][0]["callback_data"]
_, request_id, option_index = callback_data.split(":")
resolved = agent_user_choice_manager.resolve(request_id, int(option_index), "10001")
_, _, request_id, option_index = callback_data.split(":")
resolved = agent_interaction_manager.resolve(
request_id, int(option_index), "10001"
)
self.assertIsNotNone(resolved)
_, option = resolved
self.assertEqual(option.value, "继续下载")
def test_agent_choice_callback_routes_selected_value_back_to_agent(self):
def test_agent_interaction_callback_routes_selected_value_back_to_agent(self):
chain = MessageChain()
request = agent_user_choice_manager.create_request(
request = agent_interaction_manager.create_request(
session_id="session-choice",
user_id="10001",
channel=MessageChannel.Telegram.value,
source="telegram-test",
username="tester",
title="需要你的选择",
prompt="请选择",
options=[
AgentChoiceOption(label="电影", value="我选择电影"),
AgentChoiceOption(label="电视剧", value="我选择电视剧"),
AgentInteractionOption(label="电影", value="我选择电影"),
AgentInteractionOption(label="电视剧", value="我选择电视剧"),
],
)
with patch.object(chain, "_handle_ai_message") as handle_ai_message, patch.object(
chain.messagehelper, "put"
) as message_put, patch.object(chain.messageoper, "add") as message_add:
) as message_put, patch.object(chain.messageoper, "add") as message_add, patch.object(
chain, "edit_message", return_value=True
) as edit_message:
chain._handle_callback(
text=f"CALLBACK:agent_interaction:choice:{request.request_id}:1",
channel=MessageChannel.Telegram,
source="telegram-test",
userid="10001",
username="tester",
original_message_id=123,
original_chat_id="456",
)
handle_ai_message.assert_called_once()
edit_message.assert_called_once_with(
channel=MessageChannel.Telegram,
source="telegram-test",
message_id=123,
chat_id="456",
title="需要你的选择",
text="请选择\n\n已选择:电影",
)
kwargs = handle_ai_message.call_args.kwargs
self.assertEqual(kwargs["text"], "我选择电影")
self.assertEqual(kwargs["session_id"], "session-choice")
message_put.assert_called_once()
message_add.assert_called_once()
def test_legacy_agent_choice_callback_still_supported(self):
chain = MessageChain()
request = agent_interaction_manager.create_request(
session_id="session-choice",
user_id="10001",
channel=MessageChannel.Telegram.value,
source="telegram-test",
username="tester",
title=None,
prompt="请选择",
options=[AgentInteractionOption(label="电影", value="我选择电影")],
)
with patch.object(chain, "_handle_ai_message") as handle_ai_message, patch.object(
chain.messagehelper, "put"
), patch.object(chain.messageoper, "add"):
chain._handle_callback(
text=f"CALLBACK:agent_choice:{request.request_id}:1",
channel=MessageChannel.Telegram,
@@ -115,11 +164,6 @@ class TestAgentUserChoice(unittest.TestCase):
)
handle_ai_message.assert_called_once()
kwargs = handle_ai_message.call_args.kwargs
self.assertEqual(kwargs["text"], "我选择电影")
self.assertEqual(kwargs["session_id"], "session-choice")
message_put.assert_called_once()
message_add.assert_called_once()
if __name__ == "__main__":