feat(chat): 基于 RAG 的笔记内容 AI 问答功能

实现类似 Google NotebookLM 的效果:笔记生成后自动向量化,
用户可针对笔记内容进行 LLM 问答。

### 后端
- 新增 VectorStoreManager(ChromaDB),按标题/转录分块建立向量索引
- 新增 chat_service.py RAG 问答:检索相关片段 → 构建 prompt → 调用 LLM
- 新增 /chat/index, /chat/ask, /chat/status API 端点
- 笔记生成完成后自动建立向量索引

### 前端
- 使用 @ant-design/x Bubble.List + Sender 组件构建聊天面板
- 新增 chatStore(Zustand + persist)持久化聊天记录
- MarkdownViewer 右侧嵌入 ChatPanel,通过"AI 问答"按钮切换
- 首次打开自动检查/触发索引,支持重新索引

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
huangjianwu
2026-03-23 14:38:39 +08:00
parent 1cd8c33983
commit efadbc267d
13 changed files with 730 additions and 2 deletions

View File

@@ -1,6 +1,6 @@
from fastapi import FastAPI
from .routers import note, provider, model, config
from .routers import note, provider, model, config, chat
@@ -10,5 +10,6 @@ def create_app(lifespan) -> FastAPI:
app.include_router(provider.router, prefix="/api")
app.include_router(model.router,prefix="/api")
app.include_router(config.router, prefix="/api")
app.include_router(chat.router, prefix="/api")
return app

View File

@@ -0,0 +1,74 @@
from typing import Optional
from fastapi import APIRouter
from pydantic import BaseModel
from app.services.chat_service import chat as chat_service
from app.services.vector_store import VectorStoreManager
from app.utils.logger import get_logger
from app.utils.response import ResponseWrapper as R
logger = get_logger(__name__)
router = APIRouter()
class IndexRequest(BaseModel):
task_id: str
class ChatMessage(BaseModel):
role: str
content: str
class AskRequest(BaseModel):
task_id: str
question: str
history: list[ChatMessage] = []
provider_id: str
model_name: str
@router.post("/chat/index")
def index_task(data: IndexRequest):
"""为笔记建立向量索引。"""
try:
store = VectorStoreManager()
store.index_task(data.task_id)
return R.success(msg="索引完成")
except Exception as e:
logger.error(f"索引失败: {e}")
return R.error(msg=f"索引失败: {str(e)}")
@router.get("/chat/status")
def chat_status(task_id: str):
"""检查笔记是否已建立向量索引。"""
try:
store = VectorStoreManager()
indexed = store.is_indexed(task_id)
return R.success(data={"indexed": indexed})
except Exception as e:
logger.error(f"查询索引状态失败: {e}")
return R.success(data={"indexed": False})
@router.post("/chat/ask")
def ask_question(data: AskRequest):
"""基于笔记内容的 RAG 问答。"""
try:
history = [{"role": m.role, "content": m.content} for m in data.history]
result = chat_service(
task_id=data.task_id,
question=data.question,
history=history,
provider_id=data.provider_id,
model_name=data.model_name,
)
return R.success(data=result)
except ValueError as e:
return R.error(msg=str(e))
except Exception as e:
logger.error(f"Chat 问答失败: {e}", exc_info=True)
return R.error(msg=f"问答失败: {str(e)}")

View File

@@ -109,6 +109,12 @@ def run_note_task(task_id: str, video_url: str, platform: str, quality: Download
return
save_note_to_file(task_id, note)
# 自动建立向量索引(用于 AI 问答),失败不影响笔记生成
try:
from app.services.vector_store import VectorStoreManager
VectorStoreManager().index_task(task_id)
except Exception as e:
logger.warning(f"向量索引失败(不影响笔记): {e}")
@router.post('/delete_task')

View File

@@ -0,0 +1,117 @@
from typing import Optional
from app.gpt.gpt_factory import GPTFactory
from app.models.model_config import ModelConfig
from app.services.provider import ProviderService
from app.services.vector_store import VectorStoreManager
from app.utils.logger import get_logger
logger = get_logger(__name__)
SYSTEM_PROMPT = """你是一个视频笔记问答助手。根据以下笔记内容回答用户的问题。
如果笔记内容中没有相关信息,请诚实告知用户。回答时尽量引用笔记中的具体内容。
--- 相关笔记内容 ---
{context}
---
请用中文回答,保持简洁准确。"""
def _build_context(chunks: list[dict]) -> str:
"""将检索到的片段拼接为上下文文本。"""
parts = []
for i, chunk in enumerate(chunks, 1):
meta = chunk.get("metadata", {})
source_type = meta.get("source_type", "unknown")
if source_type == "markdown":
label = f"[笔记 - {meta.get('section_title', '')}]"
else:
start = meta.get("start_time", 0)
end = meta.get("end_time", 0)
label = f"[转录 - {start:.0f}s~{end:.0f}s]"
parts.append(f"{label}\n{chunk['text']}")
return "\n\n".join(parts)
def _build_sources(chunks: list[dict]) -> list[dict]:
"""从检索片段中提取来源信息。"""
sources = []
for chunk in chunks:
meta = chunk.get("metadata", {})
source = {
"text": chunk["text"][:200],
"source_type": meta.get("source_type", "unknown"),
}
if meta.get("section_title"):
source["section_title"] = meta["section_title"]
if meta.get("start_time") is not None:
source["start_time"] = meta["start_time"]
if meta.get("end_time") is not None:
source["end_time"] = meta["end_time"]
sources.append(source)
return sources
def chat(
task_id: str,
question: str,
history: list[dict],
provider_id: str,
model_name: str,
) -> dict:
"""
RAG 问答:检索相关片段 → 构建 prompt → 调用 LLM → 返回答案 + 来源。
Returns:
{"answer": str, "sources": list[dict]}
"""
vector_store = VectorStoreManager()
# 1. 检索相关片段
chunks = vector_store.query(task_id, question, n_results=5)
if not chunks:
return {
"answer": "暂未找到相关笔记内容,请确认笔记已生成并完成索引。",
"sources": [],
}
# 2. 构建上下文和来源
context = _build_context(chunks)
sources = _build_sources(chunks)
# 3. 构建消息
system_msg = SYSTEM_PROMPT.format(context=context)
messages = [{"role": "system", "content": system_msg}]
# 加入历史对话(最近 10 轮)
for msg in history[-20:]:
messages.append({"role": msg["role"], "content": msg["content"]})
messages.append({"role": "user", "content": question})
# 4. 调用 LLM
provider = ProviderService.get_provider_by_id(provider_id)
if not provider:
raise ValueError(f"未找到模型供应商: {provider_id}")
config = ModelConfig(
api_key=provider["api_key"],
base_url=provider["base_url"],
model_name=model_name,
provider=provider["type"],
name=provider["name"],
)
gpt = GPTFactory.from_config(config)
logger.info(f"Chat RAG: task_id={task_id}, provider={provider['name']}, model={model_name}")
response = gpt.client.chat.completions.create(
model=gpt.model,
messages=messages,
temperature=0.7,
)
answer = response.choices[0].message.content
return {"answer": answer, "sources": sources}

View File

@@ -0,0 +1,155 @@
import json
import os
import re
from typing import Optional
import chromadb
from chromadb.config import Settings
from app.utils.logger import get_logger
logger = get_logger(__name__)
NOTE_OUTPUT_DIR = os.getenv("NOTE_OUTPUT_DIR", "note_results")
VECTOR_DB_DIR = os.getenv("VECTOR_DB_DIR", "vector_db")
def _chunk_markdown(markdown: str) -> list[dict]:
"""按 H2/H3 标题拆分 markdown 为语义块。"""
sections = re.split(r'(?=^#{2,3}\s)', markdown, flags=re.MULTILINE)
chunks = []
for section in sections:
section = section.strip()
if not section or len(section) < 30:
continue
heading_match = re.match(r'^(#{2,3})\s+(.+)', section)
title = heading_match.group(2).strip() if heading_match else "intro"
chunks.append({
"text": section,
"metadata": {"source_type": "markdown", "section_title": title},
})
return chunks
def _chunk_transcript(segments: list[dict], window_size: int = 15, overlap: int = 3) -> list[dict]:
"""将转录 segments 按滑动窗口分组。"""
if not segments:
return []
chunks = []
step = max(window_size - overlap, 1)
for i in range(0, len(segments), step):
window = segments[i:i + window_size]
if not window:
break
text = "\n".join(
f"[{seg.get('start', 0):.0f}s] {seg.get('text', '')}" for seg in window
)
chunks.append({
"text": text,
"metadata": {
"source_type": "transcript",
"start_time": window[0].get("start", 0),
"end_time": window[-1].get("end", 0),
},
})
return chunks
class VectorStoreManager:
"""基于 ChromaDB 的笔记向量存储管理器。"""
def __init__(self):
os.makedirs(VECTOR_DB_DIR, exist_ok=True)
self._client = chromadb.PersistentClient(
path=VECTOR_DB_DIR,
settings=Settings(anonymized_telemetry=False),
)
def _collection_name(self, task_id: str) -> str:
"""ChromaDB collection 名称需满足限制3-63字符字母数字开头结尾。"""
safe = re.sub(r'[^a-zA-Z0-9_-]', '_', task_id)[:60]
if not safe or not safe[0].isalnum():
safe = "t" + safe
if not safe[-1].isalnum():
safe = safe + "0"
return safe
def index_task(self, task_id: str) -> None:
"""读取笔记结果并建立向量索引。"""
result_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json")
if not os.path.exists(result_path):
logger.warning(f"笔记文件不存在,跳过索引: {result_path}")
return
with open(result_path, "r", encoding="utf-8") as f:
note_data = json.load(f)
markdown = note_data.get("markdown", "")
transcript = note_data.get("transcript", {})
segments = transcript.get("segments", [])
md_chunks = _chunk_markdown(markdown)
tr_chunks = _chunk_transcript(segments)
all_chunks = md_chunks + tr_chunks
if not all_chunks:
logger.warning(f"笔记内容为空,跳过索引: {task_id}")
return
col_name = self._collection_name(task_id)
# 删除旧 collection幂等
try:
self._client.delete_collection(col_name)
except ValueError:
pass
collection = self._client.create_collection(
name=col_name,
metadata={"hnsw:space": "cosine"},
)
documents = [c["text"] for c in all_chunks]
metadatas = [c["metadata"] for c in all_chunks]
ids = [f"{task_id}_{i}" for i in range(len(all_chunks))]
collection.add(documents=documents, metadatas=metadatas, ids=ids)
logger.info(f"向量索引完成: task_id={task_id}, chunks={len(all_chunks)}")
def query(self, task_id: str, query_text: str, n_results: int = 5) -> list[dict]:
"""检索与查询最相关的文档片段。"""
col_name = self._collection_name(task_id)
try:
collection = self._client.get_collection(col_name)
except ValueError:
logger.warning(f"Collection 不存在: {col_name}")
return []
results = collection.query(query_texts=[query_text], n_results=n_results)
chunks = []
for i in range(len(results["documents"][0])):
chunks.append({
"text": results["documents"][0][i],
"metadata": results["metadatas"][0][i] if results["metadatas"] else {},
"distance": results["distances"][0][i] if results["distances"] else None,
})
return chunks
def delete_index(self, task_id: str) -> None:
"""删除指定任务的向量索引。"""
col_name = self._collection_name(task_id)
try:
self._client.delete_collection(col_name)
logger.info(f"已删除向量索引: {task_id}")
except ValueError:
pass
def is_indexed(self, task_id: str) -> bool:
"""检查指定任务是否已建立索引。"""
col_name = self._collection_name(task_id)
try:
col = self._client.get_collection(col_name)
return col.count() > 0
except ValueError:
return False