From 9076acc52e7a453c754c305beeb3e57e9eb947fd Mon Sep 17 00:00:00 2001 From: jxxghp Date: Mon, 18 May 2026 20:17:59 +0800 Subject: [PATCH] feat: add managed agent command sessions --- app/agent/callback/__init__.py | 5 +- app/agent/prompt/System Core Prompt.txt | 1 + app/agent/tools/impl/execute_command.py | 332 +++++++++--- app/agent/tools/impl/terminal_session.py | 628 +++++++++++++++++++++++ app/chain/skills.py | 2 +- skills/database-operation/SKILL.md | 8 +- tests/test_execute_command_tool.py | 137 ++++- 7 files changed, 1028 insertions(+), 85 deletions(-) create mode 100644 app/agent/tools/impl/terminal_session.py diff --git a/app/agent/callback/__init__.py b/app/agent/callback/__init__.py index 652ee8b7..f0e641d6 100644 --- a/app/agent/callback/__init__.py +++ b/app/agent/callback/__init__.py @@ -307,7 +307,10 @@ class StreamingHandler: or tool_kwargs.get("path"), ) if tool_name == "execute_command": - return "command", tool_kwargs.get("command") + return ( + "command", + tool_kwargs.get("command") or tool_kwargs.get("session_id"), + ) if tool_name == "ask_user_choice": return "interaction", tool_kwargs.get("message") if tool_name.startswith("search_") or tool_name in {"get_search_results"}: diff --git a/app/agent/prompt/System Core Prompt.txt b/app/agent/prompt/System Core Prompt.txt index 91a2d941..7050af60 100644 --- a/app/agent/prompt/System Core Prompt.txt +++ b/app/agent/prompt/System Core Prompt.txt @@ -56,6 +56,7 @@ Tool Calling Strategy: - Reuse the latest torrent search cache for `get_search_results` and `add_download` instead of re-running the same search unnecessarily. - Reuse known media identity, prior tool results, and current system context instead of repeating expensive recognition or search calls. - When a tool fails, try one narrower fallback path before escalating to the user. +- Use `execute_command` for shell work. Its default `action=start` starts a managed background session and returns `session_id`, `status`, `last_seq`, and `output_until_seq`; call the same tool again with `action=read`, `action=wait`, `action=write`, or `action=kill` to poll output, wait in short segments, send stdin, or stop the process. Media Management Rules: 1. Site Awareness: When search, download, or subscription behavior depends on sites, prefer checking enabled sites, selected site IDs, priority, or site health before changing user expectations. diff --git a/app/agent/tools/impl/execute_command.py b/app/agent/tools/impl/execute_command.py index aa4c43f8..5e2a3815 100644 --- a/app/agent/tools/impl/execute_command.py +++ b/app/agent/tools/impl/execute_command.py @@ -1,16 +1,25 @@ -"""执行Shell命令工具""" +"""执行 Shell 命令工具。""" + +from __future__ import annotations import asyncio +import json import os import signal import subprocess from dataclasses import dataclass, field from tempfile import NamedTemporaryFile -from typing import Any, Optional, TextIO, Type +from typing import Any, Literal, Optional, TextIO, Type from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool +from app.agent.tools.impl.terminal_session import ( + TERMINAL_DEFAULT_READ_BYTES, + TERMINAL_MAX_READ_BYTES, + TERMINAL_WAIT_DEFAULT_MS, + terminal_session_manager, +) from app.log import logger @@ -20,6 +29,14 @@ MAX_OUTPUT_PREVIEW_BYTES = 10 * 1024 READ_CHUNK_SIZE = 4096 KILL_GRACE_SECONDS = 3 COMMAND_CONCURRENCY_LIMIT = 2 +COMMAND_FORBIDDEN_KEYWORDS = ( + "rm -rf /", + ":(){ :|:& };:", + "dd if=/dev/zero", + "mkfs", + "reboot", + "shutdown", +) _command_semaphore = asyncio.Semaphore(COMMAND_CONCURRENCY_LIMIT) @@ -38,11 +55,13 @@ class _CommandOutput: @staticmethod def _clip_text_to_bytes(text: str, byte_limit: int) -> str: + """按 UTF-8 字节数截断文本,避免截断后出现非法字符。""" if byte_limit <= 0: return "" return text.encode("utf-8")[:byte_limit].decode("utf-8", errors="ignore") def _write_chunk(self, stream_name: str, text: str) -> None: + """把输出分片按 stdout/stderr 分段写入临时文件。""" if not self.temp_file_handle or not text: return @@ -56,6 +75,7 @@ class _CommandOutput: self.temp_file_handle.write(text) def _ensure_temp_file(self) -> None: + """首次超出预览上限时创建临时文件并补写已缓存预览。""" if self.temp_file_handle: return @@ -72,6 +92,7 @@ class _CommandOutput: self._write_chunk(stream_name, chunk) def close(self) -> None: + """关闭临时文件句柄,确保输出落盘。""" if not self.temp_file_handle: return self.temp_file_handle.flush() @@ -79,6 +100,7 @@ class _CommandOutput: self.temp_file_handle = None def append(self, stream_name: str, text: str) -> None: + """追加一段输出,超出预览上限后只保留完整日志文件。""" if not text: return @@ -104,47 +126,141 @@ class _CommandOutput: @property def stdout(self) -> str: + """返回当前保留的 stdout 预览。""" return "".join( text for stream_name, text in self.preview_entries if stream_name == "stdout" ).strip() @property def stderr(self) -> str: + """返回当前保留的 stderr 预览。""" return "".join( text for stream_name, text in self.preview_entries if stream_name == "stderr" ).strip() class ExecuteCommandInput(BaseModel): - """执行Shell命令工具的输入参数模型""" + """执行 Shell 命令工具的输入参数模型。""" explanation: str = Field( - ..., description="Clear explanation of why this command is being executed" + ..., description="Clear explanation of why this command action is needed" + ) + action: Optional[Literal["start", "read", "wait", "write", "kill", "run"]] = Field( + "start", + description=( + "Command action. start launches a managed background session and returns " + "session_id. read/wait/write/kill operate on that session. run executes " + "once and waits until completion or timeout." + ), + ) + command: Optional[str] = Field( + None, + description="Shell command. Required for action=start or action=run.", + ) + session_id: Optional[str] = Field( + None, + description="Command session id returned by action=start.", + ) + input_text: Optional[str] = Field( + None, + description="Text to send to stdin for action=write. Use \\u0003 for Ctrl+C.", + ) + signal_name: Optional[str] = Field( + "TERM", + description="Signal for action=kill, such as TERM, INT, KILL, or 15.", + ) + cwd: Optional[str] = Field( + None, + description="Working directory for action=start or action=run.", + ) + env: Optional[dict[str, Any]] = Field( + None, + description="Additional environment variables for action=start.", + ) + use_pty: Optional[bool] = Field( + True, + description="Use a pseudo terminal for action=start when supported.", + ) + since_seq: Optional[int] = Field( + None, + description="For action=read/wait, return output chunks after this seq.", + ) + max_bytes: Optional[int] = Field( + TERMINAL_DEFAULT_READ_BYTES, + description="For action=read/wait, maximum output bytes to return.", + ) + timeout_ms: Optional[int] = Field( + TERMINAL_WAIT_DEFAULT_MS, + description="For action=wait, maximum segmented wait time in milliseconds.", ) - command: str = Field(..., description="The shell command to execute") timeout: Optional[int] = Field( - 60, description="Max execution time in seconds (default: 60)" + 60, + description="For action=run, max execution time in seconds.", ) class ExecuteCommandTool(MoviePilotTool): + """统一执行和管理 Shell 命令的 Agent 工具。""" + name: str = "execute_command" description: str = ( - "Safely execute shell commands on the server. Useful for system " - "maintenance, checking status, or running custom scripts. Includes " - "timeout, concurrency, and output preview limits." + "Start and manage shell commands on the server. By default action=start " + "launches a background session and immediately returns session_id/status/" + "last_seq/output_until_seq. Call the same tool with action=read, wait, " + "write, or kill to poll output, wait in short segments, send stdin, or " + "terminate it. Use action=run only when a one-shot bounded command result " + "is preferred." ) args_schema: Type[BaseModel] = ExecuteCommandInput require_admin: bool = True + result_max_chars = TERMINAL_MAX_READ_BYTES + 4096 def get_tool_message(self, **kwargs) -> Optional[str]: - """根据命令生成友好的提示消息""" - command = kwargs.get("command", "") - return f"执行系统命令: {command}" + """根据命令动作生成友好的提示消息。""" + action = kwargs.get("action") or "start" + command = kwargs.get("command") + session_id = kwargs.get("session_id") + if action in {"start", "run"}: + return f"执行系统命令: {command or ''}" + if action == "read": + return f"读取命令输出: {session_id or ''}" + if action == "wait": + return f"等待命令会话: {session_id or ''}" + if action == "write": + return f"写入命令输入: {session_id or ''}" + if action == "kill": + return f"终止命令会话: {session_id or ''}" + return f"处理命令会话: {session_id or command or ''}" + + @staticmethod + def _dump(payload: dict[str, Any]) -> str: + """把结构化命令会话结果转换为 Agent 容易解析的 JSON 字符串。""" + return json.dumps(payload, ensure_ascii=False, indent=2) + + @staticmethod + def _require_session_id(session_id: Optional[str]) -> str: + """校验会话型 action 必须传入 session_id。""" + if not session_id: + raise ValueError("action 需要传入 session_id") + return session_id + + @staticmethod + def _require_command(command: Optional[str]) -> str: + """校验启动型 action 必须传入 command。""" + if not command or not command.strip(): + raise ValueError("action 需要传入 command") + return command + + @staticmethod + def _validate_command(command: str) -> None: + """复用旧工具的基础危险命令过滤,避免明显破坏性命令进入 shell。""" + for keyword in COMMAND_FORBIDDEN_KEYWORDS: + if keyword in command: + raise ValueError(f"命令包含禁止使用的关键字 '{keyword}'") @staticmethod def _normalize_timeout(timeout: Optional[int]) -> tuple[int, Optional[str]]: - """限制命令最长运行时间,避免 Agent 传入过大的 timeout。""" + """限制一次性执行命令的最长运行时间。""" try: normalized = int(timeout or DEFAULT_TIMEOUT_SECONDS) except (TypeError, ValueError): @@ -161,7 +277,7 @@ class ExecuteCommandTool(MoviePilotTool): @staticmethod def _subprocess_kwargs() -> dict: - """为子进程创建独立进程组,便于超时场景清理整棵子进程。""" + """为一次性命令创建独立进程组,便于超时清理整棵子进程。""" kwargs = { "stdin": subprocess.DEVNULL, "stdout": asyncio.subprocess.PIPE, @@ -179,17 +295,16 @@ class ExecuteCommandTool(MoviePilotTool): stream_name: str, output: _CommandOutput, ) -> None: - """按块读取输出,始终只把前 10KB 保留在返回结果中。""" + """按块读取一次性命令输出,只把前 10KB 保留在返回结果中。""" while True: chunk = await stream.read(READ_CHUNK_SIZE) if not chunk: break - output.append(stream_name, chunk.decode("utf-8", errors="replace")) @staticmethod - def _terminate_process(process: Any, sig: int): - """向进程组发送终止信号;不支持进程组的平台回退为单进程终止。""" + def _terminate_process(process: Any, sig: int) -> None: + """向进程组发送终止信号,不支持进程组的平台回退为单进程终止。""" try: if os.name == "posix": os.killpg(process.pid, sig) @@ -230,7 +345,7 @@ class ExecuteCommandTool(MoviePilotTool): @staticmethod async def _finish_reader_tasks(reader_tasks: list[asyncio.Task]) -> None: - """等待输出读取任务退出,异常只记录不影响工具返回。""" + """等待一次性命令输出读取任务退出,异常只记录不影响工具返回。""" if not reader_tasks: return done, pending = await asyncio.wait(reader_tasks, timeout=1) @@ -244,7 +359,7 @@ class ExecuteCommandTool(MoviePilotTool): logger.debug("命令输出读取任务异常: %s", result) @staticmethod - def _format_result( + def _format_run_result( *, exit_code: Optional[int], output: _CommandOutput, @@ -252,6 +367,7 @@ class ExecuteCommandTool(MoviePilotTool): timed_out: bool, timeout_note: Optional[str], ) -> str: + """格式化 action=run 的兼容文本结果。""" if timed_out: result = f"命令执行超时 (限制: {timeout}秒,已终止进程)" else: @@ -260,11 +376,7 @@ class ExecuteCommandTool(MoviePilotTool): if timeout_note: result += f"\n\n提示:\n{timeout_note}" if output.temp_file_path: - file_note = ( - "截至命令终止前的完整输出" - if timed_out - else "完整输出" - ) + file_note = "截至命令终止前的完整输出" if timed_out else "完整输出" result += ( "\n\n提示:\n" f"命令输出超过 10KB,仅返回前 {MAX_OUTPUT_PREVIEW_BYTES} 字节内容。\n" @@ -281,65 +393,129 @@ class ExecuteCommandTool(MoviePilotTool): result += "\n\n(无输出内容)" return result - async def run(self, command: str, timeout: Optional[int] = 60, **kwargs) -> str: - logger.info( - f"执行工具: {self.name}, 参数: command={command}, timeout={timeout}" - ) - - # 简单安全过滤 - forbidden_keywords = [ - "rm -rf /", - ":(){ :|:& };:", - "dd if=/dev/zero", - "mkfs", - "reboot", - "shutdown", - ] - for keyword in forbidden_keywords: - if keyword in command: - return f"错误:命令包含禁止使用的关键字 '{keyword}'" - + async def _run_once( + self, + *, + command: str, + timeout: Optional[int], + cwd: Optional[str] = None, + ) -> str: + """按旧模式一次性执行命令,等待完成或超时后返回文本结果。""" + self._validate_command(command) normalized_timeout, timeout_note = self._normalize_timeout(timeout) + async with _command_semaphore: + process = await asyncio.create_subprocess_shell( + command, + cwd=cwd, + **self._subprocess_kwargs(), + ) + output = _CommandOutput(preview_limit_bytes=MAX_OUTPUT_PREVIEW_BYTES) + wait_task = asyncio.create_task(process.wait()) + reader_tasks = [ + asyncio.create_task(self._read_stream(process.stdout, "stdout", output)), + asyncio.create_task(self._read_stream(process.stderr, "stderr", output)), + ] + + timed_out = False + try: + await asyncio.wait_for( + asyncio.shield(wait_task), timeout=normalized_timeout + ) + except asyncio.TimeoutError: + timed_out = True + await self._cleanup_process(process, wait_task) + + try: + await self._finish_reader_tasks(reader_tasks) + finally: + output.close() + + return self._format_run_result( + exit_code=process.returncode, + output=output, + timeout=normalized_timeout, + timed_out=timed_out, + timeout_note=timeout_note, + ) + + async def run( + self, + action: Optional[str] = "start", + command: Optional[str] = None, + session_id: Optional[str] = None, + input_text: Optional[str] = None, + signal_name: Optional[str] = "TERM", + cwd: Optional[str] = None, + env: Optional[dict[str, Any]] = None, + use_pty: Optional[bool] = True, + since_seq: Optional[int] = None, + max_bytes: Optional[int] = TERMINAL_DEFAULT_READ_BYTES, + timeout_ms: Optional[int] = TERMINAL_WAIT_DEFAULT_MS, + timeout: Optional[int] = 60, + **kwargs, + ) -> str: + """执行命令动作:默认后台启动,也支持读取、等待、写入、终止和一次性执行。""" + normalized_action = (action or "start").strip().lower() + logger.info( + "执行工具: %s, action=%s, command=%s, session_id=%s", + self.name, + normalized_action, + command, + session_id, + ) + try: - async with _command_semaphore: - # 命令输出可能非常大,必须边读边落盘,不能使用 communicate() 一次性收集。 - process = await asyncio.create_subprocess_shell( - command, **self._subprocess_kwargs() + if normalized_action == "start": + start_command = self._require_command(command) + self._validate_command(start_command) + payload = await terminal_session_manager.start( + command=start_command, + cwd=cwd, + env=env, + use_pty=use_pty, ) - output = _CommandOutput(preview_limit_bytes=MAX_OUTPUT_PREVIEW_BYTES) - wait_task = asyncio.create_task(process.wait()) - reader_tasks = [ - asyncio.create_task( - self._read_stream(process.stdout, "stdout", output) - ), - asyncio.create_task( - self._read_stream(process.stderr, "stderr", output) - ), - ] + return self._dump(payload) - timed_out = False - try: - await asyncio.wait_for( - asyncio.shield(wait_task), timeout=normalized_timeout - ) - except asyncio.TimeoutError: - timed_out = True - await self._cleanup_process(process, wait_task) + if normalized_action == "read": + payload = await terminal_session_manager.read( + session_id=self._require_session_id(session_id), + since_seq=since_seq, + max_bytes=max_bytes, + ) + return self._dump(payload) - try: - await self._finish_reader_tasks(reader_tasks) - finally: - output.close() + if normalized_action == "wait": + payload = await terminal_session_manager.wait( + session_id=self._require_session_id(session_id), + timeout_ms=timeout_ms, + since_seq=since_seq, + max_bytes=max_bytes, + ) + return self._dump(payload) - return self._format_result( - exit_code=process.returncode, - output=output, - timeout=normalized_timeout, - timed_out=timed_out, - timeout_note=timeout_note, + if normalized_action == "write": + payload = await terminal_session_manager.write( + session_id=self._require_session_id(session_id), + input_text=input_text or "", + ) + return self._dump(payload) + + if normalized_action == "kill": + payload = await terminal_session_manager.kill( + session_id=self._require_session_id(session_id), + sig=signal_name, + ) + return self._dump(payload) + + if normalized_action == "run": + return await self._run_once( + command=self._require_command(command), + timeout=timeout, + cwd=cwd, ) - except Exception as e: - logger.error(f"执行命令失败: {e}", exc_info=True) - return f"执行命令时发生错误: {str(e)}" + raise ValueError(f"不支持的 action: {action}") + except Exception as err: + logger.error("执行命令 action 失败: %s", err, exc_info=True) + return self._dump({"error": str(err), "status": "error", "action": normalized_action}) diff --git a/app/agent/tools/impl/terminal_session.py b/app/agent/tools/impl/terminal_session.py new file mode 100644 index 00000000..c17f7d53 --- /dev/null +++ b/app/agent/tools/impl/terminal_session.py @@ -0,0 +1,628 @@ +"""Agent 终端会话管理器。""" + +from __future__ import annotations + +import asyncio +import errno +import os +import signal +import subprocess +import time +import uuid +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Optional + +from app.core.config import settings +from app.log import logger + +if os.name == "posix": + import fcntl as _fcntl + import pty as _pty +else: + _fcntl = None + _pty = None + + +TERMINAL_CONCURRENCY_LIMIT = 4 +TERMINAL_RETENTION_SECONDS = 30 * 60 +TERMINAL_MAX_RETAINED_BYTES = 1024 * 1024 +TERMINAL_DEFAULT_READ_BYTES = 10 * 1024 +TERMINAL_MAX_READ_BYTES = 64 * 1024 +TERMINAL_READ_CHUNK_SIZE = 4096 +TERMINAL_PTY_POLL_INTERVAL = 0.05 +TERMINAL_WAIT_DEFAULT_MS = 1000 +TERMINAL_WAIT_MAX_MS = 60 * 1000 +TERMINAL_KILL_GRACE_SECONDS = 3 +TERMINAL_FORBIDDEN_KEYWORDS = ( + "rm -rf /", + ":(){ :|:& };:", + "dd if=/dev/zero", + "mkfs", + "reboot", + "shutdown", +) + + +@dataclass +class _TerminalChunk: + """记录终端输出分片,供增量读取时按 seq 过滤。""" + + seq: int + stream: str + text: str + byte_size: int + created_at: float + + +@dataclass +class _TerminalSession: + """保存一个后台命令会话的进程、输出和状态。""" + + session_id: str + command: str + cwd: str + pid: int + use_pty: bool + created_at: float = field(default_factory=time.time) + updated_at: float = field(default_factory=time.time) + status: str = "running" + exit_code: Optional[int] = None + process: Optional[asyncio.subprocess.Process] = None + master_fd: Optional[int] = None + chunks: list[_TerminalChunk] = field(default_factory=list) + next_seq: int = 1 + retained_from_seq: int = 1 + retained_bytes: int = 0 + kill_requested: bool = False + error: Optional[str] = None + reader_tasks: list[asyncio.Task] = field(default_factory=list) + wait_task: Optional[asyncio.Task] = None + + def append_output(self, stream: str, data: bytes) -> None: + """追加输出并按容量上限丢弃最旧分片,避免长任务撑爆内存。""" + if not data: + return + + text = data.decode("utf-8", errors="replace") + chunk = _TerminalChunk( + seq=self.next_seq, + stream=stream, + text=text, + byte_size=len(data), + created_at=time.time(), + ) + self.next_seq += 1 + self.chunks.append(chunk) + self.retained_bytes += chunk.byte_size + self.updated_at = chunk.created_at + self._trim_output() + + def _trim_output(self) -> None: + """移除超出保留上限的旧输出分片。""" + while self.retained_bytes > TERMINAL_MAX_RETAINED_BYTES and self.chunks: + removed = self.chunks.pop(0) + self.retained_bytes -= removed.byte_size + self.retained_from_seq = removed.seq + 1 + + def mark_finished(self, exit_code: Optional[int]) -> None: + """标记进程已经结束,并记录退出码。""" + self.exit_code = exit_code + self.status = "killed" if self.kill_requested else "exited" + self.updated_at = time.time() + + def mark_error(self, message: str) -> None: + """标记会话异常,保留错误信息供后续读取。""" + self.error = message + self.status = "error" + self.updated_at = time.time() + + def close_pty(self) -> None: + """关闭父进程持有的 PTY master fd。""" + if self.master_fd is None: + return + try: + os.close(self.master_fd) + except OSError: + pass + self.master_fd = None + + +class _TerminalSessionManager: + """管理 Agent 后台终端会话的生命周期。""" + + def __init__(self) -> None: + """初始化会话表和并发保护锁。""" + self._sessions: dict[str, _TerminalSession] = {} + self._lock = asyncio.Lock() + + @staticmethod + def _normalize_bool(value: Any, default: bool = True) -> bool: + """兼容 LLM 或 HTTP 传入的 bool/string/int 布尔值。""" + if value is None: + return default + if isinstance(value, bool): + return value + if isinstance(value, str): + return value.strip().lower() not in {"false", "0", "no", "off"} + return bool(value) + + @staticmethod + def _normalize_cwd(cwd: Optional[str]) -> str: + """解析工作目录,未传入时默认使用 MoviePilot 项目根目录。""" + if not cwd: + return str(settings.ROOT_PATH) + path = Path(cwd).expanduser() + if not path.is_absolute(): + path = (settings.ROOT_PATH / path).resolve() + else: + path = path.resolve() + if not path.exists(): + raise FileNotFoundError(f"工作目录不存在: {path}") + if not path.is_dir(): + raise NotADirectoryError(f"工作目录不是目录: {path}") + return str(path) + + @staticmethod + def _build_env(env: Optional[dict[str, Any]]) -> dict[str, str]: + """合并环境变量,并把值稳定转换为字符串。""" + merged_env = os.environ.copy() + if not env: + return merged_env + for key, value in env.items(): + if value is None: + continue + merged_env[str(key)] = str(value) + return merged_env + + @staticmethod + def _validate_command(command: str) -> None: + """拒绝明显危险或空白命令。""" + if not command or not command.strip(): + raise ValueError("命令不能为空") + for keyword in TERMINAL_FORBIDDEN_KEYWORDS: + if keyword in command: + raise ValueError(f"命令包含禁止使用的关键字 '{keyword}'") + + @staticmethod + def _set_nonblocking(fd: int) -> None: + """将 PTY master fd 设置为非阻塞,避免后台读取任务卡住事件循环。""" + if _fcntl is None: + raise RuntimeError("当前平台不支持 PTY 非阻塞设置") + flags = _fcntl.fcntl(fd, _fcntl.F_GETFL) + _fcntl.fcntl(fd, _fcntl.F_SETFL, flags | os.O_NONBLOCK) + + @staticmethod + def _pipe_subprocess_kwargs() -> dict[str, Any]: + """生成普通管道模式的子进程参数。""" + kwargs: dict[str, Any] = { + "stdin": asyncio.subprocess.PIPE, + "stdout": asyncio.subprocess.PIPE, + "stderr": asyncio.subprocess.PIPE, + } + if os.name == "posix": + kwargs["start_new_session"] = True + elif os.name == "nt": + kwargs["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP + return kwargs + + async def start( + self, + *, + command: str, + cwd: Optional[str] = None, + env: Optional[dict[str, Any]] = None, + use_pty: Any = True, + ) -> dict[str, Any]: + """启动后台命令并立即返回会话 ID。""" + self._validate_command(command) + normalized_cwd = self._normalize_cwd(cwd) + normalized_env = self._build_env(env) + should_use_pty = self._normalize_bool(use_pty, default=True) and os.name == "posix" + + async with self._lock: + self._cleanup_finished_sessions_locked() + if self._active_session_count_locked() >= TERMINAL_CONCURRENCY_LIMIT: + raise RuntimeError( + f"后台终端会话数已达到上限 {TERMINAL_CONCURRENCY_LIMIT}" + ) + + session = ( + await self._start_pty_session(command, normalized_cwd, normalized_env) + if should_use_pty + else await self._start_pipe_session(command, normalized_cwd, normalized_env) + ) + + async with self._lock: + self._sessions[session.session_id] = session + + logger.info( + "启动后台终端会话: session_id=%s, pid=%s, use_pty=%s, command=%s", + session.session_id, + session.pid, + session.use_pty, + command, + ) + await asyncio.sleep(0) + return self._session_payload(session, output="", output_truncated=False) + + async def _start_pty_session( + self, command: str, cwd: str, env: dict[str, str] + ) -> _TerminalSession: + """通过 PTY fork 启动交互式命令会话。""" + if _pty is None: + raise RuntimeError("当前平台不支持 PTY 会话") + pid, master_fd = _pty.fork() + if pid == 0: + os.chdir(cwd) + os.environ.clear() + os.environ.update(env) + shell = os.environ.get("SHELL") or "/bin/sh" + os.execl(shell, shell, "-lc", command) + + self._set_nonblocking(master_fd) + session = _TerminalSession( + session_id=f"term_{uuid.uuid4().hex[:12]}", + command=command, + cwd=cwd, + pid=pid, + use_pty=True, + master_fd=master_fd, + ) + session.reader_tasks.append(asyncio.create_task(self._read_pty(session))) + session.wait_task = asyncio.create_task(self._wait_pty_process(session)) + return session + + async def _start_pipe_session( + self, command: str, cwd: str, env: dict[str, str] + ) -> _TerminalSession: + """通过普通 stdin/stdout/stderr 管道启动命令会话。""" + process = await asyncio.create_subprocess_shell( + command, + cwd=cwd, + env=env, + **self._pipe_subprocess_kwargs(), + ) + session = _TerminalSession( + session_id=f"term_{uuid.uuid4().hex[:12]}", + command=command, + cwd=cwd, + pid=process.pid or 0, + use_pty=False, + process=process, + ) + if process.stdout: + session.reader_tasks.append( + asyncio.create_task(self._read_pipe(session, process.stdout, "stdout")) + ) + if process.stderr: + session.reader_tasks.append( + asyncio.create_task(self._read_pipe(session, process.stderr, "stderr")) + ) + session.wait_task = asyncio.create_task(self._wait_pipe_process(session)) + return session + + async def _read_pty(self, session: _TerminalSession) -> None: + """持续从 PTY 读取增量输出。""" + while session.master_fd is not None: + try: + data = os.read(session.master_fd, TERMINAL_READ_CHUNK_SIZE) + except BlockingIOError: + await asyncio.sleep(TERMINAL_PTY_POLL_INTERVAL) + continue + except OSError as err: + if err.errno not in {errno.EIO, errno.EBADF}: + logger.debug("PTY 输出读取异常: session_id=%s, error=%s", session.session_id, err) + break + + if not data: + break + session.append_output("pty", data) + + async def _read_pipe( + self, + session: _TerminalSession, + stream: asyncio.StreamReader, + stream_name: str, + ) -> None: + """持续从普通管道读取增量输出。""" + while True: + data = await stream.read(TERMINAL_READ_CHUNK_SIZE) + if not data: + break + session.append_output(stream_name, data) + + async def _wait_pty_process(self, session: _TerminalSession) -> None: + """等待 PTY 子进程结束并完成输出读取任务收尾。""" + try: + _, status = await asyncio.to_thread(os.waitpid, session.pid, 0) + exit_code = os.waitstatus_to_exitcode(status) + session.mark_finished(exit_code) + except ChildProcessError: + session.mark_finished(session.exit_code) + except Exception as err: + session.mark_error(str(err)) + logger.warning("等待 PTY 进程失败: session_id=%s, error=%s", session.session_id, err) + finally: + await self._finish_reader_tasks(session) + session.close_pty() + + async def _wait_pipe_process(self, session: _TerminalSession) -> None: + """等待普通管道子进程结束并完成输出读取任务收尾。""" + try: + if not session.process: + session.mark_error("进程对象不存在") + return + exit_code = await session.process.wait() + session.mark_finished(exit_code) + except Exception as err: + session.mark_error(str(err)) + logger.warning("等待管道进程失败: session_id=%s, error=%s", session.session_id, err) + finally: + await self._finish_reader_tasks(session) + + async def _finish_reader_tasks(self, session: _TerminalSession) -> None: + """等待输出读取任务退出,超时后取消残留任务。""" + if not session.reader_tasks: + return + done, pending = await asyncio.wait(session.reader_tasks, timeout=1) + for task in pending: + task.cancel() + await asyncio.gather(*done, *pending, return_exceptions=True) + + async def read( + self, + *, + session_id: str, + since_seq: Optional[int] = None, + max_bytes: Optional[int] = TERMINAL_DEFAULT_READ_BYTES, + ) -> dict[str, Any]: + """读取会话当前保留的增量输出。""" + session = self.get_session(session_id) + output, output_truncated, output_until_seq = self._collect_output( + session, + since_seq=since_seq, + max_bytes=max_bytes, + ) + return self._session_payload( + session, + output=output, + output_truncated=output_truncated, + output_until_seq=output_until_seq, + ) + + async def wait( + self, + *, + session_id: str, + timeout_ms: Optional[int] = TERMINAL_WAIT_DEFAULT_MS, + since_seq: Optional[int] = None, + max_bytes: Optional[int] = TERMINAL_DEFAULT_READ_BYTES, + ) -> dict[str, Any]: + """短暂等待会话结束,并返回等待期间可见的增量输出。""" + session = self.get_session(session_id) + normalized_timeout = self._normalize_wait_timeout(timeout_ms) + if session.wait_task and not session.wait_task.done(): + try: + await asyncio.wait_for( + asyncio.shield(session.wait_task), + timeout=normalized_timeout / 1000, + ) + except asyncio.TimeoutError: + pass + + output, output_truncated, output_until_seq = self._collect_output( + session, + since_seq=since_seq, + max_bytes=max_bytes, + ) + payload = self._session_payload( + session, + output=output, + output_truncated=output_truncated, + output_until_seq=output_until_seq, + ) + payload["wait_timeout_ms"] = normalized_timeout + return payload + + async def write(self, *, session_id: str, input_text: str) -> dict[str, Any]: + """向会话 stdin 写入文本,PTY 模式下写入 master fd。""" + session = self.get_session(session_id) + if session.status != "running": + raise RuntimeError(f"会话已结束,当前状态: {session.status}") + + data = (input_text or "").encode("utf-8") + if session.use_pty: + if session.master_fd is None: + raise RuntimeError("PTY 已关闭") + await asyncio.to_thread(os.write, session.master_fd, data) + else: + if not session.process or not session.process.stdin: + raise RuntimeError("进程 stdin 不可写") + session.process.stdin.write(data) + await session.process.stdin.drain() + + session.updated_at = time.time() + payload = self._session_payload(session, output="", output_truncated=False) + payload["written_bytes"] = len(data) + return payload + + async def kill( + self, + *, + session_id: str, + sig: Optional[str | int] = "TERM", + ) -> dict[str, Any]: + """向会话进程组发送信号并等待短暂清理。""" + session = self.get_session(session_id) + if session.status != "running": + return self._session_payload(session, output="", output_truncated=False) + + session.kill_requested = True + signal_number = self._resolve_signal(sig) + self._send_signal(session, signal_number) + + if session.wait_task and not session.wait_task.done(): + try: + await asyncio.wait_for( + asyncio.shield(session.wait_task), + timeout=TERMINAL_KILL_GRACE_SECONDS, + ) + except asyncio.TimeoutError: + force_signal = getattr(signal, "SIGKILL", signal.SIGTERM) + self._send_signal(session, force_signal) + + return self._session_payload(session, output="", output_truncated=False) + + def get_session(self, session_id: str) -> _TerminalSession: + """按 ID 获取会话,不存在时抛出清晰错误。""" + session = self._sessions.get(session_id) + if not session: + raise KeyError(f"终端会话不存在: {session_id}") + return session + + @staticmethod + def _normalize_wait_timeout(timeout_ms: Optional[int]) -> int: + """限制 wait 单次等待时间,避免工具调用长时间占用模型回合。""" + try: + normalized = int(timeout_ms or TERMINAL_WAIT_DEFAULT_MS) + except (TypeError, ValueError): + normalized = TERMINAL_WAIT_DEFAULT_MS + if normalized < 0: + return 0 + return min(normalized, TERMINAL_WAIT_MAX_MS) + + @staticmethod + def _normalize_read_limit(max_bytes: Optional[int]) -> int: + """限制单次读取返回的输出大小。""" + try: + normalized = int(max_bytes or TERMINAL_DEFAULT_READ_BYTES) + except (TypeError, ValueError): + normalized = TERMINAL_DEFAULT_READ_BYTES + if normalized <= 0: + return TERMINAL_DEFAULT_READ_BYTES + return min(normalized, TERMINAL_MAX_READ_BYTES) + + def _collect_output( + self, + session: _TerminalSession, + *, + since_seq: Optional[int], + max_bytes: Optional[int], + ) -> tuple[str, bool, int]: + """按 seq 和大小限制收集输出文本。""" + read_limit = self._normalize_read_limit(max_bytes) + selected_chunks = [ + chunk + for chunk in session.chunks + if since_seq is None or chunk.seq > since_seq + ] + output_parts: list[str] = [] + output_bytes = 0 + output_truncated = False + last_stream: Optional[str] = None + output_until_seq = since_seq or session.retained_from_seq - 1 + + for chunk in selected_chunks: + prefix = self._stream_prefix(chunk.stream, last_stream, session.use_pty) + text = f"{prefix}{chunk.text}" if prefix else chunk.text + encoded = text.encode("utf-8") + remaining = read_limit - output_bytes + if len(encoded) > remaining: + if remaining > 0: + output_parts.append( + encoded[:remaining].decode("utf-8", errors="ignore") + ) + output_truncated = True + break + output_parts.append(text) + output_bytes += len(encoded) + last_stream = chunk.stream + output_until_seq = chunk.seq + + if since_seq is not None and since_seq < session.retained_from_seq - 1: + output_truncated = True + if not output_truncated: + output_until_seq = session.next_seq - 1 + return "".join(output_parts), output_truncated, output_until_seq + + @staticmethod + def _stream_prefix(stream: str, last_stream: Optional[str], use_pty: bool) -> str: + """为普通管道输出增加 stdout/stderr 分段标识。""" + if use_pty or stream == last_stream: + return "" + title = "标准输出" if stream == "stdout" else "错误输出" + return f"\n[{title}]\n" + + @staticmethod + def _resolve_signal(sig: Optional[str | int]) -> int: + """解析字符串或数字形式的信号名。""" + if isinstance(sig, int): + return sig + signal_name = str(sig or "TERM").strip().upper() + if signal_name.isdigit(): + return int(signal_name) + if not signal_name.startswith("SIG"): + signal_name = f"SIG{signal_name}" + return int(getattr(signal, signal_name, signal.SIGTERM)) + + @staticmethod + def _send_signal(session: _TerminalSession, sig: int) -> None: + """优先向进程组发信号,失败时回退到单进程。""" + try: + if os.name == "posix": + os.killpg(session.pid, sig) + elif session.process: + if sig == getattr(signal, "SIGKILL", None): + session.process.kill() + else: + session.process.terminate() + except ProcessLookupError: + pass + + def _active_session_count_locked(self) -> int: + """统计仍在运行的会话数量。""" + return sum(1 for session in self._sessions.values() if session.status == "running") + + def _cleanup_finished_sessions_locked(self) -> None: + """清理已经结束且超过保留时间的会话。""" + now = time.time() + expired_ids = [ + session_id + for session_id, session in self._sessions.items() + if session.status != "running" + and now - session.updated_at > TERMINAL_RETENTION_SECONDS + ] + for session_id in expired_ids: + session = self._sessions.pop(session_id) + session.close_pty() + + @staticmethod + def _session_payload( + session: _TerminalSession, + *, + output: str, + output_truncated: bool, + output_until_seq: Optional[int] = None, + ) -> dict[str, Any]: + """生成工具返回的结构化会话状态。""" + return { + "session_id": session.session_id, + "command": session.command, + "cwd": session.cwd, + "pid": session.pid, + "status": session.status, + "exit_code": session.exit_code, + "use_pty": session.use_pty, + "last_seq": session.next_seq - 1, + "output_until_seq": ( + session.next_seq - 1 if output_until_seq is None else output_until_seq + ), + "retained_from_seq": session.retained_from_seq, + "output_truncated": output_truncated, + "output": output, + "error": session.error, + } + + +terminal_session_manager = _TerminalSessionManager() diff --git a/app/chain/skills.py b/app/chain/skills.py index b6d80780..331d751e 100644 --- a/app/chain/skills.py +++ b/app/chain/skills.py @@ -823,7 +823,7 @@ class SkillsChain(ChainBase): ) if search_row: buttons.append(search_row) - for index, _skill in enumerate(page_items, start=1): + for index, _skill in enumerate(items, start=1): buttons.append( [ { diff --git a/skills/database-operation/SKILL.md b/skills/database-operation/SKILL.md index eea526de..1618507a 100644 --- a/skills/database-operation/SKILL.md +++ b/skills/database-operation/SKILL.md @@ -1,6 +1,6 @@ --- name: database-operation -version: 1 +version: 2 description: >- Use this skill when you need to execute SQL against the MoviePilot database. This skill guides you through connecting to the database and executing SQL statements. @@ -20,7 +20,7 @@ This skill guides you through executing SQL against the MoviePilot database. Bot ## Prerequisites You need the following tools: -- `execute_command` - Execute shell commands to run database queries +- `execute_command` - Execute shell commands to run database queries. Use `action=run` when you need the query result immediately. ## Getting Database Connection Info @@ -38,7 +38,7 @@ The system prompt `` section already contains all the database conn Extract the database file path from `` (the path inside the parentheses after `SQLite`). -Use `execute_command` to run queries: +Use `execute_command` with `action=run` to run queries: ```bash sqlite3 -header -column "YOUR SQL QUERY HERE;" @@ -66,7 +66,7 @@ sqlite3 ".schema tablename" Extract the connection parameters from `` (parse `user:password@host:port/database` from the parentheses after `PostgreSQL`). -Use `execute_command` to run queries via `psql`: +Use `execute_command` with `action=run` to run queries via `psql`: ```bash PGPASSWORD= psql -h -p -U -d -c "YOUR SQL QUERY HERE;" diff --git a/tests/test_execute_command_tool.py b/tests/test_execute_command_tool.py index d23ab467..ef988cf1 100644 --- a/tests/test_execute_command_tool.py +++ b/tests/test_execute_command_tool.py @@ -1,4 +1,5 @@ import asyncio +import json import os import re import shlex @@ -23,15 +24,18 @@ def _python_command(code: str) -> str: class TestExecuteCommandTool(unittest.TestCase): def _temp_file_path_from_result(self, result: str) -> str: + """从工具返回文本中提取完整输出临时文件路径。""" match = re.search(r"临时文件: (.+)", result) self.assertIsNotNone(match) return match.group(1).strip() def _run_command(self, command: str, timeout: int = 60) -> str: + """按一次性执行模式运行命令,兼容旧测试断言。""" tool = ExecuteCommandTool(session_id="session-1", user_id="10001") - return asyncio.run(tool.run(command=command, timeout=timeout)) + return asyncio.run(tool.run(action="run", command=command, timeout=timeout)) def test_large_output_is_truncated_before_returning_to_agent(self): + """大输出一次性命令只把预览返回给 Agent,并把完整内容写到临时文件。""" command = _python_command( "import sys; sys.stdout.write('x' * 200000); sys.stdout.flush()" ) @@ -52,6 +56,7 @@ class TestExecuteCommandTool(unittest.TestCase): self.assertGreater(len(file_content), 100000) def test_timeout_returns_partial_output_promptly(self): + """一次性命令超时后应及时返回已经读取到的部分输出。""" command = _python_command( "import time; print('started', flush=True); time.sleep(5)" ) @@ -65,6 +70,7 @@ class TestExecuteCommandTool(unittest.TestCase): self.assertIn("started", result) def test_timeout_with_large_output_writes_partial_full_log_to_temp_file(self): + """超时且输出较大时,终止前完整输出应写入临时文件。""" command = _python_command( "import sys, time; sys.stdout.write('x' * 20000); sys.stdout.flush(); time.sleep(5)" ) @@ -83,6 +89,7 @@ class TestExecuteCommandTool(unittest.TestCase): self.assertGreaterEqual(file_content.count("x"), 20000) def test_timeout_is_capped(self): + """一次性执行的 timeout 参数超过上限时应自动限幅。""" command = _python_command("print('ok')") result = self._run_command(command, timeout=9999) @@ -90,6 +97,134 @@ class TestExecuteCommandTool(unittest.TestCase): self.assertIn("timeout 参数超过上限", result) self.assertIn("ok", result) + def test_forbidden_command_is_rejected(self): + """明显危险命令在进入 shell 前应被拒绝。""" + result = self._run_command("echo ok && rm -rf /") + + payload = json.loads(result) + self.assertEqual(payload["status"], "error") + self.assertIn("禁止使用", payload["error"]) + + +class TestExecuteCommandSessionTool(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + """创建每个测试复用的统一命令工具。""" + self.tool = ExecuteCommandTool(session_id="session-1", user_id="10001") + self._created_sessions: list[str] = [] + + async def asyncTearDown(self): + """清理测试中残留的后台会话,避免影响后续用例。""" + for session_id in self._created_sessions: + await self.tool.run(action="kill", session_id=session_id) + + @staticmethod + def _loads(result: str) -> dict: + """解析 execute_command 返回的 JSON 字符串。""" + return json.loads(result) + + async def _start(self, command: str, *, use_pty: bool = False) -> dict: + """通过 execute_command 启动后台会话并记录 ID。""" + payload = self._loads( + await self.tool.run(action="start", command=command, use_pty=use_pty) + ) + session_id = payload.get("session_id") + if session_id: + self._created_sessions.append(session_id) + return payload + + async def test_default_action_starts_session_promptly(self): + """不传 action 时应默认后台启动,并快速返回会话 ID。""" + command = _python_command( + "import time; print('ready', flush=True); time.sleep(1); print('done', flush=True)" + ) + + started_at = time.monotonic() + start_payload = self._loads(await self.tool.run(command=command, use_pty=False)) + duration = time.monotonic() - started_at + self._created_sessions.append(start_payload["session_id"]) + + self.assertLess(duration, 0.8) + self.assertEqual(start_payload["status"], "running") + self.assertIn("session_id", start_payload) + + async def test_read_and_wait_get_incremental_output(self): + """同一个 execute_command 工具应能分段等待并读取增量输出。""" + command = _python_command( + "import time; print('ready', flush=True); time.sleep(1); print('done', flush=True)" + ) + start_payload = await self._start(command) + + wait_payload = self._loads( + await self.tool.run( + action="wait", + session_id=start_payload["session_id"], + timeout_ms=200, + since_seq=0, + ) + ) + + self.assertEqual(wait_payload["status"], "running") + self.assertIn("ready", wait_payload["output"]) + + final_payload = self._loads( + await self.tool.run( + action="wait", + session_id=start_payload["session_id"], + timeout_ms=3000, + since_seq=wait_payload["output_until_seq"], + ) + ) + + self.assertEqual(final_payload["status"], "exited") + self.assertEqual(final_payload["exit_code"], 0) + self.assertIn("done", final_payload["output"]) + + async def test_write_sends_input_to_running_process(self): + """write 动作应能向后台进程 stdin 写入交互输入。""" + command = _python_command( + "line = input('name: '); print('hello ' + line, flush=True)" + ) + start_payload = await self._start(command) + + await self.tool.run( + action="write", + session_id=start_payload["session_id"], + input_text="moviepilot\n", + ) + wait_payload = self._loads( + await self.tool.run( + action="wait", + session_id=start_payload["session_id"], + timeout_ms=3000, + since_seq=0, + ) + ) + + self.assertEqual(wait_payload["status"], "exited") + self.assertIn("hello moviepilot", wait_payload["output"]) + + async def test_kill_stops_long_running_process(self): + """kill 动作应能终止长时间运行的后台命令会话。""" + command = _python_command( + "import time; print('started', flush=True); time.sleep(20)" + ) + start_payload = await self._start(command) + + read_payload = self._loads( + await self.tool.run( + action="wait", + session_id=start_payload["session_id"], + timeout_ms=500, + since_seq=0, + ) + ) + kill_payload = self._loads( + await self.tool.run(action="kill", session_id=start_payload["session_id"]) + ) + + self.assertIn("started", read_payload["output"]) + self.assertIn(kill_payload["status"], {"killed", "exited"}) + if __name__ == "__main__": unittest.main()