diff --git a/.gitignore b/.gitignore index 77b781e..963cf5e 100644 --- a/.gitignore +++ b/.gitignore @@ -320,5 +320,6 @@ cython_debug/ /backend/uploads/* /backend/.idea/* /backend/config/* +/backend/vector_db/ /BiliNote_frontend/.idea/* /BiliNote_frontend/src-tauri/bin/ \ No newline at end of file diff --git a/BillNote_frontend/package.json b/BillNote_frontend/package.json index 7465337..cbc0ae4 100644 --- a/BillNote_frontend/package.json +++ b/BillNote_frontend/package.json @@ -10,6 +10,7 @@ "preview": "vite preview" }, "dependencies": { + "@ant-design/x": "^2.4.0", "@hookform/resolvers": "^5.0.1", "@lobehub/icons": "^1.97.1", "@lobehub/icons-static-svg": "^1.45.0", diff --git a/BillNote_frontend/src/pages/HomePage/components/ChatPanel.tsx b/BillNote_frontend/src/pages/HomePage/components/ChatPanel.tsx new file mode 100644 index 0000000..c81c06f --- /dev/null +++ b/BillNote_frontend/src/pages/HomePage/components/ChatPanel.tsx @@ -0,0 +1,294 @@ +import { useState, useEffect, useCallback, useMemo } from 'react' +import { Bubble, Sender } from '@ant-design/x' +import ReactMarkdown from 'react-markdown' +import remarkGfm from 'remark-gfm' +import { Button } from '@/components/ui/button' +import { Badge } from '@/components/ui/badge' +import { Loader2, Trash2, ChevronDown, ChevronUp, BookOpen, UserRound, Bot, Maximize2, Minimize2 } from 'lucide-react' +import { toast } from 'react-hot-toast' +import { useChatStore } from '@/store/chatStore' +import { useTaskStore } from '@/store/taskStore' +import { askQuestion, getChatStatus, indexTask, type ChatSource, type IndexStatus } from '@/services/chat' + +type ChatMode = 'half' | 'full' + +interface ChatPanelProps { + taskId: string + mode: ChatMode + onModeChange: (mode: ChatMode) => void +} + +function SourceBadges({ sources }: { sources: ChatSource[] }) { + const [expanded, setExpanded] = useState(false) + + if (!sources || sources.length === 0) return null + + return ( +
+ + {expanded && ( +
+ {sources.map((s, i) => ( + + {s.source_type === 'markdown' + ? s.section_title || '笔记' + : `${(s.start_time ?? 0).toFixed(0)}s ~ ${(s.end_time ?? 0).toFixed(0)}s`} + + ))} +
+ )} +
+ ) +} + +export default function ChatPanel({ taskId, mode, onModeChange }: ChatPanelProps) { + const [input, setInput] = useState('') + const [loading, setLoading] = useState(false) + const [indexStatus, setIndexStatus] = useState(null) + + const messages = useChatStore(state => state.chatHistory[taskId]) ?? [] + const addMessage = useChatStore(state => state.addMessage) + const clearChat = useChatStore(state => state.clearChat) + + const currentTaskId = useTaskStore(state => state.currentTaskId) + const tasks = useTaskStore(state => state.tasks) + const currentTask = useMemo( + () => tasks.find(t => t.id === currentTaskId) ?? null, + [tasks, currentTaskId], + ) + + // 检查索引状态,未索引时自动触发,indexing 时轮询 + useEffect(() => { + if (!taskId) return + let cancelled = false + let timer: ReturnType | null = null + + const poll = async () => { + try { + const res = await getChatStatus(taskId) + if (cancelled) return + setIndexStatus(res.status) + + if (res.status === 'idle') { + // 未索引,触发后台索引 + await indexTask(taskId) + if (!cancelled) setIndexStatus('indexing') + } + + // indexing 状态持续轮询 + if (res.status === 'indexing' || res.status === 'idle') { + timer = setTimeout(poll, 2000) + } + } catch { + if (!cancelled) setIndexStatus('failed') + } + } + + poll() + return () => { + cancelled = true + if (timer) clearTimeout(timer) + } + }, [taskId]) + + const handleSend = useCallback( + async (value: string) => { + const question = value.trim() + if (!question || loading) return + + const providerId = currentTask?.formData?.provider_id + const modelName = currentTask?.formData?.model_name + if (!providerId || !modelName) { + toast.error('无法获取模型配置,请确认任务已完成') + return + } + + addMessage(taskId, { role: 'user', content: question }) + setInput('') + setLoading(true) + + try { + const history = messages.map(m => ({ role: m.role, content: m.content })) + const res = await askQuestion({ + task_id: taskId, + question, + history, + provider_id: providerId, + model_name: modelName, + }) + addMessage(taskId, { + role: 'assistant', + content: res.answer, + sources: res.sources, + }) + } catch { + toast.error('问答请求失败') + } finally { + setLoading(false) + } + }, + [loading, taskId, currentTask, messages, addMessage], + ) + + // 转换为 Bubble.List 的数据格式 + const bubbleItems = useMemo(() => { + const items = messages.map((msg, i) => ({ + key: `msg-${i}`, + role: msg.role === 'user' ? ('user' as const) : ('ai' as const), + content: msg.content, + footer: + msg.role === 'assistant' && msg.sources ? ( + + ) : undefined, + })) + + if (loading) { + items.push({ + key: 'loading', + role: 'ai' as const, + content: '思考中...', + loading: true, + } as any) + } + + return items + }, [messages, loading]) + + // Bubble 角色配置 + const roles = useMemo( + () => ({ + user: { + placement: 'end' as const, + avatar: ( +
+ +
+ ), + variant: 'filled' as const, + styles: { content: { background: '#3b82f6', color: '#fff' } }, + }, + ai: { + placement: 'start' as const, + avatar: ( +
+ +
+ ), + variant: 'outlined' as const, + contentRender: (content: any) => ( +
+ + {typeof content === 'string' ? content : String(content)} + +
+ ), + }, + }), + [], + ) + + if (indexStatus === null || indexStatus === 'indexing' || indexStatus === 'idle') { + return ( +
+ +
+

正在索引笔记内容...

+

首次使用需下载 Embedding 模型(约 80MB),请耐心等待

+
+
+ ) + } + + if (indexStatus === 'failed') { + return ( +
+ 索引失败,请重试 + +
+ ) + } + + return ( +
+ {/* 头部 */} +
+ AI 问答 +
+ + {messages.length > 0 && ( + + )} +
+
+ + {/* 消息列表 */} +
+ {messages.length === 0 && !loading ? ( +
+
+

针对笔记内容提问

+

例如:这个视频的核心观点是什么?

+
+
+ ) : ( + + )} +
+ + {/* 输入区域 */} +
+ +
+
+ ) +} diff --git a/BillNote_frontend/src/pages/HomePage/components/MarkdownHeader.tsx b/BillNote_frontend/src/pages/HomePage/components/MarkdownHeader.tsx index 27e7405..89934b5 100644 --- a/BillNote_frontend/src/pages/HomePage/components/MarkdownHeader.tsx +++ b/BillNote_frontend/src/pages/HomePage/components/MarkdownHeader.tsx @@ -1,7 +1,7 @@ 'use client' import { useEffect, useState } from 'react' -import { Copy, Download, BrainCircuit } from 'lucide-react' +import { Copy, Download, BrainCircuit, MessageSquare } from 'lucide-react' import { Button } from '@/components/ui/button' import { Select, SelectContent, SelectItem, SelectTrigger } from '@/components/ui/select' import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@/components/ui/tooltip' @@ -28,6 +28,8 @@ interface NoteHeaderProps { onDownload: () => void createAt?: string | Date setShowTranscribe: (show: boolean) => void + showChat?: false | 'half' | 'full' + setShowChat?: (mode: false | 'half' | 'full') => void } export function MarkdownHeader({ @@ -43,6 +45,8 @@ export function MarkdownHeader({ createAt, showTranscribe, setShowTranscribe, + showChat, + setShowChat, viewMode, setViewMode, }: NoteHeaderProps) { @@ -183,6 +187,24 @@ export function MarkdownHeader({ 原文参照 + {setShowChat && ( + + + + + + 基于笔记内容的 AI 问答 + + + )} ) diff --git a/BillNote_frontend/src/pages/HomePage/components/MarkdownViewer.tsx b/BillNote_frontend/src/pages/HomePage/components/MarkdownViewer.tsx index e08f076..ac95cd3 100644 --- a/BillNote_frontend/src/pages/HomePage/components/MarkdownViewer.tsx +++ b/BillNote_frontend/src/pages/HomePage/components/MarkdownViewer.tsx @@ -22,6 +22,8 @@ import { noteStyles } from '@/constant/note.ts' import { MarkdownHeader } from '@/pages/HomePage/components/MarkdownHeader.tsx' import TranscriptViewer from '@/pages/HomePage/components/transcriptViewer.tsx' import MarkmapEditor from '@/pages/HomePage/components/MarkmapComponent.tsx' +import ChatPanel from '@/pages/HomePage/components/ChatPanel.tsx' +import VideoBanner from '@/pages/HomePage/components/VideoBanner.tsx' interface VersionNote { ver_id: string @@ -280,6 +282,7 @@ const MarkdownViewer: FC = memo(({ status }) => { const retryTask = useTaskStore.getState().retryTask const isMultiVersion = Array.isArray(currentTask?.markdown) const [showTranscribe, setShowTranscribe] = useState(false) + const [showChat, setShowChat] = useState(false) const [viewMode, setViewMode] = useState<'map' | 'preview'>('preview') const svgRef = useRef(null) @@ -422,6 +425,8 @@ const MarkdownViewer: FC = memo(({ status }) => { createAt={createTime} showTranscribe={showTranscribe} setShowTranscribe={setShowTranscribe} + showChat={showChat} + setShowChat={setShowChat} viewMode={viewMode} setViewMode={setViewMode} /> @@ -441,14 +446,26 @@ const MarkdownViewer: FC = memo(({ status }) => {
{selectedContent && selectedContent !== 'loading' && selectedContent !== 'empty' ? ( <> - + {showChat === 'full' && currentTask ? ( +
+ +
+ ) : ( + <> + +
+ +
- {selectedContent} + {selectedContent.replace(/^>\s*来源链接:[^\n]*\n*/m, '')}
@@ -457,6 +474,14 @@ const MarkdownViewer: FC = memo(({ status }) => {
)} + {/* 侧边问答模式:markdown + ChatPanel 各占一半 */} + {showChat === 'half' && currentTask && ( +
+ +
+ )} + + )} ) : (
diff --git a/BillNote_frontend/src/pages/HomePage/components/VideoBanner.tsx b/BillNote_frontend/src/pages/HomePage/components/VideoBanner.tsx new file mode 100644 index 0000000..6acd3a9 --- /dev/null +++ b/BillNote_frontend/src/pages/HomePage/components/VideoBanner.tsx @@ -0,0 +1,86 @@ +import { ExternalLink } from 'lucide-react' +import type { AudioMeta } from '@/store/taskStore' + +interface VideoBannerProps { + audioMeta?: AudioMeta + videoUrl?: string +} + +/** 平台 label 映射 */ +const platformLabel: Record = { + bilibili: '哔哩哔哩', + youtube: 'YouTube', + douyin: '抖音', + xiaohongshu: '小红书', +} + +export default function VideoBanner({ audioMeta, videoUrl }: VideoBannerProps) { + if (!audioMeta) return null + + const rawCover = audioMeta.cover_url + // 通过后端代理加载封面,避免跨域/Referrer 限制 + const apiBase = String(import.meta.env.VITE_API_BASE_URL || 'api').replace(/\/$/, '') + const coverUrl = rawCover + ? `${apiBase}/image_proxy?url=${encodeURIComponent(rawCover)}` + : '' + const title = audioMeta.title + const uploader = audioMeta.raw_info?.uploader || '' + const platform = platformLabel[audioMeta.platform] || audioMeta.platform || '' + const originalUrl = videoUrl || audioMeta.raw_info?.webpage_url || '' + + return ( +
+ {/* 模糊背景封面 */} +
+ {coverUrl ? ( + + ) : ( +
+ )} +
+ + {/* 内容层 */} +
+ {/* 封面缩略图 */} + {coverUrl && ( + {title} + )} + + {/* 文字信息 */} +
+

+ {title} +

+
+ {uploader && {uploader}} + {uploader && platform && ·} + {platform && {platform}} +
+
+ + {/* 跳转原视频 */} + {originalUrl && ( + + + 原视频 + + )} +
+
+ ) +} diff --git a/BillNote_frontend/src/services/chat.ts b/BillNote_frontend/src/services/chat.ts new file mode 100644 index 0000000..32f541f --- /dev/null +++ b/BillNote_frontend/src/services/chat.ts @@ -0,0 +1,44 @@ +import request from '@/utils/request' + +export interface ChatMessage { + role: 'user' | 'assistant' + content: string +} + +export interface ChatSource { + text: string + source_type: 'markdown' | 'transcript' + section_title?: string + start_time?: number + end_time?: number +} + +export interface AskResponse { + answer: string + sources: ChatSource[] +} + +export type IndexStatus = 'idle' | 'indexing' | 'indexed' | 'failed' + +export interface ChatStatusResponse { + indexed: boolean + status: IndexStatus +} + +export const indexTask = async (taskId: string): Promise => { + return await request.post('/chat/index', { task_id: taskId }) +} + +export const askQuestion = async (data: { + task_id: string + question: string + history: ChatMessage[] + provider_id: string + model_name: string +}): Promise => { + return await request.post('/chat/ask', data, { timeout: 60000 }) +} + +export const getChatStatus = async (taskId: string): Promise => { + return await request.get(`/chat/status?task_id=${taskId}`) +} diff --git a/BillNote_frontend/src/store/chatStore/index.ts b/BillNote_frontend/src/store/chatStore/index.ts new file mode 100644 index 0000000..51baa57 --- /dev/null +++ b/BillNote_frontend/src/store/chatStore/index.ts @@ -0,0 +1,43 @@ +import { create } from 'zustand' +import { persist } from 'zustand/middleware' +import type { ChatSource } from '@/services/chat' + +export interface ChatMessage { + role: 'user' | 'assistant' + content: string + sources?: ChatSource[] +} + +interface ChatState { + chatHistory: Record + addMessage: (taskId: string, msg: ChatMessage) => void + clearChat: (taskId: string) => void + getMessages: (taskId: string) => ChatMessage[] +} + +export const useChatStore = create()( + persist( + (set, get) => ({ + chatHistory: {}, + + addMessage: (taskId, msg) => + set(state => ({ + chatHistory: { + ...state.chatHistory, + [taskId]: [...(state.chatHistory[taskId] || []), msg], + }, + })), + + clearChat: (taskId) => + set(state => { + const { [taskId]: _, ...rest } = state.chatHistory + return { chatHistory: rest } + }), + + getMessages: (taskId) => get().chatHistory[taskId] || [], + }), + { + name: 'bilinote-chat-storage', + }, + ), +) diff --git a/backend/app/__init__.py b/backend/app/__init__.py index 56179dc..f97ca9a 100644 --- a/backend/app/__init__.py +++ b/backend/app/__init__.py @@ -1,6 +1,6 @@ from fastapi import FastAPI -from .routers import note, provider, model, config +from .routers import note, provider, model, config, chat @@ -10,5 +10,6 @@ def create_app(lifespan) -> FastAPI: app.include_router(provider.router, prefix="/api") app.include_router(model.router,prefix="/api") app.include_router(config.router, prefix="/api") + app.include_router(chat.router, prefix="/api") return app diff --git a/backend/app/routers/chat.py b/backend/app/routers/chat.py new file mode 100644 index 0000000..a5633c9 --- /dev/null +++ b/backend/app/routers/chat.py @@ -0,0 +1,101 @@ +from fastapi import APIRouter, BackgroundTasks +from pydantic import BaseModel + +from app.services.chat_service import chat as chat_service +from app.services.vector_store import VectorStoreManager +from app.utils.logger import get_logger +from app.utils.response import ResponseWrapper as R + +logger = get_logger(__name__) + +router = APIRouter() + +# 索引状态追踪: task_id -> "indexing" | "indexed" | "failed" +_index_status: dict[str, str] = {} + + +class IndexRequest(BaseModel): + task_id: str + + +class ChatMessage(BaseModel): + role: str + content: str + + +class AskRequest(BaseModel): + task_id: str + question: str + history: list[ChatMessage] = [] + provider_id: str + model_name: str + + +def _do_index(task_id: str): + """后台执行索引任务。""" + try: + _index_status[task_id] = "indexing" + store = VectorStoreManager() + store.index_task(task_id) + _index_status[task_id] = "indexed" + logger.info(f"索引完成: {task_id}") + except Exception as e: + _index_status[task_id] = "failed" + logger.error(f"索引失败: {task_id}, {e}") + + +@router.post("/chat/index") +def index_task(data: IndexRequest, background_tasks: BackgroundTasks): + """触发后台索引,立即返回。""" + if _index_status.get(data.task_id) == "indexing": + return R.success(msg="正在索引中") + + # 如果已经索引过,直接返回 + store = VectorStoreManager() + if store.is_indexed(data.task_id): + _index_status[data.task_id] = "indexed" + return R.success(msg="已完成索引") + + _index_status[data.task_id] = "indexing" + background_tasks.add_task(_do_index, data.task_id) + return R.success(msg="开始索引") + + +@router.get("/chat/status") +def chat_status(task_id: str): + """返回索引状态:idle / indexing / indexed / failed。""" + try: + # 优先检查内存状态 + status = _index_status.get(task_id) + if status: + return R.success(data={"status": status, "indexed": status == "indexed"}) + + # 内存没有记录,检查持久化 + store = VectorStoreManager() + indexed = store.is_indexed(task_id) + if indexed: + _index_status[task_id] = "indexed" + return R.success(data={"status": "indexed" if indexed else "idle", "indexed": indexed}) + except Exception as e: + logger.error(f"查询索引状态失败: {e}") + return R.success(data={"status": "idle", "indexed": False}) + + +@router.post("/chat/ask") +def ask_question(data: AskRequest): + """基于笔记内容的 RAG 问答。""" + try: + history = [{"role": m.role, "content": m.content} for m in data.history] + result = chat_service( + task_id=data.task_id, + question=data.question, + history=history, + provider_id=data.provider_id, + model_name=data.model_name, + ) + return R.success(data=result) + except ValueError as e: + return R.error(msg=str(e)) + except Exception as e: + logger.error(f"Chat 问答失败: {e}", exc_info=True) + return R.error(msg=f"问答失败: {str(e)}") diff --git a/backend/app/routers/note.py b/backend/app/routers/note.py index c050b6e..a9e2d4c 100644 --- a/backend/app/routers/note.py +++ b/backend/app/routers/note.py @@ -109,6 +109,12 @@ def run_note_task(task_id: str, video_url: str, platform: str, quality: Download return save_note_to_file(task_id, note) + # 自动建立向量索引(用于 AI 问答),失败不影响笔记生成 + try: + from app.services.vector_store import VectorStoreManager + VectorStoreManager().index_task(task_id) + except Exception as e: + logger.warning(f"向量索引失败(不影响笔记): {e}") @router.post('/delete_task') diff --git a/backend/app/services/chat_service.py b/backend/app/services/chat_service.py new file mode 100644 index 0000000..e7239cf --- /dev/null +++ b/backend/app/services/chat_service.py @@ -0,0 +1,158 @@ +import json +from typing import Optional + +from app.gpt.gpt_factory import GPTFactory +from app.models.model_config import ModelConfig +from app.services.provider import ProviderService +from app.services.vector_store import VectorStoreManager +from app.services.chat_tools import TOOLS, execute_tool +from app.utils.logger import get_logger + +logger = get_logger(__name__) + +SYSTEM_PROMPT = """你是一个视频笔记问答助手。你拥有以下能力: + +1. 系统已自动检索了一些相关内容作为初始参考(见下方) +2. 你可以调用工具主动查询更多信息: + - lookup_transcript: 查询视频原始转录文本(支持按时间、关键词、位置筛选) + - get_video_info: 获取视频元信息(标题、作者、简介、标签等) + - get_note_content: 获取完整笔记内容 + +--- 初始检索内容 --- +{context} +--- + +回答要求: +- 如果初始检索内容不足以回答问题,请主动调用工具获取更多信息 +- 回答关于视频具体原话、细节时,用 lookup_transcript 查询原文 +- 回答关于作者、标题等基本信息时,用 get_video_info 查询 +- 请用中文回答,保持简洁准确""" + + +def _build_context(chunks: list[dict]) -> str: + """将检索到的片段拼接为上下文文本。""" + parts = [] + for chunk in chunks: + meta = chunk.get("metadata", {}) + source_type = meta.get("source_type", "unknown") + if source_type == "meta": + label = "[视频信息]" + elif source_type == "markdown": + label = f"[笔记 - {meta.get('section_title', '')}]" + else: + start = meta.get("start_time", 0) + end = meta.get("end_time", 0) + label = f"[转录 - {start:.0f}s~{end:.0f}s]" + parts.append(f"{label}\n{chunk['text']}") + return "\n\n".join(parts) + + +def _build_sources(chunks: list[dict]) -> list[dict]: + """从检索片段中提取来源信息。""" + sources = [] + for chunk in chunks: + meta = chunk.get("metadata", {}) + source = { + "text": chunk["text"][:200], + "source_type": meta.get("source_type", "unknown"), + } + if meta.get("section_title"): + source["section_title"] = meta["section_title"] + if meta.get("start_time") is not None: + source["start_time"] = meta["start_time"] + if meta.get("end_time") is not None: + source["end_time"] = meta["end_time"] + sources.append(source) + return sources + + +def chat( + task_id: str, + question: str, + history: list[dict], + provider_id: str, + model_name: str, +) -> dict: + """ + RAG + Tool Calling 问答。 + 1. 向量检索初始上下文 + 2. 调用 LLM(带 tools) + 3. 如果 LLM 调用了工具,执行工具并将结果返回给 LLM + 4. 循环直到 LLM 给出最终回答 + """ + vector_store = VectorStoreManager() + + # 1. 检索初始上下文 + chunks = vector_store.query(task_id, question, n_results=6) + context = _build_context(chunks) if chunks else "(未检索到相关内容,请使用工具查询)" + sources = _build_sources(chunks) if chunks else [] + + # 2. 构建消息 + system_msg = SYSTEM_PROMPT.format(context=context) + messages = [{"role": "system", "content": system_msg}] + + for msg in history[-20:]: + messages.append({"role": msg["role"], "content": msg["content"]}) + + messages.append({"role": "user", "content": question}) + + # 3. 获取 LLM client + provider = ProviderService.get_provider_by_id(provider_id) + if not provider: + raise ValueError(f"未找到模型供应商: {provider_id}") + + config = ModelConfig( + api_key=provider["api_key"], + base_url=provider["base_url"], + model_name=model_name, + provider=provider["type"], + name=provider["name"], + ) + gpt = GPTFactory.from_config(config) + + logger.info(f"Chat: task_id={task_id}, model={model_name}") + + # 4. Tool calling 循环(最多 3 轮) + max_rounds = 3 + for round_i in range(max_rounds): + response = gpt.client.chat.completions.create( + model=gpt.model, + messages=messages, + tools=TOOLS, + temperature=0.7, + ) + + msg = response.choices[0].message + + # 没有工具调用,直接返回 + if not msg.tool_calls: + return {"answer": msg.content or "", "sources": sources} + + # 处理工具调用 + messages.append(msg) + + for tool_call in msg.tool_calls: + fn_name = tool_call.function.name + try: + fn_args = json.loads(tool_call.function.arguments) + except json.JSONDecodeError: + fn_args = {} + + logger.info(f"Tool call [{round_i+1}/{max_rounds}]: {fn_name}({fn_args})") + + result = execute_tool(task_id, fn_name, fn_args) + + messages.append({ + "role": "tool", + "tool_call_id": tool_call.id, + "content": result, + }) + + # 超过最大轮次,做最后一次不带 tools 的调用 + response = gpt.client.chat.completions.create( + model=gpt.model, + messages=messages, + temperature=0.7, + ) + + return {"answer": response.choices[0].message.content or "", "sources": sources} diff --git a/backend/app/services/chat_tools.py b/backend/app/services/chat_tools.py new file mode 100644 index 0000000..186002b --- /dev/null +++ b/backend/app/services/chat_tools.py @@ -0,0 +1,184 @@ +""" +Chat function calling 工具定义与执行。 +提供给 LLM 调用,用于主动查询视频原文、笔记、元信息。 +""" + +import json +import os +from typing import Optional + +from app.utils.logger import get_logger + +logger = get_logger(__name__) + +NOTE_OUTPUT_DIR = os.getenv("NOTE_OUTPUT_DIR", "note_results") + + +def _load_note_data(task_id: str) -> Optional[dict]: + path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json") + if not os.path.exists(path): + return None + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + + +# ── 工具定义(OpenAI function calling 格式)────────────────────── + +TOOLS = [ + { + "type": "function", + "function": { + "name": "lookup_transcript", + "description": "查询视频原始转录文本。可按时间范围筛选、按关键词搜索、或获取指定位置的内容。", + "parameters": { + "type": "object", + "properties": { + "start_time": { + "type": "number", + "description": "起始时间(秒),例如 0 表示视频开头,60 表示第1分钟", + }, + "end_time": { + "type": "number", + "description": "结束时间(秒),不传则到末尾", + }, + "keyword": { + "type": "string", + "description": "搜索关键词,返回包含该关键词的转录片段", + }, + "position": { + "type": "string", + "enum": ["start", "end"], + "description": "快捷位置:start=视频开头前30句,end=视频结尾后30句", + }, + }, + "required": [], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_video_info", + "description": "获取视频的完整元信息,包括标题、作者、简介、标签、时长、播放量等。", + "parameters": { + "type": "object", + "properties": {}, + "required": [], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_note_content", + "description": "获取 AI 生成的完整笔记内容(Markdown 格式)。", + "parameters": { + "type": "object", + "properties": {}, + "required": [], + }, + }, + }, +] + + +# ── 工具执行 ────────────────────────────────────────────────── + +def execute_tool(task_id: str, tool_name: str, arguments: dict) -> str: + """执行工具调用,返回结果字符串。""" + data = _load_note_data(task_id) + if not data: + return json.dumps({"error": "笔记数据不存在"}, ensure_ascii=False) + + if tool_name == "lookup_transcript": + return _lookup_transcript(data, arguments) + elif tool_name == "get_video_info": + return _get_video_info(data) + elif tool_name == "get_note_content": + return _get_note_content(data) + else: + return json.dumps({"error": f"未知工具: {tool_name}"}, ensure_ascii=False) + + +def _lookup_transcript(data: dict, args: dict) -> str: + segments = data.get("transcript", {}).get("segments", []) + if not segments: + return json.dumps({"error": "没有转录数据"}, ensure_ascii=False) + + position = args.get("position") + start_time = args.get("start_time") + end_time = args.get("end_time") + keyword = args.get("keyword", "").strip() + + # 快捷位置 + if position == "start": + filtered = segments[:30] + elif position == "end": + filtered = segments[-30:] + else: + filtered = segments + + # 时间筛选 + if start_time is not None: + filtered = [s for s in filtered if s.get("end", 0) >= start_time] + if end_time is not None: + filtered = [s for s in filtered if s.get("start", 0) <= end_time] + + # 关键词筛选 + if keyword: + filtered = [s for s in filtered if keyword.lower() in s.get("text", "").lower()] + + # 限制返回量,避免 token 爆炸 + if len(filtered) > 50: + filtered = filtered[:50] + truncated = True + else: + truncated = False + + result = { + "total_segments": len(data.get("transcript", {}).get("segments", [])), + "returned": len(filtered), + "truncated": truncated, + "segments": [ + { + "start": round(s.get("start", 0), 1), + "end": round(s.get("end", 0), 1), + "text": s.get("text", ""), + } + for s in filtered + ], + } + return json.dumps(result, ensure_ascii=False) + + +def _get_video_info(data: dict) -> str: + am = data.get("audio_meta", {}) + raw = am.get("raw_info", {}) or {} + + info = { + "title": am.get("title") or raw.get("title", ""), + "uploader": raw.get("uploader", ""), + "description": raw.get("description", "")[:1000], + "tags": raw.get("tags", [])[:20] if isinstance(raw.get("tags"), list) else [], + "duration_seconds": am.get("duration", 0), + "platform": am.get("platform", ""), + "video_id": am.get("video_id", ""), + "url": raw.get("webpage_url", ""), + "view_count": raw.get("view_count"), + "like_count": raw.get("like_count"), + "comment_count": raw.get("comment_count"), + } + # 去除 None 值 + info = {k: v for k, v in info.items() if v is not None and v != ""} + return json.dumps(info, ensure_ascii=False) + + +def _get_note_content(data: dict) -> str: + md = data.get("markdown", "") + if isinstance(md, list): + # 多版本,取最新 + md = md[-1].get("content", "") if md else "" + # 限制长度 + if len(md) > 5000: + md = md[:5000] + "\n\n... (内容过长已截断)" + return json.dumps({"markdown": md}, ensure_ascii=False) diff --git a/backend/app/services/vector_store.py b/backend/app/services/vector_store.py new file mode 100644 index 0000000..464e9f2 --- /dev/null +++ b/backend/app/services/vector_store.py @@ -0,0 +1,226 @@ +import json +import os +import re +from typing import Optional + +import chromadb +from chromadb.config import Settings + +from app.utils.logger import get_logger + +logger = get_logger(__name__) + +NOTE_OUTPUT_DIR = os.getenv("NOTE_OUTPUT_DIR", "note_results") +VECTOR_DB_DIR = os.getenv("VECTOR_DB_DIR", "vector_db") + + +def _chunk_markdown(markdown: str) -> list[dict]: + """按 H2/H3 标题拆分 markdown 为语义块。""" + sections = re.split(r'(?=^#{2,3}\s)', markdown, flags=re.MULTILINE) + chunks = [] + for section in sections: + section = section.strip() + if not section or len(section) < 30: + continue + heading_match = re.match(r'^(#{2,3})\s+(.+)', section) + title = heading_match.group(2).strip() if heading_match else "intro" + chunks.append({ + "text": section, + "metadata": {"source_type": "markdown", "section_title": title}, + }) + return chunks + + +def _chunk_transcript(segments: list[dict], window_size: int = 15, overlap: int = 3) -> list[dict]: + """将转录 segments 按滑动窗口分组。""" + if not segments: + return [] + chunks = [] + step = max(window_size - overlap, 1) + for i in range(0, len(segments), step): + window = segments[i:i + window_size] + if not window: + break + text = "\n".join( + f"[{seg.get('start', 0):.0f}s] {seg.get('text', '')}" for seg in window + ) + chunks.append({ + "text": text, + "metadata": { + "source_type": "transcript", + "start_time": window[0].get("start", 0), + "end_time": window[-1].get("end", 0), + }, + }) + return chunks + + +def _build_meta_chunk(audio_meta: dict) -> list[dict]: + """将视频元信息(标题、作者、描述、标签等)构建为可检索的 chunk。""" + if not audio_meta: + return [] + + raw = audio_meta.get("raw_info", {}) or {} + parts = [] + + title = audio_meta.get("title") or raw.get("title", "") + if title: + parts.append(f"视频标题:{title}") + + uploader = raw.get("uploader", "") + if uploader: + parts.append(f"视频作者/UP主:{uploader}") + + desc = raw.get("description", "") + if desc: + parts.append(f"视频简介:{desc[:500]}") + + tags = raw.get("tags", []) + if tags and isinstance(tags, list): + parts.append(f"标签:{', '.join(str(t) for t in tags[:20])}") + + duration = audio_meta.get("duration", 0) + if duration: + m, s = divmod(int(duration), 60) + parts.append(f"视频时长:{m}分{s}秒") + + platform = audio_meta.get("platform", "") + if platform: + parts.append(f"平台:{platform}") + + url = raw.get("webpage_url", "") + if url: + parts.append(f"链接:{url}") + + if not parts: + return [] + + return [{ + "text": "\n".join(parts), + "metadata": {"source_type": "meta"}, + }] + + +class VectorStoreManager: + """基于 ChromaDB 的笔记向量存储管理器。""" + + def __init__(self): + os.makedirs(VECTOR_DB_DIR, exist_ok=True) + self._client = chromadb.PersistentClient( + path=VECTOR_DB_DIR, + settings=Settings(anonymized_telemetry=False), + ) + + def _collection_name(self, task_id: str) -> str: + """ChromaDB collection 名称:直接使用 task_id(UUID 格式合法)。""" + return task_id + + def index_task(self, task_id: str) -> None: + """读取笔记结果并建立向量索引。""" + result_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json") + if not os.path.exists(result_path): + logger.warning(f"笔记文件不存在,跳过索引: {result_path}") + return + + with open(result_path, "r", encoding="utf-8") as f: + note_data = json.load(f) + + markdown = note_data.get("markdown", "") + transcript = note_data.get("transcript", {}) + segments = transcript.get("segments", []) + + audio_meta = note_data.get("audio_meta", {}) + + meta_chunks = _build_meta_chunk(audio_meta) + md_chunks = _chunk_markdown(markdown) + tr_chunks = _chunk_transcript(segments) + all_chunks = meta_chunks + md_chunks + tr_chunks + + if not all_chunks: + logger.warning(f"笔记内容为空,跳过索引: {task_id}") + return + + col_name = self._collection_name(task_id) + + # 删除旧 collection(幂等) + try: + self._client.delete_collection(col_name) + except Exception: + pass + + collection = self._client.create_collection( + name=col_name, + metadata={"hnsw:space": "cosine"}, + ) + + documents = [c["text"] for c in all_chunks] + metadatas = [c["metadata"] for c in all_chunks] + ids = [f"{task_id}_{i}" for i in range(len(all_chunks))] + + collection.add(documents=documents, metadatas=metadatas, ids=ids) + logger.info(f"向量索引完成: task_id={task_id}, chunks={len(all_chunks)}") + + def _parse_results(self, results: dict) -> list[dict]: + """将 ChromaDB query 结果转换为 chunk 列表。""" + chunks = [] + if not results or not results.get("documents") or not results["documents"][0]: + return chunks + for i in range(len(results["documents"][0])): + chunks.append({ + "text": results["documents"][0][i], + "metadata": results["metadatas"][0][i] if results["metadatas"] else {}, + "distance": results["distances"][0][i] if results["distances"] else None, + }) + return chunks + + def query(self, task_id: str, query_text: str, n_results: int = 6) -> list[dict]: + """ + 按固定配额从各来源检索:meta 1 条、markdown 2 条、transcript 3 条, + 确保三种来源都被召回。 + """ + col_name = self._collection_name(task_id) + try: + collection = self._client.get_collection(col_name) + except Exception: + logger.warning(f"Collection 不存在: {col_name}") + return [] + + all_chunks = [] + + # 每种来源的配额 + quotas = {"meta": 1, "markdown": 2, "transcript": 3} + + for source_type, quota in quotas.items(): + try: + results = collection.query( + query_texts=[query_text], + n_results=quota, + where={"source_type": source_type}, + ) + all_chunks.extend(self._parse_results(results)) + except Exception: + pass + + return all_chunks + + def delete_index(self, task_id: str) -> None: + """删除指定任务的向量索引。""" + col_name = self._collection_name(task_id) + try: + self._client.delete_collection(col_name) + logger.info(f"已删除向量索引: {task_id}") + except Exception: + pass + + def is_indexed(self, task_id: str) -> bool: + """检查指定任务是否已建立完整索引(含 meta 信息)。""" + col_name = self._collection_name(task_id) + try: + col = self._client.get_collection(col_name) + if col.count() == 0: + return False + # 检查是否包含 meta chunk,旧索引可能缺失 + meta = col.get(where={"source_type": "meta"}, limit=1) + return len(meta["ids"]) > 0 + except Exception: + return False diff --git a/backend/requirements.txt b/backend/requirements.txt index e936f3e..b3afa64 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -16,6 +16,7 @@ celery==5.5.1 certifi==2025.1.31 cffi==1.17.1 charset-normalizer==3.4.1 +chromadb>=0.5.0 click==8.1.8 click-didyoumean==0.3.1 click-plugins==1.1.1