mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-06 20:42:43 +08:00
suppress channel notifications for ui background tasks
This commit is contained in:
@@ -3,6 +3,7 @@ import json
|
||||
import re
|
||||
import traceback
|
||||
import uuid
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
@@ -32,6 +33,7 @@ from app.agent.runtime import agent_runtime_manager
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.core.message_context import suppress_message_channel
|
||||
from app.helper.llm import LLMHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotificationType
|
||||
@@ -172,6 +174,7 @@ class MoviePilotAgent:
|
||||
self.output_callback: Optional[Callable[[str], None]] = None
|
||||
self.force_streaming = False
|
||||
self.suppress_user_reply = False
|
||||
self.persist_output_message = True
|
||||
self._streamed_output = ""
|
||||
self._session_usage = _SessionUsageSnapshot()
|
||||
|
||||
@@ -603,7 +606,7 @@ class MoviePilotAgent:
|
||||
and not self._tool_context.get("user_reply_sent")
|
||||
):
|
||||
await self.send_agent_message(remaining_text)
|
||||
elif streamed_text:
|
||||
elif streamed_text and self.persist_output_message:
|
||||
# 流式输出已发送全部内容,但未记录到数据库,补充保存消息记录
|
||||
await self._save_agent_message_to_db(streamed_text)
|
||||
|
||||
@@ -986,6 +989,8 @@ class AgentManager:
|
||||
session_prefix: str = "__agent_background",
|
||||
output_callback: Optional[Callable[[str], None]] = None,
|
||||
suppress_user_reply: bool = False,
|
||||
persist_output_message: bool = True,
|
||||
suppress_message_channel_dispatch: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
以独立后台会话执行一段 prompt。
|
||||
@@ -1002,9 +1007,15 @@ class AgentManager:
|
||||
agent.output_callback = output_callback
|
||||
agent.force_streaming = bool(output_callback)
|
||||
agent.suppress_user_reply = suppress_user_reply
|
||||
agent.persist_output_message = persist_output_message
|
||||
|
||||
try:
|
||||
await agent.process(message)
|
||||
with (
|
||||
suppress_message_channel()
|
||||
if suppress_message_channel_dispatch
|
||||
else nullcontext()
|
||||
):
|
||||
await agent.process(message)
|
||||
finally:
|
||||
await agent.cleanup()
|
||||
memory_manager.clear_memory(session_id, user_id)
|
||||
|
||||
@@ -131,6 +131,8 @@ def _start_ai_redo_task(history_id: int, prompt: str, progress_key: str):
|
||||
session_prefix=f"__agent_manual_redo_{history_id}",
|
||||
output_callback=update_output,
|
||||
suppress_user_reply=True,
|
||||
persist_output_message=False,
|
||||
suppress_message_channel_dispatch=True,
|
||||
)
|
||||
progress.update(
|
||||
text="智能助手整理完成",
|
||||
@@ -175,6 +177,8 @@ def _start_batch_ai_redo_task(
|
||||
session_prefix="__agent_manual_redo_batch",
|
||||
output_callback=update_output,
|
||||
suppress_user_reply=True,
|
||||
persist_output_message=False,
|
||||
suppress_message_channel_dispatch=True,
|
||||
)
|
||||
progress.update(
|
||||
text="智能助手批量整理完成",
|
||||
|
||||
@@ -17,6 +17,7 @@ from app.core.config import settings
|
||||
from app.core.context import Context, MediaInfo, TorrentInfo
|
||||
from app.core.event import EventManager
|
||||
from app.core.meta import MetaBase
|
||||
from app.core.message_context import is_message_channel_suppressed
|
||||
from app.core.module import ModuleManager
|
||||
from app.core.plugin import PluginManager
|
||||
from app.db.message_oper import MessageOper
|
||||
@@ -1136,6 +1137,9 @@ class ChainBase(metaclass=ABCMeta):
|
||||
# 保存消息
|
||||
self.messagehelper.put(message, role="user", title=message.title)
|
||||
self.messageoper.add(**message.model_dump())
|
||||
if is_message_channel_suppressed():
|
||||
logger.info("当前上下文已禁用消息渠道派发,仅保存消息记录")
|
||||
return
|
||||
dispatch_message = self._normalize_notification_for_dispatch(message)
|
||||
# 发送消息按设置隔离
|
||||
if not dispatch_message.userid and dispatch_message.mtype:
|
||||
@@ -1253,6 +1257,9 @@ class ChainBase(metaclass=ABCMeta):
|
||||
# 保存消息
|
||||
self.messagehelper.put(message, role="user", title=message.title)
|
||||
await self.messageoper.async_add(**message.model_dump())
|
||||
if is_message_channel_suppressed():
|
||||
logger.info("当前上下文已禁用消息渠道派发,仅保存消息记录")
|
||||
return
|
||||
dispatch_message = self._normalize_notification_for_dispatch(message)
|
||||
# 发送消息按设置隔离
|
||||
if not dispatch_message.userid and dispatch_message.mtype:
|
||||
@@ -1347,6 +1354,9 @@ class ChainBase(metaclass=ABCMeta):
|
||||
message, role="user", note=note_list, title=message.title
|
||||
)
|
||||
self.messageoper.add(**message.model_dump(), note=note_list)
|
||||
if is_message_channel_suppressed():
|
||||
logger.info("当前上下文已禁用消息渠道派发,仅保存媒体消息记录")
|
||||
return None
|
||||
dispatch_message = self._normalize_notification_for_dispatch(message)
|
||||
return self.messagequeue.send_message(
|
||||
"post_medias_message",
|
||||
@@ -1369,6 +1379,9 @@ class ChainBase(metaclass=ABCMeta):
|
||||
message, role="user", note=note_list, title=message.title
|
||||
)
|
||||
self.messageoper.add(**message.model_dump(), note=note_list)
|
||||
if is_message_channel_suppressed():
|
||||
logger.info("当前上下文已禁用消息渠道派发,仅保存种子消息记录")
|
||||
return None
|
||||
dispatch_message = self._normalize_notification_for_dispatch(message)
|
||||
return self.messagequeue.send_message(
|
||||
"post_torrents_message",
|
||||
|
||||
@@ -227,6 +227,8 @@ class SearchChain(ChainBase):
|
||||
session_prefix="__agent_search_recommend",
|
||||
output_callback=on_output,
|
||||
suppress_user_reply=True,
|
||||
persist_output_message=False,
|
||||
suppress_message_channel_dispatch=True,
|
||||
)
|
||||
return full_output[0].strip()
|
||||
|
||||
|
||||
26
app/core/message_context.py
Normal file
26
app/core/message_context.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import contextvars
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator
|
||||
|
||||
_suppress_message_channel = contextvars.ContextVar(
|
||||
"suppress_message_channel", default=False
|
||||
)
|
||||
|
||||
|
||||
def is_message_channel_suppressed() -> bool:
|
||||
"""
|
||||
当前上下文是否禁止向外部消息渠道派发通知。
|
||||
"""
|
||||
return bool(_suppress_message_channel.get())
|
||||
|
||||
|
||||
@contextmanager
|
||||
def suppress_message_channel() -> Iterator[None]:
|
||||
"""
|
||||
在当前上下文中临时禁用外部消息渠道派发。
|
||||
"""
|
||||
token = _suppress_message_channel.set(True)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_suppress_message_channel.reset(token)
|
||||
@@ -5,6 +5,8 @@ from types import SimpleNamespace
|
||||
from types import ModuleType
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import app.chain as chain_module
|
||||
|
||||
|
||||
def _stub_module(name: str, **attrs):
|
||||
module = sys.modules.get(name)
|
||||
@@ -21,6 +23,9 @@ _stub_module("transmission_rpc", File=object)
|
||||
|
||||
from app.chain.search import SearchChain
|
||||
from app.core.config import settings
|
||||
from app.core.message_context import suppress_message_channel
|
||||
from app.schemas import Notification
|
||||
from app.schemas.types import NotificationType
|
||||
|
||||
|
||||
def _make_result(title: str, size: int, seeders: int):
|
||||
@@ -98,6 +103,36 @@ class SearchChainAIRecommendTest(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertFalse(SearchChain._ai_recommend_running)
|
||||
self.assertIsNone(SearchChain._ai_recommend_task)
|
||||
|
||||
async def test_invoke_recommend_llm_disables_output_message_persistence(self):
|
||||
chain = self._make_chain()
|
||||
from app.agent import agent_manager
|
||||
from app.agent.prompt import prompt_manager
|
||||
|
||||
captured = {}
|
||||
|
||||
async def _fake_run_background_prompt(**kwargs):
|
||||
captured.update(kwargs)
|
||||
kwargs["output_callback"]("[0, 2]")
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
prompt_manager,
|
||||
"render_system_task_message",
|
||||
return_value="PROMPT",
|
||||
),
|
||||
patch.object(
|
||||
agent_manager,
|
||||
"run_background_prompt",
|
||||
new=AsyncMock(side_effect=_fake_run_background_prompt),
|
||||
),
|
||||
):
|
||||
result = await chain._invoke_recommend_llm("Candidates")
|
||||
|
||||
self.assertEqual("[0, 2]", result)
|
||||
self.assertTrue(captured["suppress_user_reply"])
|
||||
self.assertFalse(captured["persist_output_message"])
|
||||
self.assertTrue(captured["suppress_message_channel_dispatch"])
|
||||
|
||||
def test_search_by_title_clears_previous_recommend_state_when_caching(self):
|
||||
chain = self._make_chain()
|
||||
removed = []
|
||||
@@ -121,6 +156,46 @@ class SearchChainAIRecommendTest(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertIsNone(SearchChain._ai_recommend_result)
|
||||
self.assertIsNone(SearchChain._ai_recommend_error)
|
||||
|
||||
def test_post_message_skips_channel_dispatch_when_suppressed(self):
|
||||
chain = object.__new__(SearchChain)
|
||||
queue_calls = []
|
||||
event_calls = []
|
||||
saved_messages = []
|
||||
saved_records = []
|
||||
chain.messagehelper = SimpleNamespace(
|
||||
put=lambda *args, **kwargs: saved_messages.append((args, kwargs))
|
||||
)
|
||||
chain.messageoper = SimpleNamespace(
|
||||
add=lambda **kwargs: saved_records.append(kwargs)
|
||||
)
|
||||
chain.messagequeue = SimpleNamespace(
|
||||
send_message=lambda *args, **kwargs: queue_calls.append((args, kwargs))
|
||||
)
|
||||
chain.eventmanager = SimpleNamespace(
|
||||
send_event=lambda *args, **kwargs: event_calls.append((args, kwargs))
|
||||
)
|
||||
|
||||
notification = Notification(
|
||||
mtype=NotificationType.Manual,
|
||||
title="Title",
|
||||
text="Body",
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
chain_module.MessageTemplateHelper,
|
||||
"render",
|
||||
return_value=notification,
|
||||
),
|
||||
suppress_message_channel(),
|
||||
):
|
||||
chain.post_message(message=notification)
|
||||
|
||||
self.assertEqual(1, len(saved_messages))
|
||||
self.assertEqual(1, len(saved_records))
|
||||
self.assertEqual([], queue_calls)
|
||||
self.assertEqual([], event_calls)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user