Merge pull request #299 from JefferyHcool/feature/note-qa-chat-optimize

Feature/note qa chat optimize
This commit is contained in:
Jianwu Huang
2026-03-23 16:00:15 +08:00
committed by GitHub
15 changed files with 1197 additions and 4 deletions

1
.gitignore vendored
View File

@@ -320,5 +320,6 @@ cython_debug/
/backend/uploads/* /backend/uploads/*
/backend/.idea/* /backend/.idea/*
/backend/config/* /backend/config/*
/backend/vector_db/
/BiliNote_frontend/.idea/* /BiliNote_frontend/.idea/*
/BiliNote_frontend/src-tauri/bin/ /BiliNote_frontend/src-tauri/bin/

View File

@@ -10,6 +10,7 @@
"preview": "vite preview" "preview": "vite preview"
}, },
"dependencies": { "dependencies": {
"@ant-design/x": "^2.4.0",
"@hookform/resolvers": "^5.0.1", "@hookform/resolvers": "^5.0.1",
"@lobehub/icons": "^1.97.1", "@lobehub/icons": "^1.97.1",
"@lobehub/icons-static-svg": "^1.45.0", "@lobehub/icons-static-svg": "^1.45.0",

View File

@@ -0,0 +1,294 @@
import { useState, useEffect, useCallback, useMemo } from 'react'
import { Bubble, Sender } from '@ant-design/x'
import ReactMarkdown from 'react-markdown'
import remarkGfm from 'remark-gfm'
import { Button } from '@/components/ui/button'
import { Badge } from '@/components/ui/badge'
import { Loader2, Trash2, ChevronDown, ChevronUp, BookOpen, UserRound, Bot, Maximize2, Minimize2 } from 'lucide-react'
import { toast } from 'react-hot-toast'
import { useChatStore } from '@/store/chatStore'
import { useTaskStore } from '@/store/taskStore'
import { askQuestion, getChatStatus, indexTask, type ChatSource, type IndexStatus } from '@/services/chat'
type ChatMode = 'half' | 'full'
interface ChatPanelProps {
taskId: string
mode: ChatMode
onModeChange: (mode: ChatMode) => void
}
function SourceBadges({ sources }: { sources: ChatSource[] }) {
const [expanded, setExpanded] = useState(false)
if (!sources || sources.length === 0) return null
return (
<div className="mt-1.5">
<button
onClick={() => setExpanded(!expanded)}
className="flex items-center gap-1 text-xs text-neutral-400 hover:text-neutral-600"
>
<BookOpen className="h-3 w-3" />
<span> ({sources.length})</span>
{expanded ? <ChevronUp className="h-3 w-3" /> : <ChevronDown className="h-3 w-3" />}
</button>
{expanded && (
<div className="mt-1 flex flex-wrap gap-1">
{sources.map((s, i) => (
<Badge key={i} variant="outline" className="text-xs font-normal">
{s.source_type === 'markdown'
? s.section_title || '笔记'
: `${(s.start_time ?? 0).toFixed(0)}s ~ ${(s.end_time ?? 0).toFixed(0)}s`}
</Badge>
))}
</div>
)}
</div>
)
}
export default function ChatPanel({ taskId, mode, onModeChange }: ChatPanelProps) {
const [input, setInput] = useState('')
const [loading, setLoading] = useState(false)
const [indexStatus, setIndexStatus] = useState<IndexStatus | null>(null)
const messages = useChatStore(state => state.chatHistory[taskId]) ?? []
const addMessage = useChatStore(state => state.addMessage)
const clearChat = useChatStore(state => state.clearChat)
const currentTaskId = useTaskStore(state => state.currentTaskId)
const tasks = useTaskStore(state => state.tasks)
const currentTask = useMemo(
() => tasks.find(t => t.id === currentTaskId) ?? null,
[tasks, currentTaskId],
)
// 检查索引状态未索引时自动触发indexing 时轮询
useEffect(() => {
if (!taskId) return
let cancelled = false
let timer: ReturnType<typeof setTimeout> | null = null
const poll = async () => {
try {
const res = await getChatStatus(taskId)
if (cancelled) return
setIndexStatus(res.status)
if (res.status === 'idle') {
// 未索引,触发后台索引
await indexTask(taskId)
if (!cancelled) setIndexStatus('indexing')
}
// indexing 状态持续轮询
if (res.status === 'indexing' || res.status === 'idle') {
timer = setTimeout(poll, 2000)
}
} catch {
if (!cancelled) setIndexStatus('failed')
}
}
poll()
return () => {
cancelled = true
if (timer) clearTimeout(timer)
}
}, [taskId])
const handleSend = useCallback(
async (value: string) => {
const question = value.trim()
if (!question || loading) return
const providerId = currentTask?.formData?.provider_id
const modelName = currentTask?.formData?.model_name
if (!providerId || !modelName) {
toast.error('无法获取模型配置,请确认任务已完成')
return
}
addMessage(taskId, { role: 'user', content: question })
setInput('')
setLoading(true)
try {
const history = messages.map(m => ({ role: m.role, content: m.content }))
const res = await askQuestion({
task_id: taskId,
question,
history,
provider_id: providerId,
model_name: modelName,
})
addMessage(taskId, {
role: 'assistant',
content: res.answer,
sources: res.sources,
})
} catch {
toast.error('问答请求失败')
} finally {
setLoading(false)
}
},
[loading, taskId, currentTask, messages, addMessage],
)
// 转换为 Bubble.List 的数据格式
const bubbleItems = useMemo(() => {
const items = messages.map((msg, i) => ({
key: `msg-${i}`,
role: msg.role === 'user' ? ('user' as const) : ('ai' as const),
content: msg.content,
footer:
msg.role === 'assistant' && msg.sources ? (
<SourceBadges sources={msg.sources} />
) : undefined,
}))
if (loading) {
items.push({
key: 'loading',
role: 'ai' as const,
content: '思考中...',
loading: true,
} as any)
}
return items
}, [messages, loading])
// Bubble 角色配置
const roles = useMemo(
() => ({
user: {
placement: 'end' as const,
avatar: (
<div className="flex h-7 w-7 items-center justify-center rounded-full bg-blue-500 text-white">
<UserRound className="h-4 w-4" />
</div>
),
variant: 'filled' as const,
styles: { content: { background: '#3b82f6', color: '#fff' } },
},
ai: {
placement: 'start' as const,
avatar: (
<div className="flex h-7 w-7 items-center justify-center rounded-full bg-neutral-500 text-white">
<Bot className="h-4 w-4" />
</div>
),
variant: 'outlined' as const,
contentRender: (content: any) => (
<div className="markdown-body prose prose-sm max-w-none prose-p:my-1 prose-li:my-0.5 prose-headings:my-2">
<ReactMarkdown remarkPlugins={[remarkGfm]}>
{typeof content === 'string' ? content : String(content)}
</ReactMarkdown>
</div>
),
},
}),
[],
)
if (indexStatus === null || indexStatus === 'indexing' || indexStatus === 'idle') {
return (
<div className="flex h-full flex-col items-center justify-center gap-3 text-neutral-400">
<Loader2 className="h-6 w-6 animate-spin" />
<div className="text-center">
<p className="text-sm font-medium">...</p>
<p className="mt-1 text-xs">使 Embedding 80MB</p>
</div>
</div>
)
}
if (indexStatus === 'failed') {
return (
<div className="flex h-full flex-col items-center justify-center gap-2 text-neutral-400">
<span className="text-sm"></span>
<Button
size="sm"
variant="outline"
onClick={async () => {
setIndexStatus('indexing')
try {
await indexTask(taskId)
} catch {
toast.error('索引请求失败')
setIndexStatus('failed')
}
}}
>
</Button>
</div>
)
}
return (
<div className="flex h-full flex-col border-l">
{/* 头部 */}
<div className="flex items-center justify-between border-b px-3 py-2">
<span className="text-sm font-medium">AI </span>
<div className="flex items-center gap-1">
<Button
variant="ghost"
size="sm"
className="h-7 px-2 text-neutral-400 hover:text-neutral-600"
onClick={() => onModeChange(mode === 'half' ? 'full' : 'half')}
title={mode === 'half' ? '全屏' : '半屏'}
>
{mode === 'half' ? (
<Maximize2 className="h-3.5 w-3.5" />
) : (
<Minimize2 className="h-3.5 w-3.5" />
)}
</Button>
{messages.length > 0 && (
<Button
variant="ghost"
size="sm"
className="h-7 px-2 text-neutral-400 hover:text-red-500"
onClick={() => clearChat(taskId)}
>
<Trash2 className="h-3.5 w-3.5" />
</Button>
)}
</div>
</div>
{/* 消息列表 */}
<div className="flex-1 overflow-hidden">
{messages.length === 0 && !loading ? (
<div className="flex h-full items-center justify-center text-center text-sm text-neutral-400">
<div>
<p></p>
<p className="mt-1 text-xs"></p>
</div>
</div>
) : (
<Bubble.List
items={bubbleItems}
role={roles}
style={{ height: '100%' }}
/>
)}
</div>
{/* 输入区域 */}
<div className="border-t px-3 py-2">
<Sender
value={input}
onChange={setInput}
onSubmit={handleSend}
loading={loading}
placeholder="输入你的问题..."
/>
</div>
</div>
)
}

View File

@@ -1,7 +1,7 @@
'use client' 'use client'
import { useEffect, useState } from 'react' import { useEffect, useState } from 'react'
import { Copy, Download, BrainCircuit } from 'lucide-react' import { Copy, Download, BrainCircuit, MessageSquare } from 'lucide-react'
import { Button } from '@/components/ui/button' import { Button } from '@/components/ui/button'
import { Select, SelectContent, SelectItem, SelectTrigger } from '@/components/ui/select' import { Select, SelectContent, SelectItem, SelectTrigger } from '@/components/ui/select'
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@/components/ui/tooltip' import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@/components/ui/tooltip'
@@ -28,6 +28,8 @@ interface NoteHeaderProps {
onDownload: () => void onDownload: () => void
createAt?: string | Date createAt?: string | Date
setShowTranscribe: (show: boolean) => void setShowTranscribe: (show: boolean) => void
showChat?: false | 'half' | 'full'
setShowChat?: (mode: false | 'half' | 'full') => void
} }
export function MarkdownHeader({ export function MarkdownHeader({
@@ -43,6 +45,8 @@ export function MarkdownHeader({
createAt, createAt,
showTranscribe, showTranscribe,
setShowTranscribe, setShowTranscribe,
showChat,
setShowChat,
viewMode, viewMode,
setViewMode, setViewMode,
}: NoteHeaderProps) { }: NoteHeaderProps) {
@@ -183,6 +187,24 @@ export function MarkdownHeader({
<TooltipContent></TooltipContent> <TooltipContent></TooltipContent>
</Tooltip> </Tooltip>
</TooltipProvider> </TooltipProvider>
{setShowChat && (
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<Button
onClick={() => setShowChat(showChat ? false : 'half')}
variant={showChat ? 'default' : 'ghost'}
size="sm"
className="h-8 px-2"
>
<MessageSquare className="mr-1.5 h-4 w-4" />
<span className="text-sm">AI </span>
</Button>
</TooltipTrigger>
<TooltipContent> AI </TooltipContent>
</Tooltip>
</TooltipProvider>
)}
</div> </div>
</div> </div>
) )

View File

@@ -22,6 +22,8 @@ import { noteStyles } from '@/constant/note.ts'
import { MarkdownHeader } from '@/pages/HomePage/components/MarkdownHeader.tsx' import { MarkdownHeader } from '@/pages/HomePage/components/MarkdownHeader.tsx'
import TranscriptViewer from '@/pages/HomePage/components/transcriptViewer.tsx' import TranscriptViewer from '@/pages/HomePage/components/transcriptViewer.tsx'
import MarkmapEditor from '@/pages/HomePage/components/MarkmapComponent.tsx' import MarkmapEditor from '@/pages/HomePage/components/MarkmapComponent.tsx'
import ChatPanel from '@/pages/HomePage/components/ChatPanel.tsx'
import VideoBanner from '@/pages/HomePage/components/VideoBanner.tsx'
interface VersionNote { interface VersionNote {
ver_id: string ver_id: string
@@ -280,6 +282,7 @@ const MarkdownViewer: FC<MarkdownViewerProps> = memo(({ status }) => {
const retryTask = useTaskStore.getState().retryTask const retryTask = useTaskStore.getState().retryTask
const isMultiVersion = Array.isArray(currentTask?.markdown) const isMultiVersion = Array.isArray(currentTask?.markdown)
const [showTranscribe, setShowTranscribe] = useState(false) const [showTranscribe, setShowTranscribe] = useState(false)
const [showChat, setShowChat] = useState<false | 'half' | 'full'>(false)
const [viewMode, setViewMode] = useState<'map' | 'preview'>('preview') const [viewMode, setViewMode] = useState<'map' | 'preview'>('preview')
const svgRef = useRef<SVGSVGElement>(null) const svgRef = useRef<SVGSVGElement>(null)
@@ -422,6 +425,8 @@ const MarkdownViewer: FC<MarkdownViewerProps> = memo(({ status }) => {
createAt={createTime} createAt={createTime}
showTranscribe={showTranscribe} showTranscribe={showTranscribe}
setShowTranscribe={setShowTranscribe} setShowTranscribe={setShowTranscribe}
showChat={showChat}
setShowChat={setShowChat}
viewMode={viewMode} viewMode={viewMode}
setViewMode={setViewMode} setViewMode={setViewMode}
/> />
@@ -441,14 +446,26 @@ const MarkdownViewer: FC<MarkdownViewerProps> = memo(({ status }) => {
<div className="flex flex-1 overflow-hidden bg-white py-2"> <div className="flex flex-1 overflow-hidden bg-white py-2">
{selectedContent && selectedContent !== 'loading' && selectedContent !== 'empty' ? ( {selectedContent && selectedContent !== 'loading' && selectedContent !== 'empty' ? (
<> <>
<ScrollArea className="w-full"> {showChat === 'full' && currentTask ? (
<div className="h-full w-full">
<ChatPanel taskId={currentTask.id} mode="full" onModeChange={setShowChat} />
</div>
) : (
<>
<ScrollArea className="min-w-0 flex-1">
<div className="px-2">
<VideoBanner
audioMeta={currentTask?.audioMeta}
videoUrl={currentTask?.formData?.video_url}
/>
</div>
<div className={'markdown-body w-full px-2'}> <div className={'markdown-body w-full px-2'}>
<ReactMarkdown <ReactMarkdown
remarkPlugins={remarkPlugins} remarkPlugins={remarkPlugins}
rehypePlugins={rehypePlugins} rehypePlugins={rehypePlugins}
components={markdownComponents} components={markdownComponents}
> >
{selectedContent} {selectedContent.replace(/^>\s*来源链接:[^\n]*\n*/m, '')}
</ReactMarkdown> </ReactMarkdown>
</div> </div>
</ScrollArea> </ScrollArea>
@@ -457,6 +474,14 @@ const MarkdownViewer: FC<MarkdownViewerProps> = memo(({ status }) => {
<TranscriptViewer /> <TranscriptViewer />
</div> </div>
)} )}
{/* 侧边问答模式markdown + ChatPanel 各占一半 */}
{showChat === 'half' && currentTask && (
<div className="ml-2 h-full w-1/2 shrink-0">
<ChatPanel taskId={currentTask.id} mode="half" onModeChange={setShowChat} />
</div>
)}
</>
)}
</> </>
) : ( ) : (
<div className="flex h-full w-full items-center justify-center"> <div className="flex h-full w-full items-center justify-center">

View File

@@ -0,0 +1,86 @@
import { ExternalLink } from 'lucide-react'
import type { AudioMeta } from '@/store/taskStore'
interface VideoBannerProps {
audioMeta?: AudioMeta
videoUrl?: string
}
/** 平台 label 映射 */
const platformLabel: Record<string, string> = {
bilibili: '哔哩哔哩',
youtube: 'YouTube',
douyin: '抖音',
xiaohongshu: '小红书',
}
export default function VideoBanner({ audioMeta, videoUrl }: VideoBannerProps) {
if (!audioMeta) return null
const rawCover = audioMeta.cover_url
// 通过后端代理加载封面,避免跨域/Referrer 限制
const apiBase = String(import.meta.env.VITE_API_BASE_URL || 'api').replace(/\/$/, '')
const coverUrl = rawCover
? `${apiBase}/image_proxy?url=${encodeURIComponent(rawCover)}`
: ''
const title = audioMeta.title
const uploader = audioMeta.raw_info?.uploader || ''
const platform = platformLabel[audioMeta.platform] || audioMeta.platform || ''
const originalUrl = videoUrl || audioMeta.raw_info?.webpage_url || ''
return (
<div className="relative mb-4 overflow-hidden rounded-lg">
{/* 模糊背景封面 */}
<div className="absolute inset-0">
{coverUrl ? (
<img
src={coverUrl}
alt=""
referrerPolicy="no-referrer"
className="h-full w-full object-cover blur-md brightness-[0.4] scale-110"
/>
) : (
<div className="h-full w-full bg-gradient-to-r from-blue-600 to-indigo-700" />
)}
</div>
{/* 内容层 */}
<div className="relative flex items-center gap-4 px-5 py-4">
{/* 封面缩略图 */}
{coverUrl && (
<img
src={coverUrl}
alt={title}
referrerPolicy="no-referrer"
className="h-16 w-28 shrink-0 rounded-md object-cover shadow-md"
/>
)}
{/* 文字信息 */}
<div className="min-w-0 flex-1">
<h2 className="truncate text-base font-bold text-white" title={title}>
{title}
</h2>
<div className="mt-1 flex flex-wrap items-center gap-2 text-sm text-white/70">
{uploader && <span>{uploader}</span>}
{uploader && platform && <span className="text-white/40">·</span>}
{platform && <span>{platform}</span>}
</div>
</div>
{/* 跳转原视频 */}
{originalUrl && (
<a
href={originalUrl}
target="_blank"
rel="noopener noreferrer"
className="flex shrink-0 items-center gap-1.5 rounded-full bg-white/15 px-3 py-1.5 text-xs font-medium text-white backdrop-blur-sm transition-colors hover:bg-white/25"
>
<ExternalLink className="h-3.5 w-3.5" />
<span></span>
</a>
)}
</div>
</div>
)
}

View File

@@ -0,0 +1,44 @@
import request from '@/utils/request'
export interface ChatMessage {
role: 'user' | 'assistant'
content: string
}
export interface ChatSource {
text: string
source_type: 'markdown' | 'transcript'
section_title?: string
start_time?: number
end_time?: number
}
export interface AskResponse {
answer: string
sources: ChatSource[]
}
export type IndexStatus = 'idle' | 'indexing' | 'indexed' | 'failed'
export interface ChatStatusResponse {
indexed: boolean
status: IndexStatus
}
export const indexTask = async (taskId: string): Promise<void> => {
return await request.post('/chat/index', { task_id: taskId })
}
export const askQuestion = async (data: {
task_id: string
question: string
history: ChatMessage[]
provider_id: string
model_name: string
}): Promise<AskResponse> => {
return await request.post('/chat/ask', data, { timeout: 60000 })
}
export const getChatStatus = async (taskId: string): Promise<ChatStatusResponse> => {
return await request.get(`/chat/status?task_id=${taskId}`)
}

View File

@@ -0,0 +1,43 @@
import { create } from 'zustand'
import { persist } from 'zustand/middleware'
import type { ChatSource } from '@/services/chat'
export interface ChatMessage {
role: 'user' | 'assistant'
content: string
sources?: ChatSource[]
}
interface ChatState {
chatHistory: Record<string, ChatMessage[]>
addMessage: (taskId: string, msg: ChatMessage) => void
clearChat: (taskId: string) => void
getMessages: (taskId: string) => ChatMessage[]
}
export const useChatStore = create<ChatState>()(
persist(
(set, get) => ({
chatHistory: {},
addMessage: (taskId, msg) =>
set(state => ({
chatHistory: {
...state.chatHistory,
[taskId]: [...(state.chatHistory[taskId] || []), msg],
},
})),
clearChat: (taskId) =>
set(state => {
const { [taskId]: _, ...rest } = state.chatHistory
return { chatHistory: rest }
}),
getMessages: (taskId) => get().chatHistory[taskId] || [],
}),
{
name: 'bilinote-chat-storage',
},
),
)

View File

@@ -1,6 +1,6 @@
from fastapi import FastAPI from fastapi import FastAPI
from .routers import note, provider, model, config from .routers import note, provider, model, config, chat
@@ -10,5 +10,6 @@ def create_app(lifespan) -> FastAPI:
app.include_router(provider.router, prefix="/api") app.include_router(provider.router, prefix="/api")
app.include_router(model.router,prefix="/api") app.include_router(model.router,prefix="/api")
app.include_router(config.router, prefix="/api") app.include_router(config.router, prefix="/api")
app.include_router(chat.router, prefix="/api")
return app return app

101
backend/app/routers/chat.py Normal file
View File

@@ -0,0 +1,101 @@
from fastapi import APIRouter, BackgroundTasks
from pydantic import BaseModel
from app.services.chat_service import chat as chat_service
from app.services.vector_store import VectorStoreManager
from app.utils.logger import get_logger
from app.utils.response import ResponseWrapper as R
logger = get_logger(__name__)
router = APIRouter()
# 索引状态追踪: task_id -> "indexing" | "indexed" | "failed"
_index_status: dict[str, str] = {}
class IndexRequest(BaseModel):
task_id: str
class ChatMessage(BaseModel):
role: str
content: str
class AskRequest(BaseModel):
task_id: str
question: str
history: list[ChatMessage] = []
provider_id: str
model_name: str
def _do_index(task_id: str):
"""后台执行索引任务。"""
try:
_index_status[task_id] = "indexing"
store = VectorStoreManager()
store.index_task(task_id)
_index_status[task_id] = "indexed"
logger.info(f"索引完成: {task_id}")
except Exception as e:
_index_status[task_id] = "failed"
logger.error(f"索引失败: {task_id}, {e}")
@router.post("/chat/index")
def index_task(data: IndexRequest, background_tasks: BackgroundTasks):
"""触发后台索引,立即返回。"""
if _index_status.get(data.task_id) == "indexing":
return R.success(msg="正在索引中")
# 如果已经索引过,直接返回
store = VectorStoreManager()
if store.is_indexed(data.task_id):
_index_status[data.task_id] = "indexed"
return R.success(msg="已完成索引")
_index_status[data.task_id] = "indexing"
background_tasks.add_task(_do_index, data.task_id)
return R.success(msg="开始索引")
@router.get("/chat/status")
def chat_status(task_id: str):
"""返回索引状态idle / indexing / indexed / failed。"""
try:
# 优先检查内存状态
status = _index_status.get(task_id)
if status:
return R.success(data={"status": status, "indexed": status == "indexed"})
# 内存没有记录,检查持久化
store = VectorStoreManager()
indexed = store.is_indexed(task_id)
if indexed:
_index_status[task_id] = "indexed"
return R.success(data={"status": "indexed" if indexed else "idle", "indexed": indexed})
except Exception as e:
logger.error(f"查询索引状态失败: {e}")
return R.success(data={"status": "idle", "indexed": False})
@router.post("/chat/ask")
def ask_question(data: AskRequest):
"""基于笔记内容的 RAG 问答。"""
try:
history = [{"role": m.role, "content": m.content} for m in data.history]
result = chat_service(
task_id=data.task_id,
question=data.question,
history=history,
provider_id=data.provider_id,
model_name=data.model_name,
)
return R.success(data=result)
except ValueError as e:
return R.error(msg=str(e))
except Exception as e:
logger.error(f"Chat 问答失败: {e}", exc_info=True)
return R.error(msg=f"问答失败: {str(e)}")

View File

@@ -109,6 +109,12 @@ def run_note_task(task_id: str, video_url: str, platform: str, quality: Download
return return
save_note_to_file(task_id, note) save_note_to_file(task_id, note)
# 自动建立向量索引(用于 AI 问答),失败不影响笔记生成
try:
from app.services.vector_store import VectorStoreManager
VectorStoreManager().index_task(task_id)
except Exception as e:
logger.warning(f"向量索引失败(不影响笔记): {e}")
@router.post('/delete_task') @router.post('/delete_task')

View File

@@ -0,0 +1,158 @@
import json
from typing import Optional
from app.gpt.gpt_factory import GPTFactory
from app.models.model_config import ModelConfig
from app.services.provider import ProviderService
from app.services.vector_store import VectorStoreManager
from app.services.chat_tools import TOOLS, execute_tool
from app.utils.logger import get_logger
logger = get_logger(__name__)
SYSTEM_PROMPT = """你是一个视频笔记问答助手。你拥有以下能力:
1. 系统已自动检索了一些相关内容作为初始参考(见下方)
2. 你可以调用工具主动查询更多信息:
- lookup_transcript: 查询视频原始转录文本(支持按时间、关键词、位置筛选)
- get_video_info: 获取视频元信息(标题、作者、简介、标签等)
- get_note_content: 获取完整笔记内容
--- 初始检索内容 ---
{context}
---
回答要求:
- 如果初始检索内容不足以回答问题,请主动调用工具获取更多信息
- 回答关于视频具体原话、细节时,用 lookup_transcript 查询原文
- 回答关于作者、标题等基本信息时,用 get_video_info 查询
- 请用中文回答,保持简洁准确"""
def _build_context(chunks: list[dict]) -> str:
"""将检索到的片段拼接为上下文文本。"""
parts = []
for chunk in chunks:
meta = chunk.get("metadata", {})
source_type = meta.get("source_type", "unknown")
if source_type == "meta":
label = "[视频信息]"
elif source_type == "markdown":
label = f"[笔记 - {meta.get('section_title', '')}]"
else:
start = meta.get("start_time", 0)
end = meta.get("end_time", 0)
label = f"[转录 - {start:.0f}s~{end:.0f}s]"
parts.append(f"{label}\n{chunk['text']}")
return "\n\n".join(parts)
def _build_sources(chunks: list[dict]) -> list[dict]:
"""从检索片段中提取来源信息。"""
sources = []
for chunk in chunks:
meta = chunk.get("metadata", {})
source = {
"text": chunk["text"][:200],
"source_type": meta.get("source_type", "unknown"),
}
if meta.get("section_title"):
source["section_title"] = meta["section_title"]
if meta.get("start_time") is not None:
source["start_time"] = meta["start_time"]
if meta.get("end_time") is not None:
source["end_time"] = meta["end_time"]
sources.append(source)
return sources
def chat(
task_id: str,
question: str,
history: list[dict],
provider_id: str,
model_name: str,
) -> dict:
"""
RAG + Tool Calling 问答。
1. 向量检索初始上下文
2. 调用 LLM带 tools
3. 如果 LLM 调用了工具,执行工具并将结果返回给 LLM
4. 循环直到 LLM 给出最终回答
"""
vector_store = VectorStoreManager()
# 1. 检索初始上下文
chunks = vector_store.query(task_id, question, n_results=6)
context = _build_context(chunks) if chunks else "(未检索到相关内容,请使用工具查询)"
sources = _build_sources(chunks) if chunks else []
# 2. 构建消息
system_msg = SYSTEM_PROMPT.format(context=context)
messages = [{"role": "system", "content": system_msg}]
for msg in history[-20:]:
messages.append({"role": msg["role"], "content": msg["content"]})
messages.append({"role": "user", "content": question})
# 3. 获取 LLM client
provider = ProviderService.get_provider_by_id(provider_id)
if not provider:
raise ValueError(f"未找到模型供应商: {provider_id}")
config = ModelConfig(
api_key=provider["api_key"],
base_url=provider["base_url"],
model_name=model_name,
provider=provider["type"],
name=provider["name"],
)
gpt = GPTFactory.from_config(config)
logger.info(f"Chat: task_id={task_id}, model={model_name}")
# 4. Tool calling 循环(最多 3 轮)
max_rounds = 3
for round_i in range(max_rounds):
response = gpt.client.chat.completions.create(
model=gpt.model,
messages=messages,
tools=TOOLS,
temperature=0.7,
)
msg = response.choices[0].message
# 没有工具调用,直接返回
if not msg.tool_calls:
return {"answer": msg.content or "", "sources": sources}
# 处理工具调用
messages.append(msg)
for tool_call in msg.tool_calls:
fn_name = tool_call.function.name
try:
fn_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError:
fn_args = {}
logger.info(f"Tool call [{round_i+1}/{max_rounds}]: {fn_name}({fn_args})")
result = execute_tool(task_id, fn_name, fn_args)
messages.append({
"role": "tool",
"tool_call_id": tool_call.id,
"content": result,
})
# 超过最大轮次,做最后一次不带 tools 的调用
response = gpt.client.chat.completions.create(
model=gpt.model,
messages=messages,
temperature=0.7,
)
return {"answer": response.choices[0].message.content or "", "sources": sources}

View File

@@ -0,0 +1,184 @@
"""
Chat function calling 工具定义与执行。
提供给 LLM 调用,用于主动查询视频原文、笔记、元信息。
"""
import json
import os
from typing import Optional
from app.utils.logger import get_logger
logger = get_logger(__name__)
NOTE_OUTPUT_DIR = os.getenv("NOTE_OUTPUT_DIR", "note_results")
def _load_note_data(task_id: str) -> Optional[dict]:
path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json")
if not os.path.exists(path):
return None
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
# ── 工具定义OpenAI function calling 格式)──────────────────────
TOOLS = [
{
"type": "function",
"function": {
"name": "lookup_transcript",
"description": "查询视频原始转录文本。可按时间范围筛选、按关键词搜索、或获取指定位置的内容。",
"parameters": {
"type": "object",
"properties": {
"start_time": {
"type": "number",
"description": "起始时间(秒),例如 0 表示视频开头60 表示第1分钟",
},
"end_time": {
"type": "number",
"description": "结束时间(秒),不传则到末尾",
},
"keyword": {
"type": "string",
"description": "搜索关键词,返回包含该关键词的转录片段",
},
"position": {
"type": "string",
"enum": ["start", "end"],
"description": "快捷位置start=视频开头前30句end=视频结尾后30句",
},
},
"required": [],
},
},
},
{
"type": "function",
"function": {
"name": "get_video_info",
"description": "获取视频的完整元信息,包括标题、作者、简介、标签、时长、播放量等。",
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
},
},
{
"type": "function",
"function": {
"name": "get_note_content",
"description": "获取 AI 生成的完整笔记内容Markdown 格式)。",
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
},
},
]
# ── 工具执行 ──────────────────────────────────────────────────
def execute_tool(task_id: str, tool_name: str, arguments: dict) -> str:
"""执行工具调用,返回结果字符串。"""
data = _load_note_data(task_id)
if not data:
return json.dumps({"error": "笔记数据不存在"}, ensure_ascii=False)
if tool_name == "lookup_transcript":
return _lookup_transcript(data, arguments)
elif tool_name == "get_video_info":
return _get_video_info(data)
elif tool_name == "get_note_content":
return _get_note_content(data)
else:
return json.dumps({"error": f"未知工具: {tool_name}"}, ensure_ascii=False)
def _lookup_transcript(data: dict, args: dict) -> str:
segments = data.get("transcript", {}).get("segments", [])
if not segments:
return json.dumps({"error": "没有转录数据"}, ensure_ascii=False)
position = args.get("position")
start_time = args.get("start_time")
end_time = args.get("end_time")
keyword = args.get("keyword", "").strip()
# 快捷位置
if position == "start":
filtered = segments[:30]
elif position == "end":
filtered = segments[-30:]
else:
filtered = segments
# 时间筛选
if start_time is not None:
filtered = [s for s in filtered if s.get("end", 0) >= start_time]
if end_time is not None:
filtered = [s for s in filtered if s.get("start", 0) <= end_time]
# 关键词筛选
if keyword:
filtered = [s for s in filtered if keyword.lower() in s.get("text", "").lower()]
# 限制返回量,避免 token 爆炸
if len(filtered) > 50:
filtered = filtered[:50]
truncated = True
else:
truncated = False
result = {
"total_segments": len(data.get("transcript", {}).get("segments", [])),
"returned": len(filtered),
"truncated": truncated,
"segments": [
{
"start": round(s.get("start", 0), 1),
"end": round(s.get("end", 0), 1),
"text": s.get("text", ""),
}
for s in filtered
],
}
return json.dumps(result, ensure_ascii=False)
def _get_video_info(data: dict) -> str:
am = data.get("audio_meta", {})
raw = am.get("raw_info", {}) or {}
info = {
"title": am.get("title") or raw.get("title", ""),
"uploader": raw.get("uploader", ""),
"description": raw.get("description", "")[:1000],
"tags": raw.get("tags", [])[:20] if isinstance(raw.get("tags"), list) else [],
"duration_seconds": am.get("duration", 0),
"platform": am.get("platform", ""),
"video_id": am.get("video_id", ""),
"url": raw.get("webpage_url", ""),
"view_count": raw.get("view_count"),
"like_count": raw.get("like_count"),
"comment_count": raw.get("comment_count"),
}
# 去除 None 值
info = {k: v for k, v in info.items() if v is not None and v != ""}
return json.dumps(info, ensure_ascii=False)
def _get_note_content(data: dict) -> str:
md = data.get("markdown", "")
if isinstance(md, list):
# 多版本,取最新
md = md[-1].get("content", "") if md else ""
# 限制长度
if len(md) > 5000:
md = md[:5000] + "\n\n... (内容过长已截断)"
return json.dumps({"markdown": md}, ensure_ascii=False)

View File

@@ -0,0 +1,226 @@
import json
import os
import re
from typing import Optional
import chromadb
from chromadb.config import Settings
from app.utils.logger import get_logger
logger = get_logger(__name__)
NOTE_OUTPUT_DIR = os.getenv("NOTE_OUTPUT_DIR", "note_results")
VECTOR_DB_DIR = os.getenv("VECTOR_DB_DIR", "vector_db")
def _chunk_markdown(markdown: str) -> list[dict]:
"""按 H2/H3 标题拆分 markdown 为语义块。"""
sections = re.split(r'(?=^#{2,3}\s)', markdown, flags=re.MULTILINE)
chunks = []
for section in sections:
section = section.strip()
if not section or len(section) < 30:
continue
heading_match = re.match(r'^(#{2,3})\s+(.+)', section)
title = heading_match.group(2).strip() if heading_match else "intro"
chunks.append({
"text": section,
"metadata": {"source_type": "markdown", "section_title": title},
})
return chunks
def _chunk_transcript(segments: list[dict], window_size: int = 15, overlap: int = 3) -> list[dict]:
"""将转录 segments 按滑动窗口分组。"""
if not segments:
return []
chunks = []
step = max(window_size - overlap, 1)
for i in range(0, len(segments), step):
window = segments[i:i + window_size]
if not window:
break
text = "\n".join(
f"[{seg.get('start', 0):.0f}s] {seg.get('text', '')}" for seg in window
)
chunks.append({
"text": text,
"metadata": {
"source_type": "transcript",
"start_time": window[0].get("start", 0),
"end_time": window[-1].get("end", 0),
},
})
return chunks
def _build_meta_chunk(audio_meta: dict) -> list[dict]:
"""将视频元信息(标题、作者、描述、标签等)构建为可检索的 chunk。"""
if not audio_meta:
return []
raw = audio_meta.get("raw_info", {}) or {}
parts = []
title = audio_meta.get("title") or raw.get("title", "")
if title:
parts.append(f"视频标题:{title}")
uploader = raw.get("uploader", "")
if uploader:
parts.append(f"视频作者/UP主{uploader}")
desc = raw.get("description", "")
if desc:
parts.append(f"视频简介:{desc[:500]}")
tags = raw.get("tags", [])
if tags and isinstance(tags, list):
parts.append(f"标签:{', '.join(str(t) for t in tags[:20])}")
duration = audio_meta.get("duration", 0)
if duration:
m, s = divmod(int(duration), 60)
parts.append(f"视频时长:{m}{s}")
platform = audio_meta.get("platform", "")
if platform:
parts.append(f"平台:{platform}")
url = raw.get("webpage_url", "")
if url:
parts.append(f"链接:{url}")
if not parts:
return []
return [{
"text": "\n".join(parts),
"metadata": {"source_type": "meta"},
}]
class VectorStoreManager:
"""基于 ChromaDB 的笔记向量存储管理器。"""
def __init__(self):
os.makedirs(VECTOR_DB_DIR, exist_ok=True)
self._client = chromadb.PersistentClient(
path=VECTOR_DB_DIR,
settings=Settings(anonymized_telemetry=False),
)
def _collection_name(self, task_id: str) -> str:
"""ChromaDB collection 名称:直接使用 task_idUUID 格式合法)。"""
return task_id
def index_task(self, task_id: str) -> None:
"""读取笔记结果并建立向量索引。"""
result_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json")
if not os.path.exists(result_path):
logger.warning(f"笔记文件不存在,跳过索引: {result_path}")
return
with open(result_path, "r", encoding="utf-8") as f:
note_data = json.load(f)
markdown = note_data.get("markdown", "")
transcript = note_data.get("transcript", {})
segments = transcript.get("segments", [])
audio_meta = note_data.get("audio_meta", {})
meta_chunks = _build_meta_chunk(audio_meta)
md_chunks = _chunk_markdown(markdown)
tr_chunks = _chunk_transcript(segments)
all_chunks = meta_chunks + md_chunks + tr_chunks
if not all_chunks:
logger.warning(f"笔记内容为空,跳过索引: {task_id}")
return
col_name = self._collection_name(task_id)
# 删除旧 collection幂等
try:
self._client.delete_collection(col_name)
except Exception:
pass
collection = self._client.create_collection(
name=col_name,
metadata={"hnsw:space": "cosine"},
)
documents = [c["text"] for c in all_chunks]
metadatas = [c["metadata"] for c in all_chunks]
ids = [f"{task_id}_{i}" for i in range(len(all_chunks))]
collection.add(documents=documents, metadatas=metadatas, ids=ids)
logger.info(f"向量索引完成: task_id={task_id}, chunks={len(all_chunks)}")
def _parse_results(self, results: dict) -> list[dict]:
"""将 ChromaDB query 结果转换为 chunk 列表。"""
chunks = []
if not results or not results.get("documents") or not results["documents"][0]:
return chunks
for i in range(len(results["documents"][0])):
chunks.append({
"text": results["documents"][0][i],
"metadata": results["metadatas"][0][i] if results["metadatas"] else {},
"distance": results["distances"][0][i] if results["distances"] else None,
})
return chunks
def query(self, task_id: str, query_text: str, n_results: int = 6) -> list[dict]:
"""
按固定配额从各来源检索meta 1 条、markdown 2 条、transcript 3 条,
确保三种来源都被召回。
"""
col_name = self._collection_name(task_id)
try:
collection = self._client.get_collection(col_name)
except Exception:
logger.warning(f"Collection 不存在: {col_name}")
return []
all_chunks = []
# 每种来源的配额
quotas = {"meta": 1, "markdown": 2, "transcript": 3}
for source_type, quota in quotas.items():
try:
results = collection.query(
query_texts=[query_text],
n_results=quota,
where={"source_type": source_type},
)
all_chunks.extend(self._parse_results(results))
except Exception:
pass
return all_chunks
def delete_index(self, task_id: str) -> None:
"""删除指定任务的向量索引。"""
col_name = self._collection_name(task_id)
try:
self._client.delete_collection(col_name)
logger.info(f"已删除向量索引: {task_id}")
except Exception:
pass
def is_indexed(self, task_id: str) -> bool:
"""检查指定任务是否已建立完整索引(含 meta 信息)。"""
col_name = self._collection_name(task_id)
try:
col = self._client.get_collection(col_name)
if col.count() == 0:
return False
# 检查是否包含 meta chunk旧索引可能缺失
meta = col.get(where={"source_type": "meta"}, limit=1)
return len(meta["ids"]) > 0
except Exception:
return False

View File

@@ -16,6 +16,7 @@ celery==5.5.1
certifi==2025.1.31 certifi==2025.1.31
cffi==1.17.1 cffi==1.17.1
charset-normalizer==3.4.1 charset-normalizer==3.4.1
chromadb>=0.5.0
click==8.1.8 click==8.1.8
click-didyoumean==0.3.1 click-didyoumean==0.3.1
click-plugins==1.1.1 click-plugins==1.1.1