diff --git a/app/agent/tools/impl/execute_command.py b/app/agent/tools/impl/execute_command.py index 18167555..afc2aa15 100644 --- a/app/agent/tools/impl/execute_command.py +++ b/app/agent/tools/impl/execute_command.py @@ -5,7 +5,8 @@ import os import signal import subprocess from dataclasses import dataclass, field -from typing import Optional, Type +from tempfile import NamedTemporaryFile +from typing import Optional, TextIO, Type from pydantic import BaseModel, Field @@ -15,7 +16,7 @@ from app.log import logger DEFAULT_TIMEOUT_SECONDS = 60 MAX_TIMEOUT_SECONDS = 300 -MAX_OUTPUT_CHARS = 6000 +MAX_OUTPUT_PREVIEW_BYTES = 10 * 1024 READ_CHUNK_SIZE = 4096 KILL_GRACE_SECONDS = 3 COMMAND_CONCURRENCY_LIMIT = 2 @@ -25,40 +26,93 @@ _command_semaphore = asyncio.Semaphore(COMMAND_CONCURRENCY_LIMIT) @dataclass class _CommandOutput: - """保存受限命令输出,避免大输出一次性进入内存。""" + """保存前 10KB 预览,并在超限时将完整输出写入临时文件。""" - limit: int - stdout_chunks: list[str] = field(default_factory=list) - stderr_chunks: list[str] = field(default_factory=list) - captured_chars: int = 0 - truncated: bool = False + preview_limit_bytes: int + preview_entries: list[tuple[str, str]] = field(default_factory=list) + captured_bytes: int = 0 + preview_truncated: bool = False + temp_file_path: Optional[str] = None + temp_file_handle: Optional[TextIO] = None + last_written_stream: Optional[str] = None + + @staticmethod + def _clip_text_to_bytes(text: str, byte_limit: int) -> str: + if byte_limit <= 0: + return "" + return text.encode("utf-8")[:byte_limit].decode("utf-8", errors="ignore") + + def _write_chunk(self, stream_name: str, text: str) -> None: + if not self.temp_file_handle or not text: + return + + if self.last_written_stream != stream_name: + if self.temp_file_handle.tell() > 0: + self.temp_file_handle.write("\n") + title = "标准输出" if stream_name == "stdout" else "错误输出" + self.temp_file_handle.write(f"[{title}]\n") + self.last_written_stream = stream_name + + self.temp_file_handle.write(text) + + def _ensure_temp_file(self) -> None: + if self.temp_file_handle: + return + + temp_file = NamedTemporaryFile( + mode="w", + encoding="utf-8", + suffix=".log", + prefix="moviepilot-command-", + delete=False, + ) + self.temp_file_path = temp_file.name + self.temp_file_handle = temp_file + for stream_name, chunk in self.preview_entries: + self._write_chunk(stream_name, chunk) + + def close(self) -> None: + if not self.temp_file_handle: + return + self.temp_file_handle.flush() + self.temp_file_handle.close() + self.temp_file_handle = None def append(self, stream_name: str, text: str) -> None: if not text: return - remaining = self.limit - self.captured_chars - if remaining <= 0: - self.truncated = True + if self.temp_file_handle: + self._write_chunk(stream_name, text) return - captured = text[:remaining] - if stream_name == "stdout": - self.stdout_chunks.append(captured) - else: - self.stderr_chunks.append(captured) + chunk_bytes = len(text.encode("utf-8")) + remaining = self.preview_limit_bytes - self.captured_bytes + if chunk_bytes <= remaining: + self.preview_entries.append((stream_name, text)) + self.captured_bytes += chunk_bytes + return - self.captured_chars += len(captured) - if len(text) > remaining: - self.truncated = True + self.preview_truncated = True + self._ensure_temp_file() + self._write_chunk(stream_name, text) + + preview = self._clip_text_to_bytes(text, remaining) + if preview: + self.preview_entries.append((stream_name, preview)) + self.captured_bytes += len(preview.encode("utf-8")) @property def stdout(self) -> str: - return "".join(self.stdout_chunks).strip() + return "".join( + text for stream_name, text in self.preview_entries if stream_name == "stdout" + ).strip() @property def stderr(self) -> str: - return "".join(self.stderr_chunks).strip() + return "".join( + text for stream_name, text in self.preview_entries if stream_name == "stderr" + ).strip() class ExecuteCommandInput(BaseModel): @@ -78,7 +132,7 @@ class ExecuteCommandTool(MoviePilotTool): description: str = ( "Safely execute shell commands on the server. Useful for system " "maintenance, checking status, or running custom scripts. Includes " - "timeout, concurrency, and hard output limits." + "timeout, concurrency, and output preview limits." ) args_schema: Type[BaseModel] = ExecuteCommandInput require_admin: bool = True @@ -107,7 +161,7 @@ class ExecuteCommandTool(MoviePilotTool): @staticmethod def _subprocess_kwargs() -> dict: - """为子进程创建独立进程组,便于超时或输出过大时清理整棵子进程。""" + """为子进程创建独立进程组,便于超时场景清理整棵子进程。""" kwargs = { "stdin": subprocess.DEVNULL, "stdout": asyncio.subprocess.PIPE, @@ -124,23 +178,14 @@ class ExecuteCommandTool(MoviePilotTool): stream: asyncio.StreamReader, stream_name: str, output: _CommandOutput, - limit_reached: asyncio.Event, ) -> None: - """按块读取输出,达到上限后通知主流程终止命令。""" + """按块读取输出,始终只把前 10KB 保留在返回结果中。""" while True: chunk = await stream.read(READ_CHUNK_SIZE) if not chunk: break - if output.truncated: - limit_reached.set() - continue - output.append(stream_name, chunk.decode("utf-8", errors="replace")) - if output.truncated: - limit_reached.set() - # 达到上限后继续排空管道但不再保存内容,避免子进程因 pipe 反压卡住。 - continue @staticmethod def _terminate_process(process: asyncio.subprocess.Process, sig: int): @@ -205,27 +250,33 @@ class ExecuteCommandTool(MoviePilotTool): output: _CommandOutput, timeout: int, timed_out: bool, - output_limited: bool, timeout_note: Optional[str], ) -> str: if timed_out: result = f"命令执行超时 (限制: {timeout}秒,已终止进程)" - elif output_limited: - result = ( - f"命令输出超过限制 (限制: {MAX_OUTPUT_CHARS}字符," - f"已截断并终止进程,退出码: {exit_code})" - ) else: result = f"命令执行完成 (退出码: {exit_code})" if timeout_note: result += f"\n\n提示:\n{timeout_note}" + if output.temp_file_path: + file_note = ( + "截至命令终止前的完整输出" + if timed_out + else "完整输出" + ) + result += ( + "\n\n提示:\n" + f"命令输出超过 10KB,仅返回前 {MAX_OUTPUT_PREVIEW_BYTES} 字节内容。\n" + f"{file_note}已写入临时文件: {output.temp_file_path}\n" + "如需完整内容,请继续读取该文件。" + ) if output.stdout: result += f"\n\n标准输出:\n{output.stdout}" if output.stderr: result += f"\n\n错误输出:\n{output.stderr}" - if output.truncated: - result += "\n\n...(输出内容过长,已截断)" + if output.preview_truncated: + result += "\n\n...(仅展示前 10KB 内容)" if not output.stdout and not output.stderr: result += "\n\n(无输出内容)" return result @@ -252,51 +303,40 @@ class ExecuteCommandTool(MoviePilotTool): try: async with _command_semaphore: - # 命令输出可能非常大,必须边读边截断,不能使用 communicate() 一次性收集。 + # 命令输出可能非常大,必须边读边落盘,不能使用 communicate() 一次性收集。 process = await asyncio.create_subprocess_shell( command, **self._subprocess_kwargs() ) - output = _CommandOutput(limit=MAX_OUTPUT_CHARS) - limit_reached = asyncio.Event() + output = _CommandOutput(preview_limit_bytes=MAX_OUTPUT_PREVIEW_BYTES) wait_task = asyncio.create_task(process.wait()) - limit_task = asyncio.create_task(limit_reached.wait()) reader_tasks = [ asyncio.create_task( - self._read_stream( - process.stdout, "stdout", output, limit_reached - ) + self._read_stream(process.stdout, "stdout", output) ), asyncio.create_task( - self._read_stream( - process.stderr, "stderr", output, limit_reached - ) + self._read_stream(process.stderr, "stderr", output) ), ] timed_out = False - output_limited = False - done, _ = await asyncio.wait( - {wait_task, limit_task}, - timeout=normalized_timeout, - return_when=asyncio.FIRST_COMPLETED, - ) - - if wait_task not in done: - if limit_task in done: - output_limited = True - else: - timed_out = True + try: + await asyncio.wait_for( + asyncio.shield(wait_task), timeout=normalized_timeout + ) + except asyncio.TimeoutError: + timed_out = True await self._cleanup_process(process, wait_task) - limit_task.cancel() - await self._finish_reader_tasks(reader_tasks) + try: + await self._finish_reader_tasks(reader_tasks) + finally: + output.close() return self._format_result( exit_code=process.returncode, output=output, timeout=normalized_timeout, timed_out=timed_out, - output_limited=output_limited, timeout_note=timeout_note, ) diff --git a/requirements.in b/requirements.in index 7a989ba8..8423f0ce 100644 --- a/requirements.in +++ b/requirements.in @@ -89,3 +89,4 @@ openai~=2.33.0 google-genai~=1.74.0 ddgs~=9.10.0 websocket-client~=1.8.0 +pytest~=8.4.0 diff --git a/tests/test_execute_command_tool.py b/tests/test_execute_command_tool.py index 82a3df9b..d23ab467 100644 --- a/tests/test_execute_command_tool.py +++ b/tests/test_execute_command_tool.py @@ -1,5 +1,6 @@ import asyncio import os +import re import shlex import subprocess import sys @@ -8,7 +9,7 @@ import unittest from app.agent.tools.impl.execute_command import ( ExecuteCommandTool, - MAX_OUTPUT_CHARS, + MAX_OUTPUT_PREVIEW_BYTES, ) @@ -21,6 +22,11 @@ def _python_command(code: str) -> str: class TestExecuteCommandTool(unittest.TestCase): + def _temp_file_path_from_result(self, result: str) -> str: + match = re.search(r"临时文件: (.+)", result) + self.assertIsNotNone(match) + return match.group(1).strip() + def _run_command(self, command: str, timeout: int = 60) -> str: tool = ExecuteCommandTool(session_id="session-1", user_id="10001") return asyncio.run(tool.run(command=command, timeout=timeout)) @@ -31,9 +37,19 @@ class TestExecuteCommandTool(unittest.TestCase): ) result = self._run_command(command) + temp_file_path = self._temp_file_path_from_result(result) - self.assertIn("输出内容过长,已截断", result) - self.assertLess(len(result), MAX_OUTPUT_CHARS + 500) + self.addCleanup(lambda: os.path.exists(temp_file_path) and os.unlink(temp_file_path)) + self.assertIn("命令输出超过 10KB", result) + self.assertIn("仅展示前 10KB 内容", result) + self.assertIn("如需完整内容,请继续读取该文件", result) + self.assertLess(len(result), MAX_OUTPUT_PREVIEW_BYTES + 600) + + with open(temp_file_path, encoding="utf-8") as file_handle: + file_content = file_handle.read() + + self.assertIn("[标准输出]", file_content) + self.assertGreater(len(file_content), 100000) def test_timeout_returns_partial_output_promptly(self): command = _python_command( @@ -48,6 +64,24 @@ class TestExecuteCommandTool(unittest.TestCase): self.assertIn("命令执行超时", result) self.assertIn("started", result) + def test_timeout_with_large_output_writes_partial_full_log_to_temp_file(self): + command = _python_command( + "import sys, time; sys.stdout.write('x' * 20000); sys.stdout.flush(); time.sleep(5)" + ) + + result = self._run_command(command, timeout=1) + temp_file_path = self._temp_file_path_from_result(result) + + self.addCleanup(lambda: os.path.exists(temp_file_path) and os.unlink(temp_file_path)) + self.assertIn("命令执行超时", result) + self.assertIn("截至命令终止前的完整输出已写入临时文件", result) + + with open(temp_file_path, encoding="utf-8") as file_handle: + file_content = file_handle.read() + + self.assertIn("[标准输出]", file_content) + self.assertGreaterEqual(file_content.count("x"), 20000) + def test_timeout_is_capped(self): command = _python_command("print('ok')")