diff --git a/app/agent/tools/impl/execute_command.py b/app/agent/tools/impl/execute_command.py index c05bbc0f..a5cfb4b1 100644 --- a/app/agent/tools/impl/execute_command.py +++ b/app/agent/tools/impl/execute_command.py @@ -1,8 +1,7 @@ """执行Shell命令工具""" import asyncio -import codecs -from typing import Any, Dict, Optional, Type +from typing import Optional, Type from pydantic import BaseModel, Field @@ -27,133 +26,12 @@ class ExecuteCommandTool(MoviePilotTool): description: str = "Safely execute shell commands on the server. Useful for system maintenance, checking status, or running custom scripts. Includes timeout and output limits." args_schema: Type[BaseModel] = ExecuteCommandInput require_admin: bool = True - RESULT_LIMIT = 3000 - STREAM_CAPTURE_LIMIT = 2000 - LIVE_OUTPUT_LIMIT = 1200 def get_tool_message(self, **kwargs) -> Optional[str]: """根据命令生成友好的提示消息""" command = kwargs.get("command", "") return f"执行系统命令: {command}" - def _build_result( - self, - message: str, - stdout_capture: Dict[str, Any], - stderr_capture: Dict[str, Any], - ) -> str: - stdout_str = "".join(stdout_capture["chunks"]).strip() - stderr_str = "".join(stderr_capture["chunks"]).strip() - - result = message - if stdout_str: - result += f"\n\n标准输出:\n{stdout_str}" - if stderr_str: - result += f"\n\n错误输出:\n{stderr_str}" - if not stdout_str and not stderr_str: - result += "\n\n(无输出内容)" - - was_truncated = stdout_capture["truncated"] or stderr_capture["truncated"] - overflow_suffix = "\n\n...(输出内容过长,已截断)" - if was_truncated or len(result) > self.RESULT_LIMIT: - result = ( - result[: self.RESULT_LIMIT - len(overflow_suffix)] + overflow_suffix - ) - return result - - def _append_capture(self, capture: Dict[str, Any], text: str): - if not text: - return - - remaining = self.STREAM_CAPTURE_LIMIT - capture["length"] - if remaining <= 0: - capture["truncated"] = True - return - - fragment = text[:remaining] - capture["chunks"].append(fragment) - capture["length"] += len(fragment) - if len(text) > remaining: - capture["truncated"] = True - - def _should_emit_live_output(self) -> bool: - return bool( - self._stream_handler - and self._stream_handler.is_streaming - and self._stream_handler.is_auto_flushing - ) - - def _emit_live_output( - self, text: str, stream_name: str, live_state: Dict[str, Any] - ): - if not text or not live_state["enabled"]: - return - - header_key = f"{stream_name}_header_sent" - prefix = "" - if not live_state[header_key]: - prefix = "标准输出:\n" if stream_name == "stdout" else "\n错误输出:\n" - live_state[header_key] = True - - payload = prefix + text - remaining = self.LIVE_OUTPUT_LIMIT - live_state["chars"] - if remaining <= 0: - if not live_state["truncated"]: - self._stream_handler.emit("\n...(命令输出过长,停止实时展示)\n") - live_state["truncated"] = True - return - - fragment = payload[:remaining] - if fragment: - self._stream_handler.emit(fragment) - live_state["chars"] += len(fragment) - - if len(payload) > remaining and not live_state["truncated"]: - self._stream_handler.emit("\n...(命令输出过长,停止实时展示)\n") - live_state["truncated"] = True - - async def _collect_stream( - self, - stream: Optional[asyncio.StreamReader], - stream_name: str, - capture: Dict[str, Any], - live_state: Dict[str, Any], - ): - if not stream: - return - - decoder = codecs.getincrementaldecoder("utf-8")(errors="replace") - while True: - chunk = await stream.read(512) - if not chunk: - tail = decoder.decode(b"", final=True) - if tail: - self._append_capture(capture, tail) - self._emit_live_output(tail, stream_name, live_state) - return - - text = decoder.decode(chunk) - if not text: - continue - - self._append_capture(capture, text) - self._emit_live_output(text, stream_name, live_state) - - @staticmethod - async def _terminate_process(process: asyncio.subprocess.Process): - if process.returncode is not None: - return - - try: - process.kill() - except ProcessLookupError: - return - - try: - await asyncio.wait_for(process.wait(), timeout=5) - except asyncio.TimeoutError: - logger.warning("终止命令进程超时") - async def run(self, command: str, timeout: Optional[int] = 60, **kwargs) -> str: logger.info( f"执行工具: {self.name}, 参数: command={command}, timeout={timeout}" @@ -178,54 +56,40 @@ class ExecuteCommandTool(MoviePilotTool): command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) - stdout_capture: Dict[str, Any] = { - "chunks": [], - "length": 0, - "truncated": False, - } - stderr_capture: Dict[str, Any] = { - "chunks": [], - "length": 0, - "truncated": False, - } - live_state: Dict[str, Any] = { - "enabled": self._should_emit_live_output(), - "chars": 0, - "truncated": False, - "stdout_header_sent": False, - "stderr_header_sent": False, - } - - stdout_task = asyncio.create_task( - self._collect_stream( - process.stdout, "stdout", stdout_capture, live_state - ) - ) - stderr_task = asyncio.create_task( - self._collect_stream( - process.stderr, "stderr", stderr_capture, live_state - ) - ) - try: # 等待完成,带超时 - await asyncio.wait_for(process.wait(), timeout=timeout) - await asyncio.gather(stdout_task, stderr_task) - return self._build_result( - f"命令执行完成 (退出码: {process.returncode})", - stdout_capture, - stderr_capture, + stdout, stderr = await asyncio.wait_for( + process.communicate(), timeout=timeout ) + # 处理输出 + stdout_str = stdout.decode("utf-8", errors="replace").strip() + stderr_str = stderr.decode("utf-8", errors="replace").strip() + exit_code = process.returncode + + result = f"命令执行完成 (退出码: {exit_code})" + if stdout_str: + result += f"\n\n标准输出:\n{stdout_str}" + if stderr_str: + result += f"\n\n错误输出:\n{stderr_str}" + + # 如果没有输出 + if not stdout_str and not stderr_str: + result += "\n\n(无输出内容)" + + # 限制输出长度,防止上下文过长 + if len(result) > 3000: + result = result[:3000] + "\n\n...(输出内容过长,已截断)" + + return result + except asyncio.TimeoutError: # 超时处理 - await self._terminate_process(process) - await asyncio.gather(stdout_task, stderr_task) - return self._build_result( - f"命令执行超时 (限制: {timeout}秒)", - stdout_capture, - stderr_capture, - ) + try: + process.kill() + except ProcessLookupError: + pass + return f"命令执行超时 (限制: {timeout}秒)" except Exception as e: logger.error(f"执行命令失败: {e}", exc_info=True) diff --git a/tests/test_execute_command_tool.py b/tests/test_execute_command_tool.py deleted file mode 100644 index 8a9e74fe..00000000 --- a/tests/test_execute_command_tool.py +++ /dev/null @@ -1,57 +0,0 @@ -import asyncio -import shlex -import sys -import unittest - -import langchain.agents as langchain_agents - -if not hasattr(langchain_agents, "create_agent"): - langchain_agents.create_agent = lambda *args, **kwargs: None - -from app.agent.callback import StreamingHandler -from app.agent.tools.impl.execute_command import ExecuteCommandTool - - -class TestExecuteCommandTool(unittest.TestCase): - @staticmethod - def _build_python_command(script: str) -> str: - return f"{shlex.quote(sys.executable)} -c '{script}'" - - @staticmethod - def _build_streaming_tool() -> tuple[ExecuteCommandTool, StreamingHandler]: - tool = ExecuteCommandTool(session_id="session-1", user_id="10001") - handler = StreamingHandler() - handler._streaming_enabled = True - handler._flush_task = object() - tool.set_stream_handler(handler) - return tool, handler - - def test_run_streams_live_output_and_collects_result(self): - tool, handler = self._build_streaming_tool() - command = self._build_python_command( - 'import sys; print("out"); print("err", file=sys.stderr)' - ) - - result = asyncio.run(tool.run(command=command, timeout=5)) - live_output = asyncio.run(handler.take()) - - self.assertIn("命令执行完成 (退出码: 0)", result) - self.assertIn("标准输出:\nout", result) - self.assertIn("错误输出:\nerr", result) - self.assertIn("标准输出:\nout", live_output) - self.assertIn("错误输出:\nerr", live_output) - - def test_run_timeout_keeps_partial_output(self): - tool = ExecuteCommandTool(session_id="session-1", user_id="10001") - command = self._build_python_command( - 'import sys,time; print("start"); sys.stdout.flush(); time.sleep(0.2)' - ) - - result = asyncio.run(tool.run(command=command, timeout=0.05)) - - self.assertIn("命令执行超时", result) - self.assertIn("标准输出:\nstart", result) - - -if __name__ == "__main__": - unittest.main()