mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-30 12:11:49 +08:00
test: fix agent voice message streaming tests
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
@@ -19,15 +18,21 @@ from app.schemas.types import MessageChannel, NotificationType
|
||||
|
||||
|
||||
class DummyTool(MoviePilotTool):
|
||||
"""用于流式输出测试的固定结果工具。"""
|
||||
|
||||
name: str = "dummy_tool"
|
||||
description: str = "Dummy tool for streaming tests."
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
"""返回固定工具执行结果。"""
|
||||
return "ok"
|
||||
|
||||
|
||||
class TestAgentToolStreaming(unittest.TestCase):
|
||||
class TestAgentToolStreaming:
|
||||
"""Agent 工具流式输出测试。"""
|
||||
|
||||
async def _run_tool(self, initial_buffer: str) -> tuple[str, str]:
|
||||
"""运行测试工具并返回工具结果与缓冲内容。"""
|
||||
tool = DummyTool(session_id="session-1", user_id="10001")
|
||||
handler = StreamingHandler()
|
||||
await handler.start_streaming()
|
||||
@@ -42,24 +47,28 @@ class TestAgentToolStreaming(unittest.TestCase):
|
||||
return result, buffered_message
|
||||
|
||||
def test_non_verbose_tool_call_flushes_summary_on_take(self):
|
||||
"""校验非详细模式在读取时补充工具调用摘要。"""
|
||||
result, buffered_message = asyncio.run(self._run_tool("prefix"))
|
||||
|
||||
self.assertEqual(result, "ok")
|
||||
self.assertEqual(buffered_message, "prefix\n\n(调用了 1 次工具)\n\n")
|
||||
assert result == "ok"
|
||||
assert buffered_message == "prefix\n\n(调用了 1 次工具)\n\n"
|
||||
|
||||
def test_non_verbose_tool_call_reuses_existing_newline_before_summary(self):
|
||||
"""校验非详细模式复用已有换行追加工具摘要。"""
|
||||
result, buffered_message = asyncio.run(self._run_tool("prefix\n"))
|
||||
|
||||
self.assertEqual(result, "ok")
|
||||
self.assertEqual(buffered_message, "prefix\n(调用了 1 次工具)\n\n")
|
||||
assert result == "ok"
|
||||
assert buffered_message == "prefix\n(调用了 1 次工具)\n\n"
|
||||
|
||||
def test_non_verbose_tool_call_emits_summary_even_when_buffer_was_empty(self):
|
||||
"""校验空缓冲区仍会输出工具调用摘要。"""
|
||||
result, buffered_message = asyncio.run(self._run_tool(""))
|
||||
|
||||
self.assertEqual(result, "ok")
|
||||
self.assertEqual(buffered_message, "(调用了 1 次工具)\n\n")
|
||||
assert result == "ok"
|
||||
assert buffered_message == "(调用了 1 次工具)\n\n"
|
||||
|
||||
def test_non_verbose_tool_summary_is_inserted_before_next_text(self):
|
||||
"""校验工具摘要插入在后续文本之前。"""
|
||||
async def _run():
|
||||
tool = DummyTool(session_id="session-1", user_id="10001")
|
||||
handler = StreamingHandler()
|
||||
@@ -75,12 +84,10 @@ class TestAgentToolStreaming(unittest.TestCase):
|
||||
|
||||
buffered_message = asyncio.run(_run())
|
||||
|
||||
self.assertEqual(
|
||||
buffered_message,
|
||||
"让我来检查一下:\n\n(调用了 1 次工具)\n\n已经拿到结果",
|
||||
)
|
||||
assert buffered_message == "让我来检查一下:\n\n(调用了 1 次工具)\n\n已经拿到结果"
|
||||
|
||||
def test_non_verbose_tool_summary_aggregates_multiple_categories(self):
|
||||
"""校验非详细模式按工具类别聚合摘要。"""
|
||||
async def _run():
|
||||
handler = StreamingHandler()
|
||||
await handler.start_streaming()
|
||||
@@ -110,12 +117,10 @@ class TestAgentToolStreaming(unittest.TestCase):
|
||||
|
||||
buffered_message = asyncio.run(_run())
|
||||
|
||||
self.assertEqual(
|
||||
buffered_message,
|
||||
"处理中:\n\n(执行了 2 次搜索,读取了 2 个文件)\n\n继续分析",
|
||||
)
|
||||
assert buffered_message == "处理中:\n\n(执行了 2 次搜索,读取了 2 个文件)\n\n继续分析"
|
||||
|
||||
def test_non_verbose_tool_summary_counts_subagents(self):
|
||||
"""校验非详细模式统计子代理调用次数。"""
|
||||
async def _run():
|
||||
handler = StreamingHandler()
|
||||
await handler.start_streaming()
|
||||
@@ -134,18 +139,18 @@ class TestAgentToolStreaming(unittest.TestCase):
|
||||
|
||||
buffered_message = asyncio.run(_run())
|
||||
|
||||
self.assertEqual(buffered_message, "处理中:\n\n(已调用 2 个子代理)\n\n")
|
||||
assert buffered_message == "处理中:\n\n(已调用 2 个子代理)\n\n"
|
||||
|
||||
def test_subagent_stream_metadata_is_suppressed(self):
|
||||
self.assertTrue(
|
||||
is_subagent_stream_metadata(
|
||||
{"metadata": {"ls_agent_type": "subagent"}}
|
||||
)
|
||||
"""校验子代理流式元数据会被识别并抑制。"""
|
||||
assert is_subagent_stream_metadata(
|
||||
{"metadata": {"ls_agent_type": "subagent"}}
|
||||
)
|
||||
self.assertTrue(is_subagent_stream_metadata({"lc_agent_name": "media-researcher"}))
|
||||
self.assertFalse(is_subagent_stream_metadata({"lc_agent_name": "main"}))
|
||||
assert is_subagent_stream_metadata({"lc_agent_name": "media-researcher"})
|
||||
assert not is_subagent_stream_metadata({"lc_agent_name": "main"})
|
||||
|
||||
def test_openai_streaming_handler_flushes_pending_summary_to_queue(self):
|
||||
"""校验 OpenAI 流式处理器将待发送摘要推入队列。"""
|
||||
async def _run():
|
||||
handler = _OpenAIStreamingHandler()
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
@@ -163,11 +168,12 @@ class TestAgentToolStreaming(unittest.TestCase):
|
||||
|
||||
emitted, queued, buffered_message = asyncio.run(_run())
|
||||
|
||||
self.assertEqual(emitted, "(读取了 1 个文件)\n\n")
|
||||
self.assertEqual(queued, emitted)
|
||||
self.assertEqual(buffered_message, emitted)
|
||||
assert emitted == "(读取了 1 个文件)\n\n"
|
||||
assert queued == emitted
|
||||
assert buffered_message == emitted
|
||||
|
||||
def test_flush_sends_direct_message_via_threadpool(self):
|
||||
"""校验刷新时通过线程池发送首条直连消息。"""
|
||||
handler = StreamingHandler()
|
||||
handler._channel = MessageChannel.Telegram.value
|
||||
handler._source = "telegram"
|
||||
@@ -188,17 +194,13 @@ class TestAgentToolStreaming(unittest.TestCase):
|
||||
|
||||
asyncio.run(handler._flush())
|
||||
|
||||
self.assertEqual(run_in_threadpool_mock.await_count, 1)
|
||||
self.assertEqual(
|
||||
run_in_threadpool_mock.await_args.args[0].__name__, "send_direct_message"
|
||||
)
|
||||
self.assertEqual(
|
||||
run_in_threadpool_mock.await_args.args[1].mtype,
|
||||
NotificationType.Agent,
|
||||
)
|
||||
self.assertTrue(handler.has_sent_message)
|
||||
assert run_in_threadpool_mock.await_count == 1
|
||||
assert run_in_threadpool_mock.await_args.args[0].__name__ == "send_direct_message"
|
||||
assert run_in_threadpool_mock.await_args.args[1].mtype == NotificationType.Agent
|
||||
assert handler.has_sent_message
|
||||
|
||||
def test_flush_edits_message_via_threadpool(self):
|
||||
"""校验刷新时通过线程池编辑已有消息。"""
|
||||
handler = StreamingHandler()
|
||||
handler._channel = MessageChannel.Telegram.value
|
||||
handler._source = "telegram"
|
||||
@@ -219,13 +221,12 @@ class TestAgentToolStreaming(unittest.TestCase):
|
||||
|
||||
asyncio.run(handler._flush())
|
||||
|
||||
self.assertEqual(run_in_threadpool_mock.await_count, 1)
|
||||
self.assertEqual(
|
||||
run_in_threadpool_mock.await_args.args[0].__name__, "edit_message"
|
||||
)
|
||||
self.assertEqual(handler._sent_text, "hello world")
|
||||
assert run_in_threadpool_mock.await_count == 1
|
||||
assert run_in_threadpool_mock.await_args.args[0].__name__ == "edit_message"
|
||||
assert handler._sent_text == "hello world"
|
||||
|
||||
def test_stop_streaming_waits_inflight_initial_flush_before_final_edit(self):
|
||||
"""校验停止流式输出会等待首条消息发送完成再编辑。"""
|
||||
async def _run():
|
||||
handler = StreamingHandler()
|
||||
handler._channel = MessageChannel.Feishu.value
|
||||
@@ -263,7 +264,7 @@ class TestAgentToolStreaming(unittest.TestCase):
|
||||
|
||||
stop_task = asyncio.create_task(handler.stop_streaming())
|
||||
await asyncio.sleep(0)
|
||||
self.assertFalse(stop_task.done())
|
||||
assert not stop_task.done()
|
||||
|
||||
allow_send_finish.set()
|
||||
all_sent, final_text = await stop_task
|
||||
@@ -272,17 +273,19 @@ class TestAgentToolStreaming(unittest.TestCase):
|
||||
|
||||
all_sent, final_text, calls = asyncio.run(_run())
|
||||
|
||||
self.assertTrue(all_sent)
|
||||
self.assertEqual(final_text, "hello world")
|
||||
self.assertEqual(
|
||||
[call[0] for call in calls],
|
||||
["send_direct_message", "edit_message", "finalize_message"],
|
||||
)
|
||||
assert all_sent
|
||||
assert final_text == "hello world"
|
||||
assert [call[0] for call in calls] == [
|
||||
"send_direct_message",
|
||||
"edit_message",
|
||||
"finalize_message",
|
||||
]
|
||||
edit_kwargs = calls[1][2]
|
||||
self.assertEqual(edit_kwargs["message_id"], "om_stream")
|
||||
self.assertEqual(edit_kwargs["text"], "hello world")
|
||||
assert edit_kwargs["message_id"] == "om_stream"
|
||||
assert edit_kwargs["text"] == "hello world"
|
||||
|
||||
def test_stop_streaming_uses_generic_finalize_message(self):
|
||||
"""校验停止流式输出会调用通用消息完成接口。"""
|
||||
handler = StreamingHandler()
|
||||
handler._message_response = MessageResponse(
|
||||
message_id="om_stream",
|
||||
@@ -305,16 +308,12 @@ class TestAgentToolStreaming(unittest.TestCase):
|
||||
):
|
||||
asyncio.run(handler.stop_streaming())
|
||||
|
||||
self.assertEqual(run_in_threadpool_mock.await_count, 1)
|
||||
self.assertEqual(
|
||||
run_in_threadpool_mock.await_args.args[0].__name__, "finalize_message"
|
||||
)
|
||||
self.assertEqual(
|
||||
run_in_threadpool_mock.await_args.args[1].message_id,
|
||||
"om_stream",
|
||||
)
|
||||
assert run_in_threadpool_mock.await_count == 1
|
||||
assert run_in_threadpool_mock.await_args.args[0].__name__ == "finalize_message"
|
||||
assert run_in_threadpool_mock.await_args.args[1].message_id == "om_stream"
|
||||
|
||||
def test_flush_without_channel_context_does_not_send_direct_message(self):
|
||||
"""校验缺少渠道上下文时不会发送直连消息。"""
|
||||
handler = StreamingHandler()
|
||||
handler._streaming_enabled = True
|
||||
handler.emit("hello")
|
||||
@@ -325,9 +324,10 @@ class TestAgentToolStreaming(unittest.TestCase):
|
||||
asyncio.run(handler._flush())
|
||||
|
||||
run_in_threadpool_mock.assert_not_awaited()
|
||||
self.assertFalse(handler.has_sent_message)
|
||||
assert not handler.has_sent_message
|
||||
|
||||
def test_flush_without_channel_context_dispatch_allowed_sends_direct_message(self):
|
||||
"""校验允许后台派发时缺少渠道上下文也能发送消息。"""
|
||||
handler = StreamingHandler()
|
||||
handler._user_id = "10001"
|
||||
handler._username = "tester"
|
||||
@@ -347,13 +347,12 @@ class TestAgentToolStreaming(unittest.TestCase):
|
||||
|
||||
asyncio.run(handler._flush())
|
||||
|
||||
self.assertEqual(run_in_threadpool_mock.await_count, 1)
|
||||
self.assertEqual(
|
||||
run_in_threadpool_mock.await_args.args[0].__name__, "send_direct_message"
|
||||
)
|
||||
self.assertTrue(handler.has_sent_message)
|
||||
assert run_in_threadpool_mock.await_count == 1
|
||||
assert run_in_threadpool_mock.await_args.args[0].__name__ == "send_direct_message"
|
||||
assert handler.has_sent_message
|
||||
|
||||
def test_flush_passes_original_message_context_to_send_direct_message(self):
|
||||
"""校验刷新发送时保留原始消息上下文。"""
|
||||
handler = StreamingHandler()
|
||||
handler._channel = MessageChannel.Feishu.value
|
||||
handler._source = "feishu-main"
|
||||
@@ -377,10 +376,11 @@ class TestAgentToolStreaming(unittest.TestCase):
|
||||
asyncio.run(handler._flush())
|
||||
|
||||
notification = run_in_threadpool_mock.await_args.args[1]
|
||||
self.assertEqual(notification.original_message_id, "om_origin")
|
||||
self.assertEqual(notification.original_chat_id, "oc_origin")
|
||||
assert notification.original_message_id == "om_origin"
|
||||
assert notification.original_chat_id == "oc_origin"
|
||||
|
||||
def test_verbose_background_tool_call_does_not_post_message(self):
|
||||
"""校验详细模式后台工具调用不会主动发送工具消息。"""
|
||||
async def _run():
|
||||
tool = DummyTool(session_id="session-1", user_id="10001")
|
||||
handler = StreamingHandler()
|
||||
@@ -400,11 +400,12 @@ class TestAgentToolStreaming(unittest.TestCase):
|
||||
|
||||
result, buffered_message, send_tool_message = asyncio.run(_run())
|
||||
|
||||
self.assertEqual(result, "ok")
|
||||
assert result == "ok"
|
||||
send_tool_message.assert_not_awaited()
|
||||
self.assertEqual(buffered_message, "(调用了 1 次工具)\n\n")
|
||||
assert buffered_message == "(调用了 1 次工具)\n\n"
|
||||
|
||||
def test_verbose_background_dispatch_tool_call_can_post_message(self):
|
||||
"""校验允许后台派发时详细模式工具调用可以发送消息。"""
|
||||
async def _run():
|
||||
tool = DummyTool(session_id="session-1", user_id="10001")
|
||||
handler = StreamingHandler()
|
||||
@@ -426,9 +427,9 @@ class TestAgentToolStreaming(unittest.TestCase):
|
||||
|
||||
result, buffered_message, send_tool_message = asyncio.run(_run())
|
||||
|
||||
self.assertEqual(result, "ok")
|
||||
assert result == "ok"
|
||||
send_tool_message.assert_awaited_once_with("前置内容\n\n⚙️ => run test tool")
|
||||
self.assertEqual(buffered_message, "")
|
||||
assert buffered_message == ""
|
||||
|
||||
def test_send_voice_message_uses_native_voice_for_supported_channels(self):
|
||||
"""校验支持语音输出的渠道会发送原生语音消息。"""
|
||||
@@ -453,28 +454,30 @@ class TestAgentToolStreaming(unittest.TestCase):
|
||||
"app.agent.tools.impl.send_voice_message.AgentCapabilityManager.synthesize_speech",
|
||||
return_value=Path("/tmp/reply.opus"),
|
||||
) as synthesize_speech,
|
||||
patch(
|
||||
"app.agent.tools.impl.send_voice_message.ToolChain.async_post_message",
|
||||
patch.object(
|
||||
SendVoiceMessageTool,
|
||||
"send_notification_message",
|
||||
new_callable=AsyncMock,
|
||||
) as async_post_message,
|
||||
) as send_notification_message,
|
||||
):
|
||||
result = await tool.run("你好")
|
||||
return result, synthesize_speech, async_post_message
|
||||
return result, synthesize_speech, send_notification_message
|
||||
|
||||
for channel in (MessageChannel.Telegram, MessageChannel.Feishu):
|
||||
result, synthesize_speech, async_post_message = asyncio.run(
|
||||
result, synthesize_speech, send_notification_message = asyncio.run(
|
||||
_run(channel)
|
||||
)
|
||||
notification = async_post_message.await_args.args[0]
|
||||
notification = send_notification_message.await_args.args[-1]
|
||||
|
||||
self.assertEqual(result, "语音回复已发送")
|
||||
assert result == "语音回复已发送"
|
||||
synthesize_speech.assert_called_once_with("你好")
|
||||
self.assertEqual(notification.channel, channel)
|
||||
self.assertEqual(notification.voice_path, "/tmp/reply.opus")
|
||||
self.assertEqual(notification.voice_caption, "你好")
|
||||
send_notification_message.assert_awaited_once()
|
||||
assert notification.channel == channel
|
||||
assert notification.voice_path == "/tmp/reply.opus"
|
||||
assert notification.voice_caption == "你好"
|
||||
voice_tool = SendVoiceMessageTool(session_id="session-1", user_id="10001")
|
||||
self.assertTrue(voice_tool.return_direct)
|
||||
self.assertIn("terminal response tool", voice_tool.description)
|
||||
assert voice_tool.return_direct
|
||||
assert "terminal response tool" in voice_tool.description
|
||||
|
||||
def test_send_voice_message_falls_back_for_unsupported_channels(self):
|
||||
"""校验不支持语音输出的渠道继续回退为文字消息。"""
|
||||
@@ -495,18 +498,20 @@ class TestAgentToolStreaming(unittest.TestCase):
|
||||
patch(
|
||||
"app.agent.tools.impl.send_voice_message.AgentCapabilityManager.synthesize_speech"
|
||||
) as synthesize_speech,
|
||||
patch(
|
||||
"app.agent.tools.impl.send_voice_message.ToolChain.async_post_message",
|
||||
patch.object(
|
||||
SendVoiceMessageTool,
|
||||
"send_notification_message",
|
||||
new_callable=AsyncMock,
|
||||
) as async_post_message,
|
||||
) as send_notification_message,
|
||||
):
|
||||
result = await tool.run("你好")
|
||||
return result, synthesize_speech, async_post_message
|
||||
return result, synthesize_speech, send_notification_message
|
||||
|
||||
result, synthesize_speech, async_post_message = asyncio.run(_run())
|
||||
notification = async_post_message.await_args.args[0]
|
||||
result, synthesize_speech, send_notification_message = asyncio.run(_run())
|
||||
notification = send_notification_message.await_args.args[-1]
|
||||
|
||||
self.assertEqual(result, "当前渠道不支持语音回复,已自动回退为文字回复")
|
||||
assert result == "当前渠道不支持语音回复,已自动回退为文字回复"
|
||||
synthesize_speech.assert_not_called()
|
||||
self.assertEqual(notification.text, "你好")
|
||||
self.assertIsNone(notification.voice_path)
|
||||
send_notification_message.assert_awaited_once()
|
||||
assert notification.text == "你好"
|
||||
assert notification.voice_path is None
|
||||
|
||||
Reference in New Issue
Block a user