mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-21 23:44:31 +08:00
feat: add WebAgent support for streaming and notifications
This commit is contained in:
@@ -612,13 +612,30 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
|
||||
return False
|
||||
|
||||
async def send_notification_message(self, notification: Notification) -> None:
|
||||
"""
|
||||
发送工具通知消息。
|
||||
|
||||
WebAgent 渠道没有后端模块实例,前端流式面板通过 Agent 上下文中的
|
||||
回调直接接收通知;其它渠道继续走统一消息链。
|
||||
"""
|
||||
callback = self._agent_context.get("notification_callback")
|
||||
if (
|
||||
self._channel == MessageChannel.WebAgent.value
|
||||
and callable(callback)
|
||||
):
|
||||
callback(notification)
|
||||
return
|
||||
|
||||
await ToolChain().async_post_message(notification)
|
||||
|
||||
async def send_tool_message(
|
||||
self, message: str, title: str = "", image: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
发送工具消息
|
||||
"""
|
||||
await ToolChain().async_post_message(
|
||||
await self.send_notification_message(
|
||||
Notification(
|
||||
channel=self._channel,
|
||||
source=self._source,
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotificationType
|
||||
@@ -96,7 +96,7 @@ class SendLocalFileTool(MoviePilotTool):
|
||||
resolved_path,
|
||||
)
|
||||
|
||||
await ToolChain().async_post_message(
|
||||
await self.send_notification_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=self._source,
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.llm.capability import AgentCapabilityManager
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.tags import ToolTag
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
@@ -86,7 +86,7 @@ class SendVoiceMessageTool(MoviePilotTool):
|
||||
f"use_voice={used_voice}, text_len={len(message)}"
|
||||
)
|
||||
|
||||
await ToolChain().async_post_message(
|
||||
await self.send_notification_message(
|
||||
Notification(
|
||||
channel=self._channel,
|
||||
source=self._source,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.endpoints import auth, login, user, webhook, message, site, subscribe, \
|
||||
from app.api.endpoints import auth, login, user, webhook, message, agent, site, subscribe, \
|
||||
media, douban, search, plugin, tmdb, history, system, download, dashboard, \
|
||||
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa, openai, anthropic, llm, notification
|
||||
|
||||
@@ -11,6 +11,7 @@ api_router.include_router(user.router, prefix="/user", tags=["user"])
|
||||
api_router.include_router(mfa.router, prefix="/mfa", tags=["mfa"])
|
||||
api_router.include_router(site.router, prefix="/site", tags=["site"])
|
||||
api_router.include_router(message.router, prefix="/message", tags=["message"])
|
||||
api_router.include_router(agent.router, prefix="/message/agent", tags=["agent"])
|
||||
api_router.include_router(webhook.router, prefix="/webhook", tags=["webhook"])
|
||||
api_router.include_router(subscribe.router, prefix="/subscribe", tags=["subscribe"])
|
||||
api_router.include_router(media.router, prefix="/media", tags=["media"])
|
||||
|
||||
536
app/api/endpoints/agent.py
Normal file
536
app/api/endpoints/agent.py
Normal file
@@ -0,0 +1,536 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import mimetypes
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, AsyncIterator, Callable, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
|
||||
from app import schemas
|
||||
from app.agent import MoviePilotAgent, ReplyMode, StreamingHandler
|
||||
from app.core.config import global_vars, settings
|
||||
from app.db.models import User
|
||||
from app.db.user_oper import get_current_active_superuser
|
||||
from app.log import logger
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
WEB_AGENT_SESSION_PREFIX = "web-agent:"
|
||||
WEB_AGENT_SOURCE = "web-agent"
|
||||
WEB_AGENT_FILE_TTL_SECONDS = 6 * 60 * 60
|
||||
WEB_AGENT_FILE_MAX_ITEMS = 256
|
||||
_WEB_AGENT_FILE_REGISTRY: dict[str, dict[str, Any]] = {}
|
||||
|
||||
|
||||
class _WebAgentStreamingHandler(StreamingHandler):
|
||||
"""
|
||||
Web 前端专用流式处理器,将工具提示和文本统一回调给 SSE。
|
||||
"""
|
||||
|
||||
def __init__(self, on_emit: Callable[[str], None]) -> None:
|
||||
super().__init__()
|
||||
self._on_emit = on_emit
|
||||
|
||||
def emit(self, token: str) -> str:
|
||||
"""追加 token 并同步通知 SSE 生产者。"""
|
||||
emitted = super().emit(token)
|
||||
if emitted:
|
||||
self._on_emit(emitted)
|
||||
return emitted
|
||||
|
||||
def flush_pending_tool_summary(self) -> str:
|
||||
"""输出延迟聚合的工具摘要。"""
|
||||
emitted = super().flush_pending_tool_summary()
|
||||
if emitted:
|
||||
self._on_emit(emitted)
|
||||
return ""
|
||||
|
||||
async def start_streaming(
|
||||
self,
|
||||
channel: Optional[str] = None,
|
||||
source: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
original_message_id: Optional[str] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
title: str = "",
|
||||
) -> None:
|
||||
"""Web SSE 自身负责外发,不启动消息模块编辑循环。"""
|
||||
self._channel = channel
|
||||
self._source = source
|
||||
self._user_id = user_id
|
||||
self._username = username
|
||||
self._original_message_id = original_message_id
|
||||
self._original_chat_id = original_chat_id
|
||||
self._title = title
|
||||
self._streaming_enabled = True
|
||||
self._sent_text = ""
|
||||
self._message_response = None
|
||||
self._msg_start_offset = 0
|
||||
self._pending_tool_stats = {}
|
||||
|
||||
async def stop_streaming(self) -> tuple[bool, str]:
|
||||
"""停止 Web SSE 流式状态,保留缓冲区给 Agent 收口逻辑去重。"""
|
||||
if not self._streaming_enabled:
|
||||
return False, ""
|
||||
self._streaming_enabled = False
|
||||
self.flush_pending_tool_summary()
|
||||
with self._lock:
|
||||
self._sent_text = ""
|
||||
self._message_response = None
|
||||
self._msg_start_offset = 0
|
||||
self._pending_tool_stats = {}
|
||||
return False, ""
|
||||
|
||||
@property
|
||||
def is_auto_flushing(self) -> bool:
|
||||
"""让工具执行提示进入缓冲区,由 SSE 回调负责外发。"""
|
||||
return True
|
||||
|
||||
|
||||
class _WebAgentMoviePilotAgent(MoviePilotAgent):
|
||||
"""
|
||||
Web 前端专用 Agent,强制使用流式推理。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
notification_callback: Optional[Callable[[schemas.Notification], None]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._notification_callback = notification_callback
|
||||
self.stream_handler = _WebAgentStreamingHandler(self._emit_output)
|
||||
|
||||
def _should_stream(self) -> bool:
|
||||
"""Web 面板需要实时输出,即使 Web 渠道本身不支持消息编辑。"""
|
||||
return True
|
||||
|
||||
async def _is_system_admin_context(self) -> bool:
|
||||
"""Web Agent 入口已要求超级管理员,工具上下文可直接按管理员处理。"""
|
||||
return True
|
||||
|
||||
async def _build_tool_context(self, should_dispatch_reply: bool) -> dict[str, object]:
|
||||
"""向工具上下文注入 Web SSE 通知回调。"""
|
||||
context = await super()._build_tool_context(should_dispatch_reply)
|
||||
context["notification_callback"] = self._notification_callback
|
||||
return context
|
||||
|
||||
def _handle_stream_text(self, text: str) -> None:
|
||||
"""文本输出交由 Web 流式处理器统一回调,避免重复增量。"""
|
||||
self.stream_handler.emit(text)
|
||||
|
||||
|
||||
def _build_web_agent_session_id(user: User, session_id: Optional[str]) -> str:
|
||||
"""
|
||||
构建前端 Agent 会话 ID。
|
||||
|
||||
:param user: 当前登录用户
|
||||
:param session_id: 前端传入的会话标识
|
||||
:return: 可用于 Agent 记忆隔离的服务端会话 ID
|
||||
"""
|
||||
seed = str(session_id or "").strip() or uuid.uuid4().hex
|
||||
user_part = user.name or str(user.id)
|
||||
digest = hashlib.sha256(f"{user_part}:{seed}".encode("utf-8")).hexdigest()
|
||||
return f"{WEB_AGENT_SESSION_PREFIX}{digest[:32]}"
|
||||
|
||||
|
||||
def _build_web_agent_sse(event_type: str, data: Optional[dict] = None) -> str:
|
||||
"""
|
||||
构建 Web Agent SSE 消息。
|
||||
|
||||
:param event_type: 前端事件类型
|
||||
:param data: 事件数据
|
||||
:return: 符合 SSE 格式的字符串
|
||||
"""
|
||||
payload = {"type": event_type, **(data or {})}
|
||||
return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
def _cleanup_web_agent_file_registry() -> None:
|
||||
"""清理过期或过量的 Web Agent 临时附件引用。"""
|
||||
now = time.time()
|
||||
expired_ids = [
|
||||
file_id
|
||||
for file_id, info in _WEB_AGENT_FILE_REGISTRY.items()
|
||||
if now - info.get("created_at", now) > WEB_AGENT_FILE_TTL_SECONDS
|
||||
]
|
||||
for file_id in expired_ids:
|
||||
_WEB_AGENT_FILE_REGISTRY.pop(file_id, None)
|
||||
|
||||
overflow = len(_WEB_AGENT_FILE_REGISTRY) - WEB_AGENT_FILE_MAX_ITEMS
|
||||
if overflow <= 0:
|
||||
return
|
||||
sorted_items = sorted(
|
||||
_WEB_AGENT_FILE_REGISTRY.items(),
|
||||
key=lambda item: item[1].get("created_at", 0),
|
||||
)
|
||||
for file_id, _ in sorted_items[:overflow]:
|
||||
_WEB_AGENT_FILE_REGISTRY.pop(file_id, None)
|
||||
|
||||
|
||||
def _guess_web_agent_attachment_kind(
|
||||
mime_type: Optional[str], fallback: str = "file"
|
||||
) -> str:
|
||||
"""
|
||||
根据 MIME 类型推断前端附件展示方式。
|
||||
|
||||
:param mime_type: 文件 MIME 类型
|
||||
:param fallback: 无法推断时使用的类型
|
||||
:return: image、audio 或 file
|
||||
"""
|
||||
if mime_type and mime_type.startswith("image/"):
|
||||
return "image"
|
||||
if mime_type and mime_type.startswith("audio/"):
|
||||
return "audio"
|
||||
return fallback
|
||||
|
||||
|
||||
def _build_web_agent_url_attachment(
|
||||
url: str,
|
||||
kind: str,
|
||||
name: Optional[str] = None,
|
||||
mime_type: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
构建远程或 data URL 附件事件。
|
||||
|
||||
:param url: 前端可访问的附件地址
|
||||
:param kind: 附件展示类型
|
||||
:param name: 展示名称
|
||||
:param mime_type: MIME 类型
|
||||
:return: 前端附件描述
|
||||
"""
|
||||
return {
|
||||
"kind": kind,
|
||||
"url": url,
|
||||
"download_url": url,
|
||||
"name": name,
|
||||
"mime_type": mime_type,
|
||||
}
|
||||
|
||||
|
||||
def _register_web_agent_file(
|
||||
file_path: Optional[str],
|
||||
file_name: Optional[str] = None,
|
||||
kind: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
注册 Web Agent 本地附件并返回前端可访问的短期下载地址。
|
||||
|
||||
:param file_path: 本地文件路径
|
||||
:param file_name: 前端展示文件名
|
||||
:param kind: 附件展示类型
|
||||
:return: 前端附件描述,文件不可访问时返回 None
|
||||
"""
|
||||
if not file_path:
|
||||
return None
|
||||
try:
|
||||
resolved_path = Path(file_path).expanduser().resolve(strict=True)
|
||||
except OSError:
|
||||
return None
|
||||
if not resolved_path.is_file():
|
||||
return None
|
||||
|
||||
_cleanup_web_agent_file_registry()
|
||||
file_id = uuid.uuid4().hex
|
||||
display_name = file_name or resolved_path.name
|
||||
mime_type = mimetypes.guess_type(display_name or str(resolved_path))[0]
|
||||
file_url = f"message/agent/file/{file_id}"
|
||||
_WEB_AGENT_FILE_REGISTRY[file_id] = {
|
||||
"path": resolved_path,
|
||||
"name": display_name,
|
||||
"mime_type": mime_type or "application/octet-stream",
|
||||
"created_at": time.time(),
|
||||
}
|
||||
return {
|
||||
"kind": kind or _guess_web_agent_attachment_kind(mime_type),
|
||||
"url": file_url,
|
||||
"download_url": file_url,
|
||||
"name": display_name,
|
||||
"mime_type": mime_type,
|
||||
"size": resolved_path.stat().st_size,
|
||||
}
|
||||
|
||||
|
||||
def _build_web_agent_notification_events(
|
||||
notification: schemas.Notification,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
将 Agent 工具通知转换为 Web SSE 事件。
|
||||
|
||||
:param notification: 工具产生的通知消息
|
||||
:return: 前端可直接应用到当前助手消息的事件列表
|
||||
"""
|
||||
events = []
|
||||
text_parts = [
|
||||
str(item).strip()
|
||||
for item in (notification.title, notification.text)
|
||||
if str(item or "").strip()
|
||||
]
|
||||
if text_parts:
|
||||
events.append({"type": "delta", "content": "\n\n".join(text_parts)})
|
||||
|
||||
if notification.image:
|
||||
image_ref = notification.image
|
||||
image_path = Path(image_ref).expanduser()
|
||||
attachment = None
|
||||
if not image_ref.startswith(("http://", "https://", "data:", "blob:")):
|
||||
attachment = _register_web_agent_file(
|
||||
image_ref, file_name=Path(image_ref).name, kind="image"
|
||||
)
|
||||
if not attachment:
|
||||
attachment = _build_web_agent_url_attachment(
|
||||
image_ref,
|
||||
kind="image",
|
||||
name=notification.title or image_path.name or "image",
|
||||
)
|
||||
events.append({"type": "attachment", "attachment": attachment})
|
||||
|
||||
if notification.voice_path:
|
||||
attachment = _register_web_agent_file(
|
||||
notification.voice_path,
|
||||
file_name=Path(notification.voice_path).name,
|
||||
kind="audio",
|
||||
)
|
||||
if attachment:
|
||||
events.append({"type": "attachment", "attachment": attachment})
|
||||
|
||||
if notification.file_path:
|
||||
attachment = _register_web_agent_file(
|
||||
notification.file_path,
|
||||
file_name=notification.file_name or Path(notification.file_path).name,
|
||||
)
|
||||
if attachment:
|
||||
events.append({"type": "attachment", "attachment": attachment})
|
||||
|
||||
return events
|
||||
|
||||
|
||||
def _split_web_agent_output(text: str) -> list[dict]:
|
||||
"""
|
||||
将 Agent 输出拆成普通文本与工具提示事件。
|
||||
|
||||
:param text: 本次新增的 Agent 文本
|
||||
:return: 前端可直接渲染的事件片段
|
||||
"""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
events = []
|
||||
|
||||
def append_text(content: str) -> None:
|
||||
"""将工具汇总行从普通文本中拆出,便于前端独立展示。"""
|
||||
if not content or not content.strip():
|
||||
return
|
||||
lines = content.splitlines(keepends=True)
|
||||
buffer = ""
|
||||
for line in lines:
|
||||
stripped_line = line.strip()
|
||||
if (
|
||||
stripped_line.startswith("(")
|
||||
and stripped_line.endswith(")")
|
||||
):
|
||||
if buffer:
|
||||
events.append({"type": "delta", "content": buffer})
|
||||
buffer = ""
|
||||
events.append(
|
||||
{
|
||||
"type": "tool",
|
||||
"message": stripped_line.strip("()"),
|
||||
}
|
||||
)
|
||||
else:
|
||||
buffer += line
|
||||
if buffer:
|
||||
events.append({"type": "delta", "content": buffer})
|
||||
|
||||
marker = "⚙️ => "
|
||||
remaining = text
|
||||
while remaining:
|
||||
marker_index = remaining.find(marker)
|
||||
if marker_index < 0:
|
||||
append_text(remaining)
|
||||
break
|
||||
|
||||
if marker_index > 0:
|
||||
append_text(remaining[:marker_index])
|
||||
|
||||
after_marker = remaining[marker_index + len(marker):]
|
||||
line_end = after_marker.find("\n")
|
||||
if line_end < 0:
|
||||
message = after_marker.strip()
|
||||
remaining = ""
|
||||
else:
|
||||
message = after_marker[:line_end].strip()
|
||||
remaining = after_marker[line_end:].lstrip("\n")
|
||||
|
||||
if message:
|
||||
events.append({"type": "tool", "message": message})
|
||||
|
||||
return events
|
||||
|
||||
|
||||
@router.get("/file/{file_id}", summary="下载 Web 智能助手附件")
|
||||
async def download_web_agent_file(file_id: str) -> FileResponse:
|
||||
"""
|
||||
下载 Web 智能助手本轮生成的临时附件。
|
||||
|
||||
:param file_id: 附件随机标识
|
||||
:return: 附件文件响应
|
||||
"""
|
||||
_cleanup_web_agent_file_registry()
|
||||
file_info = _WEB_AGENT_FILE_REGISTRY.get(file_id)
|
||||
if not file_info:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="附件不存在或已过期")
|
||||
|
||||
file_path = file_info["path"]
|
||||
if not file_path.exists() or not file_path.is_file():
|
||||
_WEB_AGENT_FILE_REGISTRY.pop(file_id, None)
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="附件不存在或已过期")
|
||||
|
||||
return FileResponse(
|
||||
path=file_path,
|
||||
media_type=file_info.get("mime_type") or "application/octet-stream",
|
||||
filename=file_info.get("name") or file_path.name,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/stream", summary="Web智能助手流式对话")
|
||||
async def web_agent_stream(
|
||||
payload: schemas.AgentWebChatRequest,
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_active_superuser),
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
Web 智能助手流式对话。
|
||||
|
||||
:param payload: 对话请求
|
||||
:param request: 当前 HTTP 请求
|
||||
:param current_user: 当前登录管理员
|
||||
:return: SSE 流式响应
|
||||
"""
|
||||
if not settings.AI_AGENT_ENABLE:
|
||||
return StreamingResponse(
|
||||
iter([
|
||||
_build_web_agent_sse(
|
||||
"error",
|
||||
{"message": "智能助手未启用,请先在系统设置中开启。"},
|
||||
)
|
||||
]),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
prompt = payload.text.strip()
|
||||
if not prompt:
|
||||
return StreamingResponse(
|
||||
iter([
|
||||
_build_web_agent_sse(
|
||||
"error",
|
||||
{"message": "请输入要发送给智能助手的内容。"},
|
||||
)
|
||||
]),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
session_id = _build_web_agent_session_id(current_user, payload.session_id)
|
||||
event_queue: asyncio.Queue = asyncio.Queue()
|
||||
last_output = ""
|
||||
|
||||
def output_callback(output: str) -> None:
|
||||
"""
|
||||
接收 Agent 累积输出并转成增量事件。
|
||||
"""
|
||||
nonlocal last_output
|
||||
delta = output[len(last_output):] if output.startswith(last_output) else output
|
||||
last_output = output
|
||||
for item in _split_web_agent_output(delta):
|
||||
event_queue.put_nowait(item)
|
||||
|
||||
def notification_callback(notification: schemas.Notification) -> None:
|
||||
"""
|
||||
接收 Agent 工具主动发送的 Web 通知。
|
||||
"""
|
||||
for item in _build_web_agent_notification_events(notification):
|
||||
event_queue.put_nowait(item)
|
||||
|
||||
async def event_generator() -> AsyncIterator[str]:
|
||||
"""
|
||||
生成前端 Agent SSE 事件。
|
||||
"""
|
||||
files = [
|
||||
file.model_dump(exclude_none=True)
|
||||
for file in (payload.files or [])
|
||||
]
|
||||
for audio_ref in payload.audio_refs or []:
|
||||
files.append({"ref": audio_ref, "mime_type": "audio/*"})
|
||||
|
||||
agent = _WebAgentMoviePilotAgent(
|
||||
session_id=session_id,
|
||||
user_id=str(current_user.id),
|
||||
channel=MessageChannel.WebAgent.value,
|
||||
source=WEB_AGENT_SOURCE,
|
||||
username=current_user.name,
|
||||
replay_mode=ReplyMode.CAPTURE_ONLY,
|
||||
persist_output_message=False,
|
||||
allow_message_tools=True,
|
||||
output_callback=output_callback,
|
||||
notification_callback=notification_callback,
|
||||
)
|
||||
|
||||
async def run_agent() -> None:
|
||||
"""后台执行 Agent,并将结果写入事件队列。"""
|
||||
try:
|
||||
await agent.process(
|
||||
message=prompt,
|
||||
images=payload.images or [],
|
||||
files=files or None,
|
||||
has_audio_input=bool(payload.audio_refs),
|
||||
)
|
||||
except Exception as err:
|
||||
logger.error(f"Web智能助手执行失败: {str(err)}")
|
||||
await event_queue.put(
|
||||
{"type": "error", "message": f"智能助手执行失败: {str(err)}"}
|
||||
)
|
||||
finally:
|
||||
await event_queue.put({"type": "done"})
|
||||
|
||||
task = asyncio.create_task(run_agent())
|
||||
try:
|
||||
yield _build_web_agent_sse(
|
||||
"start",
|
||||
{"session_id": payload.session_id or session_id},
|
||||
)
|
||||
while not global_vars.is_system_stopped:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
event = await event_queue.get()
|
||||
yield _build_web_agent_sse(event.pop("type"), event)
|
||||
if task.done() and event_queue.empty():
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
finally:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
await agent.cleanup()
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
@@ -250,6 +250,33 @@ class SubscriptionMessage(BaseModel):
|
||||
data: Optional[dict] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class AgentWebChatRequest(BaseModel):
|
||||
"""
|
||||
Web 智能助手对话请求。
|
||||
"""
|
||||
|
||||
class AgentWebChatFile(BaseModel):
|
||||
"""
|
||||
Web 智能助手输入附件。
|
||||
"""
|
||||
|
||||
ref: str = Field(..., min_length=1)
|
||||
name: Optional[str] = Field(None)
|
||||
mime_type: Optional[str] = Field(None)
|
||||
size: Optional[int] = Field(None)
|
||||
|
||||
# 用户本轮输入
|
||||
text: str = Field(..., min_length=1)
|
||||
# 前端会话标识,相同标识复用同一段 Agent 记忆
|
||||
session_id: Optional[str] = Field(None)
|
||||
# 图片 URL 或 data URL 列表
|
||||
images: Optional[List[str]] = Field(default_factory=list)
|
||||
# 语音/音频引用列表
|
||||
audio_refs: Optional[List[str]] = Field(default_factory=list)
|
||||
# 文件附件列表
|
||||
files: Optional[List[AgentWebChatFile]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ChannelCapability(Enum):
|
||||
"""
|
||||
渠道能力枚举
|
||||
@@ -444,6 +471,19 @@ class ChannelCapabilityManager:
|
||||
},
|
||||
fallback_enabled=True,
|
||||
),
|
||||
MessageChannel.WebAgent: ChannelCapabilities(
|
||||
channel=MessageChannel.WebAgent,
|
||||
capabilities={
|
||||
ChannelCapability.MESSAGE_EDITING,
|
||||
ChannelCapability.MARKDOWN,
|
||||
ChannelCapability.RICH_TEXT,
|
||||
ChannelCapability.IMAGES,
|
||||
ChannelCapability.LINKS,
|
||||
ChannelCapability.AUDIO_OUTPUT,
|
||||
ChannelCapability.FILE_SENDING,
|
||||
},
|
||||
fallback_enabled=False,
|
||||
),
|
||||
MessageChannel.QQ: ChannelCapabilities(
|
||||
channel=MessageChannel.QQ,
|
||||
capabilities={
|
||||
|
||||
@@ -342,6 +342,7 @@ class MessageChannel(Enum):
|
||||
SynologyChat = "SynologyChat"
|
||||
VoceChat = "VoceChat"
|
||||
Web = "Web"
|
||||
WebAgent = "WebAgent"
|
||||
WebPush = "WebPush"
|
||||
QQ = "QQ"
|
||||
|
||||
|
||||
111
tests/test_web_agent_stream.py
Normal file
111
tests/test_web_agent_stream.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from app import schemas
|
||||
from app.api.endpoints.agent import (
|
||||
_build_web_agent_session_id,
|
||||
_build_web_agent_notification_events,
|
||||
_split_web_agent_output,
|
||||
)
|
||||
from app.schemas.message import ChannelCapability, ChannelCapabilityManager
|
||||
from app.schemas.types import MessageChannel, NotificationType
|
||||
|
||||
|
||||
def test_split_web_agent_output_extracts_verbose_tool_message():
|
||||
"""应将啰嗦模式工具提示拆成独立工具事件。"""
|
||||
events = _split_web_agent_output("准备查询。\n\n⚙️ => 查询站点\n\n已完成")
|
||||
|
||||
assert events == [
|
||||
{"type": "delta", "content": "准备查询。\n\n"},
|
||||
{"type": "tool", "message": "查询站点"},
|
||||
{"type": "delta", "content": "已完成"},
|
||||
]
|
||||
|
||||
|
||||
def test_split_web_agent_output_extracts_summary_tool_message():
|
||||
"""应将非啰嗦模式工具汇总行拆成独立工具事件。"""
|
||||
events = _split_web_agent_output("(查询了 2 次数据)\n\n这里是结果")
|
||||
|
||||
assert events == [
|
||||
{"type": "tool", "message": "查询了 2 次数据"},
|
||||
{"type": "delta", "content": "\n这里是结果"},
|
||||
]
|
||||
|
||||
|
||||
def test_build_web_agent_session_id_is_stable_per_user_and_seed():
|
||||
"""同一用户和前端会话标识应生成稳定的服务端会话 ID。"""
|
||||
user = SimpleNamespace(id=1, name="admin")
|
||||
|
||||
first = _build_web_agent_session_id(user, "browser-session")
|
||||
second = _build_web_agent_session_id(user, "browser-session")
|
||||
other = _build_web_agent_session_id(user, "other-session")
|
||||
|
||||
assert first == second
|
||||
assert first != other
|
||||
assert first.startswith("web-agent:")
|
||||
|
||||
|
||||
def test_web_agent_channel_supports_streaming_and_attachments():
|
||||
"""WebAgent 渠道应声明流式、多媒体和文件发送能力。"""
|
||||
assert ChannelCapabilityManager.supports_capability(
|
||||
MessageChannel.WebAgent, ChannelCapability.MESSAGE_EDITING
|
||||
)
|
||||
assert ChannelCapabilityManager.supports_capability(
|
||||
MessageChannel.WebAgent, ChannelCapability.IMAGES
|
||||
)
|
||||
assert ChannelCapabilityManager.supports_capability(
|
||||
MessageChannel.WebAgent, ChannelCapability.AUDIO_OUTPUT
|
||||
)
|
||||
assert ChannelCapabilityManager.supports_capability(
|
||||
MessageChannel.WebAgent, ChannelCapability.FILE_SENDING
|
||||
)
|
||||
|
||||
|
||||
def test_build_web_agent_notification_events_extracts_image():
|
||||
"""Agent 工具发送图片消息时应转换为图片附件事件。"""
|
||||
events = _build_web_agent_notification_events(
|
||||
schemas.Notification(
|
||||
channel=MessageChannel.WebAgent,
|
||||
mtype=NotificationType.Agent,
|
||||
title="海报",
|
||||
text="已找到图片",
|
||||
image="https://example.com/poster.jpg",
|
||||
)
|
||||
)
|
||||
|
||||
assert events == [
|
||||
{"type": "delta", "content": "海报\n\n已找到图片"},
|
||||
{
|
||||
"type": "attachment",
|
||||
"attachment": {
|
||||
"kind": "image",
|
||||
"url": "https://example.com/poster.jpg",
|
||||
"download_url": "https://example.com/poster.jpg",
|
||||
"name": "海报",
|
||||
"mime_type": None,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_build_web_agent_notification_events_registers_local_file(tmp_path):
|
||||
"""Agent 工具发送本地文件时应生成可下载附件事件。"""
|
||||
file_path = tmp_path / "report.txt"
|
||||
file_path.write_text("hello", encoding="utf-8")
|
||||
|
||||
events = _build_web_agent_notification_events(
|
||||
schemas.Notification(
|
||||
channel=MessageChannel.WebAgent,
|
||||
mtype=NotificationType.Agent,
|
||||
file_path=str(file_path),
|
||||
file_name="report.txt",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(events) == 1
|
||||
attachment = events[0]["attachment"]
|
||||
assert events[0]["type"] == "attachment"
|
||||
assert attachment["kind"] == "file"
|
||||
assert attachment["name"] == "report.txt"
|
||||
assert attachment["mime_type"] == "text/plain"
|
||||
assert attachment["size"] == 5
|
||||
assert attachment["url"].startswith("message/agent/file/")
|
||||
Reference in New Issue
Block a user