From e78efe3e34b3facb06e56dda6303ed9fa6751578 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Tue, 16 Jun 2026 22:53:11 +0800 Subject: [PATCH] feat: implement file upload and callback handling for Web Agent --- app/agent/tools/impl/ask_user_choice.py | 4 +- app/api/endpoints/agent.py | 277 ++++++++++++++++++++++-- app/db/user_oper.py | 6 + app/schemas/message.py | 17 +- tests/test_agent_interaction.py | 16 +- tests/test_web_agent_stream.py | 112 +++++++++- 6 files changed, 413 insertions(+), 19 deletions(-) diff --git a/app/agent/tools/impl/ask_user_choice.py b/app/agent/tools/impl/ask_user_choice.py index 37eed9bd..a69d6944 100644 --- a/app/agent/tools/impl/ask_user_choice.py +++ b/app/agent/tools/impl/ask_user_choice.py @@ -4,7 +4,7 @@ from typing import List, Optional, Type from pydantic import BaseModel, Field, model_validator -from app.agent.tools.base import MoviePilotTool, ToolChain +from app.agent.tools.base import MoviePilotTool from app.agent.tools.tags import ToolTag from app.helper.interaction import ( AgentInteractionOption, @@ -188,7 +188,7 @@ class AskUserChoiceTool(MoviePilotTool): len(choice_options), ) - await ToolChain().async_post_message( + await self.send_notification_message( Notification( channel=channel, source=self._source, diff --git a/app/api/endpoints/agent.py b/app/api/endpoints/agent.py index b8052aa0..70c5d9ca 100644 --- a/app/api/endpoints/agent.py +++ b/app/api/endpoints/agent.py @@ -7,14 +7,15 @@ import uuid from pathlib import Path from typing import Any, AsyncIterator, Callable, Optional -from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi import APIRouter, Depends, File, Form, HTTPException, Request, UploadFile, status from fastapi.responses import FileResponse, StreamingResponse from app import schemas from app.agent import MoviePilotAgent, ReplyMode, StreamingHandler from app.core.config import global_vars, settings from app.db.models import User -from app.db.user_oper import get_current_active_superuser +from app.db.user_oper import UserOper, get_current_active_user +from app.helper.interaction import agent_interaction_manager from app.log import logger from app.schemas.types import MessageChannel @@ -24,6 +25,8 @@ WEB_AGENT_SESSION_PREFIX = "web-agent:" WEB_AGENT_SOURCE = "web-agent" WEB_AGENT_FILE_TTL_SECONDS = 6 * 60 * 60 WEB_AGENT_FILE_MAX_ITEMS = 256 +WEB_AGENT_UPLOAD_MAX_BYTES = 32 * 1024 * 1024 +WEB_AGENT_UPLOAD_CHUNK_SIZE = 1024 * 1024 _WEB_AGENT_FILE_REGISTRY: dict[str, dict[str, Any]] = {} @@ -113,8 +116,17 @@ class _WebAgentMoviePilotAgent(MoviePilotAgent): return True async def _is_system_admin_context(self) -> bool: - """Web Agent 入口已要求超级管理员,工具上下文可直接按管理员处理。""" - return True + """Web Agent 根据当前登录用户 ID 判断工具管理员上下文。""" + if not self.user_id: + return False + try: + user = await UserOper().async_get_by_id(int(self.user_id)) + except (TypeError, ValueError): + return False + except Exception as e: + logger.error(f"检查 Web Agent 用户管理员身份失败: {e}") + return False + return bool(user and user.is_superuser) async def _build_tool_context(self, should_dispatch_reply: bool) -> dict[str, object]: """向工具上下文注入 Web SSE 通知回调。""" @@ -153,6 +165,73 @@ def _build_web_agent_sse(event_type: str, data: Optional[dict] = None) -> str: return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" +def _sanitize_web_agent_upload_name( + filename: Optional[str], mime_type: Optional[str] = None +) -> str: + """ + 规范化 Web Agent 上传文件名,避免路径穿越和空文件名。 + + :param filename: 浏览器上传的原始文件名 + :param mime_type: 浏览器上报的 MIME 类型 + :return: 可安全落盘的文件名 + """ + name = Path(filename or "attachment").name.strip() + safe_name = "".join( + char for char in name if char.isalnum() or char in (" ", ".", "_", "-") + ).strip(" .") + if not safe_name: + safe_name = "attachment" + if "." not in safe_name: + suffix = mimetypes.guess_extension(mime_type or "") or "" + safe_name = f"{safe_name}{suffix}" + return safe_name + + +def _get_web_agent_upload_dir(user: User, session_id: Optional[str]) -> Path: + """ + 计算当前 Web Agent 会话的临时附件目录。 + + :param user: 当前登录用户 + :param session_id: 前端会话标识 + :return: 已创建的临时附件目录 + """ + server_session_id = _build_web_agent_session_id(user, session_id) + safe_session_id = server_session_id.replace(":", "_") + upload_dir = settings.TEMP_PATH / "agent_uploads" / safe_session_id + upload_dir.mkdir(parents=True, exist_ok=True) + return upload_dir + + +async def _save_web_agent_upload(upload_file: UploadFile, target_path: Path) -> int: + """ + 分块保存 Web Agent 上传文件,并限制单文件体积。 + + :param upload_file: FastAPI 上传文件对象 + :param target_path: 目标落盘路径 + :return: 已写入的字节数 + """ + size = 0 + try: + with target_path.open("wb") as output: + while True: + chunk = await upload_file.read(WEB_AGENT_UPLOAD_CHUNK_SIZE) + if not chunk: + break + size += len(chunk) + if size > WEB_AGENT_UPLOAD_MAX_BYTES: + raise HTTPException( + status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, + detail="附件超过 32MB,无法发送给智能助手", + ) + output.write(chunk) + except Exception: + target_path.unlink(missing_ok=True) + raise + finally: + await upload_file.close() + return size + + def _cleanup_web_agent_file_registry() -> None: """清理过期或过量的 Web Agent 临时附件引用。""" now = time.time() @@ -220,6 +299,7 @@ def _register_web_agent_file( file_path: Optional[str], file_name: Optional[str] = None, kind: Optional[str] = None, + mime_type: Optional[str] = None, ) -> Optional[dict]: """ 注册 Web Agent 本地附件并返回前端可访问的短期下载地址。 @@ -227,6 +307,7 @@ def _register_web_agent_file( :param file_path: 本地文件路径 :param file_name: 前端展示文件名 :param kind: 附件展示类型 + :param mime_type: 已知 MIME 类型 :return: 前端附件描述,文件不可访问时返回 None """ if not file_path: @@ -241,24 +322,129 @@ def _register_web_agent_file( _cleanup_web_agent_file_registry() file_id = uuid.uuid4().hex display_name = file_name or resolved_path.name - mime_type = mimetypes.guess_type(display_name or str(resolved_path))[0] + resolved_mime_type = mime_type or mimetypes.guess_type( + display_name or str(resolved_path) + )[0] file_url = f"message/agent/file/{file_id}" _WEB_AGENT_FILE_REGISTRY[file_id] = { "path": resolved_path, "name": display_name, - "mime_type": mime_type or "application/octet-stream", + "mime_type": resolved_mime_type or "application/octet-stream", "created_at": time.time(), } return { - "kind": kind or _guess_web_agent_attachment_kind(mime_type), + "kind": kind or _guess_web_agent_attachment_kind(resolved_mime_type), "url": file_url, "download_url": file_url, "name": display_name, - "mime_type": mime_type, + "mime_type": resolved_mime_type, "size": resolved_path.stat().st_size, } +def _parse_web_agent_choice_callback(callback_data: str) -> Optional[tuple[str, int]]: + """ + 解析 Web Agent 按钮选择回调数据。 + + :param callback_data: Agent 按钮携带的回调数据 + :return: 请求 ID 与选项序号,格式无效时返回 None + """ + if not callback_data.startswith("agent_interaction:choice:"): + return None + try: + _, _, request_id, option_index = callback_data.split(":", 3) + except ValueError: + return None + if not request_id or not option_index.isdigit(): + return None + return request_id, int(option_index) + + +def _flatten_web_agent_choice_buttons(buttons: Optional[list[list[dict]]]) -> list[dict]: + """ + 将消息渠道按钮二维结构转换为 Web 前端可渲染的一维选项列表。 + + :param buttons: Notification 中的按钮行 + :return: Web 选择卡片按钮列表 + """ + flattened = [] + for row in buttons or []: + for button in row or []: + text = str(button.get("text") or "").strip() + callback_data = str(button.get("callback_data") or "").strip() + if not text or not callback_data: + continue + flattened.append( + { + "label": text, + "callback_data": callback_data, + } + ) + return flattened + + +def _build_web_agent_choice_event(notification: schemas.Notification) -> Optional[dict]: + """ + 将带按钮通知转换为 Web Agent 选择卡片事件。 + + :param notification: Agent 工具发出的按钮通知 + :return: 选择卡片事件,按钮为空时返回 None + """ + buttons = _flatten_web_agent_choice_buttons(notification.buttons) + if not buttons: + return None + + choice_id = None + parsed = _parse_web_agent_choice_callback(buttons[0]["callback_data"]) + if parsed: + choice_id = parsed[0] + + return { + "type": "choice", + "choice": { + "id": choice_id or uuid.uuid4().hex, + "title": notification.title, + "prompt": notification.text or "", + "buttons": buttons, + }, + } + + +def _resolve_web_agent_choice_payload(callback_data: str, user_id: str) -> Optional[dict]: + """ + 解析并消费 Web Agent 按钮选择,生成前端反馈与下一条用户消息。 + + :param callback_data: 前端点击的按钮回调数据 + :param user_id: 当前登录用户 ID + :return: 可返回给前端的数据,选择无效时返回 None + """ + parsed = _parse_web_agent_choice_callback(callback_data) + if not parsed: + return None + + request_id, option_index = parsed + resolved = agent_interaction_manager.resolve( + request_id=request_id, + option_index=option_index, + user_id=str(user_id), + ) + if not resolved: + return None + + request, option = resolved + return { + "message": option.value, + "session_id": request.session_id, + "feedback": { + "request_id": request.request_id, + "title": request.title, + "prompt": request.prompt, + "selected_label": option.label, + "selected_value": option.value, + }, + } + + def _build_web_agent_notification_events( notification: schemas.Notification, ) -> list[dict]: @@ -269,12 +455,16 @@ def _build_web_agent_notification_events( :return: 前端可直接应用到当前助手消息的事件列表 """ events = [] + choice_event = _build_web_agent_choice_event(notification) + if choice_event: + events.append(choice_event) + text_parts = [ str(item).strip() for item in (notification.title, notification.text) if str(item or "").strip() ] - if text_parts: + if text_parts and not choice_event: events.append({"type": "delta", "content": "\n\n".join(text_parts)}) if notification.image: @@ -402,18 +592,79 @@ async def download_web_agent_file(file_id: str) -> FileResponse: ) +@router.post("/upload", summary="上传 Web 智能助手附件", response_model=schemas.Response) +async def upload_web_agent_file( + file: UploadFile = File(...), + session_id: Optional[str] = Form(None), + current_user: User = Depends(get_current_active_user), +) -> schemas.Response: + """ + 上传 Web 智能助手对话附件。 + + :param file: 浏览器选择的文件 + :param session_id: 前端会话标识 + :param current_user: 当前登录用户 + :return: Agent 可消费的附件描述 + """ + mime_type = file.content_type or mimetypes.guess_type(file.filename or "")[0] + safe_name = _sanitize_web_agent_upload_name(file.filename, mime_type) + upload_dir = _get_web_agent_upload_dir(current_user, session_id) + target_path = upload_dir / f"{uuid.uuid4().hex[:8]}_{safe_name}" + size = await _save_web_agent_upload(file, target_path) + attachment = _register_web_agent_file( + str(target_path), + file_name=safe_name, + kind=_guess_web_agent_attachment_kind(mime_type), + mime_type=mime_type, + ) + if not attachment: + target_path.unlink(missing_ok=True) + return schemas.Response(success=False, message="附件保存失败") + + attachment.update( + { + "ref": attachment["url"], + "local_path": str(target_path), + "status": "ready", + "size": size, + } + ) + return schemas.Response(success=True, data=attachment) + + +@router.post("/callback", summary="Web 智能助手按钮回调", response_model=schemas.Response) +async def web_agent_callback( + payload: schemas.AgentWebChoiceRequest, + current_user: User = Depends(get_current_active_user), +) -> schemas.Response: + """ + 接收 Web 智能助手选择卡片回调。 + + :param payload: 按钮选择请求 + :param current_user: 当前登录用户 + :return: 下一条需要发送给 Agent 的用户消息与卡片反馈 + """ + result = _resolve_web_agent_choice_payload( + callback_data=payload.callback_data, + user_id=str(current_user.id), + ) + if not result: + return schemas.Response(success=False, message="该选择已失效,请重新发起选择") + return schemas.Response(success=True, data=result) + + @router.post("/stream", summary="Web智能助手流式对话") async def web_agent_stream( payload: schemas.AgentWebChatRequest, request: Request, - current_user: User = Depends(get_current_active_superuser), + current_user: User = Depends(get_current_active_user), ) -> StreamingResponse: """ Web 智能助手流式对话。 :param payload: 对话请求 :param request: 当前 HTTP 请求 - :param current_user: 当前登录管理员 + :param current_user: 当前登录用户 :return: SSE 流式响应 """ if not settings.AI_AGENT_ENABLE: @@ -428,12 +679,12 @@ async def web_agent_stream( ) prompt = payload.text.strip() - if not prompt: + if not prompt and not payload.images and not payload.files and not payload.audio_refs: return StreamingResponse( iter([ _build_web_agent_sse( "error", - {"message": "请输入要发送给智能助手的内容。"}, + {"message": "请输入要发送给智能助手的内容或选择附件。"}, ) ]), media_type="text/event-stream", diff --git a/app/db/user_oper.py b/app/db/user_oper.py index c8d3f9ff..c413a61a 100644 --- a/app/db/user_oper.py +++ b/app/db/user_oper.py @@ -114,6 +114,12 @@ class UserOper(DbOper): """ return await User.async_get_by_name(self._db, name) + async def async_get_by_id(self, user_id: int) -> User: + """ + 异步根据用户 ID 获取用户。 + """ + return await User.async_get_by_id(self._db, user_id) + def get_permissions(self, name: str) -> dict: """ 获取用户权限 diff --git a/app/schemas/message.py b/app/schemas/message.py index cab6e728..7da7c3a2 100644 --- a/app/schemas/message.py +++ b/app/schemas/message.py @@ -264,9 +264,11 @@ class AgentWebChatRequest(BaseModel): name: Optional[str] = Field(None) mime_type: Optional[str] = Field(None) size: Optional[int] = Field(None) + local_path: Optional[str] = Field(None) + status: Optional[str] = Field(None) # 用户本轮输入 - text: str = Field(..., min_length=1) + text: str = Field(default="") # 前端会话标识,相同标识复用同一段 Agent 记忆 session_id: Optional[str] = Field(None) # 图片 URL 或 data URL 列表 @@ -277,6 +279,17 @@ class AgentWebChatRequest(BaseModel): files: Optional[List[AgentWebChatFile]] = Field(default_factory=list) +class AgentWebChoiceRequest(BaseModel): + """ + Web 智能助手按钮选择请求。 + """ + + # 前端会话标识,用于保持与原对话窗口的关联 + session_id: Optional[str] = Field(None) + # Agent 工具生成的按钮回调数据 + callback_data: str = Field(..., min_length=1) + + class ChannelCapability(Enum): """ 渠道能力枚举 @@ -474,6 +487,8 @@ class ChannelCapabilityManager: MessageChannel.WebAgent: ChannelCapabilities( channel=MessageChannel.WebAgent, capabilities={ + ChannelCapability.INLINE_BUTTONS, + ChannelCapability.CALLBACK_QUERIES, ChannelCapability.MESSAGE_EDITING, ChannelCapability.MARKDOWN, ChannelCapability.RICH_TEXT, diff --git a/tests/test_agent_interaction.py b/tests/test_agent_interaction.py index 9658618d..3e309c6f 100644 --- a/tests/test_agent_interaction.py +++ b/tests/test_agent_interaction.py @@ -25,11 +25,15 @@ class TestAgentInteraction(unittest.TestCase): telegram_prompt = prompt_manager.get_agent_prompt( channel=MessageChannel.Telegram.value ) + web_agent_prompt = prompt_manager.get_agent_prompt( + channel=MessageChannel.WebAgent.value + ) wechat_prompt = prompt_manager.get_agent_prompt( channel=MessageChannel.Wechat.value ) self.assertIn("ask_user_choice", telegram_prompt) + self.assertIn("ask_user_choice", web_agent_prompt) self.assertIn("terminal interaction tool", telegram_prompt) self.assertIn("do not write a final text reply after it", telegram_prompt) self.assertNotIn("ask_user_choice", wechat_prompt) @@ -46,6 +50,13 @@ class TestAgentInteraction(unittest.TestCase): source="telegram-test", username="tester", ) + web_agent_tools = MoviePilotToolFactory.create_tools( + session_id="session-web", + user_id="10001", + channel=MessageChannel.WebAgent.value, + source="web-agent", + username="tester", + ) wechat_tools = MoviePilotToolFactory.create_tools( session_id="session-2", user_id="10001", @@ -55,6 +66,7 @@ class TestAgentInteraction(unittest.TestCase): ) self.assertIn("ask_user_choice", [tool.name for tool in telegram_tools]) + self.assertIn("ask_user_choice", [tool.name for tool in web_agent_tools]) self.assertNotIn("ask_user_choice", [tool.name for tool in wechat_tools]) def test_choice_tool_returns_direct_after_sending_interaction(self): @@ -74,7 +86,7 @@ class TestAgentInteraction(unittest.TestCase): tool.set_agent_context(agent_context={}) with patch( - "app.agent.tools.impl.ask_user_choice.ToolChain.async_post_message", + "app.agent.tools.base.ToolChain.async_post_message", new=AsyncMock(), ) as async_post_message: result = asyncio.run( @@ -115,7 +127,7 @@ class TestAgentInteraction(unittest.TestCase): ) with patch( - "app.agent.tools.impl.ask_user_choice.ToolChain.async_post_message", + "app.agent.tools.base.ToolChain.async_post_message", new=AsyncMock(), ) as async_post_message: result = asyncio.run( diff --git a/tests/test_web_agent_stream.py b/tests/test_web_agent_stream.py index 29711c8a..44828e83 100644 --- a/tests/test_web_agent_stream.py +++ b/tests/test_web_agent_stream.py @@ -1,11 +1,17 @@ +import asyncio from types import SimpleNamespace +from unittest.mock import AsyncMock, patch from app import schemas +from app.agent import ReplyMode from app.api.endpoints.agent import ( - _build_web_agent_session_id, + _WebAgentMoviePilotAgent, _build_web_agent_notification_events, + _build_web_agent_session_id, + _resolve_web_agent_choice_payload, _split_web_agent_output, ) +from app.helper.interaction import AgentInteractionOption, agent_interaction_manager from app.schemas.message import ChannelCapability, ChannelCapabilityManager from app.schemas.types import MessageChannel, NotificationType @@ -44,8 +50,34 @@ def test_build_web_agent_session_id_is_stable_per_user_and_seed(): assert first.startswith("web-agent:") +def test_web_agent_admin_context_uses_current_user_id(): + """Web Agent 工具权限应按当前登录用户 ID 判断管理员身份。""" + agent = _WebAgentMoviePilotAgent( + session_id="web-agent:session", + user_id="7", + channel=MessageChannel.WebAgent.value, + source="web-agent", + username="normal-user", + replay_mode=ReplyMode.CAPTURE_ONLY, + ) + + with patch("app.api.endpoints.agent.UserOper") as user_oper: + user_oper.return_value.async_get_by_id = AsyncMock( + return_value=SimpleNamespace(is_superuser=True) + ) + + assert asyncio.run(agent._is_system_admin_context()) is True + user_oper.return_value.async_get_by_id.assert_awaited_once_with(7) + + def test_web_agent_channel_supports_streaming_and_attachments(): """WebAgent 渠道应声明流式、多媒体和文件发送能力。""" + assert ChannelCapabilityManager.supports_capability( + MessageChannel.WebAgent, ChannelCapability.INLINE_BUTTONS + ) + assert ChannelCapabilityManager.supports_capability( + MessageChannel.WebAgent, ChannelCapability.CALLBACK_QUERIES + ) assert ChannelCapabilityManager.supports_capability( MessageChannel.WebAgent, ChannelCapability.MESSAGE_EDITING ) @@ -109,3 +141,81 @@ def test_build_web_agent_notification_events_registers_local_file(tmp_path): assert attachment["mime_type"] == "text/plain" assert attachment["size"] == 5 assert attachment["url"].startswith("message/agent/file/") + + +def test_build_web_agent_notification_events_extracts_choice_card(): + """Agent 按钮通知应转换为 Web 选择卡片事件而非普通文本。""" + events = _build_web_agent_notification_events( + schemas.Notification( + channel=MessageChannel.WebAgent, + mtype=NotificationType.Agent, + title="需要你的选择", + text="请选择要执行的操作", + buttons=[ + [ + { + "text": "继续下载", + "callback_data": "agent_interaction:choice:req-1:1", + } + ], + [ + { + "text": "查看详情", + "callback_data": "agent_interaction:choice:req-1:2", + } + ], + ], + ) + ) + + assert events == [ + { + "type": "choice", + "choice": { + "id": "req-1", + "title": "需要你的选择", + "prompt": "请选择要执行的操作", + "buttons": [ + { + "label": "继续下载", + "callback_data": "agent_interaction:choice:req-1:1", + }, + { + "label": "查看详情", + "callback_data": "agent_interaction:choice:req-1:2", + }, + ], + }, + } + ] + + +def test_resolve_web_agent_choice_payload_returns_next_message(): + """Web 按钮回调应解析为下一条用户消息并返回卡片反馈。""" + agent_interaction_manager.clear() + request = agent_interaction_manager.create_request( + session_id="web-agent:session", + user_id="1", + channel=MessageChannel.WebAgent.value, + source="web-agent", + username="admin", + title="需要你的选择", + prompt="请选择", + options=[ + AgentInteractionOption(label="电影", value="我选择电影"), + AgentInteractionOption(label="电视剧", value="我选择电视剧"), + ], + ) + + try: + result = _resolve_web_agent_choice_payload( + callback_data=f"agent_interaction:choice:{request.request_id}:2", + user_id="1", + ) + finally: + agent_interaction_manager.clear() + + assert result["message"] == "我选择电视剧" + assert result["session_id"] == "web-agent:session" + assert result["feedback"]["prompt"] == "请选择" + assert result["feedback"]["selected_label"] == "电视剧"