import json import os import re from typing import Optional import chromadb from chromadb.config import Settings from app.utils.logger import get_logger logger = get_logger(__name__) NOTE_OUTPUT_DIR = os.getenv("NOTE_OUTPUT_DIR", "note_results") VECTOR_DB_DIR = os.getenv("VECTOR_DB_DIR", "vector_db") def _chunk_markdown(markdown: str) -> list[dict]: """按 H2/H3 标题拆分 markdown 为语义块。""" sections = re.split(r'(?=^#{2,3}\s)', markdown, flags=re.MULTILINE) chunks = [] for section in sections: section = section.strip() if not section or len(section) < 30: continue heading_match = re.match(r'^(#{2,3})\s+(.+)', section) title = heading_match.group(2).strip() if heading_match else "intro" chunks.append({ "text": section, "metadata": {"source_type": "markdown", "section_title": title}, }) return chunks def _chunk_transcript(segments: list[dict], window_size: int = 15, overlap: int = 3) -> list[dict]: """将转录 segments 按滑动窗口分组。""" if not segments: return [] chunks = [] step = max(window_size - overlap, 1) for i in range(0, len(segments), step): window = segments[i:i + window_size] if not window: break text = "\n".join( f"[{seg.get('start', 0):.0f}s] {seg.get('text', '')}" for seg in window ) chunks.append({ "text": text, "metadata": { "source_type": "transcript", "start_time": window[0].get("start", 0), "end_time": window[-1].get("end", 0), }, }) return chunks class VectorStoreManager: """基于 ChromaDB 的笔记向量存储管理器。""" def __init__(self): os.makedirs(VECTOR_DB_DIR, exist_ok=True) self._client = chromadb.PersistentClient( path=VECTOR_DB_DIR, settings=Settings(anonymized_telemetry=False), ) def _collection_name(self, task_id: str) -> str: """ChromaDB collection 名称:直接使用 task_id(UUID 格式合法)。""" return task_id def index_task(self, task_id: str) -> None: """读取笔记结果并建立向量索引。""" result_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json") if not os.path.exists(result_path): logger.warning(f"笔记文件不存在,跳过索引: {result_path}") return with open(result_path, "r", encoding="utf-8") as f: note_data = json.load(f) markdown = note_data.get("markdown", "") transcript = note_data.get("transcript", {}) segments = transcript.get("segments", []) md_chunks = _chunk_markdown(markdown) tr_chunks = _chunk_transcript(segments) all_chunks = md_chunks + tr_chunks if not all_chunks: logger.warning(f"笔记内容为空,跳过索引: {task_id}") return col_name = self._collection_name(task_id) # 删除旧 collection(幂等) try: self._client.delete_collection(col_name) except Exception: pass collection = self._client.create_collection( name=col_name, metadata={"hnsw:space": "cosine"}, ) documents = [c["text"] for c in all_chunks] metadatas = [c["metadata"] for c in all_chunks] ids = [f"{task_id}_{i}" for i in range(len(all_chunks))] collection.add(documents=documents, metadatas=metadatas, ids=ids) logger.info(f"向量索引完成: task_id={task_id}, chunks={len(all_chunks)}") def _parse_results(self, results: dict) -> list[dict]: """将 ChromaDB query 结果转换为 chunk 列表。""" chunks = [] if not results or not results.get("documents") or not results["documents"][0]: return chunks for i in range(len(results["documents"][0])): chunks.append({ "text": results["documents"][0][i], "metadata": results["metadatas"][0][i] if results["metadatas"] else {}, "distance": results["distances"][0][i] if results["distances"] else None, }) return chunks def query(self, task_id: str, query_text: str, n_results: int = 5) -> list[dict]: """ 分别从 markdown 和 transcript 各检索,确保两种来源都被召回, 最后按距离排序返回 top-n。 """ col_name = self._collection_name(task_id) try: collection = self._client.get_collection(col_name) except Exception: logger.warning(f"Collection 不存在: {col_name}") return [] all_chunks = [] # 分别从两种来源各检索 n_results 条 for source_type in ("markdown", "transcript"): try: results = collection.query( query_texts=[query_text], n_results=n_results, where={"source_type": source_type}, ) all_chunks.extend(self._parse_results(results)) except Exception: pass # 按距离排序,取 top-n all_chunks.sort(key=lambda c: c.get("distance", 999)) return all_chunks[:n_results] def delete_index(self, task_id: str) -> None: """删除指定任务的向量索引。""" col_name = self._collection_name(task_id) try: self._client.delete_collection(col_name) logger.info(f"已删除向量索引: {task_id}") except Exception: pass def is_indexed(self, task_id: str) -> bool: """检查指定任务是否已建立索引。""" col_name = self._collection_name(task_id) try: col = self._client.get_collection(col_name) return col.count() > 0 except Exception: return False