mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-21 15:36:37 +08:00
131 lines
4.3 KiB
Python
131 lines
4.3 KiB
Python
from typing import Optional
|
||
|
||
from sqlalchemy import Column, Integer, String, JSON, Index, select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy.orm import Session
|
||
|
||
from app.db import Base, async_db_query, db_query, get_id_column
|
||
|
||
|
||
class AgentChat(Base):
|
||
"""
|
||
Agent 会话历史表。
|
||
"""
|
||
|
||
id = get_id_column()
|
||
# Agent 内部会话 ID,用于恢复 LangGraph 对话上下文
|
||
session_id = Column(String, nullable=False)
|
||
# 前端或渠道侧传入的原始会话标识
|
||
client_session_id = Column(String)
|
||
# 用户 ID
|
||
user_id = Column(String)
|
||
# 用户名称
|
||
username = Column(String)
|
||
# 消息渠道
|
||
channel = Column(String)
|
||
# 渠道来源配置名
|
||
source = Column(String)
|
||
# 原聊天 ID,用于区分群聊、频道或私聊
|
||
original_chat_id = Column(String)
|
||
# 会话标题
|
||
title = Column(String)
|
||
# 会话预览文本
|
||
preview = Column(String)
|
||
# 原始 LangChain messages,用于继续会话
|
||
agent_messages = Column(JSON)
|
||
# 展示给用户的消息记录,包含文字、工具提示、附件与选择卡片
|
||
display_messages = Column(JSON)
|
||
# 展示消息数量
|
||
message_count = Column(Integer, default=0)
|
||
# 创建时间
|
||
created_at = Column(String)
|
||
# 更新时间
|
||
updated_at = Column(String)
|
||
|
||
__table_args__ = (
|
||
Index("ix_agentchat_session_user", "session_id", "user_id"),
|
||
Index("ix_agentchat_user_updated", "user_id", "updated_at", "id"),
|
||
Index("ix_agentchat_channel_updated", "channel", "updated_at", "id"),
|
||
)
|
||
|
||
@classmethod
|
||
@db_query
|
||
def get_by_session(
|
||
cls, db: Session, session_id: str, user_id: Optional[str] = None
|
||
) -> Optional["AgentChat"]:
|
||
"""
|
||
根据会话 ID 获取 Agent 会话。
|
||
"""
|
||
query = db.query(cls).filter(cls.session_id == session_id)
|
||
if user_id is not None:
|
||
query = query.filter(cls.user_id == user_id)
|
||
return query.order_by(cls.id.desc()).first()
|
||
|
||
@classmethod
|
||
@async_db_query
|
||
async def async_get_by_session(
|
||
cls, db: AsyncSession, session_id: str, user_id: Optional[str] = None
|
||
) -> Optional["AgentChat"]:
|
||
"""
|
||
异步根据会话 ID 获取 Agent 会话。
|
||
"""
|
||
statement = select(cls).where(cls.session_id == session_id)
|
||
if user_id is not None:
|
||
statement = statement.where(cls.user_id == user_id)
|
||
result = await db.execute(statement.order_by(cls.id.desc()))
|
||
return result.scalars().first()
|
||
|
||
@classmethod
|
||
@db_query
|
||
def list_by_page(
|
||
cls,
|
||
db: Session,
|
||
page: Optional[int] = 1,
|
||
count: Optional[int] = 30,
|
||
user_id: Optional[str] = None,
|
||
username: Optional[str] = None,
|
||
) -> list["AgentChat"]:
|
||
"""
|
||
分页获取 Agent 会话历史。
|
||
"""
|
||
query = db.query(cls)
|
||
if user_id is not None and username is not None:
|
||
query = query.filter((cls.user_id == user_id) | (cls.username == username))
|
||
elif user_id is not None:
|
||
query = query.filter(cls.user_id == user_id)
|
||
elif username is not None:
|
||
query = query.filter(cls.username == username)
|
||
return (
|
||
query.order_by(cls.updated_at.desc(), cls.id.desc())
|
||
.offset((page - 1) * count)
|
||
.limit(count)
|
||
.all()
|
||
)
|
||
|
||
@classmethod
|
||
@async_db_query
|
||
async def async_list_by_page(
|
||
cls,
|
||
db: AsyncSession,
|
||
page: Optional[int] = 1,
|
||
count: Optional[int] = 30,
|
||
user_id: Optional[str] = None,
|
||
username: Optional[str] = None,
|
||
) -> list["AgentChat"]:
|
||
"""
|
||
异步分页获取 Agent 会话历史。
|
||
"""
|
||
statement = select(cls)
|
||
if user_id is not None and username is not None:
|
||
statement = statement.where((cls.user_id == user_id) | (cls.username == username))
|
||
elif user_id is not None:
|
||
statement = statement.where(cls.user_id == user_id)
|
||
elif username is not None:
|
||
statement = statement.where(cls.username == username)
|
||
result = await db.execute(
|
||
statement.order_by(cls.updated_at.desc(), cls.id.desc())
|
||
.offset((page - 1) * count)
|
||
.limit(count)
|
||
)
|
||
return result.scalars().all()
|