revert execute_command streaming changes

Restore the previous subprocess handling for execute_command and drop the new command streaming test so agent startup is unblocked.
This commit is contained in:
jxxghp
2026-04-27 08:12:37 +08:00
parent 140d224a9a
commit 323844b26d
2 changed files with 29 additions and 222 deletions

View File

@@ -1,8 +1,7 @@
"""执行Shell命令工具"""
import asyncio
import codecs
from typing import Any, Dict, Optional, Type
from typing import Optional, Type
from pydantic import BaseModel, Field
@@ -27,133 +26,12 @@ class ExecuteCommandTool(MoviePilotTool):
description: str = "Safely execute shell commands on the server. Useful for system maintenance, checking status, or running custom scripts. Includes timeout and output limits."
args_schema: Type[BaseModel] = ExecuteCommandInput
require_admin: bool = True
RESULT_LIMIT = 3000
STREAM_CAPTURE_LIMIT = 2000
LIVE_OUTPUT_LIMIT = 1200
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据命令生成友好的提示消息"""
command = kwargs.get("command", "")
return f"执行系统命令: {command}"
def _build_result(
self,
message: str,
stdout_capture: Dict[str, Any],
stderr_capture: Dict[str, Any],
) -> str:
stdout_str = "".join(stdout_capture["chunks"]).strip()
stderr_str = "".join(stderr_capture["chunks"]).strip()
result = message
if stdout_str:
result += f"\n\n标准输出:\n{stdout_str}"
if stderr_str:
result += f"\n\n错误输出:\n{stderr_str}"
if not stdout_str and not stderr_str:
result += "\n\n(无输出内容)"
was_truncated = stdout_capture["truncated"] or stderr_capture["truncated"]
overflow_suffix = "\n\n...(输出内容过长,已截断)"
if was_truncated or len(result) > self.RESULT_LIMIT:
result = (
result[: self.RESULT_LIMIT - len(overflow_suffix)] + overflow_suffix
)
return result
def _append_capture(self, capture: Dict[str, Any], text: str):
if not text:
return
remaining = self.STREAM_CAPTURE_LIMIT - capture["length"]
if remaining <= 0:
capture["truncated"] = True
return
fragment = text[:remaining]
capture["chunks"].append(fragment)
capture["length"] += len(fragment)
if len(text) > remaining:
capture["truncated"] = True
def _should_emit_live_output(self) -> bool:
return bool(
self._stream_handler
and self._stream_handler.is_streaming
and self._stream_handler.is_auto_flushing
)
def _emit_live_output(
self, text: str, stream_name: str, live_state: Dict[str, Any]
):
if not text or not live_state["enabled"]:
return
header_key = f"{stream_name}_header_sent"
prefix = ""
if not live_state[header_key]:
prefix = "标准输出:\n" if stream_name == "stdout" else "\n错误输出:\n"
live_state[header_key] = True
payload = prefix + text
remaining = self.LIVE_OUTPUT_LIMIT - live_state["chars"]
if remaining <= 0:
if not live_state["truncated"]:
self._stream_handler.emit("\n...(命令输出过长,停止实时展示)\n")
live_state["truncated"] = True
return
fragment = payload[:remaining]
if fragment:
self._stream_handler.emit(fragment)
live_state["chars"] += len(fragment)
if len(payload) > remaining and not live_state["truncated"]:
self._stream_handler.emit("\n...(命令输出过长,停止实时展示)\n")
live_state["truncated"] = True
async def _collect_stream(
self,
stream: Optional[asyncio.StreamReader],
stream_name: str,
capture: Dict[str, Any],
live_state: Dict[str, Any],
):
if not stream:
return
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
while True:
chunk = await stream.read(512)
if not chunk:
tail = decoder.decode(b"", final=True)
if tail:
self._append_capture(capture, tail)
self._emit_live_output(tail, stream_name, live_state)
return
text = decoder.decode(chunk)
if not text:
continue
self._append_capture(capture, text)
self._emit_live_output(text, stream_name, live_state)
@staticmethod
async def _terminate_process(process: asyncio.subprocess.Process):
if process.returncode is not None:
return
try:
process.kill()
except ProcessLookupError:
return
try:
await asyncio.wait_for(process.wait(), timeout=5)
except asyncio.TimeoutError:
logger.warning("终止命令进程超时")
async def run(self, command: str, timeout: Optional[int] = 60, **kwargs) -> str:
logger.info(
f"执行工具: {self.name}, 参数: command={command}, timeout={timeout}"
@@ -178,54 +56,40 @@ class ExecuteCommandTool(MoviePilotTool):
command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
stdout_capture: Dict[str, Any] = {
"chunks": [],
"length": 0,
"truncated": False,
}
stderr_capture: Dict[str, Any] = {
"chunks": [],
"length": 0,
"truncated": False,
}
live_state: Dict[str, Any] = {
"enabled": self._should_emit_live_output(),
"chars": 0,
"truncated": False,
"stdout_header_sent": False,
"stderr_header_sent": False,
}
stdout_task = asyncio.create_task(
self._collect_stream(
process.stdout, "stdout", stdout_capture, live_state
)
)
stderr_task = asyncio.create_task(
self._collect_stream(
process.stderr, "stderr", stderr_capture, live_state
)
)
try:
# 等待完成,带超时
await asyncio.wait_for(process.wait(), timeout=timeout)
await asyncio.gather(stdout_task, stderr_task)
return self._build_result(
f"命令执行完成 (退出码: {process.returncode})",
stdout_capture,
stderr_capture,
stdout, stderr = await asyncio.wait_for(
process.communicate(), timeout=timeout
)
# 处理输出
stdout_str = stdout.decode("utf-8", errors="replace").strip()
stderr_str = stderr.decode("utf-8", errors="replace").strip()
exit_code = process.returncode
result = f"命令执行完成 (退出码: {exit_code})"
if stdout_str:
result += f"\n\n标准输出:\n{stdout_str}"
if stderr_str:
result += f"\n\n错误输出:\n{stderr_str}"
# 如果没有输出
if not stdout_str and not stderr_str:
result += "\n\n(无输出内容)"
# 限制输出长度,防止上下文过长
if len(result) > 3000:
result = result[:3000] + "\n\n...(输出内容过长,已截断)"
return result
except asyncio.TimeoutError:
# 超时处理
await self._terminate_process(process)
await asyncio.gather(stdout_task, stderr_task)
return self._build_result(
f"命令执行超时 (限制: {timeout}秒)",
stdout_capture,
stderr_capture,
)
try:
process.kill()
except ProcessLookupError:
pass
return f"命令执行超时 (限制: {timeout}秒)"
except Exception as e:
logger.error(f"执行命令失败: {e}", exc_info=True)

View File

@@ -1,57 +0,0 @@
import asyncio
import shlex
import sys
import unittest
import langchain.agents as langchain_agents
if not hasattr(langchain_agents, "create_agent"):
langchain_agents.create_agent = lambda *args, **kwargs: None
from app.agent.callback import StreamingHandler
from app.agent.tools.impl.execute_command import ExecuteCommandTool
class TestExecuteCommandTool(unittest.TestCase):
@staticmethod
def _build_python_command(script: str) -> str:
return f"{shlex.quote(sys.executable)} -c '{script}'"
@staticmethod
def _build_streaming_tool() -> tuple[ExecuteCommandTool, StreamingHandler]:
tool = ExecuteCommandTool(session_id="session-1", user_id="10001")
handler = StreamingHandler()
handler._streaming_enabled = True
handler._flush_task = object()
tool.set_stream_handler(handler)
return tool, handler
def test_run_streams_live_output_and_collects_result(self):
tool, handler = self._build_streaming_tool()
command = self._build_python_command(
'import sys; print("out"); print("err", file=sys.stderr)'
)
result = asyncio.run(tool.run(command=command, timeout=5))
live_output = asyncio.run(handler.take())
self.assertIn("命令执行完成 (退出码: 0)", result)
self.assertIn("标准输出:\nout", result)
self.assertIn("错误输出:\nerr", result)
self.assertIn("标准输出:\nout", live_output)
self.assertIn("错误输出:\nerr", live_output)
def test_run_timeout_keeps_partial_output(self):
tool = ExecuteCommandTool(session_id="session-1", user_id="10001")
command = self._build_python_command(
'import sys,time; print("start"); sys.stdout.flush(); time.sleep(0.2)'
)
result = asyncio.run(tool.run(command=command, timeout=0.05))
self.assertIn("命令执行超时", result)
self.assertIn("标准输出:\nstart", result)
if __name__ == "__main__":
unittest.main()