mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-21 07:24:29 +08:00
feat: persist agent chat history
This commit is contained in:
@@ -16,6 +16,7 @@ from langchain.agents.middleware import (
|
||||
from langchain_core.messages import ( # noqa: F401
|
||||
HumanMessage,
|
||||
BaseMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
import warnings
|
||||
@@ -50,6 +51,7 @@ from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager
|
||||
from app.db.agentchat_oper import AgentChatOper
|
||||
from app.db.user_oper import UserOper
|
||||
from app.log import logger
|
||||
from app.schemas import AgentLLMProviderEventData, AgentTokensUsageEventData, Notification, NotificationType
|
||||
@@ -229,6 +231,11 @@ HEARTBEAT_SESSION_PREFIX = "__agent_heartbeat_"
|
||||
UNSUPPORTED_IMAGE_INPUT_MESSAGE = "当前模型不支持图片输入,请更换支持图片输入的模型,或在系统设置中关闭图片输入支持后重试。"
|
||||
AGENT_EXECUTION_ERROR_PREFIX = "智能助手执行失败"
|
||||
AGENT_EXECUTION_ERROR_MESSAGE = "智能助手执行失败,请稍后重试。"
|
||||
AGENT_DISPLAY_HISTORY_SKIP_CHANNELS = {MessageChannel.WebAgent.value}
|
||||
AGENT_CHAT_TITLE_PROMPT = (
|
||||
"你是 MoviePilot 智能助手的会话标题生成器。请根据用户的第一条消息生成一个简洁中文标题,"
|
||||
"不超过 18 个汉字或 36 个英文字符,只输出标题本身,不要引号、编号或解释。"
|
||||
)
|
||||
|
||||
|
||||
class MoviePilotAgent:
|
||||
@@ -269,6 +276,129 @@ class MoviePilotAgent:
|
||||
# 流式token管理
|
||||
self.stream_handler = StreamingHandler()
|
||||
|
||||
@staticmethod
|
||||
def _current_timestamp_ms() -> int:
|
||||
"""返回当前毫秒时间戳。"""
|
||||
return int(datetime.now().timestamp() * 1000)
|
||||
|
||||
@classmethod
|
||||
def build_display_message(
|
||||
cls,
|
||||
role: str,
|
||||
content: str = "",
|
||||
attachments: Optional[List[dict]] = None,
|
||||
status: str = "done",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
构造可展示的 Agent 会话消息。
|
||||
"""
|
||||
return {
|
||||
"id": f"{role}-{uuid.uuid4().hex}",
|
||||
"role": role,
|
||||
"content": content or "",
|
||||
"createdAt": cls._current_timestamp_ms(),
|
||||
"status": status,
|
||||
"tools": [],
|
||||
"attachments": attachments or [],
|
||||
"choices": [],
|
||||
}
|
||||
|
||||
def _should_save_display_history(self) -> bool:
|
||||
"""
|
||||
判断当前 Agent 是否由通用渠道保存展示历史。
|
||||
"""
|
||||
return bool(
|
||||
self.channel
|
||||
and self.source
|
||||
and self.channel not in AGENT_DISPLAY_HISTORY_SKIP_CHANNELS
|
||||
)
|
||||
|
||||
def _save_display_history_messages(self, messages: List[dict]) -> None:
|
||||
"""
|
||||
将一组可见消息追加到 Agent 会话历史表。
|
||||
"""
|
||||
if not messages or not self._should_save_display_history():
|
||||
return
|
||||
try:
|
||||
AgentChatOper().append_display_messages(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
username=self.username,
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
original_chat_id=self.original_chat_id,
|
||||
messages=messages,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"写入Agent展示历史失败: {e}")
|
||||
|
||||
def _save_assistant_display_message_once(self, message: str) -> None:
|
||||
"""
|
||||
保存一条助手回复展示记录,并标记本轮已写入。
|
||||
"""
|
||||
if not message or self._tool_context.get("assistant_display_saved"):
|
||||
return
|
||||
self._save_display_history_messages(
|
||||
[self.build_display_message(role="assistant", content=message)]
|
||||
)
|
||||
self._tool_context["assistant_display_saved"] = True
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_chat_title(value: str) -> str:
|
||||
"""清理模型返回的会话标题。"""
|
||||
title = str(value or "").strip()
|
||||
title = re.sub(r"^[#\-*\d.、\s]+", "", title)
|
||||
title = title.strip("「」『』“”\"'` \n\t")
|
||||
title = re.sub(r"\s+", " ", title)
|
||||
return title[:120]
|
||||
|
||||
async def _generate_chat_title(self, message: str) -> str:
|
||||
"""
|
||||
使用当前 Agent 模型生成会话标题。
|
||||
"""
|
||||
if not str(message or "").strip():
|
||||
return ""
|
||||
model = await self._initialize_llm(streaming=False)
|
||||
response = await model.ainvoke(
|
||||
[
|
||||
SystemMessage(content=AGENT_CHAT_TITLE_PROMPT),
|
||||
HumanMessage(content=str(message).strip()[:1000]),
|
||||
]
|
||||
)
|
||||
content = LLMHelper._extract_text_content(getattr(response, "content", response))
|
||||
return self._sanitize_chat_title(content)
|
||||
|
||||
async def prepare_chat_title(self, message: str) -> None:
|
||||
"""
|
||||
首次对话时生成并保存会话标题。
|
||||
"""
|
||||
if self._tool_context.get("chat_title_prepared"):
|
||||
return
|
||||
self._tool_context["chat_title_prepared"] = True
|
||||
try:
|
||||
chat = await run_in_threadpool(
|
||||
AgentChatOper().get,
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
)
|
||||
if chat and AgentChatOper.has_custom_title(chat.title):
|
||||
return
|
||||
title = await self._generate_chat_title(message)
|
||||
if not title:
|
||||
return
|
||||
await run_in_threadpool(
|
||||
AgentChatOper().update_title_if_empty,
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
title=title,
|
||||
username=self.username,
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
original_chat_id=self.original_chat_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"生成Agent会话标题失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any) -> Optional[int]:
|
||||
if value is None:
|
||||
@@ -944,6 +1074,7 @@ class MoviePilotAgent:
|
||||
"""
|
||||
处理用户消息,流式推理并返回 Agent 回复
|
||||
"""
|
||||
user_display_saved = False
|
||||
try:
|
||||
logger.info(
|
||||
f"Agent推理: session_id={self.session_id}, input={message}, "
|
||||
@@ -982,6 +1113,21 @@ class MoviePilotAgent:
|
||||
for img in images or []:
|
||||
content.append({"type": "image_url", "image_url": {"url": img}})
|
||||
messages.append(HumanMessage(content=content))
|
||||
await self.prepare_chat_title(message)
|
||||
self._save_display_history_messages(
|
||||
[
|
||||
self.build_display_message(
|
||||
role="user",
|
||||
content=message,
|
||||
attachments=self._build_input_display_attachments(
|
||||
images=images,
|
||||
files=files,
|
||||
has_audio_input=has_audio_input,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
user_display_saved = True
|
||||
|
||||
# 执行推理
|
||||
result = await self._execute_agent(messages)
|
||||
@@ -992,11 +1138,63 @@ class MoviePilotAgent:
|
||||
except Exception as e:
|
||||
error_message = f"处理消息时发生错误: {str(e)}"
|
||||
logger.error(error_message)
|
||||
if not user_display_saved:
|
||||
self._save_display_history_messages(
|
||||
[self.build_display_message(role="user", content=message)]
|
||||
)
|
||||
if not self.should_dispatch_reply:
|
||||
raise
|
||||
await self.send_agent_message(error_message)
|
||||
return error_message
|
||||
|
||||
@staticmethod
|
||||
def _guess_file_attachment_kind(mime_type: Optional[str], fallback: str = "file") -> str:
|
||||
"""
|
||||
根据 MIME 类型推断展示附件类型。
|
||||
"""
|
||||
if mime_type and mime_type.startswith("image/"):
|
||||
return "image"
|
||||
if mime_type and mime_type.startswith("audio/"):
|
||||
return "audio"
|
||||
return fallback
|
||||
|
||||
def _build_input_display_attachments(
|
||||
self,
|
||||
images: Optional[List[str]] = None,
|
||||
files: Optional[List[dict]] = None,
|
||||
has_audio_input: bool = False,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
构造用户输入附件的展示记录。
|
||||
"""
|
||||
attachments: List[dict] = []
|
||||
for index, image in enumerate(images or [], start=1):
|
||||
attachments.append(
|
||||
{
|
||||
"kind": "image",
|
||||
"url": image,
|
||||
"download_url": image,
|
||||
"name": f"image-{index}",
|
||||
"mime_type": "image/*",
|
||||
}
|
||||
)
|
||||
for index, file in enumerate(files or [], start=1):
|
||||
ref = file.get("ref") or file.get("local_path") or ""
|
||||
mime_type = file.get("mime_type")
|
||||
fallback = "audio" if has_audio_input and mime_type == "audio/*" else "file"
|
||||
attachments.append(
|
||||
{
|
||||
"kind": self._guess_file_attachment_kind(mime_type, fallback=fallback),
|
||||
"url": ref,
|
||||
"download_url": ref,
|
||||
"name": file.get("name") or f"attachment-{index}",
|
||||
"mime_type": mime_type,
|
||||
"size": file.get("size"),
|
||||
"local_path": file.get("local_path"),
|
||||
}
|
||||
)
|
||||
return attachments
|
||||
|
||||
async def _stream_agent_tokens(
|
||||
self, agent, messages: dict, config: dict, on_token: Callable[[str], None]
|
||||
):
|
||||
@@ -1164,6 +1362,17 @@ class MoviePilotAgent:
|
||||
# 非流式渠道:发送最终回复
|
||||
await self.send_agent_message(final_text)
|
||||
|
||||
display_text = self._streamed_output
|
||||
if not display_text:
|
||||
final_messages = agent.get_state(agent_config).values.get(
|
||||
"messages", []
|
||||
)
|
||||
for msg in reversed(final_messages):
|
||||
if hasattr(msg, "type") and msg.type == "ai" and msg.content:
|
||||
display_text = self._extract_text_content(msg.content).strip()
|
||||
break
|
||||
self._save_assistant_display_message_once(display_text)
|
||||
|
||||
# 保存消息
|
||||
memory_manager.save_agent_messages(
|
||||
session_id=self.session_id,
|
||||
@@ -1200,6 +1409,7 @@ class MoviePilotAgent:
|
||||
"""
|
||||
通过原渠道发送消息给用户
|
||||
"""
|
||||
self._save_assistant_display_message_once(message)
|
||||
await AgentChain().async_post_message(
|
||||
Notification(
|
||||
channel=self.channel,
|
||||
|
||||
@@ -4,9 +4,10 @@ import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import BaseMessage, messages_from_dict, messages_to_dict
|
||||
|
||||
from app.core.config import settings
|
||||
from app.db.agentchat_oper import AgentChatOper
|
||||
from app.log import logger
|
||||
from app.schemas.agent import ConversationMemory
|
||||
|
||||
@@ -70,24 +71,43 @@ class MemoryManager:
|
||||
self, session_id: str, user_id: str
|
||||
) -> List[BaseMessage]:
|
||||
"""
|
||||
为Agent获取最近的消息(仅内存缓存)
|
||||
为Agent获取最近的消息。
|
||||
|
||||
如果消息Token数量超过模型最大上下文长度的阀值,会自动进行摘要裁剪
|
||||
优先使用内存缓存,缓存不存在时从数据库恢复上一轮持久化的原始 messages。
|
||||
"""
|
||||
memory = self.get_memory(session_id, user_id)
|
||||
if not memory:
|
||||
if memory:
|
||||
return memory.messages
|
||||
|
||||
try:
|
||||
chat = AgentChatOper().get(session_id=session_id, user_id=user_id)
|
||||
if not chat:
|
||||
chat = AgentChatOper().get(session_id=session_id)
|
||||
except Exception as e:
|
||||
logger.debug(f"读取持久化Agent会话失败: {e}")
|
||||
return []
|
||||
if not chat or not chat.agent_messages:
|
||||
return []
|
||||
|
||||
# 获取所有消息
|
||||
try:
|
||||
messages = messages_from_dict(chat.agent_messages)
|
||||
except Exception as e:
|
||||
logger.debug(f"恢复持久化Agent消息失败: {e}")
|
||||
return []
|
||||
|
||||
memory = ConversationMemory(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
messages=messages,
|
||||
)
|
||||
self.save_memory(memory)
|
||||
return memory.messages
|
||||
|
||||
def save_agent_messages(
|
||||
self, session_id: str, user_id: str, messages: List[BaseMessage]
|
||||
):
|
||||
"""
|
||||
保存Agent消息(仅内存缓存)
|
||||
|
||||
注意:Redis中的记忆通过TTL机制自动过期,这里只更新内存缓存,Redis会在下次访问时自动过期
|
||||
保存Agent消息到内存缓存与持久化会话表。
|
||||
"""
|
||||
memory = self.get_memory(session_id, user_id)
|
||||
if not memory:
|
||||
@@ -98,6 +118,14 @@ class MemoryManager:
|
||||
|
||||
# 更新内存缓存
|
||||
self.save_memory(memory)
|
||||
try:
|
||||
AgentChatOper().save_agent_messages(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
messages=messages_to_dict(messages),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"持久化Agent消息失败: {e}")
|
||||
|
||||
def save_memory(self, memory: ConversationMemory):
|
||||
"""
|
||||
|
||||
@@ -10,13 +10,18 @@ from pathlib import Path
|
||||
from typing import Any, AsyncIterator, Callable, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, Request, UploadFile, status
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app import schemas
|
||||
from app.agent import MoviePilotAgent, ReplyMode, StreamingHandler
|
||||
from app.agent.llm.capability import AgentCapabilityManager
|
||||
from app.core.config import global_vars, settings
|
||||
from app.db import get_async_db
|
||||
from app.db.agentchat_oper import AgentChatOper
|
||||
from app.db.models import User
|
||||
from app.db.models.agentchat import AgentChat
|
||||
from app.db.user_oper import UserOper, get_current_active_user
|
||||
from app.helper.interaction import agent_interaction_manager
|
||||
from app.log import logger
|
||||
@@ -152,11 +157,122 @@ def _build_web_agent_session_id(user: User, session_id: Optional[str]) -> str:
|
||||
:return: 可用于 Agent 记忆隔离的服务端会话 ID
|
||||
"""
|
||||
seed = str(session_id or "").strip() or uuid.uuid4().hex
|
||||
if seed.startswith(WEB_AGENT_SESSION_PREFIX):
|
||||
return seed
|
||||
try:
|
||||
existing_chat = AgentChatOper().get(session_id=seed)
|
||||
if existing_chat and _can_access_agent_chat(existing_chat, user):
|
||||
return seed
|
||||
except Exception as e:
|
||||
logger.debug(f"读取WebAgent历史会话失败: {e}")
|
||||
user_part = user.name or str(user.id)
|
||||
digest = hashlib.sha256(f"{user_part}:{seed}".encode("utf-8")).hexdigest()
|
||||
return f"{WEB_AGENT_SESSION_PREFIX}{digest[:32]}"
|
||||
|
||||
|
||||
def _can_access_agent_chat(chat: AgentChat, user: User) -> bool:
|
||||
"""
|
||||
判断当前登录用户是否可以访问指定 Agent 会话。
|
||||
|
||||
超级用户可查看所有渠道历史;普通用户仅能查看 user_id 或 username 匹配自己的会话。
|
||||
"""
|
||||
if not chat or not user:
|
||||
return False
|
||||
if getattr(user, "is_superuser", False):
|
||||
return True
|
||||
user_id = str(user.id)
|
||||
username = str(user.name or "")
|
||||
return chat.user_id == user_id or (bool(username) and chat.username == username)
|
||||
|
||||
|
||||
async def _get_accessible_agent_chat(
|
||||
oper: AgentChatOper, session_id: str, user: User
|
||||
) -> Optional[AgentChat]:
|
||||
"""
|
||||
读取当前用户可访问的 Agent 会话。
|
||||
"""
|
||||
chat = await oper.async_get(session_id=session_id)
|
||||
if not chat or not _can_access_agent_chat(chat, user):
|
||||
return None
|
||||
return chat
|
||||
|
||||
|
||||
def _apply_web_agent_display_event(event: dict, assistant_message: dict) -> None:
|
||||
"""
|
||||
将 WebAgent SSE 事件同步应用到服务端展示消息快照。
|
||||
"""
|
||||
event_type = event.get("type")
|
||||
if event_type == "delta":
|
||||
assistant_message["content"] += event.get("content") or ""
|
||||
elif event_type == "tool":
|
||||
for tool in assistant_message["tools"]:
|
||||
tool["status"] = "done"
|
||||
assistant_message["tools"].append(
|
||||
{
|
||||
"id": f"tool-{uuid.uuid4().hex}",
|
||||
"message": str(event.get("message") or "").replace("=>", "", 1).strip(),
|
||||
"status": "running",
|
||||
}
|
||||
)
|
||||
elif event_type == "attachment" and event.get("attachment"):
|
||||
assistant_message["attachments"].append(event["attachment"])
|
||||
elif event_type == "choice" and event.get("choice"):
|
||||
assistant_message["choices"].append({**event["choice"], "status": "pending"})
|
||||
elif event_type == "error":
|
||||
assistant_message["status"] = "error"
|
||||
assistant_message["content"] = (
|
||||
assistant_message["content"]
|
||||
or event.get("message")
|
||||
or "智能助手响应失败"
|
||||
)
|
||||
for tool in assistant_message["tools"]:
|
||||
tool["status"] = "done"
|
||||
elif event_type == "done":
|
||||
if assistant_message.get("status") != "error":
|
||||
assistant_message["status"] = "done"
|
||||
for tool in assistant_message["tools"]:
|
||||
tool["status"] = "done"
|
||||
|
||||
|
||||
def _save_web_agent_display_snapshot(
|
||||
*,
|
||||
session_id: str,
|
||||
current_user: User,
|
||||
messages: list[dict],
|
||||
client_session_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
保存 WebAgent 当前展示消息快照。
|
||||
"""
|
||||
try:
|
||||
oper = AgentChatOper()
|
||||
existing_chat = oper.get(session_id=session_id)
|
||||
AgentChatOper().save_display_messages(
|
||||
session_id=session_id,
|
||||
user_id=(existing_chat.user_id if existing_chat else str(current_user.id)),
|
||||
username=(existing_chat.username if existing_chat else current_user.name),
|
||||
channel=(
|
||||
existing_chat.channel
|
||||
if existing_chat and existing_chat.channel
|
||||
else MessageChannel.WebAgent
|
||||
),
|
||||
source=(
|
||||
existing_chat.source
|
||||
if existing_chat and existing_chat.source
|
||||
else WEB_AGENT_SOURCE
|
||||
),
|
||||
original_chat_id=existing_chat.original_chat_id if existing_chat else None,
|
||||
client_session_id=(
|
||||
existing_chat.client_session_id
|
||||
if existing_chat and existing_chat.client_session_id
|
||||
else client_session_id
|
||||
),
|
||||
messages=messages,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"保存WebAgent展示历史失败: {e}")
|
||||
|
||||
|
||||
def _build_web_agent_sse(event_type: str, data: Optional[dict] = None) -> str:
|
||||
"""
|
||||
构建 Web Agent SSE 消息。
|
||||
@@ -299,6 +415,52 @@ def _build_web_agent_url_attachment(
|
||||
}
|
||||
|
||||
|
||||
def _build_web_agent_input_attachments(
|
||||
images: list[str],
|
||||
files: list[dict],
|
||||
audio_refs: list[str],
|
||||
) -> list[dict]:
|
||||
"""
|
||||
构造 WebAgent 用户输入附件展示记录。
|
||||
"""
|
||||
attachments = []
|
||||
for index, image in enumerate(images or [], start=1):
|
||||
attachments.append(
|
||||
{
|
||||
"kind": "image",
|
||||
"url": image,
|
||||
"download_url": image,
|
||||
"name": f"image-{index}",
|
||||
"mime_type": "image/*",
|
||||
}
|
||||
)
|
||||
for index, file in enumerate(files or [], start=1):
|
||||
ref = file.get("ref") or file.get("url") or file.get("local_path") or ""
|
||||
mime_type = file.get("mime_type")
|
||||
attachments.append(
|
||||
{
|
||||
"kind": _guess_web_agent_attachment_kind(mime_type),
|
||||
"url": ref,
|
||||
"download_url": ref,
|
||||
"name": file.get("name") or f"attachment-{index}",
|
||||
"mime_type": mime_type,
|
||||
"size": file.get("size"),
|
||||
"local_path": file.get("local_path"),
|
||||
}
|
||||
)
|
||||
for index, audio_ref in enumerate(audio_refs or [], start=1):
|
||||
attachments.append(
|
||||
{
|
||||
"kind": "audio",
|
||||
"url": audio_ref,
|
||||
"download_url": audio_ref,
|
||||
"name": f"voice-{index}",
|
||||
"mime_type": "audio/*",
|
||||
}
|
||||
)
|
||||
return attachments
|
||||
|
||||
|
||||
def _register_web_agent_file(
|
||||
file_path: Optional[str],
|
||||
file_name: Optional[str] = None,
|
||||
@@ -790,6 +952,116 @@ async def web_agent_callback(
|
||||
return schemas.Response(success=True, data=result)
|
||||
|
||||
|
||||
@router.get("/sessions", summary="获取 Agent 历史会话", response_model=schemas.Response)
|
||||
async def list_agent_chat_sessions(
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
) -> schemas.Response:
|
||||
"""
|
||||
获取当前用户可访问的 Agent 历史会话列表。
|
||||
|
||||
:param current_user: 当前登录用户
|
||||
:param db: 异步数据库会话
|
||||
:param page: 页码
|
||||
:param count: 每页数量
|
||||
:return: 会话摘要列表
|
||||
"""
|
||||
user_id = None if current_user.is_superuser else str(current_user.id)
|
||||
username = None if current_user.is_superuser else current_user.name
|
||||
chats = await AgentChatOper(db).async_list_by_page(
|
||||
page=page,
|
||||
count=count,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
)
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
data=[AgentChatOper.to_summary(chat) for chat in chats],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}", summary="获取 Agent 历史会话详情", response_model=schemas.Response)
|
||||
async def get_agent_chat_session(
|
||||
session_id: str,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
) -> schemas.Response:
|
||||
"""
|
||||
获取一条 Agent 历史会话详情。
|
||||
|
||||
:param session_id: Agent 会话 ID
|
||||
:param current_user: 当前登录用户
|
||||
:param db: 异步数据库会话
|
||||
:return: 会话详情
|
||||
"""
|
||||
chat = await _get_accessible_agent_chat(AgentChatOper(db), session_id, current_user)
|
||||
if not chat:
|
||||
return schemas.Response(success=False, message="会话不存在或无权访问")
|
||||
return schemas.Response(success=True, data=AgentChatOper.to_detail(chat))
|
||||
|
||||
|
||||
@router.put("/sessions/{session_id}/display", summary="保存 Agent 展示会话", response_model=schemas.Response)
|
||||
async def save_agent_chat_display(
|
||||
session_id: str,
|
||||
payload: schemas.AgentChatDisplaySaveRequest,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
) -> schemas.Response:
|
||||
"""
|
||||
保存前端聚合后的 Agent 展示消息。
|
||||
|
||||
:param session_id: Agent 会话 ID
|
||||
:param payload: 展示消息保存请求
|
||||
:param current_user: 当前登录用户
|
||||
:param db: 异步数据库会话
|
||||
:return: 保存后的会话摘要
|
||||
"""
|
||||
oper = AgentChatOper(db)
|
||||
existing_chat = await oper.async_get(session_id=session_id)
|
||||
if existing_chat and not _can_access_agent_chat(existing_chat, current_user):
|
||||
return schemas.Response(success=False, message="会话不存在或无权访问")
|
||||
|
||||
messages = [
|
||||
message.model_dump(exclude_none=True)
|
||||
for message in payload.messages
|
||||
]
|
||||
await run_in_threadpool(
|
||||
_save_web_agent_display_snapshot,
|
||||
session_id=session_id,
|
||||
current_user=current_user,
|
||||
messages=messages,
|
||||
client_session_id=existing_chat.client_session_id if existing_chat else session_id,
|
||||
)
|
||||
chat = await oper.async_get(session_id=session_id)
|
||||
if not chat:
|
||||
return schemas.Response(success=False, message="会话保存失败")
|
||||
return schemas.Response(success=True, data=AgentChatOper.to_summary(chat))
|
||||
|
||||
|
||||
@router.delete("/sessions/{session_id}", summary="删除 Agent 历史会话", response_model=schemas.Response)
|
||||
async def delete_agent_chat_session(
|
||||
session_id: str,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
) -> schemas.Response:
|
||||
"""
|
||||
删除一条 Agent 历史会话。
|
||||
|
||||
:param session_id: Agent 会话 ID
|
||||
:param current_user: 当前登录用户
|
||||
:param db: 异步数据库会话
|
||||
:return: 删除结果
|
||||
"""
|
||||
oper = AgentChatOper(db)
|
||||
chat = await _get_accessible_agent_chat(oper, session_id, current_user)
|
||||
if not chat:
|
||||
return schemas.Response(success=False, message="会话不存在或无权访问")
|
||||
deleted = await oper.async_delete(session_id=session_id)
|
||||
return schemas.Response(success=deleted, message="删除成功" if deleted else "删除失败")
|
||||
|
||||
|
||||
@router.post("/stream", summary="Web智能助手流式对话")
|
||||
async def web_agent_stream(
|
||||
payload: schemas.AgentWebChatRequest,
|
||||
@@ -843,6 +1115,28 @@ async def web_agent_stream(
|
||||
session_id = _build_web_agent_session_id(current_user, payload.session_id)
|
||||
event_queue: asyncio.Queue = asyncio.Queue()
|
||||
last_output = ""
|
||||
user_attachments = _build_web_agent_input_attachments(
|
||||
images=payload.images or [],
|
||||
files=[
|
||||
file.model_dump(exclude_none=True)
|
||||
for file in (payload.files or [])
|
||||
],
|
||||
audio_refs=payload.audio_refs or [],
|
||||
)
|
||||
display_messages = []
|
||||
if payload.echo_user:
|
||||
display_messages.append(
|
||||
MoviePilotAgent.build_display_message(
|
||||
role="user",
|
||||
content=prompt,
|
||||
attachments=user_attachments,
|
||||
)
|
||||
)
|
||||
assistant_display_message = MoviePilotAgent.build_display_message(
|
||||
role="assistant",
|
||||
status="streaming",
|
||||
)
|
||||
display_messages.append(assistant_display_message)
|
||||
|
||||
def output_callback(output: str) -> None:
|
||||
"""
|
||||
@@ -852,6 +1146,7 @@ async def web_agent_stream(
|
||||
delta = output[len(last_output):] if output.startswith(last_output) else output
|
||||
last_output = output
|
||||
for item in _split_web_agent_output(delta):
|
||||
_apply_web_agent_display_event(item, assistant_display_message)
|
||||
event_queue.put_nowait(item)
|
||||
|
||||
def notification_callback(notification: schemas.Notification) -> None:
|
||||
@@ -859,6 +1154,7 @@ async def web_agent_stream(
|
||||
接收 Agent 工具主动发送的 Web 通知。
|
||||
"""
|
||||
for item in _build_web_agent_notification_events(notification):
|
||||
_apply_web_agent_display_event(item, assistant_display_message)
|
||||
event_queue.put_nowait(item)
|
||||
|
||||
async def event_generator() -> AsyncIterator[str]:
|
||||
@@ -897,17 +1193,29 @@ async def web_agent_stream(
|
||||
)
|
||||
except Exception as err:
|
||||
logger.error(f"Web智能助手执行失败: {str(err)}")
|
||||
await event_queue.put(
|
||||
{"type": "error", "message": f"智能助手执行失败: {str(err)}"}
|
||||
)
|
||||
error_event = {
|
||||
"type": "error",
|
||||
"message": f"智能助手执行失败: {str(err)}",
|
||||
}
|
||||
_apply_web_agent_display_event(error_event, assistant_display_message)
|
||||
await event_queue.put(error_event)
|
||||
finally:
|
||||
await event_queue.put({"type": "done"})
|
||||
done_event = {"type": "done"}
|
||||
_apply_web_agent_display_event(done_event, assistant_display_message)
|
||||
await run_in_threadpool(
|
||||
_save_web_agent_display_snapshot,
|
||||
session_id=session_id,
|
||||
current_user=current_user,
|
||||
messages=display_messages,
|
||||
client_session_id=payload.session_id or session_id,
|
||||
)
|
||||
await event_queue.put(done_event)
|
||||
|
||||
task = asyncio.create_task(run_agent())
|
||||
try:
|
||||
yield _build_web_agent_sse(
|
||||
"start",
|
||||
{"session_id": payload.session_id or session_id},
|
||||
{"session_id": session_id},
|
||||
)
|
||||
while not global_vars.is_system_stopped:
|
||||
if await request.is_disconnected():
|
||||
|
||||
336
app/db/agentchat_oper.py
Normal file
336
app/db/agentchat_oper.py
Normal file
@@ -0,0 +1,336 @@
|
||||
import time
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import DbOper
|
||||
from app.db.models.agentchat import AgentChat
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
DEFAULT_AGENT_CHAT_TITLE = "未命名会话"
|
||||
|
||||
|
||||
class AgentChatOper(DbOper):
|
||||
"""
|
||||
Agent 会话历史数据管理。
|
||||
"""
|
||||
|
||||
def __init__(self, db: Union[Session, AsyncSession] = None):
|
||||
super().__init__(db)
|
||||
|
||||
@staticmethod
|
||||
def _now() -> str:
|
||||
"""返回数据库统一使用的当前时间字符串。"""
|
||||
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
@staticmethod
|
||||
def _channel_value(channel: Optional[Union[MessageChannel, str]]) -> Optional[str]:
|
||||
"""获取渠道枚举的字符串值。"""
|
||||
if isinstance(channel, MessageChannel):
|
||||
return channel.value
|
||||
return channel
|
||||
|
||||
@staticmethod
|
||||
def _normalize_messages(messages: Optional[list[dict]]) -> list[dict]:
|
||||
"""规范化展示消息列表,避免 JSON 字段存入 None。"""
|
||||
return messages if isinstance(messages, list) else []
|
||||
|
||||
@staticmethod
|
||||
def _normalize_title(value: Optional[str], messages: list[dict]) -> str:
|
||||
"""生成会话标题。"""
|
||||
if value and value.strip():
|
||||
return value.strip()[:120]
|
||||
for message in messages:
|
||||
if message.get("role") != "user":
|
||||
continue
|
||||
content = str(message.get("content") or "").strip()
|
||||
if content:
|
||||
return content.replace("\n", " ")[:120]
|
||||
attachments = message.get("attachments")
|
||||
if isinstance(attachments, list) and attachments:
|
||||
name = attachments[0].get("name") or "附件消息"
|
||||
return str(name)[:120]
|
||||
return DEFAULT_AGENT_CHAT_TITLE
|
||||
|
||||
@staticmethod
|
||||
def has_custom_title(value: Optional[str]) -> bool:
|
||||
"""判断会话是否已有真实标题。"""
|
||||
return bool(value and value.strip() and value.strip() != DEFAULT_AGENT_CHAT_TITLE)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_preview(messages: list[dict]) -> str:
|
||||
"""生成会话预览文本。"""
|
||||
for message in reversed(messages):
|
||||
content = str(message.get("content") or "").strip()
|
||||
if content:
|
||||
return content.replace("\n", " ")[:240]
|
||||
attachments = message.get("attachments")
|
||||
if isinstance(attachments, list) and attachments:
|
||||
name = attachments[0].get("name") or "附件消息"
|
||||
return str(name)[:240]
|
||||
return ""
|
||||
|
||||
def get(
|
||||
self, session_id: str, user_id: Optional[str] = None
|
||||
) -> Optional[AgentChat]:
|
||||
"""
|
||||
获取 Agent 会话。
|
||||
"""
|
||||
return AgentChat.get_by_session(self._db, session_id, user_id)
|
||||
|
||||
async def async_get(
|
||||
self, session_id: str, user_id: Optional[str] = None
|
||||
) -> Optional[AgentChat]:
|
||||
"""
|
||||
异步获取 Agent 会话。
|
||||
"""
|
||||
return await AgentChat.async_get_by_session(self._db, session_id, user_id)
|
||||
|
||||
def ensure_session(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
channel: Optional[Union[MessageChannel, str]] = None,
|
||||
source: Optional[str] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
client_session_id: Optional[str] = None,
|
||||
) -> AgentChat:
|
||||
"""
|
||||
确保 Agent 会话记录存在,并刷新基础渠道信息。
|
||||
"""
|
||||
now = self._now()
|
||||
chat = self.get(session_id=session_id, user_id=user_id)
|
||||
if not chat:
|
||||
chat = self.get(session_id=session_id)
|
||||
payload = {
|
||||
"user_id": user_id,
|
||||
"username": username,
|
||||
"channel": self._channel_value(channel),
|
||||
"source": source,
|
||||
"original_chat_id": original_chat_id,
|
||||
"client_session_id": client_session_id,
|
||||
"updated_at": now,
|
||||
}
|
||||
payload = {key: value for key, value in payload.items() if value is not None}
|
||||
if chat:
|
||||
chat.update(self._db, payload)
|
||||
return self.get(session_id=session_id, user_id=user_id) or self.get(session_id=session_id)
|
||||
|
||||
chat = AgentChat(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
channel=self._channel_value(channel),
|
||||
source=source,
|
||||
original_chat_id=original_chat_id,
|
||||
client_session_id=client_session_id,
|
||||
title=DEFAULT_AGENT_CHAT_TITLE,
|
||||
preview="",
|
||||
agent_messages=[],
|
||||
display_messages=[],
|
||||
message_count=0,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
chat.create(self._db)
|
||||
return self.get(session_id=session_id, user_id=user_id) or self.get(session_id=session_id)
|
||||
|
||||
def save_agent_messages(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: Optional[str],
|
||||
messages: list[dict],
|
||||
) -> None:
|
||||
"""
|
||||
保存可恢复 Agent 上下文的原始消息。
|
||||
"""
|
||||
chat = self.get(session_id=session_id, user_id=user_id)
|
||||
if not chat:
|
||||
chat = self.get(session_id=session_id)
|
||||
if not chat:
|
||||
chat = self.ensure_session(session_id=session_id, user_id=user_id)
|
||||
chat.update(
|
||||
self._db,
|
||||
{
|
||||
"agent_messages": messages or [],
|
||||
"updated_at": self._now(),
|
||||
},
|
||||
)
|
||||
|
||||
def update_title_if_empty(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: Optional[str],
|
||||
title: Optional[str],
|
||||
username: Optional[str] = None,
|
||||
channel: Optional[Union[MessageChannel, str]] = None,
|
||||
source: Optional[str] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
client_session_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
在会话尚未生成标题时写入标题。
|
||||
"""
|
||||
normalized_title = self._normalize_title(title, [])
|
||||
if normalized_title == DEFAULT_AGENT_CHAT_TITLE:
|
||||
return
|
||||
|
||||
chat = self.ensure_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
channel=channel,
|
||||
source=source,
|
||||
original_chat_id=original_chat_id,
|
||||
client_session_id=client_session_id,
|
||||
)
|
||||
if self.has_custom_title(chat.title):
|
||||
return
|
||||
chat.update(
|
||||
self._db,
|
||||
{
|
||||
"title": normalized_title,
|
||||
"updated_at": self._now(),
|
||||
},
|
||||
)
|
||||
|
||||
def save_display_messages(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: Optional[str] = None,
|
||||
messages: Optional[list[dict]] = None,
|
||||
username: Optional[str] = None,
|
||||
channel: Optional[Union[MessageChannel, str]] = None,
|
||||
source: Optional[str] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
client_session_id: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
) -> AgentChat:
|
||||
"""
|
||||
保存用户可见的 Agent 会话消息。
|
||||
"""
|
||||
normalized_messages = self._normalize_messages(messages)
|
||||
chat = self.ensure_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
channel=channel,
|
||||
source=source,
|
||||
original_chat_id=original_chat_id,
|
||||
client_session_id=client_session_id,
|
||||
)
|
||||
normalized_title = (
|
||||
chat.title
|
||||
if self.has_custom_title(chat.title)
|
||||
else self._normalize_title(title, normalized_messages)
|
||||
)
|
||||
chat.update(
|
||||
self._db,
|
||||
{
|
||||
"title": normalized_title,
|
||||
"preview": self._normalize_preview(normalized_messages),
|
||||
"display_messages": normalized_messages,
|
||||
"message_count": len(normalized_messages),
|
||||
"updated_at": self._now(),
|
||||
},
|
||||
)
|
||||
return self.get(session_id=session_id, user_id=user_id) or self.get(session_id=session_id)
|
||||
|
||||
def append_display_messages(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: Optional[str] = None,
|
||||
messages: Optional[list[dict]] = None,
|
||||
username: Optional[str] = None,
|
||||
channel: Optional[Union[MessageChannel, str]] = None,
|
||||
source: Optional[str] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
client_session_id: Optional[str] = None,
|
||||
) -> AgentChat:
|
||||
"""
|
||||
追加一组用户可见的 Agent 会话消息。
|
||||
"""
|
||||
chat = self.ensure_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
channel=channel,
|
||||
source=source,
|
||||
original_chat_id=original_chat_id,
|
||||
client_session_id=client_session_id,
|
||||
)
|
||||
display_messages = self._normalize_messages(chat.display_messages)
|
||||
display_messages.extend(self._normalize_messages(messages))
|
||||
title = chat.title if self.has_custom_title(chat.title) else None
|
||||
return self.save_display_messages(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
messages=display_messages,
|
||||
username=username or chat.username,
|
||||
channel=channel or chat.channel,
|
||||
source=source or chat.source,
|
||||
original_chat_id=original_chat_id or chat.original_chat_id,
|
||||
client_session_id=client_session_id or chat.client_session_id,
|
||||
title=title,
|
||||
)
|
||||
|
||||
async def async_list_by_page(
|
||||
self,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
user_id: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
) -> list[AgentChat]:
|
||||
"""
|
||||
异步分页获取 Agent 会话历史。
|
||||
"""
|
||||
return await AgentChat.async_list_by_page(
|
||||
self._db,
|
||||
page=page,
|
||||
count=count,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
)
|
||||
|
||||
async def async_delete(
|
||||
self, session_id: str, user_id: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
异步删除 Agent 会话历史。
|
||||
"""
|
||||
chat = await self.async_get(session_id=session_id, user_id=user_id)
|
||||
if not chat:
|
||||
return False
|
||||
await AgentChat.async_delete(self._db, chat.id)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def to_summary(chat: AgentChat) -> dict[str, Any]:
|
||||
"""
|
||||
转换为历史会话摘要。
|
||||
"""
|
||||
return {
|
||||
"id": chat.id,
|
||||
"session_id": chat.session_id,
|
||||
"client_session_id": chat.client_session_id,
|
||||
"title": chat.title,
|
||||
"channel": chat.channel,
|
||||
"source": chat.source,
|
||||
"user_id": chat.user_id,
|
||||
"username": chat.username,
|
||||
"original_chat_id": chat.original_chat_id,
|
||||
"message_count": chat.message_count or 0,
|
||||
"created_at": chat.created_at,
|
||||
"updated_at": chat.updated_at,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def to_detail(cls, chat: AgentChat) -> dict[str, Any]:
|
||||
"""
|
||||
转换为历史会话详情。
|
||||
"""
|
||||
data = cls.to_summary(chat)
|
||||
data["messages"] = chat.display_messages or []
|
||||
return data
|
||||
@@ -1,3 +1,4 @@
|
||||
from .agentchat import AgentChat
|
||||
from .downloadhistory import DownloadHistory, DownloadFiles
|
||||
from .mediaserver import MediaServerItem
|
||||
from .message import Message
|
||||
|
||||
130
app/db/models/agentchat.py
Normal file
130
app/db/models/agentchat.py
Normal file
@@ -0,0 +1,130 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Column, Integer, String, JSON, Index, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import Base, async_db_query, db_query, get_id_column
|
||||
|
||||
|
||||
class AgentChat(Base):
|
||||
"""
|
||||
Agent 会话历史表。
|
||||
"""
|
||||
|
||||
id = get_id_column()
|
||||
# Agent 内部会话 ID,用于恢复 LangGraph 对话上下文
|
||||
session_id = Column(String, nullable=False)
|
||||
# 前端或渠道侧传入的原始会话标识
|
||||
client_session_id = Column(String)
|
||||
# 用户 ID
|
||||
user_id = Column(String)
|
||||
# 用户名称
|
||||
username = Column(String)
|
||||
# 消息渠道
|
||||
channel = Column(String)
|
||||
# 渠道来源配置名
|
||||
source = Column(String)
|
||||
# 原聊天 ID,用于区分群聊、频道或私聊
|
||||
original_chat_id = Column(String)
|
||||
# 会话标题
|
||||
title = Column(String)
|
||||
# 会话预览文本
|
||||
preview = Column(String)
|
||||
# 原始 LangChain messages,用于继续会话
|
||||
agent_messages = Column(JSON)
|
||||
# 展示给用户的消息记录,包含文字、工具提示、附件与选择卡片
|
||||
display_messages = Column(JSON)
|
||||
# 展示消息数量
|
||||
message_count = Column(Integer, default=0)
|
||||
# 创建时间
|
||||
created_at = Column(String)
|
||||
# 更新时间
|
||||
updated_at = Column(String)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_agentchat_session_user", "session_id", "user_id"),
|
||||
Index("ix_agentchat_user_updated", "user_id", "updated_at", "id"),
|
||||
Index("ix_agentchat_channel_updated", "channel", "updated_at", "id"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_session(
|
||||
cls, db: Session, session_id: str, user_id: Optional[str] = None
|
||||
) -> Optional["AgentChat"]:
|
||||
"""
|
||||
根据会话 ID 获取 Agent 会话。
|
||||
"""
|
||||
query = db.query(cls).filter(cls.session_id == session_id)
|
||||
if user_id is not None:
|
||||
query = query.filter(cls.user_id == user_id)
|
||||
return query.order_by(cls.id.desc()).first()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_session(
|
||||
cls, db: AsyncSession, session_id: str, user_id: Optional[str] = None
|
||||
) -> Optional["AgentChat"]:
|
||||
"""
|
||||
异步根据会话 ID 获取 Agent 会话。
|
||||
"""
|
||||
statement = select(cls).where(cls.session_id == session_id)
|
||||
if user_id is not None:
|
||||
statement = statement.where(cls.user_id == user_id)
|
||||
result = await db.execute(statement.order_by(cls.id.desc()))
|
||||
return result.scalars().first()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def list_by_page(
|
||||
cls,
|
||||
db: Session,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
user_id: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
) -> list["AgentChat"]:
|
||||
"""
|
||||
分页获取 Agent 会话历史。
|
||||
"""
|
||||
query = db.query(cls)
|
||||
if user_id is not None and username is not None:
|
||||
query = query.filter((cls.user_id == user_id) | (cls.username == username))
|
||||
elif user_id is not None:
|
||||
query = query.filter(cls.user_id == user_id)
|
||||
elif username is not None:
|
||||
query = query.filter(cls.username == username)
|
||||
return (
|
||||
query.order_by(cls.updated_at.desc(), cls.id.desc())
|
||||
.offset((page - 1) * count)
|
||||
.limit(count)
|
||||
.all()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_list_by_page(
|
||||
cls,
|
||||
db: AsyncSession,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
user_id: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
) -> list["AgentChat"]:
|
||||
"""
|
||||
异步分页获取 Agent 会话历史。
|
||||
"""
|
||||
statement = select(cls)
|
||||
if user_id is not None and username is not None:
|
||||
statement = statement.where((cls.user_id == user_id) | (cls.username == username))
|
||||
elif user_id is not None:
|
||||
statement = statement.where(cls.user_id == user_id)
|
||||
elif username is not None:
|
||||
statement = statement.where(cls.username == username)
|
||||
result = await db.execute(
|
||||
statement.order_by(cls.updated_at.desc(), cls.id.desc())
|
||||
.offset((page - 1) * count)
|
||||
.limit(count)
|
||||
)
|
||||
return result.scalars().all()
|
||||
@@ -1,3 +1,4 @@
|
||||
from .agent import *
|
||||
from .context import *
|
||||
from .dashboard import *
|
||||
from .download import *
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""AI智能体相关数据模型"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer
|
||||
@@ -55,3 +55,95 @@ class ToolResult(BaseModel):
|
||||
success: bool = Field(description="是否成功")
|
||||
result: Optional[str] = Field(default=None, description="执行结果")
|
||||
error: Optional[str] = Field(default=None, description="错误信息")
|
||||
|
||||
|
||||
class AgentChatAttachment(BaseModel):
|
||||
"""
|
||||
Agent 会话展示附件。
|
||||
"""
|
||||
|
||||
kind: str = Field(..., description="附件类型")
|
||||
url: str = Field(..., description="附件访问地址")
|
||||
download_url: Optional[str] = Field(None, description="附件下载地址")
|
||||
name: Optional[str] = Field(None, description="附件名称")
|
||||
mime_type: Optional[str] = Field(None, description="MIME 类型")
|
||||
size: Optional[int] = Field(None, description="附件大小")
|
||||
local_path: Optional[str] = Field(None, description="服务端本地路径")
|
||||
|
||||
|
||||
class AgentChatToolCall(BaseModel):
|
||||
"""
|
||||
Agent 会话工具调用展示项。
|
||||
"""
|
||||
|
||||
id: str = Field(..., description="展示 ID")
|
||||
message: str = Field(..., description="工具提示")
|
||||
status: str = Field(default="done", description="工具状态")
|
||||
|
||||
|
||||
class AgentChatChoiceButton(BaseModel):
|
||||
"""
|
||||
Agent 会话选择按钮。
|
||||
"""
|
||||
|
||||
label: str = Field(..., description="按钮文案")
|
||||
callback_data: str = Field(..., description="回调数据")
|
||||
|
||||
|
||||
class AgentChatChoiceCard(BaseModel):
|
||||
"""
|
||||
Agent 会话选择卡片。
|
||||
"""
|
||||
|
||||
id: str = Field(..., description="选择卡片 ID")
|
||||
title: Optional[str] = Field(None, description="标题")
|
||||
prompt: str = Field(default="", description="提示语")
|
||||
buttons: list[AgentChatChoiceButton] = Field(default_factory=list, description="按钮列表")
|
||||
status: str = Field(default="pending", description="选择状态")
|
||||
selected_label: Optional[str] = Field(None, description="已选择文案")
|
||||
selected_value: Optional[str] = Field(None, description="已选择值")
|
||||
|
||||
|
||||
class AgentChatMessage(BaseModel):
|
||||
"""
|
||||
Agent 会话展示消息。
|
||||
"""
|
||||
|
||||
id: str = Field(..., description="展示消息 ID")
|
||||
role: str = Field(..., description="消息角色")
|
||||
content: str = Field(default="", description="消息文本")
|
||||
createdAt: Union[int, float] = Field(..., description="创建时间戳")
|
||||
status: str = Field(default="done", description="消息状态")
|
||||
tools: list[AgentChatToolCall] = Field(default_factory=list, description="工具提示列表")
|
||||
attachments: list[AgentChatAttachment] = Field(default_factory=list, description="附件列表")
|
||||
choices: list[AgentChatChoiceCard] = Field(default_factory=list, description="选择卡片列表")
|
||||
|
||||
|
||||
class AgentChatSession(BaseModel):
|
||||
"""
|
||||
Agent 会话历史详情。
|
||||
"""
|
||||
|
||||
id: Optional[int] = Field(None, description="数据库 ID")
|
||||
session_id: str = Field(..., description="Agent 内部会话 ID")
|
||||
client_session_id: Optional[str] = Field(None, description="客户端原始会话 ID")
|
||||
title: Optional[str] = Field(None, description="会话标题")
|
||||
preview: Optional[str] = Field(None, description="会话预览")
|
||||
channel: Optional[str] = Field(None, description="消息渠道")
|
||||
source: Optional[str] = Field(None, description="渠道来源")
|
||||
user_id: Optional[str] = Field(None, description="用户 ID")
|
||||
username: Optional[str] = Field(None, description="用户名")
|
||||
original_chat_id: Optional[str] = Field(None, description="原聊天 ID")
|
||||
message_count: int = Field(default=0, description="展示消息数量")
|
||||
created_at: Optional[str] = Field(None, description="创建时间")
|
||||
updated_at: Optional[str] = Field(None, description="更新时间")
|
||||
messages: list[AgentChatMessage] = Field(default_factory=list, description="展示消息列表")
|
||||
|
||||
|
||||
class AgentChatDisplaySaveRequest(BaseModel):
|
||||
"""
|
||||
Agent 会话展示消息保存请求。
|
||||
"""
|
||||
|
||||
messages: list[AgentChatMessage] = Field(default_factory=list, description="展示消息列表")
|
||||
title: Optional[str] = Field(None, description="会话标题")
|
||||
|
||||
@@ -310,6 +310,8 @@ class AgentWebChatRequest(BaseModel):
|
||||
audio_refs: Optional[List[str]] = Field(default_factory=list)
|
||||
# 文件附件列表
|
||||
files: Optional[List[AgentWebChatFile]] = Field(default_factory=list)
|
||||
# 是否在展示历史中记录本轮用户消息
|
||||
echo_user: bool = Field(default=True)
|
||||
|
||||
|
||||
class AgentWebChoiceRequest(BaseModel):
|
||||
|
||||
73
database/versions/8ab72c49d1e3_2_2_10.py
Normal file
73
database/versions/8ab72c49d1e3_2_2_10.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""2.2.10
|
||||
新增 Agent 会话历史表
|
||||
|
||||
Revision ID: 8ab72c49d1e3
|
||||
Revises: 7c1a2b3d4e5f
|
||||
Create Date: 2026-06-18
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "8ab72c49d1e3"
|
||||
down_revision = "7c1a2b3d4e5f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def _has_table(inspector: sa.Inspector, table_name: str) -> bool:
|
||||
"""检查数据表是否已存在。"""
|
||||
return table_name in inspector.get_table_names()
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""升级数据库结构。"""
|
||||
inspector = sa.inspect(op.get_bind())
|
||||
if _has_table(inspector, "agentchat"):
|
||||
return
|
||||
|
||||
op.create_table(
|
||||
"agentchat",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("session_id", sa.String(), nullable=False),
|
||||
sa.Column("client_session_id", sa.String(), nullable=True),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("username", sa.String(), nullable=True),
|
||||
sa.Column("channel", sa.String(), nullable=True),
|
||||
sa.Column("source", sa.String(), nullable=True),
|
||||
sa.Column("original_chat_id", sa.String(), nullable=True),
|
||||
sa.Column("title", sa.String(), nullable=True),
|
||||
sa.Column("preview", sa.String(), nullable=True),
|
||||
sa.Column("agent_messages", sa.JSON(), nullable=True),
|
||||
sa.Column("display_messages", sa.JSON(), nullable=True),
|
||||
sa.Column("message_count", sa.Integer(), nullable=True),
|
||||
sa.Column("created_at", sa.String(), nullable=True),
|
||||
sa.Column("updated_at", sa.String(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_agentchat_session_user",
|
||||
"agentchat",
|
||||
["session_id", "user_id"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_agentchat_user_updated",
|
||||
"agentchat",
|
||||
["user_id", "updated_at", "id"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_agentchat_channel_updated",
|
||||
"agentchat",
|
||||
["channel", "updated_at", "id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""回滚数据库结构。"""
|
||||
inspector = sa.inspect(op.get_bind())
|
||||
if not _has_table(inspector, "agentchat"):
|
||||
return
|
||||
op.drop_index("ix_agentchat_channel_updated", table_name="agentchat")
|
||||
op.drop_index("ix_agentchat_user_updated", table_name="agentchat")
|
||||
op.drop_index("ix_agentchat_session_user", table_name="agentchat")
|
||||
op.drop_table("agentchat")
|
||||
141
tests/test_agent_chat_history.py
Normal file
141
tests/test_agent_chat_history.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from app.agent import MoviePilotAgent
|
||||
from app.agent.memory import memory_manager
|
||||
from app.db.agentchat_oper import AgentChatOper
|
||||
|
||||
|
||||
def test_agent_chat_oper_saves_display_messages_with_channel():
|
||||
"""Agent 会话历史应保存展示消息与渠道标识。"""
|
||||
oper = AgentChatOper()
|
||||
oper.save_display_messages(
|
||||
session_id="session-chat",
|
||||
user_id="1",
|
||||
username="admin",
|
||||
channel="Telegram",
|
||||
source="telegram-main",
|
||||
original_chat_id="chat-1",
|
||||
messages=[
|
||||
{
|
||||
"id": "user-1",
|
||||
"role": "user",
|
||||
"content": "帮我看看下载器",
|
||||
"createdAt": 1,
|
||||
"status": "done",
|
||||
"tools": [],
|
||||
"attachments": [],
|
||||
"choices": [],
|
||||
}
|
||||
],
|
||||
)
|
||||
chat = AgentChatOper().get(session_id="session-chat", user_id="1")
|
||||
|
||||
assert chat.channel == "Telegram"
|
||||
assert chat.source == "telegram-main"
|
||||
assert chat.original_chat_id == "chat-1"
|
||||
assert chat.message_count == 1
|
||||
assert chat.title == "帮我看看下载器"
|
||||
|
||||
|
||||
def test_agent_chat_oper_keeps_generated_title_when_saving_display_messages():
|
||||
"""保存展示消息时不应覆盖已生成的模型标题。"""
|
||||
oper = AgentChatOper()
|
||||
oper.update_title_if_empty(
|
||||
session_id="session-title",
|
||||
user_id="1",
|
||||
username="admin",
|
||||
channel="WebAgent",
|
||||
source="web-agent",
|
||||
title="下载器状态排查",
|
||||
)
|
||||
oper.save_display_messages(
|
||||
session_id="session-title",
|
||||
user_id="1",
|
||||
messages=[
|
||||
{
|
||||
"id": "user-1",
|
||||
"role": "user",
|
||||
"content": "帮我看看下载器现在是不是正常",
|
||||
"createdAt": 1,
|
||||
"status": "done",
|
||||
"tools": [],
|
||||
"attachments": [],
|
||||
"choices": [],
|
||||
}
|
||||
],
|
||||
title="帮我看看下载器现在是不是正常",
|
||||
)
|
||||
|
||||
chat = AgentChatOper().get(session_id="session-title", user_id="1")
|
||||
summary = AgentChatOper.to_summary(chat)
|
||||
|
||||
assert chat.title == "下载器状态排查"
|
||||
assert "preview" not in summary
|
||||
assert "messages" not in summary
|
||||
|
||||
|
||||
def test_agent_prepare_chat_title_generates_title(monkeypatch):
|
||||
"""首次调用 Agent 时应使用模型生成会话标题并写入渠道信息。"""
|
||||
|
||||
class FakeTitleModel:
|
||||
"""测试用标题模型。"""
|
||||
|
||||
async def ainvoke(self, messages):
|
||||
"""返回固定标题。"""
|
||||
assert "标题生成器" in messages[0].content
|
||||
assert messages[1].content == "帮我看看下载器现在是不是正常"
|
||||
return SimpleNamespace(content="「下载器状态排查」")
|
||||
|
||||
async def fake_initialize_llm(self, streaming=False):
|
||||
"""返回测试标题模型。"""
|
||||
return FakeTitleModel()
|
||||
|
||||
monkeypatch.setattr(MoviePilotAgent, "_initialize_llm", fake_initialize_llm)
|
||||
agent = MoviePilotAgent(
|
||||
session_id="session-ai-title",
|
||||
user_id="3",
|
||||
channel="WebAgent",
|
||||
source="web-agent",
|
||||
username="admin",
|
||||
)
|
||||
|
||||
asyncio.run(agent.prepare_chat_title("帮我看看下载器现在是不是正常"))
|
||||
chat = AgentChatOper().get(session_id="session-ai-title", user_id="3")
|
||||
|
||||
assert chat.title == "下载器状态排查"
|
||||
assert chat.channel == "WebAgent"
|
||||
assert chat.source == "web-agent"
|
||||
|
||||
|
||||
def test_memory_manager_restores_agent_messages_from_database():
|
||||
"""内存缓存缺失时应从 Agent 会话历史表恢复原始 messages。"""
|
||||
session_id = "session-memory"
|
||||
user_id = "2"
|
||||
memory_manager.clear_memory(session_id, user_id)
|
||||
AgentChatOper().save_agent_messages(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
messages=[
|
||||
{
|
||||
"type": "human",
|
||||
"data": {
|
||||
"content": "继续之前的话题",
|
||||
"additional_kwargs": {},
|
||||
"response_metadata": {},
|
||||
"type": "human",
|
||||
"name": None,
|
||||
"id": None,
|
||||
"example": False,
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
messages = memory_manager.get_agent_messages(session_id=session_id, user_id=user_id)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert isinstance(messages[0], HumanMessage)
|
||||
assert messages[0].content == "继续之前的话题"
|
||||
@@ -8,6 +8,8 @@ from app.agent import ReplyMode
|
||||
from app.api.endpoints.agent import (
|
||||
_WebAgentMoviePilotAgent,
|
||||
_WEB_AGENT_FILE_REGISTRY,
|
||||
_apply_web_agent_display_event,
|
||||
_build_web_agent_input_attachments,
|
||||
_build_web_agent_notification_events,
|
||||
_build_web_agent_session_id,
|
||||
_prepare_web_agent_audio_attachment_path,
|
||||
@@ -16,6 +18,7 @@ from app.api.endpoints.agent import (
|
||||
_resolve_web_agent_choice_payload,
|
||||
_split_web_agent_output,
|
||||
)
|
||||
from app.db.agentchat_oper import AgentChatOper
|
||||
from app.helper.interaction import AgentInteractionOption, agent_interaction_manager
|
||||
from app.schemas.message import ChannelCapability, ChannelCapabilityManager
|
||||
from app.schemas.types import MessageChannel, NotificationType
|
||||
@@ -74,6 +77,73 @@ def test_build_web_agent_session_id_is_stable_per_user_and_seed():
|
||||
assert first.startswith("web-agent:")
|
||||
|
||||
|
||||
def test_build_web_agent_session_id_reuses_accessible_history():
|
||||
"""传入已有历史会话 ID 时应直接复用,避免跨渠道继续对话丢上下文。"""
|
||||
user = SimpleNamespace(id=1, name="admin", is_superuser=True)
|
||||
AgentChatOper().save_display_messages(
|
||||
session_id="telegram-session",
|
||||
user_id="telegram-user",
|
||||
username="tester",
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-main",
|
||||
messages=[],
|
||||
title="Telegram 会话",
|
||||
)
|
||||
|
||||
assert _build_web_agent_session_id(user, "telegram-session") == "telegram-session"
|
||||
|
||||
|
||||
def test_apply_web_agent_display_event_updates_snapshot():
|
||||
"""WebAgent SSE 事件应可聚合为服务端展示快照。"""
|
||||
message = {
|
||||
"id": "assistant-1",
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"createdAt": 1,
|
||||
"status": "streaming",
|
||||
"tools": [],
|
||||
"attachments": [],
|
||||
"choices": [],
|
||||
}
|
||||
|
||||
_apply_web_agent_display_event({"type": "delta", "content": "你好"}, message)
|
||||
_apply_web_agent_display_event({"type": "tool", "message": "查询订阅"}, message)
|
||||
_apply_web_agent_display_event(
|
||||
{
|
||||
"type": "attachment",
|
||||
"attachment": {"kind": "file", "url": "message/agent/file/a"},
|
||||
},
|
||||
message,
|
||||
)
|
||||
_apply_web_agent_display_event({"type": "done"}, message)
|
||||
|
||||
assert message["content"] == "你好"
|
||||
assert message["status"] == "done"
|
||||
assert len(message["tools"]) == 1
|
||||
assert message["tools"][0]["message"] == "查询订阅"
|
||||
assert message["tools"][0]["status"] == "done"
|
||||
assert message["attachments"] == [{"kind": "file", "url": "message/agent/file/a"}]
|
||||
|
||||
|
||||
def test_build_web_agent_input_attachments_marks_kinds():
|
||||
"""WebAgent 用户输入附件应转换为可展示的附件记录。"""
|
||||
attachments = _build_web_agent_input_attachments(
|
||||
images=["data:image/png;base64,abc"],
|
||||
files=[
|
||||
{
|
||||
"ref": "message/agent/file/file-1",
|
||||
"name": "report.txt",
|
||||
"mime_type": "text/plain",
|
||||
"size": 5,
|
||||
}
|
||||
],
|
||||
audio_refs=["message/agent/file/audio-1"],
|
||||
)
|
||||
|
||||
assert [item["kind"] for item in attachments] == ["image", "file", "audio"]
|
||||
assert attachments[1]["name"] == "report.txt"
|
||||
|
||||
|
||||
def test_web_agent_admin_context_uses_current_user_id():
|
||||
"""Web Agent 工具权限应按当前登录用户 ID 判断管理员身份。"""
|
||||
agent = _WebAgentMoviePilotAgent(
|
||||
|
||||
Reference in New Issue
Block a user