Files
MoviePilot/app/db/models/agentchat.py
2026-06-18 11:45:50 +08:00

131 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from typing import Optional
from sqlalchemy import Column, Integer, String, JSON, Index, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import Base, async_db_query, db_query, get_id_column
class AgentChat(Base):
"""
Agent 会话历史表。
"""
id = get_id_column()
# Agent 内部会话 ID用于恢复 LangGraph 对话上下文
session_id = Column(String, nullable=False)
# 前端或渠道侧传入的原始会话标识
client_session_id = Column(String)
# 用户 ID
user_id = Column(String)
# 用户名称
username = Column(String)
# 消息渠道
channel = Column(String)
# 渠道来源配置名
source = Column(String)
# 原聊天 ID用于区分群聊、频道或私聊
original_chat_id = Column(String)
# 会话标题
title = Column(String)
# 会话预览文本
preview = Column(String)
# 原始 LangChain messages用于继续会话
agent_messages = Column(JSON)
# 展示给用户的消息记录,包含文字、工具提示、附件与选择卡片
display_messages = Column(JSON)
# 展示消息数量
message_count = Column(Integer, default=0)
# 创建时间
created_at = Column(String)
# 更新时间
updated_at = Column(String)
__table_args__ = (
Index("ix_agentchat_session_user", "session_id", "user_id"),
Index("ix_agentchat_user_updated", "user_id", "updated_at", "id"),
Index("ix_agentchat_channel_updated", "channel", "updated_at", "id"),
)
@classmethod
@db_query
def get_by_session(
cls, db: Session, session_id: str, user_id: Optional[str] = None
) -> Optional["AgentChat"]:
"""
根据会话 ID 获取 Agent 会话。
"""
query = db.query(cls).filter(cls.session_id == session_id)
if user_id is not None:
query = query.filter(cls.user_id == user_id)
return query.order_by(cls.id.desc()).first()
@classmethod
@async_db_query
async def async_get_by_session(
cls, db: AsyncSession, session_id: str, user_id: Optional[str] = None
) -> Optional["AgentChat"]:
"""
异步根据会话 ID 获取 Agent 会话。
"""
statement = select(cls).where(cls.session_id == session_id)
if user_id is not None:
statement = statement.where(cls.user_id == user_id)
result = await db.execute(statement.order_by(cls.id.desc()))
return result.scalars().first()
@classmethod
@db_query
def list_by_page(
cls,
db: Session,
page: Optional[int] = 1,
count: Optional[int] = 30,
user_id: Optional[str] = None,
username: Optional[str] = None,
) -> list["AgentChat"]:
"""
分页获取 Agent 会话历史。
"""
query = db.query(cls)
if user_id is not None and username is not None:
query = query.filter((cls.user_id == user_id) | (cls.username == username))
elif user_id is not None:
query = query.filter(cls.user_id == user_id)
elif username is not None:
query = query.filter(cls.username == username)
return (
query.order_by(cls.updated_at.desc(), cls.id.desc())
.offset((page - 1) * count)
.limit(count)
.all()
)
@classmethod
@async_db_query
async def async_list_by_page(
cls,
db: AsyncSession,
page: Optional[int] = 1,
count: Optional[int] = 30,
user_id: Optional[str] = None,
username: Optional[str] = None,
) -> list["AgentChat"]:
"""
异步分页获取 Agent 会话历史。
"""
statement = select(cls)
if user_id is not None and username is not None:
statement = statement.where((cls.user_id == user_id) | (cls.username == username))
elif user_id is not None:
statement = statement.where(cls.user_id == user_id)
elif username is not None:
statement = statement.where(cls.username == username)
result = await db.execute(
statement.order_by(cls.updated_at.desc(), cls.id.desc())
.offset((page - 1) * count)
.limit(count)
)
return result.scalars().all()