Files
BiliNote/backend/app/services/vector_store.py
huangjianwu fdc888512a fix(chat): 修复 ChromaDB 1.x 兼容性问题导致索引失败
- ChromaDB 1.x delete/get 不存在的 collection 抛 NotFoundError
  而非 ValueError,统一改为 except Exception
- 简化 _collection_name,UUID 格式本身就是合法的 collection name
- requirements.txt 放宽 chromadb 版本约束

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-23 14:43:57 +08:00

151 lines
5.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
import os
import re
from typing import Optional
import chromadb
from chromadb.config import Settings
from app.utils.logger import get_logger
logger = get_logger(__name__)
NOTE_OUTPUT_DIR = os.getenv("NOTE_OUTPUT_DIR", "note_results")
VECTOR_DB_DIR = os.getenv("VECTOR_DB_DIR", "vector_db")
def _chunk_markdown(markdown: str) -> list[dict]:
"""按 H2/H3 标题拆分 markdown 为语义块。"""
sections = re.split(r'(?=^#{2,3}\s)', markdown, flags=re.MULTILINE)
chunks = []
for section in sections:
section = section.strip()
if not section or len(section) < 30:
continue
heading_match = re.match(r'^(#{2,3})\s+(.+)', section)
title = heading_match.group(2).strip() if heading_match else "intro"
chunks.append({
"text": section,
"metadata": {"source_type": "markdown", "section_title": title},
})
return chunks
def _chunk_transcript(segments: list[dict], window_size: int = 15, overlap: int = 3) -> list[dict]:
"""将转录 segments 按滑动窗口分组。"""
if not segments:
return []
chunks = []
step = max(window_size - overlap, 1)
for i in range(0, len(segments), step):
window = segments[i:i + window_size]
if not window:
break
text = "\n".join(
f"[{seg.get('start', 0):.0f}s] {seg.get('text', '')}" for seg in window
)
chunks.append({
"text": text,
"metadata": {
"source_type": "transcript",
"start_time": window[0].get("start", 0),
"end_time": window[-1].get("end", 0),
},
})
return chunks
class VectorStoreManager:
"""基于 ChromaDB 的笔记向量存储管理器。"""
def __init__(self):
os.makedirs(VECTOR_DB_DIR, exist_ok=True)
self._client = chromadb.PersistentClient(
path=VECTOR_DB_DIR,
settings=Settings(anonymized_telemetry=False),
)
def _collection_name(self, task_id: str) -> str:
"""ChromaDB collection 名称:直接使用 task_idUUID 格式合法)。"""
return task_id
def index_task(self, task_id: str) -> None:
"""读取笔记结果并建立向量索引。"""
result_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json")
if not os.path.exists(result_path):
logger.warning(f"笔记文件不存在,跳过索引: {result_path}")
return
with open(result_path, "r", encoding="utf-8") as f:
note_data = json.load(f)
markdown = note_data.get("markdown", "")
transcript = note_data.get("transcript", {})
segments = transcript.get("segments", [])
md_chunks = _chunk_markdown(markdown)
tr_chunks = _chunk_transcript(segments)
all_chunks = md_chunks + tr_chunks
if not all_chunks:
logger.warning(f"笔记内容为空,跳过索引: {task_id}")
return
col_name = self._collection_name(task_id)
# 删除旧 collection幂等
try:
self._client.delete_collection(col_name)
except Exception:
pass
collection = self._client.create_collection(
name=col_name,
metadata={"hnsw:space": "cosine"},
)
documents = [c["text"] for c in all_chunks]
metadatas = [c["metadata"] for c in all_chunks]
ids = [f"{task_id}_{i}" for i in range(len(all_chunks))]
collection.add(documents=documents, metadatas=metadatas, ids=ids)
logger.info(f"向量索引完成: task_id={task_id}, chunks={len(all_chunks)}")
def query(self, task_id: str, query_text: str, n_results: int = 5) -> list[dict]:
"""检索与查询最相关的文档片段。"""
col_name = self._collection_name(task_id)
try:
collection = self._client.get_collection(col_name)
except Exception:
logger.warning(f"Collection 不存在: {col_name}")
return []
results = collection.query(query_texts=[query_text], n_results=n_results)
chunks = []
for i in range(len(results["documents"][0])):
chunks.append({
"text": results["documents"][0][i],
"metadata": results["metadatas"][0][i] if results["metadatas"] else {},
"distance": results["distances"][0][i] if results["distances"] else None,
})
return chunks
def delete_index(self, task_id: str) -> None:
"""删除指定任务的向量索引。"""
col_name = self._collection_name(task_id)
try:
self._client.delete_collection(col_name)
logger.info(f"已删除向量索引: {task_id}")
except Exception:
pass
def is_indexed(self, task_id: str) -> bool:
"""检查指定任务是否已建立索引。"""
col_name = self._collection_name(task_id)
try:
col = self._client.get_collection(col_name)
return col.count() > 0
except Exception:
return False