fix(chat): RAG 检索同时召回笔记和转录内容

之前 query 只做一次全局检索,embedding 模型倾向匹配笔记,
导致转录原文几乎不会被召回。

- 改为分别对 markdown 和 transcript 各检索 n_results 条,
  合并后按距离排序取 top-n
- 更新 system prompt,明确区分笔记和转录两种来源,
  引导 LLM 根据问题类型选择合适的来源回答

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
huangjianwu
2026-03-23 15:35:31 +08:00
parent ef1dec1e47
commit a92c779dd6
2 changed files with 45 additions and 15 deletions

View File

@@ -8,14 +8,21 @@ from app.utils.logger import get_logger
logger = get_logger(__name__)
SYSTEM_PROMPT = """你是一个视频笔记问答助手。根据以下笔记内容回答用户的问题
如果笔记内容中没有相关信息,请诚实告知用户。回答时尽量引用笔记中的具体内容。
SYSTEM_PROMPT = """你是一个视频笔记问答助手。你可以参考两种来源回答用户的问题
1. [笔记] — AI 生成的视频摘要笔记
2. [转录] — 视频原始语音转录文本(含时间戳)
--- 相关笔记内容 ---
以下是检索到的相关内容:
--- 相关内容 ---
{context}
---
请用中文回答,保持简洁准确。"""
回答要求:
- 优先使用转录原文回答关于视频具体内容、原话、细节的问题
- 优先使用笔记回答关于总结、要点、结构的问题
- 如果确实没有相关信息,请诚实告知
- 请用中文回答,保持简洁准确"""
def _build_context(chunks: list[dict]) -> str:

View File

@@ -111,18 +111,11 @@ class VectorStoreManager:
collection.add(documents=documents, metadatas=metadatas, ids=ids)
logger.info(f"向量索引完成: task_id={task_id}, chunks={len(all_chunks)}")
def query(self, task_id: str, query_text: str, n_results: int = 5) -> list[dict]:
"""检索与查询最相关的文档片段"""
col_name = self._collection_name(task_id)
try:
collection = self._client.get_collection(col_name)
except Exception:
logger.warning(f"Collection 不存在: {col_name}")
return []
results = collection.query(query_texts=[query_text], n_results=n_results)
def _parse_results(self, results: dict) -> list[dict]:
"""将 ChromaDB query 结果转换为 chunk 列表"""
chunks = []
if not results or not results.get("documents") or not results["documents"][0]:
return chunks
for i in range(len(results["documents"][0])):
chunks.append({
"text": results["documents"][0][i],
@@ -131,6 +124,36 @@ class VectorStoreManager:
})
return chunks
def query(self, task_id: str, query_text: str, n_results: int = 5) -> list[dict]:
"""
分别从 markdown 和 transcript 各检索,确保两种来源都被召回,
最后按距离排序返回 top-n。
"""
col_name = self._collection_name(task_id)
try:
collection = self._client.get_collection(col_name)
except Exception:
logger.warning(f"Collection 不存在: {col_name}")
return []
all_chunks = []
# 分别从两种来源各检索 n_results 条
for source_type in ("markdown", "transcript"):
try:
results = collection.query(
query_texts=[query_text],
n_results=n_results,
where={"source_type": source_type},
)
all_chunks.extend(self._parse_results(results))
except Exception:
pass
# 按距离排序,取 top-n
all_chunks.sort(key=lambda c: c.get("distance", 999))
return all_chunks[:n_results]
def delete_index(self, task_id: str) -> None:
"""删除指定任务的向量索引。"""
col_name = self._collection_name(task_id)