refactor: remove persist_output_message functionality and related database save logic

This commit is contained in:
jxxghp
2026-06-17 16:14:35 +08:00
parent 039558d240
commit d4b6d3f332
20 changed files with 442 additions and 170 deletions

View File

@@ -246,7 +246,6 @@ class MoviePilotAgent:
original_message_id: Optional[str] = None,
original_chat_id: Optional[str] = None,
replay_mode: ReplyMode = ReplyMode.DISPATCH,
persist_output_message: bool = True,
allow_message_tools: bool = True,
output_callback: Optional[Callable[[str], None]] = None,
):
@@ -258,7 +257,6 @@ class MoviePilotAgent:
self.original_message_id = original_message_id
self.original_chat_id = original_chat_id
self.reply_mode = replay_mode
self.persist_output_message = persist_output_message
self.allow_message_tools = allow_message_tools
self.output_callback = output_callback
self._tool_context: Dict[str, object] = {}
@@ -782,8 +780,6 @@ class MoviePilotAgent:
title = "MoviePilot助手" if self.is_background else ""
if self.should_dispatch_reply:
await self.send_agent_message(message, title=title)
elif self.persist_output_message:
await self._save_agent_message_to_db(message, title=title)
def _emit_output(self, text: str):
"""
@@ -1126,19 +1122,6 @@ class MoviePilotAgent:
and not self._tool_context.get("user_reply_sent")
):
await self.send_agent_message(remaining_text)
elif (
remaining_text
and self.persist_output_message
and not self._tool_context.get("user_reply_sent")
):
title = "MoviePilot助手" if self.is_background else ""
await self._save_agent_message_to_db(
remaining_text,
title=title,
)
elif streamed_text and self.persist_output_message:
# 流式输出已发送全部内容,但未记录到数据库,补充保存消息记录
await self._save_agent_message_to_db(streamed_text)
else:
# 非流式模式:后台任务或渠道不支持消息编辑
@@ -1180,13 +1163,6 @@ class MoviePilotAgent:
else:
# 非流式渠道:发送最终回复
await self.send_agent_message(final_text)
elif (
final_text
and self.persist_output_message
and not self._tool_context.get("user_reply_sent")
):
title = "MoviePilot助手" if self.is_background else ""
await self._save_agent_message_to_db(final_text, title=title)
# 保存消息
memory_manager.save_agent_messages(
@@ -1238,26 +1214,6 @@ class MoviePilotAgent:
)
)
async def _save_agent_message_to_db(self, message: str, title: str = ""):
"""
仅保存Agent回复消息到数据库和SSE队列不重新发送到渠道
用于流式输出场景:消息已通过 send_direct_message/edit_message 发送给用户,
但未记录到数据库中,此方法补充保存消息历史记录。
"""
chain = AgentChain()
notification = Notification(
channel=self.channel,
source=self.source,
userid=self.user_id,
username=self.username,
title=title,
text=message,
)
# 保存到SSE消息队列供前端展示
chain.messagehelper.put(notification, role="user", title=title)
# 保存到数据库
await chain.messageoper.async_add(**notification.model_dump())
async def cleanup(self):
"""
清理智能体资源
@@ -1284,7 +1240,6 @@ class _MessageTask:
original_chat_id: Optional[str] = None
processing_status: Optional[dict] = None
reply_mode: ReplyMode = ReplyMode.DISPATCH
persist_output_message: bool = True
allow_message_tools: bool = True
@@ -1430,7 +1385,6 @@ class AgentManager:
original_message_id: Optional[str] = None,
original_chat_id: Optional[str] = None,
reply_mode: ReplyMode = ReplyMode.DISPATCH,
persist_output_message: bool = True,
allow_message_tools: bool = True,
) -> str:
"""
@@ -1450,7 +1404,6 @@ class AgentManager:
original_message_id=original_message_id,
original_chat_id=original_chat_id,
reply_mode=reply_mode,
persist_output_message=persist_output_message,
allow_message_tools=allow_message_tools,
)
self._record_session_activity(session_id, user_id)
@@ -1561,7 +1514,6 @@ class AgentManager:
original_message_id=task.original_message_id,
original_chat_id=task.original_chat_id,
replay_mode=task.reply_mode,
persist_output_message=task.persist_output_message,
allow_message_tools=task.allow_message_tools,
)
self.active_agents[session_id] = agent
@@ -1577,7 +1529,6 @@ class AgentManager:
agent.original_message_id = task.original_message_id
agent.original_chat_id = task.original_chat_id
agent.reply_mode = task.reply_mode
agent.persist_output_message = task.persist_output_message
agent.allow_message_tools = task.allow_message_tools
process_kwargs = {
@@ -1656,7 +1607,6 @@ class AgentManager:
session_prefix: str = "__agent_background",
output_callback: Optional[Callable[[str], None]] = None,
reply_mode: ReplyMode = ReplyMode.CAPTURE_ONLY,
persist_output_message: bool = True,
allow_message_tools: Optional[bool] = None,
) -> None:
"""
@@ -1677,7 +1627,6 @@ class AgentManager:
source=None,
username=settings.SUPERUSER,
replay_mode=reply_mode,
persist_output_message=persist_output_message,
output_callback=output_callback,
allow_message_tools=allow_message_tools,
)
@@ -1723,7 +1672,6 @@ class AgentManager:
source=None,
username=settings.SUPERUSER,
reply_mode=ReplyMode.CAPTURE_ONLY,
persist_output_message=False,
allow_message_tools=True,
)

View File

@@ -7,6 +7,8 @@ from pydantic import BaseModel, Field, model_validator
from app.agent.tools.base import MoviePilotTool
from app.agent.tools.tags import ToolTag
from app.log import logger
from app.schemas import Notification
from app.schemas.types import NotificationType
class SendMessageInput(BaseModel):
@@ -77,7 +79,18 @@ class SendMessageTool(MoviePilotTool):
f"执行工具: {self.name}, 参数: title={title}, message={text}, image_url={image_url}"
)
try:
await self.send_tool_message(text, title=title, image=image_url)
await self.send_notification_message(
Notification(
channel=self._channel,
source=self._source,
mtype=NotificationType.Other,
userid=self._user_id,
username=self._username,
title=title,
text=text,
image=image_url,
)
)
return "消息已发送"
except Exception as e:
logger.error(f"发送消息失败: {e}")

View File

@@ -881,7 +881,6 @@ async def web_agent_stream(
source=WEB_AGENT_SOURCE,
username=current_user.name,
replay_mode=ReplyMode.CAPTURE_ONLY,
persist_output_message=False,
allow_message_tools=True,
output_callback=output_callback,
notification_callback=notification_callback,

View File

@@ -136,7 +136,6 @@ def _start_ai_redo_task(history_id: int, prompt: str, progress_key: str):
session_prefix=f"__agent_manual_redo_{history_id}",
output_callback=update_output,
reply_mode=ReplyMode.CAPTURE_ONLY,
persist_output_message=False,
allow_message_tools=False,
)
progress.update(
@@ -182,7 +181,6 @@ def _start_batch_ai_redo_task(
session_prefix="__agent_manual_redo_batch",
output_callback=update_output,
reply_mode=ReplyMode.CAPTURE_ONLY,
persist_output_message=False,
allow_message_tools=False,
)
progress.update(

View File

@@ -12,7 +12,7 @@ from app.core.config import settings, global_vars
from app.core.security import verify_token, verify_apitoken
from app.db import get_async_db
from app.db.models import User
from app.db.models.message import Message
from app.db.message_oper import MessageOper
from app.db.user_oper import get_current_active_superuser
from app.helper.service import ServiceConfigHelper
from app.helper.webpush import is_webpush_subscription_gone
@@ -120,7 +120,7 @@ async def get_web_message(
获取WEB消息列表
"""
ret_messages = []
messages = await Message.async_list_by_page(db, page=page, count=count)
messages = await MessageOper(db).async_list_by_page(page=page, count=count)
for message in messages:
try:
ret_messages.append(message.to_dict())
@@ -130,6 +130,20 @@ async def get_web_message(
return ret_messages
@router.get("/notification", summary="获取通知消息", response_model=List[schemas.NotificationHistoryItem])
async def get_notification_message(
_: schemas.TokenPayload = Depends(verify_token),
db: AsyncSession = Depends(get_async_db),
page: Optional[int] = 1,
count: Optional[int] = 20,
):
"""
获取系统发送的通知消息列表。
"""
messages = await MessageOper(db).async_list_sent_by_page(page=page, count=count)
return [schemas.NotificationHistoryItem(**message.to_dict()) for message in messages]
def wechat_verify(
echostr: str,
msg_signature: str,

View File

@@ -52,9 +52,6 @@ class _CollectingMoviePilotAgent(MoviePilotAgent):
if self.stream_mode:
self.stream_handler.emit(text)
async def _save_agent_message_to_db(self, message: str, title: str = ""):
return None
class _OpenAIStreamingHandler(StreamingHandler):
"""

View File

@@ -1478,8 +1478,6 @@ class ChainBase(metaclass=ABCMeta):
if not message:
logger.warning("消息为空,跳过发送")
return
# 保存消息
self.messagehelper.put(message, role="user", title=message.title)
self.messageoper.add(**message.model_dump())
dispatch_message = self._normalize_notification_for_dispatch(message)
# 发送消息按设置隔离
@@ -1595,8 +1593,6 @@ class ChainBase(metaclass=ABCMeta):
if not message:
logger.warning("消息为空,跳过发送")
return
# 保存消息
self.messagehelper.put(message, role="user", title=message.title)
await self.messageoper.async_add(**message.model_dump())
dispatch_message = self._normalize_notification_for_dispatch(message)
# 发送消息按设置隔离
@@ -1688,9 +1684,6 @@ class ChainBase(metaclass=ABCMeta):
:return: 成功或失败
"""
note_list = [media.to_dict() for media in medias]
self.messagehelper.put(
message, role="user", note=note_list, title=message.title
)
self.messageoper.add(**message.model_dump(), note=note_list)
dispatch_message = self._normalize_notification_for_dispatch(message)
return self.messagequeue.send_message(
@@ -1710,9 +1703,6 @@ class ChainBase(metaclass=ABCMeta):
:return: 成功或失败
"""
note_list = [torrent.torrent_info.to_dict() for torrent in torrents]
self.messagehelper.put(
message, role="user", note=note_list, title=message.title
)
self.messageoper.add(**message.model_dump(), note=note_list)
dispatch_message = self._normalize_notification_for_dispatch(message)
return self.messagequeue.send_message(

View File

@@ -197,7 +197,15 @@ class MessageChain(ChainBase):
)
return
if not text.startswith("CALLBACK:"):
is_agent_message = self._is_agent_message(
userid=userid,
text=text,
images=images,
files=files,
has_audio_input=has_audio_input,
)
if not text.startswith("CALLBACK:") and not is_agent_message:
self._record_user_message(
channel=channel,
source=source,
@@ -206,14 +214,7 @@ class MessageChain(ChainBase):
text=text,
)
if not self._is_agent_message(
channel=channel,
userid=userid,
text=text,
images=images,
files=files,
has_audio_input=has_audio_input,
):
if not is_agent_message:
processing_status = self._mark_message_processing_started(
channel=channel,
source=source,
@@ -430,7 +431,6 @@ class MessageChain(ChainBase):
def _is_agent_message(
self,
channel: MessageChannel,
userid: Union[str, int],
text: str,
images: Optional[List[CommingMessage.MessageImage]] = None,
@@ -766,13 +766,6 @@ class MessageChain(ChainBase):
selected_label=option.label,
)
self._bind_session_id(userid, request.session_id)
self._record_user_message(
channel=channel,
source=source,
userid=userid,
username=username,
text=selected_text,
)
return self._handle_ai_message(
text=selected_text,
channel=channel,
@@ -954,7 +947,6 @@ class MessageChain(ChainBase):
session_prefix=f"__agent_manual_redo_{history_id}",
output_callback=_capture_output,
reply_mode=ReplyMode.CAPTURE_ONLY,
persist_output_message=False,
allow_message_tools=False,
)
await self.async_post_message(

View File

@@ -395,7 +395,6 @@ class SearchChain(ChainBase):
session_prefix="__agent_search_recommend",
output_callback=on_output,
reply_mode=ReplyMode.CAPTURE_ONLY,
persist_output_message=False,
allow_message_tools=False,
)
return full_output[0].strip()

View File

@@ -1,6 +1,7 @@
import time
from typing import Optional, Union
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import DbOper
@@ -13,7 +14,7 @@ class MessageOper(DbOper):
消息数据管理
"""
def __init__(self, db: Session = None):
def __init__(self, db: Union[Session, AsyncSession] = None):
super().__init__(db)
def add(self,
@@ -27,7 +28,7 @@ class MessageOper(DbOper):
userid: Optional[str] = None,
action: Optional[int] = 1,
note: Union[list, dict] = None,
**kwargs):
**kwargs) -> dict:
"""
新增消息
:param channel: 消息渠道
@@ -60,7 +61,7 @@ class MessageOper(DbOper):
if k not in Message.__table__.columns.keys(): # noqa
kwargs.pop(k)
Message(**kwargs).create(self._db)
return Message(**kwargs).create_and_to_dict(self._db)
async def async_add(self,
channel: MessageChannel = None,
@@ -73,7 +74,7 @@ class MessageOper(DbOper):
userid: Optional[str] = None,
action: Optional[int] = 1,
note: Union[list, dict] = None,
**kwargs):
**kwargs) -> Message:
"""
异步新增消息
"""
@@ -96,10 +97,26 @@ class MessageOper(DbOper):
if k not in Message.__table__.columns.keys(): # noqa
kwargs.pop(k)
await Message(**kwargs).async_create(self._db)
return await Message(**kwargs).async_create(self._db)
def list_by_page(self, page: Optional[int] = 1, count: Optional[int] = 30) -> Optional[str]:
def list_by_page(self, page: Optional[int] = 1, count: Optional[int] = 30) -> list[Message]:
"""
获取媒体服务器数据ID
分页获取消息记录。
"""
return Message.list_by_page(self._db, page, count)
async def async_list_by_page(
self, page: Optional[int] = 1, count: Optional[int] = 30
) -> list[Message]:
"""
分页获取消息记录。
"""
return await Message.async_list_by_page(self._db, page, count)
async def async_list_sent_by_page(
self, page: Optional[int] = 1, count: Optional[int] = 30
) -> list[Message]:
"""
分页获取系统发送的通知消息。
"""
return await Message.async_list_sent_by_page(self._db, page, count)

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import List, Optional
from sqlalchemy import Column, Integer, String, JSON, Index, select
from sqlalchemy.ext.asyncio import AsyncSession
@@ -39,16 +39,59 @@ class Message(Base):
Index('ix_message_reg_time_id', 'reg_time', 'id'),
)
@db_update
def create_and_to_dict(self, db: Session) -> dict:
"""
创建消息记录并返回写入后的字段字典。
"""
db.add(self)
db.flush()
return self.to_dict()
@classmethod
@db_query
def list_by_page(cls, db: Session, page: Optional[int] = 1, count: Optional[int] = 30):
return db.query(cls).order_by(cls.reg_time.desc()).offset((page - 1) * count).limit(count).all()
def list_by_page(cls, db: Session, page: Optional[int] = 1, count: Optional[int] = 30) -> List["Message"]:
"""
分页获取消息记录。
"""
return (
db.query(cls)
.order_by(cls.reg_time.desc(), cls.id.desc())
.offset((page - 1) * count)
.limit(count)
.all()
)
@classmethod
@async_db_query
async def async_list_by_page(cls, db: AsyncSession, page: Optional[int] = 1, count: Optional[int] = 30):
async def async_list_by_page(
cls, db: AsyncSession, page: Optional[int] = 1, count: Optional[int] = 30
) -> List["Message"]:
"""
异步分页获取消息记录。
"""
result = await db.execute(
select(cls).order_by(cls.reg_time.desc()).offset((page - 1) * count).limit(count)
select(cls)
.order_by(cls.reg_time.desc(), cls.id.desc())
.offset((page - 1) * count)
.limit(count)
)
return result.scalars().all()
@classmethod
@async_db_query
async def async_list_sent_by_page(
cls, db: AsyncSession, page: Optional[int] = 1, count: Optional[int] = 30
) -> List["Message"]:
"""
分页获取系统发送的通知消息。
"""
result = await db.execute(
select(cls)
.where(cls.action == 1)
.order_by(cls.reg_time.desc(), cls.id.desc())
.offset((page - 1) * count)
.limit(count)
)
return result.scalars().all()

View File

@@ -764,61 +764,74 @@ class MessageQueueManager(metaclass=SingletonClass):
class MessageHelper(metaclass=Singleton):
"""
消息队列管理器,包括系统消息和用户消息
消息队列管理器,负责系统和插件实时消息的 SSE 推送
"""
def __init__(self):
self.sys_queue = queue.Queue()
self.user_queue = queue.Queue()
self._recent_notification_keys = TTLCache(region="message:notification", maxsize=500, ttl=60)
@staticmethod
def _build_system_notification_key(
message: Any, role: str, title: str = None, note: Union[list, dict] = None
) -> str:
"""
构建系统通知短期去重键。
"""
return json.dumps(
{
"role": role,
"title": title or "",
"text": str(message),
"note": note or {},
"time": time.strftime("%Y-%m-%d %H:%M", time.localtime()),
},
ensure_ascii=False,
sort_keys=True,
)
def _is_recent_system_notification(
self, message: Any, role: str, title: str = None, note: Union[list, dict] = None
) -> bool:
"""
判断系统通知是否在短时间内重复。
"""
key = self._build_system_notification_key(message, role, title=title, note=note)
if self._recent_notification_keys.get(key):
return True
self._recent_notification_keys.set(key, True)
return False
def put(self, message: Any, role: str = "plugin", title: str = None, note: Union[list, dict] = None):
"""
存消息
:param message: 消息
:param role: 消息通道 systm系统消息plugin插件消息user用户消息
:param role: 消息通道 system系统消息plugin插件消息
:param title: 标题
:param note: 附件json
"""
if role in ["system", "plugin"]:
# 没有标题时获取插件名称
if role == "plugin" and not title:
title = "插件通知"
# 系统通知,默认
self.sys_queue.put(json.dumps({
"type": role,
"title": title,
"text": message,
"date": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
"note": note
}))
else:
if isinstance(message, str):
# 非系统的文本通知
self.user_queue.put(json.dumps({
"title": title,
"text": message,
"date": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
"note": note
}))
elif hasattr(message, "to_dict"):
# 非系统的复杂结构通知,如媒体信息/种子列表等。
content = message.to_dict()
content['title'] = title
content['date'] = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
content['note'] = note
self.user_queue.put(json.dumps(content))
if role not in ["system", "plugin"]:
return
# 没有标题时获取插件名称
if role == "plugin" and not title:
title = "插件通知"
if self._is_recent_system_notification(message, role, title=title, note=note):
return
self.sys_queue.put(json.dumps({
"type": role,
"title": title,
"text": message,
"date": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
"note": note
}))
def get(self, role: str = "system") -> Optional[str]:
"""
取消息
:param role: 消息通道 systm系统消息plugin插件消息user用户消息
:param role: 兼容旧参数,当前所有 SSE 消息共用一个队列
"""
if role == "system":
if not self.sys_queue.empty():
return self.sys_queue.get(block=False)
else:
if not self.user_queue.empty():
return self.user_queue.get(block=False)
if not self.sys_queue.empty():
return self.sys_queue.get(block=False)
return None

View File

@@ -26,6 +26,37 @@ class MessageResponse(BaseModel):
success: bool = False
class NotificationHistoryItem(BaseModel):
"""
通知历史记录。
"""
# 消息ID
id: Optional[int] = None
# 消息渠道
channel: Optional[str] = None
# 消息来源
source: Optional[str] = None
# 消息类型
mtype: Optional[str] = None
# 标题
title: Optional[str] = None
# 文本内容
text: Optional[str] = None
# 图片
image: Optional[str] = None
# 链接
link: Optional[str] = None
# 用户ID
userid: Optional[str] = None
# 登记时间
reg_time: Optional[str] = None
# 消息方向0-接收消息1-发送消息
action: Optional[int] = None
# 附件json
note: Optional[Union[list, dict]] = None
class CommingMessage(BaseModel):
"""
外来消息

View File

@@ -66,7 +66,6 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
agent.channel = None
agent.source = None
agent.reply_mode = ReplyMode.CAPTURE_ONLY
agent.persist_output_message = True
agent._tool_context = {"user_reply_sent": False}
agent._streamed_output = ""
agent.stream_handler = SimpleNamespace(
@@ -77,15 +76,11 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
return_value=_FakeAgent([AIMessage(content="后台结果")])
)
agent.send_agent_message = AsyncMock()
agent._save_agent_message_to_db = AsyncMock()
with patch.object(memory_manager, "save_agent_messages") as save_messages:
await agent._execute_agent([])
agent.send_agent_message.assert_not_awaited()
agent._save_agent_message_to_db.assert_awaited_once_with(
"后台结果", title="MoviePilot助手"
)
save_messages.assert_called_once()
self.assertEqual("后台结果", agent._streamed_output)
@@ -105,7 +100,6 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
)
)
agent.send_agent_message = AsyncMock()
agent._save_agent_message_to_db = AsyncMock()
result, _ = await agent._execute_agent(
[
@@ -122,7 +116,6 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
agent.send_agent_message.assert_awaited_once_with(
UNSUPPORTED_IMAGE_INPUT_MESSAGE, title=""
)
agent._save_agent_message_to_db.assert_not_awaited()
self.assertEqual(UNSUPPORTED_IMAGE_INPUT_MESSAGE, agent._streamed_output)
async def test_streaming_image_unsupported_error_sends_friendly_notice(self):
@@ -144,7 +137,6 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
)
)
agent.send_agent_message = AsyncMock()
agent._save_agent_message_to_db = AsyncMock()
result, _ = await agent._execute_agent(
[
@@ -161,7 +153,6 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
agent.send_agent_message.assert_awaited_once_with(
UNSUPPORTED_IMAGE_INPUT_MESSAGE, title=""
)
agent._save_agent_message_to_db.assert_not_awaited()
async def test_streaming_model_chunk_timeout_sends_friendly_notice(self):
"""流式模型分块超时时应只把主错误信息发给用户。"""
@@ -186,7 +177,6 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
return_value=_FakeStreamingFailingAgent(raw_error)
)
agent.send_agent_message = AsyncMock()
agent._save_agent_message_to_db = AsyncMock()
result, _ = await agent._execute_agent([HumanMessage(content="测试超时")])
@@ -200,14 +190,12 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
self.assertIn("No streaming chunk received for 120.0s", sent_message)
self.assertNotIn("Tune or disable", sent_message)
self.assertEqual(expected, agent._streamed_output)
agent._save_agent_message_to_db.assert_not_awaited()
async def test_background_non_streaming_sends_when_reply_mode_dispatch(self):
agent = MoviePilotAgent(session_id="bg-test", user_id="system")
agent.channel = None
agent.source = None
agent.reply_mode = ReplyMode.DISPATCH
agent.persist_output_message = False
agent._tool_context = {"user_reply_sent": False}
agent._streamed_output = ""
agent.stream_handler = SimpleNamespace(
@@ -218,7 +206,6 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
return_value=_FakeAgent([AIMessage(content="后台结果")])
)
agent.send_agent_message = AsyncMock()
agent._save_agent_message_to_db = AsyncMock()
with patch.object(memory_manager, "save_agent_messages") as save_messages:
await agent._execute_agent([])
@@ -226,16 +213,14 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
agent.send_agent_message.assert_awaited_once_with(
"后台结果", title="MoviePilot助手"
)
agent._save_agent_message_to_db.assert_not_awaited()
save_messages.assert_called_once()
self.assertEqual("后台结果", agent._streamed_output)
async def test_background_non_streaming_persists_without_sending_when_capture_only(self):
async def test_background_non_streaming_captures_without_sending_when_capture_only(self):
agent = MoviePilotAgent(session_id="bg-test", user_id="system")
agent.channel = None
agent.source = None
agent.reply_mode = ReplyMode.CAPTURE_ONLY
agent.persist_output_message = True
agent._tool_context = {"user_reply_sent": False}
agent._streamed_output = ""
agent.stream_handler = SimpleNamespace(
@@ -246,15 +231,11 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
return_value=_FakeAgent([AIMessage(content="后台结果")])
)
agent.send_agent_message = AsyncMock()
agent._save_agent_message_to_db = AsyncMock()
with patch.object(memory_manager, "save_agent_messages") as save_messages:
await agent._execute_agent([])
agent.send_agent_message.assert_not_awaited()
agent._save_agent_message_to_db.assert_awaited_once_with(
"后台结果", title="MoviePilot助手"
)
save_messages.assert_called_once()
self.assertEqual("后台结果", agent._streamed_output)
@@ -279,7 +260,6 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
process_message.assert_awaited_once()
kwargs = process_message.await_args.kwargs
self.assertEqual(ReplyMode.CAPTURE_ONLY, kwargs["reply_mode"])
self.assertFalse(kwargs["persist_output_message"])
self.assertTrue(kwargs["allow_message_tools"])
async def test_heartbeat_check_jobs_skips_when_no_active_jobs(self):

View File

@@ -569,8 +569,8 @@ class AgentImageSupportTest(unittest.TestCase):
self.assertEqual(payload.image_url, "https://example.com/poster.png")
def test_send_message_tool_uses_agent_notification_type(self):
"""发送消息工具应固定使用智能体消息类型"""
def test_send_message_tool_uses_regular_notification_type(self):
"""发送消息工具应按普通通知消息登记"""
async def _run():
tool = SendMessageTool(session_id="session-1", user_id="10001")
@@ -595,7 +595,7 @@ class AgentImageSupportTest(unittest.TestCase):
notification = async_post_message.await_args.args[0]
self.assertEqual(result, "消息已发送")
self.assertEqual(notification.mtype, NotificationType.Agent)
self.assertEqual(notification.mtype, NotificationType.Other)
self.assertEqual(notification.channel, MessageChannel.Telegram)
self.assertEqual(notification.source, "telegram-test")
self.assertEqual(notification.title, "智能体通知")

View File

@@ -204,8 +204,8 @@ class TestAgentInteraction(unittest.TestCase):
self.assertEqual(kwargs["channel"], MessageChannel.Telegram.value)
self.assertEqual(kwargs["source"], "telegram-test")
self.assertNotIn("processing_status", kwargs)
message_put.assert_called_once()
message_add.assert_called_once()
message_put.assert_not_called()
message_add.assert_not_called()
def test_legacy_agent_choice_callback_still_supported(self):
chain = MessageChain()

View File

@@ -1,7 +1,8 @@
from unittest.mock import patch
from unittest.mock import AsyncMock, Mock, patch
from app.chain.message import MessageChain
from app.helper.interaction import media_interaction_manager
from app.core.config import settings
from app.helper.interaction import AgentInteractionOption, agent_interaction_manager, media_interaction_manager
from app.schemas.types import MessageChannel
@@ -38,3 +39,73 @@ def test_explicit_ai_message_bypasses_pending_media_interaction():
handle_ai_message.assert_called_once()
handle_media_interaction.assert_not_called()
def test_explicit_ai_message_is_not_recorded_to_message_history():
"""显式 /ai 消息不登记到数据库或实时消息队列。"""
chain = MessageChain()
with patch.object(settings, "AI_AGENT_ENABLE", True), patch.object(
chain, "_record_user_message"
) as record_user_message, patch(
"app.chain.message.agent_manager.process_message",
new_callable=AsyncMock,
) as process_message, patch(
"app.chain.message.asyncio.run_coroutine_threadsafe",
side_effect=lambda coro, _loop: (coro.close(), Mock())[1],
):
chain.handle_message(
channel=MessageChannel.Telegram,
source="telegram-test",
userid="10001",
username="tester",
text="/ai 帮我检查订阅",
)
record_user_message.assert_not_called()
process_message.assert_called_once()
def test_agent_choice_callback_is_not_recorded_to_message_history():
"""Agent 按钮选择回传不登记到数据库或实时消息队列。"""
chain = MessageChain()
request = agent_interaction_manager.create_request(
session_id="session-choice",
user_id="10001",
channel=MessageChannel.Telegram.value,
source="telegram-test",
username="tester",
title="需要你的选择",
prompt="请选择",
options=[
AgentInteractionOption(label="电影", value="我选择电影"),
AgentInteractionOption(label="电视剧", value="我选择电视剧"),
],
)
try:
with patch.object(settings, "AI_AGENT_ENABLE", True), patch.object(
chain, "_record_user_message"
) as record_user_message, patch.object(
chain, "edit_message", return_value=True
), patch(
"app.chain.message.agent_manager.process_message",
new_callable=AsyncMock,
) as process_message, patch(
"app.chain.message.asyncio.run_coroutine_threadsafe",
side_effect=lambda coro, _loop: (coro.close(), Mock())[1],
):
chain._handle_callback(
text=f"CALLBACK:agent_interaction:choice:{request.request_id}:1",
channel=MessageChannel.Telegram,
source="telegram-test",
userid="10001",
username="tester",
original_message_id=123,
original_chat_id="456",
)
finally:
agent_interaction_manager.clear()
record_user_message.assert_not_called()
process_message.assert_called_once()

View File

@@ -94,7 +94,6 @@ class AgentTokensEventsTest(unittest.IsolatedAsyncioTestCase):
stop_streaming=AsyncMock(return_value=(False, ""))
)
agent.send_agent_message = AsyncMock()
agent._save_agent_message_to_db = AsyncMock()
async def create_agent(_streaming=False, streaming=False):
"""模拟创建 Agent 时完成供应商选择和用量统计。"""

View File

@@ -0,0 +1,169 @@
import json
from unittest.mock import Mock
from app.db import SessionFactory
from app.db.message_oper import MessageOper
from app.db.models.message import Message
from app.chain import ChainBase
from app.helper.message import MessageHelper
from app.schemas import Notification
from app.schemas.types import NotificationType
def _clear_messages() -> None:
"""
清空消息表,隔离通知测试数据。
"""
with SessionFactory() as db:
db.query(Message).delete()
db.commit()
def _reset_message_helper(helper: MessageHelper) -> None:
"""
清空单例消息队列和去重缓存,避免用例间互相影响。
"""
while helper.get() is not None:
pass
helper._recent_notification_keys.clear()
def test_notification_history_only_lists_sent_messages() -> None:
"""
通知历史应返回已发送消息,包含通过消息链登记的智能体消息。
"""
_clear_messages()
oper = MessageOper()
oper.add(title="系统通知", text="下载完成", action=1, mtype=NotificationType.Download)
oper.add(title="用户消息", text="帮我搜索", action=0)
oper.add(title="智能体回复", text="已处理", action=1, mtype=NotificationType.Agent)
messages = MessageOper().list_by_page(page=1, count=10)
assert [message.title for message in messages if message.action == 1] == ["智能体回复", "系统通知"]
def test_web_message_history_returns_all_messages() -> None:
"""
Web 消息历史返回消息表中的全部记录。
"""
_clear_messages()
oper = MessageOper()
oper.add(title="智能体回复", text="已处理", action=1, mtype=NotificationType.Agent)
oper.add(title="用户消息", text="/ai 帮我处理", action=0)
oper.add(title="普通通知", text="下载完成", action=1, mtype=NotificationType.Download)
messages = MessageOper().list_by_page(page=1, count=10)
assert [message.title for message in messages] == ["普通通知", "用户消息", "智能体回复"]
def test_system_helper_message_only_enters_sse_queue() -> None:
"""
系统实时消息只进入前端 SSE 队列,不写入通知历史。
"""
_clear_messages()
helper = MessageHelper()
_reset_message_helper(helper)
helper.put("调度任务执行失败", role="system", title="系统错误")
assert MessageOper().list_by_page(page=1, count=10) == []
realtime_message = json.loads(helper.get())
assert realtime_message["type"] == "system"
assert realtime_message["title"] == "系统错误"
assert realtime_message["text"] == "调度任务执行失败"
def test_plugin_helper_message_deduplicates_recent_sse_messages() -> None:
"""
短时间内相同插件实时消息只应推送一次,不写入通知历史。
"""
_clear_messages()
helper = MessageHelper()
_reset_message_helper(helper)
helper.put("站点刷流任务出错,获取下载器实例失败,请检查配置", role="plugin", title="站点刷流")
helper.put("站点刷流任务出错,获取下载器实例失败,请检查配置", role="plugin", title="站点刷流")
assert MessageOper().list_by_page(page=1, count=10) == []
assert json.loads(helper.get())["title"] == "站点刷流"
assert helper.get() is None
def test_agent_helper_message_does_not_enter_sse_queue() -> None:
"""
智能体消息不进入前端 SSE 队列。
"""
helper = MessageHelper()
_reset_message_helper(helper)
helper.put("智能体回复", role="agent", title="MoviePilot助手")
assert helper.get() is None
def test_user_helper_message_does_not_enter_sse_queue() -> None:
"""
用户消息不进入前端 SSE 队列。
"""
helper = MessageHelper()
_reset_message_helper(helper)
helper.put("用户输入", role="user", title="admin")
assert helper.get() is None
def test_notification_post_message_is_persisted_without_sse_queue() -> None:
"""
业务通知通过消息链发送时只登记数据库,不进入前端 SSE 队列。
"""
_clear_messages()
helper = MessageHelper()
_reset_message_helper(helper)
chain = ChainBase()
chain.messagequeue.send_message = Mock()
chain.eventmanager.send_event = Mock()
chain.post_message(
Notification(
mtype=NotificationType.Download,
title="下载完成",
text="影片已加入下载器",
)
)
messages = MessageOper().list_by_page(page=1, count=10)
assert len(messages) == 1
assert messages[0].title == "下载完成"
assert messages[0].mtype == NotificationType.Download.value
assert helper.get() is None
chain.messagequeue.send_message.assert_called_once()
def test_agent_notification_post_message_is_persisted_without_sse_queue() -> None:
"""
智能体消息通过消息链发送时登记数据库,但不进入前端 SSE 队列。
"""
_clear_messages()
helper = MessageHelper()
_reset_message_helper(helper)
chain = ChainBase()
chain.messagequeue.send_message = Mock()
chain.eventmanager.send_event = Mock()
chain.post_message(
Notification(
mtype=NotificationType.Agent,
title="MoviePilot助手",
text="已完成处理",
)
)
messages = MessageOper().list_by_page(page=1, count=10)
assert len(messages) == 1
assert messages[0].title == "MoviePilot助手"
assert messages[0].mtype == NotificationType.Agent.value
assert helper.get() is None
chain.messagequeue.send_message.assert_called_once()

View File

@@ -122,7 +122,6 @@ class SearchChainAIRecommendTest(unittest.IsolatedAsyncioTestCase):
self.assertEqual("[0, 2]", result)
self.assertEqual(ReplyMode.CAPTURE_ONLY, captured["reply_mode"])
self.assertFalse(captured["persist_output_message"])
self.assertFalse(captured["allow_message_tools"])
def test_search_by_title_clears_previous_recommend_state_when_caching(self):