mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-28 11:12:00 +08:00
refactor: remove persist_output_message functionality and related database save logic
This commit is contained in:
@@ -66,7 +66,6 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
|
||||
agent.channel = None
|
||||
agent.source = None
|
||||
agent.reply_mode = ReplyMode.CAPTURE_ONLY
|
||||
agent.persist_output_message = True
|
||||
agent._tool_context = {"user_reply_sent": False}
|
||||
agent._streamed_output = ""
|
||||
agent.stream_handler = SimpleNamespace(
|
||||
@@ -77,15 +76,11 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
|
||||
return_value=_FakeAgent([AIMessage(content="后台结果")])
|
||||
)
|
||||
agent.send_agent_message = AsyncMock()
|
||||
agent._save_agent_message_to_db = AsyncMock()
|
||||
|
||||
with patch.object(memory_manager, "save_agent_messages") as save_messages:
|
||||
await agent._execute_agent([])
|
||||
|
||||
agent.send_agent_message.assert_not_awaited()
|
||||
agent._save_agent_message_to_db.assert_awaited_once_with(
|
||||
"后台结果", title="MoviePilot助手"
|
||||
)
|
||||
save_messages.assert_called_once()
|
||||
self.assertEqual("后台结果", agent._streamed_output)
|
||||
|
||||
@@ -105,7 +100,6 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
)
|
||||
agent.send_agent_message = AsyncMock()
|
||||
agent._save_agent_message_to_db = AsyncMock()
|
||||
|
||||
result, _ = await agent._execute_agent(
|
||||
[
|
||||
@@ -122,7 +116,6 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
|
||||
agent.send_agent_message.assert_awaited_once_with(
|
||||
UNSUPPORTED_IMAGE_INPUT_MESSAGE, title=""
|
||||
)
|
||||
agent._save_agent_message_to_db.assert_not_awaited()
|
||||
self.assertEqual(UNSUPPORTED_IMAGE_INPUT_MESSAGE, agent._streamed_output)
|
||||
|
||||
async def test_streaming_image_unsupported_error_sends_friendly_notice(self):
|
||||
@@ -144,7 +137,6 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
)
|
||||
agent.send_agent_message = AsyncMock()
|
||||
agent._save_agent_message_to_db = AsyncMock()
|
||||
|
||||
result, _ = await agent._execute_agent(
|
||||
[
|
||||
@@ -161,7 +153,6 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
|
||||
agent.send_agent_message.assert_awaited_once_with(
|
||||
UNSUPPORTED_IMAGE_INPUT_MESSAGE, title=""
|
||||
)
|
||||
agent._save_agent_message_to_db.assert_not_awaited()
|
||||
|
||||
async def test_streaming_model_chunk_timeout_sends_friendly_notice(self):
|
||||
"""流式模型分块超时时应只把主错误信息发给用户。"""
|
||||
@@ -186,7 +177,6 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
|
||||
return_value=_FakeStreamingFailingAgent(raw_error)
|
||||
)
|
||||
agent.send_agent_message = AsyncMock()
|
||||
agent._save_agent_message_to_db = AsyncMock()
|
||||
|
||||
result, _ = await agent._execute_agent([HumanMessage(content="测试超时")])
|
||||
|
||||
@@ -200,14 +190,12 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertIn("No streaming chunk received for 120.0s", sent_message)
|
||||
self.assertNotIn("Tune or disable", sent_message)
|
||||
self.assertEqual(expected, agent._streamed_output)
|
||||
agent._save_agent_message_to_db.assert_not_awaited()
|
||||
|
||||
async def test_background_non_streaming_sends_when_reply_mode_dispatch(self):
|
||||
agent = MoviePilotAgent(session_id="bg-test", user_id="system")
|
||||
agent.channel = None
|
||||
agent.source = None
|
||||
agent.reply_mode = ReplyMode.DISPATCH
|
||||
agent.persist_output_message = False
|
||||
agent._tool_context = {"user_reply_sent": False}
|
||||
agent._streamed_output = ""
|
||||
agent.stream_handler = SimpleNamespace(
|
||||
@@ -218,7 +206,6 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
|
||||
return_value=_FakeAgent([AIMessage(content="后台结果")])
|
||||
)
|
||||
agent.send_agent_message = AsyncMock()
|
||||
agent._save_agent_message_to_db = AsyncMock()
|
||||
|
||||
with patch.object(memory_manager, "save_agent_messages") as save_messages:
|
||||
await agent._execute_agent([])
|
||||
@@ -226,16 +213,14 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
|
||||
agent.send_agent_message.assert_awaited_once_with(
|
||||
"后台结果", title="MoviePilot助手"
|
||||
)
|
||||
agent._save_agent_message_to_db.assert_not_awaited()
|
||||
save_messages.assert_called_once()
|
||||
self.assertEqual("后台结果", agent._streamed_output)
|
||||
|
||||
async def test_background_non_streaming_persists_without_sending_when_capture_only(self):
|
||||
async def test_background_non_streaming_captures_without_sending_when_capture_only(self):
|
||||
agent = MoviePilotAgent(session_id="bg-test", user_id="system")
|
||||
agent.channel = None
|
||||
agent.source = None
|
||||
agent.reply_mode = ReplyMode.CAPTURE_ONLY
|
||||
agent.persist_output_message = True
|
||||
agent._tool_context = {"user_reply_sent": False}
|
||||
agent._streamed_output = ""
|
||||
agent.stream_handler = SimpleNamespace(
|
||||
@@ -246,15 +231,11 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
|
||||
return_value=_FakeAgent([AIMessage(content="后台结果")])
|
||||
)
|
||||
agent.send_agent_message = AsyncMock()
|
||||
agent._save_agent_message_to_db = AsyncMock()
|
||||
|
||||
with patch.object(memory_manager, "save_agent_messages") as save_messages:
|
||||
await agent._execute_agent([])
|
||||
|
||||
agent.send_agent_message.assert_not_awaited()
|
||||
agent._save_agent_message_to_db.assert_awaited_once_with(
|
||||
"后台结果", title="MoviePilot助手"
|
||||
)
|
||||
save_messages.assert_called_once()
|
||||
self.assertEqual("后台结果", agent._streamed_output)
|
||||
|
||||
@@ -279,7 +260,6 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
|
||||
process_message.assert_awaited_once()
|
||||
kwargs = process_message.await_args.kwargs
|
||||
self.assertEqual(ReplyMode.CAPTURE_ONLY, kwargs["reply_mode"])
|
||||
self.assertFalse(kwargs["persist_output_message"])
|
||||
self.assertTrue(kwargs["allow_message_tools"])
|
||||
|
||||
async def test_heartbeat_check_jobs_skips_when_no_active_jobs(self):
|
||||
|
||||
@@ -569,8 +569,8 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(payload.image_url, "https://example.com/poster.png")
|
||||
|
||||
def test_send_message_tool_uses_agent_notification_type(self):
|
||||
"""发送消息工具应固定使用智能体消息类型。"""
|
||||
def test_send_message_tool_uses_regular_notification_type(self):
|
||||
"""发送消息工具应按普通通知消息登记。"""
|
||||
|
||||
async def _run():
|
||||
tool = SendMessageTool(session_id="session-1", user_id="10001")
|
||||
@@ -595,7 +595,7 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
notification = async_post_message.await_args.args[0]
|
||||
|
||||
self.assertEqual(result, "消息已发送")
|
||||
self.assertEqual(notification.mtype, NotificationType.Agent)
|
||||
self.assertEqual(notification.mtype, NotificationType.Other)
|
||||
self.assertEqual(notification.channel, MessageChannel.Telegram)
|
||||
self.assertEqual(notification.source, "telegram-test")
|
||||
self.assertEqual(notification.title, "智能体通知")
|
||||
|
||||
@@ -204,8 +204,8 @@ class TestAgentInteraction(unittest.TestCase):
|
||||
self.assertEqual(kwargs["channel"], MessageChannel.Telegram.value)
|
||||
self.assertEqual(kwargs["source"], "telegram-test")
|
||||
self.assertNotIn("processing_status", kwargs)
|
||||
message_put.assert_called_once()
|
||||
message_add.assert_called_once()
|
||||
message_put.assert_not_called()
|
||||
message_add.assert_not_called()
|
||||
|
||||
def test_legacy_agent_choice_callback_still_supported(self):
|
||||
chain = MessageChain()
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from app.chain.message import MessageChain
|
||||
from app.helper.interaction import media_interaction_manager
|
||||
from app.core.config import settings
|
||||
from app.helper.interaction import AgentInteractionOption, agent_interaction_manager, media_interaction_manager
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
@@ -38,3 +39,73 @@ def test_explicit_ai_message_bypasses_pending_media_interaction():
|
||||
|
||||
handle_ai_message.assert_called_once()
|
||||
handle_media_interaction.assert_not_called()
|
||||
|
||||
|
||||
def test_explicit_ai_message_is_not_recorded_to_message_history():
|
||||
"""显式 /ai 消息不登记到数据库或实时消息队列。"""
|
||||
chain = MessageChain()
|
||||
|
||||
with patch.object(settings, "AI_AGENT_ENABLE", True), patch.object(
|
||||
chain, "_record_user_message"
|
||||
) as record_user_message, patch(
|
||||
"app.chain.message.agent_manager.process_message",
|
||||
new_callable=AsyncMock,
|
||||
) as process_message, patch(
|
||||
"app.chain.message.asyncio.run_coroutine_threadsafe",
|
||||
side_effect=lambda coro, _loop: (coro.close(), Mock())[1],
|
||||
):
|
||||
chain.handle_message(
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
text="/ai 帮我检查订阅",
|
||||
)
|
||||
|
||||
record_user_message.assert_not_called()
|
||||
process_message.assert_called_once()
|
||||
|
||||
|
||||
def test_agent_choice_callback_is_not_recorded_to_message_history():
|
||||
"""Agent 按钮选择回传不登记到数据库或实时消息队列。"""
|
||||
chain = MessageChain()
|
||||
request = agent_interaction_manager.create_request(
|
||||
session_id="session-choice",
|
||||
user_id="10001",
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-test",
|
||||
username="tester",
|
||||
title="需要你的选择",
|
||||
prompt="请选择",
|
||||
options=[
|
||||
AgentInteractionOption(label="电影", value="我选择电影"),
|
||||
AgentInteractionOption(label="电视剧", value="我选择电视剧"),
|
||||
],
|
||||
)
|
||||
|
||||
try:
|
||||
with patch.object(settings, "AI_AGENT_ENABLE", True), patch.object(
|
||||
chain, "_record_user_message"
|
||||
) as record_user_message, patch.object(
|
||||
chain, "edit_message", return_value=True
|
||||
), patch(
|
||||
"app.chain.message.agent_manager.process_message",
|
||||
new_callable=AsyncMock,
|
||||
) as process_message, patch(
|
||||
"app.chain.message.asyncio.run_coroutine_threadsafe",
|
||||
side_effect=lambda coro, _loop: (coro.close(), Mock())[1],
|
||||
):
|
||||
chain._handle_callback(
|
||||
text=f"CALLBACK:agent_interaction:choice:{request.request_id}:1",
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
original_message_id=123,
|
||||
original_chat_id="456",
|
||||
)
|
||||
finally:
|
||||
agent_interaction_manager.clear()
|
||||
|
||||
record_user_message.assert_not_called()
|
||||
process_message.assert_called_once()
|
||||
|
||||
@@ -94,7 +94,6 @@ class AgentTokensEventsTest(unittest.IsolatedAsyncioTestCase):
|
||||
stop_streaming=AsyncMock(return_value=(False, ""))
|
||||
)
|
||||
agent.send_agent_message = AsyncMock()
|
||||
agent._save_agent_message_to_db = AsyncMock()
|
||||
|
||||
async def create_agent(_streaming=False, streaming=False):
|
||||
"""模拟创建 Agent 时完成供应商选择和用量统计。"""
|
||||
|
||||
169
tests/test_message_notifications.py
Normal file
169
tests/test_message_notifications.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import json
|
||||
from unittest.mock import Mock
|
||||
|
||||
from app.db import SessionFactory
|
||||
from app.db.message_oper import MessageOper
|
||||
from app.db.models.message import Message
|
||||
from app.chain import ChainBase
|
||||
from app.helper.message import MessageHelper
|
||||
from app.schemas import Notification
|
||||
from app.schemas.types import NotificationType
|
||||
|
||||
|
||||
def _clear_messages() -> None:
|
||||
"""
|
||||
清空消息表,隔离通知测试数据。
|
||||
"""
|
||||
with SessionFactory() as db:
|
||||
db.query(Message).delete()
|
||||
db.commit()
|
||||
|
||||
|
||||
def _reset_message_helper(helper: MessageHelper) -> None:
|
||||
"""
|
||||
清空单例消息队列和去重缓存,避免用例间互相影响。
|
||||
"""
|
||||
while helper.get() is not None:
|
||||
pass
|
||||
helper._recent_notification_keys.clear()
|
||||
|
||||
|
||||
def test_notification_history_only_lists_sent_messages() -> None:
|
||||
"""
|
||||
通知历史应返回已发送消息,包含通过消息链登记的智能体消息。
|
||||
"""
|
||||
_clear_messages()
|
||||
oper = MessageOper()
|
||||
oper.add(title="系统通知", text="下载完成", action=1, mtype=NotificationType.Download)
|
||||
oper.add(title="用户消息", text="帮我搜索", action=0)
|
||||
oper.add(title="智能体回复", text="已处理", action=1, mtype=NotificationType.Agent)
|
||||
|
||||
messages = MessageOper().list_by_page(page=1, count=10)
|
||||
assert [message.title for message in messages if message.action == 1] == ["智能体回复", "系统通知"]
|
||||
|
||||
|
||||
def test_web_message_history_returns_all_messages() -> None:
|
||||
"""
|
||||
Web 消息历史返回消息表中的全部记录。
|
||||
"""
|
||||
_clear_messages()
|
||||
oper = MessageOper()
|
||||
oper.add(title="智能体回复", text="已处理", action=1, mtype=NotificationType.Agent)
|
||||
oper.add(title="用户消息", text="/ai 帮我处理", action=0)
|
||||
oper.add(title="普通通知", text="下载完成", action=1, mtype=NotificationType.Download)
|
||||
|
||||
messages = MessageOper().list_by_page(page=1, count=10)
|
||||
assert [message.title for message in messages] == ["普通通知", "用户消息", "智能体回复"]
|
||||
|
||||
|
||||
def test_system_helper_message_only_enters_sse_queue() -> None:
|
||||
"""
|
||||
系统实时消息只进入前端 SSE 队列,不写入通知历史。
|
||||
"""
|
||||
_clear_messages()
|
||||
helper = MessageHelper()
|
||||
_reset_message_helper(helper)
|
||||
|
||||
helper.put("调度任务执行失败", role="system", title="系统错误")
|
||||
|
||||
assert MessageOper().list_by_page(page=1, count=10) == []
|
||||
realtime_message = json.loads(helper.get())
|
||||
assert realtime_message["type"] == "system"
|
||||
assert realtime_message["title"] == "系统错误"
|
||||
assert realtime_message["text"] == "调度任务执行失败"
|
||||
|
||||
|
||||
def test_plugin_helper_message_deduplicates_recent_sse_messages() -> None:
|
||||
"""
|
||||
短时间内相同插件实时消息只应推送一次,不写入通知历史。
|
||||
"""
|
||||
_clear_messages()
|
||||
helper = MessageHelper()
|
||||
_reset_message_helper(helper)
|
||||
|
||||
helper.put("站点刷流任务出错,获取下载器实例失败,请检查配置", role="plugin", title="站点刷流")
|
||||
helper.put("站点刷流任务出错,获取下载器实例失败,请检查配置", role="plugin", title="站点刷流")
|
||||
|
||||
assert MessageOper().list_by_page(page=1, count=10) == []
|
||||
assert json.loads(helper.get())["title"] == "站点刷流"
|
||||
assert helper.get() is None
|
||||
|
||||
|
||||
def test_agent_helper_message_does_not_enter_sse_queue() -> None:
|
||||
"""
|
||||
智能体消息不进入前端 SSE 队列。
|
||||
"""
|
||||
helper = MessageHelper()
|
||||
_reset_message_helper(helper)
|
||||
|
||||
helper.put("智能体回复", role="agent", title="MoviePilot助手")
|
||||
|
||||
assert helper.get() is None
|
||||
|
||||
|
||||
def test_user_helper_message_does_not_enter_sse_queue() -> None:
|
||||
"""
|
||||
用户消息不进入前端 SSE 队列。
|
||||
"""
|
||||
helper = MessageHelper()
|
||||
_reset_message_helper(helper)
|
||||
|
||||
helper.put("用户输入", role="user", title="admin")
|
||||
|
||||
assert helper.get() is None
|
||||
|
||||
|
||||
def test_notification_post_message_is_persisted_without_sse_queue() -> None:
|
||||
"""
|
||||
业务通知通过消息链发送时只登记数据库,不进入前端 SSE 队列。
|
||||
"""
|
||||
_clear_messages()
|
||||
helper = MessageHelper()
|
||||
_reset_message_helper(helper)
|
||||
chain = ChainBase()
|
||||
|
||||
chain.messagequeue.send_message = Mock()
|
||||
chain.eventmanager.send_event = Mock()
|
||||
|
||||
chain.post_message(
|
||||
Notification(
|
||||
mtype=NotificationType.Download,
|
||||
title="下载完成",
|
||||
text="影片已加入下载器",
|
||||
)
|
||||
)
|
||||
|
||||
messages = MessageOper().list_by_page(page=1, count=10)
|
||||
assert len(messages) == 1
|
||||
assert messages[0].title == "下载完成"
|
||||
assert messages[0].mtype == NotificationType.Download.value
|
||||
assert helper.get() is None
|
||||
chain.messagequeue.send_message.assert_called_once()
|
||||
|
||||
|
||||
def test_agent_notification_post_message_is_persisted_without_sse_queue() -> None:
|
||||
"""
|
||||
智能体消息通过消息链发送时登记数据库,但不进入前端 SSE 队列。
|
||||
"""
|
||||
_clear_messages()
|
||||
helper = MessageHelper()
|
||||
_reset_message_helper(helper)
|
||||
chain = ChainBase()
|
||||
|
||||
chain.messagequeue.send_message = Mock()
|
||||
chain.eventmanager.send_event = Mock()
|
||||
|
||||
chain.post_message(
|
||||
Notification(
|
||||
mtype=NotificationType.Agent,
|
||||
title="MoviePilot助手",
|
||||
text="已完成处理",
|
||||
)
|
||||
)
|
||||
|
||||
messages = MessageOper().list_by_page(page=1, count=10)
|
||||
assert len(messages) == 1
|
||||
assert messages[0].title == "MoviePilot助手"
|
||||
assert messages[0].mtype == NotificationType.Agent.value
|
||||
assert helper.get() is None
|
||||
chain.messagequeue.send_message.assert_called_once()
|
||||
@@ -122,7 +122,6 @@ class SearchChainAIRecommendTest(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
self.assertEqual("[0, 2]", result)
|
||||
self.assertEqual(ReplyMode.CAPTURE_ONLY, captured["reply_mode"])
|
||||
self.assertFalse(captured["persist_output_message"])
|
||||
self.assertFalse(captured["allow_message_tools"])
|
||||
|
||||
def test_search_by_title_clears_previous_recommend_state_when_caching(self):
|
||||
|
||||
Reference in New Issue
Block a user