feat: implement file upload and callback handling for Web Agent

This commit is contained in:
jxxghp
2026-06-16 22:53:11 +08:00
parent e8ae686d4f
commit e78efe3e34
6 changed files with 413 additions and 19 deletions

View File

@@ -4,7 +4,7 @@ from typing import List, Optional, Type
from pydantic import BaseModel, Field, model_validator
from app.agent.tools.base import MoviePilotTool, ToolChain
from app.agent.tools.base import MoviePilotTool
from app.agent.tools.tags import ToolTag
from app.helper.interaction import (
AgentInteractionOption,
@@ -188,7 +188,7 @@ class AskUserChoiceTool(MoviePilotTool):
len(choice_options),
)
await ToolChain().async_post_message(
await self.send_notification_message(
Notification(
channel=channel,
source=self._source,

View File

@@ -7,14 +7,15 @@ import uuid
from pathlib import Path
from typing import Any, AsyncIterator, Callable, Optional
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi import APIRouter, Depends, File, Form, HTTPException, Request, UploadFile, status
from fastapi.responses import FileResponse, StreamingResponse
from app import schemas
from app.agent import MoviePilotAgent, ReplyMode, StreamingHandler
from app.core.config import global_vars, settings
from app.db.models import User
from app.db.user_oper import get_current_active_superuser
from app.db.user_oper import UserOper, get_current_active_user
from app.helper.interaction import agent_interaction_manager
from app.log import logger
from app.schemas.types import MessageChannel
@@ -24,6 +25,8 @@ WEB_AGENT_SESSION_PREFIX = "web-agent:"
WEB_AGENT_SOURCE = "web-agent"
WEB_AGENT_FILE_TTL_SECONDS = 6 * 60 * 60
WEB_AGENT_FILE_MAX_ITEMS = 256
WEB_AGENT_UPLOAD_MAX_BYTES = 32 * 1024 * 1024
WEB_AGENT_UPLOAD_CHUNK_SIZE = 1024 * 1024
_WEB_AGENT_FILE_REGISTRY: dict[str, dict[str, Any]] = {}
@@ -113,8 +116,17 @@ class _WebAgentMoviePilotAgent(MoviePilotAgent):
return True
async def _is_system_admin_context(self) -> bool:
"""Web Agent 入口已要求超级管理员,工具上下文可直接按管理员处理"""
return True
"""Web Agent 根据当前登录用户 ID 判断工具管理员上下文"""
if not self.user_id:
return False
try:
user = await UserOper().async_get_by_id(int(self.user_id))
except (TypeError, ValueError):
return False
except Exception as e:
logger.error(f"检查 Web Agent 用户管理员身份失败: {e}")
return False
return bool(user and user.is_superuser)
async def _build_tool_context(self, should_dispatch_reply: bool) -> dict[str, object]:
"""向工具上下文注入 Web SSE 通知回调。"""
@@ -153,6 +165,73 @@ def _build_web_agent_sse(event_type: str, data: Optional[dict] = None) -> str:
return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
def _sanitize_web_agent_upload_name(
filename: Optional[str], mime_type: Optional[str] = None
) -> str:
"""
规范化 Web Agent 上传文件名,避免路径穿越和空文件名。
:param filename: 浏览器上传的原始文件名
:param mime_type: 浏览器上报的 MIME 类型
:return: 可安全落盘的文件名
"""
name = Path(filename or "attachment").name.strip()
safe_name = "".join(
char for char in name if char.isalnum() or char in (" ", ".", "_", "-")
).strip(" .")
if not safe_name:
safe_name = "attachment"
if "." not in safe_name:
suffix = mimetypes.guess_extension(mime_type or "") or ""
safe_name = f"{safe_name}{suffix}"
return safe_name
def _get_web_agent_upload_dir(user: User, session_id: Optional[str]) -> Path:
"""
计算当前 Web Agent 会话的临时附件目录。
:param user: 当前登录用户
:param session_id: 前端会话标识
:return: 已创建的临时附件目录
"""
server_session_id = _build_web_agent_session_id(user, session_id)
safe_session_id = server_session_id.replace(":", "_")
upload_dir = settings.TEMP_PATH / "agent_uploads" / safe_session_id
upload_dir.mkdir(parents=True, exist_ok=True)
return upload_dir
async def _save_web_agent_upload(upload_file: UploadFile, target_path: Path) -> int:
"""
分块保存 Web Agent 上传文件,并限制单文件体积。
:param upload_file: FastAPI 上传文件对象
:param target_path: 目标落盘路径
:return: 已写入的字节数
"""
size = 0
try:
with target_path.open("wb") as output:
while True:
chunk = await upload_file.read(WEB_AGENT_UPLOAD_CHUNK_SIZE)
if not chunk:
break
size += len(chunk)
if size > WEB_AGENT_UPLOAD_MAX_BYTES:
raise HTTPException(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail="附件超过 32MB无法发送给智能助手",
)
output.write(chunk)
except Exception:
target_path.unlink(missing_ok=True)
raise
finally:
await upload_file.close()
return size
def _cleanup_web_agent_file_registry() -> None:
"""清理过期或过量的 Web Agent 临时附件引用。"""
now = time.time()
@@ -220,6 +299,7 @@ def _register_web_agent_file(
file_path: Optional[str],
file_name: Optional[str] = None,
kind: Optional[str] = None,
mime_type: Optional[str] = None,
) -> Optional[dict]:
"""
注册 Web Agent 本地附件并返回前端可访问的短期下载地址。
@@ -227,6 +307,7 @@ def _register_web_agent_file(
:param file_path: 本地文件路径
:param file_name: 前端展示文件名
:param kind: 附件展示类型
:param mime_type: 已知 MIME 类型
:return: 前端附件描述,文件不可访问时返回 None
"""
if not file_path:
@@ -241,24 +322,129 @@ def _register_web_agent_file(
_cleanup_web_agent_file_registry()
file_id = uuid.uuid4().hex
display_name = file_name or resolved_path.name
mime_type = mimetypes.guess_type(display_name or str(resolved_path))[0]
resolved_mime_type = mime_type or mimetypes.guess_type(
display_name or str(resolved_path)
)[0]
file_url = f"message/agent/file/{file_id}"
_WEB_AGENT_FILE_REGISTRY[file_id] = {
"path": resolved_path,
"name": display_name,
"mime_type": mime_type or "application/octet-stream",
"mime_type": resolved_mime_type or "application/octet-stream",
"created_at": time.time(),
}
return {
"kind": kind or _guess_web_agent_attachment_kind(mime_type),
"kind": kind or _guess_web_agent_attachment_kind(resolved_mime_type),
"url": file_url,
"download_url": file_url,
"name": display_name,
"mime_type": mime_type,
"mime_type": resolved_mime_type,
"size": resolved_path.stat().st_size,
}
def _parse_web_agent_choice_callback(callback_data: str) -> Optional[tuple[str, int]]:
"""
解析 Web Agent 按钮选择回调数据。
:param callback_data: Agent 按钮携带的回调数据
:return: 请求 ID 与选项序号,格式无效时返回 None
"""
if not callback_data.startswith("agent_interaction:choice:"):
return None
try:
_, _, request_id, option_index = callback_data.split(":", 3)
except ValueError:
return None
if not request_id or not option_index.isdigit():
return None
return request_id, int(option_index)
def _flatten_web_agent_choice_buttons(buttons: Optional[list[list[dict]]]) -> list[dict]:
"""
将消息渠道按钮二维结构转换为 Web 前端可渲染的一维选项列表。
:param buttons: Notification 中的按钮行
:return: Web 选择卡片按钮列表
"""
flattened = []
for row in buttons or []:
for button in row or []:
text = str(button.get("text") or "").strip()
callback_data = str(button.get("callback_data") or "").strip()
if not text or not callback_data:
continue
flattened.append(
{
"label": text,
"callback_data": callback_data,
}
)
return flattened
def _build_web_agent_choice_event(notification: schemas.Notification) -> Optional[dict]:
"""
将带按钮通知转换为 Web Agent 选择卡片事件。
:param notification: Agent 工具发出的按钮通知
:return: 选择卡片事件,按钮为空时返回 None
"""
buttons = _flatten_web_agent_choice_buttons(notification.buttons)
if not buttons:
return None
choice_id = None
parsed = _parse_web_agent_choice_callback(buttons[0]["callback_data"])
if parsed:
choice_id = parsed[0]
return {
"type": "choice",
"choice": {
"id": choice_id or uuid.uuid4().hex,
"title": notification.title,
"prompt": notification.text or "",
"buttons": buttons,
},
}
def _resolve_web_agent_choice_payload(callback_data: str, user_id: str) -> Optional[dict]:
"""
解析并消费 Web Agent 按钮选择,生成前端反馈与下一条用户消息。
:param callback_data: 前端点击的按钮回调数据
:param user_id: 当前登录用户 ID
:return: 可返回给前端的数据,选择无效时返回 None
"""
parsed = _parse_web_agent_choice_callback(callback_data)
if not parsed:
return None
request_id, option_index = parsed
resolved = agent_interaction_manager.resolve(
request_id=request_id,
option_index=option_index,
user_id=str(user_id),
)
if not resolved:
return None
request, option = resolved
return {
"message": option.value,
"session_id": request.session_id,
"feedback": {
"request_id": request.request_id,
"title": request.title,
"prompt": request.prompt,
"selected_label": option.label,
"selected_value": option.value,
},
}
def _build_web_agent_notification_events(
notification: schemas.Notification,
) -> list[dict]:
@@ -269,12 +455,16 @@ def _build_web_agent_notification_events(
:return: 前端可直接应用到当前助手消息的事件列表
"""
events = []
choice_event = _build_web_agent_choice_event(notification)
if choice_event:
events.append(choice_event)
text_parts = [
str(item).strip()
for item in (notification.title, notification.text)
if str(item or "").strip()
]
if text_parts:
if text_parts and not choice_event:
events.append({"type": "delta", "content": "\n\n".join(text_parts)})
if notification.image:
@@ -402,18 +592,79 @@ async def download_web_agent_file(file_id: str) -> FileResponse:
)
@router.post("/upload", summary="上传 Web 智能助手附件", response_model=schemas.Response)
async def upload_web_agent_file(
file: UploadFile = File(...),
session_id: Optional[str] = Form(None),
current_user: User = Depends(get_current_active_user),
) -> schemas.Response:
"""
上传 Web 智能助手对话附件。
:param file: 浏览器选择的文件
:param session_id: 前端会话标识
:param current_user: 当前登录用户
:return: Agent 可消费的附件描述
"""
mime_type = file.content_type or mimetypes.guess_type(file.filename or "")[0]
safe_name = _sanitize_web_agent_upload_name(file.filename, mime_type)
upload_dir = _get_web_agent_upload_dir(current_user, session_id)
target_path = upload_dir / f"{uuid.uuid4().hex[:8]}_{safe_name}"
size = await _save_web_agent_upload(file, target_path)
attachment = _register_web_agent_file(
str(target_path),
file_name=safe_name,
kind=_guess_web_agent_attachment_kind(mime_type),
mime_type=mime_type,
)
if not attachment:
target_path.unlink(missing_ok=True)
return schemas.Response(success=False, message="附件保存失败")
attachment.update(
{
"ref": attachment["url"],
"local_path": str(target_path),
"status": "ready",
"size": size,
}
)
return schemas.Response(success=True, data=attachment)
@router.post("/callback", summary="Web 智能助手按钮回调", response_model=schemas.Response)
async def web_agent_callback(
payload: schemas.AgentWebChoiceRequest,
current_user: User = Depends(get_current_active_user),
) -> schemas.Response:
"""
接收 Web 智能助手选择卡片回调。
:param payload: 按钮选择请求
:param current_user: 当前登录用户
:return: 下一条需要发送给 Agent 的用户消息与卡片反馈
"""
result = _resolve_web_agent_choice_payload(
callback_data=payload.callback_data,
user_id=str(current_user.id),
)
if not result:
return schemas.Response(success=False, message="该选择已失效,请重新发起选择")
return schemas.Response(success=True, data=result)
@router.post("/stream", summary="Web智能助手流式对话")
async def web_agent_stream(
payload: schemas.AgentWebChatRequest,
request: Request,
current_user: User = Depends(get_current_active_superuser),
current_user: User = Depends(get_current_active_user),
) -> StreamingResponse:
"""
Web 智能助手流式对话。
:param payload: 对话请求
:param request: 当前 HTTP 请求
:param current_user: 当前登录管理员
:param current_user: 当前登录用户
:return: SSE 流式响应
"""
if not settings.AI_AGENT_ENABLE:
@@ -428,12 +679,12 @@ async def web_agent_stream(
)
prompt = payload.text.strip()
if not prompt:
if not prompt and not payload.images and not payload.files and not payload.audio_refs:
return StreamingResponse(
iter([
_build_web_agent_sse(
"error",
{"message": "请输入要发送给智能助手的内容。"},
{"message": "请输入要发送给智能助手的内容或选择附件"},
)
]),
media_type="text/event-stream",

View File

@@ -114,6 +114,12 @@ class UserOper(DbOper):
"""
return await User.async_get_by_name(self._db, name)
async def async_get_by_id(self, user_id: int) -> User:
"""
异步根据用户 ID 获取用户。
"""
return await User.async_get_by_id(self._db, user_id)
def get_permissions(self, name: str) -> dict:
"""
获取用户权限

View File

@@ -264,9 +264,11 @@ class AgentWebChatRequest(BaseModel):
name: Optional[str] = Field(None)
mime_type: Optional[str] = Field(None)
size: Optional[int] = Field(None)
local_path: Optional[str] = Field(None)
status: Optional[str] = Field(None)
# 用户本轮输入
text: str = Field(..., min_length=1)
text: str = Field(default="")
# 前端会话标识,相同标识复用同一段 Agent 记忆
session_id: Optional[str] = Field(None)
# 图片 URL 或 data URL 列表
@@ -277,6 +279,17 @@ class AgentWebChatRequest(BaseModel):
files: Optional[List[AgentWebChatFile]] = Field(default_factory=list)
class AgentWebChoiceRequest(BaseModel):
"""
Web 智能助手按钮选择请求。
"""
# 前端会话标识,用于保持与原对话窗口的关联
session_id: Optional[str] = Field(None)
# Agent 工具生成的按钮回调数据
callback_data: str = Field(..., min_length=1)
class ChannelCapability(Enum):
"""
渠道能力枚举
@@ -474,6 +487,8 @@ class ChannelCapabilityManager:
MessageChannel.WebAgent: ChannelCapabilities(
channel=MessageChannel.WebAgent,
capabilities={
ChannelCapability.INLINE_BUTTONS,
ChannelCapability.CALLBACK_QUERIES,
ChannelCapability.MESSAGE_EDITING,
ChannelCapability.MARKDOWN,
ChannelCapability.RICH_TEXT,

View File

@@ -25,11 +25,15 @@ class TestAgentInteraction(unittest.TestCase):
telegram_prompt = prompt_manager.get_agent_prompt(
channel=MessageChannel.Telegram.value
)
web_agent_prompt = prompt_manager.get_agent_prompt(
channel=MessageChannel.WebAgent.value
)
wechat_prompt = prompt_manager.get_agent_prompt(
channel=MessageChannel.Wechat.value
)
self.assertIn("ask_user_choice", telegram_prompt)
self.assertIn("ask_user_choice", web_agent_prompt)
self.assertIn("terminal interaction tool", telegram_prompt)
self.assertIn("do not write a final text reply after it", telegram_prompt)
self.assertNotIn("ask_user_choice", wechat_prompt)
@@ -46,6 +50,13 @@ class TestAgentInteraction(unittest.TestCase):
source="telegram-test",
username="tester",
)
web_agent_tools = MoviePilotToolFactory.create_tools(
session_id="session-web",
user_id="10001",
channel=MessageChannel.WebAgent.value,
source="web-agent",
username="tester",
)
wechat_tools = MoviePilotToolFactory.create_tools(
session_id="session-2",
user_id="10001",
@@ -55,6 +66,7 @@ class TestAgentInteraction(unittest.TestCase):
)
self.assertIn("ask_user_choice", [tool.name for tool in telegram_tools])
self.assertIn("ask_user_choice", [tool.name for tool in web_agent_tools])
self.assertNotIn("ask_user_choice", [tool.name for tool in wechat_tools])
def test_choice_tool_returns_direct_after_sending_interaction(self):
@@ -74,7 +86,7 @@ class TestAgentInteraction(unittest.TestCase):
tool.set_agent_context(agent_context={})
with patch(
"app.agent.tools.impl.ask_user_choice.ToolChain.async_post_message",
"app.agent.tools.base.ToolChain.async_post_message",
new=AsyncMock(),
) as async_post_message:
result = asyncio.run(
@@ -115,7 +127,7 @@ class TestAgentInteraction(unittest.TestCase):
)
with patch(
"app.agent.tools.impl.ask_user_choice.ToolChain.async_post_message",
"app.agent.tools.base.ToolChain.async_post_message",
new=AsyncMock(),
) as async_post_message:
result = asyncio.run(

View File

@@ -1,11 +1,17 @@
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
from app import schemas
from app.agent import ReplyMode
from app.api.endpoints.agent import (
_build_web_agent_session_id,
_WebAgentMoviePilotAgent,
_build_web_agent_notification_events,
_build_web_agent_session_id,
_resolve_web_agent_choice_payload,
_split_web_agent_output,
)
from app.helper.interaction import AgentInteractionOption, agent_interaction_manager
from app.schemas.message import ChannelCapability, ChannelCapabilityManager
from app.schemas.types import MessageChannel, NotificationType
@@ -44,8 +50,34 @@ def test_build_web_agent_session_id_is_stable_per_user_and_seed():
assert first.startswith("web-agent:")
def test_web_agent_admin_context_uses_current_user_id():
"""Web Agent 工具权限应按当前登录用户 ID 判断管理员身份。"""
agent = _WebAgentMoviePilotAgent(
session_id="web-agent:session",
user_id="7",
channel=MessageChannel.WebAgent.value,
source="web-agent",
username="normal-user",
replay_mode=ReplyMode.CAPTURE_ONLY,
)
with patch("app.api.endpoints.agent.UserOper") as user_oper:
user_oper.return_value.async_get_by_id = AsyncMock(
return_value=SimpleNamespace(is_superuser=True)
)
assert asyncio.run(agent._is_system_admin_context()) is True
user_oper.return_value.async_get_by_id.assert_awaited_once_with(7)
def test_web_agent_channel_supports_streaming_and_attachments():
"""WebAgent 渠道应声明流式、多媒体和文件发送能力。"""
assert ChannelCapabilityManager.supports_capability(
MessageChannel.WebAgent, ChannelCapability.INLINE_BUTTONS
)
assert ChannelCapabilityManager.supports_capability(
MessageChannel.WebAgent, ChannelCapability.CALLBACK_QUERIES
)
assert ChannelCapabilityManager.supports_capability(
MessageChannel.WebAgent, ChannelCapability.MESSAGE_EDITING
)
@@ -109,3 +141,81 @@ def test_build_web_agent_notification_events_registers_local_file(tmp_path):
assert attachment["mime_type"] == "text/plain"
assert attachment["size"] == 5
assert attachment["url"].startswith("message/agent/file/")
def test_build_web_agent_notification_events_extracts_choice_card():
"""Agent 按钮通知应转换为 Web 选择卡片事件而非普通文本。"""
events = _build_web_agent_notification_events(
schemas.Notification(
channel=MessageChannel.WebAgent,
mtype=NotificationType.Agent,
title="需要你的选择",
text="请选择要执行的操作",
buttons=[
[
{
"text": "继续下载",
"callback_data": "agent_interaction:choice:req-1:1",
}
],
[
{
"text": "查看详情",
"callback_data": "agent_interaction:choice:req-1:2",
}
],
],
)
)
assert events == [
{
"type": "choice",
"choice": {
"id": "req-1",
"title": "需要你的选择",
"prompt": "请选择要执行的操作",
"buttons": [
{
"label": "继续下载",
"callback_data": "agent_interaction:choice:req-1:1",
},
{
"label": "查看详情",
"callback_data": "agent_interaction:choice:req-1:2",
},
],
},
}
]
def test_resolve_web_agent_choice_payload_returns_next_message():
"""Web 按钮回调应解析为下一条用户消息并返回卡片反馈。"""
agent_interaction_manager.clear()
request = agent_interaction_manager.create_request(
session_id="web-agent:session",
user_id="1",
channel=MessageChannel.WebAgent.value,
source="web-agent",
username="admin",
title="需要你的选择",
prompt="请选择",
options=[
AgentInteractionOption(label="电影", value="我选择电影"),
AgentInteractionOption(label="电视剧", value="我选择电视剧"),
],
)
try:
result = _resolve_web_agent_choice_payload(
callback_data=f"agent_interaction:choice:{request.request_id}:2",
user_id="1",
)
finally:
agent_interaction_manager.clear()
assert result["message"] == "我选择电视剧"
assert result["session_id"] == "web-agent:session"
assert result["feedback"]["prompt"] == "请选择"
assert result["feedback"]["selected_label"] == "电视剧"