move ui background message suppression into agent context

This commit is contained in:
jxxghp
2026-04-29 23:22:37 +08:00
parent 7ab643d34a
commit ef5bd29759
5 changed files with 47 additions and 42 deletions

View File

@@ -3,7 +3,6 @@ import json
import re
import traceback
import uuid
from contextlib import nullcontext
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional
@@ -31,9 +30,9 @@ from app.agent.middleware.usage import UsageMiddleware
from app.agent.prompt import prompt_manager
from app.agent.runtime import agent_runtime_manager
from app.agent.tools.factory import MoviePilotToolFactory
from app.agent_context import agent_execution_context
from app.chain import ChainBase
from app.core.config import settings
from app.core.message_context import suppress_message_channel
from app.helper.llm import LLMHelper
from app.log import logger
from app.schemas import Notification, NotificationType
@@ -175,6 +174,7 @@ class MoviePilotAgent:
self.force_streaming = False
self.suppress_user_reply = False
self.persist_output_message = True
self.suppress_message_channel_dispatch = False
self._streamed_output = ""
self._session_usage = _SessionUsageSnapshot()
@@ -457,6 +457,7 @@ class MoviePilotAgent:
self._tool_context = {
"user_reply_sent": False,
"reply_mode": None,
"suppress_message_channel_dispatch": self.suppress_message_channel_dispatch,
}
self._streamed_output = ""
@@ -485,7 +486,10 @@ class MoviePilotAgent:
messages.append(HumanMessage(content=content))
# 执行推理
await self._execute_agent(messages)
with agent_execution_context(
suppress_message_channel_dispatch=self.suppress_message_channel_dispatch
):
await self._execute_agent(messages)
except Exception as e:
error_message = f"处理消息时发生错误: {str(e)}"
@@ -1008,14 +1012,10 @@ class AgentManager:
agent.force_streaming = bool(output_callback)
agent.suppress_user_reply = suppress_user_reply
agent.persist_output_message = persist_output_message
agent.suppress_message_channel_dispatch = suppress_message_channel_dispatch
try:
with (
suppress_message_channel()
if suppress_message_channel_dispatch
else nullcontext()
):
await agent.process(message)
await agent.process(message)
finally:
await agent.cleanup()
memory_manager.clear_memory(session_id, user_id)

31
app/agent_context.py Normal file
View File

@@ -0,0 +1,31 @@
import contextvars
from contextlib import contextmanager
from typing import Iterator
_suppress_message_channel_dispatch = contextvars.ContextVar(
"suppress_message_channel_dispatch",
default=False,
)
def is_message_channel_dispatch_suppressed() -> bool:
"""
当前 Agent 执行上下文是否禁止向外部消息渠道派发通知。
"""
return bool(_suppress_message_channel_dispatch.get())
@contextmanager
def agent_execution_context(
*, suppress_message_channel_dispatch: bool = False
) -> Iterator[None]:
"""
绑定当前 Agent 执行期的上下文参数。
"""
token = _suppress_message_channel_dispatch.set(
bool(suppress_message_channel_dispatch)
)
try:
yield
finally:
_suppress_message_channel_dispatch.reset(token)

View File

@@ -12,12 +12,12 @@ from fastapi.concurrency import run_in_threadpool
from qbittorrentapi import TorrentFilesList
from transmission_rpc import File
from app.agent_context import is_message_channel_dispatch_suppressed
from app.core.cache import FileCache, AsyncFileCache, fresh, async_fresh
from app.core.config import settings
from app.core.context import Context, MediaInfo, TorrentInfo
from app.core.event import EventManager
from app.core.meta import MetaBase
from app.core.message_context import is_message_channel_suppressed
from app.core.module import ModuleManager
from app.core.plugin import PluginManager
from app.db.message_oper import MessageOper
@@ -1137,7 +1137,7 @@ class ChainBase(metaclass=ABCMeta):
# 保存消息
self.messagehelper.put(message, role="user", title=message.title)
self.messageoper.add(**message.model_dump())
if is_message_channel_suppressed():
if is_message_channel_dispatch_suppressed():
logger.info("当前上下文已禁用消息渠道派发,仅保存消息记录")
return
dispatch_message = self._normalize_notification_for_dispatch(message)
@@ -1257,7 +1257,7 @@ class ChainBase(metaclass=ABCMeta):
# 保存消息
self.messagehelper.put(message, role="user", title=message.title)
await self.messageoper.async_add(**message.model_dump())
if is_message_channel_suppressed():
if is_message_channel_dispatch_suppressed():
logger.info("当前上下文已禁用消息渠道派发,仅保存消息记录")
return
dispatch_message = self._normalize_notification_for_dispatch(message)
@@ -1354,7 +1354,7 @@ class ChainBase(metaclass=ABCMeta):
message, role="user", note=note_list, title=message.title
)
self.messageoper.add(**message.model_dump(), note=note_list)
if is_message_channel_suppressed():
if is_message_channel_dispatch_suppressed():
logger.info("当前上下文已禁用消息渠道派发,仅保存媒体消息记录")
return None
dispatch_message = self._normalize_notification_for_dispatch(message)
@@ -1379,7 +1379,7 @@ class ChainBase(metaclass=ABCMeta):
message, role="user", note=note_list, title=message.title
)
self.messageoper.add(**message.model_dump(), note=note_list)
if is_message_channel_suppressed():
if is_message_channel_dispatch_suppressed():
logger.info("当前上下文已禁用消息渠道派发,仅保存种子消息记录")
return None
dispatch_message = self._normalize_notification_for_dispatch(message)

View File

@@ -1,26 +0,0 @@
import contextvars
from contextlib import contextmanager
from typing import Iterator
_suppress_message_channel = contextvars.ContextVar(
"suppress_message_channel", default=False
)
def is_message_channel_suppressed() -> bool:
"""
当前上下文是否禁止向外部消息渠道派发通知。
"""
return bool(_suppress_message_channel.get())
@contextmanager
def suppress_message_channel() -> Iterator[None]:
"""
在当前上下文中临时禁用外部消息渠道派发。
"""
token = _suppress_message_channel.set(True)
try:
yield
finally:
_suppress_message_channel.reset(token)

View File

@@ -22,8 +22,8 @@ _stub_module("qbittorrentapi", TorrentFilesList=list)
_stub_module("transmission_rpc", File=object)
from app.chain.search import SearchChain
from app.agent_context import agent_execution_context
from app.core.config import settings
from app.core.message_context import suppress_message_channel
from app.schemas import Notification
from app.schemas.types import NotificationType
@@ -187,7 +187,7 @@ class SearchChainAIRecommendTest(unittest.IsolatedAsyncioTestCase):
"render",
return_value=notification,
),
suppress_message_channel(),
agent_execution_context(suppress_message_channel_dispatch=True),
):
chain.post_message(message=notification)