fix agent stream blocking during command execution

Offload synchronous message edits from the event loop and stream subprocess output so long-running commands stay responsive.
This commit is contained in:
jxxghp
2026-04-27 07:57:32 +08:00
parent 7bc032d17c
commit 140d224a9a
4 changed files with 293 additions and 35 deletions

View File

@@ -1,10 +1,17 @@
import asyncio
import unittest
from unittest.mock import patch
from unittest.mock import AsyncMock, patch
import langchain.agents as langchain_agents
if not hasattr(langchain_agents, "create_agent"):
langchain_agents.create_agent = lambda *args, **kwargs: None
from app.agent.callback import StreamingHandler
from app.agent.tools.base import MoviePilotTool
from app.core.config import settings
from app.schemas.message import MessageResponse
from app.schemas.types import MessageChannel
class DummyTool(MoviePilotTool):
@@ -48,6 +55,59 @@ class TestAgentToolStreaming(unittest.TestCase):
self.assertEqual(result, "ok")
self.assertEqual(buffered_message, "")
def test_flush_sends_direct_message_via_threadpool(self):
handler = StreamingHandler()
handler._channel = MessageChannel.Telegram.value
handler._source = "telegram"
handler._user_id = "10001"
handler._username = "tester"
handler._streaming_enabled = True
handler.emit("hello")
with patch(
"app.agent.callback.run_in_threadpool", new_callable=AsyncMock
) as run_in_threadpool_mock:
run_in_threadpool_mock.return_value = MessageResponse(
message_id=1,
chat_id=2,
source="telegram",
success=True,
)
asyncio.run(handler._flush())
self.assertEqual(run_in_threadpool_mock.await_count, 1)
self.assertEqual(
run_in_threadpool_mock.await_args.args[0].__name__, "send_direct_message"
)
self.assertTrue(handler.has_sent_message)
def test_flush_edits_message_via_threadpool(self):
handler = StreamingHandler()
handler._channel = MessageChannel.Telegram.value
handler._streaming_enabled = True
handler._message_response = MessageResponse(
message_id=1,
chat_id=2,
source="telegram",
success=True,
)
handler._sent_text = "hello"
handler.emit("hello world")
with patch(
"app.agent.callback.run_in_threadpool", new_callable=AsyncMock
) as run_in_threadpool_mock:
run_in_threadpool_mock.return_value = True
asyncio.run(handler._flush())
self.assertEqual(run_in_threadpool_mock.await_count, 1)
self.assertEqual(
run_in_threadpool_mock.await_args.args[0].__name__, "edit_message"
)
self.assertEqual(handler._sent_text, "hello world")
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,57 @@
import asyncio
import shlex
import sys
import unittest
import langchain.agents as langchain_agents
if not hasattr(langchain_agents, "create_agent"):
langchain_agents.create_agent = lambda *args, **kwargs: None
from app.agent.callback import StreamingHandler
from app.agent.tools.impl.execute_command import ExecuteCommandTool
class TestExecuteCommandTool(unittest.TestCase):
@staticmethod
def _build_python_command(script: str) -> str:
return f"{shlex.quote(sys.executable)} -c '{script}'"
@staticmethod
def _build_streaming_tool() -> tuple[ExecuteCommandTool, StreamingHandler]:
tool = ExecuteCommandTool(session_id="session-1", user_id="10001")
handler = StreamingHandler()
handler._streaming_enabled = True
handler._flush_task = object()
tool.set_stream_handler(handler)
return tool, handler
def test_run_streams_live_output_and_collects_result(self):
tool, handler = self._build_streaming_tool()
command = self._build_python_command(
'import sys; print("out"); print("err", file=sys.stderr)'
)
result = asyncio.run(tool.run(command=command, timeout=5))
live_output = asyncio.run(handler.take())
self.assertIn("命令执行完成 (退出码: 0)", result)
self.assertIn("标准输出:\nout", result)
self.assertIn("错误输出:\nerr", result)
self.assertIn("标准输出:\nout", live_output)
self.assertIn("错误输出:\nerr", live_output)
def test_run_timeout_keeps_partial_output(self):
tool = ExecuteCommandTool(session_id="session-1", user_id="10001")
command = self._build_python_command(
'import sys,time; print("start"); sys.stdout.flush(); time.sleep(0.2)'
)
result = asyncio.run(tool.run(command=command, timeout=0.05))
self.assertIn("命令执行超时", result)
self.assertIn("标准输出:\nstart", result)
if __name__ == "__main__":
unittest.main()