mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-06-06 08:10:48 +08:00
refactor(backend): 重构后端异常处理和模型管理
- 新增自定义异常类 BizException、NoteError 和 ProviderError - 优化了模型管理相关的逻辑,包括加载、删除和测试连接等功能 - 改进了 Douyin 下载器的错误处理 - 调整了任务重试逻辑和笔记生成的异常处理- 更新了相关组件和页面以适应新的异常处理机制
This commit is contained in:
@@ -11,6 +11,7 @@ VITE_FRONTEND_PORT=3015
|
|||||||
ENV=production
|
ENV=production
|
||||||
STATIC=/static
|
STATIC=/static
|
||||||
OUT_DIR=./static/screenshots
|
OUT_DIR=./static/screenshots
|
||||||
|
NOTE_OUTPUT_DIR=note_results
|
||||||
IMAGE_BASE_URL=/static/screenshots
|
IMAGE_BASE_URL=/static/screenshots
|
||||||
DATA_DIR=data
|
DATA_DIR=data
|
||||||
# FFMPEG 配置
|
# FFMPEG 配置
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ const DownloaderForm = () => {
|
|||||||
setLoading(true) // 🔁 切换平台时显示 loading
|
setLoading(true) // 🔁 切换平台时显示 loading
|
||||||
try {
|
try {
|
||||||
const res = await getDownloaderCookie(id)
|
const res = await getDownloaderCookie(id)
|
||||||
const cookie = res?.data?.data?.cookie || ''
|
const cookie = res?.cookie || ''
|
||||||
form.reset({ cookie }) // ✅ 正确重置表单值
|
form.reset({ cookie }) // ✅ 正确重置表单值
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
toast.error('加载 Cookie 失败: ' + e)
|
toast.error('加载 Cookie 失败: ' + e)
|
||||||
|
|||||||
@@ -129,11 +129,10 @@ const ProviderForm = ({ isCreate = false }: { isCreate?: boolean }) => {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
const res = await deleteModelById(modelId)
|
const res = await deleteModelById(modelId)
|
||||||
if (res.data.code === 0) {
|
console.log('🔧 删除结果:', res)
|
||||||
toast.success('删除成功')
|
|
||||||
} else {
|
toast.success('删除成功')
|
||||||
toast.error(res.data.msg || '删除失败')
|
|
||||||
}
|
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
toast.error('删除异常')
|
toast.error('删除异常')
|
||||||
}
|
}
|
||||||
@@ -151,16 +150,16 @@ const ProviderForm = ({ isCreate = false }: { isCreate?: boolean }) => {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
setTesting(true)
|
setTesting(true)
|
||||||
const data = await testConnection({
|
await testConnection({
|
||||||
id
|
id
|
||||||
})
|
})
|
||||||
if (data.data.code === 0) {
|
|
||||||
toast.success('测试连通性成功 🎉')
|
toast.success('测试连通性成功 🎉')
|
||||||
} else {
|
|
||||||
toast.error(`连接失败: ${data.data.msg || '未知错误'}`)
|
|
||||||
}
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
toast.error('测试连通性异常')
|
|
||||||
|
toast.error(`连接失败: ${data.data.msg || '未知错误'}`)
|
||||||
|
// toast.error('测试连通性异常')
|
||||||
} finally {
|
} finally {
|
||||||
setTesting(false)
|
setTesting(false)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,11 +26,11 @@ export const useTaskPolling = (interval = 3000) => {
|
|||||||
try {
|
try {
|
||||||
console.log('🔄 正在轮询任务:', task.id)
|
console.log('🔄 正在轮询任务:', task.id)
|
||||||
const res = await get_task_status(task.id)
|
const res = await get_task_status(task.id)
|
||||||
const { status } = res.data
|
const { status } = res
|
||||||
|
|
||||||
if (status && status !== task.status) {
|
if (status && status !== task.status) {
|
||||||
if (status === 'SUCCESS') {
|
if (status === 'SUCCESS') {
|
||||||
const { markdown, transcript, audio_meta } = res.data.result
|
const { markdown, transcript, audio_meta } = res.result
|
||||||
toast.success('笔记生成成功')
|
toast.success('笔记生成成功')
|
||||||
updateTaskContent(task.id, {
|
updateTaskContent(task.id, {
|
||||||
status,
|
status,
|
||||||
@@ -47,7 +47,7 @@ export const useTaskPolling = (interval = 3000) => {
|
|||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.error('❌ 任务轮询失败:', e)
|
console.error('❌ 任务轮询失败:', e)
|
||||||
toast.error(`生成失败 ${e.message || e}`)
|
// toast.error(`生成失败 ${e.message || e}`)
|
||||||
updateTaskContent(task.id, { status: 'FAILED' })
|
updateTaskContent(task.id, { status: 'FAILED' })
|
||||||
// removeTask(task.id)
|
// removeTask(task.id)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -173,6 +173,7 @@ const MarkdownViewer: FC<MarkdownViewerProps> = ({ status }) => {
|
|||||||
<div className="text-center">
|
<div className="text-center">
|
||||||
<p className="text-lg font-bold text-red-500">笔记生成失败</p>
|
<p className="text-lg font-bold text-red-500">笔记生成失败</p>
|
||||||
<p className="mt-2 mb-2 text-xs text-red-400">请检查后台或稍后再试</p>
|
<p className="mt-2 mb-2 text-xs text-red-400">请检查后台或稍后再试</p>
|
||||||
|
|
||||||
<Button onClick={() => retryTask(currentTask.id)} size="lg">
|
<Button onClick={() => retryTask(currentTask.id)} size="lg">
|
||||||
重试
|
重试
|
||||||
</Button>
|
</Button>
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ import {
|
|||||||
import { Input } from '@/components/ui/input.tsx'
|
import { Input } from '@/components/ui/input.tsx'
|
||||||
import { Textarea } from '@/components/ui/textarea.tsx'
|
import { Textarea } from '@/components/ui/textarea.tsx'
|
||||||
import { noteStyles, noteFormats, videoPlatforms } from '@/constant/note.ts'
|
import { noteStyles, noteFormats, videoPlatforms } from '@/constant/note.ts'
|
||||||
|
import { fetchModels } from '@/services/model.ts'
|
||||||
|
|
||||||
/* -------------------- 校验 Schema -------------------- */
|
/* -------------------- 校验 Schema -------------------- */
|
||||||
const formSchema = z
|
const formSchema = z
|
||||||
@@ -206,7 +207,7 @@ const NoteForm = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
message.success('已提交任务')
|
message.success('已提交任务')
|
||||||
const { data } = await generateNote(payload)
|
const data = await generateNote(payload)
|
||||||
addPendingTask(data.task_id, values.platform, payload)
|
addPendingTask(data.task_id, values.platform, payload)
|
||||||
}
|
}
|
||||||
const onInvalid = (errors: FieldErrors<NoteFormValues>) => {
|
const onInvalid = (errors: FieldErrors<NoteFormValues>) => {
|
||||||
@@ -355,6 +356,9 @@ const NoteForm = () => {
|
|||||||
<FormItem>
|
<FormItem>
|
||||||
<SectionHeader title="模型选择" tip="不同模型效果不同,建议自行测试" />
|
<SectionHeader title="模型选择" tip="不同模型效果不同,建议自行测试" />
|
||||||
<Select
|
<Select
|
||||||
|
onOpenChange={()=>{
|
||||||
|
loadEnabledModels()
|
||||||
|
}}
|
||||||
value={field.value}
|
value={field.value}
|
||||||
onValueChange={field.onChange}
|
onValueChange={field.onChange}
|
||||||
defaultValue={field.value}
|
defaultValue={field.value}
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ export default function AboutPage() {
|
|||||||
height={50}
|
height={50}
|
||||||
className="rounded-lg"
|
className="rounded-lg"
|
||||||
/>
|
/>
|
||||||
<h1 className="text-4xl font-bold">BiliNote v1.7.3</h1>
|
<h1 className="text-4xl font-bold">BiliNote v1.7.4</h1>
|
||||||
</div>
|
</div>
|
||||||
<p className="text-muted-foreground mb-6 text-xl italic">
|
<p className="text-muted-foreground mb-6 text-xl italic">
|
||||||
AI 视频笔记生成工具 让 AI 为你的视频做笔记
|
AI 视频笔记生成工具 让 AI 为你的视频做笔记
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import request from '@/utils/request'
|
import request from '@/utils/request'
|
||||||
import toast from 'react-hot-toast'
|
import toast from 'react-hot-toast'
|
||||||
import { useTaskStore } from '@/store/taskStore'
|
|
||||||
import request from '@/utils/request'
|
|
||||||
export const generateNote = async (data: {
|
export const generateNote = async (data: {
|
||||||
video_url: string
|
video_url: string
|
||||||
platform: string
|
platform: string
|
||||||
@@ -14,12 +13,13 @@ export const generateNote = async (data: {
|
|||||||
extras?: string
|
extras?: string
|
||||||
video_understand?: boolean
|
video_understand?: boolean
|
||||||
video_interval?: number
|
video_interval?: number
|
||||||
grid_size:Array<number>
|
grid_size: Array<number>
|
||||||
}) => {
|
}) => {
|
||||||
try {
|
try {
|
||||||
|
console.log('generateNote', data)
|
||||||
const response = await request.post('/generate_note', data)
|
const response = await request.post('/generate_note', data)
|
||||||
|
|
||||||
if (response.data.code != 0) {
|
if (!response) {
|
||||||
if (response.data.msg) {
|
if (response.data.msg) {
|
||||||
toast.error(response.data.msg)
|
toast.error(response.data.msg)
|
||||||
}
|
}
|
||||||
@@ -30,12 +30,12 @@ export const generateNote = async (data: {
|
|||||||
console.log('res', response)
|
console.log('res', response)
|
||||||
// 成功提示
|
// 成功提示
|
||||||
|
|
||||||
return response.data
|
return response
|
||||||
} catch (e: any) {
|
} catch (e: any) {
|
||||||
console.error('❌ 请求出错', e)
|
console.error('❌ 请求出错', e)
|
||||||
|
|
||||||
// 错误提示
|
// 错误提示
|
||||||
toast.error('笔记生成失败,请稍后重试')
|
// toast.error('笔记生成失败,请稍后重试')
|
||||||
|
|
||||||
throw e // 抛出错误以便调用方处理
|
throw e // 抛出错误以便调用方处理
|
||||||
}
|
}
|
||||||
@@ -65,15 +65,9 @@ export const delete_task = async ({ video_id, platform }) => {
|
|||||||
|
|
||||||
export const get_task_status = async (task_id: string) => {
|
export const get_task_status = async (task_id: string) => {
|
||||||
try {
|
try {
|
||||||
const response = await request.get('/task_status/' + task_id)
|
|
||||||
|
|
||||||
if (response.data.code == 0 && response.data.status == 'SUCCESS') {
|
|
||||||
// toast.success("笔记生成成功")
|
|
||||||
}
|
|
||||||
console.log('res', response)
|
|
||||||
// 成功提示
|
// 成功提示
|
||||||
|
|
||||||
return response.data
|
return await request.get('/task_status/' + task_id)
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.error('❌ 请求出错', e)
|
console.error('❌ 请求出错', e)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,12 @@
|
|||||||
import { create } from 'zustand'
|
import { create } from 'zustand'
|
||||||
import { devtools } from 'zustand/middleware'
|
import { devtools } from 'zustand/middleware'
|
||||||
import { fetchModels, addModel, fetchEnableModels, fetchEnableModelById, deleteModelById } from '@/services/model.ts'
|
import {
|
||||||
|
fetchModels,
|
||||||
|
addModel,
|
||||||
|
fetchEnableModels,
|
||||||
|
fetchEnableModelById,
|
||||||
|
deleteModelById
|
||||||
|
} from '@/services/model'
|
||||||
|
|
||||||
interface IModel {
|
interface IModel {
|
||||||
id: string
|
id: string
|
||||||
@@ -11,81 +17,93 @@ interface IModel {
|
|||||||
root: string
|
root: string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface IModelListItem {
|
||||||
|
id: string
|
||||||
|
provider_id: string
|
||||||
|
model_name: string
|
||||||
|
created_at?: string
|
||||||
|
}
|
||||||
|
|
||||||
interface ModelStore {
|
interface ModelStore {
|
||||||
models: IModel[]
|
models: IModel[]
|
||||||
modelList: []
|
modelList: IModelListItem[]
|
||||||
loading: boolean
|
loading: boolean
|
||||||
selectedModel: string
|
selectedModel: string
|
||||||
|
|
||||||
loadModels: (providerId: string) => Promise<void>
|
loadModels: (providerId: string) => Promise<void>
|
||||||
|
loadModelsById: (providerId: string) => Promise<IModelListItem[]>
|
||||||
loadEnabledModels: () => Promise<void>
|
loadEnabledModels: () => Promise<void>
|
||||||
loadModelsById : (providerId: string) => Promise<void>
|
|
||||||
addNewModel: (providerId: string, modelId: string) => Promise<void>
|
addNewModel: (providerId: string, modelId: string) => Promise<void>
|
||||||
setSelectedModel: (modelId: string) => void
|
|
||||||
deleteModel: (modelId: number) => Promise<void>
|
deleteModel: (modelId: number) => Promise<void>
|
||||||
|
setSelectedModel: (modelId: string) => void
|
||||||
clearModels: () => void
|
clearModels: () => void
|
||||||
}
|
}
|
||||||
|
|
||||||
export const useModelStore = create<ModelStore>()(
|
export const useModelStore = create<ModelStore>()(
|
||||||
devtools(set => ({
|
devtools((set) => ({
|
||||||
models: [],
|
models: [],
|
||||||
|
modelList: [],
|
||||||
loading: false,
|
loading: false,
|
||||||
selectedModel: '',
|
selectedModel: '',
|
||||||
modelList: [],
|
|
||||||
|
|
||||||
|
// 获取所有可用模型 (全局可用模型列表)
|
||||||
loadEnabledModels: async () => {
|
loadEnabledModels: async () => {
|
||||||
try {
|
try {
|
||||||
set({ loading: true })
|
set({ loading: true })
|
||||||
const res = await fetchEnableModels()
|
const list = await fetchEnableModels()
|
||||||
if (res.data.code === 0 && res.data.data.length > 0) {
|
set({ modelList: list })
|
||||||
set({ modelList: res.data.data })
|
|
||||||
} else {
|
|
||||||
set({ modelList: [] })
|
|
||||||
console.error('模型列表加载失败')
|
|
||||||
}
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
set({ modelList: [] })
|
set({ modelList: [] })
|
||||||
console.error('加载模型出错', error)
|
console.error('加载可用模型失败', error)
|
||||||
}
|
|
||||||
},
|
|
||||||
|
|
||||||
deleteModel: async (modelId: number) => {
|
|
||||||
await deleteModelById( modelId)
|
|
||||||
},
|
|
||||||
// 加载模型列表
|
|
||||||
loadModels: async (providerId: string) => {
|
|
||||||
try {
|
|
||||||
set({ loading: true })
|
|
||||||
const res = await fetchModels(providerId)
|
|
||||||
if (res.data.code === 0 && res.data.data.models.data.length > 0) {
|
|
||||||
set({ models: res.data.data.models.data })
|
|
||||||
} else if (res.data.code === 0 && res.data.data.models.length > 0) {
|
|
||||||
set({ models: res.data.data.models.data })
|
|
||||||
} else {
|
|
||||||
set({ models: [] })
|
|
||||||
console.error('模型列表加载失败')
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
set({ models: [] })
|
|
||||||
console.error('加载模型出错', error)
|
|
||||||
} finally {
|
} finally {
|
||||||
set({ loading: false })
|
set({ loading: false })
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
loadModelsById: async (providerId: string)=>{
|
|
||||||
const models = await fetchEnableModelById(providerId)
|
// 通过 provider 获取该供应商的模型列表
|
||||||
if (models.data.code === 0) {
|
loadModels: async (providerId: string) => {
|
||||||
console.log('模型列表加载成功:', models.data)
|
try {
|
||||||
return models.data.data
|
set({ loading: true })
|
||||||
|
const res = await fetchModels(providerId)
|
||||||
|
|
||||||
|
let models: IModel[] = []
|
||||||
|
|
||||||
|
// 兼容 SyncPage 分页对象与普通数组两种格式
|
||||||
|
if (Array.isArray(res.models)) {
|
||||||
|
models = res.models
|
||||||
|
} else if (res.models?.data && Array.isArray(res.models.data)) {
|
||||||
|
models = res.models.data
|
||||||
|
}
|
||||||
|
|
||||||
|
set({ models })
|
||||||
|
} catch (error) {
|
||||||
|
set({ models: [] })
|
||||||
|
console.error('加载模型列表失败', error)
|
||||||
|
} finally {
|
||||||
|
set({ loading: false })
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
// 新增模型
|
|
||||||
|
// 单独获取某个供应商下已启用模型
|
||||||
|
loadModelsById: async (providerId: string) => {
|
||||||
|
try {
|
||||||
|
const models = await fetchEnableModelById(providerId)
|
||||||
|
console.log('获取供应商模型成功:', models)
|
||||||
|
return models
|
||||||
|
} catch (error) {
|
||||||
|
console.error('加载供应商模型失败', error)
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
// 新增模型逻辑
|
||||||
addNewModel: async (providerId: string, modelId: string) => {
|
addNewModel: async (providerId: string, modelId: string) => {
|
||||||
try {
|
try {
|
||||||
const res = await addModel({ provider_id: providerId, model_name: modelId })
|
const res = await addModel({ provider_id: providerId, model_name: modelId })
|
||||||
|
|
||||||
if (res.code === 0) {
|
if (res.code === 0) {
|
||||||
console.log('新增模型成功:', modelId)
|
console.log('新增模型成功:', modelId)
|
||||||
// ✅ 新增成功以后,前端直接追加一条到 models 列表
|
set((state) => ({
|
||||||
set(state => ({
|
|
||||||
models: [
|
models: [
|
||||||
...state.models,
|
...state.models,
|
||||||
{
|
{
|
||||||
@@ -99,17 +117,30 @@ export const useModelStore = create<ModelStore>()(
|
|||||||
],
|
],
|
||||||
}))
|
}))
|
||||||
} else {
|
} else {
|
||||||
console.error('新增模型失败')
|
console.error('新增模型失败', res.msg)
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('添加模型出错', error)
|
console.error('添加模型出错', error)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
// 设置选中的模型
|
// 删除模型
|
||||||
setSelectedModel: modelId => set({ selectedModel: modelId }),
|
deleteModel: async (modelId: number) => {
|
||||||
|
try {
|
||||||
|
await deleteModelById(modelId)
|
||||||
|
// 删除后更新本地状态(可选)
|
||||||
|
set((state) => ({
|
||||||
|
models: state.models.filter((model) => model.id !== modelId.toString())
|
||||||
|
}))
|
||||||
|
} catch (error) {
|
||||||
|
console.error('删除模型失败', error)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
// 清空
|
// 切换选中模型
|
||||||
clearModels: () => set({ models: [], selectedModel: '' }),
|
setSelectedModel: (modelId: string) => set({ selectedModel: modelId }),
|
||||||
|
|
||||||
|
// 清空
|
||||||
|
clearModels: () => set({ models: [], selectedModel: '', modelList: [] }),
|
||||||
}))
|
}))
|
||||||
)
|
)
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
import { create } from 'zustand'
|
import { create } from 'zustand'
|
||||||
import { IProvider } from '@/types'
|
import { IProvider, IResponse } from '@/types'
|
||||||
import {
|
import {
|
||||||
addProvider,
|
addProvider,
|
||||||
getProviderById,
|
getProviderById,
|
||||||
@@ -38,10 +38,9 @@ export const useProviderStore = create<ProviderStore>((set, get) => ({
|
|||||||
// 设置整个 provider 列表
|
// 设置整个 provider 列表
|
||||||
setAllProviders: providers => set({ provider: providers }),
|
setAllProviders: providers => set({ provider: providers }),
|
||||||
loadProviderById: async (id: string) => {
|
loadProviderById: async (id: string) => {
|
||||||
const res = await getProviderById(id)
|
const res:IResponse<IProvider> = await getProviderById(id)
|
||||||
if (res.data.code === 0) {
|
|
||||||
const item = res.data.data
|
const item = res
|
||||||
console.log('Provider ', item)
|
|
||||||
return {
|
return {
|
||||||
id: item.id,
|
id: item.id,
|
||||||
name: item.name,
|
name: item.name,
|
||||||
@@ -51,9 +50,7 @@ export const useProviderStore = create<ProviderStore>((set, get) => ({
|
|||||||
type: item.type,
|
type: item.type,
|
||||||
enabled: item.enabled,
|
enabled: item.enabled,
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
console.log('Provider not found')
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
addNewProvider: async (provider: IProvider) => {
|
addNewProvider: async (provider: IProvider) => {
|
||||||
const payload = {
|
const payload = {
|
||||||
@@ -96,16 +93,18 @@ export const useProviderStore = create<ProviderStore>((set, get) => ({
|
|||||||
getProviderList: () => get().provider,
|
getProviderList: () => get().provider,
|
||||||
fetchProviderList: async () => {
|
fetchProviderList: async () => {
|
||||||
try {
|
try {
|
||||||
const res = await getProviderList()
|
const res = await getProviderList()
|
||||||
if (res.data.code === 0) {
|
|
||||||
set({
|
set({
|
||||||
provider: res.data.data.map(
|
provider: res.map(
|
||||||
(item: {
|
(item: {
|
||||||
id: string
|
id: string
|
||||||
name: string
|
name: string
|
||||||
logo: string
|
logo: string
|
||||||
api_key: string
|
api_key: string
|
||||||
base_url: string
|
base_url: string
|
||||||
|
type: string
|
||||||
|
enabled: number
|
||||||
}) => {
|
}) => {
|
||||||
return {
|
return {
|
||||||
id: item.id,
|
id: item.id,
|
||||||
@@ -119,7 +118,6 @@ export const useProviderStore = create<ProviderStore>((set, get) => ({
|
|||||||
}
|
}
|
||||||
),
|
),
|
||||||
})
|
})
|
||||||
}
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error fetching provider list:', error)
|
console.error('Error fetching provider list:', error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import { create } from 'zustand'
|
|||||||
import { persist } from 'zustand/middleware'
|
import { persist } from 'zustand/middleware'
|
||||||
import { delete_task, generateNote } from '@/services/note.ts'
|
import { delete_task, generateNote } from '@/services/note.ts'
|
||||||
import { v4 as uuidv4 } from 'uuid'
|
import { v4 as uuidv4 } from 'uuid'
|
||||||
|
import toast from 'react-hot-toast'
|
||||||
|
|
||||||
|
|
||||||
export type TaskStatus = 'PENDING' | 'RUNNING' | 'SUCCESS' | 'FAILD'
|
export type TaskStatus = 'PENDING' | 'RUNNING' | 'SUCCESS' | 'FAILD'
|
||||||
@@ -157,14 +158,19 @@ export const useTaskStore = create<TaskStore>()(
|
|||||||
return get().tasks.find(task => task.id === currentTaskId) || null
|
return get().tasks.find(task => task.id === currentTaskId) || null
|
||||||
},
|
},
|
||||||
retryTask: async (id: string, payload?: any) => {
|
retryTask: async (id: string, payload?: any) => {
|
||||||
|
|
||||||
|
if (!id){
|
||||||
|
toast.error('任务不存在')
|
||||||
|
return
|
||||||
|
}
|
||||||
const task = get().tasks.find(task => task.id === id)
|
const task = get().tasks.find(task => task.id === id)
|
||||||
|
console.log('retry',task)
|
||||||
if (!task) return
|
if (!task) return
|
||||||
|
|
||||||
const newFormData = payload || task.formData
|
const newFormData = payload || task.formData
|
||||||
|
|
||||||
await generateNote({
|
await generateNote({
|
||||||
task_id: id,
|
|
||||||
...newFormData,
|
...newFormData,
|
||||||
|
task_id: id,
|
||||||
})
|
})
|
||||||
|
|
||||||
set(state => ({
|
set(state => ({
|
||||||
|
|||||||
5
BillNote_frontend/src/types/index.d.ts
vendored
5
BillNote_frontend/src/types/index.d.ts
vendored
@@ -7,3 +7,8 @@ export interface IProvider {
|
|||||||
baseUrl: string
|
baseUrl: string
|
||||||
enabled: number
|
enabled: number
|
||||||
}
|
}
|
||||||
|
export interface IResponse<T> {
|
||||||
|
code: number
|
||||||
|
data:T
|
||||||
|
msg: string
|
||||||
|
}
|
||||||
@@ -1,27 +1,58 @@
|
|||||||
import axios from 'axios'
|
import axios, { AxiosInstance, AxiosResponse } from 'axios';
|
||||||
const request = axios.create({
|
import toast from 'react-hot-toast'
|
||||||
baseURL: '/api',
|
|
||||||
timeout: 10000,
|
|
||||||
})
|
|
||||||
function handleErrorResponse(response: any) {
|
|
||||||
if (!response) return '请求失败,请检查网络连接'
|
|
||||||
if (typeof response.code !== 'number') return '系统异常'
|
|
||||||
|
|
||||||
// 错误码判断
|
// 统一响应类型
|
||||||
switch (response.code) {
|
export interface IResponse<T = any> {
|
||||||
case 1001:
|
code: number;
|
||||||
return response.msg || '下载失败,请检查视频链接'
|
msg: string;
|
||||||
case 1002:
|
data: T;
|
||||||
return response.msg || '转写失败,请稍后重试'
|
|
||||||
case 1003:
|
|
||||||
return response.msg || '总结失败,可能是模型服务异常'
|
|
||||||
case 2001:
|
|
||||||
case 2002:
|
|
||||||
return Array.isArray(response.data)
|
|
||||||
? response.data.map(e => `${e.field}: ${e.error}`).join('\n')
|
|
||||||
: response.msg || '参数错误'
|
|
||||||
default:
|
|
||||||
return response.msg || '系统异常'
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 模拟一个消息提示函数 (实际项目中会使用UI库的组件,如 Ant Design 的 message 或 Element UI 的 ElMessage)
|
||||||
|
// This function simulates a message display (in real projects, you'd use a UI library's component)
|
||||||
|
|
||||||
|
|
||||||
|
// 创建实例
|
||||||
|
const request: AxiosInstance = axios.create({
|
||||||
|
baseURL: '/api', // 请确保你的开发服务器代理设置正确
|
||||||
|
timeout: 10000,
|
||||||
|
});
|
||||||
|
|
||||||
|
// 响应拦截器
|
||||||
|
request.interceptors.response.use(
|
||||||
|
(response: AxiosResponse<IResponse>) => {
|
||||||
|
const res = response.data;
|
||||||
|
if (res.code === 0) {
|
||||||
|
// 业务成功,可以根据需要显示成功消息,或者不显示(如果操作本身就是可见的)
|
||||||
|
// showMessage('success', res.msg || '操作成功'); // 如果需要显示成功消息
|
||||||
|
return res.data; // 返回data部分,简化后续业务代码
|
||||||
|
} else {
|
||||||
|
// 业务错误,统一显示后端返回的错误消息
|
||||||
|
// Business error, uniformly display the error message returned from the backend
|
||||||
|
toast.error(res.msg || '操作失败,请稍后再试');
|
||||||
|
return Promise.reject(res); // 拒绝Promise,让业务代码可以捕获并处理
|
||||||
|
}
|
||||||
|
},
|
||||||
|
(error) => {
|
||||||
|
// 网络/服务器错误
|
||||||
|
const res = error?.response?.data as IResponse | undefined;
|
||||||
|
if (res) {
|
||||||
|
// 如果后端有返回错误信息,则显示后端信息
|
||||||
|
// If the backend returns an error message, display it
|
||||||
|
|
||||||
|
toast.error(res.msg || '服务器错误,请稍后再试');
|
||||||
|
return Promise.reject(res);
|
||||||
|
} else {
|
||||||
|
// 没有响应数据(如网络中断),显示通用网络错误
|
||||||
|
// No response data (e.g., network disconnected), display generic network error
|
||||||
|
toast.error( '请求失败,请检查网络连接或稍后再试')
|
||||||
|
return Promise.reject({
|
||||||
|
code: -1,
|
||||||
|
msg: '请求失败,请检查网络连接',
|
||||||
|
data: null
|
||||||
|
} as IResponse);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
export default request
|
export default request
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="./doc/icon.svg" alt="BiliNote Banner" width="50" height="50" />
|
<img src="./doc/icon.svg" alt="BiliNote Banner" width="50" height="50" />
|
||||||
</p>
|
</p>
|
||||||
<h1 align="center" > BiliNote v1.7.3</h1>
|
<h1 align="center" > BiliNote v1.7.4</h1>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<p align="center"><i>AI 视频笔记生成工具 让 AI 为你的视频做笔记</i></p>
|
<p align="center"><i>AI 视频笔记生成工具 让 AI 为你的视频做笔记</i></p>
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
from .routers import note, provider, model, config
|
from .routers import note, provider, model, config
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> FastAPI:
|
def create_app() -> FastAPI:
|
||||||
app = FastAPI(title="BiliNote")
|
app = FastAPI(title="BiliNote")
|
||||||
app.include_router(note.router, prefix="/api")
|
app.include_router(note.router, prefix="/api")
|
||||||
app.include_router(provider.router, prefix="/api")
|
app.include_router(provider.router, prefix="/api")
|
||||||
app.include_router(model.router,prefix="/api")
|
app.include_router(model.router,prefix="/api")
|
||||||
app.include_router(config.router, prefix="/api")
|
app.include_router(config.router, prefix="/api")
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|||||||
@@ -1,38 +0,0 @@
|
|||||||
# app/core/exception_handlers.py
|
|
||||||
from fastapi import Request, HTTPException
|
|
||||||
from fastapi.exceptions import RequestValidationError
|
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
|
|
||||||
from app.utils.logger import get_logger
|
|
||||||
from app.utils.response import ResponseWrapper
|
|
||||||
from app.utils.status_code import StatusCode
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
def register_exception_handlers(app):
|
|
||||||
@app.exception_handler(RequestValidationError)
|
|
||||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
|
||||||
errors = []
|
|
||||||
for err in exc.errors():
|
|
||||||
loc = err.get("loc", [])
|
|
||||||
field = loc[-1] if loc else "body"
|
|
||||||
msg = err.get("msg", "参数不合法")
|
|
||||||
errors.append({"field": field, "error": msg})
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=400,
|
|
||||||
content=ResponseWrapper.error(msg="参数验证失败", code=StatusCode.PARAM_ERROR, data=errors)
|
|
||||||
)
|
|
||||||
|
|
||||||
@app.exception_handler(HTTPException)
|
|
||||||
async def http_exception_handler(request: Request, exc: HTTPException):
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=exc.status_code,
|
|
||||||
content=ResponseWrapper.error(msg=str(exc.detail), code=StatusCode.FAIL)
|
|
||||||
)
|
|
||||||
|
|
||||||
@app.exception_handler(Exception)
|
|
||||||
async def global_exception_handler(request: Request, exc: Exception):
|
|
||||||
logger.exception(f"服务器内部错误: {exc}")
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=500,
|
|
||||||
content=ResponseWrapper.error(msg="服务器内部错误", code=StatusCode.FAIL, data=str(exc))
|
|
||||||
)
|
|
||||||
@@ -136,7 +136,8 @@ def get_provider_by_name(name: str):
|
|||||||
if row is None:
|
if row is None:
|
||||||
logger.info(f"Provider not found: {name}")
|
logger.info(f"Provider not found: {name}")
|
||||||
return None
|
return None
|
||||||
logger.info(f"Provider found: {row}")
|
logger.info(f"Provider found: {row[0]}")
|
||||||
|
|
||||||
return row
|
return row
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get provider by name: {e}")
|
logger.error(f"Failed to get provider by name: {e}")
|
||||||
@@ -155,7 +156,7 @@ def get_provider_by_id(id: int):
|
|||||||
if row is None:
|
if row is None:
|
||||||
logger.info(f"Provider not found: {id}")
|
logger.info(f"Provider not found: {id}")
|
||||||
return None
|
return None
|
||||||
logger.info(f"Provider found: {row}")
|
logger.info(f"Provider found: {row[0]}")
|
||||||
return row
|
return row
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get provider by id: {e}")
|
logger.error(f"Failed to get provider by id: {e}")
|
||||||
@@ -173,7 +174,7 @@ def get_all_providers():
|
|||||||
if rows is None:
|
if rows is None:
|
||||||
logger.info("No providers found")
|
logger.info("No providers found")
|
||||||
return None
|
return None
|
||||||
logger.info(f"Providers found: {rows}")
|
logger.info(f"Providers found total {len(rows) }")
|
||||||
return rows
|
return rows
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get all providers: {e}")
|
logger.error(f"Failed to get all providers: {e}")
|
||||||
|
|||||||
@@ -145,53 +145,59 @@ class DouyinDownloader(Downloader):
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
def gen_real_msToken(self) -> str:
|
def gen_real_msToken(self) -> str:
|
||||||
payload = json.dumps(
|
try:
|
||||||
{
|
payload = json.dumps(
|
||||||
"magic": self.ms_token_config["magic"],
|
{
|
||||||
"version": self.ms_token_config["version"],
|
"magic": self.ms_token_config["magic"],
|
||||||
"dataType": self.ms_token_config["dataType"],
|
"version": self.ms_token_config["version"],
|
||||||
"strData": self.ms_token_config["strData"],
|
"dataType": self.ms_token_config["dataType"],
|
||||||
"tspFromClient": get_timestamp(),
|
"strData": self.ms_token_config["strData"],
|
||||||
|
"tspFromClient": get_timestamp(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
headers = {
|
||||||
|
"User-Agent": self.headers_config["User-Agent"],
|
||||||
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
)
|
transport = httpx.HTTPTransport(retries=5)
|
||||||
headers = {
|
with httpx.Client(transport=transport) as client:
|
||||||
"User-Agent": self.headers_config["User-Agent"],
|
try:
|
||||||
"Content-Type": "application/json",
|
response = client.post(
|
||||||
}
|
self.ms_token_config["url"], content=payload, headers=headers
|
||||||
transport = httpx.HTTPTransport(retries=5)
|
)
|
||||||
with httpx.Client(transport=transport) as client:
|
response.raise_for_status()
|
||||||
try:
|
|
||||||
response = client.post(
|
|
||||||
self.ms_token_config["url"], content=payload, headers=headers
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
msToken = str(httpx.Cookies(response.cookies).get("msToken"))
|
msToken = str(httpx.Cookies(response.cookies).get("msToken"))
|
||||||
if len(msToken) not in [120, 128]:
|
if len(msToken) not in [120, 128]:
|
||||||
raise ValueError("响应内容:{0}, Douyin msToken API 的响应内容不符合要求。".format(msToken))
|
raise ValueError("响应内容:{0}, Douyin msToken API 的响应内容不符合要求。".format(msToken))
|
||||||
|
|
||||||
return msToken
|
return msToken
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError("Douyin msToken API 请求失败:{0}".format(e))
|
raise ValueError("Douyin msToken API 请求失败:{0}".format(e))
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError("Douyin msToken API{0}".format(e))
|
||||||
|
|
||||||
def fetch_video_info(self, video_url: str) -> json:
|
def fetch_video_info(self, video_url: str) -> json:
|
||||||
aweme_id = self.extract_video_id(video_url)
|
|
||||||
kwargs = self.headers_config
|
|
||||||
print("kwargs:", kwargs)
|
|
||||||
base_params = BaseRequestModel().model_dump()
|
|
||||||
base_params["msToken"] = self.gen_real_msToken()
|
|
||||||
base_params["aweme_id"] = aweme_id
|
|
||||||
bogus = ABogus()
|
|
||||||
ab_value = bogus.get_value(base_params)
|
|
||||||
a_bogus = quote(ab_value, safe='')
|
|
||||||
print(base_params)
|
|
||||||
query_str = urlencode(base_params)
|
|
||||||
full_url = f"{DOUYIN_DOMAIN}/aweme/v1/web/aweme/detail/?{query_str}&a_bogus={a_bogus}"
|
|
||||||
|
|
||||||
print("Request URL:", full_url)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
aweme_id = self.extract_video_id(video_url)
|
||||||
|
kwargs = self.headers_config
|
||||||
|
print("@kwargs:", kwargs)
|
||||||
|
base_params = BaseRequestModel().model_dump()
|
||||||
|
base_params["msToken"] = self.gen_real_msToken()
|
||||||
|
|
||||||
|
base_params["aweme_id"] = aweme_id
|
||||||
|
bogus = ABogus()
|
||||||
|
ab_value = bogus.get_value(base_params)
|
||||||
|
a_bogus = quote(ab_value, safe='')
|
||||||
|
print("@a_bogus:", a_bogus)
|
||||||
|
print(base_params)
|
||||||
|
query_str = urlencode(base_params)
|
||||||
|
full_url = f"{DOUYIN_DOMAIN}/aweme/v1/web/aweme/detail/?{query_str}&a_bogus={a_bogus}"
|
||||||
|
|
||||||
|
print("Request URL:", full_url)
|
||||||
|
|
||||||
|
|
||||||
response = requests.get(full_url, headers=kwargs)
|
response = requests.get(full_url, headers=kwargs)
|
||||||
|
|
||||||
print("Response JSON:", response.content)
|
print("Response JSON:", response.content)
|
||||||
@@ -208,46 +214,49 @@ class DouyinDownloader(Downloader):
|
|||||||
quality: DownloadQuality = "fast",
|
quality: DownloadQuality = "fast",
|
||||||
need_video: Optional[bool] = False
|
need_video: Optional[bool] = False
|
||||||
) -> AudioDownloadResult:
|
) -> AudioDownloadResult:
|
||||||
print(
|
try:
|
||||||
f"正在下载视频: {video_url},保存路径: {output_dir},质量: {quality}"
|
print(
|
||||||
)
|
f"正在下载视频: {video_url},保存路径: {output_dir},质量: {quality}"
|
||||||
if output_dir is None:
|
)
|
||||||
output_dir = get_data_dir()
|
if output_dir is None:
|
||||||
if not output_dir:
|
output_dir = get_data_dir()
|
||||||
output_dir = self.cache_data
|
if not output_dir:
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
output_dir = self.cache_data
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
output_path = os.path.join(output_dir, "%(id)s.%(ext)s")
|
output_path = os.path.join(output_dir, "%(id)s.%(ext)s")
|
||||||
|
|
||||||
video_data = self.fetch_video_info(video_url)
|
video_data = self.fetch_video_info(video_url)
|
||||||
output_path = output_path % {
|
output_path = output_path % {
|
||||||
"id": video_data['aweme_detail']['aweme_id'],
|
"id": video_data['aweme_detail']['aweme_id'],
|
||||||
"ext": "mp3",
|
"ext": "mp3",
|
||||||
}
|
}
|
||||||
url = video_data['aweme_detail']['music']['play_url']['uri']
|
url = video_data['aweme_detail']['music']['play_url']['uri']
|
||||||
# 下载音频
|
# 下载音频
|
||||||
audio_data = requests.get(url)
|
audio_data = requests.get(url)
|
||||||
with open(output_path, 'wb') as f:
|
with open(output_path, 'wb') as f:
|
||||||
f.write(audio_data.content)
|
f.write(audio_data.content)
|
||||||
print(url)
|
print(url)
|
||||||
tags = []
|
tags = []
|
||||||
for tag in video_data['aweme_detail']['video_tag']:
|
for tag in video_data['aweme_detail']['video_tag']:
|
||||||
if tag['tag_name']:
|
if tag['tag_name']:
|
||||||
tags.append(tag['tag_name'])
|
tags.append(tag['tag_name'])
|
||||||
|
|
||||||
return AudioDownloadResult(
|
return AudioDownloadResult(
|
||||||
file_path=output_path,
|
file_path=output_path,
|
||||||
title=video_data['aweme_detail']['item_title'],
|
title=video_data['aweme_detail']['item_title'],
|
||||||
duration=video_data['aweme_detail']['video']['duration'],
|
duration=video_data['aweme_detail']['video']['duration'],
|
||||||
cover_url=video_data['aweme_detail']['video']['cover_original_scale']['url_list'][0] if
|
cover_url=video_data['aweme_detail']['video']['cover_original_scale']['url_list'][0] if
|
||||||
video_data['aweme_detail']['video']['cover'] else video_data['video']['big_thumbs']['img_url'],
|
video_data['aweme_detail']['video']['cover'] else video_data['video']['big_thumbs']['img_url'],
|
||||||
platform="douyin",
|
platform="douyin",
|
||||||
video_id=video_data['aweme_detail']['aweme_id'],
|
video_id=video_data['aweme_detail']['aweme_id'],
|
||||||
raw_info={
|
raw_info={
|
||||||
'tags': video_data['aweme_detail']['caption'] + ''.join(tags),
|
'tags': video_data['aweme_detail']['caption'] + ''.join(tags),
|
||||||
},
|
},
|
||||||
video_path=None # ❗音频下载不包含视频路径
|
video_path=None # ❗音频下载不包含视频路径
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
def download_video(self, video_url: str, output_dir: Union[str, None] = None) -> str:
|
def download_video(self, video_url: str, output_dir: Union[str, None] = None) -> str:
|
||||||
|
|
||||||
|
|||||||
21
backend/app/enmus/exception.py
Normal file
21
backend/app/enmus/exception.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import enum
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderErrorEnum(enum.Enum):
|
||||||
|
CONNECTION_TEST_FAILED = (200101, "供应商连接测试失败")
|
||||||
|
SAVE_FAILED = (200102, "供应商保存失败")
|
||||||
|
CREATE_FAILED = (200103, "供应商创建失败")
|
||||||
|
NOT_FOUND = (200104, "供应商不存在/未保存")
|
||||||
|
WRONG_PARAMETER = (200105, "API / API 地址不正确")
|
||||||
|
UNKNOW_ERROR = (200106, "未知错误")
|
||||||
|
|
||||||
|
def __init__(self, code, message):
|
||||||
|
self.code = code
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
class NoteErrorEnum(enum.Enum):
|
||||||
|
PLATFORM_NOT_SUPPORTED = (300101 ,"选择的平台不受支持")
|
||||||
|
|
||||||
|
def __init__(self, code, message):
|
||||||
|
self.code = code
|
||||||
|
self.message = message
|
||||||
0
backend/app/exceptions/__init__.py
Normal file
0
backend/app/exceptions/__init__.py
Normal file
6
backend/app/exceptions/biz_exception.py
Normal file
6
backend/app/exceptions/biz_exception.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
# exceptions/biz_exception.py
|
||||||
|
|
||||||
|
class BizException(Exception):
|
||||||
|
def __init__(self, code: int, message: str = "业务异常"):
|
||||||
|
self.code = code
|
||||||
|
self.message = message
|
||||||
33
backend/app/exceptions/exception_handlers.py
Normal file
33
backend/app/exceptions/exception_handlers.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
# middlewares/exception_handler.py
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from app.enmus.exception import NoteErrorEnum
|
||||||
|
from app.exceptions.biz_exception import BizException
|
||||||
|
from app.exceptions.note import NoteError
|
||||||
|
from app.exceptions.provider import ProviderError
|
||||||
|
from app.utils.logger import get_logger
|
||||||
|
from app.utils.response import ResponseWrapper as R
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
def register_exception_handlers(app: FastAPI):
|
||||||
|
@app.exception_handler(BizException)
|
||||||
|
async def biz_exception_handler(request: Request, exc: BizException):
|
||||||
|
logger.error(f"BizException: {exc.code} - {exc.message}")
|
||||||
|
return R.error(code=exc.code, msg=str(exc.message))
|
||||||
|
@app.exception_handler(NoteError)
|
||||||
|
async def note_exception_handler(request: Request, exc: NoteError):
|
||||||
|
logger.error(f"NoteError: {exc.code} - {exc.message}")
|
||||||
|
return R.error(code=exc.code, msg=str(exc.message))
|
||||||
|
@app.exception_handler(ProviderError)
|
||||||
|
async def provider_exception_handler(request: Request, exc: ProviderError):
|
||||||
|
logger.error(f"供应商模块错误: {exc.code} - {exc.message}")
|
||||||
|
return R.error(code=exc.code, msg=str(exc.message))
|
||||||
|
|
||||||
|
@app.exception_handler(Exception)
|
||||||
|
async def general_exception_handler(request: Request, exc: Exception):
|
||||||
|
logger.error(f"系统异常: {str(exc)}\n{traceback.format_exc()}")
|
||||||
|
return R.error(code=500000, msg="系统异常")
|
||||||
9
backend/app/exceptions/note.py
Normal file
9
backend/app/exceptions/note.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
# exceptions.py
|
||||||
|
from app.enmus.exception import ProviderErrorEnum
|
||||||
|
|
||||||
|
|
||||||
|
class NoteError(Exception):
|
||||||
|
def __init__(self, message: str,code: ProviderErrorEnum) -> None:
|
||||||
|
super().__init__(message)
|
||||||
|
self.code=code
|
||||||
|
self.message = message
|
||||||
@@ -1,5 +1,12 @@
|
|||||||
# exceptions.py
|
# exceptions.py
|
||||||
class ConnectionTestError(Exception):
|
from app.enmus.exception import ProviderErrorEnum
|
||||||
def __init__(self, message: str):
|
|
||||||
|
|
||||||
|
class ProviderError(Exception):
|
||||||
|
def __init__(self, message: str,code: ProviderErrorEnum) -> None:
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
self.message = message
|
self.code=code
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,9 @@ from typing import Optional, Union
|
|||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from app.utils.logger import get_logger
|
||||||
|
|
||||||
|
logging= get_logger(__name__)
|
||||||
class OpenAICompatibleProvider:
|
class OpenAICompatibleProvider:
|
||||||
def __init__(self, api_key: str, base_url: str, model: Union[str, None]=None):
|
def __init__(self, api_key: str, base_url: str, model: Union[str, None]=None):
|
||||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||||
@@ -13,11 +16,16 @@ class OpenAICompatibleProvider:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def test_connection(api_key: str, base_url: str) -> bool:
|
def test_connection(api_key: str, base_url: str) -> bool:
|
||||||
print(api_key)
|
|
||||||
try:
|
try:
|
||||||
client = OpenAI(api_key=api_key, base_url=base_url)
|
client = OpenAI(api_key=api_key, base_url=base_url)
|
||||||
client.models.list()
|
model = client.models.list()
|
||||||
|
# for segment in model:
|
||||||
|
# print(segment)
|
||||||
|
# print(model)
|
||||||
|
logging.info("连通性测试成功")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error connecting to OpenAI API: {e}")
|
logging.info(f"连通性测试失败:{e}")
|
||||||
|
|
||||||
|
# print(f"Error connecting to OpenAI API: {e}")
|
||||||
return False
|
return False
|
||||||
@@ -27,4 +27,6 @@ def get_cookie(platform: str):
|
|||||||
@router.post("/update_downloader_cookie")
|
@router.post("/update_downloader_cookie")
|
||||||
def update_cookie(data: CookieUpdateRequest):
|
def update_cookie(data: CookieUpdateRequest):
|
||||||
cookie_manager.set(data.platform, data.cookie)
|
cookie_manager.set(data.platform, data.cookie)
|
||||||
return {"message": "Cookie updated successfully"}
|
return R.success(
|
||||||
|
|
||||||
|
)
|
||||||
|
|||||||
@@ -31,10 +31,9 @@ def delete_model(model_id: int):
|
|||||||
return R.error(f"删除模型失败: {e}")
|
return R.error(f"删除模型失败: {e}")
|
||||||
@router.get("/model_list/{provider_id}")
|
@router.get("/model_list/{provider_id}")
|
||||||
def model_list(provider_id):
|
def model_list(provider_id):
|
||||||
try:
|
|
||||||
return R.success(modelService.get_all_models_by_id(provider_id))
|
return R.success(modelService.get_all_models_by_id(provider_id))
|
||||||
except Exception as e:
|
|
||||||
return R.error(e)
|
|
||||||
|
|
||||||
@router.post("/models")
|
@router.post("/models")
|
||||||
def create_model(data: CreateModelRequest):
|
def create_model(data: CreateModelRequest):
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
@@ -10,7 +11,9 @@ from pydantic import BaseModel, validator, field_validator
|
|||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
|
|
||||||
from app.db.video_task_dao import get_task_by_video
|
from app.db.video_task_dao import get_task_by_video
|
||||||
|
from app.enmus.exception import NoteErrorEnum
|
||||||
from app.enmus.note_enums import DownloadQuality
|
from app.enmus.note_enums import DownloadQuality
|
||||||
|
from app.exceptions.note import NoteError
|
||||||
from app.services.note import NoteGenerator, logger
|
from app.services.note import NoteGenerator, logger
|
||||||
from app.utils.response import ResponseWrapper as R
|
from app.utils.response import ResponseWrapper as R
|
||||||
from app.utils.url_parser import extract_video_id
|
from app.utils.url_parser import extract_video_id
|
||||||
@@ -54,12 +57,13 @@ class VideoRequest(BaseModel):
|
|||||||
if parsed.scheme in ("http", "https"):
|
if parsed.scheme in ("http", "https"):
|
||||||
# 是网络链接,继续用原有平台校验
|
# 是网络链接,继续用原有平台校验
|
||||||
if not is_supported_video_url(url):
|
if not is_supported_video_url(url):
|
||||||
raise ValueError("暂不支持该视频平台或链接格式无效")
|
raise NoteError(code=NoteErrorEnum.PLATFORM_NOT_SUPPORTED.code,
|
||||||
|
message=NoteErrorEnum.PLATFORM_NOT_SUPPORTED.message)
|
||||||
|
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
NOTE_OUTPUT_DIR = "note_results"
|
NOTE_OUTPUT_DIR = os.getenv("NOTE_OUTPUT_DIR", "note_results")
|
||||||
UPLOAD_DIR = "uploads"
|
UPLOAD_DIR = "uploads"
|
||||||
|
|
||||||
|
|
||||||
@@ -74,30 +78,32 @@ def run_note_task(task_id: str, video_url: str, platform: str, quality: Download
|
|||||||
_format: list = None, style: str = None, extras: str = None, video_understanding: bool = False,
|
_format: list = None, style: str = None, extras: str = None, video_understanding: bool = False,
|
||||||
video_interval=0, grid_size=[]
|
video_interval=0, grid_size=[]
|
||||||
):
|
):
|
||||||
try:
|
|
||||||
if not model_name or not provider_id:
|
|
||||||
raise HTTPException(status_code=400, detail="请选择模型和提供者")
|
|
||||||
|
|
||||||
note = NoteGenerator().generate(
|
if not model_name or not provider_id:
|
||||||
video_url=video_url,
|
raise HTTPException(status_code=400, detail="请选择模型和提供者")
|
||||||
platform=platform,
|
|
||||||
quality=quality,
|
note = NoteGenerator().generate(
|
||||||
task_id=task_id,
|
video_url=video_url,
|
||||||
model_name=model_name,
|
platform=platform,
|
||||||
provider_id=provider_id,
|
quality=quality,
|
||||||
link=link,
|
task_id=task_id,
|
||||||
_format=_format,
|
model_name=model_name,
|
||||||
style=style,
|
provider_id=provider_id,
|
||||||
extras=extras,
|
link=link,
|
||||||
screenshot=screenshot
|
_format=_format,
|
||||||
, video_understanding=video_understanding,
|
style=style,
|
||||||
video_interval=video_interval,
|
extras=extras,
|
||||||
grid_size=grid_size
|
screenshot=screenshot
|
||||||
)
|
, video_understanding=video_understanding,
|
||||||
logger.info(f"Note generated: {task_id}")
|
video_interval=video_interval,
|
||||||
save_note_to_file(task_id, note)
|
grid_size=grid_size
|
||||||
except Exception as e:
|
)
|
||||||
save_note_to_file(task_id, {"error": str(e)})
|
logger.info(f"Note generated: {task_id}")
|
||||||
|
if not note or not note.markdown:
|
||||||
|
logger.warning(f"任务 {task_id} 执行失败,跳过保存")
|
||||||
|
return
|
||||||
|
save_note_to_file(task_id, note)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.post('/delete_task')
|
@router.post('/delete_task')
|
||||||
@@ -135,7 +141,6 @@ def generate_note(data: VideoRequest, background_tasks: BackgroundTasks):
|
|||||||
# msg='笔记已生成,请勿重复发起',
|
# msg='笔记已生成,请勿重复发起',
|
||||||
#
|
#
|
||||||
# )
|
# )
|
||||||
|
|
||||||
if data.task_id:
|
if data.task_id:
|
||||||
# 如果传了task_id,说明是重试!
|
# 如果传了task_id,说明是重试!
|
||||||
task_id = data.task_id
|
task_id = data.task_id
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from typing import Optional
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from app.exceptions.provider import ConnectionTestError
|
from app.exceptions.provider import ProviderError
|
||||||
from app.models.model_config import ModelConfig
|
from app.models.model_config import ModelConfig
|
||||||
from app.services.model import ModelService
|
from app.services.model import ModelService
|
||||||
from app.utils.response import ResponseWrapper as R
|
from app.utils.response import ResponseWrapper as R
|
||||||
@@ -88,9 +88,5 @@ def update_provider(data: ProviderUpdateRequest):
|
|||||||
|
|
||||||
@router.post('/connect_test')
|
@router.post('/connect_test')
|
||||||
def gpt_connect_test(data: TestRequest):
|
def gpt_connect_test(data: TestRequest):
|
||||||
try:
|
ModelService().connect_test(data.id)
|
||||||
ModelService().connect_test(data.id)
|
return R.success(msg='连接成功')
|
||||||
return R.success(msg='连接成功')
|
|
||||||
except Exception as e:
|
|
||||||
print("捕获到异常类型:", type(e))
|
|
||||||
return R.error(msg=str(e))
|
|
||||||
|
|||||||
@@ -1,12 +1,16 @@
|
|||||||
|
|
||||||
|
|
||||||
from app.db.model_dao import insert_model, get_all_models, get_model_by_provider_and_name, delete_model
|
from app.db.model_dao import insert_model, get_all_models, get_model_by_provider_and_name, delete_model
|
||||||
from app.db.provider_dao import get_enabled_providers
|
from app.db.provider_dao import get_enabled_providers
|
||||||
from app.exceptions.provider import ConnectionTestError
|
from app.enmus.exception import ProviderErrorEnum
|
||||||
|
from app.exceptions.provider import ProviderError
|
||||||
from app.gpt.gpt_factory import GPTFactory
|
from app.gpt.gpt_factory import GPTFactory
|
||||||
from app.gpt.provider.OpenAI_compatible_provider import OpenAICompatibleProvider
|
from app.gpt.provider.OpenAI_compatible_provider import OpenAICompatibleProvider
|
||||||
from app.models.model_config import ModelConfig
|
from app.models.model_config import ModelConfig
|
||||||
from app.services.provider import ProviderService
|
from app.services.provider import ProviderService
|
||||||
|
from app.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger=get_logger(__name__)
|
||||||
class ModelService:
|
class ModelService:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -83,37 +87,38 @@ class ModelService:
|
|||||||
provider = ProviderService.get_provider_by_id(provider_id)
|
provider = ProviderService.get_provider_by_id(provider_id)
|
||||||
|
|
||||||
models = ModelService.get_model_list(provider["id"], verbose=verbose)
|
models = ModelService.get_model_list(provider["id"], verbose=verbose)
|
||||||
|
print(type(models))
|
||||||
model_list={
|
serializable_models = [m.dict() for m in models.data]
|
||||||
|
model_list = {
|
||||||
"models": models
|
"models": serializable_models
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.info(f"[{provider['name']}] 获取模型成功")
|
||||||
return model_list
|
return model_list
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[{provider_id}] 获取模型失败: {e}")
|
# print(f"[{provider_id}] 获取模型失败: {e}")
|
||||||
|
logger.error(f"[{provider_id}] 获取模型失败: {e}")
|
||||||
return []
|
return []
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def connect_test(id: str) -> bool:
|
def connect_test(id: str) -> bool:
|
||||||
try:
|
|
||||||
provider = ProviderService.get_provider_by_id(id)
|
|
||||||
|
|
||||||
if provider:
|
provider = ProviderService.get_provider_by_id(id)
|
||||||
if not provider.get('api_key'):
|
|
||||||
raise ConnectionTestError(f"供应商信息未找到,请先保存重试")
|
if provider:
|
||||||
result = OpenAICompatibleProvider.test_connection(
|
if not provider.get('api_key'):
|
||||||
api_key=provider.get('api_key'),
|
raise ProviderError(code=ProviderErrorEnum.NOT_FOUND.code, message=ProviderErrorEnum.NOT_FOUND.message)
|
||||||
base_url=provider.get('base_url')
|
result = OpenAICompatibleProvider.test_connection(
|
||||||
)
|
api_key=provider.get('api_key'),
|
||||||
if result:
|
base_url=provider.get('base_url')
|
||||||
return True
|
)
|
||||||
else:
|
if result:
|
||||||
raise ConnectionTestError("请检查API Key 和 API 地址是否正确")
|
return True
|
||||||
|
else:
|
||||||
|
raise ProviderError(code=ProviderErrorEnum.WRONG_PARAMETER.code,message=ProviderErrorEnum.WRONG_PARAMETER.message)
|
||||||
|
|
||||||
|
raise ProviderError(code=ProviderErrorEnum.NOT_FOUND.code, message=ProviderErrorEnum.NOT_FOUND.message)
|
||||||
|
|
||||||
|
|
||||||
raise ConnectionTestError("供应商信息未找到,请先保存重试")
|
|
||||||
except Exception as e:
|
|
||||||
# 抛出业务异常,交由 Controller 处理
|
|
||||||
raise ConnectionTestError(f"{str(e)}") from e
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def delete_model_by_id( model_id: int) -> bool:
|
def delete_model_by_id( model_id: int) -> bool:
|
||||||
|
|||||||
@@ -1,75 +1,63 @@
|
|||||||
import json
|
import json
|
||||||
from dataclasses import asdict
|
import logging
|
||||||
|
|
||||||
from fastapi import HTTPException
|
|
||||||
|
|
||||||
from app.downloaders.local_downloader import LocalDownloader
|
|
||||||
from app.enmus.task_status_enums import TaskStatus
|
|
||||||
import os
|
import os
|
||||||
from typing import Union, Optional
|
import re
|
||||||
|
from dataclasses import asdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Tuple, Union, Any
|
||||||
|
|
||||||
from pydantic import HttpUrl
|
from pydantic import HttpUrl
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from app.db.video_task_dao import insert_video_task, delete_task_by_video
|
|
||||||
from app.downloaders.base import Downloader
|
from app.downloaders.base import Downloader
|
||||||
from app.downloaders.bilibili_downloader import BilibiliDownloader
|
from app.services.constant import SUPPORT_PLATFORM_MAP
|
||||||
from app.downloaders.douyin_downloader import DouyinDownloader
|
from app.enmus.task_status_enums import TaskStatus
|
||||||
from app.downloaders.youtube_downloader import YoutubeDownloader
|
from app.enmus.exception import NoteErrorEnum, ProviderErrorEnum
|
||||||
|
from app.exceptions.note import NoteError
|
||||||
|
from app.exceptions.provider import ProviderError
|
||||||
|
from app.db.video_task_dao import delete_task_by_video, insert_video_task
|
||||||
from app.gpt.base import GPT
|
from app.gpt.base import GPT
|
||||||
from app.gpt.deepseek_gpt import DeepSeekGPT
|
|
||||||
from app.gpt.gpt_factory import GPTFactory
|
from app.gpt.gpt_factory import GPTFactory
|
||||||
from app.gpt.openai_gpt import OpenaiGPT
|
from app.models.audio_model import AudioDownloadResult
|
||||||
from app.gpt.qwen_gpt import QwenGPT
|
|
||||||
from app.models.gpt_model import GPTSource
|
from app.models.gpt_model import GPTSource
|
||||||
from app.models.model_config import ModelConfig
|
from app.models.model_config import ModelConfig
|
||||||
from app.models.notes_model import NoteResult
|
from app.models.notes_model import NoteResult
|
||||||
from app.models.notes_model import AudioDownloadResult
|
|
||||||
from app.enmus.note_enums import DownloadQuality
|
|
||||||
from app.models.transcriber_model import TranscriptResult, TranscriptSegment
|
from app.models.transcriber_model import TranscriptResult, TranscriptSegment
|
||||||
from app.services.constant import SUPPORT_PLATFORM_MAP
|
|
||||||
|
|
||||||
from app.services.provider import ProviderService
|
from app.services.provider import ProviderService
|
||||||
from app.transcriber.base import Transcriber
|
from app.transcriber.base import Transcriber
|
||||||
from app.transcriber.transcriber_provider import get_transcriber, _transcribers
|
from app.transcriber.transcriber_provider import get_transcriber, _transcribers
|
||||||
from app.transcriber.whisper import WhisperTranscriber
|
|
||||||
import re
|
|
||||||
|
|
||||||
from app.utils.note_helper import replace_content_markers
|
|
||||||
from app.utils.status_code import StatusCode
|
|
||||||
from app.utils.video_helper import generate_screenshot
|
|
||||||
|
|
||||||
# from app.services.whisperer import transcribe_audio
|
|
||||||
# from app.services.gpt import summarize_text
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from app.utils.logger import get_logger
|
|
||||||
from app.utils.video_reader import VideoReader
|
from app.utils.video_reader import VideoReader
|
||||||
from events import transcription_finished
|
from app.utils.video_helper import generate_screenshot
|
||||||
|
from app.utils.note_helper import replace_content_markers
|
||||||
|
from app.enmus.note_enums import DownloadQuality
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
# 环境变量
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
api_path = os.getenv("API_BASE_URL", "http://localhost")
|
NOTE_OUTPUT_DIR = Path(os.getenv("NOTE_OUTPUT_DIR", "note_results"))
|
||||||
BACKEND_PORT = os.getenv("BACKEND_PORT", 8000)
|
NOTE_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
IMAGE_BASE_URL = os.getenv("IMAGE_BASE_URL", "/static/screenshots")
|
||||||
|
IMAGE_OUTPUT_DIR = os.getenv("OUT_DIR", "images")
|
||||||
|
|
||||||
BACKEND_BASE_URL = f"{api_path}:{BACKEND_PORT}"
|
logger = logging.getLogger(__name__)
|
||||||
output_dir = os.getenv('OUT_DIR')
|
logger.setLevel(logging.INFO)
|
||||||
image_base_url = os.getenv('IMAGE_BASE_URL')
|
|
||||||
logger.info("starting up")
|
|
||||||
|
|
||||||
NOTE_OUTPUT_DIR = "note_results"
|
|
||||||
|
|
||||||
|
|
||||||
class NoteGenerator:
|
class NoteGenerator:
|
||||||
|
|
||||||
|
class States:
|
||||||
|
INIT = 'INIT'
|
||||||
|
PARSING = 'PARSING'
|
||||||
|
DOWNLOADING = 'DOWNLOADING'
|
||||||
|
TRANSCRIBING = 'TRANSCRIBING'
|
||||||
|
SUMMARIZING = 'SUMMARIZING'
|
||||||
|
SAVING = 'SAVING'
|
||||||
|
SUCCESS = 'SUCCESS'
|
||||||
|
FAILED = 'FAILED'
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.model_size: str = 'base'
|
self.transcriber_type = os.getenv("TRANSCRIBER_TYPE", "fast-whisper")
|
||||||
self.device: Union[str, None] = None
|
self.transcriber: Transcriber = self._init_transcriber()
|
||||||
self.transcriber_type = os.getenv('TRANSCRIBER_TYPE', 'fast-whisper')
|
self.video_img_urls = []
|
||||||
self.transcriber = self.get_transcriber()
|
|
||||||
self.video_path = None
|
|
||||||
logger.info("初始化NoteGenerator")
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_task_status(task_id: str, status: Union[str, TaskStatus], message: Optional[str] = None):
|
def update_task_status(task_id: str, status: Union[str, TaskStatus], message: Optional[str] = None):
|
||||||
@@ -81,310 +69,179 @@ class NoteGenerator:
|
|||||||
with open(path, "w", encoding="utf-8") as f:
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
json.dump(content, f, ensure_ascii=False, indent=2)
|
json.dump(content, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
def get_gpt(self, model_name: str = None, provider_id: str = None) -> GPT:
|
def generate(
|
||||||
|
self,
|
||||||
|
video_url: Union[str, HttpUrl],
|
||||||
|
platform: str,
|
||||||
|
quality: DownloadQuality = DownloadQuality.medium,
|
||||||
|
task_id: Optional[str] = None,
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
provider_id: Optional[str] = None,
|
||||||
|
link: bool = False,
|
||||||
|
screenshot: bool = False,
|
||||||
|
_format: Optional[List[str]] = None,
|
||||||
|
style: Optional[str] = None,
|
||||||
|
extras: Optional[str] = None,
|
||||||
|
output_path: Optional[str] = None,
|
||||||
|
video_understanding: bool = False,
|
||||||
|
video_interval: int = 0,
|
||||||
|
grid_size: Optional[List[int]] = None,
|
||||||
|
) -> NoteResult | None:
|
||||||
|
|
||||||
|
self.task_id = task_id
|
||||||
|
self._change_state(self.States.INIT)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._change_state(self.States.PARSING)
|
||||||
|
|
||||||
|
downloader = self._get_downloader(platform)
|
||||||
|
gpt = self._get_gpt(model_name, provider_id)
|
||||||
|
|
||||||
|
self.audio_cache_file = NOTE_OUTPUT_DIR / f"{task_id}_audio.json"
|
||||||
|
self.transcript_cache_file = NOTE_OUTPUT_DIR / f"{task_id}_transcript.json"
|
||||||
|
self.markdown_cache_file = NOTE_OUTPUT_DIR / f"{task_id}_markdown.md"
|
||||||
|
|
||||||
|
self.audio_meta = self._download_audio_video(
|
||||||
|
downloader, video_url, quality, output_path,
|
||||||
|
screenshot, video_understanding, video_interval, grid_size or []
|
||||||
|
)
|
||||||
|
|
||||||
|
self.transcript = self._transcribe_audio()
|
||||||
|
|
||||||
|
self.markdown = self._summarize_text(
|
||||||
|
gpt, link, screenshot, _format or [], style, extras
|
||||||
|
)
|
||||||
|
|
||||||
|
self.markdown = self._post_process_markdown(
|
||||||
|
self.markdown, self.video_path, _format or [], self.audio_meta, platform
|
||||||
|
)
|
||||||
|
|
||||||
|
self._change_state(self.States.SAVING)
|
||||||
|
self._save_metadata(self.audio_meta.video_id, platform, task_id)
|
||||||
|
|
||||||
|
self._change_state(self.States.SUCCESS)
|
||||||
|
return NoteResult(markdown=self.markdown, transcript=self.transcript, audio_meta=self.audio_meta)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"任务 {self.task_id} 失败: {e}")
|
||||||
|
self._change_state(self.States.FAILED, str(e))
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _change_state(self, state: str, message: Optional[str] = None):
|
||||||
|
if not self.task_id:
|
||||||
|
return
|
||||||
|
NOTE_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
status_file = NOTE_OUTPUT_DIR / f"{self.task_id}.status.json"
|
||||||
|
data = {"status": state}
|
||||||
|
if message:
|
||||||
|
data["message"] = message
|
||||||
|
temp_file = status_file.with_suffix('.tmp')
|
||||||
|
with temp_file.open('w', encoding='utf-8') as f:
|
||||||
|
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||||
|
temp_file.replace(status_file)
|
||||||
|
|
||||||
|
def _init_transcriber(self) -> Transcriber:
|
||||||
|
if self.transcriber_type not in _transcribers:
|
||||||
|
raise Exception(f"不支持的转写器:{self.transcriber_type}")
|
||||||
|
return get_transcriber(self.transcriber_type)
|
||||||
|
|
||||||
|
def _get_gpt(self, model_name: Optional[str], provider_id: Optional[str]) -> GPT:
|
||||||
provider = ProviderService.get_provider_by_id(provider_id)
|
provider = ProviderService.get_provider_by_id(provider_id)
|
||||||
if not provider:
|
if not provider:
|
||||||
logger.error(f"[get_gpt] 未找到对应的模型供应商: provider_id={provider_id}")
|
raise ProviderError(code=ProviderErrorEnum.NOT_FOUND, message=ProviderErrorEnum.NOT_FOUND.message)
|
||||||
raise ValueError(f"未找到对应的模型供应商: provider_id={provider_id}")
|
config = ModelConfig(
|
||||||
|
api_key=provider["api_key"], base_url=provider["base_url"],
|
||||||
gpt = GPTFactory().from_config(
|
model_name=model_name, provider=provider["type"], name=provider["name"]
|
||||||
ModelConfig(
|
|
||||||
api_key=provider.get('api_key'),
|
|
||||||
base_url=provider.get('base_url'),
|
|
||||||
model_name=model_name,
|
|
||||||
provider=provider.get('type'),
|
|
||||||
name=provider.get('name')
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
return gpt
|
return GPTFactory().from_config(config)
|
||||||
|
|
||||||
def get_downloader(self, platform: str) -> Downloader:
|
def _get_downloader(self, platform: str) -> Downloader:
|
||||||
downloader = SUPPORT_PLATFORM_MAP[platform]
|
downloader_cls = SUPPORT_PLATFORM_MAP.get(platform)
|
||||||
if downloader:
|
if not downloader_cls:
|
||||||
logger.info(f"使用{downloader}下载器")
|
raise NoteError(code=NoteErrorEnum.PLATFORM_NOT_SUPPORTED.code,
|
||||||
return downloader
|
message=NoteErrorEnum.PLATFORM_NOT_SUPPORTED.message)
|
||||||
else:
|
return downloader_cls
|
||||||
logger.warning("不支持的平台")
|
|
||||||
raise ValueError(f"不支持的平台:{platform}")
|
|
||||||
|
|
||||||
def get_transcriber(self) -> Transcriber:
|
def _download_audio_video(self, downloader, video_url, quality, output_path,
|
||||||
'''
|
screenshot, video_understanding, video_interval, grid_size):
|
||||||
|
self._change_state(self.States.DOWNLOADING)
|
||||||
|
|
||||||
:param transcriber: 选择的转义器
|
need_video = screenshot or video_understanding
|
||||||
:return:
|
if need_video:
|
||||||
'''
|
self.video_path = Path(downloader.download_video(video_url, output_path))
|
||||||
if self.transcriber_type in _transcribers.keys():
|
if grid_size:
|
||||||
logger.info(f"使用{self.transcriber_type}转义器")
|
self.video_img_urls = VideoReader(
|
||||||
return get_transcriber(transcriber_type=self.transcriber_type)
|
video_path=str(self.video_path),
|
||||||
else:
|
grid_size=tuple(grid_size),
|
||||||
logger.warning("不支持的转义器")
|
frame_interval=video_interval,
|
||||||
raise ValueError(f"不支持的转义器:{self.transcriber}")
|
unit_width=1280, unit_height=720,
|
||||||
|
save_quality=90,
|
||||||
|
).run()
|
||||||
|
|
||||||
def save_meta(self, video_id, platform, task_id):
|
if self.audio_cache_file.exists():
|
||||||
logger.info(f"记录已经生成的数据信息")
|
with open(self.audio_cache_file, "r", encoding="utf-8") as f:
|
||||||
insert_video_task(video_id=video_id, platform=platform, task_id=task_id)
|
data = json.load(f)
|
||||||
|
return AudioDownloadResult(**data)
|
||||||
|
|
||||||
def insert_screenshots_into_markdown(self, markdown: str, video_path: str, image_base_url: str,
|
audio = downloader.download(
|
||||||
output_dir: str, _format: list) -> str:
|
video_url=video_url, quality=quality, output_dir=output_path, need_video=need_video
|
||||||
"""
|
)
|
||||||
扫描 markdown 中的 *Screenshot-xx:xx,生成截图并插入 markdown 图片
|
with open(self.audio_cache_file, "w", encoding="utf-8") as f:
|
||||||
:param markdown:
|
json.dump(asdict(audio), f, ensure_ascii=False, indent=2)
|
||||||
:param image_base_url: 最终返回给前端的路径前缀(如 /static/screenshots)
|
return audio
|
||||||
"""
|
|
||||||
matches = self.extract_screenshot_timestamps(markdown)
|
|
||||||
new_markdown = markdown
|
|
||||||
|
|
||||||
logger.info(f"开始为笔记生成截图")
|
def _transcribe_audio(self):
|
||||||
try:
|
self._change_state(self.States.TRANSCRIBING)
|
||||||
for idx, (marker, ts) in enumerate(matches):
|
if self.transcript_cache_file.exists():
|
||||||
image_path = generate_screenshot(video_path, output_dir, ts, idx)
|
with open(self.transcript_cache_file, "r", encoding="utf-8") as f:
|
||||||
image_relative_path = os.path.join(image_base_url, os.path.basename(image_path)).replace("\\", "/")
|
data = json.load(f)
|
||||||
image_url = f"/static/screenshots/{os.path.basename(image_path)}"
|
segments = [TranscriptSegment(**seg) for seg in data.get("segments", [])]
|
||||||
replacement = f""
|
return TranscriptResult(language=data["language"], full_text=data["full_text"], segments=segments)
|
||||||
new_markdown = new_markdown.replace(marker, replacement, 1)
|
|
||||||
|
|
||||||
return new_markdown
|
transcript = self.transcriber.transcript(self.audio_meta.file_path)
|
||||||
except Exception as e:
|
with open(self.transcript_cache_file, "w", encoding="utf-8") as f:
|
||||||
logger.error(f"截图生成失败:{e}")
|
json.dump(asdict(transcript), f, ensure_ascii=False, indent=2)
|
||||||
raise HTTPException(
|
return transcript
|
||||||
status_code=500,
|
|
||||||
detail={
|
|
||||||
"code": StatusCode.DOWNLOAD_ERROR,
|
|
||||||
"msg": f"截图生成失败",
|
|
||||||
"error": str(e)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
def _summarize_text(self, gpt, link, screenshot, formats, style, extras):
|
||||||
|
self._change_state(self.States.SUMMARIZING)
|
||||||
|
source = GPTSource(
|
||||||
|
title=self.audio_meta.title,
|
||||||
|
segment=self.transcript.segments,
|
||||||
|
tags=self.audio_meta.raw_info.get("tags", []),
|
||||||
|
screenshot=screenshot,
|
||||||
|
video_img_urls=self.video_img_urls,
|
||||||
|
link=link, _format=formats, style=style, extras=extras
|
||||||
|
)
|
||||||
|
markdown = gpt.summarize(source)
|
||||||
|
with open(self.markdown_cache_file, "w", encoding="utf-8") as f:
|
||||||
|
f.write(markdown)
|
||||||
|
return markdown
|
||||||
|
|
||||||
@staticmethod
|
def _post_process_markdown(self, markdown, video_path, formats, audio_meta, platform):
|
||||||
def delete_note(video_id: str, platform: str):
|
if "screenshot" in formats and video_path:
|
||||||
logger.info(f"删除生成的笔记记录")
|
markdown = self._insert_screenshots(markdown, video_path)
|
||||||
return delete_task_by_video(video_id, platform)
|
if "link" in formats:
|
||||||
|
markdown = replace_content_markers(markdown, video_id=audio_meta.video_id, platform=platform)
|
||||||
|
return markdown
|
||||||
|
|
||||||
import re
|
def _insert_screenshots(self, markdown, video_path):
|
||||||
|
|
||||||
def extract_screenshot_timestamps(self, markdown: str) -> list[tuple[str, int]]:
|
|
||||||
"""
|
|
||||||
从 Markdown 中提取 Screenshot 时间标记(如 *Screenshot-03:39 或 Screenshot-[03:39]),
|
|
||||||
并返回匹配文本和对应时间戳(秒)
|
|
||||||
"""
|
|
||||||
logger.info(f"开始提取截图时间标记")
|
|
||||||
pattern = r"(?:\*Screenshot-(\d{2}):(\d{2})|Screenshot-\[(\d{2}):(\d{2})\])"
|
pattern = r"(?:\*Screenshot-(\d{2}):(\d{2})|Screenshot-\[(\d{2}):(\d{2})\])"
|
||||||
matches = list(re.finditer(pattern, markdown))
|
matches = []
|
||||||
results = []
|
for match in re.finditer(pattern, markdown):
|
||||||
for match in matches:
|
|
||||||
mm = match.group(1) or match.group(3)
|
mm = match.group(1) or match.group(3)
|
||||||
ss = match.group(2) or match.group(4)
|
ss = match.group(2) or match.group(4)
|
||||||
total_seconds = int(mm) * 60 + int(ss)
|
matches.append((match.group(0), int(mm)*60+int(ss)))
|
||||||
results.append((match.group(0), total_seconds))
|
for idx, (marker, ts) in enumerate(matches):
|
||||||
return results
|
img_path = generate_screenshot(str(video_path), str(IMAGE_OUTPUT_DIR), ts, idx)
|
||||||
|
filename = Path(img_path).name
|
||||||
|
img_url = f"{IMAGE_BASE_URL.rstrip('/')}/{filename}"
|
||||||
|
markdown = markdown.replace(marker, f"", 1)
|
||||||
|
return markdown
|
||||||
|
|
||||||
def generate(
|
def _save_metadata(self, video_id: str, platform: str, task_id: str):
|
||||||
self,
|
insert_video_task(video_id=video_id, platform=platform, task_id=task_id)
|
||||||
video_url: Union[str, HttpUrl],
|
|
||||||
platform: str,
|
|
||||||
quality: DownloadQuality = DownloadQuality.medium,
|
|
||||||
task_id: Union[str, None] = None,
|
|
||||||
model_name: str = None,
|
|
||||||
provider_id: str = None,
|
|
||||||
link: bool = False,
|
|
||||||
screenshot: bool = False,
|
|
||||||
_format: list = None,
|
|
||||||
style: str = None,
|
|
||||||
extras: str = None,
|
|
||||||
path: Union[str, None] = None,
|
|
||||||
video_understanding: bool = False,
|
|
||||||
video_interval=0,
|
|
||||||
grid_size=[]
|
|
||||||
) -> NoteResult:
|
|
||||||
|
|
||||||
try:
|
@staticmethod
|
||||||
logger.info(f"🎯 开始解析并生成笔记,task_id={task_id}")
|
def delete_note(video_id: str, platform: str) -> int:
|
||||||
self.update_task_status(task_id, TaskStatus.PARSING)
|
return delete_task_by_video(video_id, platform)
|
||||||
downloader = self.get_downloader(platform)
|
|
||||||
gpt = self.get_gpt(model_name=model_name, provider_id=provider_id)
|
|
||||||
video_img_urls = []
|
|
||||||
audio_cache_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}_audio.json")
|
|
||||||
transcript_cache_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}_transcript.json")
|
|
||||||
markdown_cache_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}_markdown.md")
|
|
||||||
|
|
||||||
# -------- 1. 下载音频 --------
|
|
||||||
try:
|
|
||||||
self.update_task_status(task_id, TaskStatus.DOWNLOADING)
|
|
||||||
|
|
||||||
# 加载音频缓存(如果存在)
|
|
||||||
audio = None
|
|
||||||
if os.path.exists(audio_cache_path):
|
|
||||||
logger.info(f"检测到已有音频缓存,直接读取,task_id={task_id}")
|
|
||||||
with open(audio_cache_path, "r", encoding="utf-8") as f:
|
|
||||||
audio_data = json.load(f)
|
|
||||||
audio = AudioDownloadResult(**audio_data)
|
|
||||||
|
|
||||||
# 需要视频的情况(截图 or 视频理解)
|
|
||||||
need_video = 'screenshot' in _format or video_understanding
|
|
||||||
if need_video:
|
|
||||||
try:
|
|
||||||
video_path = downloader.download_video(video_url)
|
|
||||||
self.video_path = video_path
|
|
||||||
logger.info(f"成功下载视频文件: {video_path}")
|
|
||||||
|
|
||||||
video_img_urls = VideoReader(
|
|
||||||
video_path=video_path,
|
|
||||||
grid_size=tuple(grid_size),
|
|
||||||
frame_interval=video_interval,
|
|
||||||
unit_width=1280,
|
|
||||||
unit_height=720,
|
|
||||||
save_quality=90,
|
|
||||||
).run()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error 下载视频失败,task_id={task_id},错误信息:{e}")
|
|
||||||
self.update_task_status(task_id, TaskStatus.FAILED, message=f"下载音频失败:{e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail={
|
|
||||||
"code": StatusCode.DOWNLOAD_ERROR,
|
|
||||||
"msg": f"下载视频失败,task_id={task_id}",
|
|
||||||
"error": str(e)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 没有音频缓存就下载音频(可能同时也带上视频)
|
|
||||||
if audio is None:
|
|
||||||
audio = downloader.download(
|
|
||||||
video_url=video_url,
|
|
||||||
quality=quality,
|
|
||||||
output_dir=path,
|
|
||||||
need_video='screenshot' in _format, # 注意这里只为了截图需要
|
|
||||||
)
|
|
||||||
with open(audio_cache_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(asdict(audio), f, ensure_ascii=False, indent=2)
|
|
||||||
logger.info(f"音频下载并缓存成功,task_id={task_id}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error 下载音频失败,task_id={task_id},错误信息:{e}")
|
|
||||||
self.update_task_status(task_id, TaskStatus.FAILED, message=f"下载音频失败:{e}")
|
|
||||||
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail={
|
|
||||||
"code": StatusCode.DOWNLOAD_ERROR,
|
|
||||||
"msg": f"下载音频失败,task_id={task_id}",
|
|
||||||
"error": str(e)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# -------- 2. 转写文字 --------
|
|
||||||
try:
|
|
||||||
self.update_task_status(task_id, TaskStatus.TRANSCRIBING)
|
|
||||||
if os.path.exists(transcript_cache_path):
|
|
||||||
logger.info(f"检测到已有转写缓存,直接读取,task_id={task_id}")
|
|
||||||
try:
|
|
||||||
with open(transcript_cache_path, "r", encoding="utf-8") as f:
|
|
||||||
transcript_data = json.load(f)
|
|
||||||
transcript = TranscriptResult(
|
|
||||||
language=transcript_data["language"],
|
|
||||||
full_text=transcript_data["full_text"],
|
|
||||||
segments=[TranscriptSegment(**seg) for seg in transcript_data["segments"]]
|
|
||||||
)
|
|
||||||
except (json.JSONDecodeError, KeyError) as e:
|
|
||||||
logger.warning(f"Warning 读取转录缓存失败,重新转录,task_id={task_id},错误信息:{e}")
|
|
||||||
transcript: TranscriptResult = self.transcriber.transcript(file_path=audio.file_path)
|
|
||||||
with open(transcript_cache_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(asdict(transcript), f, ensure_ascii=False, indent=2)
|
|
||||||
else:
|
|
||||||
transcript: TranscriptResult = self.transcriber.transcript(file_path=audio.file_path)
|
|
||||||
with open(transcript_cache_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(asdict(transcript), f, ensure_ascii=False, indent=2)
|
|
||||||
logger.info(f"文字转写并缓存成功,task_id={task_id}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error 转写文字失败,task_id={task_id},错误信息:{e}")
|
|
||||||
self.update_task_status(task_id, TaskStatus.FAILED, message=f"转写文字失败:{e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail={
|
|
||||||
"code": StatusCode.GENERATE_ERROR, # =1003
|
|
||||||
"msg": f"转写文字失败,task_id={task_id}",
|
|
||||||
"error": str(e)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# -------- 3. 总结内容 --------
|
|
||||||
try:
|
|
||||||
self.update_task_status(task_id, TaskStatus.SUMMARIZING)
|
|
||||||
# if os.path.exists(markdown_cache_path):
|
|
||||||
# logger.info(f"检测到已有总结缓存,直接读取,task_id={task_id}")
|
|
||||||
# with open(markdown_cache_path, "r", encoding="utf-8") as f:
|
|
||||||
# markdown = f.read()
|
|
||||||
# else:
|
|
||||||
source = GPTSource(
|
|
||||||
title=audio.title,
|
|
||||||
segment=transcript.segments,
|
|
||||||
tags=audio.raw_info.get('tags'),
|
|
||||||
screenshot=screenshot,
|
|
||||||
video_img_urls=video_img_urls,
|
|
||||||
link=link,
|
|
||||||
_format=_format,
|
|
||||||
style=style,
|
|
||||||
extras=extras
|
|
||||||
)
|
|
||||||
|
|
||||||
markdown: str = gpt.summarize(source)
|
|
||||||
with open(markdown_cache_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write(markdown)
|
|
||||||
logger.info(f"GPT总结并缓存成功,task_id={task_id}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error 总结内容失败,task_id={task_id},错误信息:{e}")
|
|
||||||
self.update_task_status(task_id, TaskStatus.FAILED, message=f"总结内容失败:{e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail={
|
|
||||||
"code": StatusCode.GENERATE_ERROR, # =1003
|
|
||||||
"msg": f"总结内容失败,task_id={task_id}",
|
|
||||||
"error": str(e)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# -------- 4. 插入截图 --------
|
|
||||||
if _format and 'screenshot' in _format:
|
|
||||||
try:
|
|
||||||
markdown = self.insert_screenshots_into_markdown(markdown, self.video_path, image_base_url,
|
|
||||||
output_dir, _format)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Warning 插入截图失败,跳过处理,task_id={task_id},错误信息:{e}")
|
|
||||||
if _format and 'link' in _format:
|
|
||||||
try:
|
|
||||||
markdown = replace_content_markers(markdown, video_id=audio.video_id, platform=platform)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Warning 插入链接失败,跳过处理,task_id={task_id},错误信息:{e}")
|
|
||||||
# 注意:截图失败不终止整体流程
|
|
||||||
|
|
||||||
# -------- 5. 保存数据库记录 --------
|
|
||||||
self.update_task_status(task_id, TaskStatus.SAVING)
|
|
||||||
self.save_meta(video_id=audio.video_id, platform=platform, task_id=task_id)
|
|
||||||
|
|
||||||
# -------- 6. 完成 --------
|
|
||||||
self.update_task_status(task_id, TaskStatus.SUCCESS)
|
|
||||||
logger.info(f"succeed 笔记生成成功,task_id={task_id}")
|
|
||||||
# TODO :改为前端一键清除缓存
|
|
||||||
# if platform != 'local':
|
|
||||||
# transcription_finished.send({
|
|
||||||
# "file_path": audio.file_path,
|
|
||||||
# })
|
|
||||||
return NoteResult(
|
|
||||||
markdown=markdown,
|
|
||||||
transcript=transcript,
|
|
||||||
audio_meta=audio
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error 笔记生成流程异常终止,task_id={task_id},错误信息:{e}")
|
|
||||||
self.update_task_status(task_id, TaskStatus.FAILED, message=str(e))
|
|
||||||
|
|
||||||
# 返回结构化错误信息给前端(可以用于日志 + 显示 + 错误定位)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail={
|
|
||||||
"code": StatusCode.FAIL,
|
|
||||||
"msg": f"笔记生成流程异常终止,task_id={task_id}",
|
|
||||||
"error": str(e)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
@@ -1,18 +1,24 @@
|
|||||||
|
from fastapi.responses import JSONResponse
|
||||||
from app.utils.status_code import StatusCode
|
from app.utils.status_code import StatusCode
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional, Any
|
||||||
|
|
||||||
|
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
class ResponseWrapper:
|
class ResponseWrapper:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def success(data=None, msg="success", code=StatusCode.SUCCESS):
|
def success(data=None, msg="success", code=0):
|
||||||
return {
|
return JSONResponse(content={
|
||||||
"code": int(code),
|
"code": code,
|
||||||
"msg": msg,
|
"msg": msg,
|
||||||
"data": data
|
"data": data
|
||||||
}
|
})
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def error(msg="error", code=StatusCode.FAIL, data=None):
|
def error(msg="error", code=500, data=None):
|
||||||
return {
|
return JSONResponse(content={
|
||||||
"code": int(code),
|
"code": code,
|
||||||
"msg": msg,
|
"msg": msg,
|
||||||
"data": data
|
"data": data
|
||||||
}
|
})
|
||||||
@@ -4,7 +4,7 @@ import uvicorn
|
|||||||
from starlette.staticfiles import StaticFiles
|
from starlette.staticfiles import StaticFiles
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from app.core.exception_handlers import register_exception_handlers
|
from app.exceptions.exception_handlers import register_exception_handlers
|
||||||
from app.db.model_dao import init_model_table
|
from app.db.model_dao import init_model_table
|
||||||
from app.db.provider_dao import init_provider_table
|
from app.db.provider_dao import init_provider_table
|
||||||
from app.utils.logger import get_logger
|
from app.utils.logger import get_logger
|
||||||
@@ -33,11 +33,14 @@ if not os.path.exists(out_dir):
|
|||||||
os.makedirs(out_dir)
|
os.makedirs(out_dir)
|
||||||
|
|
||||||
app = create_app()
|
app = create_app()
|
||||||
|
register_exception_handlers(app)
|
||||||
app.mount(static_path, StaticFiles(directory=static_dir), name="static")
|
app.mount(static_path, StaticFiles(directory=static_dir), name="static")
|
||||||
app.mount("/uploads", StaticFiles(directory=uploads_dir), name="uploads")
|
app.mount("/uploads", StaticFiles(directory=uploads_dir), name="uploads")
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
register_exception_handlers(app)
|
|
||||||
|
|
||||||
|
|
||||||
register_handler()
|
register_handler()
|
||||||
ensure_ffmpeg_or_raise()
|
ensure_ffmpeg_or_raise()
|
||||||
register_handler()
|
register_handler()
|
||||||
@@ -46,8 +49,9 @@ async def startup_event():
|
|||||||
init_provider_table()
|
init_provider_table()
|
||||||
init_model_table()
|
init_model_table()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
port = int(os.getenv("BACKEND_PORT", 8000))
|
port = int(os.getenv("BACKEND_PORT", 8000))
|
||||||
host = os.getenv("BACKEND_HOST", "0.0.0.0")
|
host = os.getenv("BACKEND_HOST", "0.0.0.0")
|
||||||
logger.info(f"Starting server on {host}:{port}")
|
logger.info(f"Starting server on {host}:{port}")
|
||||||
uvicorn.run("main:app", host=host, port=port, reload=False)
|
uvicorn.run(app, host=host, port=port, reload=False)
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
server {
|
server {
|
||||||
listen 80;
|
listen 80;
|
||||||
|
client_max_body_size 10G;
|
||||||
# 所有非 /api 请求全部代理给 frontend 容器
|
# 所有非 /api 请求全部代理给 frontend 容器
|
||||||
location / {
|
location / {
|
||||||
proxy_pass http://frontend:80;
|
proxy_pass http://frontend:80;
|
||||||
@@ -11,8 +11,6 @@ server {
|
|||||||
proxy_pass http://backend:8000;
|
proxy_pass http://backend:8000;
|
||||||
proxy_set_header Host $host;
|
proxy_set_header Host $host;
|
||||||
proxy_set_header X-Real-IP $remote_addr;
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
client_max_body_size 10G;
|
|
||||||
client_body_buffer_size 128k;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
location /static/ {
|
location /static/ {
|
||||||
|
|||||||
Reference in New Issue
Block a user