Files
MoviePilot/app/agent/tools/impl/execute_command.py
jxxghp 140d224a9a fix agent stream blocking during command execution
Offload synchronous message edits from the event loop and stream subprocess output so long-running commands stay responsive.
2026-04-27 07:57:32 +08:00

233 lines
7.7 KiB
Python

"""执行Shell命令工具"""
import asyncio
import codecs
from typing import Any, Dict, Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.log import logger
class ExecuteCommandInput(BaseModel):
"""执行Shell命令工具的输入参数模型"""
explanation: str = Field(
..., description="Clear explanation of why this command is being executed"
)
command: str = Field(..., description="The shell command to execute")
timeout: Optional[int] = Field(
60, description="Max execution time in seconds (default: 60)"
)
class ExecuteCommandTool(MoviePilotTool):
name: str = "execute_command"
description: str = "Safely execute shell commands on the server. Useful for system maintenance, checking status, or running custom scripts. Includes timeout and output limits."
args_schema: Type[BaseModel] = ExecuteCommandInput
require_admin: bool = True
RESULT_LIMIT = 3000
STREAM_CAPTURE_LIMIT = 2000
LIVE_OUTPUT_LIMIT = 1200
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据命令生成友好的提示消息"""
command = kwargs.get("command", "")
return f"执行系统命令: {command}"
def _build_result(
self,
message: str,
stdout_capture: Dict[str, Any],
stderr_capture: Dict[str, Any],
) -> str:
stdout_str = "".join(stdout_capture["chunks"]).strip()
stderr_str = "".join(stderr_capture["chunks"]).strip()
result = message
if stdout_str:
result += f"\n\n标准输出:\n{stdout_str}"
if stderr_str:
result += f"\n\n错误输出:\n{stderr_str}"
if not stdout_str and not stderr_str:
result += "\n\n(无输出内容)"
was_truncated = stdout_capture["truncated"] or stderr_capture["truncated"]
overflow_suffix = "\n\n...(输出内容过长,已截断)"
if was_truncated or len(result) > self.RESULT_LIMIT:
result = (
result[: self.RESULT_LIMIT - len(overflow_suffix)] + overflow_suffix
)
return result
def _append_capture(self, capture: Dict[str, Any], text: str):
if not text:
return
remaining = self.STREAM_CAPTURE_LIMIT - capture["length"]
if remaining <= 0:
capture["truncated"] = True
return
fragment = text[:remaining]
capture["chunks"].append(fragment)
capture["length"] += len(fragment)
if len(text) > remaining:
capture["truncated"] = True
def _should_emit_live_output(self) -> bool:
return bool(
self._stream_handler
and self._stream_handler.is_streaming
and self._stream_handler.is_auto_flushing
)
def _emit_live_output(
self, text: str, stream_name: str, live_state: Dict[str, Any]
):
if not text or not live_state["enabled"]:
return
header_key = f"{stream_name}_header_sent"
prefix = ""
if not live_state[header_key]:
prefix = "标准输出:\n" if stream_name == "stdout" else "\n错误输出:\n"
live_state[header_key] = True
payload = prefix + text
remaining = self.LIVE_OUTPUT_LIMIT - live_state["chars"]
if remaining <= 0:
if not live_state["truncated"]:
self._stream_handler.emit("\n...(命令输出过长,停止实时展示)\n")
live_state["truncated"] = True
return
fragment = payload[:remaining]
if fragment:
self._stream_handler.emit(fragment)
live_state["chars"] += len(fragment)
if len(payload) > remaining and not live_state["truncated"]:
self._stream_handler.emit("\n...(命令输出过长,停止实时展示)\n")
live_state["truncated"] = True
async def _collect_stream(
self,
stream: Optional[asyncio.StreamReader],
stream_name: str,
capture: Dict[str, Any],
live_state: Dict[str, Any],
):
if not stream:
return
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
while True:
chunk = await stream.read(512)
if not chunk:
tail = decoder.decode(b"", final=True)
if tail:
self._append_capture(capture, tail)
self._emit_live_output(tail, stream_name, live_state)
return
text = decoder.decode(chunk)
if not text:
continue
self._append_capture(capture, text)
self._emit_live_output(text, stream_name, live_state)
@staticmethod
async def _terminate_process(process: asyncio.subprocess.Process):
if process.returncode is not None:
return
try:
process.kill()
except ProcessLookupError:
return
try:
await asyncio.wait_for(process.wait(), timeout=5)
except asyncio.TimeoutError:
logger.warning("终止命令进程超时")
async def run(self, command: str, timeout: Optional[int] = 60, **kwargs) -> str:
logger.info(
f"执行工具: {self.name}, 参数: command={command}, timeout={timeout}"
)
# 简单安全过滤
forbidden_keywords = [
"rm -rf /",
":(){ :|:& };:",
"dd if=/dev/zero",
"mkfs",
"reboot",
"shutdown",
]
for keyword in forbidden_keywords:
if keyword in command:
return f"错误:命令包含禁止使用的关键字 '{keyword}'"
try:
# 执行命令
process = await asyncio.create_subprocess_shell(
command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
stdout_capture: Dict[str, Any] = {
"chunks": [],
"length": 0,
"truncated": False,
}
stderr_capture: Dict[str, Any] = {
"chunks": [],
"length": 0,
"truncated": False,
}
live_state: Dict[str, Any] = {
"enabled": self._should_emit_live_output(),
"chars": 0,
"truncated": False,
"stdout_header_sent": False,
"stderr_header_sent": False,
}
stdout_task = asyncio.create_task(
self._collect_stream(
process.stdout, "stdout", stdout_capture, live_state
)
)
stderr_task = asyncio.create_task(
self._collect_stream(
process.stderr, "stderr", stderr_capture, live_state
)
)
try:
# 等待完成,带超时
await asyncio.wait_for(process.wait(), timeout=timeout)
await asyncio.gather(stdout_task, stderr_task)
return self._build_result(
f"命令执行完成 (退出码: {process.returncode})",
stdout_capture,
stderr_capture,
)
except asyncio.TimeoutError:
# 超时处理
await self._terminate_process(process)
await asyncio.gather(stdout_task, stderr_task)
return self._build_result(
f"命令执行超时 (限制: {timeout}秒)",
stdout_capture,
stderr_capture,
)
except Exception as e:
logger.error(f"执行命令失败: {e}", exc_info=True)
return f"执行命令时发生错误: {str(e)}"