diff --git a/app/agent/__init__.py b/app/agent/__init__.py index 605a6551..4720e0f2 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -16,6 +16,7 @@ from langchain.agents.middleware import ( from langchain_core.messages import ( # noqa: F401 HumanMessage, BaseMessage, + SystemMessage, ) import warnings @@ -50,6 +51,7 @@ from app.agent.tools.factory import MoviePilotToolFactory from app.chain import ChainBase from app.core.config import settings from app.core.event import eventmanager +from app.db.agentchat_oper import AgentChatOper from app.db.user_oper import UserOper from app.log import logger from app.schemas import AgentLLMProviderEventData, AgentTokensUsageEventData, Notification, NotificationType @@ -229,6 +231,11 @@ HEARTBEAT_SESSION_PREFIX = "__agent_heartbeat_" UNSUPPORTED_IMAGE_INPUT_MESSAGE = "当前模型不支持图片输入,请更换支持图片输入的模型,或在系统设置中关闭图片输入支持后重试。" AGENT_EXECUTION_ERROR_PREFIX = "智能助手执行失败" AGENT_EXECUTION_ERROR_MESSAGE = "智能助手执行失败,请稍后重试。" +AGENT_DISPLAY_HISTORY_SKIP_CHANNELS = {MessageChannel.WebAgent.value} +AGENT_CHAT_TITLE_PROMPT = ( + "你是 MoviePilot 智能助手的会话标题生成器。请根据用户的第一条消息生成一个简洁中文标题," + "不超过 18 个汉字或 36 个英文字符,只输出标题本身,不要引号、编号或解释。" +) class MoviePilotAgent: @@ -269,6 +276,129 @@ class MoviePilotAgent: # 流式token管理 self.stream_handler = StreamingHandler() + @staticmethod + def _current_timestamp_ms() -> int: + """返回当前毫秒时间戳。""" + return int(datetime.now().timestamp() * 1000) + + @classmethod + def build_display_message( + cls, + role: str, + content: str = "", + attachments: Optional[List[dict]] = None, + status: str = "done", + ) -> dict[str, Any]: + """ + 构造可展示的 Agent 会话消息。 + """ + return { + "id": f"{role}-{uuid.uuid4().hex}", + "role": role, + "content": content or "", + "createdAt": cls._current_timestamp_ms(), + "status": status, + "tools": [], + "attachments": attachments or [], + "choices": [], + } + + def _should_save_display_history(self) -> bool: + """ + 判断当前 Agent 是否由通用渠道保存展示历史。 + """ + return bool( + self.channel + and self.source + and self.channel not in AGENT_DISPLAY_HISTORY_SKIP_CHANNELS + ) + + def _save_display_history_messages(self, messages: List[dict]) -> None: + """ + 将一组可见消息追加到 Agent 会话历史表。 + """ + if not messages or not self._should_save_display_history(): + return + try: + AgentChatOper().append_display_messages( + session_id=self.session_id, + user_id=self.user_id, + username=self.username, + channel=self.channel, + source=self.source, + original_chat_id=self.original_chat_id, + messages=messages, + ) + except Exception as e: + logger.debug(f"写入Agent展示历史失败: {e}") + + def _save_assistant_display_message_once(self, message: str) -> None: + """ + 保存一条助手回复展示记录,并标记本轮已写入。 + """ + if not message or self._tool_context.get("assistant_display_saved"): + return + self._save_display_history_messages( + [self.build_display_message(role="assistant", content=message)] + ) + self._tool_context["assistant_display_saved"] = True + + @staticmethod + def _sanitize_chat_title(value: str) -> str: + """清理模型返回的会话标题。""" + title = str(value or "").strip() + title = re.sub(r"^[#\-*\d.、\s]+", "", title) + title = title.strip("「」『』“”\"'` \n\t") + title = re.sub(r"\s+", " ", title) + return title[:120] + + async def _generate_chat_title(self, message: str) -> str: + """ + 使用当前 Agent 模型生成会话标题。 + """ + if not str(message or "").strip(): + return "" + model = await self._initialize_llm(streaming=False) + response = await model.ainvoke( + [ + SystemMessage(content=AGENT_CHAT_TITLE_PROMPT), + HumanMessage(content=str(message).strip()[:1000]), + ] + ) + content = LLMHelper._extract_text_content(getattr(response, "content", response)) + return self._sanitize_chat_title(content) + + async def prepare_chat_title(self, message: str) -> None: + """ + 首次对话时生成并保存会话标题。 + """ + if self._tool_context.get("chat_title_prepared"): + return + self._tool_context["chat_title_prepared"] = True + try: + chat = await run_in_threadpool( + AgentChatOper().get, + session_id=self.session_id, + user_id=self.user_id, + ) + if chat and AgentChatOper.has_custom_title(chat.title): + return + title = await self._generate_chat_title(message) + if not title: + return + await run_in_threadpool( + AgentChatOper().update_title_if_empty, + session_id=self.session_id, + user_id=self.user_id, + title=title, + username=self.username, + channel=self.channel, + source=self.source, + original_chat_id=self.original_chat_id, + ) + except Exception as e: + logger.debug(f"生成Agent会话标题失败: {e}") + @staticmethod def _coerce_int(value: Any) -> Optional[int]: if value is None: @@ -944,6 +1074,7 @@ class MoviePilotAgent: """ 处理用户消息,流式推理并返回 Agent 回复 """ + user_display_saved = False try: logger.info( f"Agent推理: session_id={self.session_id}, input={message}, " @@ -982,6 +1113,21 @@ class MoviePilotAgent: for img in images or []: content.append({"type": "image_url", "image_url": {"url": img}}) messages.append(HumanMessage(content=content)) + await self.prepare_chat_title(message) + self._save_display_history_messages( + [ + self.build_display_message( + role="user", + content=message, + attachments=self._build_input_display_attachments( + images=images, + files=files, + has_audio_input=has_audio_input, + ), + ) + ] + ) + user_display_saved = True # 执行推理 result = await self._execute_agent(messages) @@ -992,11 +1138,63 @@ class MoviePilotAgent: except Exception as e: error_message = f"处理消息时发生错误: {str(e)}" logger.error(error_message) + if not user_display_saved: + self._save_display_history_messages( + [self.build_display_message(role="user", content=message)] + ) if not self.should_dispatch_reply: raise await self.send_agent_message(error_message) return error_message + @staticmethod + def _guess_file_attachment_kind(mime_type: Optional[str], fallback: str = "file") -> str: + """ + 根据 MIME 类型推断展示附件类型。 + """ + if mime_type and mime_type.startswith("image/"): + return "image" + if mime_type and mime_type.startswith("audio/"): + return "audio" + return fallback + + def _build_input_display_attachments( + self, + images: Optional[List[str]] = None, + files: Optional[List[dict]] = None, + has_audio_input: bool = False, + ) -> List[dict]: + """ + 构造用户输入附件的展示记录。 + """ + attachments: List[dict] = [] + for index, image in enumerate(images or [], start=1): + attachments.append( + { + "kind": "image", + "url": image, + "download_url": image, + "name": f"image-{index}", + "mime_type": "image/*", + } + ) + for index, file in enumerate(files or [], start=1): + ref = file.get("ref") or file.get("local_path") or "" + mime_type = file.get("mime_type") + fallback = "audio" if has_audio_input and mime_type == "audio/*" else "file" + attachments.append( + { + "kind": self._guess_file_attachment_kind(mime_type, fallback=fallback), + "url": ref, + "download_url": ref, + "name": file.get("name") or f"attachment-{index}", + "mime_type": mime_type, + "size": file.get("size"), + "local_path": file.get("local_path"), + } + ) + return attachments + async def _stream_agent_tokens( self, agent, messages: dict, config: dict, on_token: Callable[[str], None] ): @@ -1164,6 +1362,17 @@ class MoviePilotAgent: # 非流式渠道:发送最终回复 await self.send_agent_message(final_text) + display_text = self._streamed_output + if not display_text: + final_messages = agent.get_state(agent_config).values.get( + "messages", [] + ) + for msg in reversed(final_messages): + if hasattr(msg, "type") and msg.type == "ai" and msg.content: + display_text = self._extract_text_content(msg.content).strip() + break + self._save_assistant_display_message_once(display_text) + # 保存消息 memory_manager.save_agent_messages( session_id=self.session_id, @@ -1200,6 +1409,7 @@ class MoviePilotAgent: """ 通过原渠道发送消息给用户 """ + self._save_assistant_display_message_once(message) await AgentChain().async_post_message( Notification( channel=self.channel, diff --git a/app/agent/memory/__init__.py b/app/agent/memory/__init__.py index ad9782b8..85843d00 100644 --- a/app/agent/memory/__init__.py +++ b/app/agent/memory/__init__.py @@ -4,9 +4,10 @@ import asyncio from datetime import datetime from typing import Dict, List, Optional -from langchain_core.messages import BaseMessage +from langchain_core.messages import BaseMessage, messages_from_dict, messages_to_dict from app.core.config import settings +from app.db.agentchat_oper import AgentChatOper from app.log import logger from app.schemas.agent import ConversationMemory @@ -70,24 +71,43 @@ class MemoryManager: self, session_id: str, user_id: str ) -> List[BaseMessage]: """ - 为Agent获取最近的消息(仅内存缓存) + 为Agent获取最近的消息。 - 如果消息Token数量超过模型最大上下文长度的阀值,会自动进行摘要裁剪 + 优先使用内存缓存,缓存不存在时从数据库恢复上一轮持久化的原始 messages。 """ memory = self.get_memory(session_id, user_id) - if not memory: + if memory: + return memory.messages + + try: + chat = AgentChatOper().get(session_id=session_id, user_id=user_id) + if not chat: + chat = AgentChatOper().get(session_id=session_id) + except Exception as e: + logger.debug(f"读取持久化Agent会话失败: {e}") + return [] + if not chat or not chat.agent_messages: return [] - # 获取所有消息 + try: + messages = messages_from_dict(chat.agent_messages) + except Exception as e: + logger.debug(f"恢复持久化Agent消息失败: {e}") + return [] + + memory = ConversationMemory( + session_id=session_id, + user_id=user_id, + messages=messages, + ) + self.save_memory(memory) return memory.messages def save_agent_messages( self, session_id: str, user_id: str, messages: List[BaseMessage] ): """ - 保存Agent消息(仅内存缓存) - - 注意:Redis中的记忆通过TTL机制自动过期,这里只更新内存缓存,Redis会在下次访问时自动过期 + 保存Agent消息到内存缓存与持久化会话表。 """ memory = self.get_memory(session_id, user_id) if not memory: @@ -98,6 +118,14 @@ class MemoryManager: # 更新内存缓存 self.save_memory(memory) + try: + AgentChatOper().save_agent_messages( + session_id=session_id, + user_id=user_id, + messages=messages_to_dict(messages), + ) + except Exception as e: + logger.debug(f"持久化Agent消息失败: {e}") def save_memory(self, memory: ConversationMemory): """ diff --git a/app/api/endpoints/agent.py b/app/api/endpoints/agent.py index 321033ff..1a8a5cd1 100644 --- a/app/api/endpoints/agent.py +++ b/app/api/endpoints/agent.py @@ -10,13 +10,18 @@ from pathlib import Path from typing import Any, AsyncIterator, Callable, Optional from fastapi import APIRouter, Depends, File, Form, HTTPException, Request, UploadFile, status +from fastapi.concurrency import run_in_threadpool from fastapi.responses import FileResponse, StreamingResponse +from sqlalchemy.ext.asyncio import AsyncSession from app import schemas from app.agent import MoviePilotAgent, ReplyMode, StreamingHandler from app.agent.llm.capability import AgentCapabilityManager from app.core.config import global_vars, settings +from app.db import get_async_db +from app.db.agentchat_oper import AgentChatOper from app.db.models import User +from app.db.models.agentchat import AgentChat from app.db.user_oper import UserOper, get_current_active_user from app.helper.interaction import agent_interaction_manager from app.log import logger @@ -152,11 +157,122 @@ def _build_web_agent_session_id(user: User, session_id: Optional[str]) -> str: :return: 可用于 Agent 记忆隔离的服务端会话 ID """ seed = str(session_id or "").strip() or uuid.uuid4().hex + if seed.startswith(WEB_AGENT_SESSION_PREFIX): + return seed + try: + existing_chat = AgentChatOper().get(session_id=seed) + if existing_chat and _can_access_agent_chat(existing_chat, user): + return seed + except Exception as e: + logger.debug(f"读取WebAgent历史会话失败: {e}") user_part = user.name or str(user.id) digest = hashlib.sha256(f"{user_part}:{seed}".encode("utf-8")).hexdigest() return f"{WEB_AGENT_SESSION_PREFIX}{digest[:32]}" +def _can_access_agent_chat(chat: AgentChat, user: User) -> bool: + """ + 判断当前登录用户是否可以访问指定 Agent 会话。 + + 超级用户可查看所有渠道历史;普通用户仅能查看 user_id 或 username 匹配自己的会话。 + """ + if not chat or not user: + return False + if getattr(user, "is_superuser", False): + return True + user_id = str(user.id) + username = str(user.name or "") + return chat.user_id == user_id or (bool(username) and chat.username == username) + + +async def _get_accessible_agent_chat( + oper: AgentChatOper, session_id: str, user: User +) -> Optional[AgentChat]: + """ + 读取当前用户可访问的 Agent 会话。 + """ + chat = await oper.async_get(session_id=session_id) + if not chat or not _can_access_agent_chat(chat, user): + return None + return chat + + +def _apply_web_agent_display_event(event: dict, assistant_message: dict) -> None: + """ + 将 WebAgent SSE 事件同步应用到服务端展示消息快照。 + """ + event_type = event.get("type") + if event_type == "delta": + assistant_message["content"] += event.get("content") or "" + elif event_type == "tool": + for tool in assistant_message["tools"]: + tool["status"] = "done" + assistant_message["tools"].append( + { + "id": f"tool-{uuid.uuid4().hex}", + "message": str(event.get("message") or "").replace("=>", "", 1).strip(), + "status": "running", + } + ) + elif event_type == "attachment" and event.get("attachment"): + assistant_message["attachments"].append(event["attachment"]) + elif event_type == "choice" and event.get("choice"): + assistant_message["choices"].append({**event["choice"], "status": "pending"}) + elif event_type == "error": + assistant_message["status"] = "error" + assistant_message["content"] = ( + assistant_message["content"] + or event.get("message") + or "智能助手响应失败" + ) + for tool in assistant_message["tools"]: + tool["status"] = "done" + elif event_type == "done": + if assistant_message.get("status") != "error": + assistant_message["status"] = "done" + for tool in assistant_message["tools"]: + tool["status"] = "done" + + +def _save_web_agent_display_snapshot( + *, + session_id: str, + current_user: User, + messages: list[dict], + client_session_id: Optional[str] = None, +) -> None: + """ + 保存 WebAgent 当前展示消息快照。 + """ + try: + oper = AgentChatOper() + existing_chat = oper.get(session_id=session_id) + AgentChatOper().save_display_messages( + session_id=session_id, + user_id=(existing_chat.user_id if existing_chat else str(current_user.id)), + username=(existing_chat.username if existing_chat else current_user.name), + channel=( + existing_chat.channel + if existing_chat and existing_chat.channel + else MessageChannel.WebAgent + ), + source=( + existing_chat.source + if existing_chat and existing_chat.source + else WEB_AGENT_SOURCE + ), + original_chat_id=existing_chat.original_chat_id if existing_chat else None, + client_session_id=( + existing_chat.client_session_id + if existing_chat and existing_chat.client_session_id + else client_session_id + ), + messages=messages, + ) + except Exception as e: + logger.debug(f"保存WebAgent展示历史失败: {e}") + + def _build_web_agent_sse(event_type: str, data: Optional[dict] = None) -> str: """ 构建 Web Agent SSE 消息。 @@ -299,6 +415,52 @@ def _build_web_agent_url_attachment( } +def _build_web_agent_input_attachments( + images: list[str], + files: list[dict], + audio_refs: list[str], +) -> list[dict]: + """ + 构造 WebAgent 用户输入附件展示记录。 + """ + attachments = [] + for index, image in enumerate(images or [], start=1): + attachments.append( + { + "kind": "image", + "url": image, + "download_url": image, + "name": f"image-{index}", + "mime_type": "image/*", + } + ) + for index, file in enumerate(files or [], start=1): + ref = file.get("ref") or file.get("url") or file.get("local_path") or "" + mime_type = file.get("mime_type") + attachments.append( + { + "kind": _guess_web_agent_attachment_kind(mime_type), + "url": ref, + "download_url": ref, + "name": file.get("name") or f"attachment-{index}", + "mime_type": mime_type, + "size": file.get("size"), + "local_path": file.get("local_path"), + } + ) + for index, audio_ref in enumerate(audio_refs or [], start=1): + attachments.append( + { + "kind": "audio", + "url": audio_ref, + "download_url": audio_ref, + "name": f"voice-{index}", + "mime_type": "audio/*", + } + ) + return attachments + + def _register_web_agent_file( file_path: Optional[str], file_name: Optional[str] = None, @@ -790,6 +952,116 @@ async def web_agent_callback( return schemas.Response(success=True, data=result) +@router.get("/sessions", summary="获取 Agent 历史会话", response_model=schemas.Response) +async def list_agent_chat_sessions( + current_user: User = Depends(get_current_active_user), + db: AsyncSession = Depends(get_async_db), + page: Optional[int] = 1, + count: Optional[int] = 30, +) -> schemas.Response: + """ + 获取当前用户可访问的 Agent 历史会话列表。 + + :param current_user: 当前登录用户 + :param db: 异步数据库会话 + :param page: 页码 + :param count: 每页数量 + :return: 会话摘要列表 + """ + user_id = None if current_user.is_superuser else str(current_user.id) + username = None if current_user.is_superuser else current_user.name + chats = await AgentChatOper(db).async_list_by_page( + page=page, + count=count, + user_id=user_id, + username=username, + ) + return schemas.Response( + success=True, + data=[AgentChatOper.to_summary(chat) for chat in chats], + ) + + +@router.get("/sessions/{session_id}", summary="获取 Agent 历史会话详情", response_model=schemas.Response) +async def get_agent_chat_session( + session_id: str, + current_user: User = Depends(get_current_active_user), + db: AsyncSession = Depends(get_async_db), +) -> schemas.Response: + """ + 获取一条 Agent 历史会话详情。 + + :param session_id: Agent 会话 ID + :param current_user: 当前登录用户 + :param db: 异步数据库会话 + :return: 会话详情 + """ + chat = await _get_accessible_agent_chat(AgentChatOper(db), session_id, current_user) + if not chat: + return schemas.Response(success=False, message="会话不存在或无权访问") + return schemas.Response(success=True, data=AgentChatOper.to_detail(chat)) + + +@router.put("/sessions/{session_id}/display", summary="保存 Agent 展示会话", response_model=schemas.Response) +async def save_agent_chat_display( + session_id: str, + payload: schemas.AgentChatDisplaySaveRequest, + current_user: User = Depends(get_current_active_user), + db: AsyncSession = Depends(get_async_db), +) -> schemas.Response: + """ + 保存前端聚合后的 Agent 展示消息。 + + :param session_id: Agent 会话 ID + :param payload: 展示消息保存请求 + :param current_user: 当前登录用户 + :param db: 异步数据库会话 + :return: 保存后的会话摘要 + """ + oper = AgentChatOper(db) + existing_chat = await oper.async_get(session_id=session_id) + if existing_chat and not _can_access_agent_chat(existing_chat, current_user): + return schemas.Response(success=False, message="会话不存在或无权访问") + + messages = [ + message.model_dump(exclude_none=True) + for message in payload.messages + ] + await run_in_threadpool( + _save_web_agent_display_snapshot, + session_id=session_id, + current_user=current_user, + messages=messages, + client_session_id=existing_chat.client_session_id if existing_chat else session_id, + ) + chat = await oper.async_get(session_id=session_id) + if not chat: + return schemas.Response(success=False, message="会话保存失败") + return schemas.Response(success=True, data=AgentChatOper.to_summary(chat)) + + +@router.delete("/sessions/{session_id}", summary="删除 Agent 历史会话", response_model=schemas.Response) +async def delete_agent_chat_session( + session_id: str, + current_user: User = Depends(get_current_active_user), + db: AsyncSession = Depends(get_async_db), +) -> schemas.Response: + """ + 删除一条 Agent 历史会话。 + + :param session_id: Agent 会话 ID + :param current_user: 当前登录用户 + :param db: 异步数据库会话 + :return: 删除结果 + """ + oper = AgentChatOper(db) + chat = await _get_accessible_agent_chat(oper, session_id, current_user) + if not chat: + return schemas.Response(success=False, message="会话不存在或无权访问") + deleted = await oper.async_delete(session_id=session_id) + return schemas.Response(success=deleted, message="删除成功" if deleted else "删除失败") + + @router.post("/stream", summary="Web智能助手流式对话") async def web_agent_stream( payload: schemas.AgentWebChatRequest, @@ -843,6 +1115,28 @@ async def web_agent_stream( session_id = _build_web_agent_session_id(current_user, payload.session_id) event_queue: asyncio.Queue = asyncio.Queue() last_output = "" + user_attachments = _build_web_agent_input_attachments( + images=payload.images or [], + files=[ + file.model_dump(exclude_none=True) + for file in (payload.files or []) + ], + audio_refs=payload.audio_refs or [], + ) + display_messages = [] + if payload.echo_user: + display_messages.append( + MoviePilotAgent.build_display_message( + role="user", + content=prompt, + attachments=user_attachments, + ) + ) + assistant_display_message = MoviePilotAgent.build_display_message( + role="assistant", + status="streaming", + ) + display_messages.append(assistant_display_message) def output_callback(output: str) -> None: """ @@ -852,6 +1146,7 @@ async def web_agent_stream( delta = output[len(last_output):] if output.startswith(last_output) else output last_output = output for item in _split_web_agent_output(delta): + _apply_web_agent_display_event(item, assistant_display_message) event_queue.put_nowait(item) def notification_callback(notification: schemas.Notification) -> None: @@ -859,6 +1154,7 @@ async def web_agent_stream( 接收 Agent 工具主动发送的 Web 通知。 """ for item in _build_web_agent_notification_events(notification): + _apply_web_agent_display_event(item, assistant_display_message) event_queue.put_nowait(item) async def event_generator() -> AsyncIterator[str]: @@ -897,17 +1193,29 @@ async def web_agent_stream( ) except Exception as err: logger.error(f"Web智能助手执行失败: {str(err)}") - await event_queue.put( - {"type": "error", "message": f"智能助手执行失败: {str(err)}"} - ) + error_event = { + "type": "error", + "message": f"智能助手执行失败: {str(err)}", + } + _apply_web_agent_display_event(error_event, assistant_display_message) + await event_queue.put(error_event) finally: - await event_queue.put({"type": "done"}) + done_event = {"type": "done"} + _apply_web_agent_display_event(done_event, assistant_display_message) + await run_in_threadpool( + _save_web_agent_display_snapshot, + session_id=session_id, + current_user=current_user, + messages=display_messages, + client_session_id=payload.session_id or session_id, + ) + await event_queue.put(done_event) task = asyncio.create_task(run_agent()) try: yield _build_web_agent_sse( "start", - {"session_id": payload.session_id or session_id}, + {"session_id": session_id}, ) while not global_vars.is_system_stopped: if await request.is_disconnected(): diff --git a/app/db/agentchat_oper.py b/app/db/agentchat_oper.py new file mode 100644 index 00000000..3c9dc51e --- /dev/null +++ b/app/db/agentchat_oper.py @@ -0,0 +1,336 @@ +import time +from typing import Any, Optional, Union + +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Session + +from app.db import DbOper +from app.db.models.agentchat import AgentChat +from app.schemas.types import MessageChannel + +DEFAULT_AGENT_CHAT_TITLE = "未命名会话" + + +class AgentChatOper(DbOper): + """ + Agent 会话历史数据管理。 + """ + + def __init__(self, db: Union[Session, AsyncSession] = None): + super().__init__(db) + + @staticmethod + def _now() -> str: + """返回数据库统一使用的当前时间字符串。""" + return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + + @staticmethod + def _channel_value(channel: Optional[Union[MessageChannel, str]]) -> Optional[str]: + """获取渠道枚举的字符串值。""" + if isinstance(channel, MessageChannel): + return channel.value + return channel + + @staticmethod + def _normalize_messages(messages: Optional[list[dict]]) -> list[dict]: + """规范化展示消息列表,避免 JSON 字段存入 None。""" + return messages if isinstance(messages, list) else [] + + @staticmethod + def _normalize_title(value: Optional[str], messages: list[dict]) -> str: + """生成会话标题。""" + if value and value.strip(): + return value.strip()[:120] + for message in messages: + if message.get("role") != "user": + continue + content = str(message.get("content") or "").strip() + if content: + return content.replace("\n", " ")[:120] + attachments = message.get("attachments") + if isinstance(attachments, list) and attachments: + name = attachments[0].get("name") or "附件消息" + return str(name)[:120] + return DEFAULT_AGENT_CHAT_TITLE + + @staticmethod + def has_custom_title(value: Optional[str]) -> bool: + """判断会话是否已有真实标题。""" + return bool(value and value.strip() and value.strip() != DEFAULT_AGENT_CHAT_TITLE) + + @staticmethod + def _normalize_preview(messages: list[dict]) -> str: + """生成会话预览文本。""" + for message in reversed(messages): + content = str(message.get("content") or "").strip() + if content: + return content.replace("\n", " ")[:240] + attachments = message.get("attachments") + if isinstance(attachments, list) and attachments: + name = attachments[0].get("name") or "附件消息" + return str(name)[:240] + return "" + + def get( + self, session_id: str, user_id: Optional[str] = None + ) -> Optional[AgentChat]: + """ + 获取 Agent 会话。 + """ + return AgentChat.get_by_session(self._db, session_id, user_id) + + async def async_get( + self, session_id: str, user_id: Optional[str] = None + ) -> Optional[AgentChat]: + """ + 异步获取 Agent 会话。 + """ + return await AgentChat.async_get_by_session(self._db, session_id, user_id) + + def ensure_session( + self, + session_id: str, + user_id: Optional[str] = None, + username: Optional[str] = None, + channel: Optional[Union[MessageChannel, str]] = None, + source: Optional[str] = None, + original_chat_id: Optional[str] = None, + client_session_id: Optional[str] = None, + ) -> AgentChat: + """ + 确保 Agent 会话记录存在,并刷新基础渠道信息。 + """ + now = self._now() + chat = self.get(session_id=session_id, user_id=user_id) + if not chat: + chat = self.get(session_id=session_id) + payload = { + "user_id": user_id, + "username": username, + "channel": self._channel_value(channel), + "source": source, + "original_chat_id": original_chat_id, + "client_session_id": client_session_id, + "updated_at": now, + } + payload = {key: value for key, value in payload.items() if value is not None} + if chat: + chat.update(self._db, payload) + return self.get(session_id=session_id, user_id=user_id) or self.get(session_id=session_id) + + chat = AgentChat( + session_id=session_id, + user_id=user_id, + username=username, + channel=self._channel_value(channel), + source=source, + original_chat_id=original_chat_id, + client_session_id=client_session_id, + title=DEFAULT_AGENT_CHAT_TITLE, + preview="", + agent_messages=[], + display_messages=[], + message_count=0, + created_at=now, + updated_at=now, + ) + chat.create(self._db) + return self.get(session_id=session_id, user_id=user_id) or self.get(session_id=session_id) + + def save_agent_messages( + self, + session_id: str, + user_id: Optional[str], + messages: list[dict], + ) -> None: + """ + 保存可恢复 Agent 上下文的原始消息。 + """ + chat = self.get(session_id=session_id, user_id=user_id) + if not chat: + chat = self.get(session_id=session_id) + if not chat: + chat = self.ensure_session(session_id=session_id, user_id=user_id) + chat.update( + self._db, + { + "agent_messages": messages or [], + "updated_at": self._now(), + }, + ) + + def update_title_if_empty( + self, + session_id: str, + user_id: Optional[str], + title: Optional[str], + username: Optional[str] = None, + channel: Optional[Union[MessageChannel, str]] = None, + source: Optional[str] = None, + original_chat_id: Optional[str] = None, + client_session_id: Optional[str] = None, + ) -> None: + """ + 在会话尚未生成标题时写入标题。 + """ + normalized_title = self._normalize_title(title, []) + if normalized_title == DEFAULT_AGENT_CHAT_TITLE: + return + + chat = self.ensure_session( + session_id=session_id, + user_id=user_id, + username=username, + channel=channel, + source=source, + original_chat_id=original_chat_id, + client_session_id=client_session_id, + ) + if self.has_custom_title(chat.title): + return + chat.update( + self._db, + { + "title": normalized_title, + "updated_at": self._now(), + }, + ) + + def save_display_messages( + self, + session_id: str, + user_id: Optional[str] = None, + messages: Optional[list[dict]] = None, + username: Optional[str] = None, + channel: Optional[Union[MessageChannel, str]] = None, + source: Optional[str] = None, + original_chat_id: Optional[str] = None, + client_session_id: Optional[str] = None, + title: Optional[str] = None, + ) -> AgentChat: + """ + 保存用户可见的 Agent 会话消息。 + """ + normalized_messages = self._normalize_messages(messages) + chat = self.ensure_session( + session_id=session_id, + user_id=user_id, + username=username, + channel=channel, + source=source, + original_chat_id=original_chat_id, + client_session_id=client_session_id, + ) + normalized_title = ( + chat.title + if self.has_custom_title(chat.title) + else self._normalize_title(title, normalized_messages) + ) + chat.update( + self._db, + { + "title": normalized_title, + "preview": self._normalize_preview(normalized_messages), + "display_messages": normalized_messages, + "message_count": len(normalized_messages), + "updated_at": self._now(), + }, + ) + return self.get(session_id=session_id, user_id=user_id) or self.get(session_id=session_id) + + def append_display_messages( + self, + session_id: str, + user_id: Optional[str] = None, + messages: Optional[list[dict]] = None, + username: Optional[str] = None, + channel: Optional[Union[MessageChannel, str]] = None, + source: Optional[str] = None, + original_chat_id: Optional[str] = None, + client_session_id: Optional[str] = None, + ) -> AgentChat: + """ + 追加一组用户可见的 Agent 会话消息。 + """ + chat = self.ensure_session( + session_id=session_id, + user_id=user_id, + username=username, + channel=channel, + source=source, + original_chat_id=original_chat_id, + client_session_id=client_session_id, + ) + display_messages = self._normalize_messages(chat.display_messages) + display_messages.extend(self._normalize_messages(messages)) + title = chat.title if self.has_custom_title(chat.title) else None + return self.save_display_messages( + session_id=session_id, + user_id=user_id, + messages=display_messages, + username=username or chat.username, + channel=channel or chat.channel, + source=source or chat.source, + original_chat_id=original_chat_id or chat.original_chat_id, + client_session_id=client_session_id or chat.client_session_id, + title=title, + ) + + async def async_list_by_page( + self, + page: Optional[int] = 1, + count: Optional[int] = 30, + user_id: Optional[str] = None, + username: Optional[str] = None, + ) -> list[AgentChat]: + """ + 异步分页获取 Agent 会话历史。 + """ + return await AgentChat.async_list_by_page( + self._db, + page=page, + count=count, + user_id=user_id, + username=username, + ) + + async def async_delete( + self, session_id: str, user_id: Optional[str] = None + ) -> bool: + """ + 异步删除 Agent 会话历史。 + """ + chat = await self.async_get(session_id=session_id, user_id=user_id) + if not chat: + return False + await AgentChat.async_delete(self._db, chat.id) + return True + + @staticmethod + def to_summary(chat: AgentChat) -> dict[str, Any]: + """ + 转换为历史会话摘要。 + """ + return { + "id": chat.id, + "session_id": chat.session_id, + "client_session_id": chat.client_session_id, + "title": chat.title, + "channel": chat.channel, + "source": chat.source, + "user_id": chat.user_id, + "username": chat.username, + "original_chat_id": chat.original_chat_id, + "message_count": chat.message_count or 0, + "created_at": chat.created_at, + "updated_at": chat.updated_at, + } + + @classmethod + def to_detail(cls, chat: AgentChat) -> dict[str, Any]: + """ + 转换为历史会话详情。 + """ + data = cls.to_summary(chat) + data["messages"] = chat.display_messages or [] + return data diff --git a/app/db/models/__init__.py b/app/db/models/__init__.py index 9f772c02..81b7227b 100644 --- a/app/db/models/__init__.py +++ b/app/db/models/__init__.py @@ -1,3 +1,4 @@ +from .agentchat import AgentChat from .downloadhistory import DownloadHistory, DownloadFiles from .mediaserver import MediaServerItem from .message import Message diff --git a/app/db/models/agentchat.py b/app/db/models/agentchat.py new file mode 100644 index 00000000..56d0160c --- /dev/null +++ b/app/db/models/agentchat.py @@ -0,0 +1,130 @@ +from typing import Optional + +from sqlalchemy import Column, Integer, String, JSON, Index, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Session + +from app.db import Base, async_db_query, db_query, get_id_column + + +class AgentChat(Base): + """ + Agent 会话历史表。 + """ + + id = get_id_column() + # Agent 内部会话 ID,用于恢复 LangGraph 对话上下文 + session_id = Column(String, nullable=False) + # 前端或渠道侧传入的原始会话标识 + client_session_id = Column(String) + # 用户 ID + user_id = Column(String) + # 用户名称 + username = Column(String) + # 消息渠道 + channel = Column(String) + # 渠道来源配置名 + source = Column(String) + # 原聊天 ID,用于区分群聊、频道或私聊 + original_chat_id = Column(String) + # 会话标题 + title = Column(String) + # 会话预览文本 + preview = Column(String) + # 原始 LangChain messages,用于继续会话 + agent_messages = Column(JSON) + # 展示给用户的消息记录,包含文字、工具提示、附件与选择卡片 + display_messages = Column(JSON) + # 展示消息数量 + message_count = Column(Integer, default=0) + # 创建时间 + created_at = Column(String) + # 更新时间 + updated_at = Column(String) + + __table_args__ = ( + Index("ix_agentchat_session_user", "session_id", "user_id"), + Index("ix_agentchat_user_updated", "user_id", "updated_at", "id"), + Index("ix_agentchat_channel_updated", "channel", "updated_at", "id"), + ) + + @classmethod + @db_query + def get_by_session( + cls, db: Session, session_id: str, user_id: Optional[str] = None + ) -> Optional["AgentChat"]: + """ + 根据会话 ID 获取 Agent 会话。 + """ + query = db.query(cls).filter(cls.session_id == session_id) + if user_id is not None: + query = query.filter(cls.user_id == user_id) + return query.order_by(cls.id.desc()).first() + + @classmethod + @async_db_query + async def async_get_by_session( + cls, db: AsyncSession, session_id: str, user_id: Optional[str] = None + ) -> Optional["AgentChat"]: + """ + 异步根据会话 ID 获取 Agent 会话。 + """ + statement = select(cls).where(cls.session_id == session_id) + if user_id is not None: + statement = statement.where(cls.user_id == user_id) + result = await db.execute(statement.order_by(cls.id.desc())) + return result.scalars().first() + + @classmethod + @db_query + def list_by_page( + cls, + db: Session, + page: Optional[int] = 1, + count: Optional[int] = 30, + user_id: Optional[str] = None, + username: Optional[str] = None, + ) -> list["AgentChat"]: + """ + 分页获取 Agent 会话历史。 + """ + query = db.query(cls) + if user_id is not None and username is not None: + query = query.filter((cls.user_id == user_id) | (cls.username == username)) + elif user_id is not None: + query = query.filter(cls.user_id == user_id) + elif username is not None: + query = query.filter(cls.username == username) + return ( + query.order_by(cls.updated_at.desc(), cls.id.desc()) + .offset((page - 1) * count) + .limit(count) + .all() + ) + + @classmethod + @async_db_query + async def async_list_by_page( + cls, + db: AsyncSession, + page: Optional[int] = 1, + count: Optional[int] = 30, + user_id: Optional[str] = None, + username: Optional[str] = None, + ) -> list["AgentChat"]: + """ + 异步分页获取 Agent 会话历史。 + """ + statement = select(cls) + if user_id is not None and username is not None: + statement = statement.where((cls.user_id == user_id) | (cls.username == username)) + elif user_id is not None: + statement = statement.where(cls.user_id == user_id) + elif username is not None: + statement = statement.where(cls.username == username) + result = await db.execute( + statement.order_by(cls.updated_at.desc(), cls.id.desc()) + .offset((page - 1) * count) + .limit(count) + ) + return result.scalars().all() diff --git a/app/schemas/__init__.py b/app/schemas/__init__.py index 1fcbf9c6..85e8df1c 100644 --- a/app/schemas/__init__.py +++ b/app/schemas/__init__.py @@ -1,3 +1,4 @@ +from .agent import * from .context import * from .dashboard import * from .download import * diff --git a/app/schemas/agent.py b/app/schemas/agent.py index 2bc88062..ff94452d 100644 --- a/app/schemas/agent.py +++ b/app/schemas/agent.py @@ -1,7 +1,7 @@ """AI智能体相关数据模型""" from datetime import datetime -from typing import List, Optional +from typing import List, Optional, Union from langchain_core.messages import BaseMessage from pydantic import BaseModel, Field, ConfigDict, field_serializer @@ -55,3 +55,95 @@ class ToolResult(BaseModel): success: bool = Field(description="是否成功") result: Optional[str] = Field(default=None, description="执行结果") error: Optional[str] = Field(default=None, description="错误信息") + + +class AgentChatAttachment(BaseModel): + """ + Agent 会话展示附件。 + """ + + kind: str = Field(..., description="附件类型") + url: str = Field(..., description="附件访问地址") + download_url: Optional[str] = Field(None, description="附件下载地址") + name: Optional[str] = Field(None, description="附件名称") + mime_type: Optional[str] = Field(None, description="MIME 类型") + size: Optional[int] = Field(None, description="附件大小") + local_path: Optional[str] = Field(None, description="服务端本地路径") + + +class AgentChatToolCall(BaseModel): + """ + Agent 会话工具调用展示项。 + """ + + id: str = Field(..., description="展示 ID") + message: str = Field(..., description="工具提示") + status: str = Field(default="done", description="工具状态") + + +class AgentChatChoiceButton(BaseModel): + """ + Agent 会话选择按钮。 + """ + + label: str = Field(..., description="按钮文案") + callback_data: str = Field(..., description="回调数据") + + +class AgentChatChoiceCard(BaseModel): + """ + Agent 会话选择卡片。 + """ + + id: str = Field(..., description="选择卡片 ID") + title: Optional[str] = Field(None, description="标题") + prompt: str = Field(default="", description="提示语") + buttons: list[AgentChatChoiceButton] = Field(default_factory=list, description="按钮列表") + status: str = Field(default="pending", description="选择状态") + selected_label: Optional[str] = Field(None, description="已选择文案") + selected_value: Optional[str] = Field(None, description="已选择值") + + +class AgentChatMessage(BaseModel): + """ + Agent 会话展示消息。 + """ + + id: str = Field(..., description="展示消息 ID") + role: str = Field(..., description="消息角色") + content: str = Field(default="", description="消息文本") + createdAt: Union[int, float] = Field(..., description="创建时间戳") + status: str = Field(default="done", description="消息状态") + tools: list[AgentChatToolCall] = Field(default_factory=list, description="工具提示列表") + attachments: list[AgentChatAttachment] = Field(default_factory=list, description="附件列表") + choices: list[AgentChatChoiceCard] = Field(default_factory=list, description="选择卡片列表") + + +class AgentChatSession(BaseModel): + """ + Agent 会话历史详情。 + """ + + id: Optional[int] = Field(None, description="数据库 ID") + session_id: str = Field(..., description="Agent 内部会话 ID") + client_session_id: Optional[str] = Field(None, description="客户端原始会话 ID") + title: Optional[str] = Field(None, description="会话标题") + preview: Optional[str] = Field(None, description="会话预览") + channel: Optional[str] = Field(None, description="消息渠道") + source: Optional[str] = Field(None, description="渠道来源") + user_id: Optional[str] = Field(None, description="用户 ID") + username: Optional[str] = Field(None, description="用户名") + original_chat_id: Optional[str] = Field(None, description="原聊天 ID") + message_count: int = Field(default=0, description="展示消息数量") + created_at: Optional[str] = Field(None, description="创建时间") + updated_at: Optional[str] = Field(None, description="更新时间") + messages: list[AgentChatMessage] = Field(default_factory=list, description="展示消息列表") + + +class AgentChatDisplaySaveRequest(BaseModel): + """ + Agent 会话展示消息保存请求。 + """ + + messages: list[AgentChatMessage] = Field(default_factory=list, description="展示消息列表") + title: Optional[str] = Field(None, description="会话标题") diff --git a/app/schemas/message.py b/app/schemas/message.py index 400a6b51..d3624ad1 100644 --- a/app/schemas/message.py +++ b/app/schemas/message.py @@ -310,6 +310,8 @@ class AgentWebChatRequest(BaseModel): audio_refs: Optional[List[str]] = Field(default_factory=list) # 文件附件列表 files: Optional[List[AgentWebChatFile]] = Field(default_factory=list) + # 是否在展示历史中记录本轮用户消息 + echo_user: bool = Field(default=True) class AgentWebChoiceRequest(BaseModel): diff --git a/database/versions/8ab72c49d1e3_2_2_10.py b/database/versions/8ab72c49d1e3_2_2_10.py new file mode 100644 index 00000000..71020cd2 --- /dev/null +++ b/database/versions/8ab72c49d1e3_2_2_10.py @@ -0,0 +1,73 @@ +"""2.2.10 +新增 Agent 会话历史表 + +Revision ID: 8ab72c49d1e3 +Revises: 7c1a2b3d4e5f +Create Date: 2026-06-18 +""" + +from alembic import op +import sqlalchemy as sa + +revision = "8ab72c49d1e3" +down_revision = "7c1a2b3d4e5f" +branch_labels = None +depends_on = None + + +def _has_table(inspector: sa.Inspector, table_name: str) -> bool: + """检查数据表是否已存在。""" + return table_name in inspector.get_table_names() + + +def upgrade() -> None: + """升级数据库结构。""" + inspector = sa.inspect(op.get_bind()) + if _has_table(inspector, "agentchat"): + return + + op.create_table( + "agentchat", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("session_id", sa.String(), nullable=False), + sa.Column("client_session_id", sa.String(), nullable=True), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("username", sa.String(), nullable=True), + sa.Column("channel", sa.String(), nullable=True), + sa.Column("source", sa.String(), nullable=True), + sa.Column("original_chat_id", sa.String(), nullable=True), + sa.Column("title", sa.String(), nullable=True), + sa.Column("preview", sa.String(), nullable=True), + sa.Column("agent_messages", sa.JSON(), nullable=True), + sa.Column("display_messages", sa.JSON(), nullable=True), + sa.Column("message_count", sa.Integer(), nullable=True), + sa.Column("created_at", sa.String(), nullable=True), + sa.Column("updated_at", sa.String(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "ix_agentchat_session_user", + "agentchat", + ["session_id", "user_id"], + ) + op.create_index( + "ix_agentchat_user_updated", + "agentchat", + ["user_id", "updated_at", "id"], + ) + op.create_index( + "ix_agentchat_channel_updated", + "agentchat", + ["channel", "updated_at", "id"], + ) + + +def downgrade() -> None: + """回滚数据库结构。""" + inspector = sa.inspect(op.get_bind()) + if not _has_table(inspector, "agentchat"): + return + op.drop_index("ix_agentchat_channel_updated", table_name="agentchat") + op.drop_index("ix_agentchat_user_updated", table_name="agentchat") + op.drop_index("ix_agentchat_session_user", table_name="agentchat") + op.drop_table("agentchat") diff --git a/tests/test_agent_chat_history.py b/tests/test_agent_chat_history.py new file mode 100644 index 00000000..01c96d0d --- /dev/null +++ b/tests/test_agent_chat_history.py @@ -0,0 +1,141 @@ +import asyncio +from types import SimpleNamespace + +from langchain_core.messages import HumanMessage + +from app.agent import MoviePilotAgent +from app.agent.memory import memory_manager +from app.db.agentchat_oper import AgentChatOper + + +def test_agent_chat_oper_saves_display_messages_with_channel(): + """Agent 会话历史应保存展示消息与渠道标识。""" + oper = AgentChatOper() + oper.save_display_messages( + session_id="session-chat", + user_id="1", + username="admin", + channel="Telegram", + source="telegram-main", + original_chat_id="chat-1", + messages=[ + { + "id": "user-1", + "role": "user", + "content": "帮我看看下载器", + "createdAt": 1, + "status": "done", + "tools": [], + "attachments": [], + "choices": [], + } + ], + ) + chat = AgentChatOper().get(session_id="session-chat", user_id="1") + + assert chat.channel == "Telegram" + assert chat.source == "telegram-main" + assert chat.original_chat_id == "chat-1" + assert chat.message_count == 1 + assert chat.title == "帮我看看下载器" + + +def test_agent_chat_oper_keeps_generated_title_when_saving_display_messages(): + """保存展示消息时不应覆盖已生成的模型标题。""" + oper = AgentChatOper() + oper.update_title_if_empty( + session_id="session-title", + user_id="1", + username="admin", + channel="WebAgent", + source="web-agent", + title="下载器状态排查", + ) + oper.save_display_messages( + session_id="session-title", + user_id="1", + messages=[ + { + "id": "user-1", + "role": "user", + "content": "帮我看看下载器现在是不是正常", + "createdAt": 1, + "status": "done", + "tools": [], + "attachments": [], + "choices": [], + } + ], + title="帮我看看下载器现在是不是正常", + ) + + chat = AgentChatOper().get(session_id="session-title", user_id="1") + summary = AgentChatOper.to_summary(chat) + + assert chat.title == "下载器状态排查" + assert "preview" not in summary + assert "messages" not in summary + + +def test_agent_prepare_chat_title_generates_title(monkeypatch): + """首次调用 Agent 时应使用模型生成会话标题并写入渠道信息。""" + + class FakeTitleModel: + """测试用标题模型。""" + + async def ainvoke(self, messages): + """返回固定标题。""" + assert "标题生成器" in messages[0].content + assert messages[1].content == "帮我看看下载器现在是不是正常" + return SimpleNamespace(content="「下载器状态排查」") + + async def fake_initialize_llm(self, streaming=False): + """返回测试标题模型。""" + return FakeTitleModel() + + monkeypatch.setattr(MoviePilotAgent, "_initialize_llm", fake_initialize_llm) + agent = MoviePilotAgent( + session_id="session-ai-title", + user_id="3", + channel="WebAgent", + source="web-agent", + username="admin", + ) + + asyncio.run(agent.prepare_chat_title("帮我看看下载器现在是不是正常")) + chat = AgentChatOper().get(session_id="session-ai-title", user_id="3") + + assert chat.title == "下载器状态排查" + assert chat.channel == "WebAgent" + assert chat.source == "web-agent" + + +def test_memory_manager_restores_agent_messages_from_database(): + """内存缓存缺失时应从 Agent 会话历史表恢复原始 messages。""" + session_id = "session-memory" + user_id = "2" + memory_manager.clear_memory(session_id, user_id) + AgentChatOper().save_agent_messages( + session_id=session_id, + user_id=user_id, + messages=[ + { + "type": "human", + "data": { + "content": "继续之前的话题", + "additional_kwargs": {}, + "response_metadata": {}, + "type": "human", + "name": None, + "id": None, + "example": False, + }, + } + ], + ) + + messages = memory_manager.get_agent_messages(session_id=session_id, user_id=user_id) + + assert len(messages) == 1 + assert isinstance(messages[0], HumanMessage) + assert messages[0].content == "继续之前的话题" diff --git a/tests/test_web_agent_stream.py b/tests/test_web_agent_stream.py index 34fe6ffc..2e5841e9 100644 --- a/tests/test_web_agent_stream.py +++ b/tests/test_web_agent_stream.py @@ -8,6 +8,8 @@ from app.agent import ReplyMode from app.api.endpoints.agent import ( _WebAgentMoviePilotAgent, _WEB_AGENT_FILE_REGISTRY, + _apply_web_agent_display_event, + _build_web_agent_input_attachments, _build_web_agent_notification_events, _build_web_agent_session_id, _prepare_web_agent_audio_attachment_path, @@ -16,6 +18,7 @@ from app.api.endpoints.agent import ( _resolve_web_agent_choice_payload, _split_web_agent_output, ) +from app.db.agentchat_oper import AgentChatOper from app.helper.interaction import AgentInteractionOption, agent_interaction_manager from app.schemas.message import ChannelCapability, ChannelCapabilityManager from app.schemas.types import MessageChannel, NotificationType @@ -74,6 +77,73 @@ def test_build_web_agent_session_id_is_stable_per_user_and_seed(): assert first.startswith("web-agent:") +def test_build_web_agent_session_id_reuses_accessible_history(): + """传入已有历史会话 ID 时应直接复用,避免跨渠道继续对话丢上下文。""" + user = SimpleNamespace(id=1, name="admin", is_superuser=True) + AgentChatOper().save_display_messages( + session_id="telegram-session", + user_id="telegram-user", + username="tester", + channel=MessageChannel.Telegram.value, + source="telegram-main", + messages=[], + title="Telegram 会话", + ) + + assert _build_web_agent_session_id(user, "telegram-session") == "telegram-session" + + +def test_apply_web_agent_display_event_updates_snapshot(): + """WebAgent SSE 事件应可聚合为服务端展示快照。""" + message = { + "id": "assistant-1", + "role": "assistant", + "content": "", + "createdAt": 1, + "status": "streaming", + "tools": [], + "attachments": [], + "choices": [], + } + + _apply_web_agent_display_event({"type": "delta", "content": "你好"}, message) + _apply_web_agent_display_event({"type": "tool", "message": "查询订阅"}, message) + _apply_web_agent_display_event( + { + "type": "attachment", + "attachment": {"kind": "file", "url": "message/agent/file/a"}, + }, + message, + ) + _apply_web_agent_display_event({"type": "done"}, message) + + assert message["content"] == "你好" + assert message["status"] == "done" + assert len(message["tools"]) == 1 + assert message["tools"][0]["message"] == "查询订阅" + assert message["tools"][0]["status"] == "done" + assert message["attachments"] == [{"kind": "file", "url": "message/agent/file/a"}] + + +def test_build_web_agent_input_attachments_marks_kinds(): + """WebAgent 用户输入附件应转换为可展示的附件记录。""" + attachments = _build_web_agent_input_attachments( + images=["data:image/png;base64,abc"], + files=[ + { + "ref": "message/agent/file/file-1", + "name": "report.txt", + "mime_type": "text/plain", + "size": 5, + } + ], + audio_refs=["message/agent/file/audio-1"], + ) + + assert [item["kind"] for item in attachments] == ["image", "file", "audio"] + assert attachments[1]["name"] == "report.txt" + + def test_web_agent_admin_context_uses_current_user_id(): """Web Agent 工具权限应按当前登录用户 ID 判断管理员身份。""" agent = _WebAgentMoviePilotAgent(