mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-05-19 12:29:31 +08:00
之前 query 只做一次全局检索,embedding 模型倾向匹配笔记, 导致转录原文几乎不会被召回。 - 改为分别对 markdown 和 transcript 各检索 n_results 条, 合并后按距离排序取 top-n - 更新 system prompt,明确区分笔记和转录两种来源, 引导 LLM 根据问题类型选择合适的来源回答 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
174 lines
5.9 KiB
Python
174 lines
5.9 KiB
Python
import json
|
||
import os
|
||
import re
|
||
from typing import Optional
|
||
|
||
import chromadb
|
||
from chromadb.config import Settings
|
||
|
||
from app.utils.logger import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
NOTE_OUTPUT_DIR = os.getenv("NOTE_OUTPUT_DIR", "note_results")
|
||
VECTOR_DB_DIR = os.getenv("VECTOR_DB_DIR", "vector_db")
|
||
|
||
|
||
def _chunk_markdown(markdown: str) -> list[dict]:
|
||
"""按 H2/H3 标题拆分 markdown 为语义块。"""
|
||
sections = re.split(r'(?=^#{2,3}\s)', markdown, flags=re.MULTILINE)
|
||
chunks = []
|
||
for section in sections:
|
||
section = section.strip()
|
||
if not section or len(section) < 30:
|
||
continue
|
||
heading_match = re.match(r'^(#{2,3})\s+(.+)', section)
|
||
title = heading_match.group(2).strip() if heading_match else "intro"
|
||
chunks.append({
|
||
"text": section,
|
||
"metadata": {"source_type": "markdown", "section_title": title},
|
||
})
|
||
return chunks
|
||
|
||
|
||
def _chunk_transcript(segments: list[dict], window_size: int = 15, overlap: int = 3) -> list[dict]:
|
||
"""将转录 segments 按滑动窗口分组。"""
|
||
if not segments:
|
||
return []
|
||
chunks = []
|
||
step = max(window_size - overlap, 1)
|
||
for i in range(0, len(segments), step):
|
||
window = segments[i:i + window_size]
|
||
if not window:
|
||
break
|
||
text = "\n".join(
|
||
f"[{seg.get('start', 0):.0f}s] {seg.get('text', '')}" for seg in window
|
||
)
|
||
chunks.append({
|
||
"text": text,
|
||
"metadata": {
|
||
"source_type": "transcript",
|
||
"start_time": window[0].get("start", 0),
|
||
"end_time": window[-1].get("end", 0),
|
||
},
|
||
})
|
||
return chunks
|
||
|
||
|
||
class VectorStoreManager:
|
||
"""基于 ChromaDB 的笔记向量存储管理器。"""
|
||
|
||
def __init__(self):
|
||
os.makedirs(VECTOR_DB_DIR, exist_ok=True)
|
||
self._client = chromadb.PersistentClient(
|
||
path=VECTOR_DB_DIR,
|
||
settings=Settings(anonymized_telemetry=False),
|
||
)
|
||
|
||
def _collection_name(self, task_id: str) -> str:
|
||
"""ChromaDB collection 名称:直接使用 task_id(UUID 格式合法)。"""
|
||
return task_id
|
||
|
||
def index_task(self, task_id: str) -> None:
|
||
"""读取笔记结果并建立向量索引。"""
|
||
result_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json")
|
||
if not os.path.exists(result_path):
|
||
logger.warning(f"笔记文件不存在,跳过索引: {result_path}")
|
||
return
|
||
|
||
with open(result_path, "r", encoding="utf-8") as f:
|
||
note_data = json.load(f)
|
||
|
||
markdown = note_data.get("markdown", "")
|
||
transcript = note_data.get("transcript", {})
|
||
segments = transcript.get("segments", [])
|
||
|
||
md_chunks = _chunk_markdown(markdown)
|
||
tr_chunks = _chunk_transcript(segments)
|
||
all_chunks = md_chunks + tr_chunks
|
||
|
||
if not all_chunks:
|
||
logger.warning(f"笔记内容为空,跳过索引: {task_id}")
|
||
return
|
||
|
||
col_name = self._collection_name(task_id)
|
||
|
||
# 删除旧 collection(幂等)
|
||
try:
|
||
self._client.delete_collection(col_name)
|
||
except Exception:
|
||
pass
|
||
|
||
collection = self._client.create_collection(
|
||
name=col_name,
|
||
metadata={"hnsw:space": "cosine"},
|
||
)
|
||
|
||
documents = [c["text"] for c in all_chunks]
|
||
metadatas = [c["metadata"] for c in all_chunks]
|
||
ids = [f"{task_id}_{i}" for i in range(len(all_chunks))]
|
||
|
||
collection.add(documents=documents, metadatas=metadatas, ids=ids)
|
||
logger.info(f"向量索引完成: task_id={task_id}, chunks={len(all_chunks)}")
|
||
|
||
def _parse_results(self, results: dict) -> list[dict]:
|
||
"""将 ChromaDB query 结果转换为 chunk 列表。"""
|
||
chunks = []
|
||
if not results or not results.get("documents") or not results["documents"][0]:
|
||
return chunks
|
||
for i in range(len(results["documents"][0])):
|
||
chunks.append({
|
||
"text": results["documents"][0][i],
|
||
"metadata": results["metadatas"][0][i] if results["metadatas"] else {},
|
||
"distance": results["distances"][0][i] if results["distances"] else None,
|
||
})
|
||
return chunks
|
||
|
||
def query(self, task_id: str, query_text: str, n_results: int = 5) -> list[dict]:
|
||
"""
|
||
分别从 markdown 和 transcript 各检索,确保两种来源都被召回,
|
||
最后按距离排序返回 top-n。
|
||
"""
|
||
col_name = self._collection_name(task_id)
|
||
try:
|
||
collection = self._client.get_collection(col_name)
|
||
except Exception:
|
||
logger.warning(f"Collection 不存在: {col_name}")
|
||
return []
|
||
|
||
all_chunks = []
|
||
|
||
# 分别从两种来源各检索 n_results 条
|
||
for source_type in ("markdown", "transcript"):
|
||
try:
|
||
results = collection.query(
|
||
query_texts=[query_text],
|
||
n_results=n_results,
|
||
where={"source_type": source_type},
|
||
)
|
||
all_chunks.extend(self._parse_results(results))
|
||
except Exception:
|
||
pass
|
||
|
||
# 按距离排序,取 top-n
|
||
all_chunks.sort(key=lambda c: c.get("distance", 999))
|
||
return all_chunks[:n_results]
|
||
|
||
def delete_index(self, task_id: str) -> None:
|
||
"""删除指定任务的向量索引。"""
|
||
col_name = self._collection_name(task_id)
|
||
try:
|
||
self._client.delete_collection(col_name)
|
||
logger.info(f"已删除向量索引: {task_id}")
|
||
except Exception:
|
||
pass
|
||
|
||
def is_indexed(self, task_id: str) -> bool:
|
||
"""检查指定任务是否已建立索引。"""
|
||
col_name = self._collection_name(task_id)
|
||
try:
|
||
col = self._client.get_collection(col_name)
|
||
return col.count() > 0
|
||
except Exception:
|
||
return False
|