refactor: generalize agent interaction requests

This commit is contained in:
jxxghp
2026-04-16 22:51:51 +08:00
parent cc31c66b93
commit e0e21e39a2
4 changed files with 151 additions and 47 deletions

View File

@@ -1,4 +1,4 @@
"""Agent 用户按钮选择请求管理。"""
"""Agent 客户端交互请求管理。"""
from dataclasses import dataclass, field
from datetime import datetime, timedelta
@@ -8,16 +8,16 @@ import uuid
@dataclass(frozen=True)
class AgentChoiceOption:
"""按钮选项。"""
class AgentInteractionOption:
"""交互选项。"""
label: str
value: str
@dataclass
class PendingAgentChoice:
"""待处理的按钮选择请求。"""
class PendingAgentInteraction:
"""待处理的 Agent 客户端交互请求。"""
request_id: str
session_id: str
@@ -25,29 +25,30 @@ class PendingAgentChoice:
channel: Optional[str]
source: Optional[str]
username: Optional[str]
title: Optional[str]
prompt: str
options: List[AgentChoiceOption]
options: List[AgentInteractionOption]
created_at: datetime = field(default_factory=datetime.now)
class AgentUserChoiceManager:
"""管理 Agent 发起的按钮选择请求。"""
class AgentInteractionManager:
"""管理 Agent 发起的客户端交互请求。"""
_ttl = timedelta(hours=24)
def __init__(self):
self._pending_choices: Dict[str, PendingAgentChoice] = {}
self._pending_interactions: Dict[str, PendingAgentInteraction] = {}
self._lock = Lock()
def _cleanup_locked(self):
expire_before = datetime.now() - self._ttl
expired_ids = [
request_id
for request_id, request in self._pending_choices.items()
for request_id, request in self._pending_interactions.items()
if request.created_at < expire_before
]
for request_id in expired_ids:
self._pending_choices.pop(request_id, None)
self._pending_interactions.pop(request_id, None)
def create_request(
self,
@@ -56,25 +57,27 @@ class AgentUserChoiceManager:
channel: Optional[str],
source: Optional[str],
username: Optional[str],
title: Optional[str],
prompt: str,
options: List[AgentChoiceOption],
) -> PendingAgentChoice:
options: List[AgentInteractionOption],
) -> PendingAgentInteraction:
with self._lock:
self._cleanup_locked()
request_id = uuid.uuid4().hex[:12]
while request_id in self._pending_choices:
while request_id in self._pending_interactions:
request_id = uuid.uuid4().hex[:12]
request = PendingAgentChoice(
request = PendingAgentInteraction(
request_id=request_id,
session_id=session_id,
user_id=str(user_id),
channel=channel,
source=source,
username=username,
title=title,
prompt=prompt,
options=options,
)
self._pending_choices[request_id] = request
self._pending_interactions[request_id] = request
return request
def resolve(
@@ -82,10 +85,10 @@ class AgentUserChoiceManager:
request_id: str,
option_index: int,
user_id: Optional[str] = None,
) -> Optional[tuple[PendingAgentChoice, AgentChoiceOption]]:
) -> Optional[tuple[PendingAgentInteraction, AgentInteractionOption]]:
with self._lock:
self._cleanup_locked()
request = self._pending_choices.get(request_id)
request = self._pending_interactions.get(request_id)
if not request:
return None
if user_id is not None and str(request.user_id) != str(user_id):
@@ -93,12 +96,12 @@ class AgentUserChoiceManager:
if option_index < 1 or option_index > len(request.options):
return None
option = request.options[option_index - 1]
self._pending_choices.pop(request_id, None)
self._pending_interactions.pop(request_id, None)
return request, option
def clear(self):
with self._lock:
self._pending_choices.clear()
self._pending_interactions.clear()
agent_user_choice_manager = AgentUserChoiceManager()
agent_interaction_manager = AgentInteractionManager()

View File

@@ -5,7 +5,10 @@ from typing import List, Optional, Type
from pydantic import BaseModel, Field, model_validator
from app.agent.tools.base import MoviePilotTool, ToolChain
from app.agent.user_choice import AgentChoiceOption, agent_user_choice_manager
from app.agent.interaction import (
AgentInteractionOption,
agent_interaction_manager,
)
from app.log import logger
from app.schemas import Notification, NotificationType
from app.schemas.message import ChannelCapabilityManager
@@ -111,15 +114,18 @@ class AskUserChoiceTool(MoviePilotTool):
return f"当前渠道最多支持 {max_options} 个按钮选项"
choice_options = [
AgentChoiceOption(label=option.label.strip(), value=option.value.strip())
AgentInteractionOption(
label=option.label.strip(), value=option.value.strip()
)
for option in options
]
request = agent_user_choice_manager.create_request(
request = agent_interaction_manager.create_request(
session_id=self._session_id,
user_id=str(self._user_id),
channel=channel.value,
source=self._source,
username=self._username,
title=title,
prompt=message.strip(),
options=choice_options,
)
@@ -130,7 +136,9 @@ class AskUserChoiceTool(MoviePilotTool):
current_row.append(
{
"text": self._truncate_button_text(option.label, max_text_length),
"callback_data": f"agent_choice:{request.request_id}:{index}",
"callback_data": (
f"agent_interaction:choice:{request.request_id}:{index}"
),
}
)
if len(current_row) >= max_per_row:

View File

@@ -11,7 +11,7 @@ import uuid
import base64
from app.agent import agent_manager
from app.agent.user_choice import agent_user_choice_manager
from app.agent.interaction import agent_interaction_manager
from app.chain import ChainBase
from app.chain.download import DownloadChain
from app.chain.media import MediaChain
@@ -791,6 +791,8 @@ class MessageChain(ChainBase):
source=source,
userid=userid,
username=username,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
):
return
@@ -895,11 +897,18 @@ class MessageChain(ChainBase):
"""
解析 Agent 按钮选择回调。
"""
if not callback_data.startswith("agent_choice:"):
return None
try:
_, request_id, option_index = callback_data.split(":", 2)
except ValueError:
if callback_data.startswith("agent_interaction:choice:"):
try:
_, _, request_id, option_index = callback_data.split(":", 3)
except ValueError:
return None
elif callback_data.startswith("agent_choice:"):
# 兼容旧格式,避免已发送的按钮失效
try:
_, request_id, option_index = callback_data.split(":", 2)
except ValueError:
return None
else:
return None
if not request_id or not option_index.isdigit():
return None
@@ -912,6 +921,8 @@ class MessageChain(ChainBase):
source: str,
userid: Union[str, int],
username: str,
original_message_id: Optional[Union[str, int]] = None,
original_chat_id: Optional[str] = None,
) -> bool:
"""
将 Agent 按钮选择回传为同一会话中的下一条用户消息。
@@ -921,7 +932,7 @@ class MessageChain(ChainBase):
return False
request_id, option_index = callback
resolved = agent_user_choice_manager.resolve(
resolved = agent_interaction_manager.resolve(
request_id=request_id,
option_index=option_index,
user_id=str(userid),
@@ -940,6 +951,15 @@ class MessageChain(ChainBase):
request, option = resolved
selected_text = option.value
self._update_interaction_message_feedback(
channel=channel,
source=source,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
title=request.title,
prompt=request.prompt,
selected_label=option.label,
)
self._bind_session_id(userid, request.session_id)
self._record_user_message(
channel=channel,
@@ -958,6 +978,35 @@ class MessageChain(ChainBase):
)
return True
def _update_interaction_message_feedback(
self,
channel: MessageChannel,
source: str,
original_message_id: Optional[Union[str, int]],
original_chat_id: Optional[str],
prompt: str,
selected_label: str,
title: Optional[str] = None,
) -> None:
"""
在用户点击交互按钮后,立即更新原消息,明确显示已选择的内容。
"""
if not original_message_id or not original_chat_id:
return
lines = [prompt.strip()]
if selected_label:
lines.append(f"已选择:{selected_label}")
feedback_text = "\n\n".join(line for line in lines if line)
self.edit_message(
channel=channel,
source=source,
message_id=original_message_id,
chat_id=original_chat_id,
title=title,
text=feedback_text,
)
def _retry_transfer_history(
self,
history_id: int,