mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-25 17:54:43 +08:00
Refine agent background reply handling
This commit is contained in:
@@ -5,6 +5,7 @@ import traceback
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from langchain.agents import create_agent
|
||||
@@ -150,6 +151,15 @@ class _ThinkTagStripper:
|
||||
self.buffer = ""
|
||||
|
||||
|
||||
class ReplyMode(str, Enum):
|
||||
"""
|
||||
Agent 最终回复处理模式。
|
||||
"""
|
||||
|
||||
DISPATCH = "dispatch"
|
||||
CAPTURE_ONLY = "capture_only"
|
||||
|
||||
|
||||
class MoviePilotAgent:
|
||||
"""
|
||||
MoviePilot AI智能体(基于 LangChain v1 + LangGraph)
|
||||
@@ -171,7 +181,7 @@ class MoviePilotAgent:
|
||||
self._tool_context: Dict[str, object] = {}
|
||||
self.output_callback: Optional[Callable[[str], None]] = None
|
||||
self.force_streaming = False
|
||||
self.suppress_user_reply = False
|
||||
self.reply_mode = ReplyMode.DISPATCH
|
||||
self.persist_output_message = True
|
||||
self.allow_message_tools = True
|
||||
self._streamed_output = ""
|
||||
@@ -268,6 +278,13 @@ class MoviePilotAgent:
|
||||
"""
|
||||
return not self.channel or not self.source
|
||||
|
||||
@property
|
||||
def should_dispatch_reply(self) -> bool:
|
||||
"""
|
||||
是否应将最终回复真正发送到消息渠道。
|
||||
"""
|
||||
return self.reply_mode == ReplyMode.DISPATCH
|
||||
|
||||
def _should_stream(self) -> bool:
|
||||
"""
|
||||
判断是否应启用流式输出:
|
||||
@@ -490,7 +507,7 @@ class MoviePilotAgent:
|
||||
except Exception as e:
|
||||
error_message = f"处理消息时发生错误: {str(e)}"
|
||||
logger.error(error_message)
|
||||
if self.suppress_user_reply:
|
||||
if not self.should_dispatch_reply:
|
||||
raise
|
||||
await self.send_agent_message(error_message)
|
||||
return error_message
|
||||
@@ -543,7 +560,7 @@ class MoviePilotAgent:
|
||||
"""
|
||||
调用 LangGraph Agent 执行推理。
|
||||
根据运行环境选择不同的执行模式:
|
||||
- 后台任务模式(无渠道信息):非流式 LLM + ainvoke,仅广播最终结果
|
||||
- 后台任务模式(无渠道信息):非流式 LLM + ainvoke,由 reply_mode 决定是发送还是仅捕获
|
||||
- 渠道不支持消息编辑:非流式 LLM + ainvoke,完成后发送最终回复
|
||||
- 渠道支持消息编辑:流式 LLM + astream,实时推送 token
|
||||
"""
|
||||
@@ -602,10 +619,20 @@ class MoviePilotAgent:
|
||||
self._emit_output(unsent_text)
|
||||
if (
|
||||
remaining_text
|
||||
and not self.suppress_user_reply
|
||||
and self.should_dispatch_reply
|
||||
and not self._tool_context.get("user_reply_sent")
|
||||
):
|
||||
await self.send_agent_message(remaining_text)
|
||||
elif (
|
||||
remaining_text
|
||||
and self.persist_output_message
|
||||
and not self._tool_context.get("user_reply_sent")
|
||||
):
|
||||
title = "MoviePilot助手" if self.is_background else ""
|
||||
await self._save_agent_message_to_db(
|
||||
remaining_text,
|
||||
title=title,
|
||||
)
|
||||
elif streamed_text and self.persist_output_message:
|
||||
# 流式输出已发送全部内容,但未记录到数据库,补充保存消息记录
|
||||
await self._save_agent_message_to_db(streamed_text)
|
||||
@@ -639,11 +666,11 @@ class MoviePilotAgent:
|
||||
|
||||
if (
|
||||
final_text
|
||||
and not self.suppress_user_reply
|
||||
and self.should_dispatch_reply
|
||||
and not self._tool_context.get("user_reply_sent")
|
||||
):
|
||||
if self.is_background:
|
||||
# 后台任务仅广播最终回复,带标题
|
||||
# 后台任务发送最终回复时统一带标题
|
||||
await self.send_agent_message(
|
||||
final_text, title="MoviePilot助手"
|
||||
)
|
||||
@@ -732,6 +759,7 @@ class _MessageTask:
|
||||
channel: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
username: Optional[str] = None
|
||||
reply_mode: ReplyMode = ReplyMode.DISPATCH
|
||||
|
||||
|
||||
class AgentManager:
|
||||
@@ -815,6 +843,7 @@ class AgentManager:
|
||||
channel: str = None,
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
reply_mode: ReplyMode = ReplyMode.DISPATCH,
|
||||
) -> str:
|
||||
"""
|
||||
处理用户消息:将消息放入会话队列,按顺序依次处理。
|
||||
@@ -829,6 +858,7 @@ class AgentManager:
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username,
|
||||
reply_mode=reply_mode,
|
||||
)
|
||||
|
||||
# 获取或创建会话队列
|
||||
@@ -916,6 +946,7 @@ class AgentManager:
|
||||
source=task.source,
|
||||
username=task.username,
|
||||
)
|
||||
agent.reply_mode = task.reply_mode
|
||||
self.active_agents[session_id] = agent
|
||||
else:
|
||||
agent = self.active_agents[session_id]
|
||||
@@ -926,6 +957,7 @@ class AgentManager:
|
||||
agent.source = task.source
|
||||
if task.username:
|
||||
agent.username = task.username
|
||||
agent.reply_mode = task.reply_mode
|
||||
|
||||
return await agent.process(task.message, images=task.images, files=task.files)
|
||||
|
||||
@@ -992,10 +1024,10 @@ class AgentManager:
|
||||
|
||||
@staticmethod
|
||||
async def run_background_prompt(
|
||||
message: str,
|
||||
message: str,
|
||||
session_prefix: str = "__agent_background",
|
||||
output_callback: Optional[Callable[[str], None]] = None,
|
||||
suppress_user_reply: bool = False,
|
||||
reply_mode: ReplyMode = ReplyMode.CAPTURE_ONLY,
|
||||
persist_output_message: bool = True,
|
||||
allow_message_tools: bool = True,
|
||||
) -> None:
|
||||
@@ -1013,7 +1045,7 @@ class AgentManager:
|
||||
)
|
||||
agent.output_callback = output_callback
|
||||
agent.force_streaming = bool(output_callback)
|
||||
agent.suppress_user_reply = suppress_user_reply
|
||||
agent.reply_mode = reply_mode
|
||||
agent.persist_output_message = persist_output_message
|
||||
agent.allow_message_tools = allow_message_tools
|
||||
|
||||
@@ -1048,6 +1080,7 @@ class AgentManager:
|
||||
channel=None,
|
||||
source=None,
|
||||
username=settings.SUPERUSER,
|
||||
reply_mode=ReplyMode.DISPATCH,
|
||||
)
|
||||
|
||||
# 等待消息队列处理完成
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
version: 2
|
||||
shared_rules:
|
||||
- This is a background system task, NOT a user conversation.
|
||||
- Your final response will be broadcast as a notification.
|
||||
- Your final response will be consumed by the system. Keep it concise and task-focused.
|
||||
- Do NOT include greetings, explanations, or conversational text.
|
||||
- Respond in Chinese (中文).
|
||||
task_types:
|
||||
|
||||
@@ -9,7 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app import schemas
|
||||
from app.agent import prompt_manager, agent_manager
|
||||
from app.agent import ReplyMode, prompt_manager, agent_manager
|
||||
from app.chain.storage import StorageChain
|
||||
from app.core.config import settings, global_vars
|
||||
from app.core.event import eventmanager
|
||||
@@ -130,7 +130,7 @@ def _start_ai_redo_task(history_id: int, prompt: str, progress_key: str):
|
||||
message=prompt,
|
||||
session_prefix=f"__agent_manual_redo_{history_id}",
|
||||
output_callback=update_output,
|
||||
suppress_user_reply=True,
|
||||
reply_mode=ReplyMode.CAPTURE_ONLY,
|
||||
persist_output_message=False,
|
||||
allow_message_tools=False,
|
||||
)
|
||||
@@ -176,7 +176,7 @@ def _start_batch_ai_redo_task(
|
||||
message=prompt,
|
||||
session_prefix="__agent_manual_redo_batch",
|
||||
output_callback=update_output,
|
||||
suppress_user_reply=True,
|
||||
reply_mode=ReplyMode.CAPTURE_ONLY,
|
||||
persist_output_message=False,
|
||||
allow_message_tools=False,
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@ from pathlib import Path
|
||||
from typing import Any, Optional, Dict, Union, List
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
from app.agent import agent_manager, prompt_manager
|
||||
from app.agent import ReplyMode, agent_manager, prompt_manager
|
||||
from app.chain import ChainBase
|
||||
from app.chain.interaction import (
|
||||
MediaInteractionChain,
|
||||
@@ -635,7 +635,7 @@ class MessageChain(ChainBase):
|
||||
message=redo_prompt,
|
||||
session_prefix=f"__agent_manual_redo_{history_id}",
|
||||
output_callback=_capture_output,
|
||||
suppress_user_reply=True,
|
||||
reply_mode=ReplyMode.CAPTURE_ONLY,
|
||||
)
|
||||
await self.async_post_message(
|
||||
Notification(
|
||||
|
||||
@@ -210,7 +210,7 @@ class SearchChain(ChainBase):
|
||||
"""
|
||||
通过统一后台提示词机制执行资源推荐。
|
||||
"""
|
||||
from app.agent import agent_manager
|
||||
from app.agent import ReplyMode, agent_manager
|
||||
from app.agent.prompt import prompt_manager
|
||||
|
||||
prompt = prompt_manager.render_system_task_message(
|
||||
@@ -226,7 +226,7 @@ class SearchChain(ChainBase):
|
||||
message=prompt,
|
||||
session_prefix="__agent_search_recommend",
|
||||
output_callback=on_output,
|
||||
suppress_user_reply=True,
|
||||
reply_mode=ReplyMode.CAPTURE_ONLY,
|
||||
persist_output_message=False,
|
||||
allow_message_tools=False,
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union, Dict, Callable
|
||||
|
||||
from app import schemas
|
||||
from app.agent import prompt_manager, agent_manager
|
||||
from app.agent import ReplyMode, prompt_manager, agent_manager
|
||||
from app.chain import ChainBase
|
||||
from app.chain.media import MediaChain
|
||||
from app.chain.storage import StorageChain
|
||||
@@ -688,6 +688,7 @@ class FailedRetryScheduler:
|
||||
await agent_manager.run_background_prompt(
|
||||
message=self._build_retry_transfer_prompt(history_ids),
|
||||
session_prefix="__agent_retry_transfer_batch",
|
||||
reply_mode=ReplyMode.DISPATCH,
|
||||
)
|
||||
logger.info(
|
||||
f"智能体重试整理:批量处理完成 IDs=[{ids_str}] (group={group_key})"
|
||||
|
||||
@@ -4,7 +4,7 @@ from unittest.mock import AsyncMock, patch
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.agent import MoviePilotAgent
|
||||
from app.agent import MoviePilotAgent, AgentManager, ReplyMode
|
||||
from app.agent.memory import memory_manager
|
||||
|
||||
|
||||
@@ -25,39 +25,11 @@ class _FakeAgent:
|
||||
|
||||
|
||||
class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_background_non_streaming_still_sends_when_reply_not_suppressed(self):
|
||||
async def test_background_non_streaming_does_not_send_by_default(self):
|
||||
agent = MoviePilotAgent(session_id="bg-test", user_id="system")
|
||||
agent.channel = None
|
||||
agent.source = None
|
||||
agent.suppress_user_reply = False
|
||||
agent.persist_output_message = False
|
||||
agent._tool_context = {"user_reply_sent": False}
|
||||
agent._streamed_output = ""
|
||||
agent.stream_handler = SimpleNamespace(
|
||||
stop_streaming=AsyncMock(return_value=(False, ""))
|
||||
)
|
||||
agent._should_stream = lambda: False
|
||||
agent._create_agent = lambda streaming=False: _FakeAgent(
|
||||
[AIMessage(content="后台结果")]
|
||||
)
|
||||
agent.send_agent_message = AsyncMock()
|
||||
agent._save_agent_message_to_db = AsyncMock()
|
||||
|
||||
with patch.object(memory_manager, "save_agent_messages") as save_messages:
|
||||
await agent._execute_agent([])
|
||||
|
||||
agent.send_agent_message.assert_awaited_once_with(
|
||||
"后台结果", title="MoviePilot助手"
|
||||
)
|
||||
agent._save_agent_message_to_db.assert_not_awaited()
|
||||
save_messages.assert_called_once()
|
||||
self.assertEqual("后台结果", agent._streamed_output)
|
||||
|
||||
async def test_background_non_streaming_persists_without_sending_when_reply_suppressed(self):
|
||||
agent = MoviePilotAgent(session_id="bg-test", user_id="system")
|
||||
agent.channel = None
|
||||
agent.source = None
|
||||
agent.suppress_user_reply = True
|
||||
agent.reply_mode = ReplyMode.CAPTURE_ONLY
|
||||
agent.persist_output_message = True
|
||||
agent._tool_context = {"user_reply_sent": False}
|
||||
agent._streamed_output = ""
|
||||
@@ -81,6 +53,77 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
|
||||
save_messages.assert_called_once()
|
||||
self.assertEqual("后台结果", agent._streamed_output)
|
||||
|
||||
async def test_background_non_streaming_sends_when_reply_mode_dispatch(self):
|
||||
agent = MoviePilotAgent(session_id="bg-test", user_id="system")
|
||||
agent.channel = None
|
||||
agent.source = None
|
||||
agent.reply_mode = ReplyMode.DISPATCH
|
||||
agent.persist_output_message = False
|
||||
agent._tool_context = {"user_reply_sent": False}
|
||||
agent._streamed_output = ""
|
||||
agent.stream_handler = SimpleNamespace(
|
||||
stop_streaming=AsyncMock(return_value=(False, ""))
|
||||
)
|
||||
agent._should_stream = lambda: False
|
||||
agent._create_agent = lambda streaming=False: _FakeAgent(
|
||||
[AIMessage(content="后台结果")]
|
||||
)
|
||||
agent.send_agent_message = AsyncMock()
|
||||
agent._save_agent_message_to_db = AsyncMock()
|
||||
|
||||
with patch.object(memory_manager, "save_agent_messages") as save_messages:
|
||||
await agent._execute_agent([])
|
||||
|
||||
agent.send_agent_message.assert_awaited_once_with(
|
||||
"后台结果", title="MoviePilot助手"
|
||||
)
|
||||
agent._save_agent_message_to_db.assert_not_awaited()
|
||||
save_messages.assert_called_once()
|
||||
self.assertEqual("后台结果", agent._streamed_output)
|
||||
|
||||
async def test_background_non_streaming_persists_without_sending_when_capture_only(self):
|
||||
agent = MoviePilotAgent(session_id="bg-test", user_id="system")
|
||||
agent.channel = None
|
||||
agent.source = None
|
||||
agent.reply_mode = ReplyMode.CAPTURE_ONLY
|
||||
agent.persist_output_message = True
|
||||
agent._tool_context = {"user_reply_sent": False}
|
||||
agent._streamed_output = ""
|
||||
agent.stream_handler = SimpleNamespace(
|
||||
stop_streaming=AsyncMock(return_value=(False, ""))
|
||||
)
|
||||
agent._should_stream = lambda: False
|
||||
agent._create_agent = lambda streaming=False: _FakeAgent(
|
||||
[AIMessage(content="后台结果")]
|
||||
)
|
||||
agent.send_agent_message = AsyncMock()
|
||||
agent._save_agent_message_to_db = AsyncMock()
|
||||
|
||||
with patch.object(memory_manager, "save_agent_messages") as save_messages:
|
||||
await agent._execute_agent([])
|
||||
|
||||
agent.send_agent_message.assert_not_awaited()
|
||||
agent._save_agent_message_to_db.assert_awaited_once_with(
|
||||
"后台结果", title="MoviePilot助手"
|
||||
)
|
||||
save_messages.assert_called_once()
|
||||
self.assertEqual("后台结果", agent._streamed_output)
|
||||
|
||||
async def test_heartbeat_check_jobs_uses_dispatch_reply_mode(self):
|
||||
manager = AgentManager()
|
||||
|
||||
with (
|
||||
patch.object(manager, "_build_heartbeat_prompt", return_value="HEARTBEAT"),
|
||||
patch.object(manager, "process_message", new=AsyncMock()) as process_message,
|
||||
):
|
||||
await manager.heartbeat_check_jobs()
|
||||
|
||||
process_message.assert_awaited_once()
|
||||
self.assertEqual(
|
||||
ReplyMode.DISPATCH,
|
||||
process_message.await_args.kwargs["reply_mode"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -20,6 +20,7 @@ _stub_module("qbittorrentapi", TorrentFilesList=list)
|
||||
_stub_module("transmission_rpc", File=object)
|
||||
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.agent import ReplyMode
|
||||
from app.chain.search import SearchChain
|
||||
from app.core.config import settings
|
||||
|
||||
@@ -125,7 +126,7 @@ class SearchChainAIRecommendTest(unittest.IsolatedAsyncioTestCase):
|
||||
result = await chain._invoke_recommend_llm("Candidates")
|
||||
|
||||
self.assertEqual("[0, 2]", result)
|
||||
self.assertTrue(captured["suppress_user_reply"])
|
||||
self.assertEqual(ReplyMode.CAPTURE_ONLY, captured["reply_mode"])
|
||||
self.assertFalse(captured["persist_output_message"])
|
||||
self.assertFalse(captured["allow_message_tools"])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user