fix agent stream blocking during command execution

Offload synchronous message edits from the event loop and stream subprocess output so long-running commands stay responsive.
This commit is contained in:
jxxghp
2026-04-27 07:57:32 +08:00
parent 7bc032d17c
commit 140d224a9a
4 changed files with 293 additions and 35 deletions

View File

@@ -2,6 +2,8 @@ import asyncio
import threading
from typing import Optional, Tuple
from fastapi.concurrency import run_in_threadpool
from app.chain import ChainBase
from app.log import logger
from app.schemas import Notification
@@ -256,7 +258,8 @@ class StreamingHandler:
try:
if self._message_response is None:
# 第一次发送:发送新消息并获取 message_id
response = chain.send_direct_message(
response = await run_in_threadpool(
chain.send_direct_message,
Notification(
channel=self._channel,
source=self._source,
@@ -264,7 +267,7 @@ class StreamingHandler:
username=self._username,
title=self._title,
text=current_text,
)
),
)
if response and response.success and response.message_id:
self._message_response = response
@@ -297,7 +300,8 @@ class StreamingHandler:
# 如果偏移后还有新内容,立即发送为新消息
if current_text:
response = chain.send_direct_message(
response = await run_in_threadpool(
chain.send_direct_message,
Notification(
channel=self._channel,
source=self._source,
@@ -305,7 +309,7 @@ class StreamingHandler:
username=self._username,
title=self._title,
text=current_text,
)
),
)
if response and response.success and response.message_id:
self._message_response = response
@@ -324,7 +328,8 @@ class StreamingHandler:
except (ValueError, KeyError):
return
success = chain.edit_message(
success = await run_in_threadpool(
chain.edit_message,
channel=channel_enum,
source=self._message_response.source,
message_id=self._message_response.message_id,

View File

@@ -1,7 +1,8 @@
"""执行Shell命令工具"""
import asyncio
from typing import Optional, Type
import codecs
from typing import Any, Dict, Optional, Type
from pydantic import BaseModel, Field
@@ -26,12 +27,133 @@ class ExecuteCommandTool(MoviePilotTool):
description: str = "Safely execute shell commands on the server. Useful for system maintenance, checking status, or running custom scripts. Includes timeout and output limits."
args_schema: Type[BaseModel] = ExecuteCommandInput
require_admin: bool = True
RESULT_LIMIT = 3000
STREAM_CAPTURE_LIMIT = 2000
LIVE_OUTPUT_LIMIT = 1200
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据命令生成友好的提示消息"""
command = kwargs.get("command", "")
return f"执行系统命令: {command}"
def _build_result(
self,
message: str,
stdout_capture: Dict[str, Any],
stderr_capture: Dict[str, Any],
) -> str:
stdout_str = "".join(stdout_capture["chunks"]).strip()
stderr_str = "".join(stderr_capture["chunks"]).strip()
result = message
if stdout_str:
result += f"\n\n标准输出:\n{stdout_str}"
if stderr_str:
result += f"\n\n错误输出:\n{stderr_str}"
if not stdout_str and not stderr_str:
result += "\n\n(无输出内容)"
was_truncated = stdout_capture["truncated"] or stderr_capture["truncated"]
overflow_suffix = "\n\n...(输出内容过长,已截断)"
if was_truncated or len(result) > self.RESULT_LIMIT:
result = (
result[: self.RESULT_LIMIT - len(overflow_suffix)] + overflow_suffix
)
return result
def _append_capture(self, capture: Dict[str, Any], text: str):
if not text:
return
remaining = self.STREAM_CAPTURE_LIMIT - capture["length"]
if remaining <= 0:
capture["truncated"] = True
return
fragment = text[:remaining]
capture["chunks"].append(fragment)
capture["length"] += len(fragment)
if len(text) > remaining:
capture["truncated"] = True
def _should_emit_live_output(self) -> bool:
return bool(
self._stream_handler
and self._stream_handler.is_streaming
and self._stream_handler.is_auto_flushing
)
def _emit_live_output(
self, text: str, stream_name: str, live_state: Dict[str, Any]
):
if not text or not live_state["enabled"]:
return
header_key = f"{stream_name}_header_sent"
prefix = ""
if not live_state[header_key]:
prefix = "标准输出:\n" if stream_name == "stdout" else "\n错误输出:\n"
live_state[header_key] = True
payload = prefix + text
remaining = self.LIVE_OUTPUT_LIMIT - live_state["chars"]
if remaining <= 0:
if not live_state["truncated"]:
self._stream_handler.emit("\n...(命令输出过长,停止实时展示)\n")
live_state["truncated"] = True
return
fragment = payload[:remaining]
if fragment:
self._stream_handler.emit(fragment)
live_state["chars"] += len(fragment)
if len(payload) > remaining and not live_state["truncated"]:
self._stream_handler.emit("\n...(命令输出过长,停止实时展示)\n")
live_state["truncated"] = True
async def _collect_stream(
self,
stream: Optional[asyncio.StreamReader],
stream_name: str,
capture: Dict[str, Any],
live_state: Dict[str, Any],
):
if not stream:
return
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
while True:
chunk = await stream.read(512)
if not chunk:
tail = decoder.decode(b"", final=True)
if tail:
self._append_capture(capture, tail)
self._emit_live_output(tail, stream_name, live_state)
return
text = decoder.decode(chunk)
if not text:
continue
self._append_capture(capture, text)
self._emit_live_output(text, stream_name, live_state)
@staticmethod
async def _terminate_process(process: asyncio.subprocess.Process):
if process.returncode is not None:
return
try:
process.kill()
except ProcessLookupError:
return
try:
await asyncio.wait_for(process.wait(), timeout=5)
except asyncio.TimeoutError:
logger.warning("终止命令进程超时")
async def run(self, command: str, timeout: Optional[int] = 60, **kwargs) -> str:
logger.info(
f"执行工具: {self.name}, 参数: command={command}, timeout={timeout}"
@@ -56,40 +178,54 @@ class ExecuteCommandTool(MoviePilotTool):
command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
stdout_capture: Dict[str, Any] = {
"chunks": [],
"length": 0,
"truncated": False,
}
stderr_capture: Dict[str, Any] = {
"chunks": [],
"length": 0,
"truncated": False,
}
live_state: Dict[str, Any] = {
"enabled": self._should_emit_live_output(),
"chars": 0,
"truncated": False,
"stdout_header_sent": False,
"stderr_header_sent": False,
}
stdout_task = asyncio.create_task(
self._collect_stream(
process.stdout, "stdout", stdout_capture, live_state
)
)
stderr_task = asyncio.create_task(
self._collect_stream(
process.stderr, "stderr", stderr_capture, live_state
)
)
try:
# 等待完成,带超时
stdout, stderr = await asyncio.wait_for(
process.communicate(), timeout=timeout
await asyncio.wait_for(process.wait(), timeout=timeout)
await asyncio.gather(stdout_task, stderr_task)
return self._build_result(
f"命令执行完成 (退出码: {process.returncode})",
stdout_capture,
stderr_capture,
)
# 处理输出
stdout_str = stdout.decode("utf-8", errors="replace").strip()
stderr_str = stderr.decode("utf-8", errors="replace").strip()
exit_code = process.returncode
result = f"命令执行完成 (退出码: {exit_code})"
if stdout_str:
result += f"\n\n标准输出:\n{stdout_str}"
if stderr_str:
result += f"\n\n错误输出:\n{stderr_str}"
# 如果没有输出
if not stdout_str and not stderr_str:
result += "\n\n(无输出内容)"
# 限制输出长度,防止上下文过长
if len(result) > 3000:
result = result[:3000] + "\n\n...(输出内容过长,已截断)"
return result
except asyncio.TimeoutError:
# 超时处理
try:
process.kill()
except ProcessLookupError:
pass
return f"命令执行超时 (限制: {timeout}秒)"
await self._terminate_process(process)
await asyncio.gather(stdout_task, stderr_task)
return self._build_result(
f"命令执行超时 (限制: {timeout}秒)",
stdout_capture,
stderr_capture,
)
except Exception as e:
logger.error(f"执行命令失败: {e}", exc_info=True)