mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-05 07:29:56 +08:00
refactor: generalize agent interaction requests
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
"""Agent 用户按钮选择请求管理。"""
|
||||
"""Agent 客户端交互请求管理。"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
@@ -8,16 +8,16 @@ import uuid
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentChoiceOption:
|
||||
"""按钮选项。"""
|
||||
class AgentInteractionOption:
|
||||
"""交互选项。"""
|
||||
|
||||
label: str
|
||||
value: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingAgentChoice:
|
||||
"""待处理的按钮选择请求。"""
|
||||
class PendingAgentInteraction:
|
||||
"""待处理的 Agent 客户端交互请求。"""
|
||||
|
||||
request_id: str
|
||||
session_id: str
|
||||
@@ -25,29 +25,30 @@ class PendingAgentChoice:
|
||||
channel: Optional[str]
|
||||
source: Optional[str]
|
||||
username: Optional[str]
|
||||
title: Optional[str]
|
||||
prompt: str
|
||||
options: List[AgentChoiceOption]
|
||||
options: List[AgentInteractionOption]
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
class AgentUserChoiceManager:
|
||||
"""管理 Agent 发起的按钮选择请求。"""
|
||||
class AgentInteractionManager:
|
||||
"""管理 Agent 发起的客户端交互请求。"""
|
||||
|
||||
_ttl = timedelta(hours=24)
|
||||
|
||||
def __init__(self):
|
||||
self._pending_choices: Dict[str, PendingAgentChoice] = {}
|
||||
self._pending_interactions: Dict[str, PendingAgentInteraction] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def _cleanup_locked(self):
|
||||
expire_before = datetime.now() - self._ttl
|
||||
expired_ids = [
|
||||
request_id
|
||||
for request_id, request in self._pending_choices.items()
|
||||
for request_id, request in self._pending_interactions.items()
|
||||
if request.created_at < expire_before
|
||||
]
|
||||
for request_id in expired_ids:
|
||||
self._pending_choices.pop(request_id, None)
|
||||
self._pending_interactions.pop(request_id, None)
|
||||
|
||||
def create_request(
|
||||
self,
|
||||
@@ -56,25 +57,27 @@ class AgentUserChoiceManager:
|
||||
channel: Optional[str],
|
||||
source: Optional[str],
|
||||
username: Optional[str],
|
||||
title: Optional[str],
|
||||
prompt: str,
|
||||
options: List[AgentChoiceOption],
|
||||
) -> PendingAgentChoice:
|
||||
options: List[AgentInteractionOption],
|
||||
) -> PendingAgentInteraction:
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request_id = uuid.uuid4().hex[:12]
|
||||
while request_id in self._pending_choices:
|
||||
while request_id in self._pending_interactions:
|
||||
request_id = uuid.uuid4().hex[:12]
|
||||
request = PendingAgentChoice(
|
||||
request = PendingAgentInteraction(
|
||||
request_id=request_id,
|
||||
session_id=session_id,
|
||||
user_id=str(user_id),
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username,
|
||||
title=title,
|
||||
prompt=prompt,
|
||||
options=options,
|
||||
)
|
||||
self._pending_choices[request_id] = request
|
||||
self._pending_interactions[request_id] = request
|
||||
return request
|
||||
|
||||
def resolve(
|
||||
@@ -82,10 +85,10 @@ class AgentUserChoiceManager:
|
||||
request_id: str,
|
||||
option_index: int,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Optional[tuple[PendingAgentChoice, AgentChoiceOption]]:
|
||||
) -> Optional[tuple[PendingAgentInteraction, AgentInteractionOption]]:
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request = self._pending_choices.get(request_id)
|
||||
request = self._pending_interactions.get(request_id)
|
||||
if not request:
|
||||
return None
|
||||
if user_id is not None and str(request.user_id) != str(user_id):
|
||||
@@ -93,12 +96,12 @@ class AgentUserChoiceManager:
|
||||
if option_index < 1 or option_index > len(request.options):
|
||||
return None
|
||||
option = request.options[option_index - 1]
|
||||
self._pending_choices.pop(request_id, None)
|
||||
self._pending_interactions.pop(request_id, None)
|
||||
return request, option
|
||||
|
||||
def clear(self):
|
||||
with self._lock:
|
||||
self._pending_choices.clear()
|
||||
self._pending_interactions.clear()
|
||||
|
||||
|
||||
agent_user_choice_manager = AgentUserChoiceManager()
|
||||
agent_interaction_manager = AgentInteractionManager()
|
||||
@@ -5,7 +5,10 @@ from typing import List, Optional, Type
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.agent.user_choice import AgentChoiceOption, agent_user_choice_manager
|
||||
from app.agent.interaction import (
|
||||
AgentInteractionOption,
|
||||
agent_interaction_manager,
|
||||
)
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotificationType
|
||||
from app.schemas.message import ChannelCapabilityManager
|
||||
@@ -111,15 +114,18 @@ class AskUserChoiceTool(MoviePilotTool):
|
||||
return f"当前渠道最多支持 {max_options} 个按钮选项"
|
||||
|
||||
choice_options = [
|
||||
AgentChoiceOption(label=option.label.strip(), value=option.value.strip())
|
||||
AgentInteractionOption(
|
||||
label=option.label.strip(), value=option.value.strip()
|
||||
)
|
||||
for option in options
|
||||
]
|
||||
request = agent_user_choice_manager.create_request(
|
||||
request = agent_interaction_manager.create_request(
|
||||
session_id=self._session_id,
|
||||
user_id=str(self._user_id),
|
||||
channel=channel.value,
|
||||
source=self._source,
|
||||
username=self._username,
|
||||
title=title,
|
||||
prompt=message.strip(),
|
||||
options=choice_options,
|
||||
)
|
||||
@@ -130,7 +136,9 @@ class AskUserChoiceTool(MoviePilotTool):
|
||||
current_row.append(
|
||||
{
|
||||
"text": self._truncate_button_text(option.label, max_text_length),
|
||||
"callback_data": f"agent_choice:{request.request_id}:{index}",
|
||||
"callback_data": (
|
||||
f"agent_interaction:choice:{request.request_id}:{index}"
|
||||
),
|
||||
}
|
||||
)
|
||||
if len(current_row) >= max_per_row:
|
||||
|
||||
@@ -11,7 +11,7 @@ import uuid
|
||||
import base64
|
||||
|
||||
from app.agent import agent_manager
|
||||
from app.agent.user_choice import agent_user_choice_manager
|
||||
from app.agent.interaction import agent_interaction_manager
|
||||
from app.chain import ChainBase
|
||||
from app.chain.download import DownloadChain
|
||||
from app.chain.media import MediaChain
|
||||
@@ -791,6 +791,8 @@ class MessageChain(ChainBase):
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id,
|
||||
):
|
||||
return
|
||||
|
||||
@@ -895,11 +897,18 @@ class MessageChain(ChainBase):
|
||||
"""
|
||||
解析 Agent 按钮选择回调。
|
||||
"""
|
||||
if not callback_data.startswith("agent_choice:"):
|
||||
return None
|
||||
try:
|
||||
_, request_id, option_index = callback_data.split(":", 2)
|
||||
except ValueError:
|
||||
if callback_data.startswith("agent_interaction:choice:"):
|
||||
try:
|
||||
_, _, request_id, option_index = callback_data.split(":", 3)
|
||||
except ValueError:
|
||||
return None
|
||||
elif callback_data.startswith("agent_choice:"):
|
||||
# 兼容旧格式,避免已发送的按钮失效
|
||||
try:
|
||||
_, request_id, option_index = callback_data.split(":", 2)
|
||||
except ValueError:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
if not request_id or not option_index.isdigit():
|
||||
return None
|
||||
@@ -912,6 +921,8 @@ class MessageChain(ChainBase):
|
||||
source: str,
|
||||
userid: Union[str, int],
|
||||
username: str,
|
||||
original_message_id: Optional[Union[str, int]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
将 Agent 按钮选择回传为同一会话中的下一条用户消息。
|
||||
@@ -921,7 +932,7 @@ class MessageChain(ChainBase):
|
||||
return False
|
||||
|
||||
request_id, option_index = callback
|
||||
resolved = agent_user_choice_manager.resolve(
|
||||
resolved = agent_interaction_manager.resolve(
|
||||
request_id=request_id,
|
||||
option_index=option_index,
|
||||
user_id=str(userid),
|
||||
@@ -940,6 +951,15 @@ class MessageChain(ChainBase):
|
||||
|
||||
request, option = resolved
|
||||
selected_text = option.value
|
||||
self._update_interaction_message_feedback(
|
||||
channel=channel,
|
||||
source=source,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id,
|
||||
title=request.title,
|
||||
prompt=request.prompt,
|
||||
selected_label=option.label,
|
||||
)
|
||||
self._bind_session_id(userid, request.session_id)
|
||||
self._record_user_message(
|
||||
channel=channel,
|
||||
@@ -958,6 +978,35 @@ class MessageChain(ChainBase):
|
||||
)
|
||||
return True
|
||||
|
||||
def _update_interaction_message_feedback(
|
||||
self,
|
||||
channel: MessageChannel,
|
||||
source: str,
|
||||
original_message_id: Optional[Union[str, int]],
|
||||
original_chat_id: Optional[str],
|
||||
prompt: str,
|
||||
selected_label: str,
|
||||
title: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
在用户点击交互按钮后,立即更新原消息,明确显示已选择的内容。
|
||||
"""
|
||||
if not original_message_id or not original_chat_id:
|
||||
return
|
||||
|
||||
lines = [prompt.strip()]
|
||||
if selected_label:
|
||||
lines.append(f"已选择:{selected_label}")
|
||||
feedback_text = "\n\n".join(line for line in lines if line)
|
||||
self.edit_message(
|
||||
channel=channel,
|
||||
source=source,
|
||||
message_id=original_message_id,
|
||||
chat_id=original_chat_id,
|
||||
title=title,
|
||||
text=feedback_text,
|
||||
)
|
||||
|
||||
def _retry_transfer_history(
|
||||
self,
|
||||
history_id: int,
|
||||
|
||||
Reference in New Issue
Block a user