mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-06-26 02:01:38 +08:00
feat: 新增模型管理和供应商配置功能
### v1.1.0 - #### Added - 新增 AI 笔记风格选择 - 新增 AI 笔记返回格式选择 - 添加 AI 自定义笔记备注 Prompt - 添加任务失败重试 - 添加全局设置页,可在设置页进行模型设置 - #### Optimize - 优化前端样式,优化用户体验 - 增加生成中间产物,可用于失败后加快生成速度 - #### Fix - 修复视频截图视频过早删除错误
This commit is contained in:
19
.env.example
19
.env.example
@@ -1,3 +1,11 @@
|
||||
###
|
||||
# @Author: 思诺特 jefferyhcool@gmail.com
|
||||
# @Date: 2025-04-14 08:49:59
|
||||
# @LastEditors: 思诺特 jefferyhcool@gmail.com
|
||||
# @LastEditTime: 2025-04-26 19:56:50
|
||||
# @FilePath: \BiliNote\.env.example
|
||||
# @Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
|
||||
###
|
||||
# 通用端口配置
|
||||
BACKEND_PORT=8001
|
||||
FRONTEND_PORT=3015
|
||||
@@ -13,17 +21,6 @@ STATIC=/static
|
||||
OUT_DIR=./static/screenshots
|
||||
IMAGE_BASE_URL=/static/screenshots
|
||||
DATA_DIR=data
|
||||
# AI 相关配置
|
||||
OPENAI_API_KEY=
|
||||
OPENAI_API_BASE_URL=
|
||||
OPENAI_MODEL=
|
||||
DEEP_SEEK_API_KEY=
|
||||
DEEP_SEEK_API_BASE_URL=
|
||||
DEEP_SEEK_MODEL=
|
||||
QWEN_API_KEY=
|
||||
QWEN_API_BASE_URL=
|
||||
QWEN_MODEL=
|
||||
MODEl_PROVIDER= #如果不是openai 请修改 deepseek/qwen
|
||||
# FFMPEG 配置
|
||||
FFMPEG_BIN_PATH=
|
||||
|
||||
|
||||
@@ -1,13 +1,32 @@
|
||||
# === 前端构建阶段 ===
|
||||
FROM node:18-alpine AS build
|
||||
|
||||
# 安装 pnpm
|
||||
RUN npm install -g pnpm
|
||||
|
||||
# 设置工作目录
|
||||
WORKDIR /app
|
||||
|
||||
# 拷贝前端源码
|
||||
COPY ./BillNote_frontend /app
|
||||
|
||||
RUN npm install && npm run build
|
||||
# 安装依赖并构建
|
||||
RUN pnpm install && pnpm run build
|
||||
|
||||
# === nginx 运行阶段 ===
|
||||
FROM nginx:alpine
|
||||
|
||||
# 拷贝模板配置
|
||||
COPY ./BillNote_frontend/deploy/default.conf.template /etc/nginx/templates/default.conf.template
|
||||
|
||||
# 拷贝构建产物
|
||||
COPY --from=build /app/dist /usr/share/nginx/html
|
||||
|
||||
# 拷贝启动脚本
|
||||
COPY ./BillNote_frontend/deploy/start.sh /start.sh
|
||||
RUN chmod +x /start.sh
|
||||
|
||||
EXPOSE 80
|
||||
|
||||
# 使用启动脚本启动容器
|
||||
CMD ["/start.sh"]
|
||||
18
BillNote_frontend/deploy/default.conf.template
Normal file
18
BillNote_frontend/deploy/default.conf.template
Normal file
@@ -0,0 +1,18 @@
|
||||
server {
|
||||
listen 80;
|
||||
resolver 127.0.0.11 valid=10s;
|
||||
|
||||
location / {
|
||||
root /usr/share/nginx/html;
|
||||
index index.html;
|
||||
try_files $uri $uri/ /index.html;
|
||||
}
|
||||
|
||||
location /api/ {
|
||||
proxy_pass http://backend:${BACKEND_PORT};
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
}
|
||||
}
|
||||
20
BillNote_frontend/deploy/start.sh
Normal file
20
BillNote_frontend/deploy/start.sh
Normal file
@@ -0,0 +1,20 @@
|
||||
#!/bin/sh
|
||||
###
|
||||
# @Author: Jefferyhcool 1063474837@qq.com
|
||||
# @Date: 2025-04-16 01:57:05
|
||||
# @LastEditors: Jefferyhcool 1063474837@qq.com
|
||||
# @LastEditTime: 2025-04-16 01:59:37
|
||||
# @FilePath: /hotfix-dev/BillNote_frontend/deploy/start.sh
|
||||
# @Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
|
||||
###
|
||||
# 等待后端健康检查通过
|
||||
until curl -s "http://backend:${BACKEND_PORT}/health" > /dev/null; do
|
||||
echo "等待后端服务就绪..."
|
||||
sleep 2
|
||||
done
|
||||
|
||||
# 生成 nginx 配置文件(动态变量替换)
|
||||
envsubst '${BACKEND_HOST} ${BACKEND_PORT}' < /etc/nginx/templates/default.conf.template > /etc/nginx/conf.d/default.conf
|
||||
|
||||
# 启动 Nginx(在前台运行)
|
||||
exec nginx -g 'daemon off;'
|
||||
@@ -25,6 +25,7 @@
|
||||
"@radix-ui/react-tooltip": "^1.1.8",
|
||||
"@tailwindcss/vite": "^4.1.3",
|
||||
"@uiw/react-markdown-preview": "^5.1.3",
|
||||
"antd": "^5.24.8",
|
||||
"axios": "^1.8.4",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
|
||||
@@ -7,10 +7,19 @@ import { Route } from 'react-router-dom'
|
||||
import Index from '@/pages/Index.tsx'
|
||||
import NotFoundPage from '@/pages/NotFoundPage' //
|
||||
import Model from '@/pages/SettingPage/Model.tsx'
|
||||
import Transcriber from '@/pages/SettingPage/transcriber.tsx'
|
||||
import ProviderForm from '@/components/Form/modelForm/Form.tsx'
|
||||
import StepBar from '@/pages/HomePage/components/StepBar.tsx'
|
||||
import Downloading from '@/components/Lottie/download.tsx'
|
||||
function App() {
|
||||
useTaskPolling(3000) // 每 3 秒轮询一次
|
||||
|
||||
const steps = [
|
||||
{ label: '解析链接', key: 'PARSING', icon: <Downloading /> },
|
||||
{ label: '下载音频', key: 'DOWNLOADING' },
|
||||
{ label: '转写文字', key: 'TRANSCRIBING' },
|
||||
{ label: '总结内容', key: 'SUMMARIZING' },
|
||||
{ label: '保存完成', key: 'SUCCESS' },
|
||||
]
|
||||
return (
|
||||
<>
|
||||
<BrowserRouter>
|
||||
@@ -20,9 +29,11 @@ function App() {
|
||||
<Route path="settings" element={<SettingPage />}>
|
||||
<Route index element={<Navigate to="model" replace />} />
|
||||
<Route path="model" element={<Model />}>
|
||||
<Route index element={<Navigate to="openai" replace />} />
|
||||
<Route path="new" element={<ProviderForm isCreate />} />
|
||||
{/*<Route index element={<Navigate to="openai" replace />} />*/}
|
||||
<Route path=":id" element={<ProviderForm />} />
|
||||
</Route>
|
||||
<Route path="transcriber" elment={<Transcriber />}></Route>
|
||||
<Route path="*" element={<NotFoundPage />} />
|
||||
</Route>
|
||||
<Route path="*" element={<NotFoundPage />} />
|
||||
|
||||
BIN
BillNote_frontend/src/assets/customAI.png
Normal file
BIN
BillNote_frontend/src/assets/customAI.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.7 MiB |
@@ -1,153 +1,281 @@
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { z } from 'zod';
|
||||
import { zodResolver } from '@hookform/resolvers/zod';
|
||||
import { useForm } from 'react-hook-form'
|
||||
import { z } from 'zod'
|
||||
import { zodResolver } from '@hookform/resolvers/zod'
|
||||
import {
|
||||
Form,
|
||||
FormField,
|
||||
FormItem,
|
||||
FormLabel,
|
||||
FormControl,
|
||||
FormDescription,
|
||||
FormMessage,
|
||||
} from '@/components/ui/form';
|
||||
import { Input } from '@/components/ui/input';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { useParams } from 'react-router-dom';
|
||||
import { useProviderStore } from '@/store/providerStore';
|
||||
import {useEffect, useState} from 'react';
|
||||
FormDescription,
|
||||
} from '@/components/ui/form'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { useParams, useNavigate } from 'react-router-dom'
|
||||
import { useProviderStore } from '@/store/providerStore'
|
||||
import { useEffect, useState } from 'react'
|
||||
import toast from 'react-hot-toast'
|
||||
import { testConnection, fetchModels } from '@/services/model.ts'
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from '@/components/ui/select.tsx' // ⚡新增 fetchModels
|
||||
import { ModelSelector } from '@/components/Form/modelForm/ModelSelector.tsx'
|
||||
import { Alert, AlertDescription, AlertTitle } from '@/components/ui/alert.tsx'
|
||||
|
||||
// ✅ 表单校验 schema
|
||||
// ✅ Provider表单schema
|
||||
const ProviderSchema = z.object({
|
||||
name: z.string().min(2, '名称不能少于 2 个字符'),
|
||||
apiKey: z.string().optional(),
|
||||
baseUrl: z.string().url('必须是合法 URL'),
|
||||
type: z.string(), // 只展示,不可改
|
||||
});
|
||||
type: z.string(),
|
||||
})
|
||||
|
||||
type ProviderFormValues = z.infer<typeof ProviderSchema>;
|
||||
type ProviderFormValues = z.infer<typeof ProviderSchema>
|
||||
|
||||
const ProviderForm = () => {
|
||||
const rawId= useParams();
|
||||
console.log('rawId',rawId)
|
||||
// @ts-ignore
|
||||
const [providerName, idPart] = rawId.id.split('&');
|
||||
const [id,setId ]= useState(Number(idPart?.split('=')[1])) // => "1"
|
||||
const getProviderById = useProviderStore((state) => state.getProviderById);
|
||||
const provider = getProviderById(id);
|
||||
// ✅ Model表单schema
|
||||
const ModelSchema = z.object({
|
||||
modelName: z.string().min(1, '请选择或填写模型名称'),
|
||||
})
|
||||
|
||||
const form = useForm<ProviderFormValues>({
|
||||
type ModelFormValues = z.infer<typeof ModelSchema>
|
||||
interface IModel {
|
||||
id: string
|
||||
created: number
|
||||
object: string
|
||||
owned_by: string
|
||||
permission: string
|
||||
root: string
|
||||
}
|
||||
const ProviderForm = ({ isCreate = false }: { isCreate?: boolean }) => {
|
||||
const { id } = useParams()
|
||||
const navigate = useNavigate()
|
||||
const isEditMode = !isCreate
|
||||
|
||||
const getProviderById = useProviderStore(state => state.getProviderById)
|
||||
const loadProviderById = useProviderStore(state => state.loadProviderById)
|
||||
const updateProvider = useProviderStore(state => state.updateProvider)
|
||||
const addNewProvider = useProviderStore(state => state.addNewProvider)
|
||||
|
||||
const [loading, setLoading] = useState(true)
|
||||
const [testing, setTesting] = useState(false)
|
||||
const [isBuiltIn, setIsBuiltIn] = useState(false)
|
||||
const [modelOptions, setModelOptions] = useState<IModel[]>([]) // ⚡新增,保存模型列表
|
||||
const [modelLoading, setModelLoading] = useState(false)
|
||||
|
||||
const [search, setSearch] = useState('')
|
||||
const providerForm = useForm<ProviderFormValues>({
|
||||
resolver: zodResolver(ProviderSchema),
|
||||
defaultValues: {
|
||||
name: '',
|
||||
apiKey: '',
|
||||
baseUrl: '',
|
||||
type: '',
|
||||
type: 'custom',
|
||||
},
|
||||
});
|
||||
})
|
||||
const filteredModelOptions = modelOptions.filter(model => {
|
||||
const keywords = search.trim().toLowerCase().split(/\s+/) // 支持多个关键词
|
||||
const target = model.id.toLowerCase()
|
||||
return keywords.every(kw => target.includes(kw))
|
||||
})
|
||||
|
||||
const modelForm = useForm<ModelFormValues>({
|
||||
resolver: zodResolver(ModelSchema),
|
||||
defaultValues: {
|
||||
modelName: '',
|
||||
},
|
||||
})
|
||||
|
||||
useEffect(() => {
|
||||
console.log(provider)
|
||||
// if (provider) {
|
||||
// form.reset({
|
||||
// name: provider.name,
|
||||
// apiKey: provider.apiKey,
|
||||
// baseUrl: provider.baseUrl,
|
||||
// type: provider.type,
|
||||
// });
|
||||
// }
|
||||
}, [id,provider, form]);
|
||||
const load = async () => {
|
||||
if (isEditMode) {
|
||||
const data = await loadProviderById(id!)
|
||||
providerForm.reset(data)
|
||||
setIsBuiltIn(data.type === 'built-in')
|
||||
} else {
|
||||
providerForm.reset({
|
||||
name: '',
|
||||
apiKey: '',
|
||||
baseUrl: '',
|
||||
type: 'custom',
|
||||
})
|
||||
setIsBuiltIn(false)
|
||||
}
|
||||
setLoading(false)
|
||||
}
|
||||
load()
|
||||
}, [id])
|
||||
|
||||
const isBuiltIn = provider?.type === 'built-in';
|
||||
// 测试连通性
|
||||
const handleTest = async () => {
|
||||
const values = providerForm.getValues()
|
||||
if (!values.apiKey || !values.baseUrl) {
|
||||
toast.error('请填写 API Key 和 Base URL')
|
||||
return
|
||||
}
|
||||
try {
|
||||
setTesting(true)
|
||||
const data = await testConnection({
|
||||
api_key: values.apiKey,
|
||||
base_url: values.baseUrl,
|
||||
})
|
||||
if (data.data.code === 0) {
|
||||
toast.success('测试连通性成功 🎉')
|
||||
} else {
|
||||
toast.error(`连接失败: ${data.data.msg || '未知错误'}`)
|
||||
}
|
||||
} catch (error) {
|
||||
toast.error('测试连通性异常')
|
||||
} finally {
|
||||
setTesting(false)
|
||||
}
|
||||
}
|
||||
|
||||
const onSubmit = (values: ProviderFormValues) => {
|
||||
console.log('📝 提交表单数据:', values);
|
||||
// TODO: 提交接口 /update_provider
|
||||
};
|
||||
// 加载模型列表
|
||||
const handleModelLoad = async () => {
|
||||
const values = providerForm.getValues()
|
||||
if (!values.apiKey || !values.baseUrl) {
|
||||
toast.error('请先填写 API Key 和 Base URL')
|
||||
return
|
||||
}
|
||||
try {
|
||||
setModelLoading(true) // ✅ 开始 loading
|
||||
const res = await fetchModels(id!, { noCache: true }) // 这里稍后解释
|
||||
if (res.data.code === 0 && res.data.data.models.data.length > 0) {
|
||||
setModelOptions(res.data.data.models.data)
|
||||
console.log('🔧 模型列表:', res.data.data)
|
||||
toast.success('模型列表加载成功 🎉')
|
||||
} else {
|
||||
toast.error('未获取到模型列表')
|
||||
}
|
||||
} catch (error) {
|
||||
toast.error('加载模型列表失败')
|
||||
} finally {
|
||||
setModelLoading(false) // ✅ 结束 loading
|
||||
}
|
||||
}
|
||||
|
||||
// if (!provider) return <div className="p-4">加载中...</div>;
|
||||
// 保存Provider信息
|
||||
const onProviderSubmit = async (values: ProviderFormValues) => {
|
||||
if (isEditMode) {
|
||||
updateProvider({ ...values, id: id! })
|
||||
toast.success('更新供应商成功')
|
||||
} else {
|
||||
addNewProvider({ ...values })
|
||||
toast.success('新增供应商成功')
|
||||
}
|
||||
}
|
||||
|
||||
// 保存Model信息
|
||||
const onModelSubmit = async (values: ModelFormValues) => {
|
||||
console.log('🔧 选择的模型:', values.modelName)
|
||||
toast.success(`保存模型: ${values.modelName}`)
|
||||
}
|
||||
|
||||
if (loading) return <div className="p-4">加载中...</div>
|
||||
|
||||
return (
|
||||
|
||||
<Form {...form}>
|
||||
|
||||
<div className="flex flex-col gap-8 p-4">
|
||||
{/* Provider信息表单 */}
|
||||
<Form {...providerForm}>
|
||||
<form
|
||||
onSubmit={form.handleSubmit(onSubmit)}
|
||||
className="w-full max-w-xl p-4 flex flex-col gap-4"
|
||||
onSubmit={providerForm.handleSubmit(onProviderSubmit)}
|
||||
className="flex max-w-xl flex-col gap-4"
|
||||
>
|
||||
<div className="text-lg font-bold">模型供应商配置</div>
|
||||
|
||||
{/* 名称 */}
|
||||
<div className="text-lg font-bold">
|
||||
{isEditMode ? '编辑模型供应商' : '新增模型供应商'}
|
||||
</div>
|
||||
{!isBuiltIn && (
|
||||
<div className="text-sm text-red-500 italic">
|
||||
自定义模型供应商需要确保兼容 OpenAI SDK
|
||||
</div>
|
||||
)}
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="name"
|
||||
render={({ field }) => (
|
||||
<FormItem className="flex items-center gap-4">
|
||||
<FormLabel className="w-24 text-right">名称</FormLabel>
|
||||
<FormControl>
|
||||
<Input {...field} disabled={isBuiltIn} className="flex-1" />
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
control={providerForm.control}
|
||||
name="name"
|
||||
render={({ field }) => (
|
||||
<FormItem className="flex items-center gap-4">
|
||||
<FormLabel className="w-24 text-right">名称</FormLabel>
|
||||
<FormControl>
|
||||
<Input {...field} disabled={isBuiltIn} className="flex-1" />
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
{/* API Key */}
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="apiKey"
|
||||
render={({ field }) => (
|
||||
<FormItem className="flex items-center gap-4">
|
||||
<FormLabel className="w-24 text-right">API Key</FormLabel>
|
||||
<FormControl>
|
||||
<Input placeholder={'sk-xxx'} {...field} className="flex-1" />
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
|
||||
|
||||
</FormItem>
|
||||
|
||||
|
||||
|
||||
)}
|
||||
control={providerForm.control}
|
||||
name="apiKey"
|
||||
render={({ field }) => (
|
||||
<FormItem className="flex items-center gap-4">
|
||||
<FormLabel className="w-24 text-right">API Key</FormLabel>
|
||||
<FormControl>
|
||||
<Input {...field} className="flex-1" />
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
{/* Base URL */}
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="baseUrl"
|
||||
render={({ field }) => (
|
||||
<FormItem className="flex items-center gap-4">
|
||||
<FormLabel className="w-24 text-right">API 代理地址</FormLabel>
|
||||
<FormControl>
|
||||
<Input {...field} className="flex-1" />
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
control={providerForm.control}
|
||||
name="baseUrl"
|
||||
render={({ field }) => (
|
||||
<FormItem className="flex items-center gap-4">
|
||||
<FormLabel className="w-24 text-right">API地址</FormLabel>
|
||||
<FormControl>
|
||||
<Input {...field} className="flex-1" />
|
||||
</FormControl>
|
||||
<Button type="button" onClick={handleTest} variant="ghost" disabled={testing}>
|
||||
{testing ? '测试中...' : '测试连通性'}
|
||||
</Button>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
{/* 类型 */}
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="type"
|
||||
render={({ field }) => (
|
||||
<FormItem className="flex items-center gap-4">
|
||||
<FormLabel className="w-24 text-right">类型</FormLabel>
|
||||
<FormControl>
|
||||
<Input {...field} disabled className="flex-1" />
|
||||
</FormControl>
|
||||
</FormItem>
|
||||
)}
|
||||
control={providerForm.control}
|
||||
name="type"
|
||||
render={({ field }) => (
|
||||
<FormItem className="flex items-center gap-4">
|
||||
<FormLabel className="w-24 text-right">类型</FormLabel>
|
||||
<FormControl>
|
||||
<Input {...field} disabled className="flex-1" />
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
<div className="pt-2">
|
||||
<Button type="submit" disabled={!form.formState.isDirty}>
|
||||
保存修改
|
||||
<Button type="submit" disabled={!providerForm.formState.isDirty}>
|
||||
{isEditMode ? '保存修改' : '保存创建'}
|
||||
</Button>
|
||||
</div>
|
||||
</form>
|
||||
</Form>
|
||||
);
|
||||
};
|
||||
|
||||
export default ProviderForm;
|
||||
{/* 模型信息表单 */}
|
||||
<div className="flex max-w-xl flex-col gap-4">
|
||||
<div className="flex flex-col gap-2">
|
||||
<span className="font-bold">模型列表</span>
|
||||
<div className={'flex flex-col gap-2 rounded bg-[#FEF0F0] p-2.5'}>
|
||||
<h2 className={'font-bold'}>注意!</h2>
|
||||
<span>请确保已经保存供应商信息,以及通过测试连通性.</span>
|
||||
</div>
|
||||
<ModelSelector providerId={id!} />
|
||||
|
||||
{/*<datalist id="model-options">*/}
|
||||
{/* {modelOptions.map(model => (*/}
|
||||
{/* <option key={model.id + '1'} value={model.id} />*/}
|
||||
{/* ))}*/}
|
||||
{/*</datalist>*/}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default ProviderForm
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
// iconMap.ts
|
||||
import * as Icons from '@lobehub/icons'
|
||||
|
||||
export const IconMap = Icons;
|
||||
@@ -0,0 +1,29 @@
|
||||
import * as Icons from '@lobehub/icons'
|
||||
import CustomLogo from '@/assets/customAI.png'
|
||||
|
||||
interface AILogoProps {
|
||||
name: string // 图标名称(区分大小写!如 OpenAI、DeepSeek)
|
||||
style?: 'Color' | 'Text' | 'Outlined' | 'Glyph'
|
||||
size?: number
|
||||
}
|
||||
|
||||
const AILogo = ({ name, style = 'Color', size = 24 }: AILogoProps) => {
|
||||
const Icon = Icons[name as keyof typeof Icons]
|
||||
if (!Icon) {
|
||||
console.error(`❌ 图标组件不存在: ${name}`)
|
||||
return (
|
||||
<span style={{ fontSize: size }}>
|
||||
<img src={CustomLogo} alt="CustomLogo" style={{ width: size, height: size }} />
|
||||
</span>
|
||||
)
|
||||
}
|
||||
|
||||
const Variant = Icon[style as keyof typeof Icon]
|
||||
if (!Variant) {
|
||||
return <Icon size={size} />
|
||||
}
|
||||
|
||||
return <Variant size={size} />
|
||||
}
|
||||
|
||||
export default AILogo
|
||||
@@ -0,0 +1,92 @@
|
||||
import { useState, useEffect } from 'react'
|
||||
import { useModelStore } from '@/store/modelStore'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from '@/components/ui/select'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import toast from 'react-hot-toast'
|
||||
|
||||
interface ModelSelectorProps {
|
||||
providerId: string
|
||||
}
|
||||
|
||||
export function ModelSelector({ providerId }: ModelSelectorProps) {
|
||||
const { models, loading, selectedModel, loadModels, setSelectedModel, addNewModel } =
|
||||
useModelStore()
|
||||
const [search, setSearch] = useState('')
|
||||
const [submitting, setSubmitting] = useState(false)
|
||||
|
||||
const filteredModels = models.filter(model => {
|
||||
const keywords = search.trim().toLowerCase().split(/\s+/)
|
||||
const target = model.id.toLowerCase()
|
||||
return keywords.every(kw => target.includes(kw))
|
||||
})
|
||||
|
||||
useEffect(() => {
|
||||
if (providerId) {
|
||||
loadModels(providerId)
|
||||
}
|
||||
}, [providerId])
|
||||
|
||||
const handleSubmit = async () => {
|
||||
if (!selectedModel) {
|
||||
toast.error('请选择一个模型')
|
||||
return
|
||||
}
|
||||
try {
|
||||
setSubmitting(true)
|
||||
await addNewModel(providerId, selectedModel)
|
||||
toast.success('保存模型成功 🎉')
|
||||
} catch (error) {
|
||||
toast.error('保存失败')
|
||||
} finally {
|
||||
setSubmitting(false)
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex items-center gap-2 font-bold">
|
||||
<span>选择模型</span>
|
||||
<Button
|
||||
variant="ghost"
|
||||
type="button"
|
||||
onClick={() => loadModels(providerId)}
|
||||
disabled={loading}
|
||||
>
|
||||
{loading ? '加载中...' : '刷新模型'}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<Select value={selectedModel} onValueChange={setSelectedModel}>
|
||||
<SelectTrigger className="w-[300px]">
|
||||
<SelectValue placeholder="请选择模型" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<div className="p-2">
|
||||
<Input
|
||||
placeholder="搜索模型..."
|
||||
value={search}
|
||||
onChange={e => setSearch(e.target.value)}
|
||||
className="h-8"
|
||||
/>
|
||||
</div>
|
||||
{filteredModels.map(model => (
|
||||
<SelectItem key={model.id} value={model.id}>
|
||||
{model.id}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
|
||||
<Button onClick={handleSubmit} disabled={submitting || !selectedModel}>
|
||||
{submitting ? '保存中...' : '保存模型'}
|
||||
</Button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -1,28 +1,39 @@
|
||||
import ProviderCard from '@/components/Form/modelForm/components/providerCard.tsx'
|
||||
import { Button } from '@/components/ui/button.tsx'
|
||||
import { useProviderStore } from '@/store/providerStore'
|
||||
import { useNavigate } from 'react-router-dom'
|
||||
|
||||
const Provider = () => {
|
||||
const providers = useProviderStore(state => state.provider)
|
||||
|
||||
const providers = useProviderStore(state => state.provider)
|
||||
const navigate = useNavigate()
|
||||
const handleClick = () => {
|
||||
navigate(`/settings/model/new`)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-2">
|
||||
<div className={'search flex gap-1 py-1.5'}>
|
||||
<Button type={'button'} className="w-full">
|
||||
<Button
|
||||
type={'button'}
|
||||
onClick={() => {
|
||||
handleClick()
|
||||
}}
|
||||
className="w-full"
|
||||
>
|
||||
添加模型供应商
|
||||
</Button>
|
||||
</div>
|
||||
<div className="text-sm font-light">模型供应商列表</div>
|
||||
<div>
|
||||
{providers &&
|
||||
providers.map((provider, index) => {
|
||||
providers.map((provider, index) => {
|
||||
return (
|
||||
<ProviderCard
|
||||
key={index}
|
||||
providerName={provider.name}
|
||||
Icon={provider.logo}
|
||||
id={provider.id}
|
||||
enable={provider.enabled}
|
||||
/>
|
||||
)
|
||||
})}
|
||||
|
||||
@@ -1,22 +1,37 @@
|
||||
import { Switch } from '@/components/ui/switch'
|
||||
import { FC } from 'react'
|
||||
import styles from './index.module.css'
|
||||
import {useNavigate, useParams} from 'react-router-dom'
|
||||
import AILogo from "@/components/Icons";
|
||||
import { useNavigate, useParams } from 'react-router-dom'
|
||||
import AILogo from '@/components/Form/modelForm/Icons'
|
||||
import { useProviderStore } from '@/store/providerStore'
|
||||
export interface IProviderCardProps {
|
||||
id: string
|
||||
providerName: string
|
||||
Icon: string
|
||||
enable: number
|
||||
}
|
||||
const ProviderCard: FC<IProviderCardProps> = ({ providerName, Icon, id }: IProviderCardProps) => {
|
||||
const ProviderCard: FC<IProviderCardProps> = ({
|
||||
providerName,
|
||||
Icon,
|
||||
id,
|
||||
enable,
|
||||
}: IProviderCardProps) => {
|
||||
const navigate = useNavigate()
|
||||
const updateProvider = useProviderStore(state => state.updateProvider)
|
||||
const handleClick = () => {
|
||||
navigate(`/settings/model/${providerName}&id=${id}`)
|
||||
navigate(`/settings/model/${id}`)
|
||||
}
|
||||
const rawId= useParams();
|
||||
console.log('rawId',rawId)
|
||||
const handleEnable = () => {
|
||||
console.log('enable', enable)
|
||||
updateProvider({
|
||||
id,
|
||||
enabled: enable == 1 ? 0 : 1,
|
||||
})
|
||||
}
|
||||
const rawId = useParams()
|
||||
console.log('rawId', rawId)
|
||||
// @ts-ignore
|
||||
const { id: currentId } = useParams();
|
||||
const { id: currentId } = useParams()
|
||||
const isActive = currentId === id
|
||||
return (
|
||||
<div
|
||||
@@ -24,18 +39,26 @@ const ProviderCard: FC<IProviderCardProps> = ({ providerName, Icon, id }: IProvi
|
||||
handleClick()
|
||||
}}
|
||||
className={
|
||||
styles.card + ' flex h-14 items-center justify-between rounded border border-[#f3f3f3] p-2'
|
||||
+(isActive ? ' bg-[#F0F0F0] font-semibold text-blue-600' : '')
|
||||
styles.card +
|
||||
' flex h-14 items-center justify-between rounded border border-[#f3f3f3] p-2' +
|
||||
(isActive ? ' bg-[#F0F0F0] font-semibold text-blue-600' : '')
|
||||
}
|
||||
>
|
||||
<div className="flex items-center text-lg">
|
||||
<div className="h-9 w-9 flex items-center">
|
||||
<AILogo name={Icon} />
|
||||
<div className="flex h-9 w-9 items-center">
|
||||
<AILogo name={Icon} />
|
||||
</div>
|
||||
<div className="font-semibold">{providerName}</div>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Switch />
|
||||
<Switch
|
||||
onClick={e => {
|
||||
e.preventDefault()
|
||||
handleEnable()
|
||||
}}
|
||||
checked={enable == 1}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
|
||||
40
BillNote_frontend/src/components/Lottie/download.tsx
Normal file
40
BillNote_frontend/src/components/Lottie/download.tsx
Normal file
@@ -0,0 +1,40 @@
|
||||
import { FC, useRef, useEffect } from 'react'
|
||||
import Lottie, { LottieRefCurrentProps } from 'lottie-react'
|
||||
import download from '@/assets/Lottie/download.json'
|
||||
|
||||
interface LoadingProps {
|
||||
play?: boolean // 是否播放
|
||||
color?: string // 控制主色,比如 "#00BFFF"
|
||||
}
|
||||
|
||||
const Downloading: FC<LoadingProps> = ({ play = true, color = '#00BFFF' }) => {
|
||||
const lottieRef = useRef<LottieRefCurrentProps>(null)
|
||||
|
||||
useEffect(() => {
|
||||
if (!lottieRef.current) return
|
||||
|
||||
if (play) {
|
||||
lottieRef.current.play()
|
||||
} else {
|
||||
lottieRef.current.pause()
|
||||
}
|
||||
}, [play])
|
||||
|
||||
return (
|
||||
<div className="flex items-center justify-center">
|
||||
<Lottie
|
||||
lottieRef={lottieRef}
|
||||
animationData={download}
|
||||
loop
|
||||
autoplay={play}
|
||||
style={{
|
||||
width: 150,
|
||||
height: 150,
|
||||
filter: `drop-shadow(0 0 4px ${color}) saturate(2) brightness(1.2)`,
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default Downloading
|
||||
21
BillNote_frontend/src/components/Lottie/error.tsx
Normal file
21
BillNote_frontend/src/components/Lottie/error.tsx
Normal file
@@ -0,0 +1,21 @@
|
||||
import { FC } from 'react'
|
||||
import Lottie from 'lottie-react'
|
||||
import error from '@/assets/Lottie/error.json'
|
||||
|
||||
const Error: FC = () => {
|
||||
return (
|
||||
<div className="flex items-center justify-center">
|
||||
<Lottie
|
||||
animationData={error}
|
||||
loop
|
||||
autoplay
|
||||
style={{
|
||||
width: 450,
|
||||
height: 450,
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default Error
|
||||
64
BillNote_frontend/src/components/ui/alert.tsx
Normal file
64
BillNote_frontend/src/components/ui/alert.tsx
Normal file
@@ -0,0 +1,64 @@
|
||||
import * as React from 'react'
|
||||
import { cva, type VariantProps } from 'class-variance-authority'
|
||||
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
const alertVariants = cva(
|
||||
'relative w-full rounded-lg border px-4 py-3 text-sm grid has-[>svg]:grid-cols-[calc(var(--spacing)*4)_1fr] grid-cols-[0_1fr] has-[>svg]:gap-x-3 gap-y-0.5 items-start [&>svg]:size-4 [&>svg]:translate-y-0.5 [&>svg]:text-current',
|
||||
{
|
||||
variants: {
|
||||
variant: {
|
||||
default: 'bg-card text-card-foreground',
|
||||
destructive:
|
||||
'text-destructive bg-card [&>svg]:text-current *:data-[slot=alert-description]:text-destructive/90',
|
||||
success:
|
||||
'text-success bg-card [&>svg]:text-current *:data-[slot=alert-description]:text-success/90',
|
||||
warning:
|
||||
'text-[#303133] bg-[#FEF0F0] [&>svg]:text-current *:data-[slot=alert-description]:text-warning/90',
|
||||
},
|
||||
},
|
||||
defaultVariants: {
|
||||
variant: 'default',
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
function Alert({
|
||||
className,
|
||||
variant,
|
||||
...props
|
||||
}: React.ComponentProps<'div'> & VariantProps<typeof alertVariants>) {
|
||||
return (
|
||||
<div
|
||||
data-slot="alert"
|
||||
role="alert"
|
||||
className={cn(alertVariants({ variant }), className)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
function AlertTitle({ className, ...props }: React.ComponentProps<'div'>) {
|
||||
return (
|
||||
<div
|
||||
data-slot="alert-title"
|
||||
className={cn('col-start-2 line-clamp-1 min-h-4 font-medium tracking-tight', className)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
function AlertDescription({ className, ...props }: React.ComponentProps<'div'>) {
|
||||
return (
|
||||
<div
|
||||
data-slot="alert-description"
|
||||
className={cn(
|
||||
'text-muted-foreground col-start-2 grid justify-items-start gap-1 text-sm [&_p]:leading-relaxed',
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
export { Alert, AlertTitle, AlertDescription }
|
||||
29
BillNote_frontend/src/components/ui/switch.tsx
Normal file
29
BillNote_frontend/src/components/ui/switch.tsx
Normal file
@@ -0,0 +1,29 @@
|
||||
import * as React from "react"
|
||||
import * as SwitchPrimitive from "@radix-ui/react-switch"
|
||||
|
||||
import { cn } from "@/lib/utils"
|
||||
|
||||
function Switch({
|
||||
className,
|
||||
...props
|
||||
}: React.ComponentProps<typeof SwitchPrimitive.Root>) {
|
||||
return (
|
||||
<SwitchPrimitive.Root
|
||||
data-slot="switch"
|
||||
className={cn(
|
||||
"peer data-[state=checked]:bg-primary data-[state=unchecked]:bg-input focus-visible:border-ring focus-visible:ring-ring/50 dark:data-[state=unchecked]:bg-input/80 inline-flex h-[1.15rem] w-8 shrink-0 items-center rounded-full border border-transparent shadow-xs transition-all outline-none focus-visible:ring-[3px] disabled:cursor-not-allowed disabled:opacity-50",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
<SwitchPrimitive.Thumb
|
||||
data-slot="switch-thumb"
|
||||
className={cn(
|
||||
"bg-background dark:data-[state=unchecked]:bg-foreground dark:data-[state=checked]:bg-primary-foreground pointer-events-none block size-4 rounded-full ring-0 transition-transform data-[state=checked]:translate-x-[calc(100%-2px)] data-[state=unchecked]:translate-x-0"
|
||||
)}
|
||||
/>
|
||||
</SwitchPrimitive.Root>
|
||||
)
|
||||
}
|
||||
|
||||
export { Switch }
|
||||
@@ -1,45 +1,59 @@
|
||||
// hooks/useTaskPolling.ts
|
||||
import { useEffect } from 'react'
|
||||
import { useEffect, useRef } from 'react'
|
||||
import { useTaskStore } from '@/store/taskStore'
|
||||
import { get_task_status } from '@/services/note.ts'
|
||||
import toast from 'react-hot-toast'
|
||||
|
||||
export const useTaskPolling = (interval = 3000) => {
|
||||
const tasks = useTaskStore(state => state.tasks)
|
||||
const updateTaskContent = useTaskStore(state => state.updateTaskContent)
|
||||
const updateTaskStatus = useTaskStore(state => state.updateTaskStatus)
|
||||
const removeTask = useTaskStore(state => state.removeTask)
|
||||
|
||||
const tasksRef = useRef(tasks)
|
||||
|
||||
// 每次 tasks 更新,把最新的 tasks 同步进去
|
||||
useEffect(() => {
|
||||
tasksRef.current = tasks
|
||||
}, [tasks])
|
||||
|
||||
useEffect(() => {
|
||||
const timer = setInterval(async () => {
|
||||
const pendingTasks = tasks.filter(
|
||||
task => task.status === 'PENDING' || task.status === 'running'
|
||||
const pendingTasks = tasksRef.current.filter(
|
||||
task => task.status != 'SUCCESS' && task.status != 'FAILED'
|
||||
)
|
||||
|
||||
for (const task of pendingTasks) {
|
||||
try {
|
||||
console.log(task)
|
||||
console.log('🔄 正在轮询任务:', task.id)
|
||||
const res = await get_task_status(task.id)
|
||||
const { status } = res.data
|
||||
|
||||
if (status && status !== task.status) {
|
||||
if (status === 'SUCCESS') {
|
||||
const { markdown, transcript, audio_meta } = res.data.result
|
||||
|
||||
toast.success('笔记生成成功')
|
||||
updateTaskContent(task.id, {
|
||||
status,
|
||||
markdown,
|
||||
transcript,
|
||||
audioMeta: audio_meta,
|
||||
})
|
||||
} else if (status === 'FAILED') {
|
||||
updateTaskContent(task.id, { status })
|
||||
console.warn(`⚠️ 任务 ${task.id} 失败`)
|
||||
} else {
|
||||
updateTaskStatus(task.id, status)
|
||||
updateTaskContent(task.id, { status })
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('❌ 任务轮询失败:', e)
|
||||
removeTask(task.id)
|
||||
toast.error(`生成失败 ${e.message || e}`)
|
||||
updateTaskContent(task.id, { status: 'FAILED' })
|
||||
// removeTask(task.id)
|
||||
}
|
||||
}
|
||||
}, interval)
|
||||
|
||||
return () => clearInterval(timer)
|
||||
}, [interval, tasks])
|
||||
}, [interval])
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ import HomeLayout from '@/layouts/HomeLayout.tsx'
|
||||
import NoteForm from '@/pages/HomePage/components/NoteForm.tsx'
|
||||
import MarkdownViewer from '@/pages/HomePage/components/MarkdownViewer.tsx'
|
||||
import { useTaskStore } from '@/store/taskStore'
|
||||
type ViewStatus = 'idle' | 'loading' | 'success'
|
||||
type ViewStatus = 'idle' | 'loading' | 'success' | 'failed'
|
||||
export const HomePage: FC = () => {
|
||||
const tasks = useTaskStore(state => state.tasks)
|
||||
const currentTaskId = useTaskStore(state => state.currentTaskId)
|
||||
@@ -21,6 +21,8 @@ export const HomePage: FC = () => {
|
||||
setStatus('loading')
|
||||
} else if (currentTask.status === 'SUCCESS') {
|
||||
setStatus('success')
|
||||
} else if (currentTask.status === 'FAILED') {
|
||||
setStatus('failed')
|
||||
}
|
||||
}, [currentTask])
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import ReactMarkdown from 'react-markdown'
|
||||
import { Button } from '@/components/ui/button.tsx'
|
||||
import { Copy, Download, FileText, ArrowRight } from 'lucide-react'
|
||||
import { toast } from 'sonner' // 你可以换成自己的通知组件
|
||||
|
||||
import Error from '@/components/Lottie/error.tsx'
|
||||
import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter'
|
||||
import { solarizedlight as codeStyle } from 'react-syntax-highlighter/dist/cjs/styles/prism'
|
||||
import 'github-markdown-css/github-markdown-light.css'
|
||||
@@ -11,14 +11,26 @@ import { FC } from 'react'
|
||||
import Loading from '@/components/Lottie/Loading.tsx'
|
||||
import Idle from '@/components/Lottie/Idle.tsx'
|
||||
import { useTaskStore } from '@/store/taskStore'
|
||||
import StepBar from '@/pages/HomePage/components/StepBar.tsx'
|
||||
interface MarkdownViewerProps {
|
||||
content: string
|
||||
status: 'idle' | 'loading' | 'success'
|
||||
status: 'idle' | 'loading' | 'success' | 'failed'
|
||||
}
|
||||
|
||||
const steps = [
|
||||
{ label: '解析链接', key: 'PARSING' },
|
||||
{ label: '下载音频', key: 'DOWNLOADING' },
|
||||
{ label: '转写文字', key: 'TRANSCRIBING' },
|
||||
{ label: '总结内容', key: 'SUMMARIZING' },
|
||||
{ label: '保存完成', key: 'SUCCESS' },
|
||||
]
|
||||
|
||||
const MarkdownViewer: FC<MarkdownViewerProps> = ({ content, status }) => {
|
||||
const [copied, setCopied] = useState(false)
|
||||
const getCurrentTask = useTaskStore.getState().getCurrentTask
|
||||
const currentTask = useTaskStore(state => state.getCurrentTask())
|
||||
const taskStatus = currentTask?.status || 'PENDING'
|
||||
const retryTask = useTaskStore.getState().retryTask
|
||||
const handleCopy = async () => {
|
||||
try {
|
||||
await navigator.clipboard.writeText(content)
|
||||
@@ -34,6 +46,7 @@ const MarkdownViewer: FC<MarkdownViewerProps> = ({ content, status }) => {
|
||||
const handleDownload = () => {
|
||||
const currentTask = getCurrentTask()
|
||||
const currentTaskName = currentTask?.audioMeta.title
|
||||
|
||||
const blob = new Blob([content], { type: 'text/markdown;charset=utf-8' })
|
||||
const link = document.createElement('a')
|
||||
link.href = URL.createObjectURL(blob)
|
||||
@@ -45,6 +58,7 @@ const MarkdownViewer: FC<MarkdownViewerProps> = ({ content, status }) => {
|
||||
if (status === 'loading') {
|
||||
return (
|
||||
<div className="flex h-screen w-full flex-col items-center justify-center space-y-4 text-neutral-500">
|
||||
<StepBar steps={steps} currentStep={taskStatus} />
|
||||
<Loading className="h-5 w-5" />
|
||||
<div className="text-center text-sm">
|
||||
<p className="text-lg font-bold">正在生成笔记,请稍候…</p>
|
||||
@@ -63,6 +77,24 @@ const MarkdownViewer: FC<MarkdownViewerProps> = ({ content, status }) => {
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
} else if (status === 'failed') {
|
||||
return (
|
||||
<div className="flex h-screen w-full flex-col items-center justify-center gap-4 space-y-3">
|
||||
<Error /> {/* 你可以换成 Failed 动画 */}
|
||||
<div className="text-center">
|
||||
<p className="text-lg font-bold text-red-500">笔记生成失败</p>
|
||||
<p className="mt-2 mb-2 text-xs text-red-400">请检查后台或稍后再试</p>
|
||||
<Button
|
||||
onClick={() => {
|
||||
retryTask(currentTask.id)
|
||||
}}
|
||||
size="lg"
|
||||
>
|
||||
重试
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
|
||||
@@ -6,6 +6,7 @@ import {
|
||||
FormLabel,
|
||||
FormMessage,
|
||||
} from '@/components/ui/form.tsx'
|
||||
import { useEffect } from 'react'
|
||||
import { Input } from '@/components/ui/input.tsx'
|
||||
import {
|
||||
Select,
|
||||
@@ -30,7 +31,9 @@ import {
|
||||
import { generateNote } from '@/services/note.ts'
|
||||
import { useTaskStore } from '@/store/taskStore'
|
||||
import NoteHistory from '@/pages/HomePage/components/NoteHistory.tsx'
|
||||
|
||||
import { useModelStore } from '@/store/modelStore'
|
||||
import { Alert } from 'antd'
|
||||
import { Textarea } from '@/components/ui/textarea.tsx'
|
||||
// ✅ 定义表单 schema
|
||||
const formSchema = z.object({
|
||||
video_url: z.string().url('请输入正确的视频链接'),
|
||||
@@ -40,15 +43,70 @@ const formSchema = z.object({
|
||||
}),
|
||||
screenshot: z.boolean().optional(),
|
||||
link: z.boolean().optional(),
|
||||
model_name: z.string().nonempty('请选择模型'),
|
||||
format: z.array(z.string()).default([]), // ✨ 确保默认是空数组
|
||||
style: z.string().nonempty('请选择笔记生成风格'),
|
||||
extras: z.string().optional(),
|
||||
})
|
||||
|
||||
type NoteFormValues = z.infer<typeof formSchema>
|
||||
const noteFormats = [
|
||||
{
|
||||
label: '目录',
|
||||
value: 'toc',
|
||||
},
|
||||
{ label: '原片跳转', value: 'link' },
|
||||
{ label: '原片截图', value: 'screenshot' },
|
||||
{ label: 'AI总结', value: 'summary' },
|
||||
]
|
||||
const noteStyles = [
|
||||
{
|
||||
label: '精简',
|
||||
value: 'minimal', // 简洁、快速呈现要点
|
||||
},
|
||||
{
|
||||
label: '详细',
|
||||
value: 'detailed', // 详细记录,包含时间戳、关键点
|
||||
},
|
||||
{
|
||||
label: '教程',
|
||||
value: 'tutorial', // 详细记录,包含时间戳、关键点
|
||||
},
|
||||
{
|
||||
label: '学术',
|
||||
value: 'academic', // 适合学术报告,正式且结构化
|
||||
},
|
||||
{
|
||||
label: '小红书',
|
||||
value: 'xiaohongshu', // 适合社交平台分享,亲切、口语化
|
||||
},
|
||||
{
|
||||
label: '生活向',
|
||||
value: 'life_journal', // 记录个人生活感悟,情感化表达
|
||||
},
|
||||
{
|
||||
label: '任务导向',
|
||||
value: 'task_oriented', // 强调任务、目标,适合工作和待办事项
|
||||
},
|
||||
{
|
||||
label: '商业风格',
|
||||
value: 'business', // 适合商业报告、会议纪要,正式且精准
|
||||
},
|
||||
{
|
||||
label: '会议纪要',
|
||||
value: 'meeting_minutes', // 适合商业报告、会议纪要,正式且精准
|
||||
},
|
||||
]
|
||||
|
||||
const NoteForm = () => {
|
||||
useTaskStore(state => state.tasks)
|
||||
const setCurrentTask = useTaskStore(state => state.setCurrentTask)
|
||||
const currentTaskId = useTaskStore(state => state.currentTaskId)
|
||||
const getCurrentTask = useTaskStore(state => state.getCurrentTask)
|
||||
const loadEnabledModels = useModelStore(state => state.loadEnabledModels)
|
||||
const modelList = useModelStore(state => state.modelList)
|
||||
const showFeatureHint = useModelStore(state => state.showFeatureHint)
|
||||
const setShowFeatureHint = useModelStore(state => state.setShowFeatureHint)
|
||||
const form = useForm<NoteFormValues>({
|
||||
resolver: zodResolver(formSchema),
|
||||
defaultValues: {
|
||||
@@ -56,9 +114,16 @@ const NoteForm = () => {
|
||||
platform: 'bilibili',
|
||||
quality: 'medium', // 默认中等质量
|
||||
screenshot: false,
|
||||
model_name: modelList[0]?.model_name || '', // 确保有值
|
||||
format: [], // 初始化为空数组
|
||||
style: 'minimal', // 默认选择精简风格
|
||||
extras: '', // 初始化为空字符串
|
||||
},
|
||||
})
|
||||
|
||||
const onClose = () => {
|
||||
setShowFeatureHint(false)
|
||||
}
|
||||
const isGenerating = () => {
|
||||
console.log('🚀 isGenerating', getCurrentTask()?.status)
|
||||
return getCurrentTask()?.status === 'PENDING'
|
||||
@@ -66,14 +131,23 @@ const NoteForm = () => {
|
||||
|
||||
const onSubmit = async (data: NoteFormValues) => {
|
||||
console.log('🎯 提交内容:', data)
|
||||
await generateNote({
|
||||
const payload = {
|
||||
video_url: data.video_url,
|
||||
platform: data.platform,
|
||||
quality: data.quality,
|
||||
screenshot: data.screenshot,
|
||||
link: data.link,
|
||||
})
|
||||
model_name: data.model_name,
|
||||
provider_id: modelList.find(model => model.model_name === data.model_name).provider_id,
|
||||
format: data.format,
|
||||
style: data.style,
|
||||
extras: data.extras,
|
||||
}
|
||||
const res = await generateNote(payload)
|
||||
const taskId = res.data.task_id
|
||||
useTaskStore.getState().addPendingTask(taskId, data.platform, payload)
|
||||
}
|
||||
useEffect(() => {
|
||||
loadEnabledModels()
|
||||
}, [])
|
||||
|
||||
return (
|
||||
<div className="flex h-full flex-col">
|
||||
@@ -173,48 +247,157 @@ const NoteForm = () => {
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="model_name"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<div className="my-3 flex items-center justify-between">
|
||||
<h2 className="block">模型选择</h2>
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Info className="hover:text-primary h-4 w-4 cursor-pointer text-neutral-400" />
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
<p className="max-w-[200px] text-xs">不同模型返回质量不同,可自行测试</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
</div>
|
||||
<Select onValueChange={field.onChange} defaultValue={field.value}>
|
||||
<FormControl>
|
||||
<SelectTrigger className="w-full">
|
||||
<SelectValue placeholder="选择配置好的模型" />
|
||||
</SelectTrigger>
|
||||
</FormControl>
|
||||
<SelectContent>
|
||||
{modelList.map(item => {
|
||||
return <SelectItem value={item.model_name}>{item.model_name}</SelectItem>
|
||||
})}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
{/*<FormDescription className="text-xs text-neutral-500">*/}
|
||||
{/* 质量越高,下载体积越大,速度越慢*/}
|
||||
{/*</FormDescription>*/}
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* 是否需要原片位置 */}
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="link"
|
||||
name="style"
|
||||
render={({ field }) => (
|
||||
<FormItem className="flex items-center space-x-2">
|
||||
{/* Tooltip 部分 */}
|
||||
<FormItem>
|
||||
<div className="my-3 flex items-center justify-between">
|
||||
<h2 className="block">笔记风格</h2>
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Info className="hover:text-primary h-4 w-4 cursor-pointer text-neutral-400" />
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
<p className="max-w-[200px] text-xs">选择你希望生成的笔记风格</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
</div>
|
||||
|
||||
<FormControl>
|
||||
<Checkbox checked={field.value} onCheckedChange={field.onChange} id="link" />
|
||||
</FormControl>
|
||||
<Select onValueChange={field.onChange} defaultValue={field.value}>
|
||||
<FormControl>
|
||||
<SelectTrigger className="w-full">
|
||||
<SelectValue placeholder="选择笔记风格" />
|
||||
</SelectTrigger>
|
||||
</FormControl>
|
||||
<SelectContent>
|
||||
{noteStyles.map(item => (
|
||||
<SelectItem key={item.value} value={item.value}>
|
||||
{item.label}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
|
||||
<FormLabel htmlFor="link" className="text-sm leading-none font-medium">
|
||||
是否插入内容跳转链接
|
||||
</FormLabel>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
{/* 是否需要下载 */}
|
||||
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="screenshot"
|
||||
name="format"
|
||||
render={({ field }) => (
|
||||
<FormItem className="flex items-center space-x-2">
|
||||
{/* Tooltip 部分 */}
|
||||
<FormItem>
|
||||
<div className="my-3 flex items-center justify-between">
|
||||
<h2 className="block">笔记格式</h2>
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Info className="hover:text-primary h-4 w-4 cursor-pointer text-neutral-400" />
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
<p className="text-xs">选择要包含的笔记元素,比如时间戳、截图提示或总结</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
</div>
|
||||
|
||||
<FormControl>
|
||||
<Checkbox
|
||||
checked={field.value}
|
||||
onCheckedChange={field.onChange}
|
||||
id="screenshot"
|
||||
/>
|
||||
<div className="flex space-x-1.5">
|
||||
{noteFormats.map(item => (
|
||||
<label key={item.value} className="flex items-center space-x-2">
|
||||
<Checkbox
|
||||
checked={field.value?.includes(item.value)}
|
||||
onCheckedChange={checked => {
|
||||
const currentValue = field.value ?? [] // ✨ 保底是数组
|
||||
if (checked) {
|
||||
field.onChange([...currentValue, item.value])
|
||||
} else {
|
||||
field.onChange(currentValue.filter(v => v !== item.value))
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<span>{item.label}</span>
|
||||
</label>
|
||||
))}
|
||||
</div>
|
||||
</FormControl>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="extras"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<div className="my-3 flex items-center justify-between">
|
||||
<h2 className="block">备注</h2>
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Info className="hover:text-primary h-4 w-4 cursor-pointer text-neutral-400" />
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
<p className="text-xs">会把这段加入到Prompt最后 可自行测试</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
</div>
|
||||
<Textarea placeholder={'笔记需要罗列出 xxx 关键点'} />
|
||||
|
||||
<FormLabel htmlFor="screenshot" className="text-sm leading-none font-medium">
|
||||
是否插入视频截图
|
||||
</FormLabel>
|
||||
{/*<FormDescription className="text-xs text-neutral-500">*/}
|
||||
{/* 质量越高,下载体积越大,速度越慢*/}
|
||||
{/*</FormDescription>*/}
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
<div className={'flex w-full items-center gap-2 py-1.5'}>
|
||||
{/* 提交按钮 */}
|
||||
<Button type="submit" className="bg-primary w-full" disabled={isGenerating()}>
|
||||
@@ -235,27 +418,35 @@ const NoteForm = () => {
|
||||
</div>
|
||||
|
||||
{/* 添加一些额外的说明或功能介绍 */}
|
||||
<div className="bg-primary-light mt-6 rounded-lg p-4">
|
||||
<h3 className="text-primary mb-2 font-medium">功能介绍</h3>
|
||||
<ul className="space-y-2 text-sm text-neutral-600">
|
||||
<li className="flex items-start gap-2">
|
||||
<span className="text-primary font-bold">•</span>
|
||||
<span>自动提取视频内容,生成结构化笔记</span>
|
||||
</li>
|
||||
<li className="flex items-start gap-2">
|
||||
<span className="text-primary font-bold">•</span>
|
||||
<span>支持多个视频平台,包括哔哩哔哩、YouTube等</span>
|
||||
</li>
|
||||
<li className="flex items-start gap-2">
|
||||
<span className="text-primary font-bold">•</span>
|
||||
<span>一键复制笔记,支持Markdown格式</span>
|
||||
</li>
|
||||
<li className="flex items-start gap-2">
|
||||
<span className="text-primary font-bold">•</span>
|
||||
<span>可选择是否插入图片</span>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
{showFeatureHint && (
|
||||
<Alert
|
||||
message="功能介绍 v2.0.0"
|
||||
description={
|
||||
<ul className="space-y-2 text-sm text-neutral-600">
|
||||
<li className="flex items-start gap-2">
|
||||
<span className="text-primary font-bold">•</span>
|
||||
<span>自动提取视频内容,生成结构化笔记</span>
|
||||
</li>
|
||||
<li className="flex items-start gap-2">
|
||||
<span className="text-primary font-bold">•</span>
|
||||
<span>支持多个视频平台,包括哔哩哔哩、YouTube等</span>
|
||||
</li>
|
||||
<li className="flex items-start gap-2">
|
||||
<span className="text-primary font-bold">•</span>
|
||||
<span>一键复制笔记,支持Markdown格式</span>
|
||||
</li>
|
||||
<li className="flex items-start gap-2">
|
||||
<span className="text-primary font-bold">•</span>
|
||||
<span>可选择是否插入图片</span>
|
||||
</li>
|
||||
</ul>
|
||||
}
|
||||
type="info"
|
||||
onClose={onClose}
|
||||
closable
|
||||
/>
|
||||
)}
|
||||
{/*<div className="bg-primary-light mt-6 rounded-lg p-4"></div>*/}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
54
BillNote_frontend/src/pages/HomePage/components/StepBar.tsx
Normal file
54
BillNote_frontend/src/pages/HomePage/components/StepBar.tsx
Normal file
@@ -0,0 +1,54 @@
|
||||
import { FC } from 'react'
|
||||
|
||||
interface Step {
|
||||
label: string
|
||||
key: string
|
||||
Icon?: React.ReactNode // 加一个可选的 Lottie 动画
|
||||
}
|
||||
|
||||
interface StepBarProps {
|
||||
steps: Step[]
|
||||
currentStep: string
|
||||
}
|
||||
|
||||
const StepBar: FC<StepBarProps> = ({ steps, currentStep }) => {
|
||||
const currentIndex = steps.findIndex(step => step.key === currentStep)
|
||||
|
||||
return (
|
||||
<div className="flex w-full items-center justify-between">
|
||||
{steps.map((step, index) => {
|
||||
const isActive = index <= currentIndex
|
||||
const isCurrent = index === currentIndex
|
||||
const isLast = index === steps.length - 1
|
||||
return (
|
||||
<div key={step.key} className="relative flex flex-1 flex-col items-center">
|
||||
{/* 圆圈或者Lottie */}
|
||||
<div className="relative flex flex-col items-center justify-center">
|
||||
<div
|
||||
className={`flex h-8 w-8 items-center justify-center rounded-full text-xs font-bold ${
|
||||
isActive ? 'bg-primary text-white' : 'bg-gray-300 text-gray-600'
|
||||
}`}
|
||||
>
|
||||
{index + 1}
|
||||
</div>
|
||||
{/* 当前步骤显示动画 */}
|
||||
{isCurrent && step.Icon && (
|
||||
<div className="absolute top-10 h-16 w-16">{step.Icon}</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* 步骤名称 */}
|
||||
<div className="mt-4 text-center text-xs text-gray-700">{step.label}</div>
|
||||
|
||||
{/* 连接线 */}
|
||||
{!isLast && (
|
||||
<div className={`h-1 w-full ${isActive ? 'bg-primary' : 'bg-gray-300'}`}></div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default StepBar
|
||||
@@ -9,26 +9,27 @@ const Menu = () => {
|
||||
icon: <BotMessageSquare />,
|
||||
path: '/settings/model',
|
||||
},
|
||||
{
|
||||
id: ' transcriber',
|
||||
name: '音频转译配置',
|
||||
icon: <Captions />,
|
||||
path: '/settings/transcriber',
|
||||
},
|
||||
//下载配置
|
||||
{
|
||||
id: 'download',
|
||||
name: '下载配置',
|
||||
icon: <HardDriveDownload />,
|
||||
path: '/settings/download',
|
||||
},
|
||||
//其他配置
|
||||
{
|
||||
id: 'other',
|
||||
name: '其他配置',
|
||||
icon: <Wrench />,
|
||||
path: '/settings/other',
|
||||
},
|
||||
// TODO :下一版本升级优化
|
||||
// {
|
||||
// id: ' transcriber',
|
||||
// name: '音频转译配置',
|
||||
// icon: <Captions />,
|
||||
// path: '/settings/transcriber',
|
||||
// },
|
||||
// //下载配置
|
||||
// {
|
||||
// id: 'download',
|
||||
// name: '下载配置',
|
||||
// icon: <HardDriveDownload />,
|
||||
// path: '/settings/download',
|
||||
// },
|
||||
// //其他配置
|
||||
// {
|
||||
// id: 'other',
|
||||
// name: '其他配置',
|
||||
// icon: <Wrench />,
|
||||
// path: '/settings/other',
|
||||
// },
|
||||
]
|
||||
return (
|
||||
<div className="flex h-full flex-col">
|
||||
|
||||
8
BillNote_frontend/src/pages/SettingPage/transcriber.tsx
Normal file
8
BillNote_frontend/src/pages/SettingPage/transcriber.tsx
Normal file
@@ -0,0 +1,8 @@
|
||||
const Transcriber = () => {
|
||||
return (
|
||||
<div className="flex h-screen w-full flex-col items-center justify-center">
|
||||
<h1 className="text-center text-4xl font-bold">Transcriber is under development</h1>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
export default Transcriber
|
||||
@@ -3,3 +3,29 @@ import request from '@/utils/request.ts'
|
||||
export const getProviderList = async () => {
|
||||
return await request.get('/get_all_providers')
|
||||
}
|
||||
export const getProviderById = async (id: string) => {
|
||||
return await request.get(`/get_provider_by_id/${id}`)
|
||||
}
|
||||
export const updateProviderById = async (data: any) => {
|
||||
return await request.post('/update_provider', data)
|
||||
}
|
||||
|
||||
export const addProvider = async (data: any) => {
|
||||
return await request.post('/add_provider', data)
|
||||
}
|
||||
|
||||
export const testConnection = async (data: any) => {
|
||||
return await request.post('/connect_test', data)
|
||||
}
|
||||
|
||||
export const fetchModels = async (providerId: any) => {
|
||||
return await request.get('/model_list/' + providerId)
|
||||
}
|
||||
|
||||
export async function addModel(data: { provider_id: string; model_name: string }) {
|
||||
return request.post('/models', data)
|
||||
}
|
||||
|
||||
export const fetchEnableModels = async () => {
|
||||
return await request.get('/model_list')
|
||||
}
|
||||
|
||||
@@ -4,10 +4,14 @@ import { useTaskStore } from '@/store/taskStore'
|
||||
import request from '@/utils/request'
|
||||
export const generateNote = async (data: {
|
||||
video_url: string
|
||||
link: undefined | boolean
|
||||
screenshot: undefined | boolean
|
||||
platform: string
|
||||
quality: string
|
||||
model_name: string
|
||||
provider_id: string
|
||||
task_id?: string
|
||||
format: Array<string>
|
||||
style: string
|
||||
extras?: string
|
||||
}) => {
|
||||
try {
|
||||
const response = await request.post('/generate_note', data)
|
||||
@@ -20,11 +24,8 @@ export const generateNote = async (data: {
|
||||
}
|
||||
toast.success('笔记生成任务已提交!')
|
||||
|
||||
const taskId = response.data.data.task_id
|
||||
|
||||
console.log('res', response)
|
||||
// 成功提示
|
||||
useTaskStore.getState().addPendingTask(taskId, data.platform)
|
||||
|
||||
return response.data
|
||||
} catch (e: any) {
|
||||
|
||||
25
BillNote_frontend/src/store/configStore/index.ts
Normal file
25
BillNote_frontend/src/store/configStore/index.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
import { create } from 'zustand'
|
||||
import { persist } from 'zustand/middleware'
|
||||
interface SystemState {
|
||||
showFeatureHint: boolean // ✅ 是否显示功能提示
|
||||
setShowFeatureHint: (value: boolean) => void
|
||||
|
||||
// 后续如果有其他全局状态,可以继续加
|
||||
sidebarCollapsed: boolean // ✅ 侧边栏是否收起
|
||||
setSidebarCollapsed: (value: boolean) => void
|
||||
}
|
||||
// 暂不启用
|
||||
export const useSystemStore = create<SystemState>()(
|
||||
persist(
|
||||
set => ({
|
||||
showFeatureHint: true,
|
||||
setShowFeatureHint: value => set({ showFeatureHint: value }),
|
||||
|
||||
sidebarCollapsed: false,
|
||||
setSidebarCollapsed: value => set({ sidebarCollapsed: value }),
|
||||
}),
|
||||
{
|
||||
name: 'system-store', // 本地存储的 key
|
||||
}
|
||||
)
|
||||
)
|
||||
101
BillNote_frontend/src/store/modelStore/index.ts
Normal file
101
BillNote_frontend/src/store/modelStore/index.ts
Normal file
@@ -0,0 +1,101 @@
|
||||
import { create } from 'zustand'
|
||||
import { devtools } from 'zustand/middleware'
|
||||
import { fetchModels, addModel, fetchEnableModels } from '@/services/model.ts'
|
||||
|
||||
interface IModel {
|
||||
id: string
|
||||
created: number
|
||||
object: string
|
||||
owned_by: string
|
||||
permission: string
|
||||
root: string
|
||||
}
|
||||
|
||||
interface ModelStore {
|
||||
models: IModel[]
|
||||
modelList: []
|
||||
loading: boolean
|
||||
selectedModel: string
|
||||
loadModels: (providerId: string) => Promise<void>
|
||||
loadEnabledModels: () => Promise<void>
|
||||
addNewModel: (providerId: string, modelId: string) => Promise<void>
|
||||
setSelectedModel: (modelId: string) => void
|
||||
clearModels: () => void
|
||||
}
|
||||
|
||||
export const useModelStore = create<ModelStore>()(
|
||||
devtools(set => ({
|
||||
models: [],
|
||||
loading: false,
|
||||
selectedModel: '',
|
||||
modelList: [],
|
||||
|
||||
loadEnabledModels: async () => {
|
||||
try {
|
||||
set({ loading: true })
|
||||
const res = await fetchEnableModels()
|
||||
if (res.data.code === 0 && res.data.data.length > 0) {
|
||||
set({ modelList: res.data.data })
|
||||
} else {
|
||||
set({ modelList: [] })
|
||||
console.error('模型列表加载失败')
|
||||
}
|
||||
} catch (error) {
|
||||
set({ modelList: [] })
|
||||
console.error('加载模型出错', error)
|
||||
}
|
||||
},
|
||||
// 加载模型列表
|
||||
loadModels: async (providerId: string) => {
|
||||
try {
|
||||
set({ loading: true })
|
||||
const res = await fetchModels(providerId)
|
||||
if (res.data.code === 0 && res.data.data.models.data.length > 0) {
|
||||
set({ models: res.data.data.models.data })
|
||||
} else {
|
||||
set({ models: [] })
|
||||
console.error('模型列表加载失败')
|
||||
}
|
||||
} catch (error) {
|
||||
set({ models: [] })
|
||||
console.error('加载模型出错', error)
|
||||
} finally {
|
||||
set({ loading: false })
|
||||
}
|
||||
},
|
||||
|
||||
// 新增模型
|
||||
addNewModel: async (providerId: string, modelId: string) => {
|
||||
try {
|
||||
const res = await addModel({ provider_id: providerId, model_name: modelId })
|
||||
if (res.code === 0) {
|
||||
console.log('新增模型成功:', modelId)
|
||||
// ✅ 新增成功以后,前端直接追加一条到 models 列表
|
||||
set(state => ({
|
||||
models: [
|
||||
...state.models,
|
||||
{
|
||||
id: modelId,
|
||||
created: Date.now(),
|
||||
object: 'model',
|
||||
owned_by: '',
|
||||
permission: '',
|
||||
root: '',
|
||||
},
|
||||
],
|
||||
}))
|
||||
} else {
|
||||
console.error('新增模型失败')
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('添加模型出错', error)
|
||||
}
|
||||
},
|
||||
|
||||
// 设置选中的模型
|
||||
setSelectedModel: modelId => set({ selectedModel: modelId }),
|
||||
|
||||
// 清空
|
||||
clearModels: () => set({ models: [], selectedModel: '' }),
|
||||
}))
|
||||
)
|
||||
@@ -1,6 +1,11 @@
|
||||
import { create } from 'zustand'
|
||||
import { IProvider } from '@/types'
|
||||
import { getProviderList } from '@/services/model.ts'
|
||||
import {
|
||||
addProvider,
|
||||
getProviderById,
|
||||
getProviderList,
|
||||
updateProviderById,
|
||||
} from '@/services/model.ts'
|
||||
|
||||
interface ProviderStore {
|
||||
provider: IProvider[]
|
||||
@@ -9,12 +14,14 @@ interface ProviderStore {
|
||||
getProviderById: (id: number) => IProvider | undefined
|
||||
getProviderList: () => IProvider[]
|
||||
fetchProviderList: () => Promise<void>
|
||||
loadProviderById: (id: string) => Promise<void>
|
||||
addNewProvider: (provider: IProvider) => Promise<void>
|
||||
updateProvider: (provider: IProvider) => Promise<void>
|
||||
}
|
||||
|
||||
export const useProviderStore = create<ProviderStore>((set, get) => ({
|
||||
provider: [],
|
||||
|
||||
|
||||
// 添加或更新一个 provider
|
||||
setProvider: newProvider =>
|
||||
set(state => {
|
||||
@@ -30,10 +37,60 @@ export const useProviderStore = create<ProviderStore>((set, get) => ({
|
||||
|
||||
// 设置整个 provider 列表
|
||||
setAllProviders: providers => set({ provider: providers }),
|
||||
|
||||
loadProviderById: async (id: string) => {
|
||||
const res = await getProviderById(id)
|
||||
if (res.data.code === 0) {
|
||||
const item = res.data.data
|
||||
console.log('Provider ', item)
|
||||
return {
|
||||
id: item.id,
|
||||
name: item.name,
|
||||
logo: item.logo,
|
||||
apiKey: item.api_key,
|
||||
baseUrl: item.base_url,
|
||||
type: item.type,
|
||||
enabled: item.enabled,
|
||||
}
|
||||
} else {
|
||||
console.log('Provider not found')
|
||||
}
|
||||
},
|
||||
addNewProvider: async (provider: IProvider) => {
|
||||
const payload = {
|
||||
...provider,
|
||||
api_key: provider.apiKey,
|
||||
base_url: provider.baseUrl,
|
||||
}
|
||||
try {
|
||||
const res = await addProvider(payload)
|
||||
if (res.data.code === 0) {
|
||||
const item = res.data.data
|
||||
console.log('Provider ', item)
|
||||
await get().fetchProviderList()
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error fetching provider:', error)
|
||||
}
|
||||
},
|
||||
// 按 id 获取单个 provider
|
||||
getProviderById: id => get().provider.find(p => p.id === id),
|
||||
|
||||
updateProvider: async (provider: IProvider) => {
|
||||
try {
|
||||
const data = {
|
||||
...provider,
|
||||
api_key: provider.apiKey,
|
||||
base_url: provider.baseUrl,
|
||||
}
|
||||
const res = await updateProviderById(data)
|
||||
if (res.data.code === 0) {
|
||||
const item = res.data.data
|
||||
console.log('Provider ', item)
|
||||
await get().fetchProviderList()
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error fetching provider:', error)
|
||||
}
|
||||
},
|
||||
getProviderList: () => get().provider,
|
||||
fetchProviderList: async () => {
|
||||
try {
|
||||
@@ -55,6 +112,7 @@ export const useProviderStore = create<ProviderStore>((set, get) => ({
|
||||
apiKey: item.api_key,
|
||||
baseUrl: item.base_url,
|
||||
type: item.type,
|
||||
enabled: item.enabled,
|
||||
}
|
||||
}
|
||||
),
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { create } from 'zustand'
|
||||
import { persist } from 'zustand/middleware'
|
||||
import { delete_task } from '@/services/note.ts'
|
||||
import { delete_task, generateNote } from '@/services/note.ts'
|
||||
|
||||
export type TaskStatus = 'PENDING' | 'RUNNING' | 'SUCCESS' | 'FAILD'
|
||||
|
||||
@@ -34,6 +34,15 @@ export interface Task {
|
||||
status: TaskStatus
|
||||
audioMeta: AudioMeta
|
||||
createdAt: string
|
||||
formData: {
|
||||
video_url: string
|
||||
link: undefined | boolean
|
||||
screenshot: undefined | boolean
|
||||
platform: string
|
||||
quality: string
|
||||
model_name: string
|
||||
provider_id: string
|
||||
}
|
||||
}
|
||||
|
||||
interface TaskStore {
|
||||
@@ -45,6 +54,7 @@ interface TaskStore {
|
||||
clearTasks: () => void
|
||||
setCurrentTask: (taskId: string | null) => void
|
||||
getCurrentTask: () => Task | null
|
||||
retryTask: (id: string) => void
|
||||
}
|
||||
|
||||
export const useTaskStore = create<TaskStore>()(
|
||||
@@ -53,10 +63,11 @@ export const useTaskStore = create<TaskStore>()(
|
||||
tasks: [],
|
||||
currentTaskId: null,
|
||||
|
||||
addPendingTask: (taskId: string, platform: string) =>
|
||||
addPendingTask: (taskId: string, platform: string, formData: any) =>
|
||||
set(state => ({
|
||||
tasks: [
|
||||
{
|
||||
formData: formData,
|
||||
id: taskId,
|
||||
status: 'PENDING',
|
||||
markdown: '',
|
||||
@@ -91,6 +102,17 @@ export const useTaskStore = create<TaskStore>()(
|
||||
const currentTaskId = get().currentTaskId
|
||||
return get().tasks.find(task => task.id === currentTaskId) || null
|
||||
},
|
||||
retryTask: async (id: string) => {
|
||||
const task = get().tasks.find(task => task.id === id).formData
|
||||
await generateNote({
|
||||
task_id: id,
|
||||
...task,
|
||||
})
|
||||
set(state => ({
|
||||
tasks: state.tasks.map(task => (task.id === id ? { ...task, status: 'PENDING' } : task)),
|
||||
}))
|
||||
},
|
||||
|
||||
removeTask: async id => {
|
||||
const task = get().tasks.find(t => t.id === id)
|
||||
|
||||
|
||||
1
BillNote_frontend/src/types/index.d.ts
vendored
1
BillNote_frontend/src/types/index.d.ts
vendored
@@ -5,4 +5,5 @@ export interface IProvider {
|
||||
type: string
|
||||
apiKey: string
|
||||
baseUrl: string
|
||||
enabled: number
|
||||
}
|
||||
|
||||
25
README.md
25
README.md
@@ -3,7 +3,7 @@
|
||||
<p align="center">
|
||||
<img src="./doc/icon.svg" alt="BiliNote Banner" width="50" height="50" />
|
||||
</p>
|
||||
<h1 align="center" > BiliNote v1.0.1</h1>
|
||||
<h1 align="center" > BiliNote v1.1.0</h1>
|
||||
</div>
|
||||
|
||||
<p align="center"><i>AI 视频笔记生成工具 让 AI 为你的视频做笔记</i></p>
|
||||
@@ -119,6 +119,7 @@ docker compose up --build
|
||||
|
||||
|
||||
## ⚙️ 环境变量配置
|
||||
> ⚠️ v.1.1.0 以后无需通过环境变量配置 AI
|
||||
|
||||
后端 `.env` 示例:
|
||||
|
||||
@@ -131,15 +132,29 @@ OPENAI_API_KEY=sk-xxxxxx
|
||||
DEEP_SEEK_API_KEY=xxx
|
||||
QWEN_API_KEY=xxx
|
||||
```
|
||||
## Changelog
|
||||
### v1.1.0
|
||||
- #### Added
|
||||
- 新增 AI 笔记风格选择
|
||||
- 新增 AI 笔记返回格式选择
|
||||
- 添加 AI 自定义笔记备注 Prompt
|
||||
- 添加任务失败重试
|
||||
- 添加全局设置页,可在设置页进行模型设置
|
||||
|
||||
- #### Optimize
|
||||
- 优化前端样式,优化用户体验
|
||||
- 增加生成中间产物,可用于失败后加快生成速度
|
||||
- #### Fix
|
||||
- 修复视频截图视频过早删除错误
|
||||
|
||||
## 🧠 TODO
|
||||
|
||||
- [ ] 支持抖音及快手等视频平台
|
||||
- [ ] 支持前端设置切换 AI 模型切换、语音转文字模型
|
||||
- [ ] AI 摘要风格自定义(学术风、口语风、重点提取等)
|
||||
- [x] 支持前端设置切换 AI 模型切换、语音转文字模型
|
||||
- [x] AI 摘要风格自定义(学术风、口语风、重点提取等)
|
||||
- [ ] 笔记导出为 PDF / Word / Notion
|
||||
- [ ] 加入更多模型支持
|
||||
- [ ] 加入更多音频转文本模型支持
|
||||
- [x] 加入更多模型支持
|
||||
- [x] 加入更多音频转文本模型支持
|
||||
|
||||
### Contact and Join-联系和加入社区
|
||||
- BiliNote 交流QQ群:785367111
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from fastapi import FastAPI
|
||||
from .routers import note, provider
|
||||
from .routers import note, provider,model
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
app = FastAPI(title="BiliNote")
|
||||
app.include_router(note.router, prefix="/api")
|
||||
app.include_router(provider.router, prefix="/api")
|
||||
app.include_router(model.router,prefix="/api")
|
||||
return app
|
||||
|
||||
42
backend/app/db/builtin_providers.json
Normal file
42
backend/app/db/builtin_providers.json
Normal file
@@ -0,0 +1,42 @@
|
||||
[
|
||||
{
|
||||
"id": "openai",
|
||||
"name": "OpenAI",
|
||||
"type": "built-in",
|
||||
"logo": "OpenAI",
|
||||
"api_key": "",
|
||||
"base_url": "https://api.openai.com/v1"
|
||||
},
|
||||
{
|
||||
"id": "deepseek",
|
||||
"name": "DeepSeek",
|
||||
"type": "built-in",
|
||||
"logo": "DeepSeek",
|
||||
"api_key": "",
|
||||
"base_url": "https://api.deepseek.com"
|
||||
},
|
||||
{
|
||||
"id": "qwen",
|
||||
"name": "Qwen",
|
||||
"type": "built-in",
|
||||
"logo": "Qwen",
|
||||
"api_key": "",
|
||||
"base_url": "https://qwen.aliyun.com/api"
|
||||
},
|
||||
{
|
||||
"id": "doubao",
|
||||
"name": "豆包 (Doubao)",
|
||||
"type": "built-in",
|
||||
"logo": "Doubao",
|
||||
"api_key": "",
|
||||
"base_url": "https://open.doubao.com/api"
|
||||
},
|
||||
{
|
||||
"id": "Claude",
|
||||
"name": "Claude",
|
||||
"type": "built-in",
|
||||
"logo": "Claude",
|
||||
"api_key": "",
|
||||
"base_url": "https://"
|
||||
}
|
||||
]
|
||||
58
backend/app/db/model_dao.py
Normal file
58
backend/app/db/model_dao.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from app.db.sqlite_client import get_connection
|
||||
|
||||
def init_model_table():
|
||||
conn = get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS models (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
provider_id INTEGER NOT NULL,
|
||||
model_name TEXT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# 插入模型
|
||||
def insert_model(provider_id: int, model_name: str):
|
||||
conn = get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
INSERT INTO models (provider_id, model_name)
|
||||
VALUES (?, ?)
|
||||
""", (provider_id, model_name))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# 根据provider查模型
|
||||
def get_models_by_provider(provider_id: int):
|
||||
conn = get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT id, model_name FROM models
|
||||
WHERE provider_id = ?
|
||||
""", (provider_id,))
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
return [{"id": row[0], "model_name": row[1]} for row in rows]
|
||||
|
||||
# 删除某个模型
|
||||
def delete_model(model_id: int):
|
||||
conn = get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
DELETE FROM models WHERE id = ?
|
||||
""", (model_id,))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def get_all_models():
|
||||
conn = get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT id, provider_id, model_name FROM models
|
||||
""")
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
return [{"id": row[0], "provider_id": row[1], "model_name": row[2]} for row in rows]
|
||||
@@ -1,8 +1,59 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from app.db.sqlite_client import get_connection
|
||||
from app.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
|
||||
def seed_default_providers():
|
||||
conn = get_connection()
|
||||
if conn is None:
|
||||
logger.error("Failed to connect to database.")
|
||||
return
|
||||
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 检查已有数据
|
||||
cursor.execute("SELECT COUNT(*) FROM providers")
|
||||
count = cursor.fetchone()[0]
|
||||
if count > 0:
|
||||
logger.info("Providers already exist, skipping seed.")
|
||||
conn.close()
|
||||
return
|
||||
|
||||
json_path = os.path.join(os.path.dirname(__file__), 'builtin_providers.json')
|
||||
try:
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
providers = json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read builtin_providers.json: {e}")
|
||||
conn.close()
|
||||
return
|
||||
|
||||
try:
|
||||
for p in providers:
|
||||
cursor.execute("""
|
||||
INSERT INTO providers (id, name, api_key, base_url, logo, type, enabled)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
p['id'],
|
||||
p['name'],
|
||||
p['api_key'],
|
||||
p['base_url'],
|
||||
p['logo'],
|
||||
p['type'],
|
||||
p.get('enabled', 1)
|
||||
))
|
||||
|
||||
conn.commit()
|
||||
logger.info("Default providers seeded successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to seed default providers: {e}")
|
||||
finally:
|
||||
conn.close()
|
||||
def init_provider_table():
|
||||
conn = get_connection()
|
||||
if conn is None:
|
||||
@@ -11,40 +62,60 @@ def init_provider_table():
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS providers (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
logo TEXT NOT NULL,
|
||||
type TEXT NOT NULL, -- ✅ 新增字段
|
||||
type TEXT NOT NULL,
|
||||
api_key TEXT NOT NULL,
|
||||
base_url TEXT NOT NULL,
|
||||
enabled INTEGER DEFAULT 1,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
try:
|
||||
conn.commit()
|
||||
conn.close()
|
||||
logger.info("provider table created successfully.")
|
||||
seed_default_providers()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create provider table: {e}")
|
||||
def insert_provider(name: str, api_key: str, base_url: str, logo: str, type_: str):
|
||||
def insert_provider(id: str, name: str, api_key: str, base_url: str, logo: str, type_: str,enabled:int=1):
|
||||
conn = get_connection()
|
||||
if conn is None:
|
||||
logger.error("Failed to connect to the database.")
|
||||
return
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
INSERT INTO providers (name, api_key, base_url, logo, type)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""", (name, api_key, base_url, logo, type_))
|
||||
INSERT INTO providers (id, name, api_key, base_url, logo, type, enabled)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""", (id, name, api_key, base_url, logo, type_, enabled))
|
||||
try:
|
||||
conn.commit()
|
||||
cursor_id = cursor.lastrowid
|
||||
conn.close()
|
||||
logger.info(f"Provider inserted successfully. name: {name}, type: {type_}")
|
||||
return cursor_id
|
||||
logger.info(f"Provider inserted successfully. id: {id}, name: {name}, type: {type_}")
|
||||
return id
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to insert provider: {e}")
|
||||
return None
|
||||
|
||||
def get_enabled_providers():
|
||||
conn = get_connection()
|
||||
if conn is None:
|
||||
logger.error("Failed to connect to the database.")
|
||||
return
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM providers WHERE enabled = 1")
|
||||
try:
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
if rows is None:
|
||||
logger.info("No providers found")
|
||||
return None
|
||||
logger.info(f"Providers found: {rows}")
|
||||
return rows
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get enabled providers: {e}")
|
||||
def get_provider_by_name(name: str):
|
||||
conn = get_connection()
|
||||
if conn is None:
|
||||
@@ -70,6 +141,7 @@ def get_provider_by_id(id: int):
|
||||
return
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM providers WHERE id = ?", (id,))
|
||||
|
||||
try:
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
@@ -99,23 +171,40 @@ def get_all_providers():
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get all providers: {e}")
|
||||
|
||||
def update_provider(id: int, name: str, api_key: str, base_url: str, logo: str, type_: str):
|
||||
def update_provider(id: str, **kwargs):
|
||||
conn = get_connection()
|
||||
if conn is None:
|
||||
logger.error("Failed to connect to the database.")
|
||||
return
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
UPDATE providers
|
||||
SET name = ?, api_key = ?, base_url = ?, logo = ?, type = ?
|
||||
|
||||
fields = []
|
||||
values = []
|
||||
|
||||
for key, value in kwargs.items():
|
||||
fields.append(f"{key} = ?")
|
||||
values.append(value)
|
||||
|
||||
if not fields:
|
||||
logger.warning("No fields provided for update.")
|
||||
return
|
||||
|
||||
sql = f"""
|
||||
UPDATE providers
|
||||
SET {', '.join(fields)}
|
||||
WHERE id = ?
|
||||
""", (name, api_key, base_url, logo, type_, id))
|
||||
"""
|
||||
|
||||
values.append(id) # id 最后加
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute(sql, values)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
logger.info(f"Provider updated successfully. id: {id}, type: {type_}")
|
||||
logger.info(f"Provider updated successfully. id: {id}, updated_fields: {fields}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update provider: {e}")
|
||||
|
||||
def delete_provider(id: int):
|
||||
conn = get_connection()
|
||||
if conn is None:
|
||||
|
||||
28
backend/app/enmus/task_status_enums.py
Normal file
28
backend/app/enmus/task_status_enums.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import enum
|
||||
|
||||
|
||||
class TaskStatus(str, enum.Enum):
|
||||
PENDING = "PENDING"
|
||||
PARSING = "PARSING"
|
||||
DOWNLOADING = "DOWNLOADING"
|
||||
TRANSCRIBING = "TRANSCRIBING"
|
||||
SUMMARIZING = "SUMMARIZING"
|
||||
FORMATTING = "FORMATTING"
|
||||
SAVING = "SAVING"
|
||||
SUCCESS = "SUCCESS"
|
||||
FAILED = "FAILED"
|
||||
|
||||
@classmethod
|
||||
def description(cls, status):
|
||||
desc_map = {
|
||||
cls.PENDING: "排队中",
|
||||
cls.PARSING: "解析链接",
|
||||
cls.DOWNLOADING: "下载中",
|
||||
cls.TRANSCRIBING: "转录中",
|
||||
cls.SUMMARIZING: "总结中",
|
||||
cls.FORMATTING: "格式化中",
|
||||
cls.SAVING: "保存中",
|
||||
cls.SUCCESS: "完成",
|
||||
cls.FAILED: "失败",
|
||||
}
|
||||
return desc_map.get(status, "未知状态")
|
||||
@@ -9,5 +9,5 @@ from app.models.model_config import ModelConfig
|
||||
class GPTFactory:
|
||||
@staticmethod
|
||||
def from_config(config: ModelConfig) -> GPT:
|
||||
client = OpenAICompatibleProvider(api_key=config.api_key, base_url=config.base_url).get_client()
|
||||
client = OpenAICompatibleProvider(api_key=config.api_key, base_url=config.base_url).get_client
|
||||
return UniversalGPT(client=client, model=config.model_name)
|
||||
100
backend/app/gpt/prompt_builder.py
Normal file
100
backend/app/gpt/prompt_builder.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from app.gpt.prompt import BASE_PROMPT
|
||||
|
||||
note_formats = [
|
||||
{'label': '目录', 'value': 'toc'},
|
||||
{'label': '原片跳转', 'value': 'link'},
|
||||
{'label': '原片截图', 'value': 'screenshot'},
|
||||
{'label': 'AI总结', 'value': 'summary'}
|
||||
]
|
||||
|
||||
note_styles = [
|
||||
{'label': '精简', 'value': 'minimal'},
|
||||
{'label': '详细', 'value': 'detailed'},
|
||||
{'label': '学术', 'value': 'academic'},
|
||||
{"label": '教程',"value": 'tutorial', },
|
||||
{'label': '小红书', 'value': 'xiaohongshu'},
|
||||
{'label': '生活向', 'value': 'life_journal'},
|
||||
{'label': '任务导向', 'value': 'task_oriented'},
|
||||
{'label': '商业风格', 'value': 'business'},
|
||||
{'label': '会议纪要', 'value': 'meeting_minutes'}
|
||||
]
|
||||
|
||||
|
||||
# 生成 BASE_PROMPT 函数
|
||||
def generate_base_prompt(title, segment_text, tags, _format=None, style=None, extras=None):
|
||||
# 生成 Base Prompt 开头部分
|
||||
prompt = BASE_PROMPT.format(
|
||||
video_title=title,
|
||||
segment_text=segment_text,
|
||||
tags=tags
|
||||
)
|
||||
|
||||
# 添加用户选择的格式
|
||||
if _format:
|
||||
prompt += "\n" + "\n".join([get_format_function(f) for f in _format])
|
||||
|
||||
# 根据用户选择的笔记风格添加描述
|
||||
if style:
|
||||
prompt += "\n" + get_style_format(style)
|
||||
|
||||
# 添加额外内容
|
||||
if extras:
|
||||
prompt += f"\n{extras}"
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
# 获取格式函数
|
||||
def get_format_function(format_type):
|
||||
format_map = {
|
||||
'toc': get_toc_format,
|
||||
'link': get_link_format,
|
||||
'screenshot': get_screenshot_format,
|
||||
'summary': get_summary_format
|
||||
}
|
||||
return format_map.get(format_type, lambda: '')()
|
||||
|
||||
|
||||
# 风格描述的处理
|
||||
def get_style_format(style):
|
||||
style_map = {
|
||||
'minimal': '1. **精简信息**: 仅记录最重要的内容,简洁明了。',
|
||||
'detailed': '2. **详细记录**: 包含完整的时间戳和每个部分的详细讨论。',
|
||||
'academic': '3. **学术风格**: 适合学术报告,正式且结构化。',
|
||||
|
||||
'xiaohongshu': '4. **小红书风格**: 适合社交平台分享,亲切、口语化。',
|
||||
'life_journal': '5. **生活向**: 记录个人生活感悟,情感化表达。',
|
||||
'task_oriented': '6. **任务导向**: 强调任务、目标,适合工作和待办事项。',
|
||||
'business': '7. **商业风格**: 适合商业报告、会议纪要,正式且精准。',
|
||||
'meeting_minutes': '8. **会议纪要**: 适合商业报告、会议纪要,正式且精准。',
|
||||
"tutorial":"9.**教程笔记**:尽可能详细的记录教程,特别是关键点和一些重要的结论步骤"
|
||||
}
|
||||
return style_map.get(style, '')
|
||||
|
||||
|
||||
# 格式化输出内容
|
||||
def get_toc_format():
|
||||
return '''
|
||||
9. **目录**: 自动生成一个基于 `##` 级标题的目录。不需要插入原片跳转
|
||||
'''
|
||||
|
||||
|
||||
def get_link_format():
|
||||
return '''
|
||||
10. **原片跳转**: 为每个主要章节添加时间戳,使用格式 `*Content-[mm:ss]`。
|
||||
重要:**始终**在章节标题前加上 `*Content` 前缀,例如:`AI 的发展史 *Content-[01:23]`。一定是标题在前 插入标记在后
|
||||
'''
|
||||
|
||||
|
||||
def get_screenshot_format():
|
||||
return '''
|
||||
11. **原片截图**: 如果某个部分涉及**视觉演示**或任何能帮助理解的内容,插入截图提示:
|
||||
- 格式:`*Screenshot-[mm:ss]`
|
||||
至少插入 1-3张截图
|
||||
'''
|
||||
|
||||
|
||||
def get_summary_format():
|
||||
return '''
|
||||
12. **AI总结**: 在笔记末尾加入简短的AI生成总结,并且二级标题 就是 AI 总结 例如 ## AI 总结。
|
||||
'''
|
||||
@@ -1,4 +1,5 @@
|
||||
from app.gpt.base import GPT
|
||||
from app.gpt.prompt_builder import generate_base_prompt
|
||||
from app.models.gpt_model import GPTSource
|
||||
from app.gpt.prompt import BASE_PROMPT, AI_SUM, SCREENSHOT, LINK
|
||||
from app.gpt.utils import fix_markdown
|
||||
@@ -28,29 +29,35 @@ class UniversalGPT(GPT):
|
||||
return [TranscriptSegment(**seg) if isinstance(seg, dict) else seg for seg in segments]
|
||||
|
||||
def create_messages(self, segments: List[TranscriptSegment],**kwargs):
|
||||
content = BASE_PROMPT.format(
|
||||
video_title=kwargs.get('title'),
|
||||
print("UniversalGPT",kwargs)
|
||||
content =generate_base_prompt(
|
||||
title=kwargs.get('title'),
|
||||
segment_text=self._build_segment_text(segments),
|
||||
tags=kwargs.get('tags')
|
||||
tags=kwargs.get('tags'),
|
||||
_format=kwargs.get('_format'),
|
||||
style=kwargs.get('style'),
|
||||
extras=kwargs.get('extras')
|
||||
)
|
||||
if self.screenshot:
|
||||
print(":需要截图")
|
||||
content += SCREENSHOT
|
||||
if self.link:
|
||||
print(":需要链接")
|
||||
content += LINK
|
||||
|
||||
print(content)
|
||||
return [{"role": "user", "content": content + AI_SUM}]
|
||||
return [{"role": "user", "content": content }]
|
||||
|
||||
def list_models(self):
|
||||
return self.client.list_models()
|
||||
return self.client.models.list()
|
||||
def summarize(self, source: GPTSource) -> str:
|
||||
self.screenshot = source.screenshot
|
||||
self.link = source.link
|
||||
source.segment = self.ensure_segments_type(source.segment)
|
||||
messages = self.create_messages(source.segment, source.title,source.tags)
|
||||
response = self.client.chat(
|
||||
|
||||
messages = self.create_messages(
|
||||
source.segment,
|
||||
title=source.title,
|
||||
tags=source.tags
|
||||
,
|
||||
_format=source._format,
|
||||
style=source.style,
|
||||
extras=source.extras
|
||||
)
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
temperature=0.7
|
||||
|
||||
36
backend/app/routers/model.py
Normal file
36
backend/app/routers/model.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.services.model import ModelService
|
||||
from app.utils.response import ResponseWrapper as R
|
||||
router = APIRouter()
|
||||
modelService = ModelService()
|
||||
class CreateModelRequest(BaseModel):
|
||||
provider_id: str
|
||||
model_name: str
|
||||
|
||||
# 返回体:模型信息
|
||||
class ModelItem(BaseModel):
|
||||
id: int
|
||||
model_name: str
|
||||
@router.get("/model_list")
|
||||
def model_list():
|
||||
try:
|
||||
return R.success(modelService.get_all_models(True),msg="获取模型列表成功")
|
||||
except Exception as e:
|
||||
return R.error(e)
|
||||
|
||||
@router.get("/model_list/{provider_id}")
|
||||
def model_list(provider_id):
|
||||
try:
|
||||
return R.success(modelService.get_all_models_by_id(provider_id))
|
||||
except Exception as e:
|
||||
return R.error(e)
|
||||
|
||||
@router.post("/models")
|
||||
def create_model(data: CreateModelRequest):
|
||||
success = ModelService.add_new_model(data.provider_id, data.model_name)
|
||||
if not success:
|
||||
raise R.error("模型添加失败")
|
||||
return R.success(msg="模型添加成功")
|
||||
|
||||
@@ -10,13 +10,14 @@ from dataclasses import asdict
|
||||
|
||||
from app.db.video_task_dao import get_task_by_video
|
||||
from app.enmus.note_enums import DownloadQuality
|
||||
from app.services.note import NoteGenerator
|
||||
from app.services.note import NoteGenerator, logger
|
||||
from app.utils.response import ResponseWrapper as R
|
||||
from app.utils.url_parser import extract_video_id
|
||||
from app.validators.video_url_validator import is_supported_video_url
|
||||
from fastapi import APIRouter, Request, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
import httpx
|
||||
from app.enmus.task_status_enums import TaskStatus
|
||||
|
||||
# from app.services.downloader import download_raw_audio
|
||||
# from app.services.whisperer import transcribe_audio
|
||||
@@ -35,6 +36,12 @@ class VideoRequest(BaseModel):
|
||||
quality: DownloadQuality
|
||||
screenshot: Optional[bool] = False
|
||||
link: Optional[bool] = False
|
||||
model_name:str
|
||||
provider_id:str
|
||||
task_id: Optional[str] = None
|
||||
format:Optional[list]=[]
|
||||
style:str=None
|
||||
extras:Optional[str]
|
||||
|
||||
@validator("video_url")
|
||||
def validate_supported_url(cls, v):
|
||||
@@ -54,14 +61,24 @@ def save_note_to_file(task_id: str, note):
|
||||
json.dump(asdict(note), f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def run_note_task(task_id: str, video_url: str, platform: str, quality: DownloadQuality, link: bool = False,screenshot: bool = False):
|
||||
def run_note_task(task_id: str, video_url: str, platform: str, quality: DownloadQuality,
|
||||
link: bool = False,screenshot: bool = False,model_name:str=None,provider_id:str=None,
|
||||
_format:list=None,style:str=None,extras:str=None):
|
||||
try:
|
||||
if not model_name or not provider_id:
|
||||
raise HTTPException(status_code=400, detail="请选择模型和提供者")
|
||||
|
||||
note = NoteGenerator().generate(
|
||||
video_url=video_url,
|
||||
platform=platform,
|
||||
quality=quality,
|
||||
task_id=task_id,
|
||||
model_name=model_name,
|
||||
provider_id=provider_id,
|
||||
link=link,
|
||||
_format=_format,
|
||||
style=style,
|
||||
extras=extras,
|
||||
screenshot=screenshot
|
||||
)
|
||||
print('Note 结果',note)
|
||||
@@ -85,38 +102,91 @@ def generate_note(data: VideoRequest, background_tasks: BackgroundTasks):
|
||||
try:
|
||||
|
||||
video_id = extract_video_id(data.video_url, data.platform)
|
||||
if not video_id:
|
||||
raise HTTPException(status_code=400, detail="无法提取视频 ID")
|
||||
existing = get_task_by_video(video_id, data.platform)
|
||||
if existing:
|
||||
return R.error(
|
||||
msg='笔记已生成,请勿重复发起',
|
||||
# if not video_id:
|
||||
# raise HTTPException(status_code=400, detail="无法提取视频 ID")
|
||||
# existing = get_task_by_video(video_id, data.platform)
|
||||
# if existing:
|
||||
# return R.error(
|
||||
# msg='笔记已生成,请勿重复发起',
|
||||
#
|
||||
# )
|
||||
|
||||
)
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
if data.task_id:
|
||||
# 如果传了task_id,说明是重试!
|
||||
task_id = data.task_id
|
||||
# 更新之前的状态
|
||||
NoteGenerator.update_task_status(task_id, TaskStatus.PENDING)
|
||||
logger.info(f"重试模式,复用已有 task_id={task_id}")
|
||||
else:
|
||||
# 正常新建任务
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
background_tasks.add_task(run_note_task, task_id, data.video_url, data.platform, data.quality,data.link ,data.screenshot)
|
||||
background_tasks.add_task(run_note_task, task_id, data.video_url, data.platform, data.quality,data.link ,data.screenshot,data.model_name,data.provider_id,data.format,data.style,data.extras)
|
||||
return R.success({"task_id": task_id})
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
|
||||
|
||||
@router.get("/task_status/{task_id}")
|
||||
def get_task_status(task_id: str):
|
||||
path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json")
|
||||
if not os.path.exists(path):
|
||||
return R.success({"status": "PENDING"})
|
||||
status_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.status.json")
|
||||
result_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json")
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
content = json.load(f)
|
||||
# 优先读状态文件
|
||||
if os.path.exists(status_path):
|
||||
with open(status_path, "r", encoding="utf-8") as f:
|
||||
status_content = json.load(f)
|
||||
|
||||
if "error" in content:
|
||||
return R.error(content["error"], code=500)
|
||||
content['id'] = task_id
|
||||
status = status_content.get("status")
|
||||
message = status_content.get("message", "")
|
||||
|
||||
if status == TaskStatus.SUCCESS.value:
|
||||
# 成功状态的话,继续读取最终笔记内容
|
||||
if os.path.exists(result_path):
|
||||
with open(result_path, "r", encoding="utf-8") as rf:
|
||||
result_content = json.load(rf)
|
||||
return R.success({
|
||||
"status": status,
|
||||
"result": result_content,
|
||||
"message": message,
|
||||
"task_id": task_id
|
||||
})
|
||||
else:
|
||||
# 理论上不会出现,保险处理
|
||||
return R.success({
|
||||
"status": TaskStatus.PENDING.value,
|
||||
"message": "任务完成,但结果文件未找到",
|
||||
"task_id": task_id
|
||||
})
|
||||
|
||||
if status == TaskStatus.FAILED.value:
|
||||
return R.error(message or "任务失败", code=500)
|
||||
|
||||
# 处理中状态
|
||||
return R.success({
|
||||
"status": status,
|
||||
"message": message,
|
||||
"task_id": task_id
|
||||
})
|
||||
|
||||
# 没有状态文件,但有结果
|
||||
if os.path.exists(result_path):
|
||||
with open(result_path, "r", encoding="utf-8") as f:
|
||||
result_content = json.load(f)
|
||||
return R.success({
|
||||
"status": TaskStatus.SUCCESS.value,
|
||||
"result": result_content,
|
||||
"task_id": task_id
|
||||
})
|
||||
|
||||
# 什么都没有,默认PENDING
|
||||
return R.success({
|
||||
"status": "SUCCESS",
|
||||
"result": content
|
||||
"status": TaskStatus.PENDING.value,
|
||||
"message": "任务排队中",
|
||||
"task_id": task_id
|
||||
})
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.models.model_config import ModelConfig
|
||||
from app.services.model import ModelService
|
||||
from app.utils.response import ResponseWrapper as R
|
||||
from app.services.provider import ProviderService
|
||||
|
||||
@@ -11,16 +14,21 @@ class ProviderRequest(BaseModel):
|
||||
name: str
|
||||
api_key: str
|
||||
base_url: str
|
||||
logo: str
|
||||
logo: Optional[str] = None
|
||||
type: str
|
||||
|
||||
class TestRequest(BaseModel):
|
||||
|
||||
api_key: str
|
||||
base_url:str
|
||||
class ProviderUpdateRequest(BaseModel):
|
||||
id: int
|
||||
id: str
|
||||
name: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
logo: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
enabled:Optional[int] = None
|
||||
|
||||
@router.post("/add_provider")
|
||||
def add_provider(data: ProviderRequest):
|
||||
@@ -45,7 +53,7 @@ def get_all_providers():
|
||||
return R.error(msg=e)
|
||||
|
||||
@router.get("/get_provider_by_id/{id}")
|
||||
def get_provider_by_id(id: int):
|
||||
def get_provider_by_id(id: str):
|
||||
try:
|
||||
res = ProviderService.get_provider_by_id(id)
|
||||
return R.success(data=res)
|
||||
@@ -60,23 +68,33 @@ def get_provider_by_name(name: str):
|
||||
except Exception as e:
|
||||
return R.error(msg=e)
|
||||
|
||||
@router.post("/update_provider/")
|
||||
|
||||
@router.post("/update_provider")
|
||||
def update_provider(data: ProviderUpdateRequest):
|
||||
try:
|
||||
if all(
|
||||
field is None
|
||||
for field in [data.name, data.api_key, data.base_url, data.logo, data.type]
|
||||
for field in [data.name, data.api_key, data.base_url, data.logo, data.type,data.enabled]
|
||||
):
|
||||
return R.error(msg='请至少填写一个参数')
|
||||
|
||||
ProviderService.update_provider(
|
||||
id=data.id,
|
||||
name=data.name or '',
|
||||
api_key=data.api_key or '',
|
||||
base_url=data.base_url or '',
|
||||
logo=data.logo or '',
|
||||
type_=data.type or ''
|
||||
data=dict(data)
|
||||
)
|
||||
return R.success(msg='更新模型供应商成功')
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return R.error(msg=e)
|
||||
|
||||
@router.post('/connect_test')
|
||||
def gpt_connect_test(data:TestRequest):
|
||||
try:
|
||||
|
||||
res= ModelService().connect_test(data.api_key,data.base_url)
|
||||
if not res:
|
||||
return R.error(msg='连接失败')
|
||||
return R.success(msg='连接成功')
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return R.error(msg=e)
|
||||
@@ -1,23 +1,109 @@
|
||||
from app.db.model_dao import insert_model, get_all_models
|
||||
from app.db.provider_dao import get_enabled_providers
|
||||
from app.gpt.gpt_factory import GPTFactory
|
||||
from app.gpt.provider.OpenAI_compatible_provider import OpenAICompatibleProvider
|
||||
from app.models.model_config import ModelConfig
|
||||
from app.services.provider import ProviderService
|
||||
|
||||
|
||||
class ModelService:
|
||||
|
||||
@staticmethod
|
||||
def get_model_list(provider_id: int):
|
||||
provider=ProviderService.get_provider_by_id(provider_id)
|
||||
def _build_model_config(provider: dict) -> ModelConfig:
|
||||
return ModelConfig(
|
||||
api_key=provider["api_key"],
|
||||
base_url=provider["base_url"],
|
||||
provider=provider["name"],
|
||||
model_name='',
|
||||
name=provider["name"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_model_list(provider_id: int, verbose: bool = False):
|
||||
provider = ProviderService.get_provider_by_id(provider_id)
|
||||
if not provider:
|
||||
return []
|
||||
config=ModelConfig(
|
||||
api_key=provider.api_key,
|
||||
base_url=provider.base_url,
|
||||
provider=provider.name,
|
||||
model_name='',
|
||||
name=provider.name,
|
||||
)
|
||||
GPT=GPTFactory().from_config(config)
|
||||
return GPT.list_models()
|
||||
|
||||
try:
|
||||
config = ModelService._build_model_config(provider)
|
||||
gpt = GPTFactory().from_config(config)
|
||||
models = gpt.list_models()
|
||||
if verbose:
|
||||
print(f"[{provider['name']}] 模型列表: {models}")
|
||||
return models
|
||||
except Exception as e:
|
||||
print(f"[{provider['name']}] 获取模型失败: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_all_models(verbose: bool = False):
|
||||
try:
|
||||
raw_models = get_all_models()
|
||||
if verbose:
|
||||
print(f"所有模型列表: {raw_models}")
|
||||
return ModelService._format_models(raw_models)
|
||||
except Exception as e:
|
||||
print(f"获取所有模型失败: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _format_models(raw_models: list) -> list:
|
||||
"""
|
||||
格式化模型列表
|
||||
"""
|
||||
formatted = []
|
||||
for model in raw_models:
|
||||
formatted.append({
|
||||
"id": model.get("id"),
|
||||
"provider_id": model.get("provider_id"),
|
||||
"model_name": model.get("model_name"),
|
||||
"created_at": model.get("created_at", None), # 如果有created_at字段
|
||||
})
|
||||
return formatted
|
||||
@staticmethod
|
||||
def get_all_models_by_id(provider_id: str, verbose: bool = False):
|
||||
try:
|
||||
provider = ProviderService.get_provider_by_id(provider_id)
|
||||
|
||||
models = ModelService.get_model_list(provider["id"], verbose=verbose)
|
||||
|
||||
model_list={
|
||||
|
||||
"models": models
|
||||
}
|
||||
|
||||
return model_list
|
||||
except Exception as e:
|
||||
print(f"[{provider_id}] 获取模型失败: {e}")
|
||||
return []
|
||||
@staticmethod
|
||||
def connect_test(api_key: str, base_url: str) -> bool:
|
||||
try:
|
||||
return OpenAICompatibleProvider.test_connection(api_key=api_key, base_url=base_url)
|
||||
except Exception as e:
|
||||
print(f"连接测试失败:{e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def add_new_model(provider_id: int, model_name: str) -> bool:
|
||||
try:
|
||||
# 先查供应商是否存在
|
||||
provider = ProviderService.get_provider_by_id(provider_id)
|
||||
if not provider:
|
||||
print(f"供应商ID {provider_id} 不存在,无法添加模型")
|
||||
return False
|
||||
|
||||
# 插入模型
|
||||
insert_model(provider_id=provider_id, model_name=model_name)
|
||||
print(f"模型 {model_name} 已成功添加到供应商ID {provider_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"添加模型失败: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(ModelService.get_model_list(1))
|
||||
# 单个 Provider 测试
|
||||
print(ModelService.get_model_list(1, verbose=True))
|
||||
|
||||
# 所有 Provider 模型测试
|
||||
# print(ModelService.get_all_models(verbose=True))
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
import json
|
||||
from dataclasses import asdict
|
||||
|
||||
from app.enmus.task_status_enums import TaskStatus
|
||||
import os
|
||||
from typing import Union
|
||||
from typing import Union, Optional
|
||||
|
||||
from pydantic import HttpUrl
|
||||
|
||||
@@ -10,13 +14,17 @@ from app.downloaders.douyin_downloader import DouyinDownloader
|
||||
from app.downloaders.youtube_downloader import YoutubeDownloader
|
||||
from app.gpt.base import GPT
|
||||
from app.gpt.deepseek_gpt import DeepSeekGPT
|
||||
from app.gpt.gpt_factory import GPTFactory
|
||||
from app.gpt.openai_gpt import OpenaiGPT
|
||||
from app.gpt.qwen_gpt import QwenGPT
|
||||
from app.models.gpt_model import GPTSource
|
||||
from app.models.model_config import ModelConfig
|
||||
from app.models.notes_model import NoteResult
|
||||
from app.models.notes_model import AudioDownloadResult
|
||||
from app.enmus.note_enums import DownloadQuality
|
||||
from app.models.transcriber_model import TranscriptResult
|
||||
from app.models.transcriber_model import TranscriptResult, TranscriptSegment
|
||||
|
||||
from app.services.provider import ProviderService
|
||||
from app.transcriber.base import Transcriber
|
||||
from app.transcriber.transcriber_provider import get_transcriber,_transcribers
|
||||
from app.transcriber.whisper import WhisperTranscriber
|
||||
@@ -29,6 +37,8 @@ from app.utils.video_helper import generate_screenshot
|
||||
# from app.services.gpt import summarize_text
|
||||
from dotenv import load_dotenv
|
||||
from app.utils.logger import get_logger
|
||||
from events import transcription_finished
|
||||
|
||||
logger = get_logger(__name__)
|
||||
load_dotenv()
|
||||
BACKEND_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:8000")
|
||||
@@ -37,7 +47,7 @@ output_dir = os.getenv('OUT_DIR')
|
||||
image_base_url = os.getenv('IMAGE_BASE_URL')
|
||||
logger.info("starting up")
|
||||
|
||||
|
||||
NOTE_OUTPUT_DIR = "note_results"
|
||||
|
||||
class NoteGenerator:
|
||||
def __init__(self):
|
||||
@@ -45,26 +55,39 @@ class NoteGenerator:
|
||||
self.device: Union[str, None] = None
|
||||
self.transcriber_type = os.getenv('TRANSCRIBER_TYPE','fast-whisper')
|
||||
self.transcriber = self.get_transcriber()
|
||||
# TODO 需要更换为可调节
|
||||
|
||||
self.provider = os.getenv('MODEl_PROVIDER','openai')
|
||||
self.video_path = None
|
||||
logger.info("初始化NoteGenerator")
|
||||
|
||||
import logging
|
||||
|
||||
def get_gpt(self) -> GPT:
|
||||
if self.provider == 'openai':
|
||||
logger.info("使用OpenAI")
|
||||
return OpenaiGPT()
|
||||
elif self.provider == 'deepSeek':
|
||||
logger.info("使用DeepSeek")
|
||||
return DeepSeekGPT()
|
||||
elif self.provider == 'qwen':
|
||||
logger.info("使用Qwen")
|
||||
return QwenGPT()
|
||||
else:
|
||||
logger.warning("不支持的AI提供商")
|
||||
raise ValueError(f"不支持的AI提供商:{self.provider}")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@staticmethod
|
||||
def update_task_status(task_id: str, status: Union[str, TaskStatus], message: Optional[str] = None):
|
||||
os.makedirs(NOTE_OUTPUT_DIR, exist_ok=True)
|
||||
path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.status.json")
|
||||
content = {"status": status.value if isinstance(status, TaskStatus) else status}
|
||||
if message:
|
||||
content["message"] = message
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(content, f, ensure_ascii=False, indent=2)
|
||||
|
||||
def get_gpt(self, model_name: str = None, provider_id: str = None) -> GPT:
|
||||
provider = ProviderService.get_provider_by_id(provider_id)
|
||||
if not provider:
|
||||
logger.error(f"[get_gpt] 未找到对应的模型供应商: provider_id={provider_id}")
|
||||
raise ValueError(f"未找到对应的模型供应商: provider_id={provider_id}")
|
||||
|
||||
gpt = GPTFactory().from_config(
|
||||
ModelConfig(
|
||||
api_key=provider.get('api_key'),
|
||||
base_url=provider.get('base_url'),
|
||||
model_name=model_name,
|
||||
provider=provider.get('type'),
|
||||
name=provider.get('name')
|
||||
)
|
||||
)
|
||||
return gpt
|
||||
|
||||
def get_downloader(self, platform: str) -> Downloader:
|
||||
if platform == "bilibili":
|
||||
@@ -98,7 +121,7 @@ class NoteGenerator:
|
||||
insert_video_task(video_id=video_id, platform=platform, task_id=task_id)
|
||||
|
||||
def insert_screenshots_into_markdown(self, markdown: str, video_path: str, image_base_url: str,
|
||||
output_dir: str) -> str:
|
||||
output_dir: str,_format:list) -> str:
|
||||
"""
|
||||
扫描 markdown 中的 *Screenshot-xx:xx,生成截图并插入 markdown 图片
|
||||
:param markdown:
|
||||
@@ -145,62 +168,143 @@ class NoteGenerator:
|
||||
|
||||
def generate(
|
||||
self,
|
||||
|
||||
video_url: Union[str, HttpUrl],
|
||||
platform: str,
|
||||
quality: DownloadQuality = DownloadQuality.medium,
|
||||
task_id: Union[str, None] = None,
|
||||
model_name: str = None,
|
||||
provider_id: str = None,
|
||||
link: bool = False,
|
||||
screenshot: bool = False,
|
||||
_format: list = None,
|
||||
style: str = None,
|
||||
extras: str = None,
|
||||
path: Union[str, None] = None
|
||||
|
||||
) -> NoteResult:
|
||||
logger.info(f"开始解析并生成笔记")
|
||||
# 1. 选择下载器
|
||||
downloader = self.get_downloader(platform)
|
||||
gpt = self.get_gpt()
|
||||
logger.info(f'使用{downloader.__class__.__name__}下载器\n'
|
||||
f'使用{gpt.__class__.__name__}GPT\n'
|
||||
f'视频地址:{video_url}')
|
||||
if screenshot:
|
||||
try:
|
||||
logger.info(f"🎯 开始解析并生成笔记,task_id={task_id}")
|
||||
self.update_task_status(task_id, TaskStatus.PARSING)
|
||||
_path=''
|
||||
downloader = self.get_downloader(platform)
|
||||
gpt = self.get_gpt(model_name=model_name, provider_id=provider_id)
|
||||
|
||||
video_path = downloader.download_video(video_url)
|
||||
self.video_path = video_path
|
||||
print(video_path)
|
||||
audio_cache_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}_audio.json")
|
||||
transcript_cache_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}_transcript.json")
|
||||
markdown_cache_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}_markdown.md")
|
||||
|
||||
# 2. 下载音频
|
||||
audio: AudioDownloadResult = downloader.download(
|
||||
video_url=video_url,
|
||||
quality=quality,
|
||||
output_dir=path,
|
||||
need_video=screenshot
|
||||
# -------- 1. 下载音频 --------
|
||||
try:
|
||||
self.update_task_status(task_id, TaskStatus.DOWNLOADING)
|
||||
if os.path.exists(audio_cache_path):
|
||||
logger.info(f"检测到已有音频缓存,直接读取,task_id={task_id}")
|
||||
with open(audio_cache_path, "r", encoding="utf-8") as f:
|
||||
audio_data = json.load(f)
|
||||
audio = AudioDownloadResult(**audio_data)
|
||||
else:
|
||||
if 'screenshot' in _format:
|
||||
video_path = downloader.download_video(video_url)
|
||||
self.video_path = video_path
|
||||
logger.info(f"成功下载视频文件: {video_path}")
|
||||
screenshot= 'screenshot' in _format
|
||||
audio: AudioDownloadResult = downloader.download(
|
||||
video_url=video_url,
|
||||
quality=quality,
|
||||
output_dir=path,
|
||||
need_video=screenshot
|
||||
)
|
||||
_path=audio.raw_info.get('path')
|
||||
with open(audio_cache_path, "w", encoding="utf-8") as f:
|
||||
json.dump(audio.__dict__, f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"音频下载并缓存成功,task_id={task_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 下载音频失败,task_id={task_id},错误信息:{e}")
|
||||
self.update_task_status(task_id, TaskStatus.FAILED, message=f"下载音频失败:{e}")
|
||||
raise e
|
||||
|
||||
# -------- 2. 转写文字 --------
|
||||
try:
|
||||
self.update_task_status(task_id, TaskStatus.TRANSCRIBING)
|
||||
if os.path.exists(transcript_cache_path):
|
||||
logger.info(f"检测到已有转写缓存,直接读取,task_id={task_id}")
|
||||
with open(transcript_cache_path, "r", encoding="utf-8") as f:
|
||||
transcript_data = json.load(f)
|
||||
transcript = TranscriptResult(
|
||||
language=transcript_data["language"],
|
||||
full_text=transcript_data["full_text"],
|
||||
segments=[TranscriptSegment(**seg) for seg in transcript_data["segments"]]
|
||||
)
|
||||
else:
|
||||
transcript: TranscriptResult = self.transcriber.transcript(file_path=audio.file_path)
|
||||
with open(transcript_cache_path, "w", encoding="utf-8") as f:
|
||||
json.dump(asdict(transcript), f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"文字转写并缓存成功,task_id={task_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 转写文字失败,task_id={task_id},错误信息:{e}")
|
||||
self.update_task_status(task_id, TaskStatus.FAILED, message=f"转写文字失败:{e}")
|
||||
raise e
|
||||
|
||||
# -------- 3. 总结内容 --------
|
||||
try:
|
||||
self.update_task_status(task_id, TaskStatus.SUMMARIZING)
|
||||
if os.path.exists(markdown_cache_path):
|
||||
logger.info(f"检测到已有总结缓存,直接读取,task_id={task_id}")
|
||||
with open(markdown_cache_path, "r", encoding="utf-8") as f:
|
||||
markdown = f.read()
|
||||
else:
|
||||
source = GPTSource(
|
||||
title=audio.title,
|
||||
segment=transcript.segments,
|
||||
tags=audio.raw_info.get('tags'),
|
||||
screenshot=screenshot,
|
||||
link=link,
|
||||
_format=_format,
|
||||
style=style,
|
||||
extras=extras
|
||||
)
|
||||
|
||||
markdown: str = gpt.summarize(source)
|
||||
with open(markdown_cache_path, "w", encoding="utf-8") as f:
|
||||
f.write(markdown)
|
||||
logger.info(f"GPT总结并缓存成功,task_id={task_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 总结内容失败,task_id={task_id},错误信息:{e}")
|
||||
self.update_task_status(task_id, TaskStatus.FAILED, message=f"总结内容失败:{e}")
|
||||
raise e
|
||||
|
||||
# -------- 4. 插入截图 --------
|
||||
if _format and 'screenshot' in _format:
|
||||
try:
|
||||
markdown = self.insert_screenshots_into_markdown(markdown, self.video_path, image_base_url, output_dir,_format)
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 插入截图失败,跳过处理,task_id={task_id},错误信息:{e}")
|
||||
if _format and 'link' in _format:
|
||||
try:
|
||||
markdown = replace_content_markers(markdown, video_id=audio.video_id,platform=platform)
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 插入链接失败,跳过处理,task_id={task_id},错误信息:{e}")
|
||||
# 注意:截图失败不终止整体流程
|
||||
|
||||
# -------- 5. 保存数据库记录 --------
|
||||
self.update_task_status(task_id, TaskStatus.SAVING)
|
||||
self.save_meta(video_id=audio.video_id, platform=platform, task_id=task_id)
|
||||
|
||||
# -------- 6. 完成 --------
|
||||
self.update_task_status(task_id, TaskStatus.SUCCESS)
|
||||
logger.info(f"✅ 笔记生成成功,task_id={task_id}")
|
||||
transcription_finished.send({
|
||||
"file_path": audio.file_path,
|
||||
})
|
||||
return NoteResult(
|
||||
markdown=markdown,
|
||||
transcript=transcript,
|
||||
audio_meta=audio
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 笔记生成流程异常终止,task_id={task_id},错误信息:{e}")
|
||||
self.update_task_status(task_id, TaskStatus.FAILED, message=str(e))
|
||||
raise f'❌ 笔记生成流程异常终止,task_id={task_id},错误信息:{e}'
|
||||
|
||||
)
|
||||
logger.info(f"下载音频成功,文件路径:{audio.file_path}")
|
||||
# 3. Whisper 转写
|
||||
transcript: TranscriptResult = self.transcriber.transcript(file_path=audio.file_path)
|
||||
logger.info(f"Whisper 转写成功,转写结果:{transcript.full_text}")
|
||||
# 4. GPT 总结
|
||||
source = GPTSource(
|
||||
title=audio.title,
|
||||
segment=transcript.segments,
|
||||
tags=audio.raw_info.get('tags'),
|
||||
screenshot=screenshot,
|
||||
link=link
|
||||
)
|
||||
logger.info(f"GPT 总结完成,总结结果:{source}")
|
||||
markdown: str = gpt.summarize(source)
|
||||
print("markdown结果", markdown)
|
||||
|
||||
markdown = replace_content_markers(markdown=markdown, video_id=audio.video_id, platform=platform)
|
||||
if self.video_path:
|
||||
markdown = self.insert_screenshots_into_markdown(markdown, self.video_path, image_base_url, output_dir)
|
||||
self.save_meta(video_id=audio.video_id, platform=platform, task_id=task_id)
|
||||
# 5. 返回结构体
|
||||
return NoteResult(
|
||||
markdown=markdown,
|
||||
transcript=transcript,
|
||||
audio_meta=audio
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from kombu import uuid
|
||||
|
||||
from app.db.provider_dao import (
|
||||
insert_provider,
|
||||
init_provider_table,
|
||||
@@ -5,50 +7,65 @@ from app.db.provider_dao import (
|
||||
get_provider_by_name,
|
||||
get_provider_by_id,
|
||||
update_provider,
|
||||
delete_provider,
|
||||
delete_provider, get_enabled_providers,
|
||||
)
|
||||
from app.gpt.gpt_factory import GPTFactory
|
||||
from app.models.model_config import ModelConfig
|
||||
|
||||
|
||||
class ProviderService:
|
||||
@staticmethod
|
||||
def serialize_provider(row: tuple) -> dict:
|
||||
if not row:
|
||||
return None
|
||||
return {
|
||||
"id": row[0],
|
||||
"name": row[1],
|
||||
"logo": row[2],
|
||||
"type": row[3],
|
||||
"api_key": row[4],
|
||||
"base_url": row[5],
|
||||
"enabled": row[6],
|
||||
"created_at": row[7],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def add_provider(name: str, api_key: str, base_url: str, logo: str, type_: str):
|
||||
return insert_provider(name, api_key, base_url, logo, type_)
|
||||
def add_provider( name: str, api_key: str, base_url: str, logo: str, type_: str, enabled: int = 1):
|
||||
try:
|
||||
id = uuid().lower()
|
||||
logo='custom'
|
||||
return insert_provider(id, name, api_key, base_url, logo, type_, enabled)
|
||||
except Exception as e:
|
||||
print('创建模式失败',e)
|
||||
|
||||
@staticmethod
|
||||
def get_all_providers():
|
||||
provider_list = []
|
||||
provider = get_all_providers()
|
||||
|
||||
for i in provider:
|
||||
provider_list.append({
|
||||
"id": i[0],
|
||||
"name": i[1],
|
||||
"logo": i[2],
|
||||
"type": i[3], # ✅ 加上类型
|
||||
"api_key": i[4],
|
||||
"base_url": i[5],
|
||||
})
|
||||
return provider_list
|
||||
rows = get_all_providers()
|
||||
return [ProviderService.serialize_provider(row) for row in rows] if rows else []
|
||||
|
||||
@staticmethod
|
||||
def get_provider_by_name(name: str):
|
||||
return get_provider_by_name(name)
|
||||
row = get_provider_by_name(name)
|
||||
return ProviderService.serialize_provider(row)
|
||||
|
||||
@staticmethod
|
||||
def get_provider_by_id(id: int):
|
||||
return get_provider_by_id(id)
|
||||
def get_provider_by_id(id: str): # 已改为 str 类型
|
||||
row = get_provider_by_id(id)
|
||||
return ProviderService.serialize_provider(row)
|
||||
|
||||
# all_models.extend(provider['models'])
|
||||
|
||||
@staticmethod
|
||||
def update_provider(
|
||||
id: int,
|
||||
name: str,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
logo: str,
|
||||
type_: str
|
||||
):
|
||||
return update_provider(id, name, api_key, base_url, logo, type_)
|
||||
def update_provider(id: str, data: dict):
|
||||
try:
|
||||
# 过滤掉空值
|
||||
filtered_data = {k: v for k, v in data.items() if v is not None and k != 'id'}
|
||||
print('更新模型供应商',filtered_data)
|
||||
return update_provider(id, **filtered_data)
|
||||
|
||||
except Exception as e:
|
||||
print('更新模型供应商失败:',e)
|
||||
|
||||
@staticmethod
|
||||
def delete_provider(id: int):
|
||||
return delete_provider(id)
|
||||
def delete_provider(id: str):
|
||||
return delete_provider(id)
|
||||
|
||||
251
backend/app/transcriber/bcut.py
Normal file
251
backend/app/transcriber/bcut.py
Normal file
@@ -0,0 +1,251 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, List, Dict, Union
|
||||
|
||||
import requests
|
||||
|
||||
from app.decorators.timeit import timeit
|
||||
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
|
||||
from app.transcriber.base import Transcriber
|
||||
from app.utils.logger import get_logger
|
||||
from events import transcription_finished
|
||||
|
||||
__version__ = "0.0.3"
|
||||
|
||||
API_BASE_URL = "https://member.bilibili.com/x/bcut/rubick-interface"
|
||||
|
||||
# 申请上传
|
||||
API_REQ_UPLOAD = API_BASE_URL + "/resource/create"
|
||||
|
||||
# 提交上传
|
||||
API_COMMIT_UPLOAD = API_BASE_URL + "/resource/create/complete"
|
||||
|
||||
# 创建任务
|
||||
API_CREATE_TASK = API_BASE_URL + "/task"
|
||||
|
||||
# 查询结果
|
||||
API_QUERY_RESULT = API_BASE_URL + "/task/result"
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class BcutTranscriber(Transcriber):
|
||||
"""必剪 语音识别接口"""
|
||||
headers = {
|
||||
'User-Agent': 'Bilibili/1.0.0 (https://www.bilibili.com)',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.session = requests.Session()
|
||||
self.task_id = None
|
||||
self.__etags = []
|
||||
|
||||
self.__in_boss_key: Optional[str] = None
|
||||
self.__resource_id: Optional[str] = None
|
||||
self.__upload_id: Optional[str] = None
|
||||
self.__upload_urls: List[str] = []
|
||||
self.__per_size: Optional[int] = None
|
||||
self.__clips: Optional[int] = None
|
||||
|
||||
self.__etags: List[str] = []
|
||||
self.__download_url: Optional[str] = None
|
||||
self.task_id: Optional[str] = None
|
||||
|
||||
def _load_file(self, file_path: str) -> bytes:
|
||||
"""读取文件内容"""
|
||||
with open(file_path, 'rb') as f:
|
||||
return f.read()
|
||||
|
||||
def _upload(self, file_path: str) -> None:
|
||||
"""申请上传"""
|
||||
file_binary = self._load_file(file_path)
|
||||
if not file_binary:
|
||||
raise ValueError("无法读取文件数据")
|
||||
|
||||
payload = json.dumps({
|
||||
"type": 2,
|
||||
"name": "audio.mp3",
|
||||
"size": len(file_binary),
|
||||
"ResourceFileType": "mp3",
|
||||
"model_id": "8",
|
||||
})
|
||||
|
||||
resp = self.session.post(
|
||||
API_REQ_UPLOAD,
|
||||
data=payload,
|
||||
headers=self.headers
|
||||
)
|
||||
resp.raise_for_status()
|
||||
resp = resp.json()
|
||||
resp_data = resp["data"]
|
||||
|
||||
self.__in_boss_key = resp_data["in_boss_key"]
|
||||
self.__resource_id = resp_data["resource_id"]
|
||||
self.__upload_id = resp_data["upload_id"]
|
||||
self.__upload_urls = resp_data["upload_urls"]
|
||||
self.__per_size = resp_data["per_size"]
|
||||
self.__clips = len(resp_data["upload_urls"])
|
||||
|
||||
logger.info(
|
||||
f"申请上传成功, 总计大小{resp_data['size'] // 1024}KB, {self.__clips}分片, 分片大小{resp_data['per_size'] // 1024}KB: {self.__in_boss_key}"
|
||||
)
|
||||
self.__upload_part(file_binary)
|
||||
self.__commit_upload()
|
||||
|
||||
def __upload_part(self, file_binary: bytes) -> None:
|
||||
"""上传音频数据"""
|
||||
for clip in range(self.__clips):
|
||||
start_range = clip * self.__per_size
|
||||
end_range = min((clip + 1) * self.__per_size, len(file_binary))
|
||||
logger.info(f"开始上传分片{clip}: {start_range}-{end_range}")
|
||||
resp = self.session.put(
|
||||
self.__upload_urls[clip],
|
||||
data=file_binary[start_range:end_range],
|
||||
headers={'Content-Type': 'application/octet-stream'}
|
||||
)
|
||||
resp.raise_for_status()
|
||||
etag = resp.headers.get("Etag", "").strip('"')
|
||||
self.__etags.append(etag)
|
||||
logger.info(f"分片{clip}上传成功: {etag}")
|
||||
|
||||
def __commit_upload(self) -> None:
|
||||
"""提交上传数据"""
|
||||
data = json.dumps({
|
||||
"InBossKey": self.__in_boss_key,
|
||||
"ResourceId": self.__resource_id,
|
||||
"Etags": ",".join(self.__etags),
|
||||
"UploadId": self.__upload_id,
|
||||
"model_id": "8",
|
||||
})
|
||||
resp = self.session.post(
|
||||
API_COMMIT_UPLOAD,
|
||||
data=data,
|
||||
headers=self.headers
|
||||
)
|
||||
resp.raise_for_status()
|
||||
resp = resp.json()
|
||||
print('Bili',resp)
|
||||
if resp.get("code") != 0:
|
||||
error_msg = f"上传提交失败: {resp.get('message', '未知错误')}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
self.__download_url = resp["data"]["download_url"]
|
||||
logger.info(f"提交成功,下载链接: {self.__download_url}")
|
||||
|
||||
def _create_task(self) -> str:
|
||||
"""开始创建转换任务"""
|
||||
resp = self.session.post(
|
||||
API_CREATE_TASK, json={"resource": self.__download_url, "model_id": "8"}, headers=self.headers
|
||||
)
|
||||
resp.raise_for_status()
|
||||
resp = resp.json()
|
||||
if resp.get("code") != 0:
|
||||
error_msg = f"创建任务失败: {resp.get('message', '未知错误')}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
self.task_id = resp["data"]["task_id"]
|
||||
logger.info(f"任务已创建: {self.task_id}")
|
||||
return self.task_id
|
||||
|
||||
def _query_result(self) -> dict:
|
||||
"""查询转换结果"""
|
||||
resp = self.session.get(
|
||||
API_QUERY_RESULT,
|
||||
params={"model_id": 7, "task_id": self.task_id},
|
||||
headers=self.headers
|
||||
)
|
||||
resp.raise_for_status()
|
||||
resp = resp.json()
|
||||
if resp.get("code") != 0:
|
||||
error_msg = f"查询结果失败: {resp.get('message', '未知错误')}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
return resp["data"]
|
||||
|
||||
@timeit
|
||||
def transcript(self, file_path: str) -> TranscriptResult:
|
||||
"""执行识别过程,符合 Transcriber 接口"""
|
||||
try:
|
||||
logger.info(f"开始处理文件: {file_path}")
|
||||
|
||||
# 上传文件
|
||||
logger.info("正在上传文件...")
|
||||
self._upload(file_path)
|
||||
|
||||
# 创建任务
|
||||
logger.info("提交转录任务...")
|
||||
self._create_task()
|
||||
|
||||
# 轮询检查任务状态
|
||||
logger.info("等待转录结果...")
|
||||
task_resp = None
|
||||
max_retries = 500
|
||||
for i in range(max_retries):
|
||||
task_resp = self._query_result()
|
||||
|
||||
if task_resp["state"] == 4: # 完成状态
|
||||
break
|
||||
elif task_resp["state"] == 3: # 失败状态
|
||||
error_msg = f"B站ASR任务失败,状态码: {task_resp['state']}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
# 每隔一段时间打印进度
|
||||
if i % 10 == 0:
|
||||
logger.info(f"转录进行中... {i}/{max_retries}")
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
if not task_resp or task_resp["state"] != 4:
|
||||
error_msg = f"B站ASR任务未能完成,状态: {task_resp.get('state') if task_resp else 'Unknown'}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
# 解析结果
|
||||
logger.info("转录成功,处理结果...")
|
||||
result_json = json.loads(task_resp["result"])
|
||||
|
||||
# 提取分段数据
|
||||
segments = []
|
||||
full_text = ""
|
||||
|
||||
for u in result_json.get("utterances", []):
|
||||
text = u.get("transcript", "").strip()
|
||||
# B站ASR返回的时间戳是毫秒,需要转换为秒
|
||||
start_time = float(u.get("start_time", 0)) / 1000.0
|
||||
end_time = float(u.get("end_time", 0)) / 1000.0
|
||||
|
||||
full_text += text + " "
|
||||
segments.append(TranscriptSegment(
|
||||
start=start_time,
|
||||
end=end_time,
|
||||
text=text
|
||||
))
|
||||
|
||||
# 创建结果对象
|
||||
result = TranscriptResult(
|
||||
language=result_json.get("language", "zh"),
|
||||
full_text=full_text.strip(),
|
||||
segments=segments,
|
||||
raw=result_json
|
||||
)
|
||||
|
||||
# 触发完成事件
|
||||
# self.on_finish(file_path, result)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"B站ASR处理失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def on_finish(self, video_path: str, result: TranscriptResult) -> None:
|
||||
"""转录完成的回调"""
|
||||
logger.info(f"B站ASR转写完成: {video_path}")
|
||||
transcription_finished.send({
|
||||
"file_path": video_path,
|
||||
})
|
||||
115
backend/app/transcriber/kuaishou.py
Normal file
115
backend/app/transcriber/kuaishou.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import requests
|
||||
import logging
|
||||
import os
|
||||
from typing import Union, List, Dict, Optional
|
||||
|
||||
from app.decorators.timeit import timeit
|
||||
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
|
||||
from app.transcriber.base import Transcriber
|
||||
from app.utils.logger import get_logger
|
||||
from events import transcription_finished
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class KuaishouTranscriber(Transcriber):
|
||||
"""快手语音识别实现"""
|
||||
|
||||
API_URL = "https://ai.kuaishou.com/api/effects/subtitle_generate"
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def _load_file(self, file_path: str) -> bytes:
|
||||
"""读取文件内容"""
|
||||
with open(file_path, 'rb') as f:
|
||||
return f.read()
|
||||
|
||||
def _submit(self, file_path: str) -> dict:
|
||||
"""提交识别请求"""
|
||||
try:
|
||||
file_binary = self._load_file(file_path)
|
||||
|
||||
payload = {
|
||||
"typeId": "1"
|
||||
}
|
||||
|
||||
# 使用文件名作为上传文件名
|
||||
file_name = os.path.basename(file_path)
|
||||
files = [('file', (file_name, file_binary, 'audio/mpeg'))]
|
||||
|
||||
logger.info(f"开始向快手API提交请求,文件: {file_name}")
|
||||
response = requests.post(self.API_URL, data=payload, files=files, timeout=300)
|
||||
response.raise_for_status() # 检查HTTP错误
|
||||
|
||||
result = response.json()
|
||||
print('result',result)
|
||||
# 检查快手API返回是否包含错误
|
||||
if "data" not in result or result.get("code", 0) != 0:
|
||||
error_msg = f"快手API返回错误: {result.get('message', '未知错误')}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
return result
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
error_msg = f"快手ASR请求网络错误: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"快手ASR请求处理错误: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
@timeit
|
||||
def transcript(self, file_path: str) -> TranscriptResult:
|
||||
"""执行转录过程,符合 Transcriber 接口"""
|
||||
try:
|
||||
logger.info(f"开始处理文件: {file_path}")
|
||||
|
||||
# 提交请求并获取结果
|
||||
logger.info("向快手API提交识别请求...")
|
||||
result_data = self._submit(file_path)
|
||||
|
||||
logger.info("请求成功,处理结果...")
|
||||
|
||||
# 提取分段数据
|
||||
segments = []
|
||||
full_text = ""
|
||||
|
||||
# 解析快手API返回的文本段
|
||||
texts = result_data.get('data', {}).get('text', [])
|
||||
for u in texts:
|
||||
text = u.get('text', '').strip()
|
||||
start_time = float(u.get('start_time', 0))
|
||||
end_time = float(u.get('end_time', 0))
|
||||
|
||||
full_text += text + " "
|
||||
segments.append(TranscriptSegment(
|
||||
start=start_time,
|
||||
end=end_time,
|
||||
text=text
|
||||
))
|
||||
|
||||
# 创建结果对象
|
||||
result = TranscriptResult(
|
||||
language="zh", # 快手API可能不返回语言信息,默认为中文
|
||||
full_text=full_text.strip(),
|
||||
segments=segments,
|
||||
raw=result_data
|
||||
)
|
||||
|
||||
# 触发完成事件
|
||||
# self.on_finish(file_path, result)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"快手ASR处理失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def on_finish(self, video_path: str, result: TranscriptResult) -> None:
|
||||
"""转录完成的回调"""
|
||||
logger.info(f"快手ASR转写完成: {video_path}")
|
||||
transcription_finished.send({
|
||||
"file_path": video_path,
|
||||
})
|
||||
@@ -31,15 +31,15 @@ _transcribers = {
|
||||
|
||||
def get_whisper_transcriber(model_size="base", device="cuda"):
|
||||
"""获取 Whisper 转录器实例"""
|
||||
if _transcribers['fast-whisper'] is None:
|
||||
if _transcribers['fast-whisper'] is None:
|
||||
logger.info(f'创建 Whisper 转录器实例,参数:{model_size}, {device}')
|
||||
try:
|
||||
_transcribers['fast-whisper'] = WhisperTranscriber(model_size=model_size, device=device)
|
||||
_transcribers['whisper'] = WhisperTranscriber(model_size=model_size, device=device)
|
||||
logger.info('Whisper 转录器创建成功')
|
||||
except Exception as e:
|
||||
logger.error(f"Whisper 转录器创建失败: {e}")
|
||||
raise
|
||||
return _transcribers['fast-whisper']
|
||||
return _transcribers['whisper']
|
||||
|
||||
def get_bcut_transcriber():
|
||||
"""获取 Bcut 转录器实例"""
|
||||
|
||||
@@ -4,14 +4,19 @@ from app.decorators.timeit import timeit
|
||||
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
|
||||
from app.transcriber.base import Transcriber
|
||||
from app.utils.env_checker import is_cuda_available, is_torch_installed
|
||||
from app.utils.logger import get_logger
|
||||
from app.utils.path_helper import get_model_dir
|
||||
|
||||
from events import transcription_finished
|
||||
from pathlib import Path
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
'''
|
||||
Size of the model to use (tiny, tiny.en, base, base.en, small, small.en, distil-small.en, medium, medium.en, distil-medium.en, large-v1, large-v2, large-v3, large, distil-large-v2, distil-large-v3, large-v3-turbo, or turbo
|
||||
'''
|
||||
|
||||
logger=get_logger(__name__)
|
||||
|
||||
class WhisperTranscriber(Transcriber):
|
||||
# TODO:修改为可配置
|
||||
@@ -31,15 +36,25 @@ class WhisperTranscriber(Transcriber):
|
||||
|
||||
self.compute_type = compute_type or ("float16" if self.device == "cuda" else "int8")
|
||||
|
||||
model_path = get_model_dir("whisper")
|
||||
model_dir = get_model_dir("whisper")
|
||||
model_path = os.path.join(model_dir, f"whisper-{model_size}")
|
||||
if not Path(model_path).exists():
|
||||
logger.info(f"模型 whisper-{model_size} 不存在,开始下载...")
|
||||
repo_id = f"guillaumekln/faster-whisper-{model_size}"
|
||||
snapshot_download(
|
||||
repo_id,
|
||||
local_dir=model_path,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
logger.info("模型下载完成")
|
||||
|
||||
self.model = WhisperModel(
|
||||
model_size,
|
||||
device=self.device,
|
||||
# compute_type="int8", # 或 "float16"
|
||||
compute_type=self.compute_type,
|
||||
cpu_threads=cpu_threads,
|
||||
download_root=model_path
|
||||
download_root=model_dir
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_torch_installed() -> bool:
|
||||
try:
|
||||
@@ -88,7 +103,7 @@ class WhisperTranscriber(Transcriber):
|
||||
segments=segments,
|
||||
raw=info
|
||||
)
|
||||
self.on_finish(file_path, result)
|
||||
# self.on_finish(file_path, result)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"转写失败:{e}")
|
||||
|
||||
@@ -4,6 +4,7 @@ import uvicorn
|
||||
from starlette.staticfiles import StaticFiles
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from app.db.model_dao import init_model_table
|
||||
from app.db.provider_dao import init_provider_table
|
||||
from app.utils.logger import get_logger
|
||||
from app import create_app
|
||||
@@ -39,6 +40,7 @@ async def startup_event():
|
||||
get_transcriber(transcriber_type=os.getenv("TRANSCRIBER_TYPE","fast-whisper"))
|
||||
init_video_task_table()
|
||||
init_provider_table()
|
||||
init_model_table()
|
||||
|
||||
if __name__ == "__main__":
|
||||
port = int(os.getenv("BACKEND_PORT", 8000))
|
||||
|
||||
Reference in New Issue
Block a user