diff --git a/app/agent/__init__.py b/app/agent/__init__.py index e031b3d0..605a6551 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -246,7 +246,6 @@ class MoviePilotAgent: original_message_id: Optional[str] = None, original_chat_id: Optional[str] = None, replay_mode: ReplyMode = ReplyMode.DISPATCH, - persist_output_message: bool = True, allow_message_tools: bool = True, output_callback: Optional[Callable[[str], None]] = None, ): @@ -258,7 +257,6 @@ class MoviePilotAgent: self.original_message_id = original_message_id self.original_chat_id = original_chat_id self.reply_mode = replay_mode - self.persist_output_message = persist_output_message self.allow_message_tools = allow_message_tools self.output_callback = output_callback self._tool_context: Dict[str, object] = {} @@ -782,8 +780,6 @@ class MoviePilotAgent: title = "MoviePilot助手" if self.is_background else "" if self.should_dispatch_reply: await self.send_agent_message(message, title=title) - elif self.persist_output_message: - await self._save_agent_message_to_db(message, title=title) def _emit_output(self, text: str): """ @@ -1126,19 +1122,6 @@ class MoviePilotAgent: and not self._tool_context.get("user_reply_sent") ): await self.send_agent_message(remaining_text) - elif ( - remaining_text - and self.persist_output_message - and not self._tool_context.get("user_reply_sent") - ): - title = "MoviePilot助手" if self.is_background else "" - await self._save_agent_message_to_db( - remaining_text, - title=title, - ) - elif streamed_text and self.persist_output_message: - # 流式输出已发送全部内容,但未记录到数据库,补充保存消息记录 - await self._save_agent_message_to_db(streamed_text) else: # 非流式模式:后台任务或渠道不支持消息编辑 @@ -1180,13 +1163,6 @@ class MoviePilotAgent: else: # 非流式渠道:发送最终回复 await self.send_agent_message(final_text) - elif ( - final_text - and self.persist_output_message - and not self._tool_context.get("user_reply_sent") - ): - title = "MoviePilot助手" if self.is_background else "" - await self._save_agent_message_to_db(final_text, title=title) # 保存消息 memory_manager.save_agent_messages( @@ -1238,26 +1214,6 @@ class MoviePilotAgent: ) ) - async def _save_agent_message_to_db(self, message: str, title: str = ""): - """ - 仅保存Agent回复消息到数据库和SSE队列(不重新发送到渠道) - 用于流式输出场景:消息已通过 send_direct_message/edit_message 发送给用户, - 但未记录到数据库中,此方法补充保存消息历史记录。 - """ - chain = AgentChain() - notification = Notification( - channel=self.channel, - source=self.source, - userid=self.user_id, - username=self.username, - title=title, - text=message, - ) - # 保存到SSE消息队列(供前端展示) - chain.messagehelper.put(notification, role="user", title=title) - # 保存到数据库 - await chain.messageoper.async_add(**notification.model_dump()) - async def cleanup(self): """ 清理智能体资源 @@ -1284,7 +1240,6 @@ class _MessageTask: original_chat_id: Optional[str] = None processing_status: Optional[dict] = None reply_mode: ReplyMode = ReplyMode.DISPATCH - persist_output_message: bool = True allow_message_tools: bool = True @@ -1430,7 +1385,6 @@ class AgentManager: original_message_id: Optional[str] = None, original_chat_id: Optional[str] = None, reply_mode: ReplyMode = ReplyMode.DISPATCH, - persist_output_message: bool = True, allow_message_tools: bool = True, ) -> str: """ @@ -1450,7 +1404,6 @@ class AgentManager: original_message_id=original_message_id, original_chat_id=original_chat_id, reply_mode=reply_mode, - persist_output_message=persist_output_message, allow_message_tools=allow_message_tools, ) self._record_session_activity(session_id, user_id) @@ -1561,7 +1514,6 @@ class AgentManager: original_message_id=task.original_message_id, original_chat_id=task.original_chat_id, replay_mode=task.reply_mode, - persist_output_message=task.persist_output_message, allow_message_tools=task.allow_message_tools, ) self.active_agents[session_id] = agent @@ -1577,7 +1529,6 @@ class AgentManager: agent.original_message_id = task.original_message_id agent.original_chat_id = task.original_chat_id agent.reply_mode = task.reply_mode - agent.persist_output_message = task.persist_output_message agent.allow_message_tools = task.allow_message_tools process_kwargs = { @@ -1656,7 +1607,6 @@ class AgentManager: session_prefix: str = "__agent_background", output_callback: Optional[Callable[[str], None]] = None, reply_mode: ReplyMode = ReplyMode.CAPTURE_ONLY, - persist_output_message: bool = True, allow_message_tools: Optional[bool] = None, ) -> None: """ @@ -1677,7 +1627,6 @@ class AgentManager: source=None, username=settings.SUPERUSER, replay_mode=reply_mode, - persist_output_message=persist_output_message, output_callback=output_callback, allow_message_tools=allow_message_tools, ) @@ -1723,7 +1672,6 @@ class AgentManager: source=None, username=settings.SUPERUSER, reply_mode=ReplyMode.CAPTURE_ONLY, - persist_output_message=False, allow_message_tools=True, ) diff --git a/app/agent/tools/impl/send_message.py b/app/agent/tools/impl/send_message.py index 71ae1655..42d5774a 100644 --- a/app/agent/tools/impl/send_message.py +++ b/app/agent/tools/impl/send_message.py @@ -7,6 +7,8 @@ from pydantic import BaseModel, Field, model_validator from app.agent.tools.base import MoviePilotTool from app.agent.tools.tags import ToolTag from app.log import logger +from app.schemas import Notification +from app.schemas.types import NotificationType class SendMessageInput(BaseModel): @@ -77,7 +79,18 @@ class SendMessageTool(MoviePilotTool): f"执行工具: {self.name}, 参数: title={title}, message={text}, image_url={image_url}" ) try: - await self.send_tool_message(text, title=title, image=image_url) + await self.send_notification_message( + Notification( + channel=self._channel, + source=self._source, + mtype=NotificationType.Other, + userid=self._user_id, + username=self._username, + title=title, + text=text, + image=image_url, + ) + ) return "消息已发送" except Exception as e: logger.error(f"发送消息失败: {e}") diff --git a/app/api/endpoints/agent.py b/app/api/endpoints/agent.py index d36a13ec..321033ff 100644 --- a/app/api/endpoints/agent.py +++ b/app/api/endpoints/agent.py @@ -881,7 +881,6 @@ async def web_agent_stream( source=WEB_AGENT_SOURCE, username=current_user.name, replay_mode=ReplyMode.CAPTURE_ONLY, - persist_output_message=False, allow_message_tools=True, output_callback=output_callback, notification_callback=notification_callback, diff --git a/app/api/endpoints/history.py b/app/api/endpoints/history.py index 328f96ee..dbdc2488 100644 --- a/app/api/endpoints/history.py +++ b/app/api/endpoints/history.py @@ -136,7 +136,6 @@ def _start_ai_redo_task(history_id: int, prompt: str, progress_key: str): session_prefix=f"__agent_manual_redo_{history_id}", output_callback=update_output, reply_mode=ReplyMode.CAPTURE_ONLY, - persist_output_message=False, allow_message_tools=False, ) progress.update( @@ -182,7 +181,6 @@ def _start_batch_ai_redo_task( session_prefix="__agent_manual_redo_batch", output_callback=update_output, reply_mode=ReplyMode.CAPTURE_ONLY, - persist_output_message=False, allow_message_tools=False, ) progress.update( diff --git a/app/api/endpoints/message.py b/app/api/endpoints/message.py index c0e29f93..cd68fb5b 100644 --- a/app/api/endpoints/message.py +++ b/app/api/endpoints/message.py @@ -12,7 +12,7 @@ from app.core.config import settings, global_vars from app.core.security import verify_token, verify_apitoken from app.db import get_async_db from app.db.models import User -from app.db.models.message import Message +from app.db.message_oper import MessageOper from app.db.user_oper import get_current_active_superuser from app.helper.service import ServiceConfigHelper from app.helper.webpush import is_webpush_subscription_gone @@ -120,7 +120,7 @@ async def get_web_message( 获取WEB消息列表 """ ret_messages = [] - messages = await Message.async_list_by_page(db, page=page, count=count) + messages = await MessageOper(db).async_list_by_page(page=page, count=count) for message in messages: try: ret_messages.append(message.to_dict()) @@ -130,6 +130,20 @@ async def get_web_message( return ret_messages +@router.get("/notification", summary="获取通知消息", response_model=List[schemas.NotificationHistoryItem]) +async def get_notification_message( + _: schemas.TokenPayload = Depends(verify_token), + db: AsyncSession = Depends(get_async_db), + page: Optional[int] = 1, + count: Optional[int] = 20, +): + """ + 获取系统发送的通知消息列表。 + """ + messages = await MessageOper(db).async_list_sent_by_page(page=page, count=count) + return [schemas.NotificationHistoryItem(**message.to_dict()) for message in messages] + + def wechat_verify( echostr: str, msg_signature: str, diff --git a/app/api/endpoints/openai.py b/app/api/endpoints/openai.py index 6c244fd7..9bda42dc 100644 --- a/app/api/endpoints/openai.py +++ b/app/api/endpoints/openai.py @@ -52,9 +52,6 @@ class _CollectingMoviePilotAgent(MoviePilotAgent): if self.stream_mode: self.stream_handler.emit(text) - async def _save_agent_message_to_db(self, message: str, title: str = ""): - return None - class _OpenAIStreamingHandler(StreamingHandler): """ diff --git a/app/chain/__init__.py b/app/chain/__init__.py index a00ea96d..8d488909 100644 --- a/app/chain/__init__.py +++ b/app/chain/__init__.py @@ -1478,8 +1478,6 @@ class ChainBase(metaclass=ABCMeta): if not message: logger.warning("消息为空,跳过发送") return - # 保存消息 - self.messagehelper.put(message, role="user", title=message.title) self.messageoper.add(**message.model_dump()) dispatch_message = self._normalize_notification_for_dispatch(message) # 发送消息按设置隔离 @@ -1595,8 +1593,6 @@ class ChainBase(metaclass=ABCMeta): if not message: logger.warning("消息为空,跳过发送") return - # 保存消息 - self.messagehelper.put(message, role="user", title=message.title) await self.messageoper.async_add(**message.model_dump()) dispatch_message = self._normalize_notification_for_dispatch(message) # 发送消息按设置隔离 @@ -1688,9 +1684,6 @@ class ChainBase(metaclass=ABCMeta): :return: 成功或失败 """ note_list = [media.to_dict() for media in medias] - self.messagehelper.put( - message, role="user", note=note_list, title=message.title - ) self.messageoper.add(**message.model_dump(), note=note_list) dispatch_message = self._normalize_notification_for_dispatch(message) return self.messagequeue.send_message( @@ -1710,9 +1703,6 @@ class ChainBase(metaclass=ABCMeta): :return: 成功或失败 """ note_list = [torrent.torrent_info.to_dict() for torrent in torrents] - self.messagehelper.put( - message, role="user", note=note_list, title=message.title - ) self.messageoper.add(**message.model_dump(), note=note_list) dispatch_message = self._normalize_notification_for_dispatch(message) return self.messagequeue.send_message( diff --git a/app/chain/message.py b/app/chain/message.py index 506423b5..f924f101 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -197,7 +197,15 @@ class MessageChain(ChainBase): ) return - if not text.startswith("CALLBACK:"): + is_agent_message = self._is_agent_message( + userid=userid, + text=text, + images=images, + files=files, + has_audio_input=has_audio_input, + ) + + if not text.startswith("CALLBACK:") and not is_agent_message: self._record_user_message( channel=channel, source=source, @@ -206,14 +214,7 @@ class MessageChain(ChainBase): text=text, ) - if not self._is_agent_message( - channel=channel, - userid=userid, - text=text, - images=images, - files=files, - has_audio_input=has_audio_input, - ): + if not is_agent_message: processing_status = self._mark_message_processing_started( channel=channel, source=source, @@ -430,7 +431,6 @@ class MessageChain(ChainBase): def _is_agent_message( self, - channel: MessageChannel, userid: Union[str, int], text: str, images: Optional[List[CommingMessage.MessageImage]] = None, @@ -766,13 +766,6 @@ class MessageChain(ChainBase): selected_label=option.label, ) self._bind_session_id(userid, request.session_id) - self._record_user_message( - channel=channel, - source=source, - userid=userid, - username=username, - text=selected_text, - ) return self._handle_ai_message( text=selected_text, channel=channel, @@ -954,7 +947,6 @@ class MessageChain(ChainBase): session_prefix=f"__agent_manual_redo_{history_id}", output_callback=_capture_output, reply_mode=ReplyMode.CAPTURE_ONLY, - persist_output_message=False, allow_message_tools=False, ) await self.async_post_message( diff --git a/app/chain/search.py b/app/chain/search.py index a7553148..6a2d52dc 100644 --- a/app/chain/search.py +++ b/app/chain/search.py @@ -395,7 +395,6 @@ class SearchChain(ChainBase): session_prefix="__agent_search_recommend", output_callback=on_output, reply_mode=ReplyMode.CAPTURE_ONLY, - persist_output_message=False, allow_message_tools=False, ) return full_output[0].strip() diff --git a/app/db/message_oper.py b/app/db/message_oper.py index aff7dd25..29ddb0cb 100644 --- a/app/db/message_oper.py +++ b/app/db/message_oper.py @@ -1,6 +1,7 @@ import time from typing import Optional, Union +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from app.db import DbOper @@ -13,7 +14,7 @@ class MessageOper(DbOper): 消息数据管理 """ - def __init__(self, db: Session = None): + def __init__(self, db: Union[Session, AsyncSession] = None): super().__init__(db) def add(self, @@ -27,7 +28,7 @@ class MessageOper(DbOper): userid: Optional[str] = None, action: Optional[int] = 1, note: Union[list, dict] = None, - **kwargs): + **kwargs) -> dict: """ 新增消息 :param channel: 消息渠道 @@ -60,7 +61,7 @@ class MessageOper(DbOper): if k not in Message.__table__.columns.keys(): # noqa kwargs.pop(k) - Message(**kwargs).create(self._db) + return Message(**kwargs).create_and_to_dict(self._db) async def async_add(self, channel: MessageChannel = None, @@ -73,7 +74,7 @@ class MessageOper(DbOper): userid: Optional[str] = None, action: Optional[int] = 1, note: Union[list, dict] = None, - **kwargs): + **kwargs) -> Message: """ 异步新增消息 """ @@ -96,10 +97,26 @@ class MessageOper(DbOper): if k not in Message.__table__.columns.keys(): # noqa kwargs.pop(k) - await Message(**kwargs).async_create(self._db) + return await Message(**kwargs).async_create(self._db) - def list_by_page(self, page: Optional[int] = 1, count: Optional[int] = 30) -> Optional[str]: + def list_by_page(self, page: Optional[int] = 1, count: Optional[int] = 30) -> list[Message]: """ - 获取媒体服务器数据ID + 分页获取消息记录。 """ return Message.list_by_page(self._db, page, count) + + async def async_list_by_page( + self, page: Optional[int] = 1, count: Optional[int] = 30 + ) -> list[Message]: + """ + 分页获取消息记录。 + """ + return await Message.async_list_by_page(self._db, page, count) + + async def async_list_sent_by_page( + self, page: Optional[int] = 1, count: Optional[int] = 30 + ) -> list[Message]: + """ + 分页获取系统发送的通知消息。 + """ + return await Message.async_list_sent_by_page(self._db, page, count) diff --git a/app/db/models/message.py b/app/db/models/message.py index e27e9675..f13b7380 100644 --- a/app/db/models/message.py +++ b/app/db/models/message.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import List, Optional from sqlalchemy import Column, Integer, String, JSON, Index, select from sqlalchemy.ext.asyncio import AsyncSession @@ -39,16 +39,59 @@ class Message(Base): Index('ix_message_reg_time_id', 'reg_time', 'id'), ) + @db_update + def create_and_to_dict(self, db: Session) -> dict: + """ + 创建消息记录并返回写入后的字段字典。 + """ + db.add(self) + db.flush() + return self.to_dict() + @classmethod @db_query - def list_by_page(cls, db: Session, page: Optional[int] = 1, count: Optional[int] = 30): - return db.query(cls).order_by(cls.reg_time.desc()).offset((page - 1) * count).limit(count).all() + def list_by_page(cls, db: Session, page: Optional[int] = 1, count: Optional[int] = 30) -> List["Message"]: + """ + 分页获取消息记录。 + """ + return ( + db.query(cls) + .order_by(cls.reg_time.desc(), cls.id.desc()) + .offset((page - 1) * count) + .limit(count) + .all() + ) @classmethod @async_db_query - async def async_list_by_page(cls, db: AsyncSession, page: Optional[int] = 1, count: Optional[int] = 30): + async def async_list_by_page( + cls, db: AsyncSession, page: Optional[int] = 1, count: Optional[int] = 30 + ) -> List["Message"]: + """ + 异步分页获取消息记录。 + """ result = await db.execute( - select(cls).order_by(cls.reg_time.desc()).offset((page - 1) * count).limit(count) + select(cls) + .order_by(cls.reg_time.desc(), cls.id.desc()) + .offset((page - 1) * count) + .limit(count) + ) + return result.scalars().all() + + @classmethod + @async_db_query + async def async_list_sent_by_page( + cls, db: AsyncSession, page: Optional[int] = 1, count: Optional[int] = 30 + ) -> List["Message"]: + """ + 分页获取系统发送的通知消息。 + """ + result = await db.execute( + select(cls) + .where(cls.action == 1) + .order_by(cls.reg_time.desc(), cls.id.desc()) + .offset((page - 1) * count) + .limit(count) ) return result.scalars().all() diff --git a/app/helper/message.py b/app/helper/message.py index c1b11641..1e2fc167 100644 --- a/app/helper/message.py +++ b/app/helper/message.py @@ -764,61 +764,74 @@ class MessageQueueManager(metaclass=SingletonClass): class MessageHelper(metaclass=Singleton): """ - 消息队列管理器,包括系统消息和用户消息 + 消息队列管理器,负责系统和插件实时消息的 SSE 推送 """ def __init__(self): self.sys_queue = queue.Queue() - self.user_queue = queue.Queue() + self._recent_notification_keys = TTLCache(region="message:notification", maxsize=500, ttl=60) + + @staticmethod + def _build_system_notification_key( + message: Any, role: str, title: str = None, note: Union[list, dict] = None + ) -> str: + """ + 构建系统通知短期去重键。 + """ + return json.dumps( + { + "role": role, + "title": title or "", + "text": str(message), + "note": note or {}, + "time": time.strftime("%Y-%m-%d %H:%M", time.localtime()), + }, + ensure_ascii=False, + sort_keys=True, + ) + + def _is_recent_system_notification( + self, message: Any, role: str, title: str = None, note: Union[list, dict] = None + ) -> bool: + """ + 判断系统通知是否在短时间内重复。 + """ + key = self._build_system_notification_key(message, role, title=title, note=note) + if self._recent_notification_keys.get(key): + return True + self._recent_notification_keys.set(key, True) + return False def put(self, message: Any, role: str = "plugin", title: str = None, note: Union[list, dict] = None): """ 存消息 :param message: 消息 - :param role: 消息通道 systm:系统消息,plugin:插件消息,user:用户消息 + :param role: 消息通道 system:系统消息,plugin:插件消息 :param title: 标题 :param note: 附件json """ - if role in ["system", "plugin"]: - # 没有标题时获取插件名称 - if role == "plugin" and not title: - title = "插件通知" - # 系统通知,默认 - self.sys_queue.put(json.dumps({ - "type": role, - "title": title, - "text": message, - "date": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), - "note": note - })) - else: - if isinstance(message, str): - # 非系统的文本通知 - self.user_queue.put(json.dumps({ - "title": title, - "text": message, - "date": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), - "note": note - })) - elif hasattr(message, "to_dict"): - # 非系统的复杂结构通知,如媒体信息/种子列表等。 - content = message.to_dict() - content['title'] = title - content['date'] = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) - content['note'] = note - self.user_queue.put(json.dumps(content)) + if role not in ["system", "plugin"]: + return + # 没有标题时获取插件名称 + if role == "plugin" and not title: + title = "插件通知" + if self._is_recent_system_notification(message, role, title=title, note=note): + return + self.sys_queue.put(json.dumps({ + "type": role, + "title": title, + "text": message, + "date": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), + "note": note + })) def get(self, role: str = "system") -> Optional[str]: """ 取消息 - :param role: 消息通道 systm:系统消息,plugin:插件消息,user:用户消息 + :param role: 兼容旧参数,当前所有 SSE 消息共用一个队列 """ - if role == "system": - if not self.sys_queue.empty(): - return self.sys_queue.get(block=False) - else: - if not self.user_queue.empty(): - return self.user_queue.get(block=False) + if not self.sys_queue.empty(): + return self.sys_queue.get(block=False) return None diff --git a/app/schemas/message.py b/app/schemas/message.py index 7da7c3a2..cfe1d348 100644 --- a/app/schemas/message.py +++ b/app/schemas/message.py @@ -26,6 +26,37 @@ class MessageResponse(BaseModel): success: bool = False +class NotificationHistoryItem(BaseModel): + """ + 通知历史记录。 + """ + + # 消息ID + id: Optional[int] = None + # 消息渠道 + channel: Optional[str] = None + # 消息来源 + source: Optional[str] = None + # 消息类型 + mtype: Optional[str] = None + # 标题 + title: Optional[str] = None + # 文本内容 + text: Optional[str] = None + # 图片 + image: Optional[str] = None + # 链接 + link: Optional[str] = None + # 用户ID + userid: Optional[str] = None + # 登记时间 + reg_time: Optional[str] = None + # 消息方向:0-接收消息,1-发送消息 + action: Optional[int] = None + # 附件json + note: Optional[Union[list, dict]] = None + + class CommingMessage(BaseModel): """ 外来消息 diff --git a/tests/test_agent_background_output.py b/tests/test_agent_background_output.py index 949da315..daea15a3 100644 --- a/tests/test_agent_background_output.py +++ b/tests/test_agent_background_output.py @@ -66,7 +66,6 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase): agent.channel = None agent.source = None agent.reply_mode = ReplyMode.CAPTURE_ONLY - agent.persist_output_message = True agent._tool_context = {"user_reply_sent": False} agent._streamed_output = "" agent.stream_handler = SimpleNamespace( @@ -77,15 +76,11 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase): return_value=_FakeAgent([AIMessage(content="后台结果")]) ) agent.send_agent_message = AsyncMock() - agent._save_agent_message_to_db = AsyncMock() with patch.object(memory_manager, "save_agent_messages") as save_messages: await agent._execute_agent([]) agent.send_agent_message.assert_not_awaited() - agent._save_agent_message_to_db.assert_awaited_once_with( - "后台结果", title="MoviePilot助手" - ) save_messages.assert_called_once() self.assertEqual("后台结果", agent._streamed_output) @@ -105,7 +100,6 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase): ) ) agent.send_agent_message = AsyncMock() - agent._save_agent_message_to_db = AsyncMock() result, _ = await agent._execute_agent( [ @@ -122,7 +116,6 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase): agent.send_agent_message.assert_awaited_once_with( UNSUPPORTED_IMAGE_INPUT_MESSAGE, title="" ) - agent._save_agent_message_to_db.assert_not_awaited() self.assertEqual(UNSUPPORTED_IMAGE_INPUT_MESSAGE, agent._streamed_output) async def test_streaming_image_unsupported_error_sends_friendly_notice(self): @@ -144,7 +137,6 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase): ) ) agent.send_agent_message = AsyncMock() - agent._save_agent_message_to_db = AsyncMock() result, _ = await agent._execute_agent( [ @@ -161,7 +153,6 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase): agent.send_agent_message.assert_awaited_once_with( UNSUPPORTED_IMAGE_INPUT_MESSAGE, title="" ) - agent._save_agent_message_to_db.assert_not_awaited() async def test_streaming_model_chunk_timeout_sends_friendly_notice(self): """流式模型分块超时时应只把主错误信息发给用户。""" @@ -186,7 +177,6 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase): return_value=_FakeStreamingFailingAgent(raw_error) ) agent.send_agent_message = AsyncMock() - agent._save_agent_message_to_db = AsyncMock() result, _ = await agent._execute_agent([HumanMessage(content="测试超时")]) @@ -200,14 +190,12 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase): self.assertIn("No streaming chunk received for 120.0s", sent_message) self.assertNotIn("Tune or disable", sent_message) self.assertEqual(expected, agent._streamed_output) - agent._save_agent_message_to_db.assert_not_awaited() async def test_background_non_streaming_sends_when_reply_mode_dispatch(self): agent = MoviePilotAgent(session_id="bg-test", user_id="system") agent.channel = None agent.source = None agent.reply_mode = ReplyMode.DISPATCH - agent.persist_output_message = False agent._tool_context = {"user_reply_sent": False} agent._streamed_output = "" agent.stream_handler = SimpleNamespace( @@ -218,7 +206,6 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase): return_value=_FakeAgent([AIMessage(content="后台结果")]) ) agent.send_agent_message = AsyncMock() - agent._save_agent_message_to_db = AsyncMock() with patch.object(memory_manager, "save_agent_messages") as save_messages: await agent._execute_agent([]) @@ -226,16 +213,14 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase): agent.send_agent_message.assert_awaited_once_with( "后台结果", title="MoviePilot助手" ) - agent._save_agent_message_to_db.assert_not_awaited() save_messages.assert_called_once() self.assertEqual("后台结果", agent._streamed_output) - async def test_background_non_streaming_persists_without_sending_when_capture_only(self): + async def test_background_non_streaming_captures_without_sending_when_capture_only(self): agent = MoviePilotAgent(session_id="bg-test", user_id="system") agent.channel = None agent.source = None agent.reply_mode = ReplyMode.CAPTURE_ONLY - agent.persist_output_message = True agent._tool_context = {"user_reply_sent": False} agent._streamed_output = "" agent.stream_handler = SimpleNamespace( @@ -246,15 +231,11 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase): return_value=_FakeAgent([AIMessage(content="后台结果")]) ) agent.send_agent_message = AsyncMock() - agent._save_agent_message_to_db = AsyncMock() with patch.object(memory_manager, "save_agent_messages") as save_messages: await agent._execute_agent([]) agent.send_agent_message.assert_not_awaited() - agent._save_agent_message_to_db.assert_awaited_once_with( - "后台结果", title="MoviePilot助手" - ) save_messages.assert_called_once() self.assertEqual("后台结果", agent._streamed_output) @@ -279,7 +260,6 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase): process_message.assert_awaited_once() kwargs = process_message.await_args.kwargs self.assertEqual(ReplyMode.CAPTURE_ONLY, kwargs["reply_mode"]) - self.assertFalse(kwargs["persist_output_message"]) self.assertTrue(kwargs["allow_message_tools"]) async def test_heartbeat_check_jobs_skips_when_no_active_jobs(self): diff --git a/tests/test_agent_image_support.py b/tests/test_agent_image_support.py index f972ee9e..66cb042c 100644 --- a/tests/test_agent_image_support.py +++ b/tests/test_agent_image_support.py @@ -569,8 +569,8 @@ class AgentImageSupportTest(unittest.TestCase): self.assertEqual(payload.image_url, "https://example.com/poster.png") - def test_send_message_tool_uses_agent_notification_type(self): - """发送消息工具应固定使用智能体消息类型。""" + def test_send_message_tool_uses_regular_notification_type(self): + """发送消息工具应按普通通知消息登记。""" async def _run(): tool = SendMessageTool(session_id="session-1", user_id="10001") @@ -595,7 +595,7 @@ class AgentImageSupportTest(unittest.TestCase): notification = async_post_message.await_args.args[0] self.assertEqual(result, "消息已发送") - self.assertEqual(notification.mtype, NotificationType.Agent) + self.assertEqual(notification.mtype, NotificationType.Other) self.assertEqual(notification.channel, MessageChannel.Telegram) self.assertEqual(notification.source, "telegram-test") self.assertEqual(notification.title, "智能体通知") diff --git a/tests/test_agent_interaction.py b/tests/test_agent_interaction.py index 3e309c6f..873319f3 100644 --- a/tests/test_agent_interaction.py +++ b/tests/test_agent_interaction.py @@ -204,8 +204,8 @@ class TestAgentInteraction(unittest.TestCase): self.assertEqual(kwargs["channel"], MessageChannel.Telegram.value) self.assertEqual(kwargs["source"], "telegram-test") self.assertNotIn("processing_status", kwargs) - message_put.assert_called_once() - message_add.assert_called_once() + message_put.assert_not_called() + message_add.assert_not_called() def test_legacy_agent_choice_callback_still_supported(self): chain = MessageChain() diff --git a/tests/test_agent_message_routing.py b/tests/test_agent_message_routing.py index b8f81335..378641c2 100644 --- a/tests/test_agent_message_routing.py +++ b/tests/test_agent_message_routing.py @@ -1,7 +1,8 @@ -from unittest.mock import patch +from unittest.mock import AsyncMock, Mock, patch from app.chain.message import MessageChain -from app.helper.interaction import media_interaction_manager +from app.core.config import settings +from app.helper.interaction import AgentInteractionOption, agent_interaction_manager, media_interaction_manager from app.schemas.types import MessageChannel @@ -38,3 +39,73 @@ def test_explicit_ai_message_bypasses_pending_media_interaction(): handle_ai_message.assert_called_once() handle_media_interaction.assert_not_called() + + +def test_explicit_ai_message_is_not_recorded_to_message_history(): + """显式 /ai 消息不登记到数据库或实时消息队列。""" + chain = MessageChain() + + with patch.object(settings, "AI_AGENT_ENABLE", True), patch.object( + chain, "_record_user_message" + ) as record_user_message, patch( + "app.chain.message.agent_manager.process_message", + new_callable=AsyncMock, + ) as process_message, patch( + "app.chain.message.asyncio.run_coroutine_threadsafe", + side_effect=lambda coro, _loop: (coro.close(), Mock())[1], + ): + chain.handle_message( + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + username="tester", + text="/ai 帮我检查订阅", + ) + + record_user_message.assert_not_called() + process_message.assert_called_once() + + +def test_agent_choice_callback_is_not_recorded_to_message_history(): + """Agent 按钮选择回传不登记到数据库或实时消息队列。""" + chain = MessageChain() + request = agent_interaction_manager.create_request( + session_id="session-choice", + user_id="10001", + channel=MessageChannel.Telegram.value, + source="telegram-test", + username="tester", + title="需要你的选择", + prompt="请选择", + options=[ + AgentInteractionOption(label="电影", value="我选择电影"), + AgentInteractionOption(label="电视剧", value="我选择电视剧"), + ], + ) + + try: + with patch.object(settings, "AI_AGENT_ENABLE", True), patch.object( + chain, "_record_user_message" + ) as record_user_message, patch.object( + chain, "edit_message", return_value=True + ), patch( + "app.chain.message.agent_manager.process_message", + new_callable=AsyncMock, + ) as process_message, patch( + "app.chain.message.asyncio.run_coroutine_threadsafe", + side_effect=lambda coro, _loop: (coro.close(), Mock())[1], + ): + chain._handle_callback( + text=f"CALLBACK:agent_interaction:choice:{request.request_id}:1", + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + username="tester", + original_message_id=123, + original_chat_id="456", + ) + finally: + agent_interaction_manager.clear() + + record_user_message.assert_not_called() + process_message.assert_called_once() diff --git a/tests/test_agent_tokens_events.py b/tests/test_agent_tokens_events.py index 6462def4..fcc28d90 100644 --- a/tests/test_agent_tokens_events.py +++ b/tests/test_agent_tokens_events.py @@ -94,7 +94,6 @@ class AgentTokensEventsTest(unittest.IsolatedAsyncioTestCase): stop_streaming=AsyncMock(return_value=(False, "")) ) agent.send_agent_message = AsyncMock() - agent._save_agent_message_to_db = AsyncMock() async def create_agent(_streaming=False, streaming=False): """模拟创建 Agent 时完成供应商选择和用量统计。""" diff --git a/tests/test_message_notifications.py b/tests/test_message_notifications.py new file mode 100644 index 00000000..01d596ea --- /dev/null +++ b/tests/test_message_notifications.py @@ -0,0 +1,169 @@ +import json +from unittest.mock import Mock + +from app.db import SessionFactory +from app.db.message_oper import MessageOper +from app.db.models.message import Message +from app.chain import ChainBase +from app.helper.message import MessageHelper +from app.schemas import Notification +from app.schemas.types import NotificationType + + +def _clear_messages() -> None: + """ + 清空消息表,隔离通知测试数据。 + """ + with SessionFactory() as db: + db.query(Message).delete() + db.commit() + + +def _reset_message_helper(helper: MessageHelper) -> None: + """ + 清空单例消息队列和去重缓存,避免用例间互相影响。 + """ + while helper.get() is not None: + pass + helper._recent_notification_keys.clear() + + +def test_notification_history_only_lists_sent_messages() -> None: + """ + 通知历史应返回已发送消息,包含通过消息链登记的智能体消息。 + """ + _clear_messages() + oper = MessageOper() + oper.add(title="系统通知", text="下载完成", action=1, mtype=NotificationType.Download) + oper.add(title="用户消息", text="帮我搜索", action=0) + oper.add(title="智能体回复", text="已处理", action=1, mtype=NotificationType.Agent) + + messages = MessageOper().list_by_page(page=1, count=10) + assert [message.title for message in messages if message.action == 1] == ["智能体回复", "系统通知"] + + +def test_web_message_history_returns_all_messages() -> None: + """ + Web 消息历史返回消息表中的全部记录。 + """ + _clear_messages() + oper = MessageOper() + oper.add(title="智能体回复", text="已处理", action=1, mtype=NotificationType.Agent) + oper.add(title="用户消息", text="/ai 帮我处理", action=0) + oper.add(title="普通通知", text="下载完成", action=1, mtype=NotificationType.Download) + + messages = MessageOper().list_by_page(page=1, count=10) + assert [message.title for message in messages] == ["普通通知", "用户消息", "智能体回复"] + + +def test_system_helper_message_only_enters_sse_queue() -> None: + """ + 系统实时消息只进入前端 SSE 队列,不写入通知历史。 + """ + _clear_messages() + helper = MessageHelper() + _reset_message_helper(helper) + + helper.put("调度任务执行失败", role="system", title="系统错误") + + assert MessageOper().list_by_page(page=1, count=10) == [] + realtime_message = json.loads(helper.get()) + assert realtime_message["type"] == "system" + assert realtime_message["title"] == "系统错误" + assert realtime_message["text"] == "调度任务执行失败" + + +def test_plugin_helper_message_deduplicates_recent_sse_messages() -> None: + """ + 短时间内相同插件实时消息只应推送一次,不写入通知历史。 + """ + _clear_messages() + helper = MessageHelper() + _reset_message_helper(helper) + + helper.put("站点刷流任务出错,获取下载器实例失败,请检查配置", role="plugin", title="站点刷流") + helper.put("站点刷流任务出错,获取下载器实例失败,请检查配置", role="plugin", title="站点刷流") + + assert MessageOper().list_by_page(page=1, count=10) == [] + assert json.loads(helper.get())["title"] == "站点刷流" + assert helper.get() is None + + +def test_agent_helper_message_does_not_enter_sse_queue() -> None: + """ + 智能体消息不进入前端 SSE 队列。 + """ + helper = MessageHelper() + _reset_message_helper(helper) + + helper.put("智能体回复", role="agent", title="MoviePilot助手") + + assert helper.get() is None + + +def test_user_helper_message_does_not_enter_sse_queue() -> None: + """ + 用户消息不进入前端 SSE 队列。 + """ + helper = MessageHelper() + _reset_message_helper(helper) + + helper.put("用户输入", role="user", title="admin") + + assert helper.get() is None + + +def test_notification_post_message_is_persisted_without_sse_queue() -> None: + """ + 业务通知通过消息链发送时只登记数据库,不进入前端 SSE 队列。 + """ + _clear_messages() + helper = MessageHelper() + _reset_message_helper(helper) + chain = ChainBase() + + chain.messagequeue.send_message = Mock() + chain.eventmanager.send_event = Mock() + + chain.post_message( + Notification( + mtype=NotificationType.Download, + title="下载完成", + text="影片已加入下载器", + ) + ) + + messages = MessageOper().list_by_page(page=1, count=10) + assert len(messages) == 1 + assert messages[0].title == "下载完成" + assert messages[0].mtype == NotificationType.Download.value + assert helper.get() is None + chain.messagequeue.send_message.assert_called_once() + + +def test_agent_notification_post_message_is_persisted_without_sse_queue() -> None: + """ + 智能体消息通过消息链发送时登记数据库,但不进入前端 SSE 队列。 + """ + _clear_messages() + helper = MessageHelper() + _reset_message_helper(helper) + chain = ChainBase() + + chain.messagequeue.send_message = Mock() + chain.eventmanager.send_event = Mock() + + chain.post_message( + Notification( + mtype=NotificationType.Agent, + title="MoviePilot助手", + text="已完成处理", + ) + ) + + messages = MessageOper().list_by_page(page=1, count=10) + assert len(messages) == 1 + assert messages[0].title == "MoviePilot助手" + assert messages[0].mtype == NotificationType.Agent.value + assert helper.get() is None + chain.messagequeue.send_message.assert_called_once() diff --git a/tests/test_search_ai_recommend.py b/tests/test_search_ai_recommend.py index 0ea642d9..2550a815 100644 --- a/tests/test_search_ai_recommend.py +++ b/tests/test_search_ai_recommend.py @@ -122,7 +122,6 @@ class SearchChainAIRecommendTest(unittest.IsolatedAsyncioTestCase): self.assertEqual("[0, 2]", result) self.assertEqual(ReplyMode.CAPTURE_ONLY, captured["reply_mode"]) - self.assertFalse(captured["persist_output_message"]) self.assertFalse(captured["allow_message_tools"]) def test_search_by_title_clears_previous_recommend_state_when_caching(self):