mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-05-06 20:42:52 +08:00
Merge pull request #130 from JefferyHcool/feature/v1.7.4
refactor(backend): 重构后端异常处理和模型管理
This commit is contained in:
@@ -11,6 +11,7 @@ VITE_FRONTEND_PORT=3015
|
||||
ENV=production
|
||||
STATIC=/static
|
||||
OUT_DIR=./static/screenshots
|
||||
NOTE_OUTPUT_DIR=note_results
|
||||
IMAGE_BASE_URL=/static/screenshots
|
||||
DATA_DIR=data
|
||||
# FFMPEG 配置
|
||||
|
||||
@@ -36,7 +36,7 @@ const DownloaderForm = () => {
|
||||
setLoading(true) // 🔁 切换平台时显示 loading
|
||||
try {
|
||||
const res = await getDownloaderCookie(id)
|
||||
const cookie = res?.data?.data?.cookie || ''
|
||||
const cookie = res?.cookie || ''
|
||||
form.reset({ cookie }) // ✅ 正确重置表单值
|
||||
} catch (e) {
|
||||
toast.error('加载 Cookie 失败: ' + e)
|
||||
|
||||
@@ -129,11 +129,10 @@ const ProviderForm = ({ isCreate = false }: { isCreate?: boolean }) => {
|
||||
|
||||
try {
|
||||
const res = await deleteModelById(modelId)
|
||||
if (res.data.code === 0) {
|
||||
toast.success('删除成功')
|
||||
} else {
|
||||
toast.error(res.data.msg || '删除失败')
|
||||
}
|
||||
console.log('🔧 删除结果:', res)
|
||||
|
||||
toast.success('删除成功')
|
||||
|
||||
} catch (e) {
|
||||
toast.error('删除异常')
|
||||
}
|
||||
@@ -151,16 +150,16 @@ const ProviderForm = ({ isCreate = false }: { isCreate?: boolean }) => {
|
||||
return
|
||||
}
|
||||
setTesting(true)
|
||||
const data = await testConnection({
|
||||
id
|
||||
})
|
||||
if (data.data.code === 0) {
|
||||
await testConnection({
|
||||
id
|
||||
})
|
||||
|
||||
toast.success('测试连通性成功 🎉')
|
||||
} else {
|
||||
toast.error(`连接失败: ${data.data.msg || '未知错误'}`)
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
toast.error('测试连通性异常')
|
||||
|
||||
toast.error(`连接失败: ${data.data.msg || '未知错误'}`)
|
||||
// toast.error('测试连通性异常')
|
||||
} finally {
|
||||
setTesting(false)
|
||||
}
|
||||
|
||||
@@ -26,11 +26,11 @@ export const useTaskPolling = (interval = 3000) => {
|
||||
try {
|
||||
console.log('🔄 正在轮询任务:', task.id)
|
||||
const res = await get_task_status(task.id)
|
||||
const { status } = res.data
|
||||
const { status } = res
|
||||
|
||||
if (status && status !== task.status) {
|
||||
if (status === 'SUCCESS') {
|
||||
const { markdown, transcript, audio_meta } = res.data.result
|
||||
const { markdown, transcript, audio_meta } = res.result
|
||||
toast.success('笔记生成成功')
|
||||
updateTaskContent(task.id, {
|
||||
status,
|
||||
@@ -47,7 +47,7 @@ export const useTaskPolling = (interval = 3000) => {
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('❌ 任务轮询失败:', e)
|
||||
toast.error(`生成失败 ${e.message || e}`)
|
||||
// toast.error(`生成失败 ${e.message || e}`)
|
||||
updateTaskContent(task.id, { status: 'FAILED' })
|
||||
// removeTask(task.id)
|
||||
}
|
||||
|
||||
@@ -173,6 +173,7 @@ const MarkdownViewer: FC<MarkdownViewerProps> = ({ status }) => {
|
||||
<div className="text-center">
|
||||
<p className="text-lg font-bold text-red-500">笔记生成失败</p>
|
||||
<p className="mt-2 mb-2 text-xs text-red-400">请检查后台或稍后再试</p>
|
||||
|
||||
<Button onClick={() => retryTask(currentTask.id)} size="lg">
|
||||
重试
|
||||
</Button>
|
||||
|
||||
@@ -37,6 +37,7 @@ import {
|
||||
import { Input } from '@/components/ui/input.tsx'
|
||||
import { Textarea } from '@/components/ui/textarea.tsx'
|
||||
import { noteStyles, noteFormats, videoPlatforms } from '@/constant/note.ts'
|
||||
import { fetchModels } from '@/services/model.ts'
|
||||
|
||||
/* -------------------- 校验 Schema -------------------- */
|
||||
const formSchema = z
|
||||
@@ -206,7 +207,7 @@ const NoteForm = () => {
|
||||
}
|
||||
|
||||
message.success('已提交任务')
|
||||
const { data } = await generateNote(payload)
|
||||
const data = await generateNote(payload)
|
||||
addPendingTask(data.task_id, values.platform, payload)
|
||||
}
|
||||
const onInvalid = (errors: FieldErrors<NoteFormValues>) => {
|
||||
@@ -355,6 +356,9 @@ const NoteForm = () => {
|
||||
<FormItem>
|
||||
<SectionHeader title="模型选择" tip="不同模型效果不同,建议自行测试" />
|
||||
<Select
|
||||
onOpenChange={()=>{
|
||||
loadEnabledModels()
|
||||
}}
|
||||
value={field.value}
|
||||
onValueChange={field.onChange}
|
||||
defaultValue={field.value}
|
||||
|
||||
@@ -26,7 +26,7 @@ export default function AboutPage() {
|
||||
height={50}
|
||||
className="rounded-lg"
|
||||
/>
|
||||
<h1 className="text-4xl font-bold">BiliNote v1.7.3</h1>
|
||||
<h1 className="text-4xl font-bold">BiliNote v1.7.4</h1>
|
||||
</div>
|
||||
<p className="text-muted-foreground mb-6 text-xl italic">
|
||||
AI 视频笔记生成工具 让 AI 为你的视频做笔记
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import request from '@/utils/request'
|
||||
import toast from 'react-hot-toast'
|
||||
import { useTaskStore } from '@/store/taskStore'
|
||||
import request from '@/utils/request'
|
||||
|
||||
export const generateNote = async (data: {
|
||||
video_url: string
|
||||
platform: string
|
||||
@@ -14,12 +13,13 @@ export const generateNote = async (data: {
|
||||
extras?: string
|
||||
video_understand?: boolean
|
||||
video_interval?: number
|
||||
grid_size:Array<number>
|
||||
grid_size: Array<number>
|
||||
}) => {
|
||||
try {
|
||||
console.log('generateNote', data)
|
||||
const response = await request.post('/generate_note', data)
|
||||
|
||||
if (response.data.code != 0) {
|
||||
if (!response) {
|
||||
if (response.data.msg) {
|
||||
toast.error(response.data.msg)
|
||||
}
|
||||
@@ -30,12 +30,12 @@ export const generateNote = async (data: {
|
||||
console.log('res', response)
|
||||
// 成功提示
|
||||
|
||||
return response.data
|
||||
return response
|
||||
} catch (e: any) {
|
||||
console.error('❌ 请求出错', e)
|
||||
|
||||
// 错误提示
|
||||
toast.error('笔记生成失败,请稍后重试')
|
||||
// toast.error('笔记生成失败,请稍后重试')
|
||||
|
||||
throw e // 抛出错误以便调用方处理
|
||||
}
|
||||
@@ -65,15 +65,9 @@ export const delete_task = async ({ video_id, platform }) => {
|
||||
|
||||
export const get_task_status = async (task_id: string) => {
|
||||
try {
|
||||
const response = await request.get('/task_status/' + task_id)
|
||||
|
||||
if (response.data.code == 0 && response.data.status == 'SUCCESS') {
|
||||
// toast.success("笔记生成成功")
|
||||
}
|
||||
console.log('res', response)
|
||||
// 成功提示
|
||||
|
||||
return response.data
|
||||
return await request.get('/task_status/' + task_id)
|
||||
} catch (e) {
|
||||
console.error('❌ 请求出错', e)
|
||||
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
import { create } from 'zustand'
|
||||
import { devtools } from 'zustand/middleware'
|
||||
import { fetchModels, addModel, fetchEnableModels, fetchEnableModelById, deleteModelById } from '@/services/model.ts'
|
||||
import {
|
||||
fetchModels,
|
||||
addModel,
|
||||
fetchEnableModels,
|
||||
fetchEnableModelById,
|
||||
deleteModelById
|
||||
} from '@/services/model'
|
||||
|
||||
interface IModel {
|
||||
id: string
|
||||
@@ -11,81 +17,93 @@ interface IModel {
|
||||
root: string
|
||||
}
|
||||
|
||||
interface IModelListItem {
|
||||
id: string
|
||||
provider_id: string
|
||||
model_name: string
|
||||
created_at?: string
|
||||
}
|
||||
|
||||
interface ModelStore {
|
||||
models: IModel[]
|
||||
modelList: []
|
||||
modelList: IModelListItem[]
|
||||
loading: boolean
|
||||
selectedModel: string
|
||||
|
||||
loadModels: (providerId: string) => Promise<void>
|
||||
loadModelsById: (providerId: string) => Promise<IModelListItem[]>
|
||||
loadEnabledModels: () => Promise<void>
|
||||
loadModelsById : (providerId: string) => Promise<void>
|
||||
addNewModel: (providerId: string, modelId: string) => Promise<void>
|
||||
setSelectedModel: (modelId: string) => void
|
||||
deleteModel: (modelId: number) => Promise<void>
|
||||
setSelectedModel: (modelId: string) => void
|
||||
clearModels: () => void
|
||||
}
|
||||
|
||||
export const useModelStore = create<ModelStore>()(
|
||||
devtools(set => ({
|
||||
devtools((set) => ({
|
||||
models: [],
|
||||
modelList: [],
|
||||
loading: false,
|
||||
selectedModel: '',
|
||||
modelList: [],
|
||||
|
||||
// 获取所有可用模型 (全局可用模型列表)
|
||||
loadEnabledModels: async () => {
|
||||
try {
|
||||
set({ loading: true })
|
||||
const res = await fetchEnableModels()
|
||||
if (res.data.code === 0 && res.data.data.length > 0) {
|
||||
set({ modelList: res.data.data })
|
||||
} else {
|
||||
set({ modelList: [] })
|
||||
console.error('模型列表加载失败')
|
||||
}
|
||||
const list = await fetchEnableModels()
|
||||
set({ modelList: list })
|
||||
} catch (error) {
|
||||
set({ modelList: [] })
|
||||
console.error('加载模型出错', error)
|
||||
}
|
||||
},
|
||||
|
||||
deleteModel: async (modelId: number) => {
|
||||
await deleteModelById( modelId)
|
||||
},
|
||||
// 加载模型列表
|
||||
loadModels: async (providerId: string) => {
|
||||
try {
|
||||
set({ loading: true })
|
||||
const res = await fetchModels(providerId)
|
||||
if (res.data.code === 0 && res.data.data.models.data.length > 0) {
|
||||
set({ models: res.data.data.models.data })
|
||||
} else if (res.data.code === 0 && res.data.data.models.length > 0) {
|
||||
set({ models: res.data.data.models.data })
|
||||
} else {
|
||||
set({ models: [] })
|
||||
console.error('模型列表加载失败')
|
||||
}
|
||||
} catch (error) {
|
||||
set({ models: [] })
|
||||
console.error('加载模型出错', error)
|
||||
console.error('加载可用模型失败', error)
|
||||
} finally {
|
||||
set({ loading: false })
|
||||
}
|
||||
},
|
||||
loadModelsById: async (providerId: string)=>{
|
||||
const models = await fetchEnableModelById(providerId)
|
||||
if (models.data.code === 0) {
|
||||
console.log('模型列表加载成功:', models.data)
|
||||
return models.data.data
|
||||
|
||||
// 通过 provider 获取该供应商的模型列表
|
||||
loadModels: async (providerId: string) => {
|
||||
try {
|
||||
set({ loading: true })
|
||||
const res = await fetchModels(providerId)
|
||||
|
||||
let models: IModel[] = []
|
||||
|
||||
// 兼容 SyncPage 分页对象与普通数组两种格式
|
||||
if (Array.isArray(res.models)) {
|
||||
models = res.models
|
||||
} else if (res.models?.data && Array.isArray(res.models.data)) {
|
||||
models = res.models.data
|
||||
}
|
||||
|
||||
set({ models })
|
||||
} catch (error) {
|
||||
set({ models: [] })
|
||||
console.error('加载模型列表失败', error)
|
||||
} finally {
|
||||
set({ loading: false })
|
||||
}
|
||||
},
|
||||
// 新增模型
|
||||
},
|
||||
|
||||
// 单独获取某个供应商下已启用模型
|
||||
loadModelsById: async (providerId: string) => {
|
||||
try {
|
||||
const models = await fetchEnableModelById(providerId)
|
||||
console.log('获取供应商模型成功:', models)
|
||||
return models
|
||||
} catch (error) {
|
||||
console.error('加载供应商模型失败', error)
|
||||
return []
|
||||
}
|
||||
},
|
||||
|
||||
// 新增模型逻辑
|
||||
addNewModel: async (providerId: string, modelId: string) => {
|
||||
try {
|
||||
const res = await addModel({ provider_id: providerId, model_name: modelId })
|
||||
|
||||
if (res.code === 0) {
|
||||
console.log('新增模型成功:', modelId)
|
||||
// ✅ 新增成功以后,前端直接追加一条到 models 列表
|
||||
set(state => ({
|
||||
set((state) => ({
|
||||
models: [
|
||||
...state.models,
|
||||
{
|
||||
@@ -99,17 +117,30 @@ export const useModelStore = create<ModelStore>()(
|
||||
],
|
||||
}))
|
||||
} else {
|
||||
console.error('新增模型失败')
|
||||
console.error('新增模型失败', res.msg)
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('添加模型出错', error)
|
||||
}
|
||||
},
|
||||
|
||||
// 设置选中的模型
|
||||
setSelectedModel: modelId => set({ selectedModel: modelId }),
|
||||
// 删除模型
|
||||
deleteModel: async (modelId: number) => {
|
||||
try {
|
||||
await deleteModelById(modelId)
|
||||
// 删除后更新本地状态(可选)
|
||||
set((state) => ({
|
||||
models: state.models.filter((model) => model.id !== modelId.toString())
|
||||
}))
|
||||
} catch (error) {
|
||||
console.error('删除模型失败', error)
|
||||
}
|
||||
},
|
||||
|
||||
// 清空
|
||||
clearModels: () => set({ models: [], selectedModel: '' }),
|
||||
// 切换选中模型
|
||||
setSelectedModel: (modelId: string) => set({ selectedModel: modelId }),
|
||||
|
||||
// 清空
|
||||
clearModels: () => set({ models: [], selectedModel: '', modelList: [] }),
|
||||
}))
|
||||
)
|
||||
)
|
||||
@@ -1,5 +1,5 @@
|
||||
import { create } from 'zustand'
|
||||
import { IProvider } from '@/types'
|
||||
import { IProvider, IResponse } from '@/types'
|
||||
import {
|
||||
addProvider,
|
||||
getProviderById,
|
||||
@@ -38,10 +38,9 @@ export const useProviderStore = create<ProviderStore>((set, get) => ({
|
||||
// 设置整个 provider 列表
|
||||
setAllProviders: providers => set({ provider: providers }),
|
||||
loadProviderById: async (id: string) => {
|
||||
const res = await getProviderById(id)
|
||||
if (res.data.code === 0) {
|
||||
const item = res.data.data
|
||||
console.log('Provider ', item)
|
||||
const res:IResponse<IProvider> = await getProviderById(id)
|
||||
|
||||
const item = res
|
||||
return {
|
||||
id: item.id,
|
||||
name: item.name,
|
||||
@@ -51,9 +50,7 @@ export const useProviderStore = create<ProviderStore>((set, get) => ({
|
||||
type: item.type,
|
||||
enabled: item.enabled,
|
||||
}
|
||||
} else {
|
||||
console.log('Provider not found')
|
||||
}
|
||||
|
||||
},
|
||||
addNewProvider: async (provider: IProvider) => {
|
||||
const payload = {
|
||||
@@ -96,16 +93,18 @@ export const useProviderStore = create<ProviderStore>((set, get) => ({
|
||||
getProviderList: () => get().provider,
|
||||
fetchProviderList: async () => {
|
||||
try {
|
||||
const res = await getProviderList()
|
||||
if (res.data.code === 0) {
|
||||
const res = await getProviderList()
|
||||
|
||||
set({
|
||||
provider: res.data.data.map(
|
||||
provider: res.map(
|
||||
(item: {
|
||||
id: string
|
||||
name: string
|
||||
logo: string
|
||||
api_key: string
|
||||
base_url: string
|
||||
type: string
|
||||
enabled: number
|
||||
}) => {
|
||||
return {
|
||||
id: item.id,
|
||||
@@ -119,7 +118,6 @@ export const useProviderStore = create<ProviderStore>((set, get) => ({
|
||||
}
|
||||
),
|
||||
})
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error fetching provider list:', error)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ import { create } from 'zustand'
|
||||
import { persist } from 'zustand/middleware'
|
||||
import { delete_task, generateNote } from '@/services/note.ts'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
import toast from 'react-hot-toast'
|
||||
|
||||
|
||||
export type TaskStatus = 'PENDING' | 'RUNNING' | 'SUCCESS' | 'FAILD'
|
||||
@@ -157,14 +158,19 @@ export const useTaskStore = create<TaskStore>()(
|
||||
return get().tasks.find(task => task.id === currentTaskId) || null
|
||||
},
|
||||
retryTask: async (id: string, payload?: any) => {
|
||||
|
||||
if (!id){
|
||||
toast.error('任务不存在')
|
||||
return
|
||||
}
|
||||
const task = get().tasks.find(task => task.id === id)
|
||||
console.log('retry',task)
|
||||
if (!task) return
|
||||
|
||||
const newFormData = payload || task.formData
|
||||
|
||||
await generateNote({
|
||||
task_id: id,
|
||||
...newFormData,
|
||||
task_id: id,
|
||||
})
|
||||
|
||||
set(state => ({
|
||||
|
||||
5
BillNote_frontend/src/types/index.d.ts
vendored
5
BillNote_frontend/src/types/index.d.ts
vendored
@@ -7,3 +7,8 @@ export interface IProvider {
|
||||
baseUrl: string
|
||||
enabled: number
|
||||
}
|
||||
export interface IResponse<T> {
|
||||
code: number
|
||||
data:T
|
||||
msg: string
|
||||
}
|
||||
@@ -1,27 +1,58 @@
|
||||
import axios from 'axios'
|
||||
const request = axios.create({
|
||||
baseURL: '/api',
|
||||
timeout: 10000,
|
||||
})
|
||||
function handleErrorResponse(response: any) {
|
||||
if (!response) return '请求失败,请检查网络连接'
|
||||
if (typeof response.code !== 'number') return '系统异常'
|
||||
import axios, { AxiosInstance, AxiosResponse } from 'axios';
|
||||
import toast from 'react-hot-toast'
|
||||
|
||||
// 错误码判断
|
||||
switch (response.code) {
|
||||
case 1001:
|
||||
return response.msg || '下载失败,请检查视频链接'
|
||||
case 1002:
|
||||
return response.msg || '转写失败,请稍后重试'
|
||||
case 1003:
|
||||
return response.msg || '总结失败,可能是模型服务异常'
|
||||
case 2001:
|
||||
case 2002:
|
||||
return Array.isArray(response.data)
|
||||
? response.data.map(e => `${e.field}: ${e.error}`).join('\n')
|
||||
: response.msg || '参数错误'
|
||||
default:
|
||||
return response.msg || '系统异常'
|
||||
}
|
||||
// 统一响应类型
|
||||
export interface IResponse<T = any> {
|
||||
code: number;
|
||||
msg: string;
|
||||
data: T;
|
||||
}
|
||||
|
||||
// 模拟一个消息提示函数 (实际项目中会使用UI库的组件,如 Ant Design 的 message 或 Element UI 的 ElMessage)
|
||||
// This function simulates a message display (in real projects, you'd use a UI library's component)
|
||||
|
||||
|
||||
// 创建实例
|
||||
const request: AxiosInstance = axios.create({
|
||||
baseURL: '/api', // 请确保你的开发服务器代理设置正确
|
||||
timeout: 10000,
|
||||
});
|
||||
|
||||
// 响应拦截器
|
||||
request.interceptors.response.use(
|
||||
(response: AxiosResponse<IResponse>) => {
|
||||
const res = response.data;
|
||||
if (res.code === 0) {
|
||||
// 业务成功,可以根据需要显示成功消息,或者不显示(如果操作本身就是可见的)
|
||||
// showMessage('success', res.msg || '操作成功'); // 如果需要显示成功消息
|
||||
return res.data; // 返回data部分,简化后续业务代码
|
||||
} else {
|
||||
// 业务错误,统一显示后端返回的错误消息
|
||||
// Business error, uniformly display the error message returned from the backend
|
||||
toast.error(res.msg || '操作失败,请稍后再试');
|
||||
return Promise.reject(res); // 拒绝Promise,让业务代码可以捕获并处理
|
||||
}
|
||||
},
|
||||
(error) => {
|
||||
// 网络/服务器错误
|
||||
const res = error?.response?.data as IResponse | undefined;
|
||||
if (res) {
|
||||
// 如果后端有返回错误信息,则显示后端信息
|
||||
// If the backend returns an error message, display it
|
||||
|
||||
toast.error(res.msg || '服务器错误,请稍后再试');
|
||||
return Promise.reject(res);
|
||||
} else {
|
||||
// 没有响应数据(如网络中断),显示通用网络错误
|
||||
// No response data (e.g., network disconnected), display generic network error
|
||||
toast.error( '请求失败,请检查网络连接或稍后再试')
|
||||
return Promise.reject({
|
||||
code: -1,
|
||||
msg: '请求失败,请检查网络连接',
|
||||
data: null
|
||||
} as IResponse);
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
export default request
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
<p align="center">
|
||||
<img src="./doc/icon.svg" alt="BiliNote Banner" width="50" height="50" />
|
||||
</p>
|
||||
<h1 align="center" > BiliNote v1.7.3</h1>
|
||||
<h1 align="center" > BiliNote v1.7.4</h1>
|
||||
</div>
|
||||
|
||||
<p align="center"><i>AI 视频笔记生成工具 让 AI 为你的视频做笔记</i></p>
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from .routers import note, provider, model, config
|
||||
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
app = FastAPI(title="BiliNote")
|
||||
app.include_router(note.router, prefix="/api")
|
||||
app.include_router(provider.router, prefix="/api")
|
||||
app.include_router(model.router,prefix="/api")
|
||||
app.include_router(config.router, prefix="/api")
|
||||
|
||||
return app
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
# app/core/exception_handlers.py
|
||||
from fastapi import Request, HTTPException
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.utils.logger import get_logger
|
||||
from app.utils.response import ResponseWrapper
|
||||
from app.utils.status_code import StatusCode
|
||||
logger = get_logger(__name__)
|
||||
|
||||
def register_exception_handlers(app):
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
errors = []
|
||||
for err in exc.errors():
|
||||
loc = err.get("loc", [])
|
||||
field = loc[-1] if loc else "body"
|
||||
msg = err.get("msg", "参数不合法")
|
||||
errors.append({"field": field, "error": msg})
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=ResponseWrapper.error(msg="参数验证失败", code=StatusCode.PARAM_ERROR, data=errors)
|
||||
)
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request: Request, exc: HTTPException):
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=ResponseWrapper.error(msg=str(exc.detail), code=StatusCode.FAIL)
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
logger.exception(f"服务器内部错误: {exc}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=ResponseWrapper.error(msg="服务器内部错误", code=StatusCode.FAIL, data=str(exc))
|
||||
)
|
||||
@@ -136,7 +136,8 @@ def get_provider_by_name(name: str):
|
||||
if row is None:
|
||||
logger.info(f"Provider not found: {name}")
|
||||
return None
|
||||
logger.info(f"Provider found: {row}")
|
||||
logger.info(f"Provider found: {row[0]}")
|
||||
|
||||
return row
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get provider by name: {e}")
|
||||
@@ -155,7 +156,7 @@ def get_provider_by_id(id: int):
|
||||
if row is None:
|
||||
logger.info(f"Provider not found: {id}")
|
||||
return None
|
||||
logger.info(f"Provider found: {row}")
|
||||
logger.info(f"Provider found: {row[0]}")
|
||||
return row
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get provider by id: {e}")
|
||||
@@ -173,7 +174,7 @@ def get_all_providers():
|
||||
if rows is None:
|
||||
logger.info("No providers found")
|
||||
return None
|
||||
logger.info(f"Providers found: {rows}")
|
||||
logger.info(f"Providers found total {len(rows) }")
|
||||
return rows
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get all providers: {e}")
|
||||
|
||||
@@ -145,53 +145,59 @@ class DouyinDownloader(Downloader):
|
||||
return ""
|
||||
|
||||
def gen_real_msToken(self) -> str:
|
||||
payload = json.dumps(
|
||||
{
|
||||
"magic": self.ms_token_config["magic"],
|
||||
"version": self.ms_token_config["version"],
|
||||
"dataType": self.ms_token_config["dataType"],
|
||||
"strData": self.ms_token_config["strData"],
|
||||
"tspFromClient": get_timestamp(),
|
||||
try:
|
||||
payload = json.dumps(
|
||||
{
|
||||
"magic": self.ms_token_config["magic"],
|
||||
"version": self.ms_token_config["version"],
|
||||
"dataType": self.ms_token_config["dataType"],
|
||||
"strData": self.ms_token_config["strData"],
|
||||
"tspFromClient": get_timestamp(),
|
||||
}
|
||||
)
|
||||
headers = {
|
||||
"User-Agent": self.headers_config["User-Agent"],
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
)
|
||||
headers = {
|
||||
"User-Agent": self.headers_config["User-Agent"],
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
transport = httpx.HTTPTransport(retries=5)
|
||||
with httpx.Client(transport=transport) as client:
|
||||
try:
|
||||
response = client.post(
|
||||
self.ms_token_config["url"], content=payload, headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
transport = httpx.HTTPTransport(retries=5)
|
||||
with httpx.Client(transport=transport) as client:
|
||||
try:
|
||||
response = client.post(
|
||||
self.ms_token_config["url"], content=payload, headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
msToken = str(httpx.Cookies(response.cookies).get("msToken"))
|
||||
if len(msToken) not in [120, 128]:
|
||||
raise ValueError("响应内容:{0}, Douyin msToken API 的响应内容不符合要求。".format(msToken))
|
||||
msToken = str(httpx.Cookies(response.cookies).get("msToken"))
|
||||
if len(msToken) not in [120, 128]:
|
||||
raise ValueError("响应内容:{0}, Douyin msToken API 的响应内容不符合要求。".format(msToken))
|
||||
|
||||
return msToken
|
||||
except Exception as e:
|
||||
raise ValueError("Douyin msToken API 请求失败:{0}".format(e))
|
||||
return msToken
|
||||
except Exception as e:
|
||||
raise ValueError("Douyin msToken API 请求失败:{0}".format(e))
|
||||
except Exception as e:
|
||||
raise ValueError("Douyin msToken API{0}".format(e))
|
||||
|
||||
def fetch_video_info(self, video_url: str) -> json:
|
||||
aweme_id = self.extract_video_id(video_url)
|
||||
kwargs = self.headers_config
|
||||
print("kwargs:", kwargs)
|
||||
base_params = BaseRequestModel().model_dump()
|
||||
base_params["msToken"] = self.gen_real_msToken()
|
||||
base_params["aweme_id"] = aweme_id
|
||||
bogus = ABogus()
|
||||
ab_value = bogus.get_value(base_params)
|
||||
a_bogus = quote(ab_value, safe='')
|
||||
print(base_params)
|
||||
query_str = urlencode(base_params)
|
||||
full_url = f"{DOUYIN_DOMAIN}/aweme/v1/web/aweme/detail/?{query_str}&a_bogus={a_bogus}"
|
||||
|
||||
print("Request URL:", full_url)
|
||||
|
||||
try:
|
||||
|
||||
aweme_id = self.extract_video_id(video_url)
|
||||
kwargs = self.headers_config
|
||||
print("@kwargs:", kwargs)
|
||||
base_params = BaseRequestModel().model_dump()
|
||||
base_params["msToken"] = self.gen_real_msToken()
|
||||
|
||||
base_params["aweme_id"] = aweme_id
|
||||
bogus = ABogus()
|
||||
ab_value = bogus.get_value(base_params)
|
||||
a_bogus = quote(ab_value, safe='')
|
||||
print("@a_bogus:", a_bogus)
|
||||
print(base_params)
|
||||
query_str = urlencode(base_params)
|
||||
full_url = f"{DOUYIN_DOMAIN}/aweme/v1/web/aweme/detail/?{query_str}&a_bogus={a_bogus}"
|
||||
|
||||
print("Request URL:", full_url)
|
||||
|
||||
|
||||
response = requests.get(full_url, headers=kwargs)
|
||||
|
||||
print("Response JSON:", response.content)
|
||||
@@ -208,46 +214,49 @@ class DouyinDownloader(Downloader):
|
||||
quality: DownloadQuality = "fast",
|
||||
need_video: Optional[bool] = False
|
||||
) -> AudioDownloadResult:
|
||||
print(
|
||||
f"正在下载视频: {video_url},保存路径: {output_dir},质量: {quality}"
|
||||
)
|
||||
if output_dir is None:
|
||||
output_dir = get_data_dir()
|
||||
if not output_dir:
|
||||
output_dir = self.cache_data
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
try:
|
||||
print(
|
||||
f"正在下载视频: {video_url},保存路径: {output_dir},质量: {quality}"
|
||||
)
|
||||
if output_dir is None:
|
||||
output_dir = get_data_dir()
|
||||
if not output_dir:
|
||||
output_dir = self.cache_data
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
output_path = os.path.join(output_dir, "%(id)s.%(ext)s")
|
||||
output_path = os.path.join(output_dir, "%(id)s.%(ext)s")
|
||||
|
||||
video_data = self.fetch_video_info(video_url)
|
||||
output_path = output_path % {
|
||||
"id": video_data['aweme_detail']['aweme_id'],
|
||||
"ext": "mp3",
|
||||
}
|
||||
url = video_data['aweme_detail']['music']['play_url']['uri']
|
||||
# 下载音频
|
||||
audio_data = requests.get(url)
|
||||
with open(output_path, 'wb') as f:
|
||||
f.write(audio_data.content)
|
||||
print(url)
|
||||
tags = []
|
||||
for tag in video_data['aweme_detail']['video_tag']:
|
||||
if tag['tag_name']:
|
||||
tags.append(tag['tag_name'])
|
||||
video_data = self.fetch_video_info(video_url)
|
||||
output_path = output_path % {
|
||||
"id": video_data['aweme_detail']['aweme_id'],
|
||||
"ext": "mp3",
|
||||
}
|
||||
url = video_data['aweme_detail']['music']['play_url']['uri']
|
||||
# 下载音频
|
||||
audio_data = requests.get(url)
|
||||
with open(output_path, 'wb') as f:
|
||||
f.write(audio_data.content)
|
||||
print(url)
|
||||
tags = []
|
||||
for tag in video_data['aweme_detail']['video_tag']:
|
||||
if tag['tag_name']:
|
||||
tags.append(tag['tag_name'])
|
||||
|
||||
return AudioDownloadResult(
|
||||
file_path=output_path,
|
||||
title=video_data['aweme_detail']['item_title'],
|
||||
duration=video_data['aweme_detail']['video']['duration'],
|
||||
cover_url=video_data['aweme_detail']['video']['cover_original_scale']['url_list'][0] if
|
||||
video_data['aweme_detail']['video']['cover'] else video_data['video']['big_thumbs']['img_url'],
|
||||
platform="douyin",
|
||||
video_id=video_data['aweme_detail']['aweme_id'],
|
||||
raw_info={
|
||||
'tags': video_data['aweme_detail']['caption'] + ''.join(tags),
|
||||
},
|
||||
video_path=None # ❗音频下载不包含视频路径
|
||||
)
|
||||
return AudioDownloadResult(
|
||||
file_path=output_path,
|
||||
title=video_data['aweme_detail']['item_title'],
|
||||
duration=video_data['aweme_detail']['video']['duration'],
|
||||
cover_url=video_data['aweme_detail']['video']['cover_original_scale']['url_list'][0] if
|
||||
video_data['aweme_detail']['video']['cover'] else video_data['video']['big_thumbs']['img_url'],
|
||||
platform="douyin",
|
||||
video_id=video_data['aweme_detail']['aweme_id'],
|
||||
raw_info={
|
||||
'tags': video_data['aweme_detail']['caption'] + ''.join(tags),
|
||||
},
|
||||
video_path=None # ❗音频下载不包含视频路径
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def download_video(self, video_url: str, output_dir: Union[str, None] = None) -> str:
|
||||
|
||||
|
||||
21
backend/app/enmus/exception.py
Normal file
21
backend/app/enmus/exception.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import enum
|
||||
|
||||
|
||||
class ProviderErrorEnum(enum.Enum):
|
||||
CONNECTION_TEST_FAILED = (200101, "供应商连接测试失败")
|
||||
SAVE_FAILED = (200102, "供应商保存失败")
|
||||
CREATE_FAILED = (200103, "供应商创建失败")
|
||||
NOT_FOUND = (200104, "供应商不存在/未保存")
|
||||
WRONG_PARAMETER = (200105, "API / API 地址不正确")
|
||||
UNKNOW_ERROR = (200106, "未知错误")
|
||||
|
||||
def __init__(self, code, message):
|
||||
self.code = code
|
||||
self.message = message
|
||||
|
||||
class NoteErrorEnum(enum.Enum):
|
||||
PLATFORM_NOT_SUPPORTED = (300101 ,"选择的平台不受支持")
|
||||
|
||||
def __init__(self, code, message):
|
||||
self.code = code
|
||||
self.message = message
|
||||
0
backend/app/exceptions/__init__.py
Normal file
0
backend/app/exceptions/__init__.py
Normal file
6
backend/app/exceptions/biz_exception.py
Normal file
6
backend/app/exceptions/biz_exception.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# exceptions/biz_exception.py
|
||||
|
||||
class BizException(Exception):
|
||||
def __init__(self, code: int, message: str = "业务异常"):
|
||||
self.code = code
|
||||
self.message = message
|
||||
33
backend/app/exceptions/exception_handlers.py
Normal file
33
backend/app/exceptions/exception_handlers.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# middlewares/exception_handler.py
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.enmus.exception import NoteErrorEnum
|
||||
from app.exceptions.biz_exception import BizException
|
||||
from app.exceptions.note import NoteError
|
||||
from app.exceptions.provider import ProviderError
|
||||
from app.utils.logger import get_logger
|
||||
from app.utils.response import ResponseWrapper as R
|
||||
import traceback
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
def register_exception_handlers(app: FastAPI):
|
||||
@app.exception_handler(BizException)
|
||||
async def biz_exception_handler(request: Request, exc: BizException):
|
||||
logger.error(f"BizException: {exc.code} - {exc.message}")
|
||||
return R.error(code=exc.code, msg=str(exc.message))
|
||||
@app.exception_handler(NoteError)
|
||||
async def note_exception_handler(request: Request, exc: NoteError):
|
||||
logger.error(f"NoteError: {exc.code} - {exc.message}")
|
||||
return R.error(code=exc.code, msg=str(exc.message))
|
||||
@app.exception_handler(ProviderError)
|
||||
async def provider_exception_handler(request: Request, exc: ProviderError):
|
||||
logger.error(f"供应商模块错误: {exc.code} - {exc.message}")
|
||||
return R.error(code=exc.code, msg=str(exc.message))
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def general_exception_handler(request: Request, exc: Exception):
|
||||
logger.error(f"系统异常: {str(exc)}\n{traceback.format_exc()}")
|
||||
return R.error(code=500000, msg="系统异常")
|
||||
9
backend/app/exceptions/note.py
Normal file
9
backend/app/exceptions/note.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# exceptions.py
|
||||
from app.enmus.exception import ProviderErrorEnum
|
||||
|
||||
|
||||
class NoteError(Exception):
|
||||
def __init__(self, message: str,code: ProviderErrorEnum) -> None:
|
||||
super().__init__(message)
|
||||
self.code=code
|
||||
self.message = message
|
||||
@@ -1,5 +1,12 @@
|
||||
# exceptions.py
|
||||
class ConnectionTestError(Exception):
|
||||
def __init__(self, message: str):
|
||||
from app.enmus.exception import ProviderErrorEnum
|
||||
|
||||
|
||||
class ProviderError(Exception):
|
||||
def __init__(self, message: str,code: ProviderErrorEnum) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.code=code
|
||||
self.message = message
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -2,6 +2,9 @@ from typing import Optional, Union
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from app.utils.logger import get_logger
|
||||
|
||||
logging= get_logger(__name__)
|
||||
class OpenAICompatibleProvider:
|
||||
def __init__(self, api_key: str, base_url: str, model: Union[str, None]=None):
|
||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
@@ -13,11 +16,16 @@ class OpenAICompatibleProvider:
|
||||
|
||||
@staticmethod
|
||||
def test_connection(api_key: str, base_url: str) -> bool:
|
||||
print(api_key)
|
||||
try:
|
||||
client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
client.models.list()
|
||||
model = client.models.list()
|
||||
# for segment in model:
|
||||
# print(segment)
|
||||
# print(model)
|
||||
logging.info("连通性测试成功")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error connecting to OpenAI API: {e}")
|
||||
logging.info(f"连通性测试失败:{e}")
|
||||
|
||||
# print(f"Error connecting to OpenAI API: {e}")
|
||||
return False
|
||||
@@ -27,4 +27,6 @@ def get_cookie(platform: str):
|
||||
@router.post("/update_downloader_cookie")
|
||||
def update_cookie(data: CookieUpdateRequest):
|
||||
cookie_manager.set(data.platform, data.cookie)
|
||||
return {"message": "Cookie updated successfully"}
|
||||
return R.success(
|
||||
|
||||
)
|
||||
|
||||
@@ -31,10 +31,9 @@ def delete_model(model_id: int):
|
||||
return R.error(f"删除模型失败: {e}")
|
||||
@router.get("/model_list/{provider_id}")
|
||||
def model_list(provider_id):
|
||||
try:
|
||||
return R.success(modelService.get_all_models_by_id(provider_id))
|
||||
except Exception as e:
|
||||
return R.error(e)
|
||||
|
||||
return R.success(modelService.get_all_models_by_id(provider_id))
|
||||
|
||||
|
||||
@router.post("/models")
|
||||
def create_model(data: CreateModelRequest):
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -10,7 +11,9 @@ from pydantic import BaseModel, validator, field_validator
|
||||
from dataclasses import asdict
|
||||
|
||||
from app.db.video_task_dao import get_task_by_video
|
||||
from app.enmus.exception import NoteErrorEnum
|
||||
from app.enmus.note_enums import DownloadQuality
|
||||
from app.exceptions.note import NoteError
|
||||
from app.services.note import NoteGenerator, logger
|
||||
from app.utils.response import ResponseWrapper as R
|
||||
from app.utils.url_parser import extract_video_id
|
||||
@@ -54,12 +57,13 @@ class VideoRequest(BaseModel):
|
||||
if parsed.scheme in ("http", "https"):
|
||||
# 是网络链接,继续用原有平台校验
|
||||
if not is_supported_video_url(url):
|
||||
raise ValueError("暂不支持该视频平台或链接格式无效")
|
||||
raise NoteError(code=NoteErrorEnum.PLATFORM_NOT_SUPPORTED.code,
|
||||
message=NoteErrorEnum.PLATFORM_NOT_SUPPORTED.message)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
NOTE_OUTPUT_DIR = "note_results"
|
||||
NOTE_OUTPUT_DIR = os.getenv("NOTE_OUTPUT_DIR", "note_results")
|
||||
UPLOAD_DIR = "uploads"
|
||||
|
||||
|
||||
@@ -74,30 +78,32 @@ def run_note_task(task_id: str, video_url: str, platform: str, quality: Download
|
||||
_format: list = None, style: str = None, extras: str = None, video_understanding: bool = False,
|
||||
video_interval=0, grid_size=[]
|
||||
):
|
||||
try:
|
||||
if not model_name or not provider_id:
|
||||
raise HTTPException(status_code=400, detail="请选择模型和提供者")
|
||||
|
||||
note = NoteGenerator().generate(
|
||||
video_url=video_url,
|
||||
platform=platform,
|
||||
quality=quality,
|
||||
task_id=task_id,
|
||||
model_name=model_name,
|
||||
provider_id=provider_id,
|
||||
link=link,
|
||||
_format=_format,
|
||||
style=style,
|
||||
extras=extras,
|
||||
screenshot=screenshot
|
||||
, video_understanding=video_understanding,
|
||||
video_interval=video_interval,
|
||||
grid_size=grid_size
|
||||
)
|
||||
logger.info(f"Note generated: {task_id}")
|
||||
save_note_to_file(task_id, note)
|
||||
except Exception as e:
|
||||
save_note_to_file(task_id, {"error": str(e)})
|
||||
if not model_name or not provider_id:
|
||||
raise HTTPException(status_code=400, detail="请选择模型和提供者")
|
||||
|
||||
note = NoteGenerator().generate(
|
||||
video_url=video_url,
|
||||
platform=platform,
|
||||
quality=quality,
|
||||
task_id=task_id,
|
||||
model_name=model_name,
|
||||
provider_id=provider_id,
|
||||
link=link,
|
||||
_format=_format,
|
||||
style=style,
|
||||
extras=extras,
|
||||
screenshot=screenshot
|
||||
, video_understanding=video_understanding,
|
||||
video_interval=video_interval,
|
||||
grid_size=grid_size
|
||||
)
|
||||
logger.info(f"Note generated: {task_id}")
|
||||
if not note or not note.markdown:
|
||||
logger.warning(f"任务 {task_id} 执行失败,跳过保存")
|
||||
return
|
||||
save_note_to_file(task_id, note)
|
||||
|
||||
|
||||
|
||||
@router.post('/delete_task')
|
||||
@@ -135,7 +141,6 @@ def generate_note(data: VideoRequest, background_tasks: BackgroundTasks):
|
||||
# msg='笔记已生成,请勿重复发起',
|
||||
#
|
||||
# )
|
||||
|
||||
if data.task_id:
|
||||
# 如果传了task_id,说明是重试!
|
||||
task_id = data.task_id
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Optional
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.exceptions.provider import ConnectionTestError
|
||||
from app.exceptions.provider import ProviderError
|
||||
from app.models.model_config import ModelConfig
|
||||
from app.services.model import ModelService
|
||||
from app.utils.response import ResponseWrapper as R
|
||||
@@ -88,9 +88,5 @@ def update_provider(data: ProviderUpdateRequest):
|
||||
|
||||
@router.post('/connect_test')
|
||||
def gpt_connect_test(data: TestRequest):
|
||||
try:
|
||||
ModelService().connect_test(data.id)
|
||||
return R.success(msg='连接成功')
|
||||
except Exception as e:
|
||||
print("捕获到异常类型:", type(e))
|
||||
return R.error(msg=str(e))
|
||||
ModelService().connect_test(data.id)
|
||||
return R.success(msg='连接成功')
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
|
||||
|
||||
from app.db.model_dao import insert_model, get_all_models, get_model_by_provider_and_name, delete_model
|
||||
from app.db.provider_dao import get_enabled_providers
|
||||
from app.exceptions.provider import ConnectionTestError
|
||||
from app.enmus.exception import ProviderErrorEnum
|
||||
from app.exceptions.provider import ProviderError
|
||||
from app.gpt.gpt_factory import GPTFactory
|
||||
from app.gpt.provider.OpenAI_compatible_provider import OpenAICompatibleProvider
|
||||
from app.models.model_config import ModelConfig
|
||||
from app.services.provider import ProviderService
|
||||
from app.utils.logger import get_logger
|
||||
|
||||
|
||||
logger=get_logger(__name__)
|
||||
class ModelService:
|
||||
|
||||
@staticmethod
|
||||
@@ -83,37 +87,38 @@ class ModelService:
|
||||
provider = ProviderService.get_provider_by_id(provider_id)
|
||||
|
||||
models = ModelService.get_model_list(provider["id"], verbose=verbose)
|
||||
|
||||
model_list={
|
||||
|
||||
"models": models
|
||||
print(type(models))
|
||||
serializable_models = [m.dict() for m in models.data]
|
||||
model_list = {
|
||||
"models": serializable_models
|
||||
}
|
||||
|
||||
logger.info(f"[{provider['name']}] 获取模型成功")
|
||||
return model_list
|
||||
except Exception as e:
|
||||
print(f"[{provider_id}] 获取模型失败: {e}")
|
||||
# print(f"[{provider_id}] 获取模型失败: {e}")
|
||||
logger.error(f"[{provider_id}] 获取模型失败: {e}")
|
||||
return []
|
||||
@staticmethod
|
||||
def connect_test(id: str) -> bool:
|
||||
try:
|
||||
provider = ProviderService.get_provider_by_id(id)
|
||||
|
||||
if provider:
|
||||
if not provider.get('api_key'):
|
||||
raise ConnectionTestError(f"供应商信息未找到,请先保存重试")
|
||||
result = OpenAICompatibleProvider.test_connection(
|
||||
api_key=provider.get('api_key'),
|
||||
base_url=provider.get('base_url')
|
||||
)
|
||||
if result:
|
||||
return True
|
||||
else:
|
||||
raise ConnectionTestError("请检查API Key 和 API 地址是否正确")
|
||||
provider = ProviderService.get_provider_by_id(id)
|
||||
|
||||
if provider:
|
||||
if not provider.get('api_key'):
|
||||
raise ProviderError(code=ProviderErrorEnum.NOT_FOUND.code, message=ProviderErrorEnum.NOT_FOUND.message)
|
||||
result = OpenAICompatibleProvider.test_connection(
|
||||
api_key=provider.get('api_key'),
|
||||
base_url=provider.get('base_url')
|
||||
)
|
||||
if result:
|
||||
return True
|
||||
else:
|
||||
raise ProviderError(code=ProviderErrorEnum.WRONG_PARAMETER.code,message=ProviderErrorEnum.WRONG_PARAMETER.message)
|
||||
|
||||
raise ProviderError(code=ProviderErrorEnum.NOT_FOUND.code, message=ProviderErrorEnum.NOT_FOUND.message)
|
||||
|
||||
|
||||
raise ConnectionTestError("供应商信息未找到,请先保存重试")
|
||||
except Exception as e:
|
||||
# 抛出业务异常,交由 Controller 处理
|
||||
raise ConnectionTestError(f"{str(e)}") from e
|
||||
|
||||
@staticmethod
|
||||
def delete_model_by_id( model_id: int) -> bool:
|
||||
|
||||
@@ -1,75 +1,63 @@
|
||||
import json
|
||||
from dataclasses import asdict
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.downloaders.local_downloader import LocalDownloader
|
||||
from app.enmus.task_status_enums import TaskStatus
|
||||
import logging
|
||||
import os
|
||||
from typing import Union, Optional
|
||||
import re
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union, Any
|
||||
|
||||
from pydantic import HttpUrl
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from app.db.video_task_dao import insert_video_task, delete_task_by_video
|
||||
from app.downloaders.base import Downloader
|
||||
from app.downloaders.bilibili_downloader import BilibiliDownloader
|
||||
from app.downloaders.douyin_downloader import DouyinDownloader
|
||||
from app.downloaders.youtube_downloader import YoutubeDownloader
|
||||
from app.services.constant import SUPPORT_PLATFORM_MAP
|
||||
from app.enmus.task_status_enums import TaskStatus
|
||||
from app.enmus.exception import NoteErrorEnum, ProviderErrorEnum
|
||||
from app.exceptions.note import NoteError
|
||||
from app.exceptions.provider import ProviderError
|
||||
from app.db.video_task_dao import delete_task_by_video, insert_video_task
|
||||
from app.gpt.base import GPT
|
||||
from app.gpt.deepseek_gpt import DeepSeekGPT
|
||||
from app.gpt.gpt_factory import GPTFactory
|
||||
from app.gpt.openai_gpt import OpenaiGPT
|
||||
from app.gpt.qwen_gpt import QwenGPT
|
||||
from app.models.audio_model import AudioDownloadResult
|
||||
from app.models.gpt_model import GPTSource
|
||||
from app.models.model_config import ModelConfig
|
||||
from app.models.notes_model import NoteResult
|
||||
from app.models.notes_model import AudioDownloadResult
|
||||
from app.enmus.note_enums import DownloadQuality
|
||||
from app.models.transcriber_model import TranscriptResult, TranscriptSegment
|
||||
from app.services.constant import SUPPORT_PLATFORM_MAP
|
||||
|
||||
from app.services.provider import ProviderService
|
||||
from app.transcriber.base import Transcriber
|
||||
from app.transcriber.transcriber_provider import get_transcriber, _transcribers
|
||||
from app.transcriber.whisper import WhisperTranscriber
|
||||
import re
|
||||
|
||||
from app.utils.note_helper import replace_content_markers
|
||||
from app.utils.status_code import StatusCode
|
||||
from app.utils.video_helper import generate_screenshot
|
||||
|
||||
# from app.services.whisperer import transcribe_audio
|
||||
# from app.services.gpt import summarize_text
|
||||
from dotenv import load_dotenv
|
||||
from app.utils.logger import get_logger
|
||||
from app.utils.video_reader import VideoReader
|
||||
from events import transcription_finished
|
||||
from app.utils.video_helper import generate_screenshot
|
||||
from app.utils.note_helper import replace_content_markers
|
||||
from app.enmus.note_enums import DownloadQuality
|
||||
|
||||
logger = get_logger(__name__)
|
||||
# 环境变量
|
||||
load_dotenv()
|
||||
api_path = os.getenv("API_BASE_URL", "http://localhost")
|
||||
BACKEND_PORT = os.getenv("BACKEND_PORT", 8000)
|
||||
NOTE_OUTPUT_DIR = Path(os.getenv("NOTE_OUTPUT_DIR", "note_results"))
|
||||
NOTE_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
IMAGE_BASE_URL = os.getenv("IMAGE_BASE_URL", "/static/screenshots")
|
||||
IMAGE_OUTPUT_DIR = os.getenv("OUT_DIR", "images")
|
||||
|
||||
BACKEND_BASE_URL = f"{api_path}:{BACKEND_PORT}"
|
||||
output_dir = os.getenv('OUT_DIR')
|
||||
image_base_url = os.getenv('IMAGE_BASE_URL')
|
||||
logger.info("starting up")
|
||||
|
||||
NOTE_OUTPUT_DIR = "note_results"
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class NoteGenerator:
|
||||
|
||||
class States:
|
||||
INIT = 'INIT'
|
||||
PARSING = 'PARSING'
|
||||
DOWNLOADING = 'DOWNLOADING'
|
||||
TRANSCRIBING = 'TRANSCRIBING'
|
||||
SUMMARIZING = 'SUMMARIZING'
|
||||
SAVING = 'SAVING'
|
||||
SUCCESS = 'SUCCESS'
|
||||
FAILED = 'FAILED'
|
||||
|
||||
def __init__(self):
|
||||
self.model_size: str = 'base'
|
||||
self.device: Union[str, None] = None
|
||||
self.transcriber_type = os.getenv('TRANSCRIBER_TYPE', 'fast-whisper')
|
||||
self.transcriber = self.get_transcriber()
|
||||
self.video_path = None
|
||||
logger.info("初始化NoteGenerator")
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
self.transcriber_type = os.getenv("TRANSCRIBER_TYPE", "fast-whisper")
|
||||
self.transcriber: Transcriber = self._init_transcriber()
|
||||
self.video_img_urls = []
|
||||
|
||||
@staticmethod
|
||||
def update_task_status(task_id: str, status: Union[str, TaskStatus], message: Optional[str] = None):
|
||||
@@ -81,310 +69,179 @@ class NoteGenerator:
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(content, f, ensure_ascii=False, indent=2)
|
||||
|
||||
def get_gpt(self, model_name: str = None, provider_id: str = None) -> GPT:
|
||||
def generate(
|
||||
self,
|
||||
video_url: Union[str, HttpUrl],
|
||||
platform: str,
|
||||
quality: DownloadQuality = DownloadQuality.medium,
|
||||
task_id: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
link: bool = False,
|
||||
screenshot: bool = False,
|
||||
_format: Optional[List[str]] = None,
|
||||
style: Optional[str] = None,
|
||||
extras: Optional[str] = None,
|
||||
output_path: Optional[str] = None,
|
||||
video_understanding: bool = False,
|
||||
video_interval: int = 0,
|
||||
grid_size: Optional[List[int]] = None,
|
||||
) -> NoteResult | None:
|
||||
|
||||
self.task_id = task_id
|
||||
self._change_state(self.States.INIT)
|
||||
|
||||
try:
|
||||
self._change_state(self.States.PARSING)
|
||||
|
||||
downloader = self._get_downloader(platform)
|
||||
gpt = self._get_gpt(model_name, provider_id)
|
||||
|
||||
self.audio_cache_file = NOTE_OUTPUT_DIR / f"{task_id}_audio.json"
|
||||
self.transcript_cache_file = NOTE_OUTPUT_DIR / f"{task_id}_transcript.json"
|
||||
self.markdown_cache_file = NOTE_OUTPUT_DIR / f"{task_id}_markdown.md"
|
||||
|
||||
self.audio_meta = self._download_audio_video(
|
||||
downloader, video_url, quality, output_path,
|
||||
screenshot, video_understanding, video_interval, grid_size or []
|
||||
)
|
||||
|
||||
self.transcript = self._transcribe_audio()
|
||||
|
||||
self.markdown = self._summarize_text(
|
||||
gpt, link, screenshot, _format or [], style, extras
|
||||
)
|
||||
|
||||
self.markdown = self._post_process_markdown(
|
||||
self.markdown, self.video_path, _format or [], self.audio_meta, platform
|
||||
)
|
||||
|
||||
self._change_state(self.States.SAVING)
|
||||
self._save_metadata(self.audio_meta.video_id, platform, task_id)
|
||||
|
||||
self._change_state(self.States.SUCCESS)
|
||||
return NoteResult(markdown=self.markdown, transcript=self.transcript, audio_meta=self.audio_meta)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"任务 {self.task_id} 失败: {e}")
|
||||
self._change_state(self.States.FAILED, str(e))
|
||||
return None
|
||||
|
||||
def _change_state(self, state: str, message: Optional[str] = None):
|
||||
if not self.task_id:
|
||||
return
|
||||
NOTE_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
status_file = NOTE_OUTPUT_DIR / f"{self.task_id}.status.json"
|
||||
data = {"status": state}
|
||||
if message:
|
||||
data["message"] = message
|
||||
temp_file = status_file.with_suffix('.tmp')
|
||||
with temp_file.open('w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
temp_file.replace(status_file)
|
||||
|
||||
def _init_transcriber(self) -> Transcriber:
|
||||
if self.transcriber_type not in _transcribers:
|
||||
raise Exception(f"不支持的转写器:{self.transcriber_type}")
|
||||
return get_transcriber(self.transcriber_type)
|
||||
|
||||
def _get_gpt(self, model_name: Optional[str], provider_id: Optional[str]) -> GPT:
|
||||
provider = ProviderService.get_provider_by_id(provider_id)
|
||||
if not provider:
|
||||
logger.error(f"[get_gpt] 未找到对应的模型供应商: provider_id={provider_id}")
|
||||
raise ValueError(f"未找到对应的模型供应商: provider_id={provider_id}")
|
||||
|
||||
gpt = GPTFactory().from_config(
|
||||
ModelConfig(
|
||||
api_key=provider.get('api_key'),
|
||||
base_url=provider.get('base_url'),
|
||||
model_name=model_name,
|
||||
provider=provider.get('type'),
|
||||
name=provider.get('name')
|
||||
)
|
||||
raise ProviderError(code=ProviderErrorEnum.NOT_FOUND, message=ProviderErrorEnum.NOT_FOUND.message)
|
||||
config = ModelConfig(
|
||||
api_key=provider["api_key"], base_url=provider["base_url"],
|
||||
model_name=model_name, provider=provider["type"], name=provider["name"]
|
||||
)
|
||||
return gpt
|
||||
return GPTFactory().from_config(config)
|
||||
|
||||
def get_downloader(self, platform: str) -> Downloader:
|
||||
downloader = SUPPORT_PLATFORM_MAP[platform]
|
||||
if downloader:
|
||||
logger.info(f"使用{downloader}下载器")
|
||||
return downloader
|
||||
else:
|
||||
logger.warning("不支持的平台")
|
||||
raise ValueError(f"不支持的平台:{platform}")
|
||||
def _get_downloader(self, platform: str) -> Downloader:
|
||||
downloader_cls = SUPPORT_PLATFORM_MAP.get(platform)
|
||||
if not downloader_cls:
|
||||
raise NoteError(code=NoteErrorEnum.PLATFORM_NOT_SUPPORTED.code,
|
||||
message=NoteErrorEnum.PLATFORM_NOT_SUPPORTED.message)
|
||||
return downloader_cls
|
||||
|
||||
def get_transcriber(self) -> Transcriber:
|
||||
'''
|
||||
def _download_audio_video(self, downloader, video_url, quality, output_path,
|
||||
screenshot, video_understanding, video_interval, grid_size):
|
||||
self._change_state(self.States.DOWNLOADING)
|
||||
|
||||
:param transcriber: 选择的转义器
|
||||
:return:
|
||||
'''
|
||||
if self.transcriber_type in _transcribers.keys():
|
||||
logger.info(f"使用{self.transcriber_type}转义器")
|
||||
return get_transcriber(transcriber_type=self.transcriber_type)
|
||||
else:
|
||||
logger.warning("不支持的转义器")
|
||||
raise ValueError(f"不支持的转义器:{self.transcriber}")
|
||||
need_video = screenshot or video_understanding
|
||||
if need_video:
|
||||
self.video_path = Path(downloader.download_video(video_url, output_path))
|
||||
if grid_size:
|
||||
self.video_img_urls = VideoReader(
|
||||
video_path=str(self.video_path),
|
||||
grid_size=tuple(grid_size),
|
||||
frame_interval=video_interval,
|
||||
unit_width=1280, unit_height=720,
|
||||
save_quality=90,
|
||||
).run()
|
||||
|
||||
def save_meta(self, video_id, platform, task_id):
|
||||
logger.info(f"记录已经生成的数据信息")
|
||||
insert_video_task(video_id=video_id, platform=platform, task_id=task_id)
|
||||
if self.audio_cache_file.exists():
|
||||
with open(self.audio_cache_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return AudioDownloadResult(**data)
|
||||
|
||||
def insert_screenshots_into_markdown(self, markdown: str, video_path: str, image_base_url: str,
|
||||
output_dir: str, _format: list) -> str:
|
||||
"""
|
||||
扫描 markdown 中的 *Screenshot-xx:xx,生成截图并插入 markdown 图片
|
||||
:param markdown:
|
||||
:param image_base_url: 最终返回给前端的路径前缀(如 /static/screenshots)
|
||||
"""
|
||||
matches = self.extract_screenshot_timestamps(markdown)
|
||||
new_markdown = markdown
|
||||
audio = downloader.download(
|
||||
video_url=video_url, quality=quality, output_dir=output_path, need_video=need_video
|
||||
)
|
||||
with open(self.audio_cache_file, "w", encoding="utf-8") as f:
|
||||
json.dump(asdict(audio), f, ensure_ascii=False, indent=2)
|
||||
return audio
|
||||
|
||||
logger.info(f"开始为笔记生成截图")
|
||||
try:
|
||||
for idx, (marker, ts) in enumerate(matches):
|
||||
image_path = generate_screenshot(video_path, output_dir, ts, idx)
|
||||
image_relative_path = os.path.join(image_base_url, os.path.basename(image_path)).replace("\\", "/")
|
||||
image_url = f"/static/screenshots/{os.path.basename(image_path)}"
|
||||
replacement = f""
|
||||
new_markdown = new_markdown.replace(marker, replacement, 1)
|
||||
def _transcribe_audio(self):
|
||||
self._change_state(self.States.TRANSCRIBING)
|
||||
if self.transcript_cache_file.exists():
|
||||
with open(self.transcript_cache_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
segments = [TranscriptSegment(**seg) for seg in data.get("segments", [])]
|
||||
return TranscriptResult(language=data["language"], full_text=data["full_text"], segments=segments)
|
||||
|
||||
return new_markdown
|
||||
except Exception as e:
|
||||
logger.error(f"截图生成失败:{e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"code": StatusCode.DOWNLOAD_ERROR,
|
||||
"msg": f"截图生成失败",
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
transcript = self.transcriber.transcript(self.audio_meta.file_path)
|
||||
with open(self.transcript_cache_file, "w", encoding="utf-8") as f:
|
||||
json.dump(asdict(transcript), f, ensure_ascii=False, indent=2)
|
||||
return transcript
|
||||
|
||||
def _summarize_text(self, gpt, link, screenshot, formats, style, extras):
|
||||
self._change_state(self.States.SUMMARIZING)
|
||||
source = GPTSource(
|
||||
title=self.audio_meta.title,
|
||||
segment=self.transcript.segments,
|
||||
tags=self.audio_meta.raw_info.get("tags", []),
|
||||
screenshot=screenshot,
|
||||
video_img_urls=self.video_img_urls,
|
||||
link=link, _format=formats, style=style, extras=extras
|
||||
)
|
||||
markdown = gpt.summarize(source)
|
||||
with open(self.markdown_cache_file, "w", encoding="utf-8") as f:
|
||||
f.write(markdown)
|
||||
return markdown
|
||||
|
||||
@staticmethod
|
||||
def delete_note(video_id: str, platform: str):
|
||||
logger.info(f"删除生成的笔记记录")
|
||||
return delete_task_by_video(video_id, platform)
|
||||
def _post_process_markdown(self, markdown, video_path, formats, audio_meta, platform):
|
||||
if "screenshot" in formats and video_path:
|
||||
markdown = self._insert_screenshots(markdown, video_path)
|
||||
if "link" in formats:
|
||||
markdown = replace_content_markers(markdown, video_id=audio_meta.video_id, platform=platform)
|
||||
return markdown
|
||||
|
||||
import re
|
||||
|
||||
def extract_screenshot_timestamps(self, markdown: str) -> list[tuple[str, int]]:
|
||||
"""
|
||||
从 Markdown 中提取 Screenshot 时间标记(如 *Screenshot-03:39 或 Screenshot-[03:39]),
|
||||
并返回匹配文本和对应时间戳(秒)
|
||||
"""
|
||||
logger.info(f"开始提取截图时间标记")
|
||||
def _insert_screenshots(self, markdown, video_path):
|
||||
pattern = r"(?:\*Screenshot-(\d{2}):(\d{2})|Screenshot-\[(\d{2}):(\d{2})\])"
|
||||
matches = list(re.finditer(pattern, markdown))
|
||||
results = []
|
||||
for match in matches:
|
||||
matches = []
|
||||
for match in re.finditer(pattern, markdown):
|
||||
mm = match.group(1) or match.group(3)
|
||||
ss = match.group(2) or match.group(4)
|
||||
total_seconds = int(mm) * 60 + int(ss)
|
||||
results.append((match.group(0), total_seconds))
|
||||
return results
|
||||
matches.append((match.group(0), int(mm)*60+int(ss)))
|
||||
for idx, (marker, ts) in enumerate(matches):
|
||||
img_path = generate_screenshot(str(video_path), str(IMAGE_OUTPUT_DIR), ts, idx)
|
||||
filename = Path(img_path).name
|
||||
img_url = f"{IMAGE_BASE_URL.rstrip('/')}/{filename}"
|
||||
markdown = markdown.replace(marker, f"", 1)
|
||||
return markdown
|
||||
|
||||
def generate(
|
||||
self,
|
||||
video_url: Union[str, HttpUrl],
|
||||
platform: str,
|
||||
quality: DownloadQuality = DownloadQuality.medium,
|
||||
task_id: Union[str, None] = None,
|
||||
model_name: str = None,
|
||||
provider_id: str = None,
|
||||
link: bool = False,
|
||||
screenshot: bool = False,
|
||||
_format: list = None,
|
||||
style: str = None,
|
||||
extras: str = None,
|
||||
path: Union[str, None] = None,
|
||||
video_understanding: bool = False,
|
||||
video_interval=0,
|
||||
grid_size=[]
|
||||
) -> NoteResult:
|
||||
def _save_metadata(self, video_id: str, platform: str, task_id: str):
|
||||
insert_video_task(video_id=video_id, platform=platform, task_id=task_id)
|
||||
|
||||
try:
|
||||
logger.info(f"🎯 开始解析并生成笔记,task_id={task_id}")
|
||||
self.update_task_status(task_id, TaskStatus.PARSING)
|
||||
downloader = self.get_downloader(platform)
|
||||
gpt = self.get_gpt(model_name=model_name, provider_id=provider_id)
|
||||
video_img_urls = []
|
||||
audio_cache_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}_audio.json")
|
||||
transcript_cache_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}_transcript.json")
|
||||
markdown_cache_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}_markdown.md")
|
||||
|
||||
# -------- 1. 下载音频 --------
|
||||
try:
|
||||
self.update_task_status(task_id, TaskStatus.DOWNLOADING)
|
||||
|
||||
# 加载音频缓存(如果存在)
|
||||
audio = None
|
||||
if os.path.exists(audio_cache_path):
|
||||
logger.info(f"检测到已有音频缓存,直接读取,task_id={task_id}")
|
||||
with open(audio_cache_path, "r", encoding="utf-8") as f:
|
||||
audio_data = json.load(f)
|
||||
audio = AudioDownloadResult(**audio_data)
|
||||
|
||||
# 需要视频的情况(截图 or 视频理解)
|
||||
need_video = 'screenshot' in _format or video_understanding
|
||||
if need_video:
|
||||
try:
|
||||
video_path = downloader.download_video(video_url)
|
||||
self.video_path = video_path
|
||||
logger.info(f"成功下载视频文件: {video_path}")
|
||||
|
||||
video_img_urls = VideoReader(
|
||||
video_path=video_path,
|
||||
grid_size=tuple(grid_size),
|
||||
frame_interval=video_interval,
|
||||
unit_width=1280,
|
||||
unit_height=720,
|
||||
save_quality=90,
|
||||
).run()
|
||||
except Exception as e:
|
||||
logger.error(f"Error 下载视频失败,task_id={task_id},错误信息:{e}")
|
||||
self.update_task_status(task_id, TaskStatus.FAILED, message=f"下载音频失败:{e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"code": StatusCode.DOWNLOAD_ERROR,
|
||||
"msg": f"下载视频失败,task_id={task_id}",
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
# 没有音频缓存就下载音频(可能同时也带上视频)
|
||||
if audio is None:
|
||||
audio = downloader.download(
|
||||
video_url=video_url,
|
||||
quality=quality,
|
||||
output_dir=path,
|
||||
need_video='screenshot' in _format, # 注意这里只为了截图需要
|
||||
)
|
||||
with open(audio_cache_path, "w", encoding="utf-8") as f:
|
||||
json.dump(asdict(audio), f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"音频下载并缓存成功,task_id={task_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error 下载音频失败,task_id={task_id},错误信息:{e}")
|
||||
self.update_task_status(task_id, TaskStatus.FAILED, message=f"下载音频失败:{e}")
|
||||
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"code": StatusCode.DOWNLOAD_ERROR,
|
||||
"msg": f"下载音频失败,task_id={task_id}",
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
# -------- 2. 转写文字 --------
|
||||
try:
|
||||
self.update_task_status(task_id, TaskStatus.TRANSCRIBING)
|
||||
if os.path.exists(transcript_cache_path):
|
||||
logger.info(f"检测到已有转写缓存,直接读取,task_id={task_id}")
|
||||
try:
|
||||
with open(transcript_cache_path, "r", encoding="utf-8") as f:
|
||||
transcript_data = json.load(f)
|
||||
transcript = TranscriptResult(
|
||||
language=transcript_data["language"],
|
||||
full_text=transcript_data["full_text"],
|
||||
segments=[TranscriptSegment(**seg) for seg in transcript_data["segments"]]
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.warning(f"Warning 读取转录缓存失败,重新转录,task_id={task_id},错误信息:{e}")
|
||||
transcript: TranscriptResult = self.transcriber.transcript(file_path=audio.file_path)
|
||||
with open(transcript_cache_path, "w", encoding="utf-8") as f:
|
||||
json.dump(asdict(transcript), f, ensure_ascii=False, indent=2)
|
||||
else:
|
||||
transcript: TranscriptResult = self.transcriber.transcript(file_path=audio.file_path)
|
||||
with open(transcript_cache_path, "w", encoding="utf-8") as f:
|
||||
json.dump(asdict(transcript), f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"文字转写并缓存成功,task_id={task_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error 转写文字失败,task_id={task_id},错误信息:{e}")
|
||||
self.update_task_status(task_id, TaskStatus.FAILED, message=f"转写文字失败:{e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"code": StatusCode.GENERATE_ERROR, # =1003
|
||||
"msg": f"转写文字失败,task_id={task_id}",
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
# -------- 3. 总结内容 --------
|
||||
try:
|
||||
self.update_task_status(task_id, TaskStatus.SUMMARIZING)
|
||||
# if os.path.exists(markdown_cache_path):
|
||||
# logger.info(f"检测到已有总结缓存,直接读取,task_id={task_id}")
|
||||
# with open(markdown_cache_path, "r", encoding="utf-8") as f:
|
||||
# markdown = f.read()
|
||||
# else:
|
||||
source = GPTSource(
|
||||
title=audio.title,
|
||||
segment=transcript.segments,
|
||||
tags=audio.raw_info.get('tags'),
|
||||
screenshot=screenshot,
|
||||
video_img_urls=video_img_urls,
|
||||
link=link,
|
||||
_format=_format,
|
||||
style=style,
|
||||
extras=extras
|
||||
)
|
||||
|
||||
markdown: str = gpt.summarize(source)
|
||||
with open(markdown_cache_path, "w", encoding="utf-8") as f:
|
||||
f.write(markdown)
|
||||
logger.info(f"GPT总结并缓存成功,task_id={task_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error 总结内容失败,task_id={task_id},错误信息:{e}")
|
||||
self.update_task_status(task_id, TaskStatus.FAILED, message=f"总结内容失败:{e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"code": StatusCode.GENERATE_ERROR, # =1003
|
||||
"msg": f"总结内容失败,task_id={task_id}",
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
# -------- 4. 插入截图 --------
|
||||
if _format and 'screenshot' in _format:
|
||||
try:
|
||||
markdown = self.insert_screenshots_into_markdown(markdown, self.video_path, image_base_url,
|
||||
output_dir, _format)
|
||||
except Exception as e:
|
||||
logger.warning(f"Warning 插入截图失败,跳过处理,task_id={task_id},错误信息:{e}")
|
||||
if _format and 'link' in _format:
|
||||
try:
|
||||
markdown = replace_content_markers(markdown, video_id=audio.video_id, platform=platform)
|
||||
except Exception as e:
|
||||
logger.warning(f"Warning 插入链接失败,跳过处理,task_id={task_id},错误信息:{e}")
|
||||
# 注意:截图失败不终止整体流程
|
||||
|
||||
# -------- 5. 保存数据库记录 --------
|
||||
self.update_task_status(task_id, TaskStatus.SAVING)
|
||||
self.save_meta(video_id=audio.video_id, platform=platform, task_id=task_id)
|
||||
|
||||
# -------- 6. 完成 --------
|
||||
self.update_task_status(task_id, TaskStatus.SUCCESS)
|
||||
logger.info(f"succeed 笔记生成成功,task_id={task_id}")
|
||||
# TODO :改为前端一键清除缓存
|
||||
# if platform != 'local':
|
||||
# transcription_finished.send({
|
||||
# "file_path": audio.file_path,
|
||||
# })
|
||||
return NoteResult(
|
||||
markdown=markdown,
|
||||
transcript=transcript,
|
||||
audio_meta=audio
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error 笔记生成流程异常终止,task_id={task_id},错误信息:{e}")
|
||||
self.update_task_status(task_id, TaskStatus.FAILED, message=str(e))
|
||||
|
||||
# 返回结构化错误信息给前端(可以用于日志 + 显示 + 错误定位)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"code": StatusCode.FAIL,
|
||||
"msg": f"笔记生成流程异常终止,task_id={task_id}",
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
@staticmethod
|
||||
def delete_note(video_id: str, platform: str) -> int:
|
||||
return delete_task_by_video(video_id, platform)
|
||||
@@ -1,18 +1,24 @@
|
||||
from fastapi.responses import JSONResponse
|
||||
from app.utils.status_code import StatusCode
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, Any
|
||||
|
||||
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
class ResponseWrapper:
|
||||
@staticmethod
|
||||
def success(data=None, msg="success", code=StatusCode.SUCCESS):
|
||||
return {
|
||||
"code": int(code),
|
||||
def success(data=None, msg="success", code=0):
|
||||
return JSONResponse(content={
|
||||
"code": code,
|
||||
"msg": msg,
|
||||
"data": data
|
||||
}
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def error(msg="error", code=StatusCode.FAIL, data=None):
|
||||
return {
|
||||
"code": int(code),
|
||||
def error(msg="error", code=500, data=None):
|
||||
return JSONResponse(content={
|
||||
"code": code,
|
||||
"msg": msg,
|
||||
"data": data
|
||||
}
|
||||
})
|
||||
@@ -4,7 +4,7 @@ import uvicorn
|
||||
from starlette.staticfiles import StaticFiles
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from app.core.exception_handlers import register_exception_handlers
|
||||
from app.exceptions.exception_handlers import register_exception_handlers
|
||||
from app.db.model_dao import init_model_table
|
||||
from app.db.provider_dao import init_provider_table
|
||||
from app.utils.logger import get_logger
|
||||
@@ -33,11 +33,14 @@ if not os.path.exists(out_dir):
|
||||
os.makedirs(out_dir)
|
||||
|
||||
app = create_app()
|
||||
register_exception_handlers(app)
|
||||
app.mount(static_path, StaticFiles(directory=static_dir), name="static")
|
||||
app.mount("/uploads", StaticFiles(directory=uploads_dir), name="uploads")
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
register_exception_handlers(app)
|
||||
|
||||
|
||||
|
||||
register_handler()
|
||||
ensure_ffmpeg_or_raise()
|
||||
register_handler()
|
||||
@@ -46,8 +49,9 @@ async def startup_event():
|
||||
init_provider_table()
|
||||
init_model_table()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
port = int(os.getenv("BACKEND_PORT", 8000))
|
||||
host = os.getenv("BACKEND_HOST", "0.0.0.0")
|
||||
logger.info(f"Starting server on {host}:{port}")
|
||||
uvicorn.run("main:app", host=host, port=port, reload=False)
|
||||
uvicorn.run(app, host=host, port=port, reload=False)
|
||||
@@ -1,6 +1,6 @@
|
||||
server {
|
||||
listen 80;
|
||||
|
||||
client_max_body_size 10G;
|
||||
# 所有非 /api 请求全部代理给 frontend 容器
|
||||
location / {
|
||||
proxy_pass http://frontend:80;
|
||||
@@ -11,8 +11,6 @@ server {
|
||||
proxy_pass http://backend:8000;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
client_max_body_size 10G;
|
||||
client_body_buffer_size 128k;
|
||||
}
|
||||
|
||||
location /static/ {
|
||||
|
||||
Reference in New Issue
Block a user