mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-06 20:42:43 +08:00
refactor: generalize agent interaction requests
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
"""Agent 用户按钮选择请求管理。"""
|
||||
"""Agent 客户端交互请求管理。"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
@@ -8,16 +8,16 @@ import uuid
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentChoiceOption:
|
||||
"""按钮选项。"""
|
||||
class AgentInteractionOption:
|
||||
"""交互选项。"""
|
||||
|
||||
label: str
|
||||
value: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingAgentChoice:
|
||||
"""待处理的按钮选择请求。"""
|
||||
class PendingAgentInteraction:
|
||||
"""待处理的 Agent 客户端交互请求。"""
|
||||
|
||||
request_id: str
|
||||
session_id: str
|
||||
@@ -25,29 +25,30 @@ class PendingAgentChoice:
|
||||
channel: Optional[str]
|
||||
source: Optional[str]
|
||||
username: Optional[str]
|
||||
title: Optional[str]
|
||||
prompt: str
|
||||
options: List[AgentChoiceOption]
|
||||
options: List[AgentInteractionOption]
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
class AgentUserChoiceManager:
|
||||
"""管理 Agent 发起的按钮选择请求。"""
|
||||
class AgentInteractionManager:
|
||||
"""管理 Agent 发起的客户端交互请求。"""
|
||||
|
||||
_ttl = timedelta(hours=24)
|
||||
|
||||
def __init__(self):
|
||||
self._pending_choices: Dict[str, PendingAgentChoice] = {}
|
||||
self._pending_interactions: Dict[str, PendingAgentInteraction] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def _cleanup_locked(self):
|
||||
expire_before = datetime.now() - self._ttl
|
||||
expired_ids = [
|
||||
request_id
|
||||
for request_id, request in self._pending_choices.items()
|
||||
for request_id, request in self._pending_interactions.items()
|
||||
if request.created_at < expire_before
|
||||
]
|
||||
for request_id in expired_ids:
|
||||
self._pending_choices.pop(request_id, None)
|
||||
self._pending_interactions.pop(request_id, None)
|
||||
|
||||
def create_request(
|
||||
self,
|
||||
@@ -56,25 +57,27 @@ class AgentUserChoiceManager:
|
||||
channel: Optional[str],
|
||||
source: Optional[str],
|
||||
username: Optional[str],
|
||||
title: Optional[str],
|
||||
prompt: str,
|
||||
options: List[AgentChoiceOption],
|
||||
) -> PendingAgentChoice:
|
||||
options: List[AgentInteractionOption],
|
||||
) -> PendingAgentInteraction:
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request_id = uuid.uuid4().hex[:12]
|
||||
while request_id in self._pending_choices:
|
||||
while request_id in self._pending_interactions:
|
||||
request_id = uuid.uuid4().hex[:12]
|
||||
request = PendingAgentChoice(
|
||||
request = PendingAgentInteraction(
|
||||
request_id=request_id,
|
||||
session_id=session_id,
|
||||
user_id=str(user_id),
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username,
|
||||
title=title,
|
||||
prompt=prompt,
|
||||
options=options,
|
||||
)
|
||||
self._pending_choices[request_id] = request
|
||||
self._pending_interactions[request_id] = request
|
||||
return request
|
||||
|
||||
def resolve(
|
||||
@@ -82,10 +85,10 @@ class AgentUserChoiceManager:
|
||||
request_id: str,
|
||||
option_index: int,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Optional[tuple[PendingAgentChoice, AgentChoiceOption]]:
|
||||
) -> Optional[tuple[PendingAgentInteraction, AgentInteractionOption]]:
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request = self._pending_choices.get(request_id)
|
||||
request = self._pending_interactions.get(request_id)
|
||||
if not request:
|
||||
return None
|
||||
if user_id is not None and str(request.user_id) != str(user_id):
|
||||
@@ -93,12 +96,12 @@ class AgentUserChoiceManager:
|
||||
if option_index < 1 or option_index > len(request.options):
|
||||
return None
|
||||
option = request.options[option_index - 1]
|
||||
self._pending_choices.pop(request_id, None)
|
||||
self._pending_interactions.pop(request_id, None)
|
||||
return request, option
|
||||
|
||||
def clear(self):
|
||||
with self._lock:
|
||||
self._pending_choices.clear()
|
||||
self._pending_interactions.clear()
|
||||
|
||||
|
||||
agent_user_choice_manager = AgentUserChoiceManager()
|
||||
agent_interaction_manager = AgentInteractionManager()
|
||||
@@ -5,7 +5,10 @@ from typing import List, Optional, Type
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.agent.user_choice import AgentChoiceOption, agent_user_choice_manager
|
||||
from app.agent.interaction import (
|
||||
AgentInteractionOption,
|
||||
agent_interaction_manager,
|
||||
)
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotificationType
|
||||
from app.schemas.message import ChannelCapabilityManager
|
||||
@@ -111,15 +114,18 @@ class AskUserChoiceTool(MoviePilotTool):
|
||||
return f"当前渠道最多支持 {max_options} 个按钮选项"
|
||||
|
||||
choice_options = [
|
||||
AgentChoiceOption(label=option.label.strip(), value=option.value.strip())
|
||||
AgentInteractionOption(
|
||||
label=option.label.strip(), value=option.value.strip()
|
||||
)
|
||||
for option in options
|
||||
]
|
||||
request = agent_user_choice_manager.create_request(
|
||||
request = agent_interaction_manager.create_request(
|
||||
session_id=self._session_id,
|
||||
user_id=str(self._user_id),
|
||||
channel=channel.value,
|
||||
source=self._source,
|
||||
username=self._username,
|
||||
title=title,
|
||||
prompt=message.strip(),
|
||||
options=choice_options,
|
||||
)
|
||||
@@ -130,7 +136,9 @@ class AskUserChoiceTool(MoviePilotTool):
|
||||
current_row.append(
|
||||
{
|
||||
"text": self._truncate_button_text(option.label, max_text_length),
|
||||
"callback_data": f"agent_choice:{request.request_id}:{index}",
|
||||
"callback_data": (
|
||||
f"agent_interaction:choice:{request.request_id}:{index}"
|
||||
),
|
||||
}
|
||||
)
|
||||
if len(current_row) >= max_per_row:
|
||||
|
||||
@@ -11,7 +11,7 @@ import uuid
|
||||
import base64
|
||||
|
||||
from app.agent import agent_manager
|
||||
from app.agent.user_choice import agent_user_choice_manager
|
||||
from app.agent.interaction import agent_interaction_manager
|
||||
from app.chain import ChainBase
|
||||
from app.chain.download import DownloadChain
|
||||
from app.chain.media import MediaChain
|
||||
@@ -791,6 +791,8 @@ class MessageChain(ChainBase):
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id,
|
||||
):
|
||||
return
|
||||
|
||||
@@ -895,11 +897,18 @@ class MessageChain(ChainBase):
|
||||
"""
|
||||
解析 Agent 按钮选择回调。
|
||||
"""
|
||||
if not callback_data.startswith("agent_choice:"):
|
||||
return None
|
||||
try:
|
||||
_, request_id, option_index = callback_data.split(":", 2)
|
||||
except ValueError:
|
||||
if callback_data.startswith("agent_interaction:choice:"):
|
||||
try:
|
||||
_, _, request_id, option_index = callback_data.split(":", 3)
|
||||
except ValueError:
|
||||
return None
|
||||
elif callback_data.startswith("agent_choice:"):
|
||||
# 兼容旧格式,避免已发送的按钮失效
|
||||
try:
|
||||
_, request_id, option_index = callback_data.split(":", 2)
|
||||
except ValueError:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
if not request_id or not option_index.isdigit():
|
||||
return None
|
||||
@@ -912,6 +921,8 @@ class MessageChain(ChainBase):
|
||||
source: str,
|
||||
userid: Union[str, int],
|
||||
username: str,
|
||||
original_message_id: Optional[Union[str, int]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
将 Agent 按钮选择回传为同一会话中的下一条用户消息。
|
||||
@@ -921,7 +932,7 @@ class MessageChain(ChainBase):
|
||||
return False
|
||||
|
||||
request_id, option_index = callback
|
||||
resolved = agent_user_choice_manager.resolve(
|
||||
resolved = agent_interaction_manager.resolve(
|
||||
request_id=request_id,
|
||||
option_index=option_index,
|
||||
user_id=str(userid),
|
||||
@@ -940,6 +951,15 @@ class MessageChain(ChainBase):
|
||||
|
||||
request, option = resolved
|
||||
selected_text = option.value
|
||||
self._update_interaction_message_feedback(
|
||||
channel=channel,
|
||||
source=source,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id,
|
||||
title=request.title,
|
||||
prompt=request.prompt,
|
||||
selected_label=option.label,
|
||||
)
|
||||
self._bind_session_id(userid, request.session_id)
|
||||
self._record_user_message(
|
||||
channel=channel,
|
||||
@@ -958,6 +978,35 @@ class MessageChain(ChainBase):
|
||||
)
|
||||
return True
|
||||
|
||||
def _update_interaction_message_feedback(
|
||||
self,
|
||||
channel: MessageChannel,
|
||||
source: str,
|
||||
original_message_id: Optional[Union[str, int]],
|
||||
original_chat_id: Optional[str],
|
||||
prompt: str,
|
||||
selected_label: str,
|
||||
title: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
在用户点击交互按钮后,立即更新原消息,明确显示已选择的内容。
|
||||
"""
|
||||
if not original_message_id or not original_chat_id:
|
||||
return
|
||||
|
||||
lines = [prompt.strip()]
|
||||
if selected_label:
|
||||
lines.append(f"已选择:{selected_label}")
|
||||
feedback_text = "\n\n".join(line for line in lines if line)
|
||||
self.edit_message(
|
||||
channel=channel,
|
||||
source=source,
|
||||
message_id=original_message_id,
|
||||
chat_id=original_chat_id,
|
||||
title=title,
|
||||
text=feedback_text,
|
||||
)
|
||||
|
||||
def _retry_transfer_history(
|
||||
self,
|
||||
history_id: int,
|
||||
|
||||
@@ -8,14 +8,17 @@ from app.agent.tools.impl.ask_user_choice import (
|
||||
AskUserChoiceTool,
|
||||
UserChoiceOptionInput,
|
||||
)
|
||||
from app.agent.user_choice import AgentChoiceOption, agent_user_choice_manager
|
||||
from app.agent.interaction import (
|
||||
AgentInteractionOption,
|
||||
agent_interaction_manager,
|
||||
)
|
||||
from app.chain.message import MessageChain
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class TestAgentUserChoice(unittest.TestCase):
|
||||
class TestAgentInteraction(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
agent_user_choice_manager.clear()
|
||||
agent_interaction_manager.clear()
|
||||
|
||||
def test_prompt_injects_choice_tool_hint_only_for_button_channels(self):
|
||||
telegram_prompt = prompt_manager.get_agent_prompt(
|
||||
@@ -82,30 +85,76 @@ class TestAgentUserChoice(unittest.TestCase):
|
||||
self.assertEqual(len(notification.buttons[0]), 2)
|
||||
|
||||
callback_data = notification.buttons[0][0]["callback_data"]
|
||||
_, request_id, option_index = callback_data.split(":")
|
||||
resolved = agent_user_choice_manager.resolve(request_id, int(option_index), "10001")
|
||||
_, _, request_id, option_index = callback_data.split(":")
|
||||
resolved = agent_interaction_manager.resolve(
|
||||
request_id, int(option_index), "10001"
|
||||
)
|
||||
self.assertIsNotNone(resolved)
|
||||
_, option = resolved
|
||||
self.assertEqual(option.value, "继续下载")
|
||||
|
||||
def test_agent_choice_callback_routes_selected_value_back_to_agent(self):
|
||||
def test_agent_interaction_callback_routes_selected_value_back_to_agent(self):
|
||||
chain = MessageChain()
|
||||
request = agent_user_choice_manager.create_request(
|
||||
request = agent_interaction_manager.create_request(
|
||||
session_id="session-choice",
|
||||
user_id="10001",
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-test",
|
||||
username="tester",
|
||||
title="需要你的选择",
|
||||
prompt="请选择",
|
||||
options=[
|
||||
AgentChoiceOption(label="电影", value="我选择电影"),
|
||||
AgentChoiceOption(label="电视剧", value="我选择电视剧"),
|
||||
AgentInteractionOption(label="电影", value="我选择电影"),
|
||||
AgentInteractionOption(label="电视剧", value="我选择电视剧"),
|
||||
],
|
||||
)
|
||||
|
||||
with patch.object(chain, "_handle_ai_message") as handle_ai_message, patch.object(
|
||||
chain.messagehelper, "put"
|
||||
) as message_put, patch.object(chain.messageoper, "add") as message_add:
|
||||
) as message_put, patch.object(chain.messageoper, "add") as message_add, patch.object(
|
||||
chain, "edit_message", return_value=True
|
||||
) as edit_message:
|
||||
chain._handle_callback(
|
||||
text=f"CALLBACK:agent_interaction:choice:{request.request_id}:1",
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
original_message_id=123,
|
||||
original_chat_id="456",
|
||||
)
|
||||
|
||||
handle_ai_message.assert_called_once()
|
||||
edit_message.assert_called_once_with(
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
message_id=123,
|
||||
chat_id="456",
|
||||
title="需要你的选择",
|
||||
text="请选择\n\n已选择:电影",
|
||||
)
|
||||
kwargs = handle_ai_message.call_args.kwargs
|
||||
self.assertEqual(kwargs["text"], "我选择电影")
|
||||
self.assertEqual(kwargs["session_id"], "session-choice")
|
||||
message_put.assert_called_once()
|
||||
message_add.assert_called_once()
|
||||
|
||||
def test_legacy_agent_choice_callback_still_supported(self):
|
||||
chain = MessageChain()
|
||||
request = agent_interaction_manager.create_request(
|
||||
session_id="session-choice",
|
||||
user_id="10001",
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-test",
|
||||
username="tester",
|
||||
title=None,
|
||||
prompt="请选择",
|
||||
options=[AgentInteractionOption(label="电影", value="我选择电影")],
|
||||
)
|
||||
|
||||
with patch.object(chain, "_handle_ai_message") as handle_ai_message, patch.object(
|
||||
chain.messagehelper, "put"
|
||||
), patch.object(chain.messageoper, "add"):
|
||||
chain._handle_callback(
|
||||
text=f"CALLBACK:agent_choice:{request.request_id}:1",
|
||||
channel=MessageChannel.Telegram,
|
||||
@@ -115,11 +164,6 @@ class TestAgentUserChoice(unittest.TestCase):
|
||||
)
|
||||
|
||||
handle_ai_message.assert_called_once()
|
||||
kwargs = handle_ai_message.call_args.kwargs
|
||||
self.assertEqual(kwargs["text"], "我选择电影")
|
||||
self.assertEqual(kwargs["session_id"], "session-choice")
|
||||
message_put.assert_called_once()
|
||||
message_add.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Reference in New Issue
Block a user