Files
MoviePilot/app/api/endpoints/agent.py

788 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import hashlib
import json
import mimetypes
import time
import uuid
from pathlib import Path
from typing import Any, AsyncIterator, Callable, Optional
from fastapi import APIRouter, Depends, File, Form, HTTPException, Request, UploadFile, status
from fastapi.responses import FileResponse, StreamingResponse
from app import schemas
from app.agent import MoviePilotAgent, ReplyMode, StreamingHandler
from app.core.config import global_vars, settings
from app.db.models import User
from app.db.user_oper import UserOper, get_current_active_user
from app.helper.interaction import agent_interaction_manager
from app.log import logger
from app.schemas.types import MessageChannel
router = APIRouter()
WEB_AGENT_SESSION_PREFIX = "web-agent:"
WEB_AGENT_SOURCE = "web-agent"
WEB_AGENT_FILE_TTL_SECONDS = 6 * 60 * 60
WEB_AGENT_FILE_MAX_ITEMS = 256
WEB_AGENT_UPLOAD_MAX_BYTES = 32 * 1024 * 1024
WEB_AGENT_UPLOAD_CHUNK_SIZE = 1024 * 1024
_WEB_AGENT_FILE_REGISTRY: dict[str, dict[str, Any]] = {}
class _WebAgentStreamingHandler(StreamingHandler):
"""
Web 前端专用流式处理器,将工具提示和文本统一回调给 SSE。
"""
def __init__(self, on_emit: Callable[[str], None]) -> None:
super().__init__()
self._on_emit = on_emit
def emit(self, token: str) -> str:
"""追加 token 并同步通知 SSE 生产者。"""
emitted = super().emit(token)
if emitted:
self._on_emit(emitted)
return emitted
def flush_pending_tool_summary(self) -> str:
"""输出延迟聚合的工具摘要。"""
emitted = super().flush_pending_tool_summary()
if emitted:
self._on_emit(emitted)
return ""
async def start_streaming(
self,
channel: Optional[str] = None,
source: Optional[str] = None,
user_id: Optional[str] = None,
username: Optional[str] = None,
original_message_id: Optional[str] = None,
original_chat_id: Optional[str] = None,
title: str = "",
) -> None:
"""Web SSE 自身负责外发,不启动消息模块编辑循环。"""
self._channel = channel
self._source = source
self._user_id = user_id
self._username = username
self._original_message_id = original_message_id
self._original_chat_id = original_chat_id
self._title = title
self._streaming_enabled = True
self._sent_text = ""
self._message_response = None
self._msg_start_offset = 0
self._pending_tool_stats = {}
async def stop_streaming(self) -> tuple[bool, str]:
"""停止 Web SSE 流式状态,保留缓冲区给 Agent 收口逻辑去重。"""
if not self._streaming_enabled:
return False, ""
self._streaming_enabled = False
self.flush_pending_tool_summary()
with self._lock:
self._sent_text = ""
self._message_response = None
self._msg_start_offset = 0
self._pending_tool_stats = {}
return False, ""
@property
def is_auto_flushing(self) -> bool:
"""让工具执行提示进入缓冲区,由 SSE 回调负责外发。"""
return True
class _WebAgentMoviePilotAgent(MoviePilotAgent):
"""
Web 前端专用 Agent强制使用流式推理。
"""
def __init__(
self,
*args: Any,
notification_callback: Optional[Callable[[schemas.Notification], None]] = None,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self._notification_callback = notification_callback
self.stream_handler = _WebAgentStreamingHandler(self._emit_output)
def _should_stream(self) -> bool:
"""Web 面板需要实时输出,即使 Web 渠道本身不支持消息编辑。"""
return True
async def _is_system_admin_context(self) -> bool:
"""Web Agent 根据当前登录用户 ID 判断工具管理员上下文。"""
if not self.user_id:
return False
try:
user = await UserOper().async_get_by_id(int(self.user_id))
except (TypeError, ValueError):
return False
except Exception as e:
logger.error(f"检查 Web Agent 用户管理员身份失败: {e}")
return False
return bool(user and user.is_superuser)
async def _build_tool_context(self, should_dispatch_reply: bool) -> dict[str, object]:
"""向工具上下文注入 Web SSE 通知回调。"""
context = await super()._build_tool_context(should_dispatch_reply)
context["notification_callback"] = self._notification_callback
return context
def _handle_stream_text(self, text: str) -> None:
"""文本输出交由 Web 流式处理器统一回调,避免重复增量。"""
self.stream_handler.emit(text)
def _build_web_agent_session_id(user: User, session_id: Optional[str]) -> str:
"""
构建前端 Agent 会话 ID。
:param user: 当前登录用户
:param session_id: 前端传入的会话标识
:return: 可用于 Agent 记忆隔离的服务端会话 ID
"""
seed = str(session_id or "").strip() or uuid.uuid4().hex
user_part = user.name or str(user.id)
digest = hashlib.sha256(f"{user_part}:{seed}".encode("utf-8")).hexdigest()
return f"{WEB_AGENT_SESSION_PREFIX}{digest[:32]}"
def _build_web_agent_sse(event_type: str, data: Optional[dict] = None) -> str:
"""
构建 Web Agent SSE 消息。
:param event_type: 前端事件类型
:param data: 事件数据
:return: 符合 SSE 格式的字符串
"""
payload = {"type": event_type, **(data or {})}
return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
def _sanitize_web_agent_upload_name(
filename: Optional[str], mime_type: Optional[str] = None
) -> str:
"""
规范化 Web Agent 上传文件名,避免路径穿越和空文件名。
:param filename: 浏览器上传的原始文件名
:param mime_type: 浏览器上报的 MIME 类型
:return: 可安全落盘的文件名
"""
name = Path(filename or "attachment").name.strip()
safe_name = "".join(
char for char in name if char.isalnum() or char in (" ", ".", "_", "-")
).strip(" .")
if not safe_name:
safe_name = "attachment"
if "." not in safe_name:
suffix = mimetypes.guess_extension(mime_type or "") or ""
safe_name = f"{safe_name}{suffix}"
return safe_name
def _get_web_agent_upload_dir(user: User, session_id: Optional[str]) -> Path:
"""
计算当前 Web Agent 会话的临时附件目录。
:param user: 当前登录用户
:param session_id: 前端会话标识
:return: 已创建的临时附件目录
"""
server_session_id = _build_web_agent_session_id(user, session_id)
safe_session_id = server_session_id.replace(":", "_")
upload_dir = settings.TEMP_PATH / "agent_uploads" / safe_session_id
upload_dir.mkdir(parents=True, exist_ok=True)
return upload_dir
async def _save_web_agent_upload(upload_file: UploadFile, target_path: Path) -> int:
"""
分块保存 Web Agent 上传文件,并限制单文件体积。
:param upload_file: FastAPI 上传文件对象
:param target_path: 目标落盘路径
:return: 已写入的字节数
"""
size = 0
try:
with target_path.open("wb") as output:
while True:
chunk = await upload_file.read(WEB_AGENT_UPLOAD_CHUNK_SIZE)
if not chunk:
break
size += len(chunk)
if size > WEB_AGENT_UPLOAD_MAX_BYTES:
raise HTTPException(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail="附件超过 32MB无法发送给智能助手",
)
output.write(chunk)
except Exception:
target_path.unlink(missing_ok=True)
raise
finally:
await upload_file.close()
return size
def _cleanup_web_agent_file_registry() -> None:
"""清理过期或过量的 Web Agent 临时附件引用。"""
now = time.time()
expired_ids = [
file_id
for file_id, info in _WEB_AGENT_FILE_REGISTRY.items()
if now - info.get("created_at", now) > WEB_AGENT_FILE_TTL_SECONDS
]
for file_id in expired_ids:
_WEB_AGENT_FILE_REGISTRY.pop(file_id, None)
overflow = len(_WEB_AGENT_FILE_REGISTRY) - WEB_AGENT_FILE_MAX_ITEMS
if overflow <= 0:
return
sorted_items = sorted(
_WEB_AGENT_FILE_REGISTRY.items(),
key=lambda item: item[1].get("created_at", 0),
)
for file_id, _ in sorted_items[:overflow]:
_WEB_AGENT_FILE_REGISTRY.pop(file_id, None)
def _guess_web_agent_attachment_kind(
mime_type: Optional[str], fallback: str = "file"
) -> str:
"""
根据 MIME 类型推断前端附件展示方式。
:param mime_type: 文件 MIME 类型
:param fallback: 无法推断时使用的类型
:return: image、audio 或 file
"""
if mime_type and mime_type.startswith("image/"):
return "image"
if mime_type and mime_type.startswith("audio/"):
return "audio"
return fallback
def _build_web_agent_url_attachment(
url: str,
kind: str,
name: Optional[str] = None,
mime_type: Optional[str] = None,
) -> dict:
"""
构建远程或 data URL 附件事件。
:param url: 前端可访问的附件地址
:param kind: 附件展示类型
:param name: 展示名称
:param mime_type: MIME 类型
:return: 前端附件描述
"""
return {
"kind": kind,
"url": url,
"download_url": url,
"name": name,
"mime_type": mime_type,
}
def _register_web_agent_file(
file_path: Optional[str],
file_name: Optional[str] = None,
kind: Optional[str] = None,
mime_type: Optional[str] = None,
) -> Optional[dict]:
"""
注册 Web Agent 本地附件并返回前端可访问的短期下载地址。
:param file_path: 本地文件路径
:param file_name: 前端展示文件名
:param kind: 附件展示类型
:param mime_type: 已知 MIME 类型
:return: 前端附件描述,文件不可访问时返回 None
"""
if not file_path:
return None
try:
resolved_path = Path(file_path).expanduser().resolve(strict=True)
except OSError:
return None
if not resolved_path.is_file():
return None
_cleanup_web_agent_file_registry()
file_id = uuid.uuid4().hex
display_name = file_name or resolved_path.name
resolved_mime_type = mime_type or mimetypes.guess_type(
display_name or str(resolved_path)
)[0]
file_url = f"message/agent/file/{file_id}"
_WEB_AGENT_FILE_REGISTRY[file_id] = {
"path": resolved_path,
"name": display_name,
"mime_type": resolved_mime_type or "application/octet-stream",
"created_at": time.time(),
}
return {
"kind": kind or _guess_web_agent_attachment_kind(resolved_mime_type),
"url": file_url,
"download_url": file_url,
"name": display_name,
"mime_type": resolved_mime_type,
"size": resolved_path.stat().st_size,
}
def _parse_web_agent_choice_callback(callback_data: str) -> Optional[tuple[str, int]]:
"""
解析 Web Agent 按钮选择回调数据。
:param callback_data: Agent 按钮携带的回调数据
:return: 请求 ID 与选项序号,格式无效时返回 None
"""
if not callback_data.startswith("agent_interaction:choice:"):
return None
try:
_, _, request_id, option_index = callback_data.split(":", 3)
except ValueError:
return None
if not request_id or not option_index.isdigit():
return None
return request_id, int(option_index)
def _flatten_web_agent_choice_buttons(buttons: Optional[list[list[dict]]]) -> list[dict]:
"""
将消息渠道按钮二维结构转换为 Web 前端可渲染的一维选项列表。
:param buttons: Notification 中的按钮行
:return: Web 选择卡片按钮列表
"""
flattened = []
for row in buttons or []:
for button in row or []:
text = str(button.get("text") or "").strip()
callback_data = str(button.get("callback_data") or "").strip()
if not text or not callback_data:
continue
flattened.append(
{
"label": text,
"callback_data": callback_data,
}
)
return flattened
def _build_web_agent_choice_event(notification: schemas.Notification) -> Optional[dict]:
"""
将带按钮通知转换为 Web Agent 选择卡片事件。
:param notification: Agent 工具发出的按钮通知
:return: 选择卡片事件,按钮为空时返回 None
"""
buttons = _flatten_web_agent_choice_buttons(notification.buttons)
if not buttons:
return None
choice_id = None
parsed = _parse_web_agent_choice_callback(buttons[0]["callback_data"])
if parsed:
choice_id = parsed[0]
return {
"type": "choice",
"choice": {
"id": choice_id or uuid.uuid4().hex,
"title": notification.title,
"prompt": notification.text or "",
"buttons": buttons,
},
}
def _resolve_web_agent_choice_payload(callback_data: str, user_id: str) -> Optional[dict]:
"""
解析并消费 Web Agent 按钮选择,生成前端反馈与下一条用户消息。
:param callback_data: 前端点击的按钮回调数据
:param user_id: 当前登录用户 ID
:return: 可返回给前端的数据,选择无效时返回 None
"""
parsed = _parse_web_agent_choice_callback(callback_data)
if not parsed:
return None
request_id, option_index = parsed
resolved = agent_interaction_manager.resolve(
request_id=request_id,
option_index=option_index,
user_id=str(user_id),
)
if not resolved:
return None
request, option = resolved
return {
"message": option.value,
"session_id": request.session_id,
"feedback": {
"request_id": request.request_id,
"title": request.title,
"prompt": request.prompt,
"selected_label": option.label,
"selected_value": option.value,
},
}
def _build_web_agent_notification_events(
notification: schemas.Notification,
) -> list[dict]:
"""
将 Agent 工具通知转换为 Web SSE 事件。
:param notification: 工具产生的通知消息
:return: 前端可直接应用到当前助手消息的事件列表
"""
events = []
choice_event = _build_web_agent_choice_event(notification)
if choice_event:
events.append(choice_event)
text_parts = [
str(item).strip()
for item in (notification.title, notification.text)
if str(item or "").strip()
]
if text_parts and not choice_event:
events.append({"type": "delta", "content": "\n\n".join(text_parts)})
if notification.image:
image_ref = notification.image
image_path = Path(image_ref).expanduser()
attachment = None
if not image_ref.startswith(("http://", "https://", "data:", "blob:")):
attachment = _register_web_agent_file(
image_ref, file_name=Path(image_ref).name, kind="image"
)
if not attachment:
attachment = _build_web_agent_url_attachment(
image_ref,
kind="image",
name=notification.title or image_path.name or "image",
)
events.append({"type": "attachment", "attachment": attachment})
if notification.voice_path:
attachment = _register_web_agent_file(
notification.voice_path,
file_name=Path(notification.voice_path).name,
kind="audio",
)
if attachment:
events.append({"type": "attachment", "attachment": attachment})
if notification.file_path:
attachment = _register_web_agent_file(
notification.file_path,
file_name=notification.file_name or Path(notification.file_path).name,
)
if attachment:
events.append({"type": "attachment", "attachment": attachment})
return events
def _split_web_agent_output(text: str) -> list[dict]:
"""
将 Agent 输出拆成普通文本与工具提示事件。
:param text: 本次新增的 Agent 文本
:return: 前端可直接渲染的事件片段
"""
if not text:
return []
events = []
def append_text(content: str) -> None:
"""将工具汇总行从普通文本中拆出,便于前端独立展示。"""
if not content:
return
lines = content.splitlines(keepends=True)
buffer = ""
for line in lines:
stripped_line = line.strip()
if (
stripped_line.startswith("")
and stripped_line.endswith("")
):
if buffer:
events.append({"type": "delta", "content": buffer})
buffer = ""
events.append(
{
"type": "tool",
"message": stripped_line.strip(""),
}
)
else:
buffer += line
if buffer:
events.append({"type": "delta", "content": buffer})
marker = "⚙️ => "
remaining = text
while remaining:
marker_index = remaining.find(marker)
if marker_index < 0:
append_text(remaining)
break
if marker_index > 0:
append_text(remaining[:marker_index])
after_marker = remaining[marker_index + len(marker):]
line_end = after_marker.find("\n")
if line_end < 0:
message = after_marker.strip()
remaining = ""
else:
message = after_marker[:line_end].strip()
remaining = after_marker[line_end:].lstrip("\n")
if message:
events.append({"type": "tool", "message": message})
return events
@router.get("/file/{file_id}", summary="下载 Web 智能助手附件")
async def download_web_agent_file(file_id: str) -> FileResponse:
"""
下载 Web 智能助手本轮生成的临时附件。
:param file_id: 附件随机标识
:return: 附件文件响应
"""
_cleanup_web_agent_file_registry()
file_info = _WEB_AGENT_FILE_REGISTRY.get(file_id)
if not file_info:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="附件不存在或已过期")
file_path = file_info["path"]
if not file_path.exists() or not file_path.is_file():
_WEB_AGENT_FILE_REGISTRY.pop(file_id, None)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="附件不存在或已过期")
return FileResponse(
path=file_path,
media_type=file_info.get("mime_type") or "application/octet-stream",
filename=file_info.get("name") or file_path.name,
)
@router.post("/upload", summary="上传 Web 智能助手附件", response_model=schemas.Response)
async def upload_web_agent_file(
file: UploadFile = File(...),
session_id: Optional[str] = Form(None),
current_user: User = Depends(get_current_active_user),
) -> schemas.Response:
"""
上传 Web 智能助手对话附件。
:param file: 浏览器选择的文件
:param session_id: 前端会话标识
:param current_user: 当前登录用户
:return: Agent 可消费的附件描述
"""
mime_type = file.content_type or mimetypes.guess_type(file.filename or "")[0]
safe_name = _sanitize_web_agent_upload_name(file.filename, mime_type)
upload_dir = _get_web_agent_upload_dir(current_user, session_id)
target_path = upload_dir / f"{uuid.uuid4().hex[:8]}_{safe_name}"
size = await _save_web_agent_upload(file, target_path)
attachment = _register_web_agent_file(
str(target_path),
file_name=safe_name,
kind=_guess_web_agent_attachment_kind(mime_type),
mime_type=mime_type,
)
if not attachment:
target_path.unlink(missing_ok=True)
return schemas.Response(success=False, message="附件保存失败")
attachment.update(
{
"ref": attachment["url"],
"local_path": str(target_path),
"status": "ready",
"size": size,
}
)
return schemas.Response(success=True, data=attachment)
@router.post("/callback", summary="Web 智能助手按钮回调", response_model=schemas.Response)
async def web_agent_callback(
payload: schemas.AgentWebChoiceRequest,
current_user: User = Depends(get_current_active_user),
) -> schemas.Response:
"""
接收 Web 智能助手选择卡片回调。
:param payload: 按钮选择请求
:param current_user: 当前登录用户
:return: 下一条需要发送给 Agent 的用户消息与卡片反馈
"""
result = _resolve_web_agent_choice_payload(
callback_data=payload.callback_data,
user_id=str(current_user.id),
)
if not result:
return schemas.Response(success=False, message="该选择已失效,请重新发起选择")
return schemas.Response(success=True, data=result)
@router.post("/stream", summary="Web智能助手流式对话")
async def web_agent_stream(
payload: schemas.AgentWebChatRequest,
request: Request,
current_user: User = Depends(get_current_active_user),
) -> StreamingResponse:
"""
Web 智能助手流式对话。
:param payload: 对话请求
:param request: 当前 HTTP 请求
:param current_user: 当前登录用户
:return: SSE 流式响应
"""
if not settings.AI_AGENT_ENABLE:
return StreamingResponse(
iter([
_build_web_agent_sse(
"error",
{"message": "智能助手未启用,请先在系统设置中开启。"},
)
]),
media_type="text/event-stream",
)
prompt = payload.text.strip()
if not prompt and not payload.images and not payload.files and not payload.audio_refs:
return StreamingResponse(
iter([
_build_web_agent_sse(
"error",
{"message": "请输入要发送给智能助手的内容或选择附件。"},
)
]),
media_type="text/event-stream",
)
session_id = _build_web_agent_session_id(current_user, payload.session_id)
event_queue: asyncio.Queue = asyncio.Queue()
last_output = ""
def output_callback(output: str) -> None:
"""
接收 Agent 累积输出并转成增量事件。
"""
nonlocal last_output
delta = output[len(last_output):] if output.startswith(last_output) else output
last_output = output
for item in _split_web_agent_output(delta):
event_queue.put_nowait(item)
def notification_callback(notification: schemas.Notification) -> None:
"""
接收 Agent 工具主动发送的 Web 通知。
"""
for item in _build_web_agent_notification_events(notification):
event_queue.put_nowait(item)
async def event_generator() -> AsyncIterator[str]:
"""
生成前端 Agent SSE 事件。
"""
files = [
file.model_dump(exclude_none=True)
for file in (payload.files or [])
]
for audio_ref in payload.audio_refs or []:
files.append({"ref": audio_ref, "mime_type": "audio/*"})
agent = _WebAgentMoviePilotAgent(
session_id=session_id,
user_id=str(current_user.id),
channel=MessageChannel.WebAgent.value,
source=WEB_AGENT_SOURCE,
username=current_user.name,
replay_mode=ReplyMode.CAPTURE_ONLY,
persist_output_message=False,
allow_message_tools=True,
output_callback=output_callback,
notification_callback=notification_callback,
)
async def run_agent() -> None:
"""后台执行 Agent并将结果写入事件队列。"""
try:
await agent.process(
message=prompt,
images=payload.images or [],
files=files or None,
has_audio_input=bool(payload.audio_refs),
)
except Exception as err:
logger.error(f"Web智能助手执行失败: {str(err)}")
await event_queue.put(
{"type": "error", "message": f"智能助手执行失败: {str(err)}"}
)
finally:
await event_queue.put({"type": "done"})
task = asyncio.create_task(run_agent())
try:
yield _build_web_agent_sse(
"start",
{"session_id": payload.session_id or session_id},
)
while not global_vars.is_system_stopped:
if await request.is_disconnected():
break
event = await event_queue.get()
yield _build_web_agent_sse(event.pop("type"), event)
if task.done() and event_queue.empty():
break
except asyncio.CancelledError:
return
finally:
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
await agent.cleanup()
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)