diff --git a/.gitignore b/.gitignore
index 77b781e..963cf5e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -320,5 +320,6 @@ cython_debug/
/backend/uploads/*
/backend/.idea/*
/backend/config/*
+/backend/vector_db/
/BiliNote_frontend/.idea/*
/BiliNote_frontend/src-tauri/bin/
\ No newline at end of file
diff --git a/BillNote_frontend/package.json b/BillNote_frontend/package.json
index 7465337..cbc0ae4 100644
--- a/BillNote_frontend/package.json
+++ b/BillNote_frontend/package.json
@@ -10,6 +10,7 @@
"preview": "vite preview"
},
"dependencies": {
+ "@ant-design/x": "^2.4.0",
"@hookform/resolvers": "^5.0.1",
"@lobehub/icons": "^1.97.1",
"@lobehub/icons-static-svg": "^1.45.0",
diff --git a/BillNote_frontend/src/pages/HomePage/components/ChatPanel.tsx b/BillNote_frontend/src/pages/HomePage/components/ChatPanel.tsx
new file mode 100644
index 0000000..c81c06f
--- /dev/null
+++ b/BillNote_frontend/src/pages/HomePage/components/ChatPanel.tsx
@@ -0,0 +1,294 @@
+import { useState, useEffect, useCallback, useMemo } from 'react'
+import { Bubble, Sender } from '@ant-design/x'
+import ReactMarkdown from 'react-markdown'
+import remarkGfm from 'remark-gfm'
+import { Button } from '@/components/ui/button'
+import { Badge } from '@/components/ui/badge'
+import { Loader2, Trash2, ChevronDown, ChevronUp, BookOpen, UserRound, Bot, Maximize2, Minimize2 } from 'lucide-react'
+import { toast } from 'react-hot-toast'
+import { useChatStore } from '@/store/chatStore'
+import { useTaskStore } from '@/store/taskStore'
+import { askQuestion, getChatStatus, indexTask, type ChatSource, type IndexStatus } from '@/services/chat'
+
+type ChatMode = 'half' | 'full'
+
+interface ChatPanelProps {
+ taskId: string
+ mode: ChatMode
+ onModeChange: (mode: ChatMode) => void
+}
+
+function SourceBadges({ sources }: { sources: ChatSource[] }) {
+ const [expanded, setExpanded] = useState(false)
+
+ if (!sources || sources.length === 0) return null
+
+ return (
+
+
+ {expanded && (
+
+ {sources.map((s, i) => (
+
+ {s.source_type === 'markdown'
+ ? s.section_title || '笔记'
+ : `${(s.start_time ?? 0).toFixed(0)}s ~ ${(s.end_time ?? 0).toFixed(0)}s`}
+
+ ))}
+
+ )}
+
+ )
+}
+
+export default function ChatPanel({ taskId, mode, onModeChange }: ChatPanelProps) {
+ const [input, setInput] = useState('')
+ const [loading, setLoading] = useState(false)
+ const [indexStatus, setIndexStatus] = useState(null)
+
+ const messages = useChatStore(state => state.chatHistory[taskId]) ?? []
+ const addMessage = useChatStore(state => state.addMessage)
+ const clearChat = useChatStore(state => state.clearChat)
+
+ const currentTaskId = useTaskStore(state => state.currentTaskId)
+ const tasks = useTaskStore(state => state.tasks)
+ const currentTask = useMemo(
+ () => tasks.find(t => t.id === currentTaskId) ?? null,
+ [tasks, currentTaskId],
+ )
+
+ // 检查索引状态,未索引时自动触发,indexing 时轮询
+ useEffect(() => {
+ if (!taskId) return
+ let cancelled = false
+ let timer: ReturnType | null = null
+
+ const poll = async () => {
+ try {
+ const res = await getChatStatus(taskId)
+ if (cancelled) return
+ setIndexStatus(res.status)
+
+ if (res.status === 'idle') {
+ // 未索引,触发后台索引
+ await indexTask(taskId)
+ if (!cancelled) setIndexStatus('indexing')
+ }
+
+ // indexing 状态持续轮询
+ if (res.status === 'indexing' || res.status === 'idle') {
+ timer = setTimeout(poll, 2000)
+ }
+ } catch {
+ if (!cancelled) setIndexStatus('failed')
+ }
+ }
+
+ poll()
+ return () => {
+ cancelled = true
+ if (timer) clearTimeout(timer)
+ }
+ }, [taskId])
+
+ const handleSend = useCallback(
+ async (value: string) => {
+ const question = value.trim()
+ if (!question || loading) return
+
+ const providerId = currentTask?.formData?.provider_id
+ const modelName = currentTask?.formData?.model_name
+ if (!providerId || !modelName) {
+ toast.error('无法获取模型配置,请确认任务已完成')
+ return
+ }
+
+ addMessage(taskId, { role: 'user', content: question })
+ setInput('')
+ setLoading(true)
+
+ try {
+ const history = messages.map(m => ({ role: m.role, content: m.content }))
+ const res = await askQuestion({
+ task_id: taskId,
+ question,
+ history,
+ provider_id: providerId,
+ model_name: modelName,
+ })
+ addMessage(taskId, {
+ role: 'assistant',
+ content: res.answer,
+ sources: res.sources,
+ })
+ } catch {
+ toast.error('问答请求失败')
+ } finally {
+ setLoading(false)
+ }
+ },
+ [loading, taskId, currentTask, messages, addMessage],
+ )
+
+ // 转换为 Bubble.List 的数据格式
+ const bubbleItems = useMemo(() => {
+ const items = messages.map((msg, i) => ({
+ key: `msg-${i}`,
+ role: msg.role === 'user' ? ('user' as const) : ('ai' as const),
+ content: msg.content,
+ footer:
+ msg.role === 'assistant' && msg.sources ? (
+
+ ) : undefined,
+ }))
+
+ if (loading) {
+ items.push({
+ key: 'loading',
+ role: 'ai' as const,
+ content: '思考中...',
+ loading: true,
+ } as any)
+ }
+
+ return items
+ }, [messages, loading])
+
+ // Bubble 角色配置
+ const roles = useMemo(
+ () => ({
+ user: {
+ placement: 'end' as const,
+ avatar: (
+
+
+
+ ),
+ variant: 'filled' as const,
+ styles: { content: { background: '#3b82f6', color: '#fff' } },
+ },
+ ai: {
+ placement: 'start' as const,
+ avatar: (
+
+
+
+ ),
+ variant: 'outlined' as const,
+ contentRender: (content: any) => (
+
+
+ {typeof content === 'string' ? content : String(content)}
+
+
+ ),
+ },
+ }),
+ [],
+ )
+
+ if (indexStatus === null || indexStatus === 'indexing' || indexStatus === 'idle') {
+ return (
+
+
+
+
正在索引笔记内容...
+
首次使用需下载 Embedding 模型(约 80MB),请耐心等待
+
+
+ )
+ }
+
+ if (indexStatus === 'failed') {
+ return (
+
+ 索引失败,请重试
+
+
+ )
+ }
+
+ return (
+
+ {/* 头部 */}
+
+
AI 问答
+
+
+ {messages.length > 0 && (
+
+ )}
+
+
+
+ {/* 消息列表 */}
+
+ {messages.length === 0 && !loading ? (
+
+
+
针对笔记内容提问
+
例如:这个视频的核心观点是什么?
+
+
+ ) : (
+
+ )}
+
+
+ {/* 输入区域 */}
+
+
+
+
+ )
+}
diff --git a/BillNote_frontend/src/pages/HomePage/components/MarkdownHeader.tsx b/BillNote_frontend/src/pages/HomePage/components/MarkdownHeader.tsx
index 27e7405..89934b5 100644
--- a/BillNote_frontend/src/pages/HomePage/components/MarkdownHeader.tsx
+++ b/BillNote_frontend/src/pages/HomePage/components/MarkdownHeader.tsx
@@ -1,7 +1,7 @@
'use client'
import { useEffect, useState } from 'react'
-import { Copy, Download, BrainCircuit } from 'lucide-react'
+import { Copy, Download, BrainCircuit, MessageSquare } from 'lucide-react'
import { Button } from '@/components/ui/button'
import { Select, SelectContent, SelectItem, SelectTrigger } from '@/components/ui/select'
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@/components/ui/tooltip'
@@ -28,6 +28,8 @@ interface NoteHeaderProps {
onDownload: () => void
createAt?: string | Date
setShowTranscribe: (show: boolean) => void
+ showChat?: false | 'half' | 'full'
+ setShowChat?: (mode: false | 'half' | 'full') => void
}
export function MarkdownHeader({
@@ -43,6 +45,8 @@ export function MarkdownHeader({
createAt,
showTranscribe,
setShowTranscribe,
+ showChat,
+ setShowChat,
viewMode,
setViewMode,
}: NoteHeaderProps) {
@@ -183,6 +187,24 @@ export function MarkdownHeader({
原文参照
+ {setShowChat && (
+
+
+
+
+
+ 基于笔记内容的 AI 问答
+
+
+ )}
)
diff --git a/BillNote_frontend/src/pages/HomePage/components/MarkdownViewer.tsx b/BillNote_frontend/src/pages/HomePage/components/MarkdownViewer.tsx
index e08f076..ac95cd3 100644
--- a/BillNote_frontend/src/pages/HomePage/components/MarkdownViewer.tsx
+++ b/BillNote_frontend/src/pages/HomePage/components/MarkdownViewer.tsx
@@ -22,6 +22,8 @@ import { noteStyles } from '@/constant/note.ts'
import { MarkdownHeader } from '@/pages/HomePage/components/MarkdownHeader.tsx'
import TranscriptViewer from '@/pages/HomePage/components/transcriptViewer.tsx'
import MarkmapEditor from '@/pages/HomePage/components/MarkmapComponent.tsx'
+import ChatPanel from '@/pages/HomePage/components/ChatPanel.tsx'
+import VideoBanner from '@/pages/HomePage/components/VideoBanner.tsx'
interface VersionNote {
ver_id: string
@@ -280,6 +282,7 @@ const MarkdownViewer: FC = memo(({ status }) => {
const retryTask = useTaskStore.getState().retryTask
const isMultiVersion = Array.isArray(currentTask?.markdown)
const [showTranscribe, setShowTranscribe] = useState(false)
+ const [showChat, setShowChat] = useState(false)
const [viewMode, setViewMode] = useState<'map' | 'preview'>('preview')
const svgRef = useRef(null)
@@ -422,6 +425,8 @@ const MarkdownViewer: FC = memo(({ status }) => {
createAt={createTime}
showTranscribe={showTranscribe}
setShowTranscribe={setShowTranscribe}
+ showChat={showChat}
+ setShowChat={setShowChat}
viewMode={viewMode}
setViewMode={setViewMode}
/>
@@ -441,14 +446,26 @@ const MarkdownViewer: FC = memo(({ status }) => {
{selectedContent && selectedContent !== 'loading' && selectedContent !== 'empty' ? (
<>
-
+ {showChat === 'full' && currentTask ? (
+
+
+
+ ) : (
+ <>
+
+
+
+
- {selectedContent}
+ {selectedContent.replace(/^>\s*来源链接:[^\n]*\n*/m, '')}
@@ -457,6 +474,14 @@ const MarkdownViewer: FC = memo(({ status }) => {
)}
+ {/* 侧边问答模式:markdown + ChatPanel 各占一半 */}
+ {showChat === 'half' && currentTask && (
+
+
+
+ )}
+ >
+ )}
>
) : (
diff --git a/BillNote_frontend/src/pages/HomePage/components/VideoBanner.tsx b/BillNote_frontend/src/pages/HomePage/components/VideoBanner.tsx
new file mode 100644
index 0000000..6acd3a9
--- /dev/null
+++ b/BillNote_frontend/src/pages/HomePage/components/VideoBanner.tsx
@@ -0,0 +1,86 @@
+import { ExternalLink } from 'lucide-react'
+import type { AudioMeta } from '@/store/taskStore'
+
+interface VideoBannerProps {
+ audioMeta?: AudioMeta
+ videoUrl?: string
+}
+
+/** 平台 label 映射 */
+const platformLabel: Record
= {
+ bilibili: '哔哩哔哩',
+ youtube: 'YouTube',
+ douyin: '抖音',
+ xiaohongshu: '小红书',
+}
+
+export default function VideoBanner({ audioMeta, videoUrl }: VideoBannerProps) {
+ if (!audioMeta) return null
+
+ const rawCover = audioMeta.cover_url
+ // 通过后端代理加载封面,避免跨域/Referrer 限制
+ const apiBase = String(import.meta.env.VITE_API_BASE_URL || 'api').replace(/\/$/, '')
+ const coverUrl = rawCover
+ ? `${apiBase}/image_proxy?url=${encodeURIComponent(rawCover)}`
+ : ''
+ const title = audioMeta.title
+ const uploader = audioMeta.raw_info?.uploader || ''
+ const platform = platformLabel[audioMeta.platform] || audioMeta.platform || ''
+ const originalUrl = videoUrl || audioMeta.raw_info?.webpage_url || ''
+
+ return (
+
+ {/* 模糊背景封面 */}
+
+ {coverUrl ? (
+

+ ) : (
+
+ )}
+
+
+ {/* 内容层 */}
+
+ {/* 封面缩略图 */}
+ {coverUrl && (
+

+ )}
+
+ {/* 文字信息 */}
+
+
+ {title}
+
+
+ {uploader && {uploader}}
+ {uploader && platform && ·}
+ {platform && {platform}}
+
+
+
+ {/* 跳转原视频 */}
+ {originalUrl && (
+
+
+ 原视频
+
+ )}
+
+
+ )
+}
diff --git a/BillNote_frontend/src/services/chat.ts b/BillNote_frontend/src/services/chat.ts
new file mode 100644
index 0000000..32f541f
--- /dev/null
+++ b/BillNote_frontend/src/services/chat.ts
@@ -0,0 +1,44 @@
+import request from '@/utils/request'
+
+export interface ChatMessage {
+ role: 'user' | 'assistant'
+ content: string
+}
+
+export interface ChatSource {
+ text: string
+ source_type: 'markdown' | 'transcript'
+ section_title?: string
+ start_time?: number
+ end_time?: number
+}
+
+export interface AskResponse {
+ answer: string
+ sources: ChatSource[]
+}
+
+export type IndexStatus = 'idle' | 'indexing' | 'indexed' | 'failed'
+
+export interface ChatStatusResponse {
+ indexed: boolean
+ status: IndexStatus
+}
+
+export const indexTask = async (taskId: string): Promise => {
+ return await request.post('/chat/index', { task_id: taskId })
+}
+
+export const askQuestion = async (data: {
+ task_id: string
+ question: string
+ history: ChatMessage[]
+ provider_id: string
+ model_name: string
+}): Promise => {
+ return await request.post('/chat/ask', data, { timeout: 60000 })
+}
+
+export const getChatStatus = async (taskId: string): Promise => {
+ return await request.get(`/chat/status?task_id=${taskId}`)
+}
diff --git a/BillNote_frontend/src/store/chatStore/index.ts b/BillNote_frontend/src/store/chatStore/index.ts
new file mode 100644
index 0000000..51baa57
--- /dev/null
+++ b/BillNote_frontend/src/store/chatStore/index.ts
@@ -0,0 +1,43 @@
+import { create } from 'zustand'
+import { persist } from 'zustand/middleware'
+import type { ChatSource } from '@/services/chat'
+
+export interface ChatMessage {
+ role: 'user' | 'assistant'
+ content: string
+ sources?: ChatSource[]
+}
+
+interface ChatState {
+ chatHistory: Record
+ addMessage: (taskId: string, msg: ChatMessage) => void
+ clearChat: (taskId: string) => void
+ getMessages: (taskId: string) => ChatMessage[]
+}
+
+export const useChatStore = create()(
+ persist(
+ (set, get) => ({
+ chatHistory: {},
+
+ addMessage: (taskId, msg) =>
+ set(state => ({
+ chatHistory: {
+ ...state.chatHistory,
+ [taskId]: [...(state.chatHistory[taskId] || []), msg],
+ },
+ })),
+
+ clearChat: (taskId) =>
+ set(state => {
+ const { [taskId]: _, ...rest } = state.chatHistory
+ return { chatHistory: rest }
+ }),
+
+ getMessages: (taskId) => get().chatHistory[taskId] || [],
+ }),
+ {
+ name: 'bilinote-chat-storage',
+ },
+ ),
+)
diff --git a/backend/app/__init__.py b/backend/app/__init__.py
index 56179dc..f97ca9a 100644
--- a/backend/app/__init__.py
+++ b/backend/app/__init__.py
@@ -1,6 +1,6 @@
from fastapi import FastAPI
-from .routers import note, provider, model, config
+from .routers import note, provider, model, config, chat
@@ -10,5 +10,6 @@ def create_app(lifespan) -> FastAPI:
app.include_router(provider.router, prefix="/api")
app.include_router(model.router,prefix="/api")
app.include_router(config.router, prefix="/api")
+ app.include_router(chat.router, prefix="/api")
return app
diff --git a/backend/app/routers/chat.py b/backend/app/routers/chat.py
new file mode 100644
index 0000000..a5633c9
--- /dev/null
+++ b/backend/app/routers/chat.py
@@ -0,0 +1,101 @@
+from fastapi import APIRouter, BackgroundTasks
+from pydantic import BaseModel
+
+from app.services.chat_service import chat as chat_service
+from app.services.vector_store import VectorStoreManager
+from app.utils.logger import get_logger
+from app.utils.response import ResponseWrapper as R
+
+logger = get_logger(__name__)
+
+router = APIRouter()
+
+# 索引状态追踪: task_id -> "indexing" | "indexed" | "failed"
+_index_status: dict[str, str] = {}
+
+
+class IndexRequest(BaseModel):
+ task_id: str
+
+
+class ChatMessage(BaseModel):
+ role: str
+ content: str
+
+
+class AskRequest(BaseModel):
+ task_id: str
+ question: str
+ history: list[ChatMessage] = []
+ provider_id: str
+ model_name: str
+
+
+def _do_index(task_id: str):
+ """后台执行索引任务。"""
+ try:
+ _index_status[task_id] = "indexing"
+ store = VectorStoreManager()
+ store.index_task(task_id)
+ _index_status[task_id] = "indexed"
+ logger.info(f"索引完成: {task_id}")
+ except Exception as e:
+ _index_status[task_id] = "failed"
+ logger.error(f"索引失败: {task_id}, {e}")
+
+
+@router.post("/chat/index")
+def index_task(data: IndexRequest, background_tasks: BackgroundTasks):
+ """触发后台索引,立即返回。"""
+ if _index_status.get(data.task_id) == "indexing":
+ return R.success(msg="正在索引中")
+
+ # 如果已经索引过,直接返回
+ store = VectorStoreManager()
+ if store.is_indexed(data.task_id):
+ _index_status[data.task_id] = "indexed"
+ return R.success(msg="已完成索引")
+
+ _index_status[data.task_id] = "indexing"
+ background_tasks.add_task(_do_index, data.task_id)
+ return R.success(msg="开始索引")
+
+
+@router.get("/chat/status")
+def chat_status(task_id: str):
+ """返回索引状态:idle / indexing / indexed / failed。"""
+ try:
+ # 优先检查内存状态
+ status = _index_status.get(task_id)
+ if status:
+ return R.success(data={"status": status, "indexed": status == "indexed"})
+
+ # 内存没有记录,检查持久化
+ store = VectorStoreManager()
+ indexed = store.is_indexed(task_id)
+ if indexed:
+ _index_status[task_id] = "indexed"
+ return R.success(data={"status": "indexed" if indexed else "idle", "indexed": indexed})
+ except Exception as e:
+ logger.error(f"查询索引状态失败: {e}")
+ return R.success(data={"status": "idle", "indexed": False})
+
+
+@router.post("/chat/ask")
+def ask_question(data: AskRequest):
+ """基于笔记内容的 RAG 问答。"""
+ try:
+ history = [{"role": m.role, "content": m.content} for m in data.history]
+ result = chat_service(
+ task_id=data.task_id,
+ question=data.question,
+ history=history,
+ provider_id=data.provider_id,
+ model_name=data.model_name,
+ )
+ return R.success(data=result)
+ except ValueError as e:
+ return R.error(msg=str(e))
+ except Exception as e:
+ logger.error(f"Chat 问答失败: {e}", exc_info=True)
+ return R.error(msg=f"问答失败: {str(e)}")
diff --git a/backend/app/routers/note.py b/backend/app/routers/note.py
index c050b6e..a9e2d4c 100644
--- a/backend/app/routers/note.py
+++ b/backend/app/routers/note.py
@@ -109,6 +109,12 @@ def run_note_task(task_id: str, video_url: str, platform: str, quality: Download
return
save_note_to_file(task_id, note)
+ # 自动建立向量索引(用于 AI 问答),失败不影响笔记生成
+ try:
+ from app.services.vector_store import VectorStoreManager
+ VectorStoreManager().index_task(task_id)
+ except Exception as e:
+ logger.warning(f"向量索引失败(不影响笔记): {e}")
@router.post('/delete_task')
diff --git a/backend/app/services/chat_service.py b/backend/app/services/chat_service.py
new file mode 100644
index 0000000..e7239cf
--- /dev/null
+++ b/backend/app/services/chat_service.py
@@ -0,0 +1,158 @@
+import json
+from typing import Optional
+
+from app.gpt.gpt_factory import GPTFactory
+from app.models.model_config import ModelConfig
+from app.services.provider import ProviderService
+from app.services.vector_store import VectorStoreManager
+from app.services.chat_tools import TOOLS, execute_tool
+from app.utils.logger import get_logger
+
+logger = get_logger(__name__)
+
+SYSTEM_PROMPT = """你是一个视频笔记问答助手。你拥有以下能力:
+
+1. 系统已自动检索了一些相关内容作为初始参考(见下方)
+2. 你可以调用工具主动查询更多信息:
+ - lookup_transcript: 查询视频原始转录文本(支持按时间、关键词、位置筛选)
+ - get_video_info: 获取视频元信息(标题、作者、简介、标签等)
+ - get_note_content: 获取完整笔记内容
+
+--- 初始检索内容 ---
+{context}
+---
+
+回答要求:
+- 如果初始检索内容不足以回答问题,请主动调用工具获取更多信息
+- 回答关于视频具体原话、细节时,用 lookup_transcript 查询原文
+- 回答关于作者、标题等基本信息时,用 get_video_info 查询
+- 请用中文回答,保持简洁准确"""
+
+
+def _build_context(chunks: list[dict]) -> str:
+ """将检索到的片段拼接为上下文文本。"""
+ parts = []
+ for chunk in chunks:
+ meta = chunk.get("metadata", {})
+ source_type = meta.get("source_type", "unknown")
+ if source_type == "meta":
+ label = "[视频信息]"
+ elif source_type == "markdown":
+ label = f"[笔记 - {meta.get('section_title', '')}]"
+ else:
+ start = meta.get("start_time", 0)
+ end = meta.get("end_time", 0)
+ label = f"[转录 - {start:.0f}s~{end:.0f}s]"
+ parts.append(f"{label}\n{chunk['text']}")
+ return "\n\n".join(parts)
+
+
+def _build_sources(chunks: list[dict]) -> list[dict]:
+ """从检索片段中提取来源信息。"""
+ sources = []
+ for chunk in chunks:
+ meta = chunk.get("metadata", {})
+ source = {
+ "text": chunk["text"][:200],
+ "source_type": meta.get("source_type", "unknown"),
+ }
+ if meta.get("section_title"):
+ source["section_title"] = meta["section_title"]
+ if meta.get("start_time") is not None:
+ source["start_time"] = meta["start_time"]
+ if meta.get("end_time") is not None:
+ source["end_time"] = meta["end_time"]
+ sources.append(source)
+ return sources
+
+
+def chat(
+ task_id: str,
+ question: str,
+ history: list[dict],
+ provider_id: str,
+ model_name: str,
+) -> dict:
+ """
+ RAG + Tool Calling 问答。
+ 1. 向量检索初始上下文
+ 2. 调用 LLM(带 tools)
+ 3. 如果 LLM 调用了工具,执行工具并将结果返回给 LLM
+ 4. 循环直到 LLM 给出最终回答
+ """
+ vector_store = VectorStoreManager()
+
+ # 1. 检索初始上下文
+ chunks = vector_store.query(task_id, question, n_results=6)
+ context = _build_context(chunks) if chunks else "(未检索到相关内容,请使用工具查询)"
+ sources = _build_sources(chunks) if chunks else []
+
+ # 2. 构建消息
+ system_msg = SYSTEM_PROMPT.format(context=context)
+ messages = [{"role": "system", "content": system_msg}]
+
+ for msg in history[-20:]:
+ messages.append({"role": msg["role"], "content": msg["content"]})
+
+ messages.append({"role": "user", "content": question})
+
+ # 3. 获取 LLM client
+ provider = ProviderService.get_provider_by_id(provider_id)
+ if not provider:
+ raise ValueError(f"未找到模型供应商: {provider_id}")
+
+ config = ModelConfig(
+ api_key=provider["api_key"],
+ base_url=provider["base_url"],
+ model_name=model_name,
+ provider=provider["type"],
+ name=provider["name"],
+ )
+ gpt = GPTFactory.from_config(config)
+
+ logger.info(f"Chat: task_id={task_id}, model={model_name}")
+
+ # 4. Tool calling 循环(最多 3 轮)
+ max_rounds = 3
+ for round_i in range(max_rounds):
+ response = gpt.client.chat.completions.create(
+ model=gpt.model,
+ messages=messages,
+ tools=TOOLS,
+ temperature=0.7,
+ )
+
+ msg = response.choices[0].message
+
+ # 没有工具调用,直接返回
+ if not msg.tool_calls:
+ return {"answer": msg.content or "", "sources": sources}
+
+ # 处理工具调用
+ messages.append(msg)
+
+ for tool_call in msg.tool_calls:
+ fn_name = tool_call.function.name
+ try:
+ fn_args = json.loads(tool_call.function.arguments)
+ except json.JSONDecodeError:
+ fn_args = {}
+
+ logger.info(f"Tool call [{round_i+1}/{max_rounds}]: {fn_name}({fn_args})")
+
+ result = execute_tool(task_id, fn_name, fn_args)
+
+ messages.append({
+ "role": "tool",
+ "tool_call_id": tool_call.id,
+ "content": result,
+ })
+
+ # 超过最大轮次,做最后一次不带 tools 的调用
+ response = gpt.client.chat.completions.create(
+ model=gpt.model,
+ messages=messages,
+ temperature=0.7,
+ )
+
+ return {"answer": response.choices[0].message.content or "", "sources": sources}
diff --git a/backend/app/services/chat_tools.py b/backend/app/services/chat_tools.py
new file mode 100644
index 0000000..186002b
--- /dev/null
+++ b/backend/app/services/chat_tools.py
@@ -0,0 +1,184 @@
+"""
+Chat function calling 工具定义与执行。
+提供给 LLM 调用,用于主动查询视频原文、笔记、元信息。
+"""
+
+import json
+import os
+from typing import Optional
+
+from app.utils.logger import get_logger
+
+logger = get_logger(__name__)
+
+NOTE_OUTPUT_DIR = os.getenv("NOTE_OUTPUT_DIR", "note_results")
+
+
+def _load_note_data(task_id: str) -> Optional[dict]:
+ path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json")
+ if not os.path.exists(path):
+ return None
+ with open(path, "r", encoding="utf-8") as f:
+ return json.load(f)
+
+
+# ── 工具定义(OpenAI function calling 格式)──────────────────────
+
+TOOLS = [
+ {
+ "type": "function",
+ "function": {
+ "name": "lookup_transcript",
+ "description": "查询视频原始转录文本。可按时间范围筛选、按关键词搜索、或获取指定位置的内容。",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "start_time": {
+ "type": "number",
+ "description": "起始时间(秒),例如 0 表示视频开头,60 表示第1分钟",
+ },
+ "end_time": {
+ "type": "number",
+ "description": "结束时间(秒),不传则到末尾",
+ },
+ "keyword": {
+ "type": "string",
+ "description": "搜索关键词,返回包含该关键词的转录片段",
+ },
+ "position": {
+ "type": "string",
+ "enum": ["start", "end"],
+ "description": "快捷位置:start=视频开头前30句,end=视频结尾后30句",
+ },
+ },
+ "required": [],
+ },
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "get_video_info",
+ "description": "获取视频的完整元信息,包括标题、作者、简介、标签、时长、播放量等。",
+ "parameters": {
+ "type": "object",
+ "properties": {},
+ "required": [],
+ },
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "get_note_content",
+ "description": "获取 AI 生成的完整笔记内容(Markdown 格式)。",
+ "parameters": {
+ "type": "object",
+ "properties": {},
+ "required": [],
+ },
+ },
+ },
+]
+
+
+# ── 工具执行 ──────────────────────────────────────────────────
+
+def execute_tool(task_id: str, tool_name: str, arguments: dict) -> str:
+ """执行工具调用,返回结果字符串。"""
+ data = _load_note_data(task_id)
+ if not data:
+ return json.dumps({"error": "笔记数据不存在"}, ensure_ascii=False)
+
+ if tool_name == "lookup_transcript":
+ return _lookup_transcript(data, arguments)
+ elif tool_name == "get_video_info":
+ return _get_video_info(data)
+ elif tool_name == "get_note_content":
+ return _get_note_content(data)
+ else:
+ return json.dumps({"error": f"未知工具: {tool_name}"}, ensure_ascii=False)
+
+
+def _lookup_transcript(data: dict, args: dict) -> str:
+ segments = data.get("transcript", {}).get("segments", [])
+ if not segments:
+ return json.dumps({"error": "没有转录数据"}, ensure_ascii=False)
+
+ position = args.get("position")
+ start_time = args.get("start_time")
+ end_time = args.get("end_time")
+ keyword = args.get("keyword", "").strip()
+
+ # 快捷位置
+ if position == "start":
+ filtered = segments[:30]
+ elif position == "end":
+ filtered = segments[-30:]
+ else:
+ filtered = segments
+
+ # 时间筛选
+ if start_time is not None:
+ filtered = [s for s in filtered if s.get("end", 0) >= start_time]
+ if end_time is not None:
+ filtered = [s for s in filtered if s.get("start", 0) <= end_time]
+
+ # 关键词筛选
+ if keyword:
+ filtered = [s for s in filtered if keyword.lower() in s.get("text", "").lower()]
+
+ # 限制返回量,避免 token 爆炸
+ if len(filtered) > 50:
+ filtered = filtered[:50]
+ truncated = True
+ else:
+ truncated = False
+
+ result = {
+ "total_segments": len(data.get("transcript", {}).get("segments", [])),
+ "returned": len(filtered),
+ "truncated": truncated,
+ "segments": [
+ {
+ "start": round(s.get("start", 0), 1),
+ "end": round(s.get("end", 0), 1),
+ "text": s.get("text", ""),
+ }
+ for s in filtered
+ ],
+ }
+ return json.dumps(result, ensure_ascii=False)
+
+
+def _get_video_info(data: dict) -> str:
+ am = data.get("audio_meta", {})
+ raw = am.get("raw_info", {}) or {}
+
+ info = {
+ "title": am.get("title") or raw.get("title", ""),
+ "uploader": raw.get("uploader", ""),
+ "description": raw.get("description", "")[:1000],
+ "tags": raw.get("tags", [])[:20] if isinstance(raw.get("tags"), list) else [],
+ "duration_seconds": am.get("duration", 0),
+ "platform": am.get("platform", ""),
+ "video_id": am.get("video_id", ""),
+ "url": raw.get("webpage_url", ""),
+ "view_count": raw.get("view_count"),
+ "like_count": raw.get("like_count"),
+ "comment_count": raw.get("comment_count"),
+ }
+ # 去除 None 值
+ info = {k: v for k, v in info.items() if v is not None and v != ""}
+ return json.dumps(info, ensure_ascii=False)
+
+
+def _get_note_content(data: dict) -> str:
+ md = data.get("markdown", "")
+ if isinstance(md, list):
+ # 多版本,取最新
+ md = md[-1].get("content", "") if md else ""
+ # 限制长度
+ if len(md) > 5000:
+ md = md[:5000] + "\n\n... (内容过长已截断)"
+ return json.dumps({"markdown": md}, ensure_ascii=False)
diff --git a/backend/app/services/vector_store.py b/backend/app/services/vector_store.py
new file mode 100644
index 0000000..464e9f2
--- /dev/null
+++ b/backend/app/services/vector_store.py
@@ -0,0 +1,226 @@
+import json
+import os
+import re
+from typing import Optional
+
+import chromadb
+from chromadb.config import Settings
+
+from app.utils.logger import get_logger
+
+logger = get_logger(__name__)
+
+NOTE_OUTPUT_DIR = os.getenv("NOTE_OUTPUT_DIR", "note_results")
+VECTOR_DB_DIR = os.getenv("VECTOR_DB_DIR", "vector_db")
+
+
+def _chunk_markdown(markdown: str) -> list[dict]:
+ """按 H2/H3 标题拆分 markdown 为语义块。"""
+ sections = re.split(r'(?=^#{2,3}\s)', markdown, flags=re.MULTILINE)
+ chunks = []
+ for section in sections:
+ section = section.strip()
+ if not section or len(section) < 30:
+ continue
+ heading_match = re.match(r'^(#{2,3})\s+(.+)', section)
+ title = heading_match.group(2).strip() if heading_match else "intro"
+ chunks.append({
+ "text": section,
+ "metadata": {"source_type": "markdown", "section_title": title},
+ })
+ return chunks
+
+
+def _chunk_transcript(segments: list[dict], window_size: int = 15, overlap: int = 3) -> list[dict]:
+ """将转录 segments 按滑动窗口分组。"""
+ if not segments:
+ return []
+ chunks = []
+ step = max(window_size - overlap, 1)
+ for i in range(0, len(segments), step):
+ window = segments[i:i + window_size]
+ if not window:
+ break
+ text = "\n".join(
+ f"[{seg.get('start', 0):.0f}s] {seg.get('text', '')}" for seg in window
+ )
+ chunks.append({
+ "text": text,
+ "metadata": {
+ "source_type": "transcript",
+ "start_time": window[0].get("start", 0),
+ "end_time": window[-1].get("end", 0),
+ },
+ })
+ return chunks
+
+
+def _build_meta_chunk(audio_meta: dict) -> list[dict]:
+ """将视频元信息(标题、作者、描述、标签等)构建为可检索的 chunk。"""
+ if not audio_meta:
+ return []
+
+ raw = audio_meta.get("raw_info", {}) or {}
+ parts = []
+
+ title = audio_meta.get("title") or raw.get("title", "")
+ if title:
+ parts.append(f"视频标题:{title}")
+
+ uploader = raw.get("uploader", "")
+ if uploader:
+ parts.append(f"视频作者/UP主:{uploader}")
+
+ desc = raw.get("description", "")
+ if desc:
+ parts.append(f"视频简介:{desc[:500]}")
+
+ tags = raw.get("tags", [])
+ if tags and isinstance(tags, list):
+ parts.append(f"标签:{', '.join(str(t) for t in tags[:20])}")
+
+ duration = audio_meta.get("duration", 0)
+ if duration:
+ m, s = divmod(int(duration), 60)
+ parts.append(f"视频时长:{m}分{s}秒")
+
+ platform = audio_meta.get("platform", "")
+ if platform:
+ parts.append(f"平台:{platform}")
+
+ url = raw.get("webpage_url", "")
+ if url:
+ parts.append(f"链接:{url}")
+
+ if not parts:
+ return []
+
+ return [{
+ "text": "\n".join(parts),
+ "metadata": {"source_type": "meta"},
+ }]
+
+
+class VectorStoreManager:
+ """基于 ChromaDB 的笔记向量存储管理器。"""
+
+ def __init__(self):
+ os.makedirs(VECTOR_DB_DIR, exist_ok=True)
+ self._client = chromadb.PersistentClient(
+ path=VECTOR_DB_DIR,
+ settings=Settings(anonymized_telemetry=False),
+ )
+
+ def _collection_name(self, task_id: str) -> str:
+ """ChromaDB collection 名称:直接使用 task_id(UUID 格式合法)。"""
+ return task_id
+
+ def index_task(self, task_id: str) -> None:
+ """读取笔记结果并建立向量索引。"""
+ result_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json")
+ if not os.path.exists(result_path):
+ logger.warning(f"笔记文件不存在,跳过索引: {result_path}")
+ return
+
+ with open(result_path, "r", encoding="utf-8") as f:
+ note_data = json.load(f)
+
+ markdown = note_data.get("markdown", "")
+ transcript = note_data.get("transcript", {})
+ segments = transcript.get("segments", [])
+
+ audio_meta = note_data.get("audio_meta", {})
+
+ meta_chunks = _build_meta_chunk(audio_meta)
+ md_chunks = _chunk_markdown(markdown)
+ tr_chunks = _chunk_transcript(segments)
+ all_chunks = meta_chunks + md_chunks + tr_chunks
+
+ if not all_chunks:
+ logger.warning(f"笔记内容为空,跳过索引: {task_id}")
+ return
+
+ col_name = self._collection_name(task_id)
+
+ # 删除旧 collection(幂等)
+ try:
+ self._client.delete_collection(col_name)
+ except Exception:
+ pass
+
+ collection = self._client.create_collection(
+ name=col_name,
+ metadata={"hnsw:space": "cosine"},
+ )
+
+ documents = [c["text"] for c in all_chunks]
+ metadatas = [c["metadata"] for c in all_chunks]
+ ids = [f"{task_id}_{i}" for i in range(len(all_chunks))]
+
+ collection.add(documents=documents, metadatas=metadatas, ids=ids)
+ logger.info(f"向量索引完成: task_id={task_id}, chunks={len(all_chunks)}")
+
+ def _parse_results(self, results: dict) -> list[dict]:
+ """将 ChromaDB query 结果转换为 chunk 列表。"""
+ chunks = []
+ if not results or not results.get("documents") or not results["documents"][0]:
+ return chunks
+ for i in range(len(results["documents"][0])):
+ chunks.append({
+ "text": results["documents"][0][i],
+ "metadata": results["metadatas"][0][i] if results["metadatas"] else {},
+ "distance": results["distances"][0][i] if results["distances"] else None,
+ })
+ return chunks
+
+ def query(self, task_id: str, query_text: str, n_results: int = 6) -> list[dict]:
+ """
+ 按固定配额从各来源检索:meta 1 条、markdown 2 条、transcript 3 条,
+ 确保三种来源都被召回。
+ """
+ col_name = self._collection_name(task_id)
+ try:
+ collection = self._client.get_collection(col_name)
+ except Exception:
+ logger.warning(f"Collection 不存在: {col_name}")
+ return []
+
+ all_chunks = []
+
+ # 每种来源的配额
+ quotas = {"meta": 1, "markdown": 2, "transcript": 3}
+
+ for source_type, quota in quotas.items():
+ try:
+ results = collection.query(
+ query_texts=[query_text],
+ n_results=quota,
+ where={"source_type": source_type},
+ )
+ all_chunks.extend(self._parse_results(results))
+ except Exception:
+ pass
+
+ return all_chunks
+
+ def delete_index(self, task_id: str) -> None:
+ """删除指定任务的向量索引。"""
+ col_name = self._collection_name(task_id)
+ try:
+ self._client.delete_collection(col_name)
+ logger.info(f"已删除向量索引: {task_id}")
+ except Exception:
+ pass
+
+ def is_indexed(self, task_id: str) -> bool:
+ """检查指定任务是否已建立完整索引(含 meta 信息)。"""
+ col_name = self._collection_name(task_id)
+ try:
+ col = self._client.get_collection(col_name)
+ if col.count() == 0:
+ return False
+ # 检查是否包含 meta chunk,旧索引可能缺失
+ meta = col.get(where={"source_type": "meta"}, limit=1)
+ return len(meta["ids"]) > 0
+ except Exception:
+ return False
diff --git a/backend/requirements.txt b/backend/requirements.txt
index e936f3e..b3afa64 100644
--- a/backend/requirements.txt
+++ b/backend/requirements.txt
@@ -16,6 +16,7 @@ celery==5.5.1
certifi==2025.1.31
cffi==1.17.1
charset-normalizer==3.4.1
+chromadb>=0.5.0
click==8.1.8
click-didyoumean==0.3.1
click-plugins==1.1.1