From 7ab643d34a1d18b524cf9da8b6a8d7ac062d932c Mon Sep 17 00:00:00 2001 From: jxxghp Date: Wed, 29 Apr 2026 23:13:57 +0800 Subject: [PATCH] suppress channel notifications for ui background tasks --- app/agent/__init__.py | 15 ++++++- app/api/endpoints/history.py | 4 ++ app/chain/__init__.py | 13 ++++++ app/chain/search.py | 2 + app/core/message_context.py | 26 +++++++++++ tests/test_search_ai_recommend.py | 75 +++++++++++++++++++++++++++++++ 6 files changed, 133 insertions(+), 2 deletions(-) create mode 100644 app/core/message_context.py diff --git a/app/agent/__init__.py b/app/agent/__init__.py index 4631900f..4dfaa886 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -3,6 +3,7 @@ import json import re import traceback import uuid +from contextlib import nullcontext from dataclasses import dataclass from datetime import datetime from typing import Any, Callable, Dict, List, Optional @@ -32,6 +33,7 @@ from app.agent.runtime import agent_runtime_manager from app.agent.tools.factory import MoviePilotToolFactory from app.chain import ChainBase from app.core.config import settings +from app.core.message_context import suppress_message_channel from app.helper.llm import LLMHelper from app.log import logger from app.schemas import Notification, NotificationType @@ -172,6 +174,7 @@ class MoviePilotAgent: self.output_callback: Optional[Callable[[str], None]] = None self.force_streaming = False self.suppress_user_reply = False + self.persist_output_message = True self._streamed_output = "" self._session_usage = _SessionUsageSnapshot() @@ -603,7 +606,7 @@ class MoviePilotAgent: and not self._tool_context.get("user_reply_sent") ): await self.send_agent_message(remaining_text) - elif streamed_text: + elif streamed_text and self.persist_output_message: # 流式输出已发送全部内容,但未记录到数据库,补充保存消息记录 await self._save_agent_message_to_db(streamed_text) @@ -986,6 +989,8 @@ class AgentManager: session_prefix: str = "__agent_background", output_callback: Optional[Callable[[str], None]] = None, suppress_user_reply: bool = False, + persist_output_message: bool = True, + suppress_message_channel_dispatch: bool = False, ) -> None: """ 以独立后台会话执行一段 prompt。 @@ -1002,9 +1007,15 @@ class AgentManager: agent.output_callback = output_callback agent.force_streaming = bool(output_callback) agent.suppress_user_reply = suppress_user_reply + agent.persist_output_message = persist_output_message try: - await agent.process(message) + with ( + suppress_message_channel() + if suppress_message_channel_dispatch + else nullcontext() + ): + await agent.process(message) finally: await agent.cleanup() memory_manager.clear_memory(session_id, user_id) diff --git a/app/api/endpoints/history.py b/app/api/endpoints/history.py index f42c4916..4e5f64ea 100644 --- a/app/api/endpoints/history.py +++ b/app/api/endpoints/history.py @@ -131,6 +131,8 @@ def _start_ai_redo_task(history_id: int, prompt: str, progress_key: str): session_prefix=f"__agent_manual_redo_{history_id}", output_callback=update_output, suppress_user_reply=True, + persist_output_message=False, + suppress_message_channel_dispatch=True, ) progress.update( text="智能助手整理完成", @@ -175,6 +177,8 @@ def _start_batch_ai_redo_task( session_prefix="__agent_manual_redo_batch", output_callback=update_output, suppress_user_reply=True, + persist_output_message=False, + suppress_message_channel_dispatch=True, ) progress.update( text="智能助手批量整理完成", diff --git a/app/chain/__init__.py b/app/chain/__init__.py index 0073a537..07e12d7b 100644 --- a/app/chain/__init__.py +++ b/app/chain/__init__.py @@ -17,6 +17,7 @@ from app.core.config import settings from app.core.context import Context, MediaInfo, TorrentInfo from app.core.event import EventManager from app.core.meta import MetaBase +from app.core.message_context import is_message_channel_suppressed from app.core.module import ModuleManager from app.core.plugin import PluginManager from app.db.message_oper import MessageOper @@ -1136,6 +1137,9 @@ class ChainBase(metaclass=ABCMeta): # 保存消息 self.messagehelper.put(message, role="user", title=message.title) self.messageoper.add(**message.model_dump()) + if is_message_channel_suppressed(): + logger.info("当前上下文已禁用消息渠道派发,仅保存消息记录") + return dispatch_message = self._normalize_notification_for_dispatch(message) # 发送消息按设置隔离 if not dispatch_message.userid and dispatch_message.mtype: @@ -1253,6 +1257,9 @@ class ChainBase(metaclass=ABCMeta): # 保存消息 self.messagehelper.put(message, role="user", title=message.title) await self.messageoper.async_add(**message.model_dump()) + if is_message_channel_suppressed(): + logger.info("当前上下文已禁用消息渠道派发,仅保存消息记录") + return dispatch_message = self._normalize_notification_for_dispatch(message) # 发送消息按设置隔离 if not dispatch_message.userid and dispatch_message.mtype: @@ -1347,6 +1354,9 @@ class ChainBase(metaclass=ABCMeta): message, role="user", note=note_list, title=message.title ) self.messageoper.add(**message.model_dump(), note=note_list) + if is_message_channel_suppressed(): + logger.info("当前上下文已禁用消息渠道派发,仅保存媒体消息记录") + return None dispatch_message = self._normalize_notification_for_dispatch(message) return self.messagequeue.send_message( "post_medias_message", @@ -1369,6 +1379,9 @@ class ChainBase(metaclass=ABCMeta): message, role="user", note=note_list, title=message.title ) self.messageoper.add(**message.model_dump(), note=note_list) + if is_message_channel_suppressed(): + logger.info("当前上下文已禁用消息渠道派发,仅保存种子消息记录") + return None dispatch_message = self._normalize_notification_for_dispatch(message) return self.messagequeue.send_message( "post_torrents_message", diff --git a/app/chain/search.py b/app/chain/search.py index 10cb3eed..04acb577 100644 --- a/app/chain/search.py +++ b/app/chain/search.py @@ -227,6 +227,8 @@ class SearchChain(ChainBase): session_prefix="__agent_search_recommend", output_callback=on_output, suppress_user_reply=True, + persist_output_message=False, + suppress_message_channel_dispatch=True, ) return full_output[0].strip() diff --git a/app/core/message_context.py b/app/core/message_context.py new file mode 100644 index 00000000..eb2fe9fb --- /dev/null +++ b/app/core/message_context.py @@ -0,0 +1,26 @@ +import contextvars +from contextlib import contextmanager +from typing import Iterator + +_suppress_message_channel = contextvars.ContextVar( + "suppress_message_channel", default=False +) + + +def is_message_channel_suppressed() -> bool: + """ + 当前上下文是否禁止向外部消息渠道派发通知。 + """ + return bool(_suppress_message_channel.get()) + + +@contextmanager +def suppress_message_channel() -> Iterator[None]: + """ + 在当前上下文中临时禁用外部消息渠道派发。 + """ + token = _suppress_message_channel.set(True) + try: + yield + finally: + _suppress_message_channel.reset(token) diff --git a/tests/test_search_ai_recommend.py b/tests/test_search_ai_recommend.py index aeb9dc70..2ecb5527 100644 --- a/tests/test_search_ai_recommend.py +++ b/tests/test_search_ai_recommend.py @@ -5,6 +5,8 @@ from types import SimpleNamespace from types import ModuleType from unittest.mock import AsyncMock, patch +import app.chain as chain_module + def _stub_module(name: str, **attrs): module = sys.modules.get(name) @@ -21,6 +23,9 @@ _stub_module("transmission_rpc", File=object) from app.chain.search import SearchChain from app.core.config import settings +from app.core.message_context import suppress_message_channel +from app.schemas import Notification +from app.schemas.types import NotificationType def _make_result(title: str, size: int, seeders: int): @@ -98,6 +103,36 @@ class SearchChainAIRecommendTest(unittest.IsolatedAsyncioTestCase): self.assertFalse(SearchChain._ai_recommend_running) self.assertIsNone(SearchChain._ai_recommend_task) + async def test_invoke_recommend_llm_disables_output_message_persistence(self): + chain = self._make_chain() + from app.agent import agent_manager + from app.agent.prompt import prompt_manager + + captured = {} + + async def _fake_run_background_prompt(**kwargs): + captured.update(kwargs) + kwargs["output_callback"]("[0, 2]") + + with ( + patch.object( + prompt_manager, + "render_system_task_message", + return_value="PROMPT", + ), + patch.object( + agent_manager, + "run_background_prompt", + new=AsyncMock(side_effect=_fake_run_background_prompt), + ), + ): + result = await chain._invoke_recommend_llm("Candidates") + + self.assertEqual("[0, 2]", result) + self.assertTrue(captured["suppress_user_reply"]) + self.assertFalse(captured["persist_output_message"]) + self.assertTrue(captured["suppress_message_channel_dispatch"]) + def test_search_by_title_clears_previous_recommend_state_when_caching(self): chain = self._make_chain() removed = [] @@ -121,6 +156,46 @@ class SearchChainAIRecommendTest(unittest.IsolatedAsyncioTestCase): self.assertIsNone(SearchChain._ai_recommend_result) self.assertIsNone(SearchChain._ai_recommend_error) + def test_post_message_skips_channel_dispatch_when_suppressed(self): + chain = object.__new__(SearchChain) + queue_calls = [] + event_calls = [] + saved_messages = [] + saved_records = [] + chain.messagehelper = SimpleNamespace( + put=lambda *args, **kwargs: saved_messages.append((args, kwargs)) + ) + chain.messageoper = SimpleNamespace( + add=lambda **kwargs: saved_records.append(kwargs) + ) + chain.messagequeue = SimpleNamespace( + send_message=lambda *args, **kwargs: queue_calls.append((args, kwargs)) + ) + chain.eventmanager = SimpleNamespace( + send_event=lambda *args, **kwargs: event_calls.append((args, kwargs)) + ) + + notification = Notification( + mtype=NotificationType.Manual, + title="Title", + text="Body", + ) + + with ( + patch.object( + chain_module.MessageTemplateHelper, + "render", + return_value=notification, + ), + suppress_message_channel(), + ): + chain.post_message(message=notification) + + self.assertEqual(1, len(saved_messages)) + self.assertEqual(1, len(saved_records)) + self.assertEqual([], queue_calls) + self.assertEqual([], event_calls) + if __name__ == "__main__": unittest.main()