mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-01 13:40:54 +08:00
feat: add managed agent command sessions
This commit is contained in:
@@ -307,7 +307,10 @@ class StreamingHandler:
|
||||
or tool_kwargs.get("path"),
|
||||
)
|
||||
if tool_name == "execute_command":
|
||||
return "command", tool_kwargs.get("command")
|
||||
return (
|
||||
"command",
|
||||
tool_kwargs.get("command") or tool_kwargs.get("session_id"),
|
||||
)
|
||||
if tool_name == "ask_user_choice":
|
||||
return "interaction", tool_kwargs.get("message")
|
||||
if tool_name.startswith("search_") or tool_name in {"get_search_results"}:
|
||||
|
||||
@@ -56,6 +56,7 @@ Tool Calling Strategy:
|
||||
- Reuse the latest torrent search cache for `get_search_results` and `add_download` instead of re-running the same search unnecessarily.
|
||||
- Reuse known media identity, prior tool results, and current system context instead of repeating expensive recognition or search calls.
|
||||
- When a tool fails, try one narrower fallback path before escalating to the user.
|
||||
- Use `execute_command` for shell work. Its default `action=start` starts a managed background session and returns `session_id`, `status`, `last_seq`, and `output_until_seq`; call the same tool again with `action=read`, `action=wait`, `action=write`, or `action=kill` to poll output, wait in short segments, send stdin, or stop the process.
|
||||
|
||||
Media Management Rules:
|
||||
1. Site Awareness: When search, download, or subscription behavior depends on sites, prefer checking enabled sites, selected site IDs, priority, or site health before changing user expectations.
|
||||
|
||||
@@ -1,16 +1,25 @@
|
||||
"""执行Shell命令工具"""
|
||||
"""执行 Shell 命令工具。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
from dataclasses import dataclass, field
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Any, Optional, TextIO, Type
|
||||
from typing import Any, Literal, Optional, TextIO, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl.terminal_session import (
|
||||
TERMINAL_DEFAULT_READ_BYTES,
|
||||
TERMINAL_MAX_READ_BYTES,
|
||||
TERMINAL_WAIT_DEFAULT_MS,
|
||||
terminal_session_manager,
|
||||
)
|
||||
from app.log import logger
|
||||
|
||||
|
||||
@@ -20,6 +29,14 @@ MAX_OUTPUT_PREVIEW_BYTES = 10 * 1024
|
||||
READ_CHUNK_SIZE = 4096
|
||||
KILL_GRACE_SECONDS = 3
|
||||
COMMAND_CONCURRENCY_LIMIT = 2
|
||||
COMMAND_FORBIDDEN_KEYWORDS = (
|
||||
"rm -rf /",
|
||||
":(){ :|:& };:",
|
||||
"dd if=/dev/zero",
|
||||
"mkfs",
|
||||
"reboot",
|
||||
"shutdown",
|
||||
)
|
||||
|
||||
_command_semaphore = asyncio.Semaphore(COMMAND_CONCURRENCY_LIMIT)
|
||||
|
||||
@@ -38,11 +55,13 @@ class _CommandOutput:
|
||||
|
||||
@staticmethod
|
||||
def _clip_text_to_bytes(text: str, byte_limit: int) -> str:
|
||||
"""按 UTF-8 字节数截断文本,避免截断后出现非法字符。"""
|
||||
if byte_limit <= 0:
|
||||
return ""
|
||||
return text.encode("utf-8")[:byte_limit].decode("utf-8", errors="ignore")
|
||||
|
||||
def _write_chunk(self, stream_name: str, text: str) -> None:
|
||||
"""把输出分片按 stdout/stderr 分段写入临时文件。"""
|
||||
if not self.temp_file_handle or not text:
|
||||
return
|
||||
|
||||
@@ -56,6 +75,7 @@ class _CommandOutput:
|
||||
self.temp_file_handle.write(text)
|
||||
|
||||
def _ensure_temp_file(self) -> None:
|
||||
"""首次超出预览上限时创建临时文件并补写已缓存预览。"""
|
||||
if self.temp_file_handle:
|
||||
return
|
||||
|
||||
@@ -72,6 +92,7 @@ class _CommandOutput:
|
||||
self._write_chunk(stream_name, chunk)
|
||||
|
||||
def close(self) -> None:
|
||||
"""关闭临时文件句柄,确保输出落盘。"""
|
||||
if not self.temp_file_handle:
|
||||
return
|
||||
self.temp_file_handle.flush()
|
||||
@@ -79,6 +100,7 @@ class _CommandOutput:
|
||||
self.temp_file_handle = None
|
||||
|
||||
def append(self, stream_name: str, text: str) -> None:
|
||||
"""追加一段输出,超出预览上限后只保留完整日志文件。"""
|
||||
if not text:
|
||||
return
|
||||
|
||||
@@ -104,47 +126,141 @@ class _CommandOutput:
|
||||
|
||||
@property
|
||||
def stdout(self) -> str:
|
||||
"""返回当前保留的 stdout 预览。"""
|
||||
return "".join(
|
||||
text for stream_name, text in self.preview_entries if stream_name == "stdout"
|
||||
).strip()
|
||||
|
||||
@property
|
||||
def stderr(self) -> str:
|
||||
"""返回当前保留的 stderr 预览。"""
|
||||
return "".join(
|
||||
text for stream_name, text in self.preview_entries if stream_name == "stderr"
|
||||
).strip()
|
||||
|
||||
|
||||
class ExecuteCommandInput(BaseModel):
|
||||
"""执行Shell命令工具的输入参数模型"""
|
||||
"""执行 Shell 命令工具的输入参数模型。"""
|
||||
|
||||
explanation: str = Field(
|
||||
..., description="Clear explanation of why this command is being executed"
|
||||
..., description="Clear explanation of why this command action is needed"
|
||||
)
|
||||
action: Optional[Literal["start", "read", "wait", "write", "kill", "run"]] = Field(
|
||||
"start",
|
||||
description=(
|
||||
"Command action. start launches a managed background session and returns "
|
||||
"session_id. read/wait/write/kill operate on that session. run executes "
|
||||
"once and waits until completion or timeout."
|
||||
),
|
||||
)
|
||||
command: Optional[str] = Field(
|
||||
None,
|
||||
description="Shell command. Required for action=start or action=run.",
|
||||
)
|
||||
session_id: Optional[str] = Field(
|
||||
None,
|
||||
description="Command session id returned by action=start.",
|
||||
)
|
||||
input_text: Optional[str] = Field(
|
||||
None,
|
||||
description="Text to send to stdin for action=write. Use \\u0003 for Ctrl+C.",
|
||||
)
|
||||
signal_name: Optional[str] = Field(
|
||||
"TERM",
|
||||
description="Signal for action=kill, such as TERM, INT, KILL, or 15.",
|
||||
)
|
||||
cwd: Optional[str] = Field(
|
||||
None,
|
||||
description="Working directory for action=start or action=run.",
|
||||
)
|
||||
env: Optional[dict[str, Any]] = Field(
|
||||
None,
|
||||
description="Additional environment variables for action=start.",
|
||||
)
|
||||
use_pty: Optional[bool] = Field(
|
||||
True,
|
||||
description="Use a pseudo terminal for action=start when supported.",
|
||||
)
|
||||
since_seq: Optional[int] = Field(
|
||||
None,
|
||||
description="For action=read/wait, return output chunks after this seq.",
|
||||
)
|
||||
max_bytes: Optional[int] = Field(
|
||||
TERMINAL_DEFAULT_READ_BYTES,
|
||||
description="For action=read/wait, maximum output bytes to return.",
|
||||
)
|
||||
timeout_ms: Optional[int] = Field(
|
||||
TERMINAL_WAIT_DEFAULT_MS,
|
||||
description="For action=wait, maximum segmented wait time in milliseconds.",
|
||||
)
|
||||
command: str = Field(..., description="The shell command to execute")
|
||||
timeout: Optional[int] = Field(
|
||||
60, description="Max execution time in seconds (default: 60)"
|
||||
60,
|
||||
description="For action=run, max execution time in seconds.",
|
||||
)
|
||||
|
||||
|
||||
class ExecuteCommandTool(MoviePilotTool):
|
||||
"""统一执行和管理 Shell 命令的 Agent 工具。"""
|
||||
|
||||
name: str = "execute_command"
|
||||
description: str = (
|
||||
"Safely execute shell commands on the server. Useful for system "
|
||||
"maintenance, checking status, or running custom scripts. Includes "
|
||||
"timeout, concurrency, and output preview limits."
|
||||
"Start and manage shell commands on the server. By default action=start "
|
||||
"launches a background session and immediately returns session_id/status/"
|
||||
"last_seq/output_until_seq. Call the same tool with action=read, wait, "
|
||||
"write, or kill to poll output, wait in short segments, send stdin, or "
|
||||
"terminate it. Use action=run only when a one-shot bounded command result "
|
||||
"is preferred."
|
||||
)
|
||||
args_schema: Type[BaseModel] = ExecuteCommandInput
|
||||
require_admin: bool = True
|
||||
result_max_chars = TERMINAL_MAX_READ_BYTES + 4096
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据命令生成友好的提示消息"""
|
||||
command = kwargs.get("command", "")
|
||||
return f"执行系统命令: {command}"
|
||||
"""根据命令动作生成友好的提示消息。"""
|
||||
action = kwargs.get("action") or "start"
|
||||
command = kwargs.get("command")
|
||||
session_id = kwargs.get("session_id")
|
||||
if action in {"start", "run"}:
|
||||
return f"执行系统命令: {command or ''}"
|
||||
if action == "read":
|
||||
return f"读取命令输出: {session_id or ''}"
|
||||
if action == "wait":
|
||||
return f"等待命令会话: {session_id or ''}"
|
||||
if action == "write":
|
||||
return f"写入命令输入: {session_id or ''}"
|
||||
if action == "kill":
|
||||
return f"终止命令会话: {session_id or ''}"
|
||||
return f"处理命令会话: {session_id or command or ''}"
|
||||
|
||||
@staticmethod
|
||||
def _dump(payload: dict[str, Any]) -> str:
|
||||
"""把结构化命令会话结果转换为 Agent 容易解析的 JSON 字符串。"""
|
||||
return json.dumps(payload, ensure_ascii=False, indent=2)
|
||||
|
||||
@staticmethod
|
||||
def _require_session_id(session_id: Optional[str]) -> str:
|
||||
"""校验会话型 action 必须传入 session_id。"""
|
||||
if not session_id:
|
||||
raise ValueError("action 需要传入 session_id")
|
||||
return session_id
|
||||
|
||||
@staticmethod
|
||||
def _require_command(command: Optional[str]) -> str:
|
||||
"""校验启动型 action 必须传入 command。"""
|
||||
if not command or not command.strip():
|
||||
raise ValueError("action 需要传入 command")
|
||||
return command
|
||||
|
||||
@staticmethod
|
||||
def _validate_command(command: str) -> None:
|
||||
"""复用旧工具的基础危险命令过滤,避免明显破坏性命令进入 shell。"""
|
||||
for keyword in COMMAND_FORBIDDEN_KEYWORDS:
|
||||
if keyword in command:
|
||||
raise ValueError(f"命令包含禁止使用的关键字 '{keyword}'")
|
||||
|
||||
@staticmethod
|
||||
def _normalize_timeout(timeout: Optional[int]) -> tuple[int, Optional[str]]:
|
||||
"""限制命令最长运行时间,避免 Agent 传入过大的 timeout。"""
|
||||
"""限制一次性执行命令的最长运行时间。"""
|
||||
try:
|
||||
normalized = int(timeout or DEFAULT_TIMEOUT_SECONDS)
|
||||
except (TypeError, ValueError):
|
||||
@@ -161,7 +277,7 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
|
||||
@staticmethod
|
||||
def _subprocess_kwargs() -> dict:
|
||||
"""为子进程创建独立进程组,便于超时场景清理整棵子进程。"""
|
||||
"""为一次性命令创建独立进程组,便于超时清理整棵子进程。"""
|
||||
kwargs = {
|
||||
"stdin": subprocess.DEVNULL,
|
||||
"stdout": asyncio.subprocess.PIPE,
|
||||
@@ -179,17 +295,16 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
stream_name: str,
|
||||
output: _CommandOutput,
|
||||
) -> None:
|
||||
"""按块读取输出,始终只把前 10KB 保留在返回结果中。"""
|
||||
"""按块读取一次性命令输出,只把前 10KB 保留在返回结果中。"""
|
||||
while True:
|
||||
chunk = await stream.read(READ_CHUNK_SIZE)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
output.append(stream_name, chunk.decode("utf-8", errors="replace"))
|
||||
|
||||
@staticmethod
|
||||
def _terminate_process(process: Any, sig: int):
|
||||
"""向进程组发送终止信号;不支持进程组的平台回退为单进程终止。"""
|
||||
def _terminate_process(process: Any, sig: int) -> None:
|
||||
"""向进程组发送终止信号,不支持进程组的平台回退为单进程终止。"""
|
||||
try:
|
||||
if os.name == "posix":
|
||||
os.killpg(process.pid, sig)
|
||||
@@ -230,7 +345,7 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
|
||||
@staticmethod
|
||||
async def _finish_reader_tasks(reader_tasks: list[asyncio.Task]) -> None:
|
||||
"""等待输出读取任务退出,异常只记录不影响工具返回。"""
|
||||
"""等待一次性命令输出读取任务退出,异常只记录不影响工具返回。"""
|
||||
if not reader_tasks:
|
||||
return
|
||||
done, pending = await asyncio.wait(reader_tasks, timeout=1)
|
||||
@@ -244,7 +359,7 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
logger.debug("命令输出读取任务异常: %s", result)
|
||||
|
||||
@staticmethod
|
||||
def _format_result(
|
||||
def _format_run_result(
|
||||
*,
|
||||
exit_code: Optional[int],
|
||||
output: _CommandOutput,
|
||||
@@ -252,6 +367,7 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
timed_out: bool,
|
||||
timeout_note: Optional[str],
|
||||
) -> str:
|
||||
"""格式化 action=run 的兼容文本结果。"""
|
||||
if timed_out:
|
||||
result = f"命令执行超时 (限制: {timeout}秒,已终止进程)"
|
||||
else:
|
||||
@@ -260,11 +376,7 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
if timeout_note:
|
||||
result += f"\n\n提示:\n{timeout_note}"
|
||||
if output.temp_file_path:
|
||||
file_note = (
|
||||
"截至命令终止前的完整输出"
|
||||
if timed_out
|
||||
else "完整输出"
|
||||
)
|
||||
file_note = "截至命令终止前的完整输出" if timed_out else "完整输出"
|
||||
result += (
|
||||
"\n\n提示:\n"
|
||||
f"命令输出超过 10KB,仅返回前 {MAX_OUTPUT_PREVIEW_BYTES} 字节内容。\n"
|
||||
@@ -281,65 +393,129 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
result += "\n\n(无输出内容)"
|
||||
return result
|
||||
|
||||
async def run(self, command: str, timeout: Optional[int] = 60, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: command={command}, timeout={timeout}"
|
||||
)
|
||||
|
||||
# 简单安全过滤
|
||||
forbidden_keywords = [
|
||||
"rm -rf /",
|
||||
":(){ :|:& };:",
|
||||
"dd if=/dev/zero",
|
||||
"mkfs",
|
||||
"reboot",
|
||||
"shutdown",
|
||||
]
|
||||
for keyword in forbidden_keywords:
|
||||
if keyword in command:
|
||||
return f"错误:命令包含禁止使用的关键字 '{keyword}'"
|
||||
|
||||
async def _run_once(
|
||||
self,
|
||||
*,
|
||||
command: str,
|
||||
timeout: Optional[int],
|
||||
cwd: Optional[str] = None,
|
||||
) -> str:
|
||||
"""按旧模式一次性执行命令,等待完成或超时后返回文本结果。"""
|
||||
self._validate_command(command)
|
||||
normalized_timeout, timeout_note = self._normalize_timeout(timeout)
|
||||
|
||||
async with _command_semaphore:
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
cwd=cwd,
|
||||
**self._subprocess_kwargs(),
|
||||
)
|
||||
output = _CommandOutput(preview_limit_bytes=MAX_OUTPUT_PREVIEW_BYTES)
|
||||
wait_task = asyncio.create_task(process.wait())
|
||||
reader_tasks = [
|
||||
asyncio.create_task(self._read_stream(process.stdout, "stdout", output)),
|
||||
asyncio.create_task(self._read_stream(process.stderr, "stderr", output)),
|
||||
]
|
||||
|
||||
timed_out = False
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.shield(wait_task), timeout=normalized_timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
timed_out = True
|
||||
await self._cleanup_process(process, wait_task)
|
||||
|
||||
try:
|
||||
await self._finish_reader_tasks(reader_tasks)
|
||||
finally:
|
||||
output.close()
|
||||
|
||||
return self._format_run_result(
|
||||
exit_code=process.returncode,
|
||||
output=output,
|
||||
timeout=normalized_timeout,
|
||||
timed_out=timed_out,
|
||||
timeout_note=timeout_note,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
action: Optional[str] = "start",
|
||||
command: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
input_text: Optional[str] = None,
|
||||
signal_name: Optional[str] = "TERM",
|
||||
cwd: Optional[str] = None,
|
||||
env: Optional[dict[str, Any]] = None,
|
||||
use_pty: Optional[bool] = True,
|
||||
since_seq: Optional[int] = None,
|
||||
max_bytes: Optional[int] = TERMINAL_DEFAULT_READ_BYTES,
|
||||
timeout_ms: Optional[int] = TERMINAL_WAIT_DEFAULT_MS,
|
||||
timeout: Optional[int] = 60,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""执行命令动作:默认后台启动,也支持读取、等待、写入、终止和一次性执行。"""
|
||||
normalized_action = (action or "start").strip().lower()
|
||||
logger.info(
|
||||
"执行工具: %s, action=%s, command=%s, session_id=%s",
|
||||
self.name,
|
||||
normalized_action,
|
||||
command,
|
||||
session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
async with _command_semaphore:
|
||||
# 命令输出可能非常大,必须边读边落盘,不能使用 communicate() 一次性收集。
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command, **self._subprocess_kwargs()
|
||||
if normalized_action == "start":
|
||||
start_command = self._require_command(command)
|
||||
self._validate_command(start_command)
|
||||
payload = await terminal_session_manager.start(
|
||||
command=start_command,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
use_pty=use_pty,
|
||||
)
|
||||
output = _CommandOutput(preview_limit_bytes=MAX_OUTPUT_PREVIEW_BYTES)
|
||||
wait_task = asyncio.create_task(process.wait())
|
||||
reader_tasks = [
|
||||
asyncio.create_task(
|
||||
self._read_stream(process.stdout, "stdout", output)
|
||||
),
|
||||
asyncio.create_task(
|
||||
self._read_stream(process.stderr, "stderr", output)
|
||||
),
|
||||
]
|
||||
return self._dump(payload)
|
||||
|
||||
timed_out = False
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.shield(wait_task), timeout=normalized_timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
timed_out = True
|
||||
await self._cleanup_process(process, wait_task)
|
||||
if normalized_action == "read":
|
||||
payload = await terminal_session_manager.read(
|
||||
session_id=self._require_session_id(session_id),
|
||||
since_seq=since_seq,
|
||||
max_bytes=max_bytes,
|
||||
)
|
||||
return self._dump(payload)
|
||||
|
||||
try:
|
||||
await self._finish_reader_tasks(reader_tasks)
|
||||
finally:
|
||||
output.close()
|
||||
if normalized_action == "wait":
|
||||
payload = await terminal_session_manager.wait(
|
||||
session_id=self._require_session_id(session_id),
|
||||
timeout_ms=timeout_ms,
|
||||
since_seq=since_seq,
|
||||
max_bytes=max_bytes,
|
||||
)
|
||||
return self._dump(payload)
|
||||
|
||||
return self._format_result(
|
||||
exit_code=process.returncode,
|
||||
output=output,
|
||||
timeout=normalized_timeout,
|
||||
timed_out=timed_out,
|
||||
timeout_note=timeout_note,
|
||||
if normalized_action == "write":
|
||||
payload = await terminal_session_manager.write(
|
||||
session_id=self._require_session_id(session_id),
|
||||
input_text=input_text or "",
|
||||
)
|
||||
return self._dump(payload)
|
||||
|
||||
if normalized_action == "kill":
|
||||
payload = await terminal_session_manager.kill(
|
||||
session_id=self._require_session_id(session_id),
|
||||
sig=signal_name,
|
||||
)
|
||||
return self._dump(payload)
|
||||
|
||||
if normalized_action == "run":
|
||||
return await self._run_once(
|
||||
command=self._require_command(command),
|
||||
timeout=timeout,
|
||||
cwd=cwd,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行命令失败: {e}", exc_info=True)
|
||||
return f"执行命令时发生错误: {str(e)}"
|
||||
raise ValueError(f"不支持的 action: {action}")
|
||||
except Exception as err:
|
||||
logger.error("执行命令 action 失败: %s", err, exc_info=True)
|
||||
return self._dump({"error": str(err), "status": "error", "action": normalized_action})
|
||||
|
||||
628
app/agent/tools/impl/terminal_session.py
Normal file
628
app/agent/tools/impl/terminal_session.py
Normal file
@@ -0,0 +1,628 @@
|
||||
"""Agent 终端会话管理器。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import errno
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
if os.name == "posix":
|
||||
import fcntl as _fcntl
|
||||
import pty as _pty
|
||||
else:
|
||||
_fcntl = None
|
||||
_pty = None
|
||||
|
||||
|
||||
TERMINAL_CONCURRENCY_LIMIT = 4
|
||||
TERMINAL_RETENTION_SECONDS = 30 * 60
|
||||
TERMINAL_MAX_RETAINED_BYTES = 1024 * 1024
|
||||
TERMINAL_DEFAULT_READ_BYTES = 10 * 1024
|
||||
TERMINAL_MAX_READ_BYTES = 64 * 1024
|
||||
TERMINAL_READ_CHUNK_SIZE = 4096
|
||||
TERMINAL_PTY_POLL_INTERVAL = 0.05
|
||||
TERMINAL_WAIT_DEFAULT_MS = 1000
|
||||
TERMINAL_WAIT_MAX_MS = 60 * 1000
|
||||
TERMINAL_KILL_GRACE_SECONDS = 3
|
||||
TERMINAL_FORBIDDEN_KEYWORDS = (
|
||||
"rm -rf /",
|
||||
":(){ :|:& };:",
|
||||
"dd if=/dev/zero",
|
||||
"mkfs",
|
||||
"reboot",
|
||||
"shutdown",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _TerminalChunk:
|
||||
"""记录终端输出分片,供增量读取时按 seq 过滤。"""
|
||||
|
||||
seq: int
|
||||
stream: str
|
||||
text: str
|
||||
byte_size: int
|
||||
created_at: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class _TerminalSession:
|
||||
"""保存一个后台命令会话的进程、输出和状态。"""
|
||||
|
||||
session_id: str
|
||||
command: str
|
||||
cwd: str
|
||||
pid: int
|
||||
use_pty: bool
|
||||
created_at: float = field(default_factory=time.time)
|
||||
updated_at: float = field(default_factory=time.time)
|
||||
status: str = "running"
|
||||
exit_code: Optional[int] = None
|
||||
process: Optional[asyncio.subprocess.Process] = None
|
||||
master_fd: Optional[int] = None
|
||||
chunks: list[_TerminalChunk] = field(default_factory=list)
|
||||
next_seq: int = 1
|
||||
retained_from_seq: int = 1
|
||||
retained_bytes: int = 0
|
||||
kill_requested: bool = False
|
||||
error: Optional[str] = None
|
||||
reader_tasks: list[asyncio.Task] = field(default_factory=list)
|
||||
wait_task: Optional[asyncio.Task] = None
|
||||
|
||||
def append_output(self, stream: str, data: bytes) -> None:
|
||||
"""追加输出并按容量上限丢弃最旧分片,避免长任务撑爆内存。"""
|
||||
if not data:
|
||||
return
|
||||
|
||||
text = data.decode("utf-8", errors="replace")
|
||||
chunk = _TerminalChunk(
|
||||
seq=self.next_seq,
|
||||
stream=stream,
|
||||
text=text,
|
||||
byte_size=len(data),
|
||||
created_at=time.time(),
|
||||
)
|
||||
self.next_seq += 1
|
||||
self.chunks.append(chunk)
|
||||
self.retained_bytes += chunk.byte_size
|
||||
self.updated_at = chunk.created_at
|
||||
self._trim_output()
|
||||
|
||||
def _trim_output(self) -> None:
|
||||
"""移除超出保留上限的旧输出分片。"""
|
||||
while self.retained_bytes > TERMINAL_MAX_RETAINED_BYTES and self.chunks:
|
||||
removed = self.chunks.pop(0)
|
||||
self.retained_bytes -= removed.byte_size
|
||||
self.retained_from_seq = removed.seq + 1
|
||||
|
||||
def mark_finished(self, exit_code: Optional[int]) -> None:
|
||||
"""标记进程已经结束,并记录退出码。"""
|
||||
self.exit_code = exit_code
|
||||
self.status = "killed" if self.kill_requested else "exited"
|
||||
self.updated_at = time.time()
|
||||
|
||||
def mark_error(self, message: str) -> None:
|
||||
"""标记会话异常,保留错误信息供后续读取。"""
|
||||
self.error = message
|
||||
self.status = "error"
|
||||
self.updated_at = time.time()
|
||||
|
||||
def close_pty(self) -> None:
|
||||
"""关闭父进程持有的 PTY master fd。"""
|
||||
if self.master_fd is None:
|
||||
return
|
||||
try:
|
||||
os.close(self.master_fd)
|
||||
except OSError:
|
||||
pass
|
||||
self.master_fd = None
|
||||
|
||||
|
||||
class _TerminalSessionManager:
|
||||
"""管理 Agent 后台终端会话的生命周期。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化会话表和并发保护锁。"""
|
||||
self._sessions: dict[str, _TerminalSession] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_bool(value: Any, default: bool = True) -> bool:
|
||||
"""兼容 LLM 或 HTTP 传入的 bool/string/int 布尔值。"""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.strip().lower() not in {"false", "0", "no", "off"}
|
||||
return bool(value)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_cwd(cwd: Optional[str]) -> str:
|
||||
"""解析工作目录,未传入时默认使用 MoviePilot 项目根目录。"""
|
||||
if not cwd:
|
||||
return str(settings.ROOT_PATH)
|
||||
path = Path(cwd).expanduser()
|
||||
if not path.is_absolute():
|
||||
path = (settings.ROOT_PATH / path).resolve()
|
||||
else:
|
||||
path = path.resolve()
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"工作目录不存在: {path}")
|
||||
if not path.is_dir():
|
||||
raise NotADirectoryError(f"工作目录不是目录: {path}")
|
||||
return str(path)
|
||||
|
||||
@staticmethod
|
||||
def _build_env(env: Optional[dict[str, Any]]) -> dict[str, str]:
|
||||
"""合并环境变量,并把值稳定转换为字符串。"""
|
||||
merged_env = os.environ.copy()
|
||||
if not env:
|
||||
return merged_env
|
||||
for key, value in env.items():
|
||||
if value is None:
|
||||
continue
|
||||
merged_env[str(key)] = str(value)
|
||||
return merged_env
|
||||
|
||||
@staticmethod
|
||||
def _validate_command(command: str) -> None:
|
||||
"""拒绝明显危险或空白命令。"""
|
||||
if not command or not command.strip():
|
||||
raise ValueError("命令不能为空")
|
||||
for keyword in TERMINAL_FORBIDDEN_KEYWORDS:
|
||||
if keyword in command:
|
||||
raise ValueError(f"命令包含禁止使用的关键字 '{keyword}'")
|
||||
|
||||
@staticmethod
|
||||
def _set_nonblocking(fd: int) -> None:
|
||||
"""将 PTY master fd 设置为非阻塞,避免后台读取任务卡住事件循环。"""
|
||||
if _fcntl is None:
|
||||
raise RuntimeError("当前平台不支持 PTY 非阻塞设置")
|
||||
flags = _fcntl.fcntl(fd, _fcntl.F_GETFL)
|
||||
_fcntl.fcntl(fd, _fcntl.F_SETFL, flags | os.O_NONBLOCK)
|
||||
|
||||
@staticmethod
|
||||
def _pipe_subprocess_kwargs() -> dict[str, Any]:
|
||||
"""生成普通管道模式的子进程参数。"""
|
||||
kwargs: dict[str, Any] = {
|
||||
"stdin": asyncio.subprocess.PIPE,
|
||||
"stdout": asyncio.subprocess.PIPE,
|
||||
"stderr": asyncio.subprocess.PIPE,
|
||||
}
|
||||
if os.name == "posix":
|
||||
kwargs["start_new_session"] = True
|
||||
elif os.name == "nt":
|
||||
kwargs["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP
|
||||
return kwargs
|
||||
|
||||
async def start(
|
||||
self,
|
||||
*,
|
||||
command: str,
|
||||
cwd: Optional[str] = None,
|
||||
env: Optional[dict[str, Any]] = None,
|
||||
use_pty: Any = True,
|
||||
) -> dict[str, Any]:
|
||||
"""启动后台命令并立即返回会话 ID。"""
|
||||
self._validate_command(command)
|
||||
normalized_cwd = self._normalize_cwd(cwd)
|
||||
normalized_env = self._build_env(env)
|
||||
should_use_pty = self._normalize_bool(use_pty, default=True) and os.name == "posix"
|
||||
|
||||
async with self._lock:
|
||||
self._cleanup_finished_sessions_locked()
|
||||
if self._active_session_count_locked() >= TERMINAL_CONCURRENCY_LIMIT:
|
||||
raise RuntimeError(
|
||||
f"后台终端会话数已达到上限 {TERMINAL_CONCURRENCY_LIMIT}"
|
||||
)
|
||||
|
||||
session = (
|
||||
await self._start_pty_session(command, normalized_cwd, normalized_env)
|
||||
if should_use_pty
|
||||
else await self._start_pipe_session(command, normalized_cwd, normalized_env)
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
self._sessions[session.session_id] = session
|
||||
|
||||
logger.info(
|
||||
"启动后台终端会话: session_id=%s, pid=%s, use_pty=%s, command=%s",
|
||||
session.session_id,
|
||||
session.pid,
|
||||
session.use_pty,
|
||||
command,
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
return self._session_payload(session, output="", output_truncated=False)
|
||||
|
||||
async def _start_pty_session(
|
||||
self, command: str, cwd: str, env: dict[str, str]
|
||||
) -> _TerminalSession:
|
||||
"""通过 PTY fork 启动交互式命令会话。"""
|
||||
if _pty is None:
|
||||
raise RuntimeError("当前平台不支持 PTY 会话")
|
||||
pid, master_fd = _pty.fork()
|
||||
if pid == 0:
|
||||
os.chdir(cwd)
|
||||
os.environ.clear()
|
||||
os.environ.update(env)
|
||||
shell = os.environ.get("SHELL") or "/bin/sh"
|
||||
os.execl(shell, shell, "-lc", command)
|
||||
|
||||
self._set_nonblocking(master_fd)
|
||||
session = _TerminalSession(
|
||||
session_id=f"term_{uuid.uuid4().hex[:12]}",
|
||||
command=command,
|
||||
cwd=cwd,
|
||||
pid=pid,
|
||||
use_pty=True,
|
||||
master_fd=master_fd,
|
||||
)
|
||||
session.reader_tasks.append(asyncio.create_task(self._read_pty(session)))
|
||||
session.wait_task = asyncio.create_task(self._wait_pty_process(session))
|
||||
return session
|
||||
|
||||
async def _start_pipe_session(
|
||||
self, command: str, cwd: str, env: dict[str, str]
|
||||
) -> _TerminalSession:
|
||||
"""通过普通 stdin/stdout/stderr 管道启动命令会话。"""
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
**self._pipe_subprocess_kwargs(),
|
||||
)
|
||||
session = _TerminalSession(
|
||||
session_id=f"term_{uuid.uuid4().hex[:12]}",
|
||||
command=command,
|
||||
cwd=cwd,
|
||||
pid=process.pid or 0,
|
||||
use_pty=False,
|
||||
process=process,
|
||||
)
|
||||
if process.stdout:
|
||||
session.reader_tasks.append(
|
||||
asyncio.create_task(self._read_pipe(session, process.stdout, "stdout"))
|
||||
)
|
||||
if process.stderr:
|
||||
session.reader_tasks.append(
|
||||
asyncio.create_task(self._read_pipe(session, process.stderr, "stderr"))
|
||||
)
|
||||
session.wait_task = asyncio.create_task(self._wait_pipe_process(session))
|
||||
return session
|
||||
|
||||
async def _read_pty(self, session: _TerminalSession) -> None:
|
||||
"""持续从 PTY 读取增量输出。"""
|
||||
while session.master_fd is not None:
|
||||
try:
|
||||
data = os.read(session.master_fd, TERMINAL_READ_CHUNK_SIZE)
|
||||
except BlockingIOError:
|
||||
await asyncio.sleep(TERMINAL_PTY_POLL_INTERVAL)
|
||||
continue
|
||||
except OSError as err:
|
||||
if err.errno not in {errno.EIO, errno.EBADF}:
|
||||
logger.debug("PTY 输出读取异常: session_id=%s, error=%s", session.session_id, err)
|
||||
break
|
||||
|
||||
if not data:
|
||||
break
|
||||
session.append_output("pty", data)
|
||||
|
||||
async def _read_pipe(
|
||||
self,
|
||||
session: _TerminalSession,
|
||||
stream: asyncio.StreamReader,
|
||||
stream_name: str,
|
||||
) -> None:
|
||||
"""持续从普通管道读取增量输出。"""
|
||||
while True:
|
||||
data = await stream.read(TERMINAL_READ_CHUNK_SIZE)
|
||||
if not data:
|
||||
break
|
||||
session.append_output(stream_name, data)
|
||||
|
||||
async def _wait_pty_process(self, session: _TerminalSession) -> None:
|
||||
"""等待 PTY 子进程结束并完成输出读取任务收尾。"""
|
||||
try:
|
||||
_, status = await asyncio.to_thread(os.waitpid, session.pid, 0)
|
||||
exit_code = os.waitstatus_to_exitcode(status)
|
||||
session.mark_finished(exit_code)
|
||||
except ChildProcessError:
|
||||
session.mark_finished(session.exit_code)
|
||||
except Exception as err:
|
||||
session.mark_error(str(err))
|
||||
logger.warning("等待 PTY 进程失败: session_id=%s, error=%s", session.session_id, err)
|
||||
finally:
|
||||
await self._finish_reader_tasks(session)
|
||||
session.close_pty()
|
||||
|
||||
async def _wait_pipe_process(self, session: _TerminalSession) -> None:
|
||||
"""等待普通管道子进程结束并完成输出读取任务收尾。"""
|
||||
try:
|
||||
if not session.process:
|
||||
session.mark_error("进程对象不存在")
|
||||
return
|
||||
exit_code = await session.process.wait()
|
||||
session.mark_finished(exit_code)
|
||||
except Exception as err:
|
||||
session.mark_error(str(err))
|
||||
logger.warning("等待管道进程失败: session_id=%s, error=%s", session.session_id, err)
|
||||
finally:
|
||||
await self._finish_reader_tasks(session)
|
||||
|
||||
async def _finish_reader_tasks(self, session: _TerminalSession) -> None:
|
||||
"""等待输出读取任务退出,超时后取消残留任务。"""
|
||||
if not session.reader_tasks:
|
||||
return
|
||||
done, pending = await asyncio.wait(session.reader_tasks, timeout=1)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
await asyncio.gather(*done, *pending, return_exceptions=True)
|
||||
|
||||
async def read(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
since_seq: Optional[int] = None,
|
||||
max_bytes: Optional[int] = TERMINAL_DEFAULT_READ_BYTES,
|
||||
) -> dict[str, Any]:
|
||||
"""读取会话当前保留的增量输出。"""
|
||||
session = self.get_session(session_id)
|
||||
output, output_truncated, output_until_seq = self._collect_output(
|
||||
session,
|
||||
since_seq=since_seq,
|
||||
max_bytes=max_bytes,
|
||||
)
|
||||
return self._session_payload(
|
||||
session,
|
||||
output=output,
|
||||
output_truncated=output_truncated,
|
||||
output_until_seq=output_until_seq,
|
||||
)
|
||||
|
||||
async def wait(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
timeout_ms: Optional[int] = TERMINAL_WAIT_DEFAULT_MS,
|
||||
since_seq: Optional[int] = None,
|
||||
max_bytes: Optional[int] = TERMINAL_DEFAULT_READ_BYTES,
|
||||
) -> dict[str, Any]:
|
||||
"""短暂等待会话结束,并返回等待期间可见的增量输出。"""
|
||||
session = self.get_session(session_id)
|
||||
normalized_timeout = self._normalize_wait_timeout(timeout_ms)
|
||||
if session.wait_task and not session.wait_task.done():
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.shield(session.wait_task),
|
||||
timeout=normalized_timeout / 1000,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
output, output_truncated, output_until_seq = self._collect_output(
|
||||
session,
|
||||
since_seq=since_seq,
|
||||
max_bytes=max_bytes,
|
||||
)
|
||||
payload = self._session_payload(
|
||||
session,
|
||||
output=output,
|
||||
output_truncated=output_truncated,
|
||||
output_until_seq=output_until_seq,
|
||||
)
|
||||
payload["wait_timeout_ms"] = normalized_timeout
|
||||
return payload
|
||||
|
||||
async def write(self, *, session_id: str, input_text: str) -> dict[str, Any]:
|
||||
"""向会话 stdin 写入文本,PTY 模式下写入 master fd。"""
|
||||
session = self.get_session(session_id)
|
||||
if session.status != "running":
|
||||
raise RuntimeError(f"会话已结束,当前状态: {session.status}")
|
||||
|
||||
data = (input_text or "").encode("utf-8")
|
||||
if session.use_pty:
|
||||
if session.master_fd is None:
|
||||
raise RuntimeError("PTY 已关闭")
|
||||
await asyncio.to_thread(os.write, session.master_fd, data)
|
||||
else:
|
||||
if not session.process or not session.process.stdin:
|
||||
raise RuntimeError("进程 stdin 不可写")
|
||||
session.process.stdin.write(data)
|
||||
await session.process.stdin.drain()
|
||||
|
||||
session.updated_at = time.time()
|
||||
payload = self._session_payload(session, output="", output_truncated=False)
|
||||
payload["written_bytes"] = len(data)
|
||||
return payload
|
||||
|
||||
async def kill(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
sig: Optional[str | int] = "TERM",
|
||||
) -> dict[str, Any]:
|
||||
"""向会话进程组发送信号并等待短暂清理。"""
|
||||
session = self.get_session(session_id)
|
||||
if session.status != "running":
|
||||
return self._session_payload(session, output="", output_truncated=False)
|
||||
|
||||
session.kill_requested = True
|
||||
signal_number = self._resolve_signal(sig)
|
||||
self._send_signal(session, signal_number)
|
||||
|
||||
if session.wait_task and not session.wait_task.done():
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.shield(session.wait_task),
|
||||
timeout=TERMINAL_KILL_GRACE_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
force_signal = getattr(signal, "SIGKILL", signal.SIGTERM)
|
||||
self._send_signal(session, force_signal)
|
||||
|
||||
return self._session_payload(session, output="", output_truncated=False)
|
||||
|
||||
def get_session(self, session_id: str) -> _TerminalSession:
|
||||
"""按 ID 获取会话,不存在时抛出清晰错误。"""
|
||||
session = self._sessions.get(session_id)
|
||||
if not session:
|
||||
raise KeyError(f"终端会话不存在: {session_id}")
|
||||
return session
|
||||
|
||||
@staticmethod
|
||||
def _normalize_wait_timeout(timeout_ms: Optional[int]) -> int:
|
||||
"""限制 wait 单次等待时间,避免工具调用长时间占用模型回合。"""
|
||||
try:
|
||||
normalized = int(timeout_ms or TERMINAL_WAIT_DEFAULT_MS)
|
||||
except (TypeError, ValueError):
|
||||
normalized = TERMINAL_WAIT_DEFAULT_MS
|
||||
if normalized < 0:
|
||||
return 0
|
||||
return min(normalized, TERMINAL_WAIT_MAX_MS)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_read_limit(max_bytes: Optional[int]) -> int:
|
||||
"""限制单次读取返回的输出大小。"""
|
||||
try:
|
||||
normalized = int(max_bytes or TERMINAL_DEFAULT_READ_BYTES)
|
||||
except (TypeError, ValueError):
|
||||
normalized = TERMINAL_DEFAULT_READ_BYTES
|
||||
if normalized <= 0:
|
||||
return TERMINAL_DEFAULT_READ_BYTES
|
||||
return min(normalized, TERMINAL_MAX_READ_BYTES)
|
||||
|
||||
def _collect_output(
|
||||
self,
|
||||
session: _TerminalSession,
|
||||
*,
|
||||
since_seq: Optional[int],
|
||||
max_bytes: Optional[int],
|
||||
) -> tuple[str, bool, int]:
|
||||
"""按 seq 和大小限制收集输出文本。"""
|
||||
read_limit = self._normalize_read_limit(max_bytes)
|
||||
selected_chunks = [
|
||||
chunk
|
||||
for chunk in session.chunks
|
||||
if since_seq is None or chunk.seq > since_seq
|
||||
]
|
||||
output_parts: list[str] = []
|
||||
output_bytes = 0
|
||||
output_truncated = False
|
||||
last_stream: Optional[str] = None
|
||||
output_until_seq = since_seq or session.retained_from_seq - 1
|
||||
|
||||
for chunk in selected_chunks:
|
||||
prefix = self._stream_prefix(chunk.stream, last_stream, session.use_pty)
|
||||
text = f"{prefix}{chunk.text}" if prefix else chunk.text
|
||||
encoded = text.encode("utf-8")
|
||||
remaining = read_limit - output_bytes
|
||||
if len(encoded) > remaining:
|
||||
if remaining > 0:
|
||||
output_parts.append(
|
||||
encoded[:remaining].decode("utf-8", errors="ignore")
|
||||
)
|
||||
output_truncated = True
|
||||
break
|
||||
output_parts.append(text)
|
||||
output_bytes += len(encoded)
|
||||
last_stream = chunk.stream
|
||||
output_until_seq = chunk.seq
|
||||
|
||||
if since_seq is not None and since_seq < session.retained_from_seq - 1:
|
||||
output_truncated = True
|
||||
if not output_truncated:
|
||||
output_until_seq = session.next_seq - 1
|
||||
return "".join(output_parts), output_truncated, output_until_seq
|
||||
|
||||
@staticmethod
|
||||
def _stream_prefix(stream: str, last_stream: Optional[str], use_pty: bool) -> str:
|
||||
"""为普通管道输出增加 stdout/stderr 分段标识。"""
|
||||
if use_pty or stream == last_stream:
|
||||
return ""
|
||||
title = "标准输出" if stream == "stdout" else "错误输出"
|
||||
return f"\n[{title}]\n"
|
||||
|
||||
@staticmethod
|
||||
def _resolve_signal(sig: Optional[str | int]) -> int:
|
||||
"""解析字符串或数字形式的信号名。"""
|
||||
if isinstance(sig, int):
|
||||
return sig
|
||||
signal_name = str(sig or "TERM").strip().upper()
|
||||
if signal_name.isdigit():
|
||||
return int(signal_name)
|
||||
if not signal_name.startswith("SIG"):
|
||||
signal_name = f"SIG{signal_name}"
|
||||
return int(getattr(signal, signal_name, signal.SIGTERM))
|
||||
|
||||
@staticmethod
|
||||
def _send_signal(session: _TerminalSession, sig: int) -> None:
|
||||
"""优先向进程组发信号,失败时回退到单进程。"""
|
||||
try:
|
||||
if os.name == "posix":
|
||||
os.killpg(session.pid, sig)
|
||||
elif session.process:
|
||||
if sig == getattr(signal, "SIGKILL", None):
|
||||
session.process.kill()
|
||||
else:
|
||||
session.process.terminate()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
|
||||
def _active_session_count_locked(self) -> int:
|
||||
"""统计仍在运行的会话数量。"""
|
||||
return sum(1 for session in self._sessions.values() if session.status == "running")
|
||||
|
||||
def _cleanup_finished_sessions_locked(self) -> None:
|
||||
"""清理已经结束且超过保留时间的会话。"""
|
||||
now = time.time()
|
||||
expired_ids = [
|
||||
session_id
|
||||
for session_id, session in self._sessions.items()
|
||||
if session.status != "running"
|
||||
and now - session.updated_at > TERMINAL_RETENTION_SECONDS
|
||||
]
|
||||
for session_id in expired_ids:
|
||||
session = self._sessions.pop(session_id)
|
||||
session.close_pty()
|
||||
|
||||
@staticmethod
|
||||
def _session_payload(
|
||||
session: _TerminalSession,
|
||||
*,
|
||||
output: str,
|
||||
output_truncated: bool,
|
||||
output_until_seq: Optional[int] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""生成工具返回的结构化会话状态。"""
|
||||
return {
|
||||
"session_id": session.session_id,
|
||||
"command": session.command,
|
||||
"cwd": session.cwd,
|
||||
"pid": session.pid,
|
||||
"status": session.status,
|
||||
"exit_code": session.exit_code,
|
||||
"use_pty": session.use_pty,
|
||||
"last_seq": session.next_seq - 1,
|
||||
"output_until_seq": (
|
||||
session.next_seq - 1 if output_until_seq is None else output_until_seq
|
||||
),
|
||||
"retained_from_seq": session.retained_from_seq,
|
||||
"output_truncated": output_truncated,
|
||||
"output": output,
|
||||
"error": session.error,
|
||||
}
|
||||
|
||||
|
||||
terminal_session_manager = _TerminalSessionManager()
|
||||
@@ -823,7 +823,7 @@ class SkillsChain(ChainBase):
|
||||
)
|
||||
if search_row:
|
||||
buttons.append(search_row)
|
||||
for index, _skill in enumerate(page_items, start=1):
|
||||
for index, _skill in enumerate(items, start=1):
|
||||
buttons.append(
|
||||
[
|
||||
{
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
---
|
||||
name: database-operation
|
||||
version: 1
|
||||
version: 2
|
||||
description: >-
|
||||
Use this skill when you need to execute SQL against the MoviePilot database.
|
||||
This skill guides you through connecting to the database and executing SQL statements.
|
||||
@@ -20,7 +20,7 @@ This skill guides you through executing SQL against the MoviePilot database. Bot
|
||||
## Prerequisites
|
||||
|
||||
You need the following tools:
|
||||
- `execute_command` - Execute shell commands to run database queries
|
||||
- `execute_command` - Execute shell commands to run database queries. Use `action=run` when you need the query result immediately.
|
||||
|
||||
## Getting Database Connection Info
|
||||
|
||||
@@ -38,7 +38,7 @@ The system prompt `<system_info>` section already contains all the database conn
|
||||
|
||||
Extract the database file path from `<system_info>` (the path inside the parentheses after `SQLite`).
|
||||
|
||||
Use `execute_command` to run queries:
|
||||
Use `execute_command` with `action=run` to run queries:
|
||||
|
||||
```bash
|
||||
sqlite3 -header -column <DB_PATH> "YOUR SQL QUERY HERE;"
|
||||
@@ -66,7 +66,7 @@ sqlite3 <DB_PATH> ".schema tablename"
|
||||
|
||||
Extract the connection parameters from `<system_info>` (parse `user:password@host:port/database` from the parentheses after `PostgreSQL`).
|
||||
|
||||
Use `execute_command` to run queries via `psql`:
|
||||
Use `execute_command` with `action=run` to run queries via `psql`:
|
||||
|
||||
```bash
|
||||
PGPASSWORD=<password> psql -h <host> -p <port> -U <user> -d <database> -c "YOUR SQL QUERY HERE;"
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
@@ -23,15 +24,18 @@ def _python_command(code: str) -> str:
|
||||
|
||||
class TestExecuteCommandTool(unittest.TestCase):
|
||||
def _temp_file_path_from_result(self, result: str) -> str:
|
||||
"""从工具返回文本中提取完整输出临时文件路径。"""
|
||||
match = re.search(r"临时文件: (.+)", result)
|
||||
self.assertIsNotNone(match)
|
||||
return match.group(1).strip()
|
||||
|
||||
def _run_command(self, command: str, timeout: int = 60) -> str:
|
||||
"""按一次性执行模式运行命令,兼容旧测试断言。"""
|
||||
tool = ExecuteCommandTool(session_id="session-1", user_id="10001")
|
||||
return asyncio.run(tool.run(command=command, timeout=timeout))
|
||||
return asyncio.run(tool.run(action="run", command=command, timeout=timeout))
|
||||
|
||||
def test_large_output_is_truncated_before_returning_to_agent(self):
|
||||
"""大输出一次性命令只把预览返回给 Agent,并把完整内容写到临时文件。"""
|
||||
command = _python_command(
|
||||
"import sys; sys.stdout.write('x' * 200000); sys.stdout.flush()"
|
||||
)
|
||||
@@ -52,6 +56,7 @@ class TestExecuteCommandTool(unittest.TestCase):
|
||||
self.assertGreater(len(file_content), 100000)
|
||||
|
||||
def test_timeout_returns_partial_output_promptly(self):
|
||||
"""一次性命令超时后应及时返回已经读取到的部分输出。"""
|
||||
command = _python_command(
|
||||
"import time; print('started', flush=True); time.sleep(5)"
|
||||
)
|
||||
@@ -65,6 +70,7 @@ class TestExecuteCommandTool(unittest.TestCase):
|
||||
self.assertIn("started", result)
|
||||
|
||||
def test_timeout_with_large_output_writes_partial_full_log_to_temp_file(self):
|
||||
"""超时且输出较大时,终止前完整输出应写入临时文件。"""
|
||||
command = _python_command(
|
||||
"import sys, time; sys.stdout.write('x' * 20000); sys.stdout.flush(); time.sleep(5)"
|
||||
)
|
||||
@@ -83,6 +89,7 @@ class TestExecuteCommandTool(unittest.TestCase):
|
||||
self.assertGreaterEqual(file_content.count("x"), 20000)
|
||||
|
||||
def test_timeout_is_capped(self):
|
||||
"""一次性执行的 timeout 参数超过上限时应自动限幅。"""
|
||||
command = _python_command("print('ok')")
|
||||
|
||||
result = self._run_command(command, timeout=9999)
|
||||
@@ -90,6 +97,134 @@ class TestExecuteCommandTool(unittest.TestCase):
|
||||
self.assertIn("timeout 参数超过上限", result)
|
||||
self.assertIn("ok", result)
|
||||
|
||||
def test_forbidden_command_is_rejected(self):
|
||||
"""明显危险命令在进入 shell 前应被拒绝。"""
|
||||
result = self._run_command("echo ok && rm -rf /")
|
||||
|
||||
payload = json.loads(result)
|
||||
self.assertEqual(payload["status"], "error")
|
||||
self.assertIn("禁止使用", payload["error"])
|
||||
|
||||
|
||||
class TestExecuteCommandSessionTool(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
"""创建每个测试复用的统一命令工具。"""
|
||||
self.tool = ExecuteCommandTool(session_id="session-1", user_id="10001")
|
||||
self._created_sessions: list[str] = []
|
||||
|
||||
async def asyncTearDown(self):
|
||||
"""清理测试中残留的后台会话,避免影响后续用例。"""
|
||||
for session_id in self._created_sessions:
|
||||
await self.tool.run(action="kill", session_id=session_id)
|
||||
|
||||
@staticmethod
|
||||
def _loads(result: str) -> dict:
|
||||
"""解析 execute_command 返回的 JSON 字符串。"""
|
||||
return json.loads(result)
|
||||
|
||||
async def _start(self, command: str, *, use_pty: bool = False) -> dict:
|
||||
"""通过 execute_command 启动后台会话并记录 ID。"""
|
||||
payload = self._loads(
|
||||
await self.tool.run(action="start", command=command, use_pty=use_pty)
|
||||
)
|
||||
session_id = payload.get("session_id")
|
||||
if session_id:
|
||||
self._created_sessions.append(session_id)
|
||||
return payload
|
||||
|
||||
async def test_default_action_starts_session_promptly(self):
|
||||
"""不传 action 时应默认后台启动,并快速返回会话 ID。"""
|
||||
command = _python_command(
|
||||
"import time; print('ready', flush=True); time.sleep(1); print('done', flush=True)"
|
||||
)
|
||||
|
||||
started_at = time.monotonic()
|
||||
start_payload = self._loads(await self.tool.run(command=command, use_pty=False))
|
||||
duration = time.monotonic() - started_at
|
||||
self._created_sessions.append(start_payload["session_id"])
|
||||
|
||||
self.assertLess(duration, 0.8)
|
||||
self.assertEqual(start_payload["status"], "running")
|
||||
self.assertIn("session_id", start_payload)
|
||||
|
||||
async def test_read_and_wait_get_incremental_output(self):
|
||||
"""同一个 execute_command 工具应能分段等待并读取增量输出。"""
|
||||
command = _python_command(
|
||||
"import time; print('ready', flush=True); time.sleep(1); print('done', flush=True)"
|
||||
)
|
||||
start_payload = await self._start(command)
|
||||
|
||||
wait_payload = self._loads(
|
||||
await self.tool.run(
|
||||
action="wait",
|
||||
session_id=start_payload["session_id"],
|
||||
timeout_ms=200,
|
||||
since_seq=0,
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(wait_payload["status"], "running")
|
||||
self.assertIn("ready", wait_payload["output"])
|
||||
|
||||
final_payload = self._loads(
|
||||
await self.tool.run(
|
||||
action="wait",
|
||||
session_id=start_payload["session_id"],
|
||||
timeout_ms=3000,
|
||||
since_seq=wait_payload["output_until_seq"],
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(final_payload["status"], "exited")
|
||||
self.assertEqual(final_payload["exit_code"], 0)
|
||||
self.assertIn("done", final_payload["output"])
|
||||
|
||||
async def test_write_sends_input_to_running_process(self):
|
||||
"""write 动作应能向后台进程 stdin 写入交互输入。"""
|
||||
command = _python_command(
|
||||
"line = input('name: '); print('hello ' + line, flush=True)"
|
||||
)
|
||||
start_payload = await self._start(command)
|
||||
|
||||
await self.tool.run(
|
||||
action="write",
|
||||
session_id=start_payload["session_id"],
|
||||
input_text="moviepilot\n",
|
||||
)
|
||||
wait_payload = self._loads(
|
||||
await self.tool.run(
|
||||
action="wait",
|
||||
session_id=start_payload["session_id"],
|
||||
timeout_ms=3000,
|
||||
since_seq=0,
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(wait_payload["status"], "exited")
|
||||
self.assertIn("hello moviepilot", wait_payload["output"])
|
||||
|
||||
async def test_kill_stops_long_running_process(self):
|
||||
"""kill 动作应能终止长时间运行的后台命令会话。"""
|
||||
command = _python_command(
|
||||
"import time; print('started', flush=True); time.sleep(20)"
|
||||
)
|
||||
start_payload = await self._start(command)
|
||||
|
||||
read_payload = self._loads(
|
||||
await self.tool.run(
|
||||
action="wait",
|
||||
session_id=start_payload["session_id"],
|
||||
timeout_ms=500,
|
||||
since_seq=0,
|
||||
)
|
||||
)
|
||||
kill_payload = self._loads(
|
||||
await self.tool.run(action="kill", session_id=start_payload["session_id"])
|
||||
)
|
||||
|
||||
self.assertIn("started", read_payload["output"])
|
||||
self.assertIn(kill_payload["status"], {"killed", "exited"})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user