feat: add managed agent command sessions

This commit is contained in:
jxxghp
2026-05-18 20:17:59 +08:00
parent f5eeeebeba
commit 9076acc52e
7 changed files with 1028 additions and 85 deletions

View File

@@ -307,7 +307,10 @@ class StreamingHandler:
or tool_kwargs.get("path"),
)
if tool_name == "execute_command":
return "command", tool_kwargs.get("command")
return (
"command",
tool_kwargs.get("command") or tool_kwargs.get("session_id"),
)
if tool_name == "ask_user_choice":
return "interaction", tool_kwargs.get("message")
if tool_name.startswith("search_") or tool_name in {"get_search_results"}:

View File

@@ -56,6 +56,7 @@ Tool Calling Strategy:
- Reuse the latest torrent search cache for `get_search_results` and `add_download` instead of re-running the same search unnecessarily.
- Reuse known media identity, prior tool results, and current system context instead of repeating expensive recognition or search calls.
- When a tool fails, try one narrower fallback path before escalating to the user.
- Use `execute_command` for shell work. Its default `action=start` starts a managed background session and returns `session_id`, `status`, `last_seq`, and `output_until_seq`; call the same tool again with `action=read`, `action=wait`, `action=write`, or `action=kill` to poll output, wait in short segments, send stdin, or stop the process.
Media Management Rules:
1. Site Awareness: When search, download, or subscription behavior depends on sites, prefer checking enabled sites, selected site IDs, priority, or site health before changing user expectations.

View File

@@ -1,16 +1,25 @@
"""执行Shell命令工具"""
"""执行 Shell 命令工具"""
from __future__ import annotations
import asyncio
import json
import os
import signal
import subprocess
from dataclasses import dataclass, field
from tempfile import NamedTemporaryFile
from typing import Any, Optional, TextIO, Type
from typing import Any, Literal, Optional, TextIO, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.agent.tools.impl.terminal_session import (
TERMINAL_DEFAULT_READ_BYTES,
TERMINAL_MAX_READ_BYTES,
TERMINAL_WAIT_DEFAULT_MS,
terminal_session_manager,
)
from app.log import logger
@@ -20,6 +29,14 @@ MAX_OUTPUT_PREVIEW_BYTES = 10 * 1024
READ_CHUNK_SIZE = 4096
KILL_GRACE_SECONDS = 3
COMMAND_CONCURRENCY_LIMIT = 2
COMMAND_FORBIDDEN_KEYWORDS = (
"rm -rf /",
":(){ :|:& };:",
"dd if=/dev/zero",
"mkfs",
"reboot",
"shutdown",
)
_command_semaphore = asyncio.Semaphore(COMMAND_CONCURRENCY_LIMIT)
@@ -38,11 +55,13 @@ class _CommandOutput:
@staticmethod
def _clip_text_to_bytes(text: str, byte_limit: int) -> str:
"""按 UTF-8 字节数截断文本,避免截断后出现非法字符。"""
if byte_limit <= 0:
return ""
return text.encode("utf-8")[:byte_limit].decode("utf-8", errors="ignore")
def _write_chunk(self, stream_name: str, text: str) -> None:
"""把输出分片按 stdout/stderr 分段写入临时文件。"""
if not self.temp_file_handle or not text:
return
@@ -56,6 +75,7 @@ class _CommandOutput:
self.temp_file_handle.write(text)
def _ensure_temp_file(self) -> None:
"""首次超出预览上限时创建临时文件并补写已缓存预览。"""
if self.temp_file_handle:
return
@@ -72,6 +92,7 @@ class _CommandOutput:
self._write_chunk(stream_name, chunk)
def close(self) -> None:
"""关闭临时文件句柄,确保输出落盘。"""
if not self.temp_file_handle:
return
self.temp_file_handle.flush()
@@ -79,6 +100,7 @@ class _CommandOutput:
self.temp_file_handle = None
def append(self, stream_name: str, text: str) -> None:
"""追加一段输出,超出预览上限后只保留完整日志文件。"""
if not text:
return
@@ -104,47 +126,141 @@ class _CommandOutput:
@property
def stdout(self) -> str:
"""返回当前保留的 stdout 预览。"""
return "".join(
text for stream_name, text in self.preview_entries if stream_name == "stdout"
).strip()
@property
def stderr(self) -> str:
"""返回当前保留的 stderr 预览。"""
return "".join(
text for stream_name, text in self.preview_entries if stream_name == "stderr"
).strip()
class ExecuteCommandInput(BaseModel):
"""执行Shell命令工具的输入参数模型"""
"""执行 Shell 命令工具的输入参数模型"""
explanation: str = Field(
..., description="Clear explanation of why this command is being executed"
..., description="Clear explanation of why this command action is needed"
)
action: Optional[Literal["start", "read", "wait", "write", "kill", "run"]] = Field(
"start",
description=(
"Command action. start launches a managed background session and returns "
"session_id. read/wait/write/kill operate on that session. run executes "
"once and waits until completion or timeout."
),
)
command: Optional[str] = Field(
None,
description="Shell command. Required for action=start or action=run.",
)
session_id: Optional[str] = Field(
None,
description="Command session id returned by action=start.",
)
input_text: Optional[str] = Field(
None,
description="Text to send to stdin for action=write. Use \\u0003 for Ctrl+C.",
)
signal_name: Optional[str] = Field(
"TERM",
description="Signal for action=kill, such as TERM, INT, KILL, or 15.",
)
cwd: Optional[str] = Field(
None,
description="Working directory for action=start or action=run.",
)
env: Optional[dict[str, Any]] = Field(
None,
description="Additional environment variables for action=start.",
)
use_pty: Optional[bool] = Field(
True,
description="Use a pseudo terminal for action=start when supported.",
)
since_seq: Optional[int] = Field(
None,
description="For action=read/wait, return output chunks after this seq.",
)
max_bytes: Optional[int] = Field(
TERMINAL_DEFAULT_READ_BYTES,
description="For action=read/wait, maximum output bytes to return.",
)
timeout_ms: Optional[int] = Field(
TERMINAL_WAIT_DEFAULT_MS,
description="For action=wait, maximum segmented wait time in milliseconds.",
)
command: str = Field(..., description="The shell command to execute")
timeout: Optional[int] = Field(
60, description="Max execution time in seconds (default: 60)"
60,
description="For action=run, max execution time in seconds.",
)
class ExecuteCommandTool(MoviePilotTool):
"""统一执行和管理 Shell 命令的 Agent 工具。"""
name: str = "execute_command"
description: str = (
"Safely execute shell commands on the server. Useful for system "
"maintenance, checking status, or running custom scripts. Includes "
"timeout, concurrency, and output preview limits."
"Start and manage shell commands on the server. By default action=start "
"launches a background session and immediately returns session_id/status/"
"last_seq/output_until_seq. Call the same tool with action=read, wait, "
"write, or kill to poll output, wait in short segments, send stdin, or "
"terminate it. Use action=run only when a one-shot bounded command result "
"is preferred."
)
args_schema: Type[BaseModel] = ExecuteCommandInput
require_admin: bool = True
result_max_chars = TERMINAL_MAX_READ_BYTES + 4096
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据命令生成友好的提示消息"""
command = kwargs.get("command", "")
return f"执行系统命令: {command}"
"""根据命令动作生成友好的提示消息"""
action = kwargs.get("action") or "start"
command = kwargs.get("command")
session_id = kwargs.get("session_id")
if action in {"start", "run"}:
return f"执行系统命令: {command or ''}"
if action == "read":
return f"读取命令输出: {session_id or ''}"
if action == "wait":
return f"等待命令会话: {session_id or ''}"
if action == "write":
return f"写入命令输入: {session_id or ''}"
if action == "kill":
return f"终止命令会话: {session_id or ''}"
return f"处理命令会话: {session_id or command or ''}"
@staticmethod
def _dump(payload: dict[str, Any]) -> str:
"""把结构化命令会话结果转换为 Agent 容易解析的 JSON 字符串。"""
return json.dumps(payload, ensure_ascii=False, indent=2)
@staticmethod
def _require_session_id(session_id: Optional[str]) -> str:
"""校验会话型 action 必须传入 session_id。"""
if not session_id:
raise ValueError("action 需要传入 session_id")
return session_id
@staticmethod
def _require_command(command: Optional[str]) -> str:
"""校验启动型 action 必须传入 command。"""
if not command or not command.strip():
raise ValueError("action 需要传入 command")
return command
@staticmethod
def _validate_command(command: str) -> None:
"""复用旧工具的基础危险命令过滤,避免明显破坏性命令进入 shell。"""
for keyword in COMMAND_FORBIDDEN_KEYWORDS:
if keyword in command:
raise ValueError(f"命令包含禁止使用的关键字 '{keyword}'")
@staticmethod
def _normalize_timeout(timeout: Optional[int]) -> tuple[int, Optional[str]]:
"""限制命令最长运行时间,避免 Agent 传入过大的 timeout"""
"""限制一次性执行命令最长运行时间。"""
try:
normalized = int(timeout or DEFAULT_TIMEOUT_SECONDS)
except (TypeError, ValueError):
@@ -161,7 +277,7 @@ class ExecuteCommandTool(MoviePilotTool):
@staticmethod
def _subprocess_kwargs() -> dict:
"""子进程创建独立进程组,便于超时场景清理整棵子进程。"""
"""一次性命令创建独立进程组,便于超时清理整棵子进程。"""
kwargs = {
"stdin": subprocess.DEVNULL,
"stdout": asyncio.subprocess.PIPE,
@@ -179,17 +295,16 @@ class ExecuteCommandTool(MoviePilotTool):
stream_name: str,
output: _CommandOutput,
) -> None:
"""按块读取输出,始终只把前 10KB 保留在返回结果中。"""
"""按块读取一次性命令输出,只把前 10KB 保留在返回结果中。"""
while True:
chunk = await stream.read(READ_CHUNK_SIZE)
if not chunk:
break
output.append(stream_name, chunk.decode("utf-8", errors="replace"))
@staticmethod
def _terminate_process(process: Any, sig: int):
"""向进程组发送终止信号不支持进程组的平台回退为单进程终止。"""
def _terminate_process(process: Any, sig: int) -> None:
"""向进程组发送终止信号不支持进程组的平台回退为单进程终止。"""
try:
if os.name == "posix":
os.killpg(process.pid, sig)
@@ -230,7 +345,7 @@ class ExecuteCommandTool(MoviePilotTool):
@staticmethod
async def _finish_reader_tasks(reader_tasks: list[asyncio.Task]) -> None:
"""等待输出读取任务退出,异常只记录不影响工具返回。"""
"""等待一次性命令输出读取任务退出,异常只记录不影响工具返回。"""
if not reader_tasks:
return
done, pending = await asyncio.wait(reader_tasks, timeout=1)
@@ -244,7 +359,7 @@ class ExecuteCommandTool(MoviePilotTool):
logger.debug("命令输出读取任务异常: %s", result)
@staticmethod
def _format_result(
def _format_run_result(
*,
exit_code: Optional[int],
output: _CommandOutput,
@@ -252,6 +367,7 @@ class ExecuteCommandTool(MoviePilotTool):
timed_out: bool,
timeout_note: Optional[str],
) -> str:
"""格式化 action=run 的兼容文本结果。"""
if timed_out:
result = f"命令执行超时 (限制: {timeout}秒,已终止进程)"
else:
@@ -260,11 +376,7 @@ class ExecuteCommandTool(MoviePilotTool):
if timeout_note:
result += f"\n\n提示:\n{timeout_note}"
if output.temp_file_path:
file_note = (
"截至命令终止前的完整输出"
if timed_out
else "完整输出"
)
file_note = "截至命令终止前的完整输出" if timed_out else "完整输出"
result += (
"\n\n提示:\n"
f"命令输出超过 10KB仅返回前 {MAX_OUTPUT_PREVIEW_BYTES} 字节内容。\n"
@@ -281,65 +393,129 @@ class ExecuteCommandTool(MoviePilotTool):
result += "\n\n(无输出内容)"
return result
async def run(self, command: str, timeout: Optional[int] = 60, **kwargs) -> str:
logger.info(
f"执行工具: {self.name}, 参数: command={command}, timeout={timeout}"
)
# 简单安全过滤
forbidden_keywords = [
"rm -rf /",
":(){ :|:& };:",
"dd if=/dev/zero",
"mkfs",
"reboot",
"shutdown",
]
for keyword in forbidden_keywords:
if keyword in command:
return f"错误:命令包含禁止使用的关键字 '{keyword}'"
async def _run_once(
self,
*,
command: str,
timeout: Optional[int],
cwd: Optional[str] = None,
) -> str:
"""按旧模式一次性执行命令,等待完成或超时后返回文本结果。"""
self._validate_command(command)
normalized_timeout, timeout_note = self._normalize_timeout(timeout)
async with _command_semaphore:
process = await asyncio.create_subprocess_shell(
command,
cwd=cwd,
**self._subprocess_kwargs(),
)
output = _CommandOutput(preview_limit_bytes=MAX_OUTPUT_PREVIEW_BYTES)
wait_task = asyncio.create_task(process.wait())
reader_tasks = [
asyncio.create_task(self._read_stream(process.stdout, "stdout", output)),
asyncio.create_task(self._read_stream(process.stderr, "stderr", output)),
]
timed_out = False
try:
await asyncio.wait_for(
asyncio.shield(wait_task), timeout=normalized_timeout
)
except asyncio.TimeoutError:
timed_out = True
await self._cleanup_process(process, wait_task)
try:
await self._finish_reader_tasks(reader_tasks)
finally:
output.close()
return self._format_run_result(
exit_code=process.returncode,
output=output,
timeout=normalized_timeout,
timed_out=timed_out,
timeout_note=timeout_note,
)
async def run(
self,
action: Optional[str] = "start",
command: Optional[str] = None,
session_id: Optional[str] = None,
input_text: Optional[str] = None,
signal_name: Optional[str] = "TERM",
cwd: Optional[str] = None,
env: Optional[dict[str, Any]] = None,
use_pty: Optional[bool] = True,
since_seq: Optional[int] = None,
max_bytes: Optional[int] = TERMINAL_DEFAULT_READ_BYTES,
timeout_ms: Optional[int] = TERMINAL_WAIT_DEFAULT_MS,
timeout: Optional[int] = 60,
**kwargs,
) -> str:
"""执行命令动作:默认后台启动,也支持读取、等待、写入、终止和一次性执行。"""
normalized_action = (action or "start").strip().lower()
logger.info(
"执行工具: %s, action=%s, command=%s, session_id=%s",
self.name,
normalized_action,
command,
session_id,
)
try:
async with _command_semaphore:
# 命令输出可能非常大,必须边读边落盘,不能使用 communicate() 一次性收集。
process = await asyncio.create_subprocess_shell(
command, **self._subprocess_kwargs()
if normalized_action == "start":
start_command = self._require_command(command)
self._validate_command(start_command)
payload = await terminal_session_manager.start(
command=start_command,
cwd=cwd,
env=env,
use_pty=use_pty,
)
output = _CommandOutput(preview_limit_bytes=MAX_OUTPUT_PREVIEW_BYTES)
wait_task = asyncio.create_task(process.wait())
reader_tasks = [
asyncio.create_task(
self._read_stream(process.stdout, "stdout", output)
),
asyncio.create_task(
self._read_stream(process.stderr, "stderr", output)
),
]
return self._dump(payload)
timed_out = False
try:
await asyncio.wait_for(
asyncio.shield(wait_task), timeout=normalized_timeout
)
except asyncio.TimeoutError:
timed_out = True
await self._cleanup_process(process, wait_task)
if normalized_action == "read":
payload = await terminal_session_manager.read(
session_id=self._require_session_id(session_id),
since_seq=since_seq,
max_bytes=max_bytes,
)
return self._dump(payload)
try:
await self._finish_reader_tasks(reader_tasks)
finally:
output.close()
if normalized_action == "wait":
payload = await terminal_session_manager.wait(
session_id=self._require_session_id(session_id),
timeout_ms=timeout_ms,
since_seq=since_seq,
max_bytes=max_bytes,
)
return self._dump(payload)
return self._format_result(
exit_code=process.returncode,
output=output,
timeout=normalized_timeout,
timed_out=timed_out,
timeout_note=timeout_note,
if normalized_action == "write":
payload = await terminal_session_manager.write(
session_id=self._require_session_id(session_id),
input_text=input_text or "",
)
return self._dump(payload)
if normalized_action == "kill":
payload = await terminal_session_manager.kill(
session_id=self._require_session_id(session_id),
sig=signal_name,
)
return self._dump(payload)
if normalized_action == "run":
return await self._run_once(
command=self._require_command(command),
timeout=timeout,
cwd=cwd,
)
except Exception as e:
logger.error(f"执行命令失败: {e}", exc_info=True)
return f"执行命令时发生错误: {str(e)}"
raise ValueError(f"不支持的 action: {action}")
except Exception as err:
logger.error("执行命令 action 失败: %s", err, exc_info=True)
return self._dump({"error": str(err), "status": "error", "action": normalized_action})

View File

@@ -0,0 +1,628 @@
"""Agent 终端会话管理器。"""
from __future__ import annotations
import asyncio
import errno
import os
import signal
import subprocess
import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Optional
from app.core.config import settings
from app.log import logger
if os.name == "posix":
import fcntl as _fcntl
import pty as _pty
else:
_fcntl = None
_pty = None
TERMINAL_CONCURRENCY_LIMIT = 4
TERMINAL_RETENTION_SECONDS = 30 * 60
TERMINAL_MAX_RETAINED_BYTES = 1024 * 1024
TERMINAL_DEFAULT_READ_BYTES = 10 * 1024
TERMINAL_MAX_READ_BYTES = 64 * 1024
TERMINAL_READ_CHUNK_SIZE = 4096
TERMINAL_PTY_POLL_INTERVAL = 0.05
TERMINAL_WAIT_DEFAULT_MS = 1000
TERMINAL_WAIT_MAX_MS = 60 * 1000
TERMINAL_KILL_GRACE_SECONDS = 3
TERMINAL_FORBIDDEN_KEYWORDS = (
"rm -rf /",
":(){ :|:& };:",
"dd if=/dev/zero",
"mkfs",
"reboot",
"shutdown",
)
@dataclass
class _TerminalChunk:
"""记录终端输出分片,供增量读取时按 seq 过滤。"""
seq: int
stream: str
text: str
byte_size: int
created_at: float
@dataclass
class _TerminalSession:
"""保存一个后台命令会话的进程、输出和状态。"""
session_id: str
command: str
cwd: str
pid: int
use_pty: bool
created_at: float = field(default_factory=time.time)
updated_at: float = field(default_factory=time.time)
status: str = "running"
exit_code: Optional[int] = None
process: Optional[asyncio.subprocess.Process] = None
master_fd: Optional[int] = None
chunks: list[_TerminalChunk] = field(default_factory=list)
next_seq: int = 1
retained_from_seq: int = 1
retained_bytes: int = 0
kill_requested: bool = False
error: Optional[str] = None
reader_tasks: list[asyncio.Task] = field(default_factory=list)
wait_task: Optional[asyncio.Task] = None
def append_output(self, stream: str, data: bytes) -> None:
"""追加输出并按容量上限丢弃最旧分片,避免长任务撑爆内存。"""
if not data:
return
text = data.decode("utf-8", errors="replace")
chunk = _TerminalChunk(
seq=self.next_seq,
stream=stream,
text=text,
byte_size=len(data),
created_at=time.time(),
)
self.next_seq += 1
self.chunks.append(chunk)
self.retained_bytes += chunk.byte_size
self.updated_at = chunk.created_at
self._trim_output()
def _trim_output(self) -> None:
"""移除超出保留上限的旧输出分片。"""
while self.retained_bytes > TERMINAL_MAX_RETAINED_BYTES and self.chunks:
removed = self.chunks.pop(0)
self.retained_bytes -= removed.byte_size
self.retained_from_seq = removed.seq + 1
def mark_finished(self, exit_code: Optional[int]) -> None:
"""标记进程已经结束,并记录退出码。"""
self.exit_code = exit_code
self.status = "killed" if self.kill_requested else "exited"
self.updated_at = time.time()
def mark_error(self, message: str) -> None:
"""标记会话异常,保留错误信息供后续读取。"""
self.error = message
self.status = "error"
self.updated_at = time.time()
def close_pty(self) -> None:
"""关闭父进程持有的 PTY master fd。"""
if self.master_fd is None:
return
try:
os.close(self.master_fd)
except OSError:
pass
self.master_fd = None
class _TerminalSessionManager:
"""管理 Agent 后台终端会话的生命周期。"""
def __init__(self) -> None:
"""初始化会话表和并发保护锁。"""
self._sessions: dict[str, _TerminalSession] = {}
self._lock = asyncio.Lock()
@staticmethod
def _normalize_bool(value: Any, default: bool = True) -> bool:
"""兼容 LLM 或 HTTP 传入的 bool/string/int 布尔值。"""
if value is None:
return default
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.strip().lower() not in {"false", "0", "no", "off"}
return bool(value)
@staticmethod
def _normalize_cwd(cwd: Optional[str]) -> str:
"""解析工作目录,未传入时默认使用 MoviePilot 项目根目录。"""
if not cwd:
return str(settings.ROOT_PATH)
path = Path(cwd).expanduser()
if not path.is_absolute():
path = (settings.ROOT_PATH / path).resolve()
else:
path = path.resolve()
if not path.exists():
raise FileNotFoundError(f"工作目录不存在: {path}")
if not path.is_dir():
raise NotADirectoryError(f"工作目录不是目录: {path}")
return str(path)
@staticmethod
def _build_env(env: Optional[dict[str, Any]]) -> dict[str, str]:
"""合并环境变量,并把值稳定转换为字符串。"""
merged_env = os.environ.copy()
if not env:
return merged_env
for key, value in env.items():
if value is None:
continue
merged_env[str(key)] = str(value)
return merged_env
@staticmethod
def _validate_command(command: str) -> None:
"""拒绝明显危险或空白命令。"""
if not command or not command.strip():
raise ValueError("命令不能为空")
for keyword in TERMINAL_FORBIDDEN_KEYWORDS:
if keyword in command:
raise ValueError(f"命令包含禁止使用的关键字 '{keyword}'")
@staticmethod
def _set_nonblocking(fd: int) -> None:
"""将 PTY master fd 设置为非阻塞,避免后台读取任务卡住事件循环。"""
if _fcntl is None:
raise RuntimeError("当前平台不支持 PTY 非阻塞设置")
flags = _fcntl.fcntl(fd, _fcntl.F_GETFL)
_fcntl.fcntl(fd, _fcntl.F_SETFL, flags | os.O_NONBLOCK)
@staticmethod
def _pipe_subprocess_kwargs() -> dict[str, Any]:
"""生成普通管道模式的子进程参数。"""
kwargs: dict[str, Any] = {
"stdin": asyncio.subprocess.PIPE,
"stdout": asyncio.subprocess.PIPE,
"stderr": asyncio.subprocess.PIPE,
}
if os.name == "posix":
kwargs["start_new_session"] = True
elif os.name == "nt":
kwargs["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP
return kwargs
async def start(
self,
*,
command: str,
cwd: Optional[str] = None,
env: Optional[dict[str, Any]] = None,
use_pty: Any = True,
) -> dict[str, Any]:
"""启动后台命令并立即返回会话 ID。"""
self._validate_command(command)
normalized_cwd = self._normalize_cwd(cwd)
normalized_env = self._build_env(env)
should_use_pty = self._normalize_bool(use_pty, default=True) and os.name == "posix"
async with self._lock:
self._cleanup_finished_sessions_locked()
if self._active_session_count_locked() >= TERMINAL_CONCURRENCY_LIMIT:
raise RuntimeError(
f"后台终端会话数已达到上限 {TERMINAL_CONCURRENCY_LIMIT}"
)
session = (
await self._start_pty_session(command, normalized_cwd, normalized_env)
if should_use_pty
else await self._start_pipe_session(command, normalized_cwd, normalized_env)
)
async with self._lock:
self._sessions[session.session_id] = session
logger.info(
"启动后台终端会话: session_id=%s, pid=%s, use_pty=%s, command=%s",
session.session_id,
session.pid,
session.use_pty,
command,
)
await asyncio.sleep(0)
return self._session_payload(session, output="", output_truncated=False)
async def _start_pty_session(
self, command: str, cwd: str, env: dict[str, str]
) -> _TerminalSession:
"""通过 PTY fork 启动交互式命令会话。"""
if _pty is None:
raise RuntimeError("当前平台不支持 PTY 会话")
pid, master_fd = _pty.fork()
if pid == 0:
os.chdir(cwd)
os.environ.clear()
os.environ.update(env)
shell = os.environ.get("SHELL") or "/bin/sh"
os.execl(shell, shell, "-lc", command)
self._set_nonblocking(master_fd)
session = _TerminalSession(
session_id=f"term_{uuid.uuid4().hex[:12]}",
command=command,
cwd=cwd,
pid=pid,
use_pty=True,
master_fd=master_fd,
)
session.reader_tasks.append(asyncio.create_task(self._read_pty(session)))
session.wait_task = asyncio.create_task(self._wait_pty_process(session))
return session
async def _start_pipe_session(
self, command: str, cwd: str, env: dict[str, str]
) -> _TerminalSession:
"""通过普通 stdin/stdout/stderr 管道启动命令会话。"""
process = await asyncio.create_subprocess_shell(
command,
cwd=cwd,
env=env,
**self._pipe_subprocess_kwargs(),
)
session = _TerminalSession(
session_id=f"term_{uuid.uuid4().hex[:12]}",
command=command,
cwd=cwd,
pid=process.pid or 0,
use_pty=False,
process=process,
)
if process.stdout:
session.reader_tasks.append(
asyncio.create_task(self._read_pipe(session, process.stdout, "stdout"))
)
if process.stderr:
session.reader_tasks.append(
asyncio.create_task(self._read_pipe(session, process.stderr, "stderr"))
)
session.wait_task = asyncio.create_task(self._wait_pipe_process(session))
return session
async def _read_pty(self, session: _TerminalSession) -> None:
"""持续从 PTY 读取增量输出。"""
while session.master_fd is not None:
try:
data = os.read(session.master_fd, TERMINAL_READ_CHUNK_SIZE)
except BlockingIOError:
await asyncio.sleep(TERMINAL_PTY_POLL_INTERVAL)
continue
except OSError as err:
if err.errno not in {errno.EIO, errno.EBADF}:
logger.debug("PTY 输出读取异常: session_id=%s, error=%s", session.session_id, err)
break
if not data:
break
session.append_output("pty", data)
async def _read_pipe(
self,
session: _TerminalSession,
stream: asyncio.StreamReader,
stream_name: str,
) -> None:
"""持续从普通管道读取增量输出。"""
while True:
data = await stream.read(TERMINAL_READ_CHUNK_SIZE)
if not data:
break
session.append_output(stream_name, data)
async def _wait_pty_process(self, session: _TerminalSession) -> None:
"""等待 PTY 子进程结束并完成输出读取任务收尾。"""
try:
_, status = await asyncio.to_thread(os.waitpid, session.pid, 0)
exit_code = os.waitstatus_to_exitcode(status)
session.mark_finished(exit_code)
except ChildProcessError:
session.mark_finished(session.exit_code)
except Exception as err:
session.mark_error(str(err))
logger.warning("等待 PTY 进程失败: session_id=%s, error=%s", session.session_id, err)
finally:
await self._finish_reader_tasks(session)
session.close_pty()
async def _wait_pipe_process(self, session: _TerminalSession) -> None:
"""等待普通管道子进程结束并完成输出读取任务收尾。"""
try:
if not session.process:
session.mark_error("进程对象不存在")
return
exit_code = await session.process.wait()
session.mark_finished(exit_code)
except Exception as err:
session.mark_error(str(err))
logger.warning("等待管道进程失败: session_id=%s, error=%s", session.session_id, err)
finally:
await self._finish_reader_tasks(session)
async def _finish_reader_tasks(self, session: _TerminalSession) -> None:
"""等待输出读取任务退出,超时后取消残留任务。"""
if not session.reader_tasks:
return
done, pending = await asyncio.wait(session.reader_tasks, timeout=1)
for task in pending:
task.cancel()
await asyncio.gather(*done, *pending, return_exceptions=True)
async def read(
self,
*,
session_id: str,
since_seq: Optional[int] = None,
max_bytes: Optional[int] = TERMINAL_DEFAULT_READ_BYTES,
) -> dict[str, Any]:
"""读取会话当前保留的增量输出。"""
session = self.get_session(session_id)
output, output_truncated, output_until_seq = self._collect_output(
session,
since_seq=since_seq,
max_bytes=max_bytes,
)
return self._session_payload(
session,
output=output,
output_truncated=output_truncated,
output_until_seq=output_until_seq,
)
async def wait(
self,
*,
session_id: str,
timeout_ms: Optional[int] = TERMINAL_WAIT_DEFAULT_MS,
since_seq: Optional[int] = None,
max_bytes: Optional[int] = TERMINAL_DEFAULT_READ_BYTES,
) -> dict[str, Any]:
"""短暂等待会话结束,并返回等待期间可见的增量输出。"""
session = self.get_session(session_id)
normalized_timeout = self._normalize_wait_timeout(timeout_ms)
if session.wait_task and not session.wait_task.done():
try:
await asyncio.wait_for(
asyncio.shield(session.wait_task),
timeout=normalized_timeout / 1000,
)
except asyncio.TimeoutError:
pass
output, output_truncated, output_until_seq = self._collect_output(
session,
since_seq=since_seq,
max_bytes=max_bytes,
)
payload = self._session_payload(
session,
output=output,
output_truncated=output_truncated,
output_until_seq=output_until_seq,
)
payload["wait_timeout_ms"] = normalized_timeout
return payload
async def write(self, *, session_id: str, input_text: str) -> dict[str, Any]:
"""向会话 stdin 写入文本PTY 模式下写入 master fd。"""
session = self.get_session(session_id)
if session.status != "running":
raise RuntimeError(f"会话已结束,当前状态: {session.status}")
data = (input_text or "").encode("utf-8")
if session.use_pty:
if session.master_fd is None:
raise RuntimeError("PTY 已关闭")
await asyncio.to_thread(os.write, session.master_fd, data)
else:
if not session.process or not session.process.stdin:
raise RuntimeError("进程 stdin 不可写")
session.process.stdin.write(data)
await session.process.stdin.drain()
session.updated_at = time.time()
payload = self._session_payload(session, output="", output_truncated=False)
payload["written_bytes"] = len(data)
return payload
async def kill(
self,
*,
session_id: str,
sig: Optional[str | int] = "TERM",
) -> dict[str, Any]:
"""向会话进程组发送信号并等待短暂清理。"""
session = self.get_session(session_id)
if session.status != "running":
return self._session_payload(session, output="", output_truncated=False)
session.kill_requested = True
signal_number = self._resolve_signal(sig)
self._send_signal(session, signal_number)
if session.wait_task and not session.wait_task.done():
try:
await asyncio.wait_for(
asyncio.shield(session.wait_task),
timeout=TERMINAL_KILL_GRACE_SECONDS,
)
except asyncio.TimeoutError:
force_signal = getattr(signal, "SIGKILL", signal.SIGTERM)
self._send_signal(session, force_signal)
return self._session_payload(session, output="", output_truncated=False)
def get_session(self, session_id: str) -> _TerminalSession:
"""按 ID 获取会话,不存在时抛出清晰错误。"""
session = self._sessions.get(session_id)
if not session:
raise KeyError(f"终端会话不存在: {session_id}")
return session
@staticmethod
def _normalize_wait_timeout(timeout_ms: Optional[int]) -> int:
"""限制 wait 单次等待时间,避免工具调用长时间占用模型回合。"""
try:
normalized = int(timeout_ms or TERMINAL_WAIT_DEFAULT_MS)
except (TypeError, ValueError):
normalized = TERMINAL_WAIT_DEFAULT_MS
if normalized < 0:
return 0
return min(normalized, TERMINAL_WAIT_MAX_MS)
@staticmethod
def _normalize_read_limit(max_bytes: Optional[int]) -> int:
"""限制单次读取返回的输出大小。"""
try:
normalized = int(max_bytes or TERMINAL_DEFAULT_READ_BYTES)
except (TypeError, ValueError):
normalized = TERMINAL_DEFAULT_READ_BYTES
if normalized <= 0:
return TERMINAL_DEFAULT_READ_BYTES
return min(normalized, TERMINAL_MAX_READ_BYTES)
def _collect_output(
self,
session: _TerminalSession,
*,
since_seq: Optional[int],
max_bytes: Optional[int],
) -> tuple[str, bool, int]:
"""按 seq 和大小限制收集输出文本。"""
read_limit = self._normalize_read_limit(max_bytes)
selected_chunks = [
chunk
for chunk in session.chunks
if since_seq is None or chunk.seq > since_seq
]
output_parts: list[str] = []
output_bytes = 0
output_truncated = False
last_stream: Optional[str] = None
output_until_seq = since_seq or session.retained_from_seq - 1
for chunk in selected_chunks:
prefix = self._stream_prefix(chunk.stream, last_stream, session.use_pty)
text = f"{prefix}{chunk.text}" if prefix else chunk.text
encoded = text.encode("utf-8")
remaining = read_limit - output_bytes
if len(encoded) > remaining:
if remaining > 0:
output_parts.append(
encoded[:remaining].decode("utf-8", errors="ignore")
)
output_truncated = True
break
output_parts.append(text)
output_bytes += len(encoded)
last_stream = chunk.stream
output_until_seq = chunk.seq
if since_seq is not None and since_seq < session.retained_from_seq - 1:
output_truncated = True
if not output_truncated:
output_until_seq = session.next_seq - 1
return "".join(output_parts), output_truncated, output_until_seq
@staticmethod
def _stream_prefix(stream: str, last_stream: Optional[str], use_pty: bool) -> str:
"""为普通管道输出增加 stdout/stderr 分段标识。"""
if use_pty or stream == last_stream:
return ""
title = "标准输出" if stream == "stdout" else "错误输出"
return f"\n[{title}]\n"
@staticmethod
def _resolve_signal(sig: Optional[str | int]) -> int:
"""解析字符串或数字形式的信号名。"""
if isinstance(sig, int):
return sig
signal_name = str(sig or "TERM").strip().upper()
if signal_name.isdigit():
return int(signal_name)
if not signal_name.startswith("SIG"):
signal_name = f"SIG{signal_name}"
return int(getattr(signal, signal_name, signal.SIGTERM))
@staticmethod
def _send_signal(session: _TerminalSession, sig: int) -> None:
"""优先向进程组发信号,失败时回退到单进程。"""
try:
if os.name == "posix":
os.killpg(session.pid, sig)
elif session.process:
if sig == getattr(signal, "SIGKILL", None):
session.process.kill()
else:
session.process.terminate()
except ProcessLookupError:
pass
def _active_session_count_locked(self) -> int:
"""统计仍在运行的会话数量。"""
return sum(1 for session in self._sessions.values() if session.status == "running")
def _cleanup_finished_sessions_locked(self) -> None:
"""清理已经结束且超过保留时间的会话。"""
now = time.time()
expired_ids = [
session_id
for session_id, session in self._sessions.items()
if session.status != "running"
and now - session.updated_at > TERMINAL_RETENTION_SECONDS
]
for session_id in expired_ids:
session = self._sessions.pop(session_id)
session.close_pty()
@staticmethod
def _session_payload(
session: _TerminalSession,
*,
output: str,
output_truncated: bool,
output_until_seq: Optional[int] = None,
) -> dict[str, Any]:
"""生成工具返回的结构化会话状态。"""
return {
"session_id": session.session_id,
"command": session.command,
"cwd": session.cwd,
"pid": session.pid,
"status": session.status,
"exit_code": session.exit_code,
"use_pty": session.use_pty,
"last_seq": session.next_seq - 1,
"output_until_seq": (
session.next_seq - 1 if output_until_seq is None else output_until_seq
),
"retained_from_seq": session.retained_from_seq,
"output_truncated": output_truncated,
"output": output,
"error": session.error,
}
terminal_session_manager = _TerminalSessionManager()

View File

@@ -823,7 +823,7 @@ class SkillsChain(ChainBase):
)
if search_row:
buttons.append(search_row)
for index, _skill in enumerate(page_items, start=1):
for index, _skill in enumerate(items, start=1):
buttons.append(
[
{

View File

@@ -1,6 +1,6 @@
---
name: database-operation
version: 1
version: 2
description: >-
Use this skill when you need to execute SQL against the MoviePilot database.
This skill guides you through connecting to the database and executing SQL statements.
@@ -20,7 +20,7 @@ This skill guides you through executing SQL against the MoviePilot database. Bot
## Prerequisites
You need the following tools:
- `execute_command` - Execute shell commands to run database queries
- `execute_command` - Execute shell commands to run database queries. Use `action=run` when you need the query result immediately.
## Getting Database Connection Info
@@ -38,7 +38,7 @@ The system prompt `<system_info>` section already contains all the database conn
Extract the database file path from `<system_info>` (the path inside the parentheses after `SQLite`).
Use `execute_command` to run queries:
Use `execute_command` with `action=run` to run queries:
```bash
sqlite3 -header -column <DB_PATH> "YOUR SQL QUERY HERE;"
@@ -66,7 +66,7 @@ sqlite3 <DB_PATH> ".schema tablename"
Extract the connection parameters from `<system_info>` (parse `user:password@host:port/database` from the parentheses after `PostgreSQL`).
Use `execute_command` to run queries via `psql`:
Use `execute_command` with `action=run` to run queries via `psql`:
```bash
PGPASSWORD=<password> psql -h <host> -p <port> -U <user> -d <database> -c "YOUR SQL QUERY HERE;"

View File

@@ -1,4 +1,5 @@
import asyncio
import json
import os
import re
import shlex
@@ -23,15 +24,18 @@ def _python_command(code: str) -> str:
class TestExecuteCommandTool(unittest.TestCase):
def _temp_file_path_from_result(self, result: str) -> str:
"""从工具返回文本中提取完整输出临时文件路径。"""
match = re.search(r"临时文件: (.+)", result)
self.assertIsNotNone(match)
return match.group(1).strip()
def _run_command(self, command: str, timeout: int = 60) -> str:
"""按一次性执行模式运行命令,兼容旧测试断言。"""
tool = ExecuteCommandTool(session_id="session-1", user_id="10001")
return asyncio.run(tool.run(command=command, timeout=timeout))
return asyncio.run(tool.run(action="run", command=command, timeout=timeout))
def test_large_output_is_truncated_before_returning_to_agent(self):
"""大输出一次性命令只把预览返回给 Agent并把完整内容写到临时文件。"""
command = _python_command(
"import sys; sys.stdout.write('x' * 200000); sys.stdout.flush()"
)
@@ -52,6 +56,7 @@ class TestExecuteCommandTool(unittest.TestCase):
self.assertGreater(len(file_content), 100000)
def test_timeout_returns_partial_output_promptly(self):
"""一次性命令超时后应及时返回已经读取到的部分输出。"""
command = _python_command(
"import time; print('started', flush=True); time.sleep(5)"
)
@@ -65,6 +70,7 @@ class TestExecuteCommandTool(unittest.TestCase):
self.assertIn("started", result)
def test_timeout_with_large_output_writes_partial_full_log_to_temp_file(self):
"""超时且输出较大时,终止前完整输出应写入临时文件。"""
command = _python_command(
"import sys, time; sys.stdout.write('x' * 20000); sys.stdout.flush(); time.sleep(5)"
)
@@ -83,6 +89,7 @@ class TestExecuteCommandTool(unittest.TestCase):
self.assertGreaterEqual(file_content.count("x"), 20000)
def test_timeout_is_capped(self):
"""一次性执行的 timeout 参数超过上限时应自动限幅。"""
command = _python_command("print('ok')")
result = self._run_command(command, timeout=9999)
@@ -90,6 +97,134 @@ class TestExecuteCommandTool(unittest.TestCase):
self.assertIn("timeout 参数超过上限", result)
self.assertIn("ok", result)
def test_forbidden_command_is_rejected(self):
"""明显危险命令在进入 shell 前应被拒绝。"""
result = self._run_command("echo ok && rm -rf /")
payload = json.loads(result)
self.assertEqual(payload["status"], "error")
self.assertIn("禁止使用", payload["error"])
class TestExecuteCommandSessionTool(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
"""创建每个测试复用的统一命令工具。"""
self.tool = ExecuteCommandTool(session_id="session-1", user_id="10001")
self._created_sessions: list[str] = []
async def asyncTearDown(self):
"""清理测试中残留的后台会话,避免影响后续用例。"""
for session_id in self._created_sessions:
await self.tool.run(action="kill", session_id=session_id)
@staticmethod
def _loads(result: str) -> dict:
"""解析 execute_command 返回的 JSON 字符串。"""
return json.loads(result)
async def _start(self, command: str, *, use_pty: bool = False) -> dict:
"""通过 execute_command 启动后台会话并记录 ID。"""
payload = self._loads(
await self.tool.run(action="start", command=command, use_pty=use_pty)
)
session_id = payload.get("session_id")
if session_id:
self._created_sessions.append(session_id)
return payload
async def test_default_action_starts_session_promptly(self):
"""不传 action 时应默认后台启动,并快速返回会话 ID。"""
command = _python_command(
"import time; print('ready', flush=True); time.sleep(1); print('done', flush=True)"
)
started_at = time.monotonic()
start_payload = self._loads(await self.tool.run(command=command, use_pty=False))
duration = time.monotonic() - started_at
self._created_sessions.append(start_payload["session_id"])
self.assertLess(duration, 0.8)
self.assertEqual(start_payload["status"], "running")
self.assertIn("session_id", start_payload)
async def test_read_and_wait_get_incremental_output(self):
"""同一个 execute_command 工具应能分段等待并读取增量输出。"""
command = _python_command(
"import time; print('ready', flush=True); time.sleep(1); print('done', flush=True)"
)
start_payload = await self._start(command)
wait_payload = self._loads(
await self.tool.run(
action="wait",
session_id=start_payload["session_id"],
timeout_ms=200,
since_seq=0,
)
)
self.assertEqual(wait_payload["status"], "running")
self.assertIn("ready", wait_payload["output"])
final_payload = self._loads(
await self.tool.run(
action="wait",
session_id=start_payload["session_id"],
timeout_ms=3000,
since_seq=wait_payload["output_until_seq"],
)
)
self.assertEqual(final_payload["status"], "exited")
self.assertEqual(final_payload["exit_code"], 0)
self.assertIn("done", final_payload["output"])
async def test_write_sends_input_to_running_process(self):
"""write 动作应能向后台进程 stdin 写入交互输入。"""
command = _python_command(
"line = input('name: '); print('hello ' + line, flush=True)"
)
start_payload = await self._start(command)
await self.tool.run(
action="write",
session_id=start_payload["session_id"],
input_text="moviepilot\n",
)
wait_payload = self._loads(
await self.tool.run(
action="wait",
session_id=start_payload["session_id"],
timeout_ms=3000,
since_seq=0,
)
)
self.assertEqual(wait_payload["status"], "exited")
self.assertIn("hello moviepilot", wait_payload["output"])
async def test_kill_stops_long_running_process(self):
"""kill 动作应能终止长时间运行的后台命令会话。"""
command = _python_command(
"import time; print('started', flush=True); time.sleep(20)"
)
start_payload = await self._start(command)
read_payload = self._loads(
await self.tool.run(
action="wait",
session_id=start_payload["session_id"],
timeout_ms=500,
since_seq=0,
)
)
kill_payload = self._loads(
await self.tool.run(action="kill", session_id=start_payload["session_id"])
)
self.assertIn("started", read_payload["output"])
self.assertIn(kill_payload["status"], {"killed", "exited"})
if __name__ == "__main__":
unittest.main()