From ef5bd2975907e44741a57912e56c6be69d3032d6 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Wed, 29 Apr 2026 23:22:37 +0800 Subject: [PATCH] move ui background message suppression into agent context --- app/agent/__init__.py | 18 +++++++++--------- app/agent_context.py | 31 +++++++++++++++++++++++++++++++ app/chain/__init__.py | 10 +++++----- app/core/message_context.py | 26 -------------------------- tests/test_search_ai_recommend.py | 4 ++-- 5 files changed, 47 insertions(+), 42 deletions(-) create mode 100644 app/agent_context.py delete mode 100644 app/core/message_context.py diff --git a/app/agent/__init__.py b/app/agent/__init__.py index 4dfaa886..6f7406db 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -3,7 +3,6 @@ import json import re import traceback import uuid -from contextlib import nullcontext from dataclasses import dataclass from datetime import datetime from typing import Any, Callable, Dict, List, Optional @@ -31,9 +30,9 @@ from app.agent.middleware.usage import UsageMiddleware from app.agent.prompt import prompt_manager from app.agent.runtime import agent_runtime_manager from app.agent.tools.factory import MoviePilotToolFactory +from app.agent_context import agent_execution_context from app.chain import ChainBase from app.core.config import settings -from app.core.message_context import suppress_message_channel from app.helper.llm import LLMHelper from app.log import logger from app.schemas import Notification, NotificationType @@ -175,6 +174,7 @@ class MoviePilotAgent: self.force_streaming = False self.suppress_user_reply = False self.persist_output_message = True + self.suppress_message_channel_dispatch = False self._streamed_output = "" self._session_usage = _SessionUsageSnapshot() @@ -457,6 +457,7 @@ class MoviePilotAgent: self._tool_context = { "user_reply_sent": False, "reply_mode": None, + "suppress_message_channel_dispatch": self.suppress_message_channel_dispatch, } self._streamed_output = "" @@ -485,7 +486,10 @@ class MoviePilotAgent: messages.append(HumanMessage(content=content)) # 执行推理 - await self._execute_agent(messages) + with agent_execution_context( + suppress_message_channel_dispatch=self.suppress_message_channel_dispatch + ): + await self._execute_agent(messages) except Exception as e: error_message = f"处理消息时发生错误: {str(e)}" @@ -1008,14 +1012,10 @@ class AgentManager: agent.force_streaming = bool(output_callback) agent.suppress_user_reply = suppress_user_reply agent.persist_output_message = persist_output_message + agent.suppress_message_channel_dispatch = suppress_message_channel_dispatch try: - with ( - suppress_message_channel() - if suppress_message_channel_dispatch - else nullcontext() - ): - await agent.process(message) + await agent.process(message) finally: await agent.cleanup() memory_manager.clear_memory(session_id, user_id) diff --git a/app/agent_context.py b/app/agent_context.py new file mode 100644 index 00000000..0d9c0ba1 --- /dev/null +++ b/app/agent_context.py @@ -0,0 +1,31 @@ +import contextvars +from contextlib import contextmanager +from typing import Iterator + +_suppress_message_channel_dispatch = contextvars.ContextVar( + "suppress_message_channel_dispatch", + default=False, +) + + +def is_message_channel_dispatch_suppressed() -> bool: + """ + 当前 Agent 执行上下文是否禁止向外部消息渠道派发通知。 + """ + return bool(_suppress_message_channel_dispatch.get()) + + +@contextmanager +def agent_execution_context( + *, suppress_message_channel_dispatch: bool = False +) -> Iterator[None]: + """ + 绑定当前 Agent 执行期的上下文参数。 + """ + token = _suppress_message_channel_dispatch.set( + bool(suppress_message_channel_dispatch) + ) + try: + yield + finally: + _suppress_message_channel_dispatch.reset(token) diff --git a/app/chain/__init__.py b/app/chain/__init__.py index 07e12d7b..aa74e3f2 100644 --- a/app/chain/__init__.py +++ b/app/chain/__init__.py @@ -12,12 +12,12 @@ from fastapi.concurrency import run_in_threadpool from qbittorrentapi import TorrentFilesList from transmission_rpc import File +from app.agent_context import is_message_channel_dispatch_suppressed from app.core.cache import FileCache, AsyncFileCache, fresh, async_fresh from app.core.config import settings from app.core.context import Context, MediaInfo, TorrentInfo from app.core.event import EventManager from app.core.meta import MetaBase -from app.core.message_context import is_message_channel_suppressed from app.core.module import ModuleManager from app.core.plugin import PluginManager from app.db.message_oper import MessageOper @@ -1137,7 +1137,7 @@ class ChainBase(metaclass=ABCMeta): # 保存消息 self.messagehelper.put(message, role="user", title=message.title) self.messageoper.add(**message.model_dump()) - if is_message_channel_suppressed(): + if is_message_channel_dispatch_suppressed(): logger.info("当前上下文已禁用消息渠道派发,仅保存消息记录") return dispatch_message = self._normalize_notification_for_dispatch(message) @@ -1257,7 +1257,7 @@ class ChainBase(metaclass=ABCMeta): # 保存消息 self.messagehelper.put(message, role="user", title=message.title) await self.messageoper.async_add(**message.model_dump()) - if is_message_channel_suppressed(): + if is_message_channel_dispatch_suppressed(): logger.info("当前上下文已禁用消息渠道派发,仅保存消息记录") return dispatch_message = self._normalize_notification_for_dispatch(message) @@ -1354,7 +1354,7 @@ class ChainBase(metaclass=ABCMeta): message, role="user", note=note_list, title=message.title ) self.messageoper.add(**message.model_dump(), note=note_list) - if is_message_channel_suppressed(): + if is_message_channel_dispatch_suppressed(): logger.info("当前上下文已禁用消息渠道派发,仅保存媒体消息记录") return None dispatch_message = self._normalize_notification_for_dispatch(message) @@ -1379,7 +1379,7 @@ class ChainBase(metaclass=ABCMeta): message, role="user", note=note_list, title=message.title ) self.messageoper.add(**message.model_dump(), note=note_list) - if is_message_channel_suppressed(): + if is_message_channel_dispatch_suppressed(): logger.info("当前上下文已禁用消息渠道派发,仅保存种子消息记录") return None dispatch_message = self._normalize_notification_for_dispatch(message) diff --git a/app/core/message_context.py b/app/core/message_context.py deleted file mode 100644 index eb2fe9fb..00000000 --- a/app/core/message_context.py +++ /dev/null @@ -1,26 +0,0 @@ -import contextvars -from contextlib import contextmanager -from typing import Iterator - -_suppress_message_channel = contextvars.ContextVar( - "suppress_message_channel", default=False -) - - -def is_message_channel_suppressed() -> bool: - """ - 当前上下文是否禁止向外部消息渠道派发通知。 - """ - return bool(_suppress_message_channel.get()) - - -@contextmanager -def suppress_message_channel() -> Iterator[None]: - """ - 在当前上下文中临时禁用外部消息渠道派发。 - """ - token = _suppress_message_channel.set(True) - try: - yield - finally: - _suppress_message_channel.reset(token) diff --git a/tests/test_search_ai_recommend.py b/tests/test_search_ai_recommend.py index 2ecb5527..e7ea36fd 100644 --- a/tests/test_search_ai_recommend.py +++ b/tests/test_search_ai_recommend.py @@ -22,8 +22,8 @@ _stub_module("qbittorrentapi", TorrentFilesList=list) _stub_module("transmission_rpc", File=object) from app.chain.search import SearchChain +from app.agent_context import agent_execution_context from app.core.config import settings -from app.core.message_context import suppress_message_channel from app.schemas import Notification from app.schemas.types import NotificationType @@ -187,7 +187,7 @@ class SearchChainAIRecommendTest(unittest.IsolatedAsyncioTestCase): "render", return_value=notification, ), - suppress_message_channel(), + agent_execution_context(suppress_message_channel_dispatch=True), ): chain.post_message(message=notification)