Files
MoviePilot/tests/test_agent_tool_streaming.py
2026-04-29 07:07:33 +08:00

193 lines
6.9 KiB
Python

import asyncio
import unittest
from unittest.mock import AsyncMock, patch
import langchain.agents as langchain_agents
if not hasattr(langchain_agents, "create_agent"):
langchain_agents.create_agent = lambda *args, **kwargs: None
from app.agent.callback import StreamingHandler
from app.agent.tools.base import MoviePilotTool
from app.api.endpoints.openai import _OpenAIStreamingHandler
from app.core.config import settings
from app.schemas.message import MessageResponse
from app.schemas.types import MessageChannel
class DummyTool(MoviePilotTool):
name: str = "dummy_tool"
description: str = "Dummy tool for streaming tests."
async def run(self, **kwargs) -> str:
return "ok"
class TestAgentToolStreaming(unittest.TestCase):
async def _run_tool(self, initial_buffer: str) -> tuple[str, str]:
tool = DummyTool(session_id="session-1", user_id="10001")
handler = StreamingHandler()
await handler.start_streaming()
if initial_buffer:
handler.emit(initial_buffer)
tool.set_stream_handler(handler)
with patch.object(settings, "AI_AGENT_VERBOSE", False):
result = await tool._arun(explanation="run test tool")
buffered_message = await handler.take()
return result, buffered_message
def test_non_verbose_tool_call_flushes_summary_on_take(self):
result, buffered_message = asyncio.run(self._run_tool("prefix"))
self.assertEqual(result, "ok")
self.assertEqual(buffered_message, "prefix\n(调用了 1 次工具)\n")
def test_non_verbose_tool_call_reuses_existing_newline_before_summary(self):
result, buffered_message = asyncio.run(self._run_tool("prefix\n"))
self.assertEqual(result, "ok")
self.assertEqual(buffered_message, "prefix\n(调用了 1 次工具)\n")
def test_non_verbose_tool_call_emits_summary_even_when_buffer_was_empty(self):
result, buffered_message = asyncio.run(self._run_tool(""))
self.assertEqual(result, "ok")
self.assertEqual(buffered_message, "(调用了 1 次工具)\n")
def test_non_verbose_tool_summary_is_inserted_before_next_text(self):
async def _run():
tool = DummyTool(session_id="session-1", user_id="10001")
handler = StreamingHandler()
await handler.start_streaming()
handler.emit("让我来检查一下:")
tool.set_stream_handler(handler)
with patch.object(settings, "AI_AGENT_VERBOSE", False):
await tool._arun(explanation="run test tool")
handler.emit("已经拿到结果")
return await handler.take()
buffered_message = asyncio.run(_run())
self.assertEqual(
buffered_message,
"让我来检查一下:\n(调用了 1 次工具)\n已经拿到结果",
)
def test_non_verbose_tool_summary_aggregates_multiple_categories(self):
async def _run():
handler = StreamingHandler()
await handler.start_streaming()
handler.emit("处理中:")
handler.record_tool_call(
tool_name="search_web",
tool_message="搜索网络内容: MoviePilot",
tool_kwargs={"query": "MoviePilot"},
)
handler.record_tool_call(
tool_name="search_web",
tool_message="搜索网络内容: agent streaming",
tool_kwargs={"query": "agent streaming"},
)
handler.record_tool_call(
tool_name="read_file",
tool_message="读取文件: a.py",
tool_kwargs={"file_path": "/tmp/a.py"},
)
handler.record_tool_call(
tool_name="read_file",
tool_message="读取文件: b.py",
tool_kwargs={"file_path": "/tmp/b.py"},
)
handler.emit("继续分析")
return await handler.take()
buffered_message = asyncio.run(_run())
self.assertEqual(
buffered_message,
"处理中:\n(执行了 2 次搜索,读取了 2 个文件)\n继续分析",
)
def test_openai_streaming_handler_flushes_pending_summary_to_queue(self):
async def _run():
handler = _OpenAIStreamingHandler()
queue: asyncio.Queue = asyncio.Queue()
handler.bind_queue(queue)
await handler.start_streaming()
handler.record_tool_call(
tool_name="read_file",
tool_message="读取文件: app.py",
tool_kwargs={"file_path": "/tmp/app.py"},
)
emitted = handler.flush_pending_tool_summary()
queued = await queue.get()
buffered_message = await handler.take()
return emitted, queued, buffered_message
emitted, queued, buffered_message = asyncio.run(_run())
self.assertEqual(emitted, "(读取了 1 个文件)\n")
self.assertEqual(queued, emitted)
self.assertEqual(buffered_message, emitted)
def test_flush_sends_direct_message_via_threadpool(self):
handler = StreamingHandler()
handler._channel = MessageChannel.Telegram.value
handler._source = "telegram"
handler._user_id = "10001"
handler._username = "tester"
handler._streaming_enabled = True
handler.emit("hello")
with patch(
"app.agent.callback.run_in_threadpool", new_callable=AsyncMock
) as run_in_threadpool_mock:
run_in_threadpool_mock.return_value = MessageResponse(
message_id=1,
chat_id=2,
source="telegram",
success=True,
)
asyncio.run(handler._flush())
self.assertEqual(run_in_threadpool_mock.await_count, 1)
self.assertEqual(
run_in_threadpool_mock.await_args.args[0].__name__, "send_direct_message"
)
self.assertTrue(handler.has_sent_message)
def test_flush_edits_message_via_threadpool(self):
handler = StreamingHandler()
handler._channel = MessageChannel.Telegram.value
handler._streaming_enabled = True
handler._message_response = MessageResponse(
message_id=1,
chat_id=2,
source="telegram",
success=True,
)
handler._sent_text = "hello"
handler.emit("hello world")
with patch(
"app.agent.callback.run_in_threadpool", new_callable=AsyncMock
) as run_in_threadpool_mock:
run_in_threadpool_mock.return_value = True
asyncio.run(handler._flush())
self.assertEqual(run_in_threadpool_mock.await_count, 1)
self.assertEqual(
run_in_threadpool_mock.await_args.args[0].__name__, "edit_message"
)
self.assertEqual(handler._sent_text, "hello world")
if __name__ == "__main__":
unittest.main()