fix(agent): preserve full command output in temp logs

Return only a 10KB preview to the agent so large command results do not overwhelm conversations while keeping the full output available for follow-up reads. Add pytest to the project dependencies to make the regression tests runnable in the project venv.
This commit is contained in:
jxxghp
2026-05-06 20:04:17 +08:00
parent caf615f3bd
commit c762628217
3 changed files with 144 additions and 69 deletions

View File

@@ -5,7 +5,8 @@ import os
import signal
import subprocess
from dataclasses import dataclass, field
from typing import Optional, Type
from tempfile import NamedTemporaryFile
from typing import Optional, TextIO, Type
from pydantic import BaseModel, Field
@@ -15,7 +16,7 @@ from app.log import logger
DEFAULT_TIMEOUT_SECONDS = 60
MAX_TIMEOUT_SECONDS = 300
MAX_OUTPUT_CHARS = 6000
MAX_OUTPUT_PREVIEW_BYTES = 10 * 1024
READ_CHUNK_SIZE = 4096
KILL_GRACE_SECONDS = 3
COMMAND_CONCURRENCY_LIMIT = 2
@@ -25,40 +26,93 @@ _command_semaphore = asyncio.Semaphore(COMMAND_CONCURRENCY_LIMIT)
@dataclass
class _CommandOutput:
"""保存受限命令输出,避免大输出一次性进入内存"""
"""保存前 10KB 预览,并在超限时将完整输出写入临时文件"""
limit: int
stdout_chunks: list[str] = field(default_factory=list)
stderr_chunks: list[str] = field(default_factory=list)
captured_chars: int = 0
truncated: bool = False
preview_limit_bytes: int
preview_entries: list[tuple[str, str]] = field(default_factory=list)
captured_bytes: int = 0
preview_truncated: bool = False
temp_file_path: Optional[str] = None
temp_file_handle: Optional[TextIO] = None
last_written_stream: Optional[str] = None
@staticmethod
def _clip_text_to_bytes(text: str, byte_limit: int) -> str:
if byte_limit <= 0:
return ""
return text.encode("utf-8")[:byte_limit].decode("utf-8", errors="ignore")
def _write_chunk(self, stream_name: str, text: str) -> None:
if not self.temp_file_handle or not text:
return
if self.last_written_stream != stream_name:
if self.temp_file_handle.tell() > 0:
self.temp_file_handle.write("\n")
title = "标准输出" if stream_name == "stdout" else "错误输出"
self.temp_file_handle.write(f"[{title}]\n")
self.last_written_stream = stream_name
self.temp_file_handle.write(text)
def _ensure_temp_file(self) -> None:
if self.temp_file_handle:
return
temp_file = NamedTemporaryFile(
mode="w",
encoding="utf-8",
suffix=".log",
prefix="moviepilot-command-",
delete=False,
)
self.temp_file_path = temp_file.name
self.temp_file_handle = temp_file
for stream_name, chunk in self.preview_entries:
self._write_chunk(stream_name, chunk)
def close(self) -> None:
if not self.temp_file_handle:
return
self.temp_file_handle.flush()
self.temp_file_handle.close()
self.temp_file_handle = None
def append(self, stream_name: str, text: str) -> None:
if not text:
return
remaining = self.limit - self.captured_chars
if remaining <= 0:
self.truncated = True
if self.temp_file_handle:
self._write_chunk(stream_name, text)
return
captured = text[:remaining]
if stream_name == "stdout":
self.stdout_chunks.append(captured)
else:
self.stderr_chunks.append(captured)
chunk_bytes = len(text.encode("utf-8"))
remaining = self.preview_limit_bytes - self.captured_bytes
if chunk_bytes <= remaining:
self.preview_entries.append((stream_name, text))
self.captured_bytes += chunk_bytes
return
self.captured_chars += len(captured)
if len(text) > remaining:
self.truncated = True
self.preview_truncated = True
self._ensure_temp_file()
self._write_chunk(stream_name, text)
preview = self._clip_text_to_bytes(text, remaining)
if preview:
self.preview_entries.append((stream_name, preview))
self.captured_bytes += len(preview.encode("utf-8"))
@property
def stdout(self) -> str:
return "".join(self.stdout_chunks).strip()
return "".join(
text for stream_name, text in self.preview_entries if stream_name == "stdout"
).strip()
@property
def stderr(self) -> str:
return "".join(self.stderr_chunks).strip()
return "".join(
text for stream_name, text in self.preview_entries if stream_name == "stderr"
).strip()
class ExecuteCommandInput(BaseModel):
@@ -78,7 +132,7 @@ class ExecuteCommandTool(MoviePilotTool):
description: str = (
"Safely execute shell commands on the server. Useful for system "
"maintenance, checking status, or running custom scripts. Includes "
"timeout, concurrency, and hard output limits."
"timeout, concurrency, and output preview limits."
)
args_schema: Type[BaseModel] = ExecuteCommandInput
require_admin: bool = True
@@ -107,7 +161,7 @@ class ExecuteCommandTool(MoviePilotTool):
@staticmethod
def _subprocess_kwargs() -> dict:
"""为子进程创建独立进程组,便于超时或输出过大时清理整棵子进程。"""
"""为子进程创建独立进程组,便于超时场景清理整棵子进程。"""
kwargs = {
"stdin": subprocess.DEVNULL,
"stdout": asyncio.subprocess.PIPE,
@@ -124,23 +178,14 @@ class ExecuteCommandTool(MoviePilotTool):
stream: asyncio.StreamReader,
stream_name: str,
output: _CommandOutput,
limit_reached: asyncio.Event,
) -> None:
"""按块读取输出,达到上限后通知主流程终止命令"""
"""按块读取输出,始终只把前 10KB 保留在返回结果中"""
while True:
chunk = await stream.read(READ_CHUNK_SIZE)
if not chunk:
break
if output.truncated:
limit_reached.set()
continue
output.append(stream_name, chunk.decode("utf-8", errors="replace"))
if output.truncated:
limit_reached.set()
# 达到上限后继续排空管道但不再保存内容,避免子进程因 pipe 反压卡住。
continue
@staticmethod
def _terminate_process(process: asyncio.subprocess.Process, sig: int):
@@ -205,27 +250,33 @@ class ExecuteCommandTool(MoviePilotTool):
output: _CommandOutput,
timeout: int,
timed_out: bool,
output_limited: bool,
timeout_note: Optional[str],
) -> str:
if timed_out:
result = f"命令执行超时 (限制: {timeout}秒,已终止进程)"
elif output_limited:
result = (
f"命令输出超过限制 (限制: {MAX_OUTPUT_CHARS}字符,"
f"已截断并终止进程,退出码: {exit_code})"
)
else:
result = f"命令执行完成 (退出码: {exit_code})"
if timeout_note:
result += f"\n\n提示:\n{timeout_note}"
if output.temp_file_path:
file_note = (
"截至命令终止前的完整输出"
if timed_out
else "完整输出"
)
result += (
"\n\n提示:\n"
f"命令输出超过 10KB仅返回前 {MAX_OUTPUT_PREVIEW_BYTES} 字节内容。\n"
f"{file_note}已写入临时文件: {output.temp_file_path}\n"
"如需完整内容,请继续读取该文件。"
)
if output.stdout:
result += f"\n\n标准输出:\n{output.stdout}"
if output.stderr:
result += f"\n\n错误输出:\n{output.stderr}"
if output.truncated:
result += "\n\n...(输出内容过长,已截断)"
if output.preview_truncated:
result += "\n\n...(仅展示前 10KB 内容)"
if not output.stdout and not output.stderr:
result += "\n\n(无输出内容)"
return result
@@ -252,51 +303,40 @@ class ExecuteCommandTool(MoviePilotTool):
try:
async with _command_semaphore:
# 命令输出可能非常大,必须边读边截断,不能使用 communicate() 一次性收集。
# 命令输出可能非常大,必须边读边落盘,不能使用 communicate() 一次性收集。
process = await asyncio.create_subprocess_shell(
command, **self._subprocess_kwargs()
)
output = _CommandOutput(limit=MAX_OUTPUT_CHARS)
limit_reached = asyncio.Event()
output = _CommandOutput(preview_limit_bytes=MAX_OUTPUT_PREVIEW_BYTES)
wait_task = asyncio.create_task(process.wait())
limit_task = asyncio.create_task(limit_reached.wait())
reader_tasks = [
asyncio.create_task(
self._read_stream(
process.stdout, "stdout", output, limit_reached
)
self._read_stream(process.stdout, "stdout", output)
),
asyncio.create_task(
self._read_stream(
process.stderr, "stderr", output, limit_reached
)
self._read_stream(process.stderr, "stderr", output)
),
]
timed_out = False
output_limited = False
done, _ = await asyncio.wait(
{wait_task, limit_task},
timeout=normalized_timeout,
return_when=asyncio.FIRST_COMPLETED,
)
if wait_task not in done:
if limit_task in done:
output_limited = True
else:
timed_out = True
try:
await asyncio.wait_for(
asyncio.shield(wait_task), timeout=normalized_timeout
)
except asyncio.TimeoutError:
timed_out = True
await self._cleanup_process(process, wait_task)
limit_task.cancel()
await self._finish_reader_tasks(reader_tasks)
try:
await self._finish_reader_tasks(reader_tasks)
finally:
output.close()
return self._format_result(
exit_code=process.returncode,
output=output,
timeout=normalized_timeout,
timed_out=timed_out,
output_limited=output_limited,
timeout_note=timeout_note,
)

View File

@@ -89,3 +89,4 @@ openai~=2.33.0
google-genai~=1.74.0
ddgs~=9.10.0
websocket-client~=1.8.0
pytest~=8.4.0

View File

@@ -1,5 +1,6 @@
import asyncio
import os
import re
import shlex
import subprocess
import sys
@@ -8,7 +9,7 @@ import unittest
from app.agent.tools.impl.execute_command import (
ExecuteCommandTool,
MAX_OUTPUT_CHARS,
MAX_OUTPUT_PREVIEW_BYTES,
)
@@ -21,6 +22,11 @@ def _python_command(code: str) -> str:
class TestExecuteCommandTool(unittest.TestCase):
def _temp_file_path_from_result(self, result: str) -> str:
match = re.search(r"临时文件: (.+)", result)
self.assertIsNotNone(match)
return match.group(1).strip()
def _run_command(self, command: str, timeout: int = 60) -> str:
tool = ExecuteCommandTool(session_id="session-1", user_id="10001")
return asyncio.run(tool.run(command=command, timeout=timeout))
@@ -31,9 +37,19 @@ class TestExecuteCommandTool(unittest.TestCase):
)
result = self._run_command(command)
temp_file_path = self._temp_file_path_from_result(result)
self.assertIn("输出内容过长,已截断", result)
self.assertLess(len(result), MAX_OUTPUT_CHARS + 500)
self.addCleanup(lambda: os.path.exists(temp_file_path) and os.unlink(temp_file_path))
self.assertIn("命令输出超过 10KB", result)
self.assertIn("仅展示前 10KB 内容", result)
self.assertIn("如需完整内容,请继续读取该文件", result)
self.assertLess(len(result), MAX_OUTPUT_PREVIEW_BYTES + 600)
with open(temp_file_path, encoding="utf-8") as file_handle:
file_content = file_handle.read()
self.assertIn("[标准输出]", file_content)
self.assertGreater(len(file_content), 100000)
def test_timeout_returns_partial_output_promptly(self):
command = _python_command(
@@ -48,6 +64,24 @@ class TestExecuteCommandTool(unittest.TestCase):
self.assertIn("命令执行超时", result)
self.assertIn("started", result)
def test_timeout_with_large_output_writes_partial_full_log_to_temp_file(self):
command = _python_command(
"import sys, time; sys.stdout.write('x' * 20000); sys.stdout.flush(); time.sleep(5)"
)
result = self._run_command(command, timeout=1)
temp_file_path = self._temp_file_path_from_result(result)
self.addCleanup(lambda: os.path.exists(temp_file_path) and os.unlink(temp_file_path))
self.assertIn("命令执行超时", result)
self.assertIn("截至命令终止前的完整输出已写入临时文件", result)
with open(temp_file_path, encoding="utf-8") as file_handle:
file_content = file_handle.read()
self.assertIn("[标准输出]", file_content)
self.assertGreaterEqual(file_content.count("x"), 20000)
def test_timeout_is_capped(self):
command = _python_command("print('ok')")