Refine agent background reply handling

This commit is contained in:
jxxghp
2026-04-30 00:25:23 +08:00
parent 11478faff3
commit 6532c60a3c
8 changed files with 128 additions and 50 deletions

View File

@@ -5,6 +5,7 @@ import traceback
import uuid
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Any, Callable, Dict, List, Optional
from langchain.agents import create_agent
@@ -150,6 +151,15 @@ class _ThinkTagStripper:
self.buffer = ""
class ReplyMode(str, Enum):
"""
Agent 最终回复处理模式。
"""
DISPATCH = "dispatch"
CAPTURE_ONLY = "capture_only"
class MoviePilotAgent:
"""
MoviePilot AI智能体基于 LangChain v1 + LangGraph
@@ -171,7 +181,7 @@ class MoviePilotAgent:
self._tool_context: Dict[str, object] = {}
self.output_callback: Optional[Callable[[str], None]] = None
self.force_streaming = False
self.suppress_user_reply = False
self.reply_mode = ReplyMode.DISPATCH
self.persist_output_message = True
self.allow_message_tools = True
self._streamed_output = ""
@@ -268,6 +278,13 @@ class MoviePilotAgent:
"""
return not self.channel or not self.source
@property
def should_dispatch_reply(self) -> bool:
"""
是否应将最终回复真正发送到消息渠道。
"""
return self.reply_mode == ReplyMode.DISPATCH
def _should_stream(self) -> bool:
"""
判断是否应启用流式输出:
@@ -490,7 +507,7 @@ class MoviePilotAgent:
except Exception as e:
error_message = f"处理消息时发生错误: {str(e)}"
logger.error(error_message)
if self.suppress_user_reply:
if not self.should_dispatch_reply:
raise
await self.send_agent_message(error_message)
return error_message
@@ -543,7 +560,7 @@ class MoviePilotAgent:
"""
调用 LangGraph Agent 执行推理。
根据运行环境选择不同的执行模式:
- 后台任务模式(无渠道信息):非流式 LLM + ainvoke仅广播最终结果
- 后台任务模式(无渠道信息):非流式 LLM + ainvoke由 reply_mode 决定是发送还是仅捕获
- 渠道不支持消息编辑:非流式 LLM + ainvoke完成后发送最终回复
- 渠道支持消息编辑:流式 LLM + astream实时推送 token
"""
@@ -602,10 +619,20 @@ class MoviePilotAgent:
self._emit_output(unsent_text)
if (
remaining_text
and not self.suppress_user_reply
and self.should_dispatch_reply
and not self._tool_context.get("user_reply_sent")
):
await self.send_agent_message(remaining_text)
elif (
remaining_text
and self.persist_output_message
and not self._tool_context.get("user_reply_sent")
):
title = "MoviePilot助手" if self.is_background else ""
await self._save_agent_message_to_db(
remaining_text,
title=title,
)
elif streamed_text and self.persist_output_message:
# 流式输出已发送全部内容,但未记录到数据库,补充保存消息记录
await self._save_agent_message_to_db(streamed_text)
@@ -639,11 +666,11 @@ class MoviePilotAgent:
if (
final_text
and not self.suppress_user_reply
and self.should_dispatch_reply
and not self._tool_context.get("user_reply_sent")
):
if self.is_background:
# 后台任务仅广播最终回复带标题
# 后台任务发送最终回复时统一带标题
await self.send_agent_message(
final_text, title="MoviePilot助手"
)
@@ -732,6 +759,7 @@ class _MessageTask:
channel: Optional[str] = None
source: Optional[str] = None
username: Optional[str] = None
reply_mode: ReplyMode = ReplyMode.DISPATCH
class AgentManager:
@@ -815,6 +843,7 @@ class AgentManager:
channel: str = None,
source: str = None,
username: str = None,
reply_mode: ReplyMode = ReplyMode.DISPATCH,
) -> str:
"""
处理用户消息:将消息放入会话队列,按顺序依次处理。
@@ -829,6 +858,7 @@ class AgentManager:
channel=channel,
source=source,
username=username,
reply_mode=reply_mode,
)
# 获取或创建会话队列
@@ -916,6 +946,7 @@ class AgentManager:
source=task.source,
username=task.username,
)
agent.reply_mode = task.reply_mode
self.active_agents[session_id] = agent
else:
agent = self.active_agents[session_id]
@@ -926,6 +957,7 @@ class AgentManager:
agent.source = task.source
if task.username:
agent.username = task.username
agent.reply_mode = task.reply_mode
return await agent.process(task.message, images=task.images, files=task.files)
@@ -992,10 +1024,10 @@ class AgentManager:
@staticmethod
async def run_background_prompt(
message: str,
message: str,
session_prefix: str = "__agent_background",
output_callback: Optional[Callable[[str], None]] = None,
suppress_user_reply: bool = False,
reply_mode: ReplyMode = ReplyMode.CAPTURE_ONLY,
persist_output_message: bool = True,
allow_message_tools: bool = True,
) -> None:
@@ -1013,7 +1045,7 @@ class AgentManager:
)
agent.output_callback = output_callback
agent.force_streaming = bool(output_callback)
agent.suppress_user_reply = suppress_user_reply
agent.reply_mode = reply_mode
agent.persist_output_message = persist_output_message
agent.allow_message_tools = allow_message_tools
@@ -1048,6 +1080,7 @@ class AgentManager:
channel=None,
source=None,
username=settings.SUPERUSER,
reply_mode=ReplyMode.DISPATCH,
)
# 等待消息队列处理完成

View File

@@ -1,7 +1,7 @@
version: 2
shared_rules:
- This is a background system task, NOT a user conversation.
- Your final response will be broadcast as a notification.
- Your final response will be consumed by the system. Keep it concise and task-focused.
- Do NOT include greetings, explanations, or conversational text.
- Respond in Chinese (中文).
task_types:

View File

@@ -9,7 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app import schemas
from app.agent import prompt_manager, agent_manager
from app.agent import ReplyMode, prompt_manager, agent_manager
from app.chain.storage import StorageChain
from app.core.config import settings, global_vars
from app.core.event import eventmanager
@@ -130,7 +130,7 @@ def _start_ai_redo_task(history_id: int, prompt: str, progress_key: str):
message=prompt,
session_prefix=f"__agent_manual_redo_{history_id}",
output_callback=update_output,
suppress_user_reply=True,
reply_mode=ReplyMode.CAPTURE_ONLY,
persist_output_message=False,
allow_message_tools=False,
)
@@ -176,7 +176,7 @@ def _start_batch_ai_redo_task(
message=prompt,
session_prefix="__agent_manual_redo_batch",
output_callback=update_output,
suppress_user_reply=True,
reply_mode=ReplyMode.CAPTURE_ONLY,
persist_output_message=False,
allow_message_tools=False,
)

View File

@@ -9,7 +9,7 @@ from pathlib import Path
from typing import Any, Optional, Dict, Union, List
from urllib.parse import unquote, urlparse
from app.agent import agent_manager, prompt_manager
from app.agent import ReplyMode, agent_manager, prompt_manager
from app.chain import ChainBase
from app.chain.interaction import (
MediaInteractionChain,
@@ -635,7 +635,7 @@ class MessageChain(ChainBase):
message=redo_prompt,
session_prefix=f"__agent_manual_redo_{history_id}",
output_callback=_capture_output,
suppress_user_reply=True,
reply_mode=ReplyMode.CAPTURE_ONLY,
)
await self.async_post_message(
Notification(

View File

@@ -210,7 +210,7 @@ class SearchChain(ChainBase):
"""
通过统一后台提示词机制执行资源推荐。
"""
from app.agent import agent_manager
from app.agent import ReplyMode, agent_manager
from app.agent.prompt import prompt_manager
prompt = prompt_manager.render_system_task_message(
@@ -226,7 +226,7 @@ class SearchChain(ChainBase):
message=prompt,
session_prefix="__agent_search_recommend",
output_callback=on_output,
suppress_user_reply=True,
reply_mode=ReplyMode.CAPTURE_ONLY,
persist_output_message=False,
allow_message_tools=False,
)

View File

@@ -8,7 +8,7 @@ from pathlib import Path
from typing import List, Optional, Tuple, Union, Dict, Callable
from app import schemas
from app.agent import prompt_manager, agent_manager
from app.agent import ReplyMode, prompt_manager, agent_manager
from app.chain import ChainBase
from app.chain.media import MediaChain
from app.chain.storage import StorageChain
@@ -688,6 +688,7 @@ class FailedRetryScheduler:
await agent_manager.run_background_prompt(
message=self._build_retry_transfer_prompt(history_ids),
session_prefix="__agent_retry_transfer_batch",
reply_mode=ReplyMode.DISPATCH,
)
logger.info(
f"智能体重试整理:批量处理完成 IDs=[{ids_str}] (group={group_key})"

View File

@@ -4,7 +4,7 @@ from unittest.mock import AsyncMock, patch
from langchain_core.messages import AIMessage
from app.agent import MoviePilotAgent
from app.agent import MoviePilotAgent, AgentManager, ReplyMode
from app.agent.memory import memory_manager
@@ -25,39 +25,11 @@ class _FakeAgent:
class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
async def test_background_non_streaming_still_sends_when_reply_not_suppressed(self):
async def test_background_non_streaming_does_not_send_by_default(self):
agent = MoviePilotAgent(session_id="bg-test", user_id="system")
agent.channel = None
agent.source = None
agent.suppress_user_reply = False
agent.persist_output_message = False
agent._tool_context = {"user_reply_sent": False}
agent._streamed_output = ""
agent.stream_handler = SimpleNamespace(
stop_streaming=AsyncMock(return_value=(False, ""))
)
agent._should_stream = lambda: False
agent._create_agent = lambda streaming=False: _FakeAgent(
[AIMessage(content="后台结果")]
)
agent.send_agent_message = AsyncMock()
agent._save_agent_message_to_db = AsyncMock()
with patch.object(memory_manager, "save_agent_messages") as save_messages:
await agent._execute_agent([])
agent.send_agent_message.assert_awaited_once_with(
"后台结果", title="MoviePilot助手"
)
agent._save_agent_message_to_db.assert_not_awaited()
save_messages.assert_called_once()
self.assertEqual("后台结果", agent._streamed_output)
async def test_background_non_streaming_persists_without_sending_when_reply_suppressed(self):
agent = MoviePilotAgent(session_id="bg-test", user_id="system")
agent.channel = None
agent.source = None
agent.suppress_user_reply = True
agent.reply_mode = ReplyMode.CAPTURE_ONLY
agent.persist_output_message = True
agent._tool_context = {"user_reply_sent": False}
agent._streamed_output = ""
@@ -81,6 +53,77 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
save_messages.assert_called_once()
self.assertEqual("后台结果", agent._streamed_output)
async def test_background_non_streaming_sends_when_reply_mode_dispatch(self):
agent = MoviePilotAgent(session_id="bg-test", user_id="system")
agent.channel = None
agent.source = None
agent.reply_mode = ReplyMode.DISPATCH
agent.persist_output_message = False
agent._tool_context = {"user_reply_sent": False}
agent._streamed_output = ""
agent.stream_handler = SimpleNamespace(
stop_streaming=AsyncMock(return_value=(False, ""))
)
agent._should_stream = lambda: False
agent._create_agent = lambda streaming=False: _FakeAgent(
[AIMessage(content="后台结果")]
)
agent.send_agent_message = AsyncMock()
agent._save_agent_message_to_db = AsyncMock()
with patch.object(memory_manager, "save_agent_messages") as save_messages:
await agent._execute_agent([])
agent.send_agent_message.assert_awaited_once_with(
"后台结果", title="MoviePilot助手"
)
agent._save_agent_message_to_db.assert_not_awaited()
save_messages.assert_called_once()
self.assertEqual("后台结果", agent._streamed_output)
async def test_background_non_streaming_persists_without_sending_when_capture_only(self):
agent = MoviePilotAgent(session_id="bg-test", user_id="system")
agent.channel = None
agent.source = None
agent.reply_mode = ReplyMode.CAPTURE_ONLY
agent.persist_output_message = True
agent._tool_context = {"user_reply_sent": False}
agent._streamed_output = ""
agent.stream_handler = SimpleNamespace(
stop_streaming=AsyncMock(return_value=(False, ""))
)
agent._should_stream = lambda: False
agent._create_agent = lambda streaming=False: _FakeAgent(
[AIMessage(content="后台结果")]
)
agent.send_agent_message = AsyncMock()
agent._save_agent_message_to_db = AsyncMock()
with patch.object(memory_manager, "save_agent_messages") as save_messages:
await agent._execute_agent([])
agent.send_agent_message.assert_not_awaited()
agent._save_agent_message_to_db.assert_awaited_once_with(
"后台结果", title="MoviePilot助手"
)
save_messages.assert_called_once()
self.assertEqual("后台结果", agent._streamed_output)
async def test_heartbeat_check_jobs_uses_dispatch_reply_mode(self):
manager = AgentManager()
with (
patch.object(manager, "_build_heartbeat_prompt", return_value="HEARTBEAT"),
patch.object(manager, "process_message", new=AsyncMock()) as process_message,
):
await manager.heartbeat_check_jobs()
process_message.assert_awaited_once()
self.assertEqual(
ReplyMode.DISPATCH,
process_message.await_args.kwargs["reply_mode"],
)
if __name__ == "__main__":
unittest.main()

View File

@@ -20,6 +20,7 @@ _stub_module("qbittorrentapi", TorrentFilesList=list)
_stub_module("transmission_rpc", File=object)
from app.agent.tools.factory import MoviePilotToolFactory
from app.agent import ReplyMode
from app.chain.search import SearchChain
from app.core.config import settings
@@ -125,7 +126,7 @@ class SearchChainAIRecommendTest(unittest.IsolatedAsyncioTestCase):
result = await chain._invoke_recommend_llm("Candidates")
self.assertEqual("[0, 2]", result)
self.assertTrue(captured["suppress_user_reply"])
self.assertEqual(ReplyMode.CAPTURE_ONLY, captured["reply_mode"])
self.assertFalse(captured["persist_output_message"])
self.assertFalse(captured["allow_message_tools"])