feat: 新增模型管理和供应商配置功能

### v1.1.0
- #### Added
  - 新增 AI 笔记风格选择
  - 新增 AI 笔记返回格式选择
  - 添加 AI 自定义笔记备注 Prompt
  - 添加任务失败重试
  - 添加全局设置页,可在设置页进行模型设置

- #### Optimize
  - 优化前端样式,优化用户体验
  - 增加生成中间产物,可用于失败后加快生成速度
- #### Fix
  - 修复视频截图视频过早删除错误
This commit is contained in:
思诺特
2025-04-26 23:40:17 +08:00
parent 1323cfd1ec
commit 171dea5e0d
51 changed files with 2511 additions and 414 deletions

View File

@@ -1,3 +1,11 @@
###
# @Author: 思诺特 jefferyhcool@gmail.com
# @Date: 2025-04-14 08:49:59
# @LastEditors: 思诺特 jefferyhcool@gmail.com
# @LastEditTime: 2025-04-26 19:56:50
# @FilePath: \BiliNote\.env.example
# @Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
###
# 通用端口配置
BACKEND_PORT=8001
FRONTEND_PORT=3015
@@ -13,17 +21,6 @@ STATIC=/static
OUT_DIR=./static/screenshots
IMAGE_BASE_URL=/static/screenshots
DATA_DIR=data
# AI 相关配置
OPENAI_API_KEY=
OPENAI_API_BASE_URL=
OPENAI_MODEL=
DEEP_SEEK_API_KEY=
DEEP_SEEK_API_BASE_URL=
DEEP_SEEK_MODEL=
QWEN_API_KEY=
QWEN_API_BASE_URL=
QWEN_MODEL=
MODEl_PROVIDER= #如果不是openai 请修改 deepseek/qwen
# FFMPEG 配置
FFMPEG_BIN_PATH=

View File

@@ -1,13 +1,32 @@
# === 前端构建阶段 ===
FROM node:18-alpine AS build
# 安装 pnpm
RUN npm install -g pnpm
# 设置工作目录
WORKDIR /app
# 拷贝前端源码
COPY ./BillNote_frontend /app
RUN npm install && npm run build
# 安装依赖并构建
RUN pnpm install && pnpm run build
# === nginx 运行阶段 ===
FROM nginx:alpine
# 拷贝模板配置
COPY ./BillNote_frontend/deploy/default.conf.template /etc/nginx/templates/default.conf.template
# 拷贝构建产物
COPY --from=build /app/dist /usr/share/nginx/html
# 拷贝启动脚本
COPY ./BillNote_frontend/deploy/start.sh /start.sh
RUN chmod +x /start.sh
EXPOSE 80
# 使用启动脚本启动容器
CMD ["/start.sh"]

View File

@@ -0,0 +1,18 @@
server {
listen 80;
resolver 127.0.0.11 valid=10s;
location / {
root /usr/share/nginx/html;
index index.html;
try_files $uri $uri/ /index.html;
}
location /api/ {
proxy_pass http://backend:${BACKEND_PORT};
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
}

View File

@@ -0,0 +1,20 @@
#!/bin/sh
###
# @Author: Jefferyhcool 1063474837@qq.com
# @Date: 2025-04-16 01:57:05
# @LastEditors: Jefferyhcool 1063474837@qq.com
# @LastEditTime: 2025-04-16 01:59:37
# @FilePath: /hotfix-dev/BillNote_frontend/deploy/start.sh
# @Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
###
# 等待后端健康检查通过
until curl -s "http://backend:${BACKEND_PORT}/health" > /dev/null; do
echo "等待后端服务就绪..."
sleep 2
done
# 生成 nginx 配置文件(动态变量替换)
envsubst '${BACKEND_HOST} ${BACKEND_PORT}' < /etc/nginx/templates/default.conf.template > /etc/nginx/conf.d/default.conf
# 启动 Nginx在前台运行
exec nginx -g 'daemon off;'

View File

@@ -25,6 +25,7 @@
"@radix-ui/react-tooltip": "^1.1.8",
"@tailwindcss/vite": "^4.1.3",
"@uiw/react-markdown-preview": "^5.1.3",
"antd": "^5.24.8",
"axios": "^1.8.4",
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",

View File

@@ -7,10 +7,19 @@ import { Route } from 'react-router-dom'
import Index from '@/pages/Index.tsx'
import NotFoundPage from '@/pages/NotFoundPage' //
import Model from '@/pages/SettingPage/Model.tsx'
import Transcriber from '@/pages/SettingPage/transcriber.tsx'
import ProviderForm from '@/components/Form/modelForm/Form.tsx'
import StepBar from '@/pages/HomePage/components/StepBar.tsx'
import Downloading from '@/components/Lottie/download.tsx'
function App() {
useTaskPolling(3000) // 每 3 秒轮询一次
const steps = [
{ label: '解析链接', key: 'PARSING', icon: <Downloading /> },
{ label: '下载音频', key: 'DOWNLOADING' },
{ label: '转写文字', key: 'TRANSCRIBING' },
{ label: '总结内容', key: 'SUMMARIZING' },
{ label: '保存完成', key: 'SUCCESS' },
]
return (
<>
<BrowserRouter>
@@ -20,9 +29,11 @@ function App() {
<Route path="settings" element={<SettingPage />}>
<Route index element={<Navigate to="model" replace />} />
<Route path="model" element={<Model />}>
<Route index element={<Navigate to="openai" replace />} />
<Route path="new" element={<ProviderForm isCreate />} />
{/*<Route index element={<Navigate to="openai" replace />} />*/}
<Route path=":id" element={<ProviderForm />} />
</Route>
<Route path="transcriber" elment={<Transcriber />}></Route>
<Route path="*" element={<NotFoundPage />} />
</Route>
<Route path="*" element={<NotFoundPage />} />

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 MiB

View File

@@ -1,153 +1,281 @@
import { useForm } from 'react-hook-form';
import { z } from 'zod';
import { zodResolver } from '@hookform/resolvers/zod';
import { useForm } from 'react-hook-form'
import { z } from 'zod'
import { zodResolver } from '@hookform/resolvers/zod'
import {
Form,
FormField,
FormItem,
FormLabel,
FormControl,
FormDescription,
FormMessage,
} from '@/components/ui/form';
import { Input } from '@/components/ui/input';
import { Button } from '@/components/ui/button';
import { useParams } from 'react-router-dom';
import { useProviderStore } from '@/store/providerStore';
import {useEffect, useState} from 'react';
FormDescription,
} from '@/components/ui/form'
import { Input } from '@/components/ui/input'
import { Button } from '@/components/ui/button'
import { useParams, useNavigate } from 'react-router-dom'
import { useProviderStore } from '@/store/providerStore'
import { useEffect, useState } from 'react'
import toast from 'react-hot-toast'
import { testConnection, fetchModels } from '@/services/model.ts'
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from '@/components/ui/select.tsx' // ⚡新增 fetchModels
import { ModelSelector } from '@/components/Form/modelForm/ModelSelector.tsx'
import { Alert, AlertDescription, AlertTitle } from '@/components/ui/alert.tsx'
// ✅ 表单校验 schema
// ✅ Provider表单schema
const ProviderSchema = z.object({
name: z.string().min(2, '名称不能少于 2 个字符'),
apiKey: z.string().optional(),
baseUrl: z.string().url('必须是合法 URL'),
type: z.string(), // 只展示,不可改
});
type: z.string(),
})
type ProviderFormValues = z.infer<typeof ProviderSchema>;
type ProviderFormValues = z.infer<typeof ProviderSchema>
const ProviderForm = () => {
const rawId= useParams();
console.log('rawId',rawId)
// @ts-ignore
const [providerName, idPart] = rawId.id.split('&');
const [id,setId ]= useState(Number(idPart?.split('=')[1])) // => "1"
const getProviderById = useProviderStore((state) => state.getProviderById);
const provider = getProviderById(id);
// ✅ Model表单schema
const ModelSchema = z.object({
modelName: z.string().min(1, '请选择或填写模型名称'),
})
const form = useForm<ProviderFormValues>({
type ModelFormValues = z.infer<typeof ModelSchema>
interface IModel {
id: string
created: number
object: string
owned_by: string
permission: string
root: string
}
const ProviderForm = ({ isCreate = false }: { isCreate?: boolean }) => {
const { id } = useParams()
const navigate = useNavigate()
const isEditMode = !isCreate
const getProviderById = useProviderStore(state => state.getProviderById)
const loadProviderById = useProviderStore(state => state.loadProviderById)
const updateProvider = useProviderStore(state => state.updateProvider)
const addNewProvider = useProviderStore(state => state.addNewProvider)
const [loading, setLoading] = useState(true)
const [testing, setTesting] = useState(false)
const [isBuiltIn, setIsBuiltIn] = useState(false)
const [modelOptions, setModelOptions] = useState<IModel[]>([]) // ⚡新增,保存模型列表
const [modelLoading, setModelLoading] = useState(false)
const [search, setSearch] = useState('')
const providerForm = useForm<ProviderFormValues>({
resolver: zodResolver(ProviderSchema),
defaultValues: {
name: '',
apiKey: '',
baseUrl: '',
type: '',
type: 'custom',
},
});
})
const filteredModelOptions = modelOptions.filter(model => {
const keywords = search.trim().toLowerCase().split(/\s+/) // 支持多个关键词
const target = model.id.toLowerCase()
return keywords.every(kw => target.includes(kw))
})
const modelForm = useForm<ModelFormValues>({
resolver: zodResolver(ModelSchema),
defaultValues: {
modelName: '',
},
})
useEffect(() => {
console.log(provider)
// if (provider) {
// form.reset({
// name: provider.name,
// apiKey: provider.apiKey,
// baseUrl: provider.baseUrl,
// type: provider.type,
// });
// }
}, [id,provider, form]);
const load = async () => {
if (isEditMode) {
const data = await loadProviderById(id!)
providerForm.reset(data)
setIsBuiltIn(data.type === 'built-in')
} else {
providerForm.reset({
name: '',
apiKey: '',
baseUrl: '',
type: 'custom',
})
setIsBuiltIn(false)
}
setLoading(false)
}
load()
}, [id])
const isBuiltIn = provider?.type === 'built-in';
// 测试连通性
const handleTest = async () => {
const values = providerForm.getValues()
if (!values.apiKey || !values.baseUrl) {
toast.error('请填写 API Key 和 Base URL')
return
}
try {
setTesting(true)
const data = await testConnection({
api_key: values.apiKey,
base_url: values.baseUrl,
})
if (data.data.code === 0) {
toast.success('测试连通性成功 🎉')
} else {
toast.error(`连接失败: ${data.data.msg || '未知错误'}`)
}
} catch (error) {
toast.error('测试连通性异常')
} finally {
setTesting(false)
}
}
const onSubmit = (values: ProviderFormValues) => {
console.log('📝 提交表单数据:', values);
// TODO: 提交接口 /update_provider
};
// 加载模型列表
const handleModelLoad = async () => {
const values = providerForm.getValues()
if (!values.apiKey || !values.baseUrl) {
toast.error('请先填写 API Key 和 Base URL')
return
}
try {
setModelLoading(true) // ✅ 开始 loading
const res = await fetchModels(id!, { noCache: true }) // 这里稍后解释
if (res.data.code === 0 && res.data.data.models.data.length > 0) {
setModelOptions(res.data.data.models.data)
console.log('🔧 模型列表:', res.data.data)
toast.success('模型列表加载成功 🎉')
} else {
toast.error('未获取到模型列表')
}
} catch (error) {
toast.error('加载模型列表失败')
} finally {
setModelLoading(false) // ✅ 结束 loading
}
}
// if (!provider) return <div className="p-4">加载中...</div>;
// 保存Provider信息
const onProviderSubmit = async (values: ProviderFormValues) => {
if (isEditMode) {
updateProvider({ ...values, id: id! })
toast.success('更新供应商成功')
} else {
addNewProvider({ ...values })
toast.success('新增供应商成功')
}
}
// 保存Model信息
const onModelSubmit = async (values: ModelFormValues) => {
console.log('🔧 选择的模型:', values.modelName)
toast.success(`保存模型: ${values.modelName}`)
}
if (loading) return <div className="p-4">...</div>
return (
<Form {...form}>
<div className="flex flex-col gap-8 p-4">
{/* Provider信息表单 */}
<Form {...providerForm}>
<form
onSubmit={form.handleSubmit(onSubmit)}
className="w-full max-w-xl p-4 flex flex-col gap-4"
onSubmit={providerForm.handleSubmit(onProviderSubmit)}
className="flex max-w-xl flex-col gap-4"
>
<div className="text-lg font-bold"></div>
{/* 名称 */}
<div className="text-lg font-bold">
{isEditMode ? '编辑模型供应商' : '新增模型供应商'}
</div>
{!isBuiltIn && (
<div className="text-sm text-red-500 italic">
OpenAI SDK
</div>
)}
<FormField
control={form.control}
name="name"
render={({ field }) => (
<FormItem className="flex items-center gap-4">
<FormLabel className="w-24 text-right"></FormLabel>
<FormControl>
<Input {...field} disabled={isBuiltIn} className="flex-1" />
</FormControl>
<FormMessage />
</FormItem>
)}
control={providerForm.control}
name="name"
render={({ field }) => (
<FormItem className="flex items-center gap-4">
<FormLabel className="w-24 text-right"></FormLabel>
<FormControl>
<Input {...field} disabled={isBuiltIn} className="flex-1" />
</FormControl>
<FormMessage />
</FormItem>
)}
/>
{/* API Key */}
<FormField
control={form.control}
name="apiKey"
render={({ field }) => (
<FormItem className="flex items-center gap-4">
<FormLabel className="w-24 text-right">API Key</FormLabel>
<FormControl>
<Input placeholder={'sk-xxx'} {...field} className="flex-1" />
</FormControl>
<FormMessage />
</FormItem>
)}
control={providerForm.control}
name="apiKey"
render={({ field }) => (
<FormItem className="flex items-center gap-4">
<FormLabel className="w-24 text-right">API Key</FormLabel>
<FormControl>
<Input {...field} className="flex-1" />
</FormControl>
<FormMessage />
</FormItem>
)}
/>
{/* Base URL */}
<FormField
control={form.control}
name="baseUrl"
render={({ field }) => (
<FormItem className="flex items-center gap-4">
<FormLabel className="w-24 text-right">API </FormLabel>
<FormControl>
<Input {...field} className="flex-1" />
</FormControl>
<FormMessage />
</FormItem>
)}
control={providerForm.control}
name="baseUrl"
render={({ field }) => (
<FormItem className="flex items-center gap-4">
<FormLabel className="w-24 text-right">API</FormLabel>
<FormControl>
<Input {...field} className="flex-1" />
</FormControl>
<Button type="button" onClick={handleTest} variant="ghost" disabled={testing}>
{testing ? '测试中...' : '测试连通性'}
</Button>
<FormMessage />
</FormItem>
)}
/>
{/* 类型 */}
<FormField
control={form.control}
name="type"
render={({ field }) => (
<FormItem className="flex items-center gap-4">
<FormLabel className="w-24 text-right"></FormLabel>
<FormControl>
<Input {...field} disabled className="flex-1" />
</FormControl>
</FormItem>
)}
control={providerForm.control}
name="type"
render={({ field }) => (
<FormItem className="flex items-center gap-4">
<FormLabel className="w-24 text-right"></FormLabel>
<FormControl>
<Input {...field} disabled className="flex-1" />
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<div className="pt-2">
<Button type="submit" disabled={!form.formState.isDirty}>
<Button type="submit" disabled={!providerForm.formState.isDirty}>
{isEditMode ? '保存修改' : '保存创建'}
</Button>
</div>
</form>
</Form>
);
};
export default ProviderForm;
{/* 模型信息表单 */}
<div className="flex max-w-xl flex-col gap-4">
<div className="flex flex-col gap-2">
<span className="font-bold"></span>
<div className={'flex flex-col gap-2 rounded bg-[#FEF0F0] p-2.5'}>
<h2 className={'font-bold'}>!</h2>
<span>,.</span>
</div>
<ModelSelector providerId={id!} />
{/*<datalist id="model-options">*/}
{/* {modelOptions.map(model => (*/}
{/* <option key={model.id + '1'} value={model.id} />*/}
{/* ))}*/}
{/*</datalist>*/}
</div>
</div>
</div>
)
}
export default ProviderForm

View File

@@ -0,0 +1,4 @@
// iconMap.ts
import * as Icons from '@lobehub/icons'
export const IconMap = Icons;

View File

@@ -0,0 +1,29 @@
import * as Icons from '@lobehub/icons'
import CustomLogo from '@/assets/customAI.png'
interface AILogoProps {
name: string // 图标名称(区分大小写!如 OpenAI、DeepSeek
style?: 'Color' | 'Text' | 'Outlined' | 'Glyph'
size?: number
}
const AILogo = ({ name, style = 'Color', size = 24 }: AILogoProps) => {
const Icon = Icons[name as keyof typeof Icons]
if (!Icon) {
console.error(`❌ 图标组件不存在: ${name}`)
return (
<span style={{ fontSize: size }}>
<img src={CustomLogo} alt="CustomLogo" style={{ width: size, height: size }} />
</span>
)
}
const Variant = Icon[style as keyof typeof Icon]
if (!Variant) {
return <Icon size={size} />
}
return <Variant size={size} />
}
export default AILogo

View File

@@ -0,0 +1,92 @@
import { useState, useEffect } from 'react'
import { useModelStore } from '@/store/modelStore'
import { Input } from '@/components/ui/input'
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from '@/components/ui/select'
import { Button } from '@/components/ui/button'
import toast from 'react-hot-toast'
interface ModelSelectorProps {
providerId: string
}
export function ModelSelector({ providerId }: ModelSelectorProps) {
const { models, loading, selectedModel, loadModels, setSelectedModel, addNewModel } =
useModelStore()
const [search, setSearch] = useState('')
const [submitting, setSubmitting] = useState(false)
const filteredModels = models.filter(model => {
const keywords = search.trim().toLowerCase().split(/\s+/)
const target = model.id.toLowerCase()
return keywords.every(kw => target.includes(kw))
})
useEffect(() => {
if (providerId) {
loadModels(providerId)
}
}, [providerId])
const handleSubmit = async () => {
if (!selectedModel) {
toast.error('请选择一个模型')
return
}
try {
setSubmitting(true)
await addNewModel(providerId, selectedModel)
toast.success('保存模型成功 🎉')
} catch (error) {
toast.error('保存失败')
} finally {
setSubmitting(false)
}
}
return (
<div className="flex flex-col gap-4">
<div className="flex items-center gap-2 font-bold">
<span></span>
<Button
variant="ghost"
type="button"
onClick={() => loadModels(providerId)}
disabled={loading}
>
{loading ? '加载中...' : '刷新模型'}
</Button>
</div>
<Select value={selectedModel} onValueChange={setSelectedModel}>
<SelectTrigger className="w-[300px]">
<SelectValue placeholder="请选择模型" />
</SelectTrigger>
<SelectContent>
<div className="p-2">
<Input
placeholder="搜索模型..."
value={search}
onChange={e => setSearch(e.target.value)}
className="h-8"
/>
</div>
{filteredModels.map(model => (
<SelectItem key={model.id} value={model.id}>
{model.id}
</SelectItem>
))}
</SelectContent>
</Select>
<Button onClick={handleSubmit} disabled={submitting || !selectedModel}>
{submitting ? '保存中...' : '保存模型'}
</Button>
</div>
)
}

View File

@@ -1,28 +1,39 @@
import ProviderCard from '@/components/Form/modelForm/components/providerCard.tsx'
import { Button } from '@/components/ui/button.tsx'
import { useProviderStore } from '@/store/providerStore'
import { useNavigate } from 'react-router-dom'
const Provider = () => {
const providers = useProviderStore(state => state.provider)
const providers = useProviderStore(state => state.provider)
const navigate = useNavigate()
const handleClick = () => {
navigate(`/settings/model/new`)
}
return (
<div className="flex flex-col gap-2">
<div className={'search flex gap-1 py-1.5'}>
<Button type={'button'} className="w-full">
<Button
type={'button'}
onClick={() => {
handleClick()
}}
className="w-full"
>
</Button>
</div>
<div className="text-sm font-light"></div>
<div>
{providers &&
providers.map((provider, index) => {
providers.map((provider, index) => {
return (
<ProviderCard
key={index}
providerName={provider.name}
Icon={provider.logo}
id={provider.id}
enable={provider.enabled}
/>
)
})}

View File

@@ -1,22 +1,37 @@
import { Switch } from '@/components/ui/switch'
import { FC } from 'react'
import styles from './index.module.css'
import {useNavigate, useParams} from 'react-router-dom'
import AILogo from "@/components/Icons";
import { useNavigate, useParams } from 'react-router-dom'
import AILogo from '@/components/Form/modelForm/Icons'
import { useProviderStore } from '@/store/providerStore'
export interface IProviderCardProps {
id: string
providerName: string
Icon: string
enable: number
}
const ProviderCard: FC<IProviderCardProps> = ({ providerName, Icon, id }: IProviderCardProps) => {
const ProviderCard: FC<IProviderCardProps> = ({
providerName,
Icon,
id,
enable,
}: IProviderCardProps) => {
const navigate = useNavigate()
const updateProvider = useProviderStore(state => state.updateProvider)
const handleClick = () => {
navigate(`/settings/model/${providerName}&id=${id}`)
navigate(`/settings/model/${id}`)
}
const rawId= useParams();
console.log('rawId',rawId)
const handleEnable = () => {
console.log('enable', enable)
updateProvider({
id,
enabled: enable == 1 ? 0 : 1,
})
}
const rawId = useParams()
console.log('rawId', rawId)
// @ts-ignore
const { id: currentId } = useParams();
const { id: currentId } = useParams()
const isActive = currentId === id
return (
<div
@@ -24,18 +39,26 @@ const ProviderCard: FC<IProviderCardProps> = ({ providerName, Icon, id }: IProvi
handleClick()
}}
className={
styles.card + ' flex h-14 items-center justify-between rounded border border-[#f3f3f3] p-2'
+(isActive ? ' bg-[#F0F0F0] font-semibold text-blue-600' : '')
styles.card +
' flex h-14 items-center justify-between rounded border border-[#f3f3f3] p-2' +
(isActive ? ' bg-[#F0F0F0] font-semibold text-blue-600' : '')
}
>
<div className="flex items-center text-lg">
<div className="h-9 w-9 flex items-center">
<AILogo name={Icon} />
<div className="flex h-9 w-9 items-center">
<AILogo name={Icon} />
</div>
<div className="font-semibold">{providerName}</div>
</div>
<div>
<Switch />
<Switch
onClick={e => {
e.preventDefault()
handleEnable()
}}
checked={enable == 1}
/>
</div>
</div>
)

View File

@@ -0,0 +1,40 @@
import { FC, useRef, useEffect } from 'react'
import Lottie, { LottieRefCurrentProps } from 'lottie-react'
import download from '@/assets/Lottie/download.json'
interface LoadingProps {
play?: boolean // 是否播放
color?: string // 控制主色,比如 "#00BFFF"
}
const Downloading: FC<LoadingProps> = ({ play = true, color = '#00BFFF' }) => {
const lottieRef = useRef<LottieRefCurrentProps>(null)
useEffect(() => {
if (!lottieRef.current) return
if (play) {
lottieRef.current.play()
} else {
lottieRef.current.pause()
}
}, [play])
return (
<div className="flex items-center justify-center">
<Lottie
lottieRef={lottieRef}
animationData={download}
loop
autoplay={play}
style={{
width: 150,
height: 150,
filter: `drop-shadow(0 0 4px ${color}) saturate(2) brightness(1.2)`,
}}
/>
</div>
)
}
export default Downloading

View File

@@ -0,0 +1,21 @@
import { FC } from 'react'
import Lottie from 'lottie-react'
import error from '@/assets/Lottie/error.json'
const Error: FC = () => {
return (
<div className="flex items-center justify-center">
<Lottie
animationData={error}
loop
autoplay
style={{
width: 450,
height: 450,
}}
/>
</div>
)
}
export default Error

View File

@@ -0,0 +1,64 @@
import * as React from 'react'
import { cva, type VariantProps } from 'class-variance-authority'
import { cn } from '@/lib/utils'
const alertVariants = cva(
'relative w-full rounded-lg border px-4 py-3 text-sm grid has-[>svg]:grid-cols-[calc(var(--spacing)*4)_1fr] grid-cols-[0_1fr] has-[>svg]:gap-x-3 gap-y-0.5 items-start [&>svg]:size-4 [&>svg]:translate-y-0.5 [&>svg]:text-current',
{
variants: {
variant: {
default: 'bg-card text-card-foreground',
destructive:
'text-destructive bg-card [&>svg]:text-current *:data-[slot=alert-description]:text-destructive/90',
success:
'text-success bg-card [&>svg]:text-current *:data-[slot=alert-description]:text-success/90',
warning:
'text-[#303133] bg-[#FEF0F0] [&>svg]:text-current *:data-[slot=alert-description]:text-warning/90',
},
},
defaultVariants: {
variant: 'default',
},
}
)
function Alert({
className,
variant,
...props
}: React.ComponentProps<'div'> & VariantProps<typeof alertVariants>) {
return (
<div
data-slot="alert"
role="alert"
className={cn(alertVariants({ variant }), className)}
{...props}
/>
)
}
function AlertTitle({ className, ...props }: React.ComponentProps<'div'>) {
return (
<div
data-slot="alert-title"
className={cn('col-start-2 line-clamp-1 min-h-4 font-medium tracking-tight', className)}
{...props}
/>
)
}
function AlertDescription({ className, ...props }: React.ComponentProps<'div'>) {
return (
<div
data-slot="alert-description"
className={cn(
'text-muted-foreground col-start-2 grid justify-items-start gap-1 text-sm [&_p]:leading-relaxed',
className
)}
{...props}
/>
)
}
export { Alert, AlertTitle, AlertDescription }

View File

@@ -0,0 +1,29 @@
import * as React from "react"
import * as SwitchPrimitive from "@radix-ui/react-switch"
import { cn } from "@/lib/utils"
function Switch({
className,
...props
}: React.ComponentProps<typeof SwitchPrimitive.Root>) {
return (
<SwitchPrimitive.Root
data-slot="switch"
className={cn(
"peer data-[state=checked]:bg-primary data-[state=unchecked]:bg-input focus-visible:border-ring focus-visible:ring-ring/50 dark:data-[state=unchecked]:bg-input/80 inline-flex h-[1.15rem] w-8 shrink-0 items-center rounded-full border border-transparent shadow-xs transition-all outline-none focus-visible:ring-[3px] disabled:cursor-not-allowed disabled:opacity-50",
className
)}
{...props}
>
<SwitchPrimitive.Thumb
data-slot="switch-thumb"
className={cn(
"bg-background dark:data-[state=unchecked]:bg-foreground dark:data-[state=checked]:bg-primary-foreground pointer-events-none block size-4 rounded-full ring-0 transition-transform data-[state=checked]:translate-x-[calc(100%-2px)] data-[state=unchecked]:translate-x-0"
)}
/>
</SwitchPrimitive.Root>
)
}
export { Switch }

View File

@@ -1,45 +1,59 @@
// hooks/useTaskPolling.ts
import { useEffect } from 'react'
import { useEffect, useRef } from 'react'
import { useTaskStore } from '@/store/taskStore'
import { get_task_status } from '@/services/note.ts'
import toast from 'react-hot-toast'
export const useTaskPolling = (interval = 3000) => {
const tasks = useTaskStore(state => state.tasks)
const updateTaskContent = useTaskStore(state => state.updateTaskContent)
const updateTaskStatus = useTaskStore(state => state.updateTaskStatus)
const removeTask = useTaskStore(state => state.removeTask)
const tasksRef = useRef(tasks)
// 每次 tasks 更新,把最新的 tasks 同步进去
useEffect(() => {
tasksRef.current = tasks
}, [tasks])
useEffect(() => {
const timer = setInterval(async () => {
const pendingTasks = tasks.filter(
task => task.status === 'PENDING' || task.status === 'running'
const pendingTasks = tasksRef.current.filter(
task => task.status != 'SUCCESS' && task.status != 'FAILED'
)
for (const task of pendingTasks) {
try {
console.log(task)
console.log('🔄 正在轮询任务:', task.id)
const res = await get_task_status(task.id)
const { status } = res.data
if (status && status !== task.status) {
if (status === 'SUCCESS') {
const { markdown, transcript, audio_meta } = res.data.result
toast.success('笔记生成成功')
updateTaskContent(task.id, {
status,
markdown,
transcript,
audioMeta: audio_meta,
})
} else if (status === 'FAILED') {
updateTaskContent(task.id, { status })
console.warn(`⚠️ 任务 ${task.id} 失败`)
} else {
updateTaskStatus(task.id, status)
updateTaskContent(task.id, { status })
}
}
} catch (e) {
console.error('❌ 任务轮询失败:', e)
removeTask(task.id)
toast.error(`生成失败 ${e.message || e}`)
updateTaskContent(task.id, { status: 'FAILED' })
// removeTask(task.id)
}
}
}, interval)
return () => clearInterval(timer)
}, [interval, tasks])
}, [interval])
}

View File

@@ -3,7 +3,7 @@ import HomeLayout from '@/layouts/HomeLayout.tsx'
import NoteForm from '@/pages/HomePage/components/NoteForm.tsx'
import MarkdownViewer from '@/pages/HomePage/components/MarkdownViewer.tsx'
import { useTaskStore } from '@/store/taskStore'
type ViewStatus = 'idle' | 'loading' | 'success'
type ViewStatus = 'idle' | 'loading' | 'success' | 'failed'
export const HomePage: FC = () => {
const tasks = useTaskStore(state => state.tasks)
const currentTaskId = useTaskStore(state => state.currentTaskId)
@@ -21,6 +21,8 @@ export const HomePage: FC = () => {
setStatus('loading')
} else if (currentTask.status === 'SUCCESS') {
setStatus('success')
} else if (currentTask.status === 'FAILED') {
setStatus('failed')
}
}, [currentTask])

View File

@@ -3,7 +3,7 @@ import ReactMarkdown from 'react-markdown'
import { Button } from '@/components/ui/button.tsx'
import { Copy, Download, FileText, ArrowRight } from 'lucide-react'
import { toast } from 'sonner' // 你可以换成自己的通知组件
import Error from '@/components/Lottie/error.tsx'
import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter'
import { solarizedlight as codeStyle } from 'react-syntax-highlighter/dist/cjs/styles/prism'
import 'github-markdown-css/github-markdown-light.css'
@@ -11,14 +11,26 @@ import { FC } from 'react'
import Loading from '@/components/Lottie/Loading.tsx'
import Idle from '@/components/Lottie/Idle.tsx'
import { useTaskStore } from '@/store/taskStore'
import StepBar from '@/pages/HomePage/components/StepBar.tsx'
interface MarkdownViewerProps {
content: string
status: 'idle' | 'loading' | 'success'
status: 'idle' | 'loading' | 'success' | 'failed'
}
const steps = [
{ label: '解析链接', key: 'PARSING' },
{ label: '下载音频', key: 'DOWNLOADING' },
{ label: '转写文字', key: 'TRANSCRIBING' },
{ label: '总结内容', key: 'SUMMARIZING' },
{ label: '保存完成', key: 'SUCCESS' },
]
const MarkdownViewer: FC<MarkdownViewerProps> = ({ content, status }) => {
const [copied, setCopied] = useState(false)
const getCurrentTask = useTaskStore.getState().getCurrentTask
const currentTask = useTaskStore(state => state.getCurrentTask())
const taskStatus = currentTask?.status || 'PENDING'
const retryTask = useTaskStore.getState().retryTask
const handleCopy = async () => {
try {
await navigator.clipboard.writeText(content)
@@ -34,6 +46,7 @@ const MarkdownViewer: FC<MarkdownViewerProps> = ({ content, status }) => {
const handleDownload = () => {
const currentTask = getCurrentTask()
const currentTaskName = currentTask?.audioMeta.title
const blob = new Blob([content], { type: 'text/markdown;charset=utf-8' })
const link = document.createElement('a')
link.href = URL.createObjectURL(blob)
@@ -45,6 +58,7 @@ const MarkdownViewer: FC<MarkdownViewerProps> = ({ content, status }) => {
if (status === 'loading') {
return (
<div className="flex h-screen w-full flex-col items-center justify-center space-y-4 text-neutral-500">
<StepBar steps={steps} currentStep={taskStatus} />
<Loading className="h-5 w-5" />
<div className="text-center text-sm">
<p className="text-lg font-bold"></p>
@@ -63,6 +77,24 @@ const MarkdownViewer: FC<MarkdownViewerProps> = ({ content, status }) => {
</div>
</div>
)
} else if (status === 'failed') {
return (
<div className="flex h-screen w-full flex-col items-center justify-center gap-4 space-y-3">
<Error /> {/* 你可以换成 Failed 动画 */}
<div className="text-center">
<p className="text-lg font-bold text-red-500"></p>
<p className="mt-2 mb-2 text-xs text-red-400"></p>
<Button
onClick={() => {
retryTask(currentTask.id)
}}
size="lg"
>
</Button>
</div>
</div>
)
}
return (

View File

@@ -6,6 +6,7 @@ import {
FormLabel,
FormMessage,
} from '@/components/ui/form.tsx'
import { useEffect } from 'react'
import { Input } from '@/components/ui/input.tsx'
import {
Select,
@@ -30,7 +31,9 @@ import {
import { generateNote } from '@/services/note.ts'
import { useTaskStore } from '@/store/taskStore'
import NoteHistory from '@/pages/HomePage/components/NoteHistory.tsx'
import { useModelStore } from '@/store/modelStore'
import { Alert } from 'antd'
import { Textarea } from '@/components/ui/textarea.tsx'
// ✅ 定义表单 schema
const formSchema = z.object({
video_url: z.string().url('请输入正确的视频链接'),
@@ -40,15 +43,70 @@ const formSchema = z.object({
}),
screenshot: z.boolean().optional(),
link: z.boolean().optional(),
model_name: z.string().nonempty('请选择模型'),
format: z.array(z.string()).default([]), // ✨ 确保默认是空数组
style: z.string().nonempty('请选择笔记生成风格'),
extras: z.string().optional(),
})
type NoteFormValues = z.infer<typeof formSchema>
const noteFormats = [
{
label: '目录',
value: 'toc',
},
{ label: '原片跳转', value: 'link' },
{ label: '原片截图', value: 'screenshot' },
{ label: 'AI总结', value: 'summary' },
]
const noteStyles = [
{
label: '精简',
value: 'minimal', // 简洁、快速呈现要点
},
{
label: '详细',
value: 'detailed', // 详细记录,包含时间戳、关键点
},
{
label: '教程',
value: 'tutorial', // 详细记录,包含时间戳、关键点
},
{
label: '学术',
value: 'academic', // 适合学术报告,正式且结构化
},
{
label: '小红书',
value: 'xiaohongshu', // 适合社交平台分享,亲切、口语化
},
{
label: '生活向',
value: 'life_journal', // 记录个人生活感悟,情感化表达
},
{
label: '任务导向',
value: 'task_oriented', // 强调任务、目标,适合工作和待办事项
},
{
label: '商业风格',
value: 'business', // 适合商业报告、会议纪要,正式且精准
},
{
label: '会议纪要',
value: 'meeting_minutes', // 适合商业报告、会议纪要,正式且精准
},
]
const NoteForm = () => {
useTaskStore(state => state.tasks)
const setCurrentTask = useTaskStore(state => state.setCurrentTask)
const currentTaskId = useTaskStore(state => state.currentTaskId)
const getCurrentTask = useTaskStore(state => state.getCurrentTask)
const loadEnabledModels = useModelStore(state => state.loadEnabledModels)
const modelList = useModelStore(state => state.modelList)
const showFeatureHint = useModelStore(state => state.showFeatureHint)
const setShowFeatureHint = useModelStore(state => state.setShowFeatureHint)
const form = useForm<NoteFormValues>({
resolver: zodResolver(formSchema),
defaultValues: {
@@ -56,9 +114,16 @@ const NoteForm = () => {
platform: 'bilibili',
quality: 'medium', // 默认中等质量
screenshot: false,
model_name: modelList[0]?.model_name || '', // 确保有值
format: [], // 初始化为空数组
style: 'minimal', // 默认选择精简风格
extras: '', // 初始化为空字符串
},
})
const onClose = () => {
setShowFeatureHint(false)
}
const isGenerating = () => {
console.log('🚀 isGenerating', getCurrentTask()?.status)
return getCurrentTask()?.status === 'PENDING'
@@ -66,14 +131,23 @@ const NoteForm = () => {
const onSubmit = async (data: NoteFormValues) => {
console.log('🎯 提交内容:', data)
await generateNote({
const payload = {
video_url: data.video_url,
platform: data.platform,
quality: data.quality,
screenshot: data.screenshot,
link: data.link,
})
model_name: data.model_name,
provider_id: modelList.find(model => model.model_name === data.model_name).provider_id,
format: data.format,
style: data.style,
extras: data.extras,
}
const res = await generateNote(payload)
const taskId = res.data.task_id
useTaskStore.getState().addPendingTask(taskId, data.platform, payload)
}
useEffect(() => {
loadEnabledModels()
}, [])
return (
<div className="flex h-full flex-col">
@@ -173,48 +247,157 @@ const NoteForm = () => {
</FormItem>
)}
/>
<FormField
control={form.control}
name="model_name"
render={({ field }) => (
<FormItem>
<div className="my-3 flex items-center justify-between">
<h2 className="block"></h2>
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<Info className="hover:text-primary h-4 w-4 cursor-pointer text-neutral-400" />
</TooltipTrigger>
<TooltipContent>
<p className="max-w-[200px] text-xs"></p>
</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
<Select onValueChange={field.onChange} defaultValue={field.value}>
<FormControl>
<SelectTrigger className="w-full">
<SelectValue placeholder="选择配置好的模型" />
</SelectTrigger>
</FormControl>
<SelectContent>
{modelList.map(item => {
return <SelectItem value={item.model_name}>{item.model_name}</SelectItem>
})}
</SelectContent>
</Select>
{/*<FormDescription className="text-xs text-neutral-500">*/}
{/* 质量越高,下载体积越大,速度越慢*/}
{/*</FormDescription>*/}
<FormMessage />
</FormItem>
)}
/>
</div>
{/* 是否需要原片位置 */}
<FormField
control={form.control}
name="link"
name="style"
render={({ field }) => (
<FormItem className="flex items-center space-x-2">
{/* Tooltip 部分 */}
<FormItem>
<div className="my-3 flex items-center justify-between">
<h2 className="block"></h2>
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<Info className="hover:text-primary h-4 w-4 cursor-pointer text-neutral-400" />
</TooltipTrigger>
<TooltipContent>
<p className="max-w-[200px] text-xs"></p>
</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
<FormControl>
<Checkbox checked={field.value} onCheckedChange={field.onChange} id="link" />
</FormControl>
<Select onValueChange={field.onChange} defaultValue={field.value}>
<FormControl>
<SelectTrigger className="w-full">
<SelectValue placeholder="选择笔记风格" />
</SelectTrigger>
</FormControl>
<SelectContent>
{noteStyles.map(item => (
<SelectItem key={item.value} value={item.value}>
{item.label}
</SelectItem>
))}
</SelectContent>
</Select>
<FormLabel htmlFor="link" className="text-sm leading-none font-medium">
</FormLabel>
<FormMessage />
</FormItem>
)}
/>
{/* 是否需要下载 */}
<FormField
control={form.control}
name="screenshot"
name="format"
render={({ field }) => (
<FormItem className="flex items-center space-x-2">
{/* Tooltip 部分 */}
<FormItem>
<div className="my-3 flex items-center justify-between">
<h2 className="block"></h2>
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<Info className="hover:text-primary h-4 w-4 cursor-pointer text-neutral-400" />
</TooltipTrigger>
<TooltipContent>
<p className="text-xs"></p>
</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
<FormControl>
<Checkbox
checked={field.value}
onCheckedChange={field.onChange}
id="screenshot"
/>
<div className="flex space-x-1.5">
{noteFormats.map(item => (
<label key={item.value} className="flex items-center space-x-2">
<Checkbox
checked={field.value?.includes(item.value)}
onCheckedChange={checked => {
const currentValue = field.value ?? [] // ✨ 保底是数组
if (checked) {
field.onChange([...currentValue, item.value])
} else {
field.onChange(currentValue.filter(v => v !== item.value))
}
}}
/>
<span>{item.label}</span>
</label>
))}
</div>
</FormControl>
<FormField
control={form.control}
name="extras"
render={({ field }) => (
<FormItem>
<div className="my-3 flex items-center justify-between">
<h2 className="block"></h2>
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<Info className="hover:text-primary h-4 w-4 cursor-pointer text-neutral-400" />
</TooltipTrigger>
<TooltipContent>
<p className="text-xs">Prompt最后 </p>
</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
<Textarea placeholder={'笔记需要罗列出 xxx 关键点'} />
<FormLabel htmlFor="screenshot" className="text-sm leading-none font-medium">
</FormLabel>
{/*<FormDescription className="text-xs text-neutral-500">*/}
{/* 质量越高,下载体积越大,速度越慢*/}
{/*</FormDescription>*/}
<FormMessage />
</FormItem>
)}
/>
<FormMessage />
</FormItem>
)}
/>
<div className={'flex w-full items-center gap-2 py-1.5'}>
{/* 提交按钮 */}
<Button type="submit" className="bg-primary w-full" disabled={isGenerating()}>
@@ -235,27 +418,35 @@ const NoteForm = () => {
</div>
{/* 添加一些额外的说明或功能介绍 */}
<div className="bg-primary-light mt-6 rounded-lg p-4">
<h3 className="text-primary mb-2 font-medium"></h3>
<ul className="space-y-2 text-sm text-neutral-600">
<li className="flex items-start gap-2">
<span className="text-primary font-bold"></span>
<span></span>
</li>
<li className="flex items-start gap-2">
<span className="text-primary font-bold"></span>
<span>YouTube等</span>
</li>
<li className="flex items-start gap-2">
<span className="text-primary font-bold"></span>
<span>Markdown格式</span>
</li>
<li className="flex items-start gap-2">
<span className="text-primary font-bold"></span>
<span></span>
</li>
</ul>
</div>
{showFeatureHint && (
<Alert
message="功能介绍 v2.0.0"
description={
<ul className="space-y-2 text-sm text-neutral-600">
<li className="flex items-start gap-2">
<span className="text-primary font-bold"></span>
<span></span>
</li>
<li className="flex items-start gap-2">
<span className="text-primary font-bold"></span>
<span>YouTube等</span>
</li>
<li className="flex items-start gap-2">
<span className="text-primary font-bold"></span>
<span>Markdown格式</span>
</li>
<li className="flex items-start gap-2">
<span className="text-primary font-bold"></span>
<span></span>
</li>
</ul>
}
type="info"
onClose={onClose}
closable
/>
)}
{/*<div className="bg-primary-light mt-6 rounded-lg p-4"></div>*/}
</div>
)
}

View File

@@ -0,0 +1,54 @@
import { FC } from 'react'
interface Step {
label: string
key: string
Icon?: React.ReactNode // 加一个可选的 Lottie 动画
}
interface StepBarProps {
steps: Step[]
currentStep: string
}
const StepBar: FC<StepBarProps> = ({ steps, currentStep }) => {
const currentIndex = steps.findIndex(step => step.key === currentStep)
return (
<div className="flex w-full items-center justify-between">
{steps.map((step, index) => {
const isActive = index <= currentIndex
const isCurrent = index === currentIndex
const isLast = index === steps.length - 1
return (
<div key={step.key} className="relative flex flex-1 flex-col items-center">
{/* 圆圈或者Lottie */}
<div className="relative flex flex-col items-center justify-center">
<div
className={`flex h-8 w-8 items-center justify-center rounded-full text-xs font-bold ${
isActive ? 'bg-primary text-white' : 'bg-gray-300 text-gray-600'
}`}
>
{index + 1}
</div>
{/* 当前步骤显示动画 */}
{isCurrent && step.Icon && (
<div className="absolute top-10 h-16 w-16">{step.Icon}</div>
)}
</div>
{/* 步骤名称 */}
<div className="mt-4 text-center text-xs text-gray-700">{step.label}</div>
{/* 连接线 */}
{!isLast && (
<div className={`h-1 w-full ${isActive ? 'bg-primary' : 'bg-gray-300'}`}></div>
)}
</div>
)
})}
</div>
)
}
export default StepBar

View File

@@ -9,26 +9,27 @@ const Menu = () => {
icon: <BotMessageSquare />,
path: '/settings/model',
},
{
id: ' transcriber',
name: '音频转译配置',
icon: <Captions />,
path: '/settings/transcriber',
},
//下载配置
{
id: 'download',
name: '下载配置',
icon: <HardDriveDownload />,
path: '/settings/download',
},
//其他配置
{
id: 'other',
name: '其他配置',
icon: <Wrench />,
path: '/settings/other',
},
// TODO :下一版本升级优化
// {
// id: ' transcriber',
// name: '音频转译配置',
// icon: <Captions />,
// path: '/settings/transcriber',
// },
// //下载配置
// {
// id: 'download',
// name: '下载配置',
// icon: <HardDriveDownload />,
// path: '/settings/download',
// },
// //其他配置
// {
// id: 'other',
// name: '其他配置',
// icon: <Wrench />,
// path: '/settings/other',
// },
]
return (
<div className="flex h-full flex-col">

View File

@@ -0,0 +1,8 @@
const Transcriber = () => {
return (
<div className="flex h-screen w-full flex-col items-center justify-center">
<h1 className="text-center text-4xl font-bold">Transcriber is under development</h1>
</div>
)
}
export default Transcriber

View File

@@ -3,3 +3,29 @@ import request from '@/utils/request.ts'
export const getProviderList = async () => {
return await request.get('/get_all_providers')
}
export const getProviderById = async (id: string) => {
return await request.get(`/get_provider_by_id/${id}`)
}
export const updateProviderById = async (data: any) => {
return await request.post('/update_provider', data)
}
export const addProvider = async (data: any) => {
return await request.post('/add_provider', data)
}
export const testConnection = async (data: any) => {
return await request.post('/connect_test', data)
}
export const fetchModels = async (providerId: any) => {
return await request.get('/model_list/' + providerId)
}
export async function addModel(data: { provider_id: string; model_name: string }) {
return request.post('/models', data)
}
export const fetchEnableModels = async () => {
return await request.get('/model_list')
}

View File

@@ -4,10 +4,14 @@ import { useTaskStore } from '@/store/taskStore'
import request from '@/utils/request'
export const generateNote = async (data: {
video_url: string
link: undefined | boolean
screenshot: undefined | boolean
platform: string
quality: string
model_name: string
provider_id: string
task_id?: string
format: Array<string>
style: string
extras?: string
}) => {
try {
const response = await request.post('/generate_note', data)
@@ -20,11 +24,8 @@ export const generateNote = async (data: {
}
toast.success('笔记生成任务已提交!')
const taskId = response.data.data.task_id
console.log('res', response)
// 成功提示
useTaskStore.getState().addPendingTask(taskId, data.platform)
return response.data
} catch (e: any) {

View File

@@ -0,0 +1,25 @@
import { create } from 'zustand'
import { persist } from 'zustand/middleware'
interface SystemState {
showFeatureHint: boolean // ✅ 是否显示功能提示
setShowFeatureHint: (value: boolean) => void
// 后续如果有其他全局状态,可以继续加
sidebarCollapsed: boolean // ✅ 侧边栏是否收起
setSidebarCollapsed: (value: boolean) => void
}
// 暂不启用
export const useSystemStore = create<SystemState>()(
persist(
set => ({
showFeatureHint: true,
setShowFeatureHint: value => set({ showFeatureHint: value }),
sidebarCollapsed: false,
setSidebarCollapsed: value => set({ sidebarCollapsed: value }),
}),
{
name: 'system-store', // 本地存储的 key
}
)
)

View File

@@ -0,0 +1,101 @@
import { create } from 'zustand'
import { devtools } from 'zustand/middleware'
import { fetchModels, addModel, fetchEnableModels } from '@/services/model.ts'
interface IModel {
id: string
created: number
object: string
owned_by: string
permission: string
root: string
}
interface ModelStore {
models: IModel[]
modelList: []
loading: boolean
selectedModel: string
loadModels: (providerId: string) => Promise<void>
loadEnabledModels: () => Promise<void>
addNewModel: (providerId: string, modelId: string) => Promise<void>
setSelectedModel: (modelId: string) => void
clearModels: () => void
}
export const useModelStore = create<ModelStore>()(
devtools(set => ({
models: [],
loading: false,
selectedModel: '',
modelList: [],
loadEnabledModels: async () => {
try {
set({ loading: true })
const res = await fetchEnableModels()
if (res.data.code === 0 && res.data.data.length > 0) {
set({ modelList: res.data.data })
} else {
set({ modelList: [] })
console.error('模型列表加载失败')
}
} catch (error) {
set({ modelList: [] })
console.error('加载模型出错', error)
}
},
// 加载模型列表
loadModels: async (providerId: string) => {
try {
set({ loading: true })
const res = await fetchModels(providerId)
if (res.data.code === 0 && res.data.data.models.data.length > 0) {
set({ models: res.data.data.models.data })
} else {
set({ models: [] })
console.error('模型列表加载失败')
}
} catch (error) {
set({ models: [] })
console.error('加载模型出错', error)
} finally {
set({ loading: false })
}
},
// 新增模型
addNewModel: async (providerId: string, modelId: string) => {
try {
const res = await addModel({ provider_id: providerId, model_name: modelId })
if (res.code === 0) {
console.log('新增模型成功:', modelId)
// ✅ 新增成功以后,前端直接追加一条到 models 列表
set(state => ({
models: [
...state.models,
{
id: modelId,
created: Date.now(),
object: 'model',
owned_by: '',
permission: '',
root: '',
},
],
}))
} else {
console.error('新增模型失败')
}
} catch (error) {
console.error('添加模型出错', error)
}
},
// 设置选中的模型
setSelectedModel: modelId => set({ selectedModel: modelId }),
// 清空
clearModels: () => set({ models: [], selectedModel: '' }),
}))
)

View File

@@ -1,6 +1,11 @@
import { create } from 'zustand'
import { IProvider } from '@/types'
import { getProviderList } from '@/services/model.ts'
import {
addProvider,
getProviderById,
getProviderList,
updateProviderById,
} from '@/services/model.ts'
interface ProviderStore {
provider: IProvider[]
@@ -9,12 +14,14 @@ interface ProviderStore {
getProviderById: (id: number) => IProvider | undefined
getProviderList: () => IProvider[]
fetchProviderList: () => Promise<void>
loadProviderById: (id: string) => Promise<void>
addNewProvider: (provider: IProvider) => Promise<void>
updateProvider: (provider: IProvider) => Promise<void>
}
export const useProviderStore = create<ProviderStore>((set, get) => ({
provider: [],
// 添加或更新一个 provider
setProvider: newProvider =>
set(state => {
@@ -30,10 +37,60 @@ export const useProviderStore = create<ProviderStore>((set, get) => ({
// 设置整个 provider 列表
setAllProviders: providers => set({ provider: providers }),
loadProviderById: async (id: string) => {
const res = await getProviderById(id)
if (res.data.code === 0) {
const item = res.data.data
console.log('Provider ', item)
return {
id: item.id,
name: item.name,
logo: item.logo,
apiKey: item.api_key,
baseUrl: item.base_url,
type: item.type,
enabled: item.enabled,
}
} else {
console.log('Provider not found')
}
},
addNewProvider: async (provider: IProvider) => {
const payload = {
...provider,
api_key: provider.apiKey,
base_url: provider.baseUrl,
}
try {
const res = await addProvider(payload)
if (res.data.code === 0) {
const item = res.data.data
console.log('Provider ', item)
await get().fetchProviderList()
}
} catch (error) {
console.error('Error fetching provider:', error)
}
},
// 按 id 获取单个 provider
getProviderById: id => get().provider.find(p => p.id === id),
updateProvider: async (provider: IProvider) => {
try {
const data = {
...provider,
api_key: provider.apiKey,
base_url: provider.baseUrl,
}
const res = await updateProviderById(data)
if (res.data.code === 0) {
const item = res.data.data
console.log('Provider ', item)
await get().fetchProviderList()
}
} catch (error) {
console.error('Error fetching provider:', error)
}
},
getProviderList: () => get().provider,
fetchProviderList: async () => {
try {
@@ -55,6 +112,7 @@ export const useProviderStore = create<ProviderStore>((set, get) => ({
apiKey: item.api_key,
baseUrl: item.base_url,
type: item.type,
enabled: item.enabled,
}
}
),

View File

@@ -1,6 +1,6 @@
import { create } from 'zustand'
import { persist } from 'zustand/middleware'
import { delete_task } from '@/services/note.ts'
import { delete_task, generateNote } from '@/services/note.ts'
export type TaskStatus = 'PENDING' | 'RUNNING' | 'SUCCESS' | 'FAILD'
@@ -34,6 +34,15 @@ export interface Task {
status: TaskStatus
audioMeta: AudioMeta
createdAt: string
formData: {
video_url: string
link: undefined | boolean
screenshot: undefined | boolean
platform: string
quality: string
model_name: string
provider_id: string
}
}
interface TaskStore {
@@ -45,6 +54,7 @@ interface TaskStore {
clearTasks: () => void
setCurrentTask: (taskId: string | null) => void
getCurrentTask: () => Task | null
retryTask: (id: string) => void
}
export const useTaskStore = create<TaskStore>()(
@@ -53,10 +63,11 @@ export const useTaskStore = create<TaskStore>()(
tasks: [],
currentTaskId: null,
addPendingTask: (taskId: string, platform: string) =>
addPendingTask: (taskId: string, platform: string, formData: any) =>
set(state => ({
tasks: [
{
formData: formData,
id: taskId,
status: 'PENDING',
markdown: '',
@@ -91,6 +102,17 @@ export const useTaskStore = create<TaskStore>()(
const currentTaskId = get().currentTaskId
return get().tasks.find(task => task.id === currentTaskId) || null
},
retryTask: async (id: string) => {
const task = get().tasks.find(task => task.id === id).formData
await generateNote({
task_id: id,
...task,
})
set(state => ({
tasks: state.tasks.map(task => (task.id === id ? { ...task, status: 'PENDING' } : task)),
}))
},
removeTask: async id => {
const task = get().tasks.find(t => t.id === id)

View File

@@ -5,4 +5,5 @@ export interface IProvider {
type: string
apiKey: string
baseUrl: string
enabled: number
}

View File

@@ -3,7 +3,7 @@
<p align="center">
<img src="./doc/icon.svg" alt="BiliNote Banner" width="50" height="50" />
</p>
<h1 align="center" > BiliNote v1.0.1</h1>
<h1 align="center" > BiliNote v1.1.0</h1>
</div>
<p align="center"><i>AI 视频笔记生成工具 让 AI 为你的视频做笔记</i></p>
@@ -119,6 +119,7 @@ docker compose up --build
## ⚙️ 环境变量配置
> ⚠️ v.1.1.0 以后无需通过环境变量配置 AI
后端 `.env` 示例:
@@ -131,15 +132,29 @@ OPENAI_API_KEY=sk-xxxxxx
DEEP_SEEK_API_KEY=xxx
QWEN_API_KEY=xxx
```
## Changelog
### v1.1.0
- #### Added
- 新增 AI 笔记风格选择
- 新增 AI 笔记返回格式选择
- 添加 AI 自定义笔记备注 Prompt
- 添加任务失败重试
- 添加全局设置页,可在设置页进行模型设置
- #### Optimize
- 优化前端样式,优化用户体验
- 增加生成中间产物,可用于失败后加快生成速度
- #### Fix
- 修复视频截图视频过早删除错误
## 🧠 TODO
- [ ] 支持抖音及快手等视频平台
- [ ] 支持前端设置切换 AI 模型切换、语音转文字模型
- [ ] AI 摘要风格自定义(学术风、口语风、重点提取等)
- [x] 支持前端设置切换 AI 模型切换、语音转文字模型
- [x] AI 摘要风格自定义(学术风、口语风、重点提取等)
- [ ] 笔记导出为 PDF / Word / Notion
- [ ] 加入更多模型支持
- [ ] 加入更多音频转文本模型支持
- [x] 加入更多模型支持
- [x] 加入更多音频转文本模型支持
### Contact and Join-联系和加入社区
- BiliNote 交流QQ群785367111

View File

@@ -1,9 +1,10 @@
from fastapi import FastAPI
from .routers import note, provider
from .routers import note, provider,model
def create_app() -> FastAPI:
app = FastAPI(title="BiliNote")
app.include_router(note.router, prefix="/api")
app.include_router(provider.router, prefix="/api")
app.include_router(model.router,prefix="/api")
return app

View File

@@ -0,0 +1,42 @@
[
{
"id": "openai",
"name": "OpenAI",
"type": "built-in",
"logo": "OpenAI",
"api_key": "",
"base_url": "https://api.openai.com/v1"
},
{
"id": "deepseek",
"name": "DeepSeek",
"type": "built-in",
"logo": "DeepSeek",
"api_key": "",
"base_url": "https://api.deepseek.com"
},
{
"id": "qwen",
"name": "Qwen",
"type": "built-in",
"logo": "Qwen",
"api_key": "",
"base_url": "https://qwen.aliyun.com/api"
},
{
"id": "doubao",
"name": "豆包 (Doubao)",
"type": "built-in",
"logo": "Doubao",
"api_key": "",
"base_url": "https://open.doubao.com/api"
},
{
"id": "Claude",
"name": "Claude",
"type": "built-in",
"logo": "Claude",
"api_key": "",
"base_url": "https://"
}
]

View File

@@ -0,0 +1,58 @@
from app.db.sqlite_client import get_connection
def init_model_table():
conn = get_connection()
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS models (
id INTEGER PRIMARY KEY AUTOINCREMENT,
provider_id INTEGER NOT NULL,
model_name TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
conn.commit()
conn.close()
# 插入模型
def insert_model(provider_id: int, model_name: str):
conn = get_connection()
cursor = conn.cursor()
cursor.execute("""
INSERT INTO models (provider_id, model_name)
VALUES (?, ?)
""", (provider_id, model_name))
conn.commit()
conn.close()
# 根据provider查模型
def get_models_by_provider(provider_id: int):
conn = get_connection()
cursor = conn.cursor()
cursor.execute("""
SELECT id, model_name FROM models
WHERE provider_id = ?
""", (provider_id,))
rows = cursor.fetchall()
conn.close()
return [{"id": row[0], "model_name": row[1]} for row in rows]
# 删除某个模型
def delete_model(model_id: int):
conn = get_connection()
cursor = conn.cursor()
cursor.execute("""
DELETE FROM models WHERE id = ?
""", (model_id,))
conn.commit()
conn.close()
def get_all_models():
conn = get_connection()
cursor = conn.cursor()
cursor.execute("""
SELECT id, provider_id, model_name FROM models
""")
rows = cursor.fetchall()
conn.close()
return [{"id": row[0], "provider_id": row[1], "model_name": row[2]} for row in rows]

View File

@@ -1,8 +1,59 @@
import json
import os
from app.db.sqlite_client import get_connection
from app.utils.logger import get_logger
logger = get_logger(__name__)
def seed_default_providers():
conn = get_connection()
if conn is None:
logger.error("Failed to connect to database.")
return
cursor = conn.cursor()
# 检查已有数据
cursor.execute("SELECT COUNT(*) FROM providers")
count = cursor.fetchone()[0]
if count > 0:
logger.info("Providers already exist, skipping seed.")
conn.close()
return
json_path = os.path.join(os.path.dirname(__file__), 'builtin_providers.json')
try:
with open(json_path, 'r', encoding='utf-8') as f:
providers = json.load(f)
except Exception as e:
logger.error(f"Failed to read builtin_providers.json: {e}")
conn.close()
return
try:
for p in providers:
cursor.execute("""
INSERT INTO providers (id, name, api_key, base_url, logo, type, enabled)
VALUES (?, ?, ?, ?, ?, ?, ?)
""", (
p['id'],
p['name'],
p['api_key'],
p['base_url'],
p['logo'],
p['type'],
p.get('enabled', 1)
))
conn.commit()
logger.info("Default providers seeded successfully.")
except Exception as e:
logger.error(f"Failed to seed default providers: {e}")
finally:
conn.close()
def init_provider_table():
conn = get_connection()
if conn is None:
@@ -11,40 +62,60 @@ def init_provider_table():
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS providers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
logo TEXT NOT NULL,
type TEXT NOT NULL, -- ✅ 新增字段
type TEXT NOT NULL,
api_key TEXT NOT NULL,
base_url TEXT NOT NULL,
enabled INTEGER DEFAULT 1,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
try:
conn.commit()
conn.close()
logger.info("provider table created successfully.")
seed_default_providers()
except Exception as e:
logger.error(f"Failed to create provider table: {e}")
def insert_provider(name: str, api_key: str, base_url: str, logo: str, type_: str):
def insert_provider(id: str, name: str, api_key: str, base_url: str, logo: str, type_: str,enabled:int=1):
conn = get_connection()
if conn is None:
logger.error("Failed to connect to the database.")
return
cursor = conn.cursor()
cursor.execute("""
INSERT INTO providers (name, api_key, base_url, logo, type)
VALUES (?, ?, ?, ?, ?)
""", (name, api_key, base_url, logo, type_))
INSERT INTO providers (id, name, api_key, base_url, logo, type, enabled)
VALUES (?, ?, ?, ?, ?, ?, ?)
""", (id, name, api_key, base_url, logo, type_, enabled))
try:
conn.commit()
cursor_id = cursor.lastrowid
conn.close()
logger.info(f"Provider inserted successfully. name: {name}, type: {type_}")
return cursor_id
logger.info(f"Provider inserted successfully. id: {id}, name: {name}, type: {type_}")
return id
except Exception as e:
logger.error(f"Failed to insert provider: {e}")
return None
def get_enabled_providers():
conn = get_connection()
if conn is None:
logger.error("Failed to connect to the database.")
return
cursor = conn.cursor()
cursor.execute("SELECT * FROM providers WHERE enabled = 1")
try:
rows = cursor.fetchall()
conn.close()
if rows is None:
logger.info("No providers found")
return None
logger.info(f"Providers found: {rows}")
return rows
except Exception as e:
logger.error(f"Failed to get enabled providers: {e}")
def get_provider_by_name(name: str):
conn = get_connection()
if conn is None:
@@ -70,6 +141,7 @@ def get_provider_by_id(id: int):
return
cursor = conn.cursor()
cursor.execute("SELECT * FROM providers WHERE id = ?", (id,))
try:
row = cursor.fetchone()
conn.close()
@@ -99,23 +171,40 @@ def get_all_providers():
except Exception as e:
logger.error(f"Failed to get all providers: {e}")
def update_provider(id: int, name: str, api_key: str, base_url: str, logo: str, type_: str):
def update_provider(id: str, **kwargs):
conn = get_connection()
if conn is None:
logger.error("Failed to connect to the database.")
return
cursor = conn.cursor()
cursor.execute("""
UPDATE providers
SET name = ?, api_key = ?, base_url = ?, logo = ?, type = ?
fields = []
values = []
for key, value in kwargs.items():
fields.append(f"{key} = ?")
values.append(value)
if not fields:
logger.warning("No fields provided for update.")
return
sql = f"""
UPDATE providers
SET {', '.join(fields)}
WHERE id = ?
""", (name, api_key, base_url, logo, type_, id))
"""
values.append(id) # id 最后加
cursor = conn.cursor()
try:
cursor.execute(sql, values)
conn.commit()
conn.close()
logger.info(f"Provider updated successfully. id: {id}, type: {type_}")
logger.info(f"Provider updated successfully. id: {id}, updated_fields: {fields}")
except Exception as e:
logger.error(f"Failed to update provider: {e}")
def delete_provider(id: int):
conn = get_connection()
if conn is None:

View File

@@ -0,0 +1,28 @@
import enum
class TaskStatus(str, enum.Enum):
PENDING = "PENDING"
PARSING = "PARSING"
DOWNLOADING = "DOWNLOADING"
TRANSCRIBING = "TRANSCRIBING"
SUMMARIZING = "SUMMARIZING"
FORMATTING = "FORMATTING"
SAVING = "SAVING"
SUCCESS = "SUCCESS"
FAILED = "FAILED"
@classmethod
def description(cls, status):
desc_map = {
cls.PENDING: "排队中",
cls.PARSING: "解析链接",
cls.DOWNLOADING: "下载中",
cls.TRANSCRIBING: "转录中",
cls.SUMMARIZING: "总结中",
cls.FORMATTING: "格式化中",
cls.SAVING: "保存中",
cls.SUCCESS: "完成",
cls.FAILED: "失败",
}
return desc_map.get(status, "未知状态")

View File

@@ -9,5 +9,5 @@ from app.models.model_config import ModelConfig
class GPTFactory:
@staticmethod
def from_config(config: ModelConfig) -> GPT:
client = OpenAICompatibleProvider(api_key=config.api_key, base_url=config.base_url).get_client()
client = OpenAICompatibleProvider(api_key=config.api_key, base_url=config.base_url).get_client
return UniversalGPT(client=client, model=config.model_name)

View File

@@ -0,0 +1,100 @@
from app.gpt.prompt import BASE_PROMPT
note_formats = [
{'label': '目录', 'value': 'toc'},
{'label': '原片跳转', 'value': 'link'},
{'label': '原片截图', 'value': 'screenshot'},
{'label': 'AI总结', 'value': 'summary'}
]
note_styles = [
{'label': '精简', 'value': 'minimal'},
{'label': '详细', 'value': 'detailed'},
{'label': '学术', 'value': 'academic'},
{"label": '教程',"value": 'tutorial', },
{'label': '小红书', 'value': 'xiaohongshu'},
{'label': '生活向', 'value': 'life_journal'},
{'label': '任务导向', 'value': 'task_oriented'},
{'label': '商业风格', 'value': 'business'},
{'label': '会议纪要', 'value': 'meeting_minutes'}
]
# 生成 BASE_PROMPT 函数
def generate_base_prompt(title, segment_text, tags, _format=None, style=None, extras=None):
# 生成 Base Prompt 开头部分
prompt = BASE_PROMPT.format(
video_title=title,
segment_text=segment_text,
tags=tags
)
# 添加用户选择的格式
if _format:
prompt += "\n" + "\n".join([get_format_function(f) for f in _format])
# 根据用户选择的笔记风格添加描述
if style:
prompt += "\n" + get_style_format(style)
# 添加额外内容
if extras:
prompt += f"\n{extras}"
return prompt
# 获取格式函数
def get_format_function(format_type):
format_map = {
'toc': get_toc_format,
'link': get_link_format,
'screenshot': get_screenshot_format,
'summary': get_summary_format
}
return format_map.get(format_type, lambda: '')()
# 风格描述的处理
def get_style_format(style):
style_map = {
'minimal': '1. **精简信息**: 仅记录最重要的内容,简洁明了。',
'detailed': '2. **详细记录**: 包含完整的时间戳和每个部分的详细讨论。',
'academic': '3. **学术风格**: 适合学术报告,正式且结构化。',
'xiaohongshu': '4. **小红书风格**: 适合社交平台分享,亲切、口语化。',
'life_journal': '5. **生活向**: 记录个人生活感悟,情感化表达。',
'task_oriented': '6. **任务导向**: 强调任务、目标,适合工作和待办事项。',
'business': '7. **商业风格**: 适合商业报告、会议纪要,正式且精准。',
'meeting_minutes': '8. **会议纪要**: 适合商业报告、会议纪要,正式且精准。',
"tutorial":"9.**教程笔记**:尽可能详细的记录教程,特别是关键点和一些重要的结论步骤"
}
return style_map.get(style, '')
# 格式化输出内容
def get_toc_format():
return '''
9. **目录**: 自动生成一个基于 `##` 级标题的目录。不需要插入原片跳转
'''
def get_link_format():
return '''
10. **原片跳转**: 为每个主要章节添加时间戳,使用格式 `*Content-[mm:ss]`。
重要:**始终**在章节标题前加上 `*Content` 前缀,例如:`AI 的发展史 *Content-[01:23]`。一定是标题在前 插入标记在后
'''
def get_screenshot_format():
return '''
11. **原片截图**: 如果某个部分涉及**视觉演示**或任何能帮助理解的内容,插入截图提示:
- 格式:`*Screenshot-[mm:ss]`
至少插入 1-3张截图
'''
def get_summary_format():
return '''
12. **AI总结**: 在笔记末尾加入简短的AI生成总结,并且二级标题 就是 AI 总结 例如 ## AI 总结。
'''

View File

@@ -1,4 +1,5 @@
from app.gpt.base import GPT
from app.gpt.prompt_builder import generate_base_prompt
from app.models.gpt_model import GPTSource
from app.gpt.prompt import BASE_PROMPT, AI_SUM, SCREENSHOT, LINK
from app.gpt.utils import fix_markdown
@@ -28,29 +29,35 @@ class UniversalGPT(GPT):
return [TranscriptSegment(**seg) if isinstance(seg, dict) else seg for seg in segments]
def create_messages(self, segments: List[TranscriptSegment],**kwargs):
content = BASE_PROMPT.format(
video_title=kwargs.get('title'),
print("UniversalGPT",kwargs)
content =generate_base_prompt(
title=kwargs.get('title'),
segment_text=self._build_segment_text(segments),
tags=kwargs.get('tags')
tags=kwargs.get('tags'),
_format=kwargs.get('_format'),
style=kwargs.get('style'),
extras=kwargs.get('extras')
)
if self.screenshot:
print(":需要截图")
content += SCREENSHOT
if self.link:
print(":需要链接")
content += LINK
print(content)
return [{"role": "user", "content": content + AI_SUM}]
return [{"role": "user", "content": content }]
def list_models(self):
return self.client.list_models()
return self.client.models.list()
def summarize(self, source: GPTSource) -> str:
self.screenshot = source.screenshot
self.link = source.link
source.segment = self.ensure_segments_type(source.segment)
messages = self.create_messages(source.segment, source.title,source.tags)
response = self.client.chat(
messages = self.create_messages(
source.segment,
title=source.title,
tags=source.tags
,
_format=source._format,
style=source.style,
extras=source.extras
)
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=0.7

View File

@@ -0,0 +1,36 @@
from fastapi import APIRouter
from pydantic import BaseModel
from app.services.model import ModelService
from app.utils.response import ResponseWrapper as R
router = APIRouter()
modelService = ModelService()
class CreateModelRequest(BaseModel):
provider_id: str
model_name: str
# 返回体:模型信息
class ModelItem(BaseModel):
id: int
model_name: str
@router.get("/model_list")
def model_list():
try:
return R.success(modelService.get_all_models(True),msg="获取模型列表成功")
except Exception as e:
return R.error(e)
@router.get("/model_list/{provider_id}")
def model_list(provider_id):
try:
return R.success(modelService.get_all_models_by_id(provider_id))
except Exception as e:
return R.error(e)
@router.post("/models")
def create_model(data: CreateModelRequest):
success = ModelService.add_new_model(data.provider_id, data.model_name)
if not success:
raise R.error("模型添加失败")
return R.success(msg="模型添加成功")

View File

@@ -10,13 +10,14 @@ from dataclasses import asdict
from app.db.video_task_dao import get_task_by_video
from app.enmus.note_enums import DownloadQuality
from app.services.note import NoteGenerator
from app.services.note import NoteGenerator, logger
from app.utils.response import ResponseWrapper as R
from app.utils.url_parser import extract_video_id
from app.validators.video_url_validator import is_supported_video_url
from fastapi import APIRouter, Request, HTTPException
from fastapi.responses import StreamingResponse
import httpx
from app.enmus.task_status_enums import TaskStatus
# from app.services.downloader import download_raw_audio
# from app.services.whisperer import transcribe_audio
@@ -35,6 +36,12 @@ class VideoRequest(BaseModel):
quality: DownloadQuality
screenshot: Optional[bool] = False
link: Optional[bool] = False
model_name:str
provider_id:str
task_id: Optional[str] = None
format:Optional[list]=[]
style:str=None
extras:Optional[str]
@validator("video_url")
def validate_supported_url(cls, v):
@@ -54,14 +61,24 @@ def save_note_to_file(task_id: str, note):
json.dump(asdict(note), f, ensure_ascii=False, indent=2)
def run_note_task(task_id: str, video_url: str, platform: str, quality: DownloadQuality, link: bool = False,screenshot: bool = False):
def run_note_task(task_id: str, video_url: str, platform: str, quality: DownloadQuality,
link: bool = False,screenshot: bool = False,model_name:str=None,provider_id:str=None,
_format:list=None,style:str=None,extras:str=None):
try:
if not model_name or not provider_id:
raise HTTPException(status_code=400, detail="请选择模型和提供者")
note = NoteGenerator().generate(
video_url=video_url,
platform=platform,
quality=quality,
task_id=task_id,
model_name=model_name,
provider_id=provider_id,
link=link,
_format=_format,
style=style,
extras=extras,
screenshot=screenshot
)
print('Note 结果',note)
@@ -85,38 +102,91 @@ def generate_note(data: VideoRequest, background_tasks: BackgroundTasks):
try:
video_id = extract_video_id(data.video_url, data.platform)
if not video_id:
raise HTTPException(status_code=400, detail="无法提取视频 ID")
existing = get_task_by_video(video_id, data.platform)
if existing:
return R.error(
msg='笔记已生成,请勿重复发起',
# if not video_id:
# raise HTTPException(status_code=400, detail="无法提取视频 ID")
# existing = get_task_by_video(video_id, data.platform)
# if existing:
# return R.error(
# msg='笔记已生成,请勿重复发起',
#
# )
)
task_id = str(uuid.uuid4())
if data.task_id:
# 如果传了task_id说明是重试
task_id = data.task_id
# 更新之前的状态
NoteGenerator.update_task_status(task_id, TaskStatus.PENDING)
logger.info(f"重试模式,复用已有 task_id={task_id}")
else:
# 正常新建任务
task_id = str(uuid.uuid4())
background_tasks.add_task(run_note_task, task_id, data.video_url, data.platform, data.quality,data.link ,data.screenshot)
background_tasks.add_task(run_note_task, task_id, data.video_url, data.platform, data.quality,data.link ,data.screenshot,data.model_name,data.provider_id,data.format,data.style,data.extras)
return R.success({"task_id": task_id})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/task_status/{task_id}")
def get_task_status(task_id: str):
path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json")
if not os.path.exists(path):
return R.success({"status": "PENDING"})
status_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.status.json")
result_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json")
with open(path, "r", encoding="utf-8") as f:
content = json.load(f)
# 优先读状态文件
if os.path.exists(status_path):
with open(status_path, "r", encoding="utf-8") as f:
status_content = json.load(f)
if "error" in content:
return R.error(content["error"], code=500)
content['id'] = task_id
status = status_content.get("status")
message = status_content.get("message", "")
if status == TaskStatus.SUCCESS.value:
# 成功状态的话,继续读取最终笔记内容
if os.path.exists(result_path):
with open(result_path, "r", encoding="utf-8") as rf:
result_content = json.load(rf)
return R.success({
"status": status,
"result": result_content,
"message": message,
"task_id": task_id
})
else:
# 理论上不会出现,保险处理
return R.success({
"status": TaskStatus.PENDING.value,
"message": "任务完成,但结果文件未找到",
"task_id": task_id
})
if status == TaskStatus.FAILED.value:
return R.error(message or "任务失败", code=500)
# 处理中状态
return R.success({
"status": status,
"message": message,
"task_id": task_id
})
# 没有状态文件,但有结果
if os.path.exists(result_path):
with open(result_path, "r", encoding="utf-8") as f:
result_content = json.load(f)
return R.success({
"status": TaskStatus.SUCCESS.value,
"result": result_content,
"task_id": task_id
})
# 什么都没有默认PENDING
return R.success({
"status": "SUCCESS",
"result": content
"status": TaskStatus.PENDING.value,
"message": "任务排队中",
"task_id": task_id
})

View File

@@ -1,6 +1,9 @@
from typing import Optional
from fastapi import APIRouter
from pydantic import BaseModel
from app.models.model_config import ModelConfig
from app.services.model import ModelService
from app.utils.response import ResponseWrapper as R
from app.services.provider import ProviderService
@@ -11,16 +14,21 @@ class ProviderRequest(BaseModel):
name: str
api_key: str
base_url: str
logo: str
logo: Optional[str] = None
type: str
class TestRequest(BaseModel):
api_key: str
base_url:str
class ProviderUpdateRequest(BaseModel):
id: int
id: str
name: Optional[str] = None
api_key: Optional[str] = None
base_url: Optional[str] = None
logo: Optional[str] = None
type: Optional[str] = None
enabled:Optional[int] = None
@router.post("/add_provider")
def add_provider(data: ProviderRequest):
@@ -45,7 +53,7 @@ def get_all_providers():
return R.error(msg=e)
@router.get("/get_provider_by_id/{id}")
def get_provider_by_id(id: int):
def get_provider_by_id(id: str):
try:
res = ProviderService.get_provider_by_id(id)
return R.success(data=res)
@@ -60,23 +68,33 @@ def get_provider_by_name(name: str):
except Exception as e:
return R.error(msg=e)
@router.post("/update_provider/")
@router.post("/update_provider")
def update_provider(data: ProviderUpdateRequest):
try:
if all(
field is None
for field in [data.name, data.api_key, data.base_url, data.logo, data.type]
for field in [data.name, data.api_key, data.base_url, data.logo, data.type,data.enabled]
):
return R.error(msg='请至少填写一个参数')
ProviderService.update_provider(
id=data.id,
name=data.name or '',
api_key=data.api_key or '',
base_url=data.base_url or '',
logo=data.logo or '',
type_=data.type or ''
data=dict(data)
)
return R.success(msg='更新模型供应商成功')
except Exception as e:
print(e)
return R.error(msg=e)
@router.post('/connect_test')
def gpt_connect_test(data:TestRequest):
try:
res= ModelService().connect_test(data.api_key,data.base_url)
if not res:
return R.error(msg='连接失败')
return R.success(msg='连接成功')
except Exception as e:
print(e)
return R.error(msg=e)

View File

@@ -1,23 +1,109 @@
from app.db.model_dao import insert_model, get_all_models
from app.db.provider_dao import get_enabled_providers
from app.gpt.gpt_factory import GPTFactory
from app.gpt.provider.OpenAI_compatible_provider import OpenAICompatibleProvider
from app.models.model_config import ModelConfig
from app.services.provider import ProviderService
class ModelService:
@staticmethod
def get_model_list(provider_id: int):
provider=ProviderService.get_provider_by_id(provider_id)
def _build_model_config(provider: dict) -> ModelConfig:
return ModelConfig(
api_key=provider["api_key"],
base_url=provider["base_url"],
provider=provider["name"],
model_name='',
name=provider["name"],
)
@staticmethod
def get_model_list(provider_id: int, verbose: bool = False):
provider = ProviderService.get_provider_by_id(provider_id)
if not provider:
return []
config=ModelConfig(
api_key=provider.api_key,
base_url=provider.base_url,
provider=provider.name,
model_name='',
name=provider.name,
)
GPT=GPTFactory().from_config(config)
return GPT.list_models()
try:
config = ModelService._build_model_config(provider)
gpt = GPTFactory().from_config(config)
models = gpt.list_models()
if verbose:
print(f"[{provider['name']}] 模型列表: {models}")
return models
except Exception as e:
print(f"[{provider['name']}] 获取模型失败: {e}")
return []
@staticmethod
def get_all_models(verbose: bool = False):
try:
raw_models = get_all_models()
if verbose:
print(f"所有模型列表: {raw_models}")
return ModelService._format_models(raw_models)
except Exception as e:
print(f"获取所有模型失败: {e}")
return []
@staticmethod
def _format_models(raw_models: list) -> list:
"""
格式化模型列表
"""
formatted = []
for model in raw_models:
formatted.append({
"id": model.get("id"),
"provider_id": model.get("provider_id"),
"model_name": model.get("model_name"),
"created_at": model.get("created_at", None), # 如果有created_at字段
})
return formatted
@staticmethod
def get_all_models_by_id(provider_id: str, verbose: bool = False):
try:
provider = ProviderService.get_provider_by_id(provider_id)
models = ModelService.get_model_list(provider["id"], verbose=verbose)
model_list={
"models": models
}
return model_list
except Exception as e:
print(f"[{provider_id}] 获取模型失败: {e}")
return []
@staticmethod
def connect_test(api_key: str, base_url: str) -> bool:
try:
return OpenAICompatibleProvider.test_connection(api_key=api_key, base_url=base_url)
except Exception as e:
print(f"连接测试失败:{e}")
return False
@staticmethod
def add_new_model(provider_id: int, model_name: str) -> bool:
try:
# 先查供应商是否存在
provider = ProviderService.get_provider_by_id(provider_id)
if not provider:
print(f"供应商ID {provider_id} 不存在,无法添加模型")
return False
# 插入模型
insert_model(provider_id=provider_id, model_name=model_name)
print(f"模型 {model_name} 已成功添加到供应商ID {provider_id}")
return True
except Exception as e:
print(f"添加模型失败: {e}")
return False
if __name__ == '__main__':
print(ModelService.get_model_list(1))
# 单个 Provider 测试
print(ModelService.get_model_list(1, verbose=True))
# 所有 Provider 模型测试
# print(ModelService.get_all_models(verbose=True))

View File

@@ -1,5 +1,9 @@
import json
from dataclasses import asdict
from app.enmus.task_status_enums import TaskStatus
import os
from typing import Union
from typing import Union, Optional
from pydantic import HttpUrl
@@ -10,13 +14,17 @@ from app.downloaders.douyin_downloader import DouyinDownloader
from app.downloaders.youtube_downloader import YoutubeDownloader
from app.gpt.base import GPT
from app.gpt.deepseek_gpt import DeepSeekGPT
from app.gpt.gpt_factory import GPTFactory
from app.gpt.openai_gpt import OpenaiGPT
from app.gpt.qwen_gpt import QwenGPT
from app.models.gpt_model import GPTSource
from app.models.model_config import ModelConfig
from app.models.notes_model import NoteResult
from app.models.notes_model import AudioDownloadResult
from app.enmus.note_enums import DownloadQuality
from app.models.transcriber_model import TranscriptResult
from app.models.transcriber_model import TranscriptResult, TranscriptSegment
from app.services.provider import ProviderService
from app.transcriber.base import Transcriber
from app.transcriber.transcriber_provider import get_transcriber,_transcribers
from app.transcriber.whisper import WhisperTranscriber
@@ -29,6 +37,8 @@ from app.utils.video_helper import generate_screenshot
# from app.services.gpt import summarize_text
from dotenv import load_dotenv
from app.utils.logger import get_logger
from events import transcription_finished
logger = get_logger(__name__)
load_dotenv()
BACKEND_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:8000")
@@ -37,7 +47,7 @@ output_dir = os.getenv('OUT_DIR')
image_base_url = os.getenv('IMAGE_BASE_URL')
logger.info("starting up")
NOTE_OUTPUT_DIR = "note_results"
class NoteGenerator:
def __init__(self):
@@ -45,26 +55,39 @@ class NoteGenerator:
self.device: Union[str, None] = None
self.transcriber_type = os.getenv('TRANSCRIBER_TYPE','fast-whisper')
self.transcriber = self.get_transcriber()
# TODO 需要更换为可调节
self.provider = os.getenv('MODEl_PROVIDER','openai')
self.video_path = None
logger.info("初始化NoteGenerator")
import logging
def get_gpt(self) -> GPT:
if self.provider == 'openai':
logger.info("使用OpenAI")
return OpenaiGPT()
elif self.provider == 'deepSeek':
logger.info("使用DeepSeek")
return DeepSeekGPT()
elif self.provider == 'qwen':
logger.info("使用Qwen")
return QwenGPT()
else:
logger.warning("不支持的AI提供商")
raise ValueError(f"不支持的AI提供商{self.provider}")
logger = logging.getLogger(__name__)
@staticmethod
def update_task_status(task_id: str, status: Union[str, TaskStatus], message: Optional[str] = None):
os.makedirs(NOTE_OUTPUT_DIR, exist_ok=True)
path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.status.json")
content = {"status": status.value if isinstance(status, TaskStatus) else status}
if message:
content["message"] = message
with open(path, "w", encoding="utf-8") as f:
json.dump(content, f, ensure_ascii=False, indent=2)
def get_gpt(self, model_name: str = None, provider_id: str = None) -> GPT:
provider = ProviderService.get_provider_by_id(provider_id)
if not provider:
logger.error(f"[get_gpt] 未找到对应的模型供应商: provider_id={provider_id}")
raise ValueError(f"未找到对应的模型供应商: provider_id={provider_id}")
gpt = GPTFactory().from_config(
ModelConfig(
api_key=provider.get('api_key'),
base_url=provider.get('base_url'),
model_name=model_name,
provider=provider.get('type'),
name=provider.get('name')
)
)
return gpt
def get_downloader(self, platform: str) -> Downloader:
if platform == "bilibili":
@@ -98,7 +121,7 @@ class NoteGenerator:
insert_video_task(video_id=video_id, platform=platform, task_id=task_id)
def insert_screenshots_into_markdown(self, markdown: str, video_path: str, image_base_url: str,
output_dir: str) -> str:
output_dir: str,_format:list) -> str:
"""
扫描 markdown 中的 *Screenshot-xx:xx生成截图并插入 markdown 图片
:param markdown:
@@ -145,62 +168,143 @@ class NoteGenerator:
def generate(
self,
video_url: Union[str, HttpUrl],
platform: str,
quality: DownloadQuality = DownloadQuality.medium,
task_id: Union[str, None] = None,
model_name: str = None,
provider_id: str = None,
link: bool = False,
screenshot: bool = False,
_format: list = None,
style: str = None,
extras: str = None,
path: Union[str, None] = None
) -> NoteResult:
logger.info(f"开始解析并生成笔记")
# 1. 选择下载器
downloader = self.get_downloader(platform)
gpt = self.get_gpt()
logger.info(f'使用{downloader.__class__.__name__}下载器\n'
f'使用{gpt.__class__.__name__}GPT\n'
f'视频地址:{video_url}')
if screenshot:
try:
logger.info(f"🎯 开始解析并生成笔记task_id={task_id}")
self.update_task_status(task_id, TaskStatus.PARSING)
_path=''
downloader = self.get_downloader(platform)
gpt = self.get_gpt(model_name=model_name, provider_id=provider_id)
video_path = downloader.download_video(video_url)
self.video_path = video_path
print(video_path)
audio_cache_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}_audio.json")
transcript_cache_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}_transcript.json")
markdown_cache_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}_markdown.md")
# 2. 下载音频
audio: AudioDownloadResult = downloader.download(
video_url=video_url,
quality=quality,
output_dir=path,
need_video=screenshot
# -------- 1. 下载音频 --------
try:
self.update_task_status(task_id, TaskStatus.DOWNLOADING)
if os.path.exists(audio_cache_path):
logger.info(f"检测到已有音频缓存直接读取task_id={task_id}")
with open(audio_cache_path, "r", encoding="utf-8") as f:
audio_data = json.load(f)
audio = AudioDownloadResult(**audio_data)
else:
if 'screenshot' in _format:
video_path = downloader.download_video(video_url)
self.video_path = video_path
logger.info(f"成功下载视频文件: {video_path}")
screenshot= 'screenshot' in _format
audio: AudioDownloadResult = downloader.download(
video_url=video_url,
quality=quality,
output_dir=path,
need_video=screenshot
)
_path=audio.raw_info.get('path')
with open(audio_cache_path, "w", encoding="utf-8") as f:
json.dump(audio.__dict__, f, ensure_ascii=False, indent=2)
logger.info(f"音频下载并缓存成功task_id={task_id}")
except Exception as e:
logger.error(f"❌ 下载音频失败task_id={task_id},错误信息:{e}")
self.update_task_status(task_id, TaskStatus.FAILED, message=f"下载音频失败:{e}")
raise e
# -------- 2. 转写文字 --------
try:
self.update_task_status(task_id, TaskStatus.TRANSCRIBING)
if os.path.exists(transcript_cache_path):
logger.info(f"检测到已有转写缓存直接读取task_id={task_id}")
with open(transcript_cache_path, "r", encoding="utf-8") as f:
transcript_data = json.load(f)
transcript = TranscriptResult(
language=transcript_data["language"],
full_text=transcript_data["full_text"],
segments=[TranscriptSegment(**seg) for seg in transcript_data["segments"]]
)
else:
transcript: TranscriptResult = self.transcriber.transcript(file_path=audio.file_path)
with open(transcript_cache_path, "w", encoding="utf-8") as f:
json.dump(asdict(transcript), f, ensure_ascii=False, indent=2)
logger.info(f"文字转写并缓存成功task_id={task_id}")
except Exception as e:
logger.error(f"❌ 转写文字失败task_id={task_id},错误信息:{e}")
self.update_task_status(task_id, TaskStatus.FAILED, message=f"转写文字失败:{e}")
raise e
# -------- 3. 总结内容 --------
try:
self.update_task_status(task_id, TaskStatus.SUMMARIZING)
if os.path.exists(markdown_cache_path):
logger.info(f"检测到已有总结缓存直接读取task_id={task_id}")
with open(markdown_cache_path, "r", encoding="utf-8") as f:
markdown = f.read()
else:
source = GPTSource(
title=audio.title,
segment=transcript.segments,
tags=audio.raw_info.get('tags'),
screenshot=screenshot,
link=link,
_format=_format,
style=style,
extras=extras
)
markdown: str = gpt.summarize(source)
with open(markdown_cache_path, "w", encoding="utf-8") as f:
f.write(markdown)
logger.info(f"GPT总结并缓存成功task_id={task_id}")
except Exception as e:
logger.error(f"❌ 总结内容失败task_id={task_id},错误信息:{e}")
self.update_task_status(task_id, TaskStatus.FAILED, message=f"总结内容失败:{e}")
raise e
# -------- 4. 插入截图 --------
if _format and 'screenshot' in _format:
try:
markdown = self.insert_screenshots_into_markdown(markdown, self.video_path, image_base_url, output_dir,_format)
except Exception as e:
logger.warning(f"⚠️ 插入截图失败跳过处理task_id={task_id},错误信息:{e}")
if _format and 'link' in _format:
try:
markdown = replace_content_markers(markdown, video_id=audio.video_id,platform=platform)
except Exception as e:
logger.warning(f"⚠️ 插入链接失败跳过处理task_id={task_id},错误信息:{e}")
# 注意:截图失败不终止整体流程
# -------- 5. 保存数据库记录 --------
self.update_task_status(task_id, TaskStatus.SAVING)
self.save_meta(video_id=audio.video_id, platform=platform, task_id=task_id)
# -------- 6. 完成 --------
self.update_task_status(task_id, TaskStatus.SUCCESS)
logger.info(f"✅ 笔记生成成功task_id={task_id}")
transcription_finished.send({
"file_path": audio.file_path,
})
return NoteResult(
markdown=markdown,
transcript=transcript,
audio_meta=audio
)
except Exception as e:
logger.error(f"❌ 笔记生成流程异常终止task_id={task_id},错误信息:{e}")
self.update_task_status(task_id, TaskStatus.FAILED, message=str(e))
raise f'❌ 笔记生成流程异常终止task_id={task_id},错误信息:{e}'
)
logger.info(f"下载音频成功,文件路径:{audio.file_path}")
# 3. Whisper 转写
transcript: TranscriptResult = self.transcriber.transcript(file_path=audio.file_path)
logger.info(f"Whisper 转写成功,转写结果:{transcript.full_text}")
# 4. GPT 总结
source = GPTSource(
title=audio.title,
segment=transcript.segments,
tags=audio.raw_info.get('tags'),
screenshot=screenshot,
link=link
)
logger.info(f"GPT 总结完成,总结结果:{source}")
markdown: str = gpt.summarize(source)
print("markdown结果", markdown)
markdown = replace_content_markers(markdown=markdown, video_id=audio.video_id, platform=platform)
if self.video_path:
markdown = self.insert_screenshots_into_markdown(markdown, self.video_path, image_base_url, output_dir)
self.save_meta(video_id=audio.video_id, platform=platform, task_id=task_id)
# 5. 返回结构体
return NoteResult(
markdown=markdown,
transcript=transcript,
audio_meta=audio
)

View File

@@ -1,3 +1,5 @@
from kombu import uuid
from app.db.provider_dao import (
insert_provider,
init_provider_table,
@@ -5,50 +7,65 @@ from app.db.provider_dao import (
get_provider_by_name,
get_provider_by_id,
update_provider,
delete_provider,
delete_provider, get_enabled_providers,
)
from app.gpt.gpt_factory import GPTFactory
from app.models.model_config import ModelConfig
class ProviderService:
@staticmethod
def serialize_provider(row: tuple) -> dict:
if not row:
return None
return {
"id": row[0],
"name": row[1],
"logo": row[2],
"type": row[3],
"api_key": row[4],
"base_url": row[5],
"enabled": row[6],
"created_at": row[7],
}
@staticmethod
def add_provider(name: str, api_key: str, base_url: str, logo: str, type_: str):
return insert_provider(name, api_key, base_url, logo, type_)
def add_provider( name: str, api_key: str, base_url: str, logo: str, type_: str, enabled: int = 1):
try:
id = uuid().lower()
logo='custom'
return insert_provider(id, name, api_key, base_url, logo, type_, enabled)
except Exception as e:
print('创建模式失败',e)
@staticmethod
def get_all_providers():
provider_list = []
provider = get_all_providers()
for i in provider:
provider_list.append({
"id": i[0],
"name": i[1],
"logo": i[2],
"type": i[3], # ✅ 加上类型
"api_key": i[4],
"base_url": i[5],
})
return provider_list
rows = get_all_providers()
return [ProviderService.serialize_provider(row) for row in rows] if rows else []
@staticmethod
def get_provider_by_name(name: str):
return get_provider_by_name(name)
row = get_provider_by_name(name)
return ProviderService.serialize_provider(row)
@staticmethod
def get_provider_by_id(id: int):
return get_provider_by_id(id)
def get_provider_by_id(id: str): # 已改为 str 类型
row = get_provider_by_id(id)
return ProviderService.serialize_provider(row)
# all_models.extend(provider['models'])
@staticmethod
def update_provider(
id: int,
name: str,
api_key: str,
base_url: str,
logo: str,
type_: str
):
return update_provider(id, name, api_key, base_url, logo, type_)
def update_provider(id: str, data: dict):
try:
# 过滤掉空值
filtered_data = {k: v for k, v in data.items() if v is not None and k != 'id'}
print('更新模型供应商',filtered_data)
return update_provider(id, **filtered_data)
except Exception as e:
print('更新模型供应商失败:',e)
@staticmethod
def delete_provider(id: int):
return delete_provider(id)
def delete_provider(id: str):
return delete_provider(id)

View File

@@ -0,0 +1,251 @@
import json
import logging
import time
from typing import Optional, List, Dict, Union
import requests
from app.decorators.timeit import timeit
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
from app.transcriber.base import Transcriber
from app.utils.logger import get_logger
from events import transcription_finished
__version__ = "0.0.3"
API_BASE_URL = "https://member.bilibili.com/x/bcut/rubick-interface"
# 申请上传
API_REQ_UPLOAD = API_BASE_URL + "/resource/create"
# 提交上传
API_COMMIT_UPLOAD = API_BASE_URL + "/resource/create/complete"
# 创建任务
API_CREATE_TASK = API_BASE_URL + "/task"
# 查询结果
API_QUERY_RESULT = API_BASE_URL + "/task/result"
logger = get_logger(__name__)
class BcutTranscriber(Transcriber):
"""必剪 语音识别接口"""
headers = {
'User-Agent': 'Bilibili/1.0.0 (https://www.bilibili.com)',
'Content-Type': 'application/json'
}
def __init__(self):
self.session = requests.Session()
self.task_id = None
self.__etags = []
self.__in_boss_key: Optional[str] = None
self.__resource_id: Optional[str] = None
self.__upload_id: Optional[str] = None
self.__upload_urls: List[str] = []
self.__per_size: Optional[int] = None
self.__clips: Optional[int] = None
self.__etags: List[str] = []
self.__download_url: Optional[str] = None
self.task_id: Optional[str] = None
def _load_file(self, file_path: str) -> bytes:
"""读取文件内容"""
with open(file_path, 'rb') as f:
return f.read()
def _upload(self, file_path: str) -> None:
"""申请上传"""
file_binary = self._load_file(file_path)
if not file_binary:
raise ValueError("无法读取文件数据")
payload = json.dumps({
"type": 2,
"name": "audio.mp3",
"size": len(file_binary),
"ResourceFileType": "mp3",
"model_id": "8",
})
resp = self.session.post(
API_REQ_UPLOAD,
data=payload,
headers=self.headers
)
resp.raise_for_status()
resp = resp.json()
resp_data = resp["data"]
self.__in_boss_key = resp_data["in_boss_key"]
self.__resource_id = resp_data["resource_id"]
self.__upload_id = resp_data["upload_id"]
self.__upload_urls = resp_data["upload_urls"]
self.__per_size = resp_data["per_size"]
self.__clips = len(resp_data["upload_urls"])
logger.info(
f"申请上传成功, 总计大小{resp_data['size'] // 1024}KB, {self.__clips}分片, 分片大小{resp_data['per_size'] // 1024}KB: {self.__in_boss_key}"
)
self.__upload_part(file_binary)
self.__commit_upload()
def __upload_part(self, file_binary: bytes) -> None:
"""上传音频数据"""
for clip in range(self.__clips):
start_range = clip * self.__per_size
end_range = min((clip + 1) * self.__per_size, len(file_binary))
logger.info(f"开始上传分片{clip}: {start_range}-{end_range}")
resp = self.session.put(
self.__upload_urls[clip],
data=file_binary[start_range:end_range],
headers={'Content-Type': 'application/octet-stream'}
)
resp.raise_for_status()
etag = resp.headers.get("Etag", "").strip('"')
self.__etags.append(etag)
logger.info(f"分片{clip}上传成功: {etag}")
def __commit_upload(self) -> None:
"""提交上传数据"""
data = json.dumps({
"InBossKey": self.__in_boss_key,
"ResourceId": self.__resource_id,
"Etags": ",".join(self.__etags),
"UploadId": self.__upload_id,
"model_id": "8",
})
resp = self.session.post(
API_COMMIT_UPLOAD,
data=data,
headers=self.headers
)
resp.raise_for_status()
resp = resp.json()
print('Bili',resp)
if resp.get("code") != 0:
error_msg = f"上传提交失败: {resp.get('message', '未知错误')}"
logger.error(error_msg)
raise Exception(error_msg)
self.__download_url = resp["data"]["download_url"]
logger.info(f"提交成功,下载链接: {self.__download_url}")
def _create_task(self) -> str:
"""开始创建转换任务"""
resp = self.session.post(
API_CREATE_TASK, json={"resource": self.__download_url, "model_id": "8"}, headers=self.headers
)
resp.raise_for_status()
resp = resp.json()
if resp.get("code") != 0:
error_msg = f"创建任务失败: {resp.get('message', '未知错误')}"
logger.error(error_msg)
raise Exception(error_msg)
self.task_id = resp["data"]["task_id"]
logger.info(f"任务已创建: {self.task_id}")
return self.task_id
def _query_result(self) -> dict:
"""查询转换结果"""
resp = self.session.get(
API_QUERY_RESULT,
params={"model_id": 7, "task_id": self.task_id},
headers=self.headers
)
resp.raise_for_status()
resp = resp.json()
if resp.get("code") != 0:
error_msg = f"查询结果失败: {resp.get('message', '未知错误')}"
logger.error(error_msg)
raise Exception(error_msg)
return resp["data"]
@timeit
def transcript(self, file_path: str) -> TranscriptResult:
"""执行识别过程,符合 Transcriber 接口"""
try:
logger.info(f"开始处理文件: {file_path}")
# 上传文件
logger.info("正在上传文件...")
self._upload(file_path)
# 创建任务
logger.info("提交转录任务...")
self._create_task()
# 轮询检查任务状态
logger.info("等待转录结果...")
task_resp = None
max_retries = 500
for i in range(max_retries):
task_resp = self._query_result()
if task_resp["state"] == 4: # 完成状态
break
elif task_resp["state"] == 3: # 失败状态
error_msg = f"B站ASR任务失败状态码: {task_resp['state']}"
logger.error(error_msg)
raise Exception(error_msg)
# 每隔一段时间打印进度
if i % 10 == 0:
logger.info(f"转录进行中... {i}/{max_retries}")
time.sleep(1)
if not task_resp or task_resp["state"] != 4:
error_msg = f"B站ASR任务未能完成状态: {task_resp.get('state') if task_resp else 'Unknown'}"
logger.error(error_msg)
raise Exception(error_msg)
# 解析结果
logger.info("转录成功,处理结果...")
result_json = json.loads(task_resp["result"])
# 提取分段数据
segments = []
full_text = ""
for u in result_json.get("utterances", []):
text = u.get("transcript", "").strip()
# B站ASR返回的时间戳是毫秒需要转换为秒
start_time = float(u.get("start_time", 0)) / 1000.0
end_time = float(u.get("end_time", 0)) / 1000.0
full_text += text + " "
segments.append(TranscriptSegment(
start=start_time,
end=end_time,
text=text
))
# 创建结果对象
result = TranscriptResult(
language=result_json.get("language", "zh"),
full_text=full_text.strip(),
segments=segments,
raw=result_json
)
# 触发完成事件
# self.on_finish(file_path, result)
return result
except Exception as e:
logger.error(f"B站ASR处理失败: {str(e)}")
raise
def on_finish(self, video_path: str, result: TranscriptResult) -> None:
"""转录完成的回调"""
logger.info(f"B站ASR转写完成: {video_path}")
transcription_finished.send({
"file_path": video_path,
})

View File

@@ -0,0 +1,115 @@
import requests
import logging
import os
from typing import Union, List, Dict, Optional
from app.decorators.timeit import timeit
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
from app.transcriber.base import Transcriber
from app.utils.logger import get_logger
from events import transcription_finished
logger = get_logger(__name__)
class KuaishouTranscriber(Transcriber):
"""快手语音识别实现"""
API_URL = "https://ai.kuaishou.com/api/effects/subtitle_generate"
def __init__(self):
pass
def _load_file(self, file_path: str) -> bytes:
"""读取文件内容"""
with open(file_path, 'rb') as f:
return f.read()
def _submit(self, file_path: str) -> dict:
"""提交识别请求"""
try:
file_binary = self._load_file(file_path)
payload = {
"typeId": "1"
}
# 使用文件名作为上传文件名
file_name = os.path.basename(file_path)
files = [('file', (file_name, file_binary, 'audio/mpeg'))]
logger.info(f"开始向快手API提交请求文件: {file_name}")
response = requests.post(self.API_URL, data=payload, files=files, timeout=300)
response.raise_for_status() # 检查HTTP错误
result = response.json()
print('result',result)
# 检查快手API返回是否包含错误
if "data" not in result or result.get("code", 0) != 0:
error_msg = f"快手API返回错误: {result.get('message', '未知错误')}"
logger.error(error_msg)
raise Exception(error_msg)
return result
except requests.exceptions.RequestException as e:
error_msg = f"快手ASR请求网络错误: {str(e)}"
logger.error(error_msg)
raise
except Exception as e:
error_msg = f"快手ASR请求处理错误: {str(e)}"
logger.error(error_msg)
raise
@timeit
def transcript(self, file_path: str) -> TranscriptResult:
"""执行转录过程,符合 Transcriber 接口"""
try:
logger.info(f"开始处理文件: {file_path}")
# 提交请求并获取结果
logger.info("向快手API提交识别请求...")
result_data = self._submit(file_path)
logger.info("请求成功,处理结果...")
# 提取分段数据
segments = []
full_text = ""
# 解析快手API返回的文本段
texts = result_data.get('data', {}).get('text', [])
for u in texts:
text = u.get('text', '').strip()
start_time = float(u.get('start_time', 0))
end_time = float(u.get('end_time', 0))
full_text += text + " "
segments.append(TranscriptSegment(
start=start_time,
end=end_time,
text=text
))
# 创建结果对象
result = TranscriptResult(
language="zh", # 快手API可能不返回语言信息默认为中文
full_text=full_text.strip(),
segments=segments,
raw=result_data
)
# 触发完成事件
# self.on_finish(file_path, result)
return result
except Exception as e:
logger.error(f"快手ASR处理失败: {str(e)}")
raise
def on_finish(self, video_path: str, result: TranscriptResult) -> None:
"""转录完成的回调"""
logger.info(f"快手ASR转写完成: {video_path}")
transcription_finished.send({
"file_path": video_path,
})

View File

@@ -31,15 +31,15 @@ _transcribers = {
def get_whisper_transcriber(model_size="base", device="cuda"):
"""获取 Whisper 转录器实例"""
if _transcribers['fast-whisper'] is None:
if _transcribers['fast-whisper'] is None:
logger.info(f'创建 Whisper 转录器实例,参数:{model_size}, {device}')
try:
_transcribers['fast-whisper'] = WhisperTranscriber(model_size=model_size, device=device)
_transcribers['whisper'] = WhisperTranscriber(model_size=model_size, device=device)
logger.info('Whisper 转录器创建成功')
except Exception as e:
logger.error(f"Whisper 转录器创建失败: {e}")
raise
return _transcribers['fast-whisper']
return _transcribers['whisper']
def get_bcut_transcriber():
"""获取 Bcut 转录器实例"""

View File

@@ -4,14 +4,19 @@ from app.decorators.timeit import timeit
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
from app.transcriber.base import Transcriber
from app.utils.env_checker import is_cuda_available, is_torch_installed
from app.utils.logger import get_logger
from app.utils.path_helper import get_model_dir
from events import transcription_finished
from pathlib import Path
import os
from tqdm import tqdm
from huggingface_hub import snapshot_download
'''
Size of the model to use (tiny, tiny.en, base, base.en, small, small.en, distil-small.en, medium, medium.en, distil-medium.en, large-v1, large-v2, large-v3, large, distil-large-v2, distil-large-v3, large-v3-turbo, or turbo
'''
logger=get_logger(__name__)
class WhisperTranscriber(Transcriber):
# TODO:修改为可配置
@@ -31,15 +36,25 @@ class WhisperTranscriber(Transcriber):
self.compute_type = compute_type or ("float16" if self.device == "cuda" else "int8")
model_path = get_model_dir("whisper")
model_dir = get_model_dir("whisper")
model_path = os.path.join(model_dir, f"whisper-{model_size}")
if not Path(model_path).exists():
logger.info(f"模型 whisper-{model_size} 不存在,开始下载...")
repo_id = f"guillaumekln/faster-whisper-{model_size}"
snapshot_download(
repo_id,
local_dir=model_path,
local_dir_use_symlinks=False,
)
logger.info("模型下载完成")
self.model = WhisperModel(
model_size,
device=self.device,
# compute_type="int8", # 或 "float16"
compute_type=self.compute_type,
cpu_threads=cpu_threads,
download_root=model_path
download_root=model_dir
)
@staticmethod
def is_torch_installed() -> bool:
try:
@@ -88,7 +103,7 @@ class WhisperTranscriber(Transcriber):
segments=segments,
raw=info
)
self.on_finish(file_path, result)
# self.on_finish(file_path, result)
return result
except Exception as e:
print(f"转写失败:{e}")

View File

@@ -4,6 +4,7 @@ import uvicorn
from starlette.staticfiles import StaticFiles
from dotenv import load_dotenv
from app.db.model_dao import init_model_table
from app.db.provider_dao import init_provider_table
from app.utils.logger import get_logger
from app import create_app
@@ -39,6 +40,7 @@ async def startup_event():
get_transcriber(transcriber_type=os.getenv("TRANSCRIBER_TYPE","fast-whisper"))
init_video_task_table()
init_provider_table()
init_model_table()
if __name__ == "__main__":
port = int(os.getenv("BACKEND_PORT", 8000))