mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-06-06 16:20:05 +08:00
Merge pull request #299 from JefferyHcool/feature/note-qa-chat-optimize
Feature/note qa chat optimize
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from .routers import note, provider, model, config
|
||||
from .routers import note, provider, model, config, chat
|
||||
|
||||
|
||||
|
||||
@@ -10,5 +10,6 @@ def create_app(lifespan) -> FastAPI:
|
||||
app.include_router(provider.router, prefix="/api")
|
||||
app.include_router(model.router,prefix="/api")
|
||||
app.include_router(config.router, prefix="/api")
|
||||
app.include_router(chat.router, prefix="/api")
|
||||
|
||||
return app
|
||||
|
||||
101
backend/app/routers/chat.py
Normal file
101
backend/app/routers/chat.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from fastapi import APIRouter, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.services.chat_service import chat as chat_service
|
||||
from app.services.vector_store import VectorStoreManager
|
||||
from app.utils.logger import get_logger
|
||||
from app.utils.response import ResponseWrapper as R
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# 索引状态追踪: task_id -> "indexing" | "indexed" | "failed"
|
||||
_index_status: dict[str, str] = {}
|
||||
|
||||
|
||||
class IndexRequest(BaseModel):
|
||||
task_id: str
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class AskRequest(BaseModel):
|
||||
task_id: str
|
||||
question: str
|
||||
history: list[ChatMessage] = []
|
||||
provider_id: str
|
||||
model_name: str
|
||||
|
||||
|
||||
def _do_index(task_id: str):
|
||||
"""后台执行索引任务。"""
|
||||
try:
|
||||
_index_status[task_id] = "indexing"
|
||||
store = VectorStoreManager()
|
||||
store.index_task(task_id)
|
||||
_index_status[task_id] = "indexed"
|
||||
logger.info(f"索引完成: {task_id}")
|
||||
except Exception as e:
|
||||
_index_status[task_id] = "failed"
|
||||
logger.error(f"索引失败: {task_id}, {e}")
|
||||
|
||||
|
||||
@router.post("/chat/index")
|
||||
def index_task(data: IndexRequest, background_tasks: BackgroundTasks):
|
||||
"""触发后台索引,立即返回。"""
|
||||
if _index_status.get(data.task_id) == "indexing":
|
||||
return R.success(msg="正在索引中")
|
||||
|
||||
# 如果已经索引过,直接返回
|
||||
store = VectorStoreManager()
|
||||
if store.is_indexed(data.task_id):
|
||||
_index_status[data.task_id] = "indexed"
|
||||
return R.success(msg="已完成索引")
|
||||
|
||||
_index_status[data.task_id] = "indexing"
|
||||
background_tasks.add_task(_do_index, data.task_id)
|
||||
return R.success(msg="开始索引")
|
||||
|
||||
|
||||
@router.get("/chat/status")
|
||||
def chat_status(task_id: str):
|
||||
"""返回索引状态:idle / indexing / indexed / failed。"""
|
||||
try:
|
||||
# 优先检查内存状态
|
||||
status = _index_status.get(task_id)
|
||||
if status:
|
||||
return R.success(data={"status": status, "indexed": status == "indexed"})
|
||||
|
||||
# 内存没有记录,检查持久化
|
||||
store = VectorStoreManager()
|
||||
indexed = store.is_indexed(task_id)
|
||||
if indexed:
|
||||
_index_status[task_id] = "indexed"
|
||||
return R.success(data={"status": "indexed" if indexed else "idle", "indexed": indexed})
|
||||
except Exception as e:
|
||||
logger.error(f"查询索引状态失败: {e}")
|
||||
return R.success(data={"status": "idle", "indexed": False})
|
||||
|
||||
|
||||
@router.post("/chat/ask")
|
||||
def ask_question(data: AskRequest):
|
||||
"""基于笔记内容的 RAG 问答。"""
|
||||
try:
|
||||
history = [{"role": m.role, "content": m.content} for m in data.history]
|
||||
result = chat_service(
|
||||
task_id=data.task_id,
|
||||
question=data.question,
|
||||
history=history,
|
||||
provider_id=data.provider_id,
|
||||
model_name=data.model_name,
|
||||
)
|
||||
return R.success(data=result)
|
||||
except ValueError as e:
|
||||
return R.error(msg=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Chat 问答失败: {e}", exc_info=True)
|
||||
return R.error(msg=f"问答失败: {str(e)}")
|
||||
@@ -109,6 +109,12 @@ def run_note_task(task_id: str, video_url: str, platform: str, quality: Download
|
||||
return
|
||||
save_note_to_file(task_id, note)
|
||||
|
||||
# 自动建立向量索引(用于 AI 问答),失败不影响笔记生成
|
||||
try:
|
||||
from app.services.vector_store import VectorStoreManager
|
||||
VectorStoreManager().index_task(task_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"向量索引失败(不影响笔记): {e}")
|
||||
|
||||
|
||||
@router.post('/delete_task')
|
||||
|
||||
158
backend/app/services/chat_service.py
Normal file
158
backend/app/services/chat_service.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from app.gpt.gpt_factory import GPTFactory
|
||||
from app.models.model_config import ModelConfig
|
||||
from app.services.provider import ProviderService
|
||||
from app.services.vector_store import VectorStoreManager
|
||||
from app.services.chat_tools import TOOLS, execute_tool
|
||||
from app.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
SYSTEM_PROMPT = """你是一个视频笔记问答助手。你拥有以下能力:
|
||||
|
||||
1. 系统已自动检索了一些相关内容作为初始参考(见下方)
|
||||
2. 你可以调用工具主动查询更多信息:
|
||||
- lookup_transcript: 查询视频原始转录文本(支持按时间、关键词、位置筛选)
|
||||
- get_video_info: 获取视频元信息(标题、作者、简介、标签等)
|
||||
- get_note_content: 获取完整笔记内容
|
||||
|
||||
--- 初始检索内容 ---
|
||||
{context}
|
||||
---
|
||||
|
||||
回答要求:
|
||||
- 如果初始检索内容不足以回答问题,请主动调用工具获取更多信息
|
||||
- 回答关于视频具体原话、细节时,用 lookup_transcript 查询原文
|
||||
- 回答关于作者、标题等基本信息时,用 get_video_info 查询
|
||||
- 请用中文回答,保持简洁准确"""
|
||||
|
||||
|
||||
def _build_context(chunks: list[dict]) -> str:
|
||||
"""将检索到的片段拼接为上下文文本。"""
|
||||
parts = []
|
||||
for chunk in chunks:
|
||||
meta = chunk.get("metadata", {})
|
||||
source_type = meta.get("source_type", "unknown")
|
||||
if source_type == "meta":
|
||||
label = "[视频信息]"
|
||||
elif source_type == "markdown":
|
||||
label = f"[笔记 - {meta.get('section_title', '')}]"
|
||||
else:
|
||||
start = meta.get("start_time", 0)
|
||||
end = meta.get("end_time", 0)
|
||||
label = f"[转录 - {start:.0f}s~{end:.0f}s]"
|
||||
parts.append(f"{label}\n{chunk['text']}")
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
def _build_sources(chunks: list[dict]) -> list[dict]:
|
||||
"""从检索片段中提取来源信息。"""
|
||||
sources = []
|
||||
for chunk in chunks:
|
||||
meta = chunk.get("metadata", {})
|
||||
source = {
|
||||
"text": chunk["text"][:200],
|
||||
"source_type": meta.get("source_type", "unknown"),
|
||||
}
|
||||
if meta.get("section_title"):
|
||||
source["section_title"] = meta["section_title"]
|
||||
if meta.get("start_time") is not None:
|
||||
source["start_time"] = meta["start_time"]
|
||||
if meta.get("end_time") is not None:
|
||||
source["end_time"] = meta["end_time"]
|
||||
sources.append(source)
|
||||
return sources
|
||||
|
||||
|
||||
def chat(
|
||||
task_id: str,
|
||||
question: str,
|
||||
history: list[dict],
|
||||
provider_id: str,
|
||||
model_name: str,
|
||||
) -> dict:
|
||||
"""
|
||||
RAG + Tool Calling 问答。
|
||||
1. 向量检索初始上下文
|
||||
2. 调用 LLM(带 tools)
|
||||
3. 如果 LLM 调用了工具,执行工具并将结果返回给 LLM
|
||||
4. 循环直到 LLM 给出最终回答
|
||||
"""
|
||||
vector_store = VectorStoreManager()
|
||||
|
||||
# 1. 检索初始上下文
|
||||
chunks = vector_store.query(task_id, question, n_results=6)
|
||||
context = _build_context(chunks) if chunks else "(未检索到相关内容,请使用工具查询)"
|
||||
sources = _build_sources(chunks) if chunks else []
|
||||
|
||||
# 2. 构建消息
|
||||
system_msg = SYSTEM_PROMPT.format(context=context)
|
||||
messages = [{"role": "system", "content": system_msg}]
|
||||
|
||||
for msg in history[-20:]:
|
||||
messages.append({"role": msg["role"], "content": msg["content"]})
|
||||
|
||||
messages.append({"role": "user", "content": question})
|
||||
|
||||
# 3. 获取 LLM client
|
||||
provider = ProviderService.get_provider_by_id(provider_id)
|
||||
if not provider:
|
||||
raise ValueError(f"未找到模型供应商: {provider_id}")
|
||||
|
||||
config = ModelConfig(
|
||||
api_key=provider["api_key"],
|
||||
base_url=provider["base_url"],
|
||||
model_name=model_name,
|
||||
provider=provider["type"],
|
||||
name=provider["name"],
|
||||
)
|
||||
gpt = GPTFactory.from_config(config)
|
||||
|
||||
logger.info(f"Chat: task_id={task_id}, model={model_name}")
|
||||
|
||||
# 4. Tool calling 循环(最多 3 轮)
|
||||
max_rounds = 3
|
||||
for round_i in range(max_rounds):
|
||||
response = gpt.client.chat.completions.create(
|
||||
model=gpt.model,
|
||||
messages=messages,
|
||||
tools=TOOLS,
|
||||
temperature=0.7,
|
||||
)
|
||||
|
||||
msg = response.choices[0].message
|
||||
|
||||
# 没有工具调用,直接返回
|
||||
if not msg.tool_calls:
|
||||
return {"answer": msg.content or "", "sources": sources}
|
||||
|
||||
# 处理工具调用
|
||||
messages.append(msg)
|
||||
|
||||
for tool_call in msg.tool_calls:
|
||||
fn_name = tool_call.function.name
|
||||
try:
|
||||
fn_args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
fn_args = {}
|
||||
|
||||
logger.info(f"Tool call [{round_i+1}/{max_rounds}]: {fn_name}({fn_args})")
|
||||
|
||||
result = execute_tool(task_id, fn_name, fn_args)
|
||||
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": result,
|
||||
})
|
||||
|
||||
# 超过最大轮次,做最后一次不带 tools 的调用
|
||||
response = gpt.client.chat.completions.create(
|
||||
model=gpt.model,
|
||||
messages=messages,
|
||||
temperature=0.7,
|
||||
)
|
||||
|
||||
return {"answer": response.choices[0].message.content or "", "sources": sources}
|
||||
184
backend/app/services/chat_tools.py
Normal file
184
backend/app/services/chat_tools.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
Chat function calling 工具定义与执行。
|
||||
提供给 LLM 调用,用于主动查询视频原文、笔记、元信息。
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from app.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
NOTE_OUTPUT_DIR = os.getenv("NOTE_OUTPUT_DIR", "note_results")
|
||||
|
||||
|
||||
def _load_note_data(task_id: str) -> Optional[dict]:
|
||||
path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json")
|
||||
if not os.path.exists(path):
|
||||
return None
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
# ── 工具定义(OpenAI function calling 格式)──────────────────────
|
||||
|
||||
TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "lookup_transcript",
|
||||
"description": "查询视频原始转录文本。可按时间范围筛选、按关键词搜索、或获取指定位置的内容。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"start_time": {
|
||||
"type": "number",
|
||||
"description": "起始时间(秒),例如 0 表示视频开头,60 表示第1分钟",
|
||||
},
|
||||
"end_time": {
|
||||
"type": "number",
|
||||
"description": "结束时间(秒),不传则到末尾",
|
||||
},
|
||||
"keyword": {
|
||||
"type": "string",
|
||||
"description": "搜索关键词,返回包含该关键词的转录片段",
|
||||
},
|
||||
"position": {
|
||||
"type": "string",
|
||||
"enum": ["start", "end"],
|
||||
"description": "快捷位置:start=视频开头前30句,end=视频结尾后30句",
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_video_info",
|
||||
"description": "获取视频的完整元信息,包括标题、作者、简介、标签、时长、播放量等。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_note_content",
|
||||
"description": "获取 AI 生成的完整笔记内容(Markdown 格式)。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# ── 工具执行 ──────────────────────────────────────────────────
|
||||
|
||||
def execute_tool(task_id: str, tool_name: str, arguments: dict) -> str:
|
||||
"""执行工具调用,返回结果字符串。"""
|
||||
data = _load_note_data(task_id)
|
||||
if not data:
|
||||
return json.dumps({"error": "笔记数据不存在"}, ensure_ascii=False)
|
||||
|
||||
if tool_name == "lookup_transcript":
|
||||
return _lookup_transcript(data, arguments)
|
||||
elif tool_name == "get_video_info":
|
||||
return _get_video_info(data)
|
||||
elif tool_name == "get_note_content":
|
||||
return _get_note_content(data)
|
||||
else:
|
||||
return json.dumps({"error": f"未知工具: {tool_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def _lookup_transcript(data: dict, args: dict) -> str:
|
||||
segments = data.get("transcript", {}).get("segments", [])
|
||||
if not segments:
|
||||
return json.dumps({"error": "没有转录数据"}, ensure_ascii=False)
|
||||
|
||||
position = args.get("position")
|
||||
start_time = args.get("start_time")
|
||||
end_time = args.get("end_time")
|
||||
keyword = args.get("keyword", "").strip()
|
||||
|
||||
# 快捷位置
|
||||
if position == "start":
|
||||
filtered = segments[:30]
|
||||
elif position == "end":
|
||||
filtered = segments[-30:]
|
||||
else:
|
||||
filtered = segments
|
||||
|
||||
# 时间筛选
|
||||
if start_time is not None:
|
||||
filtered = [s for s in filtered if s.get("end", 0) >= start_time]
|
||||
if end_time is not None:
|
||||
filtered = [s for s in filtered if s.get("start", 0) <= end_time]
|
||||
|
||||
# 关键词筛选
|
||||
if keyword:
|
||||
filtered = [s for s in filtered if keyword.lower() in s.get("text", "").lower()]
|
||||
|
||||
# 限制返回量,避免 token 爆炸
|
||||
if len(filtered) > 50:
|
||||
filtered = filtered[:50]
|
||||
truncated = True
|
||||
else:
|
||||
truncated = False
|
||||
|
||||
result = {
|
||||
"total_segments": len(data.get("transcript", {}).get("segments", [])),
|
||||
"returned": len(filtered),
|
||||
"truncated": truncated,
|
||||
"segments": [
|
||||
{
|
||||
"start": round(s.get("start", 0), 1),
|
||||
"end": round(s.get("end", 0), 1),
|
||||
"text": s.get("text", ""),
|
||||
}
|
||||
for s in filtered
|
||||
],
|
||||
}
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
|
||||
|
||||
def _get_video_info(data: dict) -> str:
|
||||
am = data.get("audio_meta", {})
|
||||
raw = am.get("raw_info", {}) or {}
|
||||
|
||||
info = {
|
||||
"title": am.get("title") or raw.get("title", ""),
|
||||
"uploader": raw.get("uploader", ""),
|
||||
"description": raw.get("description", "")[:1000],
|
||||
"tags": raw.get("tags", [])[:20] if isinstance(raw.get("tags"), list) else [],
|
||||
"duration_seconds": am.get("duration", 0),
|
||||
"platform": am.get("platform", ""),
|
||||
"video_id": am.get("video_id", ""),
|
||||
"url": raw.get("webpage_url", ""),
|
||||
"view_count": raw.get("view_count"),
|
||||
"like_count": raw.get("like_count"),
|
||||
"comment_count": raw.get("comment_count"),
|
||||
}
|
||||
# 去除 None 值
|
||||
info = {k: v for k, v in info.items() if v is not None and v != ""}
|
||||
return json.dumps(info, ensure_ascii=False)
|
||||
|
||||
|
||||
def _get_note_content(data: dict) -> str:
|
||||
md = data.get("markdown", "")
|
||||
if isinstance(md, list):
|
||||
# 多版本,取最新
|
||||
md = md[-1].get("content", "") if md else ""
|
||||
# 限制长度
|
||||
if len(md) > 5000:
|
||||
md = md[:5000] + "\n\n... (内容过长已截断)"
|
||||
return json.dumps({"markdown": md}, ensure_ascii=False)
|
||||
226
backend/app/services/vector_store.py
Normal file
226
backend/app/services/vector_store.py
Normal file
@@ -0,0 +1,226 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
from app.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
NOTE_OUTPUT_DIR = os.getenv("NOTE_OUTPUT_DIR", "note_results")
|
||||
VECTOR_DB_DIR = os.getenv("VECTOR_DB_DIR", "vector_db")
|
||||
|
||||
|
||||
def _chunk_markdown(markdown: str) -> list[dict]:
|
||||
"""按 H2/H3 标题拆分 markdown 为语义块。"""
|
||||
sections = re.split(r'(?=^#{2,3}\s)', markdown, flags=re.MULTILINE)
|
||||
chunks = []
|
||||
for section in sections:
|
||||
section = section.strip()
|
||||
if not section or len(section) < 30:
|
||||
continue
|
||||
heading_match = re.match(r'^(#{2,3})\s+(.+)', section)
|
||||
title = heading_match.group(2).strip() if heading_match else "intro"
|
||||
chunks.append({
|
||||
"text": section,
|
||||
"metadata": {"source_type": "markdown", "section_title": title},
|
||||
})
|
||||
return chunks
|
||||
|
||||
|
||||
def _chunk_transcript(segments: list[dict], window_size: int = 15, overlap: int = 3) -> list[dict]:
|
||||
"""将转录 segments 按滑动窗口分组。"""
|
||||
if not segments:
|
||||
return []
|
||||
chunks = []
|
||||
step = max(window_size - overlap, 1)
|
||||
for i in range(0, len(segments), step):
|
||||
window = segments[i:i + window_size]
|
||||
if not window:
|
||||
break
|
||||
text = "\n".join(
|
||||
f"[{seg.get('start', 0):.0f}s] {seg.get('text', '')}" for seg in window
|
||||
)
|
||||
chunks.append({
|
||||
"text": text,
|
||||
"metadata": {
|
||||
"source_type": "transcript",
|
||||
"start_time": window[0].get("start", 0),
|
||||
"end_time": window[-1].get("end", 0),
|
||||
},
|
||||
})
|
||||
return chunks
|
||||
|
||||
|
||||
def _build_meta_chunk(audio_meta: dict) -> list[dict]:
|
||||
"""将视频元信息(标题、作者、描述、标签等)构建为可检索的 chunk。"""
|
||||
if not audio_meta:
|
||||
return []
|
||||
|
||||
raw = audio_meta.get("raw_info", {}) or {}
|
||||
parts = []
|
||||
|
||||
title = audio_meta.get("title") or raw.get("title", "")
|
||||
if title:
|
||||
parts.append(f"视频标题:{title}")
|
||||
|
||||
uploader = raw.get("uploader", "")
|
||||
if uploader:
|
||||
parts.append(f"视频作者/UP主:{uploader}")
|
||||
|
||||
desc = raw.get("description", "")
|
||||
if desc:
|
||||
parts.append(f"视频简介:{desc[:500]}")
|
||||
|
||||
tags = raw.get("tags", [])
|
||||
if tags and isinstance(tags, list):
|
||||
parts.append(f"标签:{', '.join(str(t) for t in tags[:20])}")
|
||||
|
||||
duration = audio_meta.get("duration", 0)
|
||||
if duration:
|
||||
m, s = divmod(int(duration), 60)
|
||||
parts.append(f"视频时长:{m}分{s}秒")
|
||||
|
||||
platform = audio_meta.get("platform", "")
|
||||
if platform:
|
||||
parts.append(f"平台:{platform}")
|
||||
|
||||
url = raw.get("webpage_url", "")
|
||||
if url:
|
||||
parts.append(f"链接:{url}")
|
||||
|
||||
if not parts:
|
||||
return []
|
||||
|
||||
return [{
|
||||
"text": "\n".join(parts),
|
||||
"metadata": {"source_type": "meta"},
|
||||
}]
|
||||
|
||||
|
||||
class VectorStoreManager:
|
||||
"""基于 ChromaDB 的笔记向量存储管理器。"""
|
||||
|
||||
def __init__(self):
|
||||
os.makedirs(VECTOR_DB_DIR, exist_ok=True)
|
||||
self._client = chromadb.PersistentClient(
|
||||
path=VECTOR_DB_DIR,
|
||||
settings=Settings(anonymized_telemetry=False),
|
||||
)
|
||||
|
||||
def _collection_name(self, task_id: str) -> str:
|
||||
"""ChromaDB collection 名称:直接使用 task_id(UUID 格式合法)。"""
|
||||
return task_id
|
||||
|
||||
def index_task(self, task_id: str) -> None:
|
||||
"""读取笔记结果并建立向量索引。"""
|
||||
result_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json")
|
||||
if not os.path.exists(result_path):
|
||||
logger.warning(f"笔记文件不存在,跳过索引: {result_path}")
|
||||
return
|
||||
|
||||
with open(result_path, "r", encoding="utf-8") as f:
|
||||
note_data = json.load(f)
|
||||
|
||||
markdown = note_data.get("markdown", "")
|
||||
transcript = note_data.get("transcript", {})
|
||||
segments = transcript.get("segments", [])
|
||||
|
||||
audio_meta = note_data.get("audio_meta", {})
|
||||
|
||||
meta_chunks = _build_meta_chunk(audio_meta)
|
||||
md_chunks = _chunk_markdown(markdown)
|
||||
tr_chunks = _chunk_transcript(segments)
|
||||
all_chunks = meta_chunks + md_chunks + tr_chunks
|
||||
|
||||
if not all_chunks:
|
||||
logger.warning(f"笔记内容为空,跳过索引: {task_id}")
|
||||
return
|
||||
|
||||
col_name = self._collection_name(task_id)
|
||||
|
||||
# 删除旧 collection(幂等)
|
||||
try:
|
||||
self._client.delete_collection(col_name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
collection = self._client.create_collection(
|
||||
name=col_name,
|
||||
metadata={"hnsw:space": "cosine"},
|
||||
)
|
||||
|
||||
documents = [c["text"] for c in all_chunks]
|
||||
metadatas = [c["metadata"] for c in all_chunks]
|
||||
ids = [f"{task_id}_{i}" for i in range(len(all_chunks))]
|
||||
|
||||
collection.add(documents=documents, metadatas=metadatas, ids=ids)
|
||||
logger.info(f"向量索引完成: task_id={task_id}, chunks={len(all_chunks)}")
|
||||
|
||||
def _parse_results(self, results: dict) -> list[dict]:
|
||||
"""将 ChromaDB query 结果转换为 chunk 列表。"""
|
||||
chunks = []
|
||||
if not results or not results.get("documents") or not results["documents"][0]:
|
||||
return chunks
|
||||
for i in range(len(results["documents"][0])):
|
||||
chunks.append({
|
||||
"text": results["documents"][0][i],
|
||||
"metadata": results["metadatas"][0][i] if results["metadatas"] else {},
|
||||
"distance": results["distances"][0][i] if results["distances"] else None,
|
||||
})
|
||||
return chunks
|
||||
|
||||
def query(self, task_id: str, query_text: str, n_results: int = 6) -> list[dict]:
|
||||
"""
|
||||
按固定配额从各来源检索:meta 1 条、markdown 2 条、transcript 3 条,
|
||||
确保三种来源都被召回。
|
||||
"""
|
||||
col_name = self._collection_name(task_id)
|
||||
try:
|
||||
collection = self._client.get_collection(col_name)
|
||||
except Exception:
|
||||
logger.warning(f"Collection 不存在: {col_name}")
|
||||
return []
|
||||
|
||||
all_chunks = []
|
||||
|
||||
# 每种来源的配额
|
||||
quotas = {"meta": 1, "markdown": 2, "transcript": 3}
|
||||
|
||||
for source_type, quota in quotas.items():
|
||||
try:
|
||||
results = collection.query(
|
||||
query_texts=[query_text],
|
||||
n_results=quota,
|
||||
where={"source_type": source_type},
|
||||
)
|
||||
all_chunks.extend(self._parse_results(results))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return all_chunks
|
||||
|
||||
def delete_index(self, task_id: str) -> None:
|
||||
"""删除指定任务的向量索引。"""
|
||||
col_name = self._collection_name(task_id)
|
||||
try:
|
||||
self._client.delete_collection(col_name)
|
||||
logger.info(f"已删除向量索引: {task_id}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def is_indexed(self, task_id: str) -> bool:
|
||||
"""检查指定任务是否已建立完整索引(含 meta 信息)。"""
|
||||
col_name = self._collection_name(task_id)
|
||||
try:
|
||||
col = self._client.get_collection(col_name)
|
||||
if col.count() == 0:
|
||||
return False
|
||||
# 检查是否包含 meta chunk,旧索引可能缺失
|
||||
meta = col.get(where={"source_type": "meta"}, limit=1)
|
||||
return len(meta["ids"]) > 0
|
||||
except Exception:
|
||||
return False
|
||||
Reference in New Issue
Block a user