mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-07 05:42:40 +08:00
move ui background message suppression into agent context
This commit is contained in:
@@ -3,7 +3,6 @@ import json
|
||||
import re
|
||||
import traceback
|
||||
import uuid
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
@@ -31,9 +30,9 @@ from app.agent.middleware.usage import UsageMiddleware
|
||||
from app.agent.prompt import prompt_manager
|
||||
from app.agent.runtime import agent_runtime_manager
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.agent_context import agent_execution_context
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.core.message_context import suppress_message_channel
|
||||
from app.helper.llm import LLMHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotificationType
|
||||
@@ -175,6 +174,7 @@ class MoviePilotAgent:
|
||||
self.force_streaming = False
|
||||
self.suppress_user_reply = False
|
||||
self.persist_output_message = True
|
||||
self.suppress_message_channel_dispatch = False
|
||||
self._streamed_output = ""
|
||||
self._session_usage = _SessionUsageSnapshot()
|
||||
|
||||
@@ -457,6 +457,7 @@ class MoviePilotAgent:
|
||||
self._tool_context = {
|
||||
"user_reply_sent": False,
|
||||
"reply_mode": None,
|
||||
"suppress_message_channel_dispatch": self.suppress_message_channel_dispatch,
|
||||
}
|
||||
self._streamed_output = ""
|
||||
|
||||
@@ -485,7 +486,10 @@ class MoviePilotAgent:
|
||||
messages.append(HumanMessage(content=content))
|
||||
|
||||
# 执行推理
|
||||
await self._execute_agent(messages)
|
||||
with agent_execution_context(
|
||||
suppress_message_channel_dispatch=self.suppress_message_channel_dispatch
|
||||
):
|
||||
await self._execute_agent(messages)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"处理消息时发生错误: {str(e)}"
|
||||
@@ -1008,14 +1012,10 @@ class AgentManager:
|
||||
agent.force_streaming = bool(output_callback)
|
||||
agent.suppress_user_reply = suppress_user_reply
|
||||
agent.persist_output_message = persist_output_message
|
||||
agent.suppress_message_channel_dispatch = suppress_message_channel_dispatch
|
||||
|
||||
try:
|
||||
with (
|
||||
suppress_message_channel()
|
||||
if suppress_message_channel_dispatch
|
||||
else nullcontext()
|
||||
):
|
||||
await agent.process(message)
|
||||
await agent.process(message)
|
||||
finally:
|
||||
await agent.cleanup()
|
||||
memory_manager.clear_memory(session_id, user_id)
|
||||
|
||||
31
app/agent_context.py
Normal file
31
app/agent_context.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import contextvars
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator
|
||||
|
||||
_suppress_message_channel_dispatch = contextvars.ContextVar(
|
||||
"suppress_message_channel_dispatch",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
def is_message_channel_dispatch_suppressed() -> bool:
|
||||
"""
|
||||
当前 Agent 执行上下文是否禁止向外部消息渠道派发通知。
|
||||
"""
|
||||
return bool(_suppress_message_channel_dispatch.get())
|
||||
|
||||
|
||||
@contextmanager
|
||||
def agent_execution_context(
|
||||
*, suppress_message_channel_dispatch: bool = False
|
||||
) -> Iterator[None]:
|
||||
"""
|
||||
绑定当前 Agent 执行期的上下文参数。
|
||||
"""
|
||||
token = _suppress_message_channel_dispatch.set(
|
||||
bool(suppress_message_channel_dispatch)
|
||||
)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_suppress_message_channel_dispatch.reset(token)
|
||||
@@ -12,12 +12,12 @@ from fastapi.concurrency import run_in_threadpool
|
||||
from qbittorrentapi import TorrentFilesList
|
||||
from transmission_rpc import File
|
||||
|
||||
from app.agent_context import is_message_channel_dispatch_suppressed
|
||||
from app.core.cache import FileCache, AsyncFileCache, fresh, async_fresh
|
||||
from app.core.config import settings
|
||||
from app.core.context import Context, MediaInfo, TorrentInfo
|
||||
from app.core.event import EventManager
|
||||
from app.core.meta import MetaBase
|
||||
from app.core.message_context import is_message_channel_suppressed
|
||||
from app.core.module import ModuleManager
|
||||
from app.core.plugin import PluginManager
|
||||
from app.db.message_oper import MessageOper
|
||||
@@ -1137,7 +1137,7 @@ class ChainBase(metaclass=ABCMeta):
|
||||
# 保存消息
|
||||
self.messagehelper.put(message, role="user", title=message.title)
|
||||
self.messageoper.add(**message.model_dump())
|
||||
if is_message_channel_suppressed():
|
||||
if is_message_channel_dispatch_suppressed():
|
||||
logger.info("当前上下文已禁用消息渠道派发,仅保存消息记录")
|
||||
return
|
||||
dispatch_message = self._normalize_notification_for_dispatch(message)
|
||||
@@ -1257,7 +1257,7 @@ class ChainBase(metaclass=ABCMeta):
|
||||
# 保存消息
|
||||
self.messagehelper.put(message, role="user", title=message.title)
|
||||
await self.messageoper.async_add(**message.model_dump())
|
||||
if is_message_channel_suppressed():
|
||||
if is_message_channel_dispatch_suppressed():
|
||||
logger.info("当前上下文已禁用消息渠道派发,仅保存消息记录")
|
||||
return
|
||||
dispatch_message = self._normalize_notification_for_dispatch(message)
|
||||
@@ -1354,7 +1354,7 @@ class ChainBase(metaclass=ABCMeta):
|
||||
message, role="user", note=note_list, title=message.title
|
||||
)
|
||||
self.messageoper.add(**message.model_dump(), note=note_list)
|
||||
if is_message_channel_suppressed():
|
||||
if is_message_channel_dispatch_suppressed():
|
||||
logger.info("当前上下文已禁用消息渠道派发,仅保存媒体消息记录")
|
||||
return None
|
||||
dispatch_message = self._normalize_notification_for_dispatch(message)
|
||||
@@ -1379,7 +1379,7 @@ class ChainBase(metaclass=ABCMeta):
|
||||
message, role="user", note=note_list, title=message.title
|
||||
)
|
||||
self.messageoper.add(**message.model_dump(), note=note_list)
|
||||
if is_message_channel_suppressed():
|
||||
if is_message_channel_dispatch_suppressed():
|
||||
logger.info("当前上下文已禁用消息渠道派发,仅保存种子消息记录")
|
||||
return None
|
||||
dispatch_message = self._normalize_notification_for_dispatch(message)
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
import contextvars
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator
|
||||
|
||||
_suppress_message_channel = contextvars.ContextVar(
|
||||
"suppress_message_channel", default=False
|
||||
)
|
||||
|
||||
|
||||
def is_message_channel_suppressed() -> bool:
|
||||
"""
|
||||
当前上下文是否禁止向外部消息渠道派发通知。
|
||||
"""
|
||||
return bool(_suppress_message_channel.get())
|
||||
|
||||
|
||||
@contextmanager
|
||||
def suppress_message_channel() -> Iterator[None]:
|
||||
"""
|
||||
在当前上下文中临时禁用外部消息渠道派发。
|
||||
"""
|
||||
token = _suppress_message_channel.set(True)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_suppress_message_channel.reset(token)
|
||||
@@ -22,8 +22,8 @@ _stub_module("qbittorrentapi", TorrentFilesList=list)
|
||||
_stub_module("transmission_rpc", File=object)
|
||||
|
||||
from app.chain.search import SearchChain
|
||||
from app.agent_context import agent_execution_context
|
||||
from app.core.config import settings
|
||||
from app.core.message_context import suppress_message_channel
|
||||
from app.schemas import Notification
|
||||
from app.schemas.types import NotificationType
|
||||
|
||||
@@ -187,7 +187,7 @@ class SearchChainAIRecommendTest(unittest.IsolatedAsyncioTestCase):
|
||||
"render",
|
||||
return_value=notification,
|
||||
),
|
||||
suppress_message_channel(),
|
||||
agent_execution_context(suppress_message_channel_dispatch=True),
|
||||
):
|
||||
chain.post_message(message=notification)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user