refactor: remove persist_output_message functionality and related database save logic

This commit is contained in:
jxxghp
2026-06-17 16:14:35 +08:00
parent 039558d240
commit d4b6d3f332
20 changed files with 442 additions and 170 deletions

View File

@@ -246,7 +246,6 @@ class MoviePilotAgent:
original_message_id: Optional[str] = None,
original_chat_id: Optional[str] = None,
replay_mode: ReplyMode = ReplyMode.DISPATCH,
persist_output_message: bool = True,
allow_message_tools: bool = True,
output_callback: Optional[Callable[[str], None]] = None,
):
@@ -258,7 +257,6 @@ class MoviePilotAgent:
self.original_message_id = original_message_id
self.original_chat_id = original_chat_id
self.reply_mode = replay_mode
self.persist_output_message = persist_output_message
self.allow_message_tools = allow_message_tools
self.output_callback = output_callback
self._tool_context: Dict[str, object] = {}
@@ -782,8 +780,6 @@ class MoviePilotAgent:
title = "MoviePilot助手" if self.is_background else ""
if self.should_dispatch_reply:
await self.send_agent_message(message, title=title)
elif self.persist_output_message:
await self._save_agent_message_to_db(message, title=title)
def _emit_output(self, text: str):
"""
@@ -1126,19 +1122,6 @@ class MoviePilotAgent:
and not self._tool_context.get("user_reply_sent")
):
await self.send_agent_message(remaining_text)
elif (
remaining_text
and self.persist_output_message
and not self._tool_context.get("user_reply_sent")
):
title = "MoviePilot助手" if self.is_background else ""
await self._save_agent_message_to_db(
remaining_text,
title=title,
)
elif streamed_text and self.persist_output_message:
# 流式输出已发送全部内容,但未记录到数据库,补充保存消息记录
await self._save_agent_message_to_db(streamed_text)
else:
# 非流式模式:后台任务或渠道不支持消息编辑
@@ -1180,13 +1163,6 @@ class MoviePilotAgent:
else:
# 非流式渠道:发送最终回复
await self.send_agent_message(final_text)
elif (
final_text
and self.persist_output_message
and not self._tool_context.get("user_reply_sent")
):
title = "MoviePilot助手" if self.is_background else ""
await self._save_agent_message_to_db(final_text, title=title)
# 保存消息
memory_manager.save_agent_messages(
@@ -1238,26 +1214,6 @@ class MoviePilotAgent:
)
)
async def _save_agent_message_to_db(self, message: str, title: str = ""):
"""
仅保存Agent回复消息到数据库和SSE队列不重新发送到渠道
用于流式输出场景:消息已通过 send_direct_message/edit_message 发送给用户,
但未记录到数据库中,此方法补充保存消息历史记录。
"""
chain = AgentChain()
notification = Notification(
channel=self.channel,
source=self.source,
userid=self.user_id,
username=self.username,
title=title,
text=message,
)
# 保存到SSE消息队列供前端展示
chain.messagehelper.put(notification, role="user", title=title)
# 保存到数据库
await chain.messageoper.async_add(**notification.model_dump())
async def cleanup(self):
"""
清理智能体资源
@@ -1284,7 +1240,6 @@ class _MessageTask:
original_chat_id: Optional[str] = None
processing_status: Optional[dict] = None
reply_mode: ReplyMode = ReplyMode.DISPATCH
persist_output_message: bool = True
allow_message_tools: bool = True
@@ -1430,7 +1385,6 @@ class AgentManager:
original_message_id: Optional[str] = None,
original_chat_id: Optional[str] = None,
reply_mode: ReplyMode = ReplyMode.DISPATCH,
persist_output_message: bool = True,
allow_message_tools: bool = True,
) -> str:
"""
@@ -1450,7 +1404,6 @@ class AgentManager:
original_message_id=original_message_id,
original_chat_id=original_chat_id,
reply_mode=reply_mode,
persist_output_message=persist_output_message,
allow_message_tools=allow_message_tools,
)
self._record_session_activity(session_id, user_id)
@@ -1561,7 +1514,6 @@ class AgentManager:
original_message_id=task.original_message_id,
original_chat_id=task.original_chat_id,
replay_mode=task.reply_mode,
persist_output_message=task.persist_output_message,
allow_message_tools=task.allow_message_tools,
)
self.active_agents[session_id] = agent
@@ -1577,7 +1529,6 @@ class AgentManager:
agent.original_message_id = task.original_message_id
agent.original_chat_id = task.original_chat_id
agent.reply_mode = task.reply_mode
agent.persist_output_message = task.persist_output_message
agent.allow_message_tools = task.allow_message_tools
process_kwargs = {
@@ -1656,7 +1607,6 @@ class AgentManager:
session_prefix: str = "__agent_background",
output_callback: Optional[Callable[[str], None]] = None,
reply_mode: ReplyMode = ReplyMode.CAPTURE_ONLY,
persist_output_message: bool = True,
allow_message_tools: Optional[bool] = None,
) -> None:
"""
@@ -1677,7 +1627,6 @@ class AgentManager:
source=None,
username=settings.SUPERUSER,
replay_mode=reply_mode,
persist_output_message=persist_output_message,
output_callback=output_callback,
allow_message_tools=allow_message_tools,
)
@@ -1723,7 +1672,6 @@ class AgentManager:
source=None,
username=settings.SUPERUSER,
reply_mode=ReplyMode.CAPTURE_ONLY,
persist_output_message=False,
allow_message_tools=True,
)

View File

@@ -7,6 +7,8 @@ from pydantic import BaseModel, Field, model_validator
from app.agent.tools.base import MoviePilotTool
from app.agent.tools.tags import ToolTag
from app.log import logger
from app.schemas import Notification
from app.schemas.types import NotificationType
class SendMessageInput(BaseModel):
@@ -77,7 +79,18 @@ class SendMessageTool(MoviePilotTool):
f"执行工具: {self.name}, 参数: title={title}, message={text}, image_url={image_url}"
)
try:
await self.send_tool_message(text, title=title, image=image_url)
await self.send_notification_message(
Notification(
channel=self._channel,
source=self._source,
mtype=NotificationType.Other,
userid=self._user_id,
username=self._username,
title=title,
text=text,
image=image_url,
)
)
return "消息已发送"
except Exception as e:
logger.error(f"发送消息失败: {e}")

View File

@@ -881,7 +881,6 @@ async def web_agent_stream(
source=WEB_AGENT_SOURCE,
username=current_user.name,
replay_mode=ReplyMode.CAPTURE_ONLY,
persist_output_message=False,
allow_message_tools=True,
output_callback=output_callback,
notification_callback=notification_callback,

View File

@@ -136,7 +136,6 @@ def _start_ai_redo_task(history_id: int, prompt: str, progress_key: str):
session_prefix=f"__agent_manual_redo_{history_id}",
output_callback=update_output,
reply_mode=ReplyMode.CAPTURE_ONLY,
persist_output_message=False,
allow_message_tools=False,
)
progress.update(
@@ -182,7 +181,6 @@ def _start_batch_ai_redo_task(
session_prefix="__agent_manual_redo_batch",
output_callback=update_output,
reply_mode=ReplyMode.CAPTURE_ONLY,
persist_output_message=False,
allow_message_tools=False,
)
progress.update(

View File

@@ -12,7 +12,7 @@ from app.core.config import settings, global_vars
from app.core.security import verify_token, verify_apitoken
from app.db import get_async_db
from app.db.models import User
from app.db.models.message import Message
from app.db.message_oper import MessageOper
from app.db.user_oper import get_current_active_superuser
from app.helper.service import ServiceConfigHelper
from app.helper.webpush import is_webpush_subscription_gone
@@ -120,7 +120,7 @@ async def get_web_message(
获取WEB消息列表
"""
ret_messages = []
messages = await Message.async_list_by_page(db, page=page, count=count)
messages = await MessageOper(db).async_list_by_page(page=page, count=count)
for message in messages:
try:
ret_messages.append(message.to_dict())
@@ -130,6 +130,20 @@ async def get_web_message(
return ret_messages
@router.get("/notification", summary="获取通知消息", response_model=List[schemas.NotificationHistoryItem])
async def get_notification_message(
_: schemas.TokenPayload = Depends(verify_token),
db: AsyncSession = Depends(get_async_db),
page: Optional[int] = 1,
count: Optional[int] = 20,
):
"""
获取系统发送的通知消息列表。
"""
messages = await MessageOper(db).async_list_sent_by_page(page=page, count=count)
return [schemas.NotificationHistoryItem(**message.to_dict()) for message in messages]
def wechat_verify(
echostr: str,
msg_signature: str,

View File

@@ -52,9 +52,6 @@ class _CollectingMoviePilotAgent(MoviePilotAgent):
if self.stream_mode:
self.stream_handler.emit(text)
async def _save_agent_message_to_db(self, message: str, title: str = ""):
return None
class _OpenAIStreamingHandler(StreamingHandler):
"""

View File

@@ -1478,8 +1478,6 @@ class ChainBase(metaclass=ABCMeta):
if not message:
logger.warning("消息为空,跳过发送")
return
# 保存消息
self.messagehelper.put(message, role="user", title=message.title)
self.messageoper.add(**message.model_dump())
dispatch_message = self._normalize_notification_for_dispatch(message)
# 发送消息按设置隔离
@@ -1595,8 +1593,6 @@ class ChainBase(metaclass=ABCMeta):
if not message:
logger.warning("消息为空,跳过发送")
return
# 保存消息
self.messagehelper.put(message, role="user", title=message.title)
await self.messageoper.async_add(**message.model_dump())
dispatch_message = self._normalize_notification_for_dispatch(message)
# 发送消息按设置隔离
@@ -1688,9 +1684,6 @@ class ChainBase(metaclass=ABCMeta):
:return: 成功或失败
"""
note_list = [media.to_dict() for media in medias]
self.messagehelper.put(
message, role="user", note=note_list, title=message.title
)
self.messageoper.add(**message.model_dump(), note=note_list)
dispatch_message = self._normalize_notification_for_dispatch(message)
return self.messagequeue.send_message(
@@ -1710,9 +1703,6 @@ class ChainBase(metaclass=ABCMeta):
:return: 成功或失败
"""
note_list = [torrent.torrent_info.to_dict() for torrent in torrents]
self.messagehelper.put(
message, role="user", note=note_list, title=message.title
)
self.messageoper.add(**message.model_dump(), note=note_list)
dispatch_message = self._normalize_notification_for_dispatch(message)
return self.messagequeue.send_message(

View File

@@ -197,7 +197,15 @@ class MessageChain(ChainBase):
)
return
if not text.startswith("CALLBACK:"):
is_agent_message = self._is_agent_message(
userid=userid,
text=text,
images=images,
files=files,
has_audio_input=has_audio_input,
)
if not text.startswith("CALLBACK:") and not is_agent_message:
self._record_user_message(
channel=channel,
source=source,
@@ -206,14 +214,7 @@ class MessageChain(ChainBase):
text=text,
)
if not self._is_agent_message(
channel=channel,
userid=userid,
text=text,
images=images,
files=files,
has_audio_input=has_audio_input,
):
if not is_agent_message:
processing_status = self._mark_message_processing_started(
channel=channel,
source=source,
@@ -430,7 +431,6 @@ class MessageChain(ChainBase):
def _is_agent_message(
self,
channel: MessageChannel,
userid: Union[str, int],
text: str,
images: Optional[List[CommingMessage.MessageImage]] = None,
@@ -766,13 +766,6 @@ class MessageChain(ChainBase):
selected_label=option.label,
)
self._bind_session_id(userid, request.session_id)
self._record_user_message(
channel=channel,
source=source,
userid=userid,
username=username,
text=selected_text,
)
return self._handle_ai_message(
text=selected_text,
channel=channel,
@@ -954,7 +947,6 @@ class MessageChain(ChainBase):
session_prefix=f"__agent_manual_redo_{history_id}",
output_callback=_capture_output,
reply_mode=ReplyMode.CAPTURE_ONLY,
persist_output_message=False,
allow_message_tools=False,
)
await self.async_post_message(

View File

@@ -395,7 +395,6 @@ class SearchChain(ChainBase):
session_prefix="__agent_search_recommend",
output_callback=on_output,
reply_mode=ReplyMode.CAPTURE_ONLY,
persist_output_message=False,
allow_message_tools=False,
)
return full_output[0].strip()

View File

@@ -1,6 +1,7 @@
import time
from typing import Optional, Union
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import DbOper
@@ -13,7 +14,7 @@ class MessageOper(DbOper):
消息数据管理
"""
def __init__(self, db: Session = None):
def __init__(self, db: Union[Session, AsyncSession] = None):
super().__init__(db)
def add(self,
@@ -27,7 +28,7 @@ class MessageOper(DbOper):
userid: Optional[str] = None,
action: Optional[int] = 1,
note: Union[list, dict] = None,
**kwargs):
**kwargs) -> dict:
"""
新增消息
:param channel: 消息渠道
@@ -60,7 +61,7 @@ class MessageOper(DbOper):
if k not in Message.__table__.columns.keys(): # noqa
kwargs.pop(k)
Message(**kwargs).create(self._db)
return Message(**kwargs).create_and_to_dict(self._db)
async def async_add(self,
channel: MessageChannel = None,
@@ -73,7 +74,7 @@ class MessageOper(DbOper):
userid: Optional[str] = None,
action: Optional[int] = 1,
note: Union[list, dict] = None,
**kwargs):
**kwargs) -> Message:
"""
异步新增消息
"""
@@ -96,10 +97,26 @@ class MessageOper(DbOper):
if k not in Message.__table__.columns.keys(): # noqa
kwargs.pop(k)
await Message(**kwargs).async_create(self._db)
return await Message(**kwargs).async_create(self._db)
def list_by_page(self, page: Optional[int] = 1, count: Optional[int] = 30) -> Optional[str]:
def list_by_page(self, page: Optional[int] = 1, count: Optional[int] = 30) -> list[Message]:
"""
获取媒体服务器数据ID
分页获取消息记录。
"""
return Message.list_by_page(self._db, page, count)
async def async_list_by_page(
self, page: Optional[int] = 1, count: Optional[int] = 30
) -> list[Message]:
"""
分页获取消息记录。
"""
return await Message.async_list_by_page(self._db, page, count)
async def async_list_sent_by_page(
self, page: Optional[int] = 1, count: Optional[int] = 30
) -> list[Message]:
"""
分页获取系统发送的通知消息。
"""
return await Message.async_list_sent_by_page(self._db, page, count)

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import List, Optional
from sqlalchemy import Column, Integer, String, JSON, Index, select
from sqlalchemy.ext.asyncio import AsyncSession
@@ -39,16 +39,59 @@ class Message(Base):
Index('ix_message_reg_time_id', 'reg_time', 'id'),
)
@db_update
def create_and_to_dict(self, db: Session) -> dict:
"""
创建消息记录并返回写入后的字段字典。
"""
db.add(self)
db.flush()
return self.to_dict()
@classmethod
@db_query
def list_by_page(cls, db: Session, page: Optional[int] = 1, count: Optional[int] = 30):
return db.query(cls).order_by(cls.reg_time.desc()).offset((page - 1) * count).limit(count).all()
def list_by_page(cls, db: Session, page: Optional[int] = 1, count: Optional[int] = 30) -> List["Message"]:
"""
分页获取消息记录。
"""
return (
db.query(cls)
.order_by(cls.reg_time.desc(), cls.id.desc())
.offset((page - 1) * count)
.limit(count)
.all()
)
@classmethod
@async_db_query
async def async_list_by_page(cls, db: AsyncSession, page: Optional[int] = 1, count: Optional[int] = 30):
async def async_list_by_page(
cls, db: AsyncSession, page: Optional[int] = 1, count: Optional[int] = 30
) -> List["Message"]:
"""
异步分页获取消息记录。
"""
result = await db.execute(
select(cls).order_by(cls.reg_time.desc()).offset((page - 1) * count).limit(count)
select(cls)
.order_by(cls.reg_time.desc(), cls.id.desc())
.offset((page - 1) * count)
.limit(count)
)
return result.scalars().all()
@classmethod
@async_db_query
async def async_list_sent_by_page(
cls, db: AsyncSession, page: Optional[int] = 1, count: Optional[int] = 30
) -> List["Message"]:
"""
分页获取系统发送的通知消息。
"""
result = await db.execute(
select(cls)
.where(cls.action == 1)
.order_by(cls.reg_time.desc(), cls.id.desc())
.offset((page - 1) * count)
.limit(count)
)
return result.scalars().all()

View File

@@ -764,61 +764,74 @@ class MessageQueueManager(metaclass=SingletonClass):
class MessageHelper(metaclass=Singleton):
"""
消息队列管理器,包括系统消息和用户消息
消息队列管理器,负责系统和插件实时消息的 SSE 推送
"""
def __init__(self):
self.sys_queue = queue.Queue()
self.user_queue = queue.Queue()
self._recent_notification_keys = TTLCache(region="message:notification", maxsize=500, ttl=60)
@staticmethod
def _build_system_notification_key(
message: Any, role: str, title: str = None, note: Union[list, dict] = None
) -> str:
"""
构建系统通知短期去重键。
"""
return json.dumps(
{
"role": role,
"title": title or "",
"text": str(message),
"note": note or {},
"time": time.strftime("%Y-%m-%d %H:%M", time.localtime()),
},
ensure_ascii=False,
sort_keys=True,
)
def _is_recent_system_notification(
self, message: Any, role: str, title: str = None, note: Union[list, dict] = None
) -> bool:
"""
判断系统通知是否在短时间内重复。
"""
key = self._build_system_notification_key(message, role, title=title, note=note)
if self._recent_notification_keys.get(key):
return True
self._recent_notification_keys.set(key, True)
return False
def put(self, message: Any, role: str = "plugin", title: str = None, note: Union[list, dict] = None):
"""
存消息
:param message: 消息
:param role: 消息通道 systm系统消息plugin插件消息user用户消息
:param role: 消息通道 system系统消息plugin插件消息
:param title: 标题
:param note: 附件json
"""
if role in ["system", "plugin"]:
# 没有标题时获取插件名称
if role == "plugin" and not title:
title = "插件通知"
# 系统通知,默认
self.sys_queue.put(json.dumps({
"type": role,
"title": title,
"text": message,
"date": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
"note": note
}))
else:
if isinstance(message, str):
# 非系统的文本通知
self.user_queue.put(json.dumps({
"title": title,
"text": message,
"date": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
"note": note
}))
elif hasattr(message, "to_dict"):
# 非系统的复杂结构通知,如媒体信息/种子列表等。
content = message.to_dict()
content['title'] = title
content['date'] = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
content['note'] = note
self.user_queue.put(json.dumps(content))
if role not in ["system", "plugin"]:
return
# 没有标题时获取插件名称
if role == "plugin" and not title:
title = "插件通知"
if self._is_recent_system_notification(message, role, title=title, note=note):
return
self.sys_queue.put(json.dumps({
"type": role,
"title": title,
"text": message,
"date": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
"note": note
}))
def get(self, role: str = "system") -> Optional[str]:
"""
取消息
:param role: 消息通道 systm系统消息plugin插件消息user用户消息
:param role: 兼容旧参数,当前所有 SSE 消息共用一个队列
"""
if role == "system":
if not self.sys_queue.empty():
return self.sys_queue.get(block=False)
else:
if not self.user_queue.empty():
return self.user_queue.get(block=False)
if not self.sys_queue.empty():
return self.sys_queue.get(block=False)
return None

View File

@@ -26,6 +26,37 @@ class MessageResponse(BaseModel):
success: bool = False
class NotificationHistoryItem(BaseModel):
"""
通知历史记录。
"""
# 消息ID
id: Optional[int] = None
# 消息渠道
channel: Optional[str] = None
# 消息来源
source: Optional[str] = None
# 消息类型
mtype: Optional[str] = None
# 标题
title: Optional[str] = None
# 文本内容
text: Optional[str] = None
# 图片
image: Optional[str] = None
# 链接
link: Optional[str] = None
# 用户ID
userid: Optional[str] = None
# 登记时间
reg_time: Optional[str] = None
# 消息方向0-接收消息1-发送消息
action: Optional[int] = None
# 附件json
note: Optional[Union[list, dict]] = None
class CommingMessage(BaseModel):
"""
外来消息