refactor(backend): 重构后端异常处理和模型管理

- 新增自定义异常类 BizException、NoteError 和 ProviderError
- 优化了模型管理相关的逻辑,包括加载、删除和测试连接等功能
- 改进了 Douyin 下载器的错误处理
- 调整了任务重试逻辑和笔记生成的异常处理- 更新了相关组件和页面以适应新的异常处理机制
This commit is contained in:
JefferyHcool
2025-06-06 21:30:23 +08:00
parent df5c0f771a
commit 8b1bc54f2d
34 changed files with 661 additions and 660 deletions

View File

@@ -11,6 +11,7 @@ VITE_FRONTEND_PORT=3015
ENV=production
STATIC=/static
OUT_DIR=./static/screenshots
NOTE_OUTPUT_DIR=note_results
IMAGE_BASE_URL=/static/screenshots
DATA_DIR=data
# FFMPEG 配置

View File

@@ -36,7 +36,7 @@ const DownloaderForm = () => {
setLoading(true) // 🔁 切换平台时显示 loading
try {
const res = await getDownloaderCookie(id)
const cookie = res?.data?.data?.cookie || ''
const cookie = res?.cookie || ''
form.reset({ cookie }) // ✅ 正确重置表单值
} catch (e) {
toast.error('加载 Cookie 失败: ' + e)

View File

@@ -129,11 +129,10 @@ const ProviderForm = ({ isCreate = false }: { isCreate?: boolean }) => {
try {
const res = await deleteModelById(modelId)
if (res.data.code === 0) {
toast.success('删除成功')
} else {
toast.error(res.data.msg || '删除失败')
}
console.log('🔧 删除结果:', res)
toast.success('删除成功')
} catch (e) {
toast.error('删除异常')
}
@@ -151,16 +150,16 @@ const ProviderForm = ({ isCreate = false }: { isCreate?: boolean }) => {
return
}
setTesting(true)
const data = await testConnection({
id
})
if (data.data.code === 0) {
await testConnection({
id
})
toast.success('测试连通性成功 🎉')
} else {
toast.error(`连接失败: ${data.data.msg || '未知错误'}`)
}
} catch (error) {
toast.error('测试连通性异常')
toast.error(`连接失败: ${data.data.msg || '未知错误'}`)
// toast.error('测试连通性异常')
} finally {
setTesting(false)
}

View File

@@ -26,11 +26,11 @@ export const useTaskPolling = (interval = 3000) => {
try {
console.log('🔄 正在轮询任务:', task.id)
const res = await get_task_status(task.id)
const { status } = res.data
const { status } = res
if (status && status !== task.status) {
if (status === 'SUCCESS') {
const { markdown, transcript, audio_meta } = res.data.result
const { markdown, transcript, audio_meta } = res.result
toast.success('笔记生成成功')
updateTaskContent(task.id, {
status,
@@ -47,7 +47,7 @@ export const useTaskPolling = (interval = 3000) => {
}
} catch (e) {
console.error('❌ 任务轮询失败:', e)
toast.error(`生成失败 ${e.message || e}`)
// toast.error(`生成失败 ${e.message || e}`)
updateTaskContent(task.id, { status: 'FAILED' })
// removeTask(task.id)
}

View File

@@ -173,6 +173,7 @@ const MarkdownViewer: FC<MarkdownViewerProps> = ({ status }) => {
<div className="text-center">
<p className="text-lg font-bold text-red-500"></p>
<p className="mt-2 mb-2 text-xs text-red-400"></p>
<Button onClick={() => retryTask(currentTask.id)} size="lg">
</Button>

View File

@@ -37,6 +37,7 @@ import {
import { Input } from '@/components/ui/input.tsx'
import { Textarea } from '@/components/ui/textarea.tsx'
import { noteStyles, noteFormats, videoPlatforms } from '@/constant/note.ts'
import { fetchModels } from '@/services/model.ts'
/* -------------------- 校验 Schema -------------------- */
const formSchema = z
@@ -206,7 +207,7 @@ const NoteForm = () => {
}
message.success('已提交任务')
const { data } = await generateNote(payload)
const data = await generateNote(payload)
addPendingTask(data.task_id, values.platform, payload)
}
const onInvalid = (errors: FieldErrors<NoteFormValues>) => {
@@ -355,6 +356,9 @@ const NoteForm = () => {
<FormItem>
<SectionHeader title="模型选择" tip="不同模型效果不同,建议自行测试" />
<Select
onOpenChange={()=>{
loadEnabledModels()
}}
value={field.value}
onValueChange={field.onChange}
defaultValue={field.value}

View File

@@ -26,7 +26,7 @@ export default function AboutPage() {
height={50}
className="rounded-lg"
/>
<h1 className="text-4xl font-bold">BiliNote v1.7.3</h1>
<h1 className="text-4xl font-bold">BiliNote v1.7.4</h1>
</div>
<p className="text-muted-foreground mb-6 text-xl italic">
AI AI

View File

@@ -1,7 +1,6 @@
import request from '@/utils/request'
import toast from 'react-hot-toast'
import { useTaskStore } from '@/store/taskStore'
import request from '@/utils/request'
export const generateNote = async (data: {
video_url: string
platform: string
@@ -14,12 +13,13 @@ export const generateNote = async (data: {
extras?: string
video_understand?: boolean
video_interval?: number
grid_size:Array<number>
grid_size: Array<number>
}) => {
try {
console.log('generateNote', data)
const response = await request.post('/generate_note', data)
if (response.data.code != 0) {
if (!response) {
if (response.data.msg) {
toast.error(response.data.msg)
}
@@ -30,12 +30,12 @@ export const generateNote = async (data: {
console.log('res', response)
// 成功提示
return response.data
return response
} catch (e: any) {
console.error('❌ 请求出错', e)
// 错误提示
toast.error('笔记生成失败,请稍后重试')
// toast.error('笔记生成失败,请稍后重试')
throw e // 抛出错误以便调用方处理
}
@@ -65,15 +65,9 @@ export const delete_task = async ({ video_id, platform }) => {
export const get_task_status = async (task_id: string) => {
try {
const response = await request.get('/task_status/' + task_id)
if (response.data.code == 0 && response.data.status == 'SUCCESS') {
// toast.success("笔记生成成功")
}
console.log('res', response)
// 成功提示
return response.data
return await request.get('/task_status/' + task_id)
} catch (e) {
console.error('❌ 请求出错', e)

View File

@@ -1,6 +1,12 @@
import { create } from 'zustand'
import { devtools } from 'zustand/middleware'
import { fetchModels, addModel, fetchEnableModels, fetchEnableModelById, deleteModelById } from '@/services/model.ts'
import {
fetchModels,
addModel,
fetchEnableModels,
fetchEnableModelById,
deleteModelById
} from '@/services/model'
interface IModel {
id: string
@@ -11,81 +17,93 @@ interface IModel {
root: string
}
interface IModelListItem {
id: string
provider_id: string
model_name: string
created_at?: string
}
interface ModelStore {
models: IModel[]
modelList: []
modelList: IModelListItem[]
loading: boolean
selectedModel: string
loadModels: (providerId: string) => Promise<void>
loadModelsById: (providerId: string) => Promise<IModelListItem[]>
loadEnabledModels: () => Promise<void>
loadModelsById : (providerId: string) => Promise<void>
addNewModel: (providerId: string, modelId: string) => Promise<void>
setSelectedModel: (modelId: string) => void
deleteModel: (modelId: number) => Promise<void>
setSelectedModel: (modelId: string) => void
clearModels: () => void
}
export const useModelStore = create<ModelStore>()(
devtools(set => ({
devtools((set) => ({
models: [],
modelList: [],
loading: false,
selectedModel: '',
modelList: [],
// 获取所有可用模型 (全局可用模型列表)
loadEnabledModels: async () => {
try {
set({ loading: true })
const res = await fetchEnableModels()
if (res.data.code === 0 && res.data.data.length > 0) {
set({ modelList: res.data.data })
} else {
set({ modelList: [] })
console.error('模型列表加载失败')
}
const list = await fetchEnableModels()
set({ modelList: list })
} catch (error) {
set({ modelList: [] })
console.error('加载模型出错', error)
}
},
deleteModel: async (modelId: number) => {
await deleteModelById( modelId)
},
// 加载模型列表
loadModels: async (providerId: string) => {
try {
set({ loading: true })
const res = await fetchModels(providerId)
if (res.data.code === 0 && res.data.data.models.data.length > 0) {
set({ models: res.data.data.models.data })
} else if (res.data.code === 0 && res.data.data.models.length > 0) {
set({ models: res.data.data.models.data })
} else {
set({ models: [] })
console.error('模型列表加载失败')
}
} catch (error) {
set({ models: [] })
console.error('加载模型出错', error)
console.error('加载可用模型失败', error)
} finally {
set({ loading: false })
}
},
loadModelsById: async (providerId: string)=>{
const models = await fetchEnableModelById(providerId)
if (models.data.code === 0) {
console.log('模型列表加载成功:', models.data)
return models.data.data
// 通过 provider 获取该供应商的模型列表
loadModels: async (providerId: string) => {
try {
set({ loading: true })
const res = await fetchModels(providerId)
let models: IModel[] = []
// 兼容 SyncPage 分页对象与普通数组两种格式
if (Array.isArray(res.models)) {
models = res.models
} else if (res.models?.data && Array.isArray(res.models.data)) {
models = res.models.data
}
set({ models })
} catch (error) {
set({ models: [] })
console.error('加载模型列表失败', error)
} finally {
set({ loading: false })
}
},
// 新增模型
},
// 单独获取某个供应商下已启用模型
loadModelsById: async (providerId: string) => {
try {
const models = await fetchEnableModelById(providerId)
console.log('获取供应商模型成功:', models)
return models
} catch (error) {
console.error('加载供应商模型失败', error)
return []
}
},
// 新增模型逻辑
addNewModel: async (providerId: string, modelId: string) => {
try {
const res = await addModel({ provider_id: providerId, model_name: modelId })
if (res.code === 0) {
console.log('新增模型成功:', modelId)
// ✅ 新增成功以后,前端直接追加一条到 models 列表
set(state => ({
set((state) => ({
models: [
...state.models,
{
@@ -99,17 +117,30 @@ export const useModelStore = create<ModelStore>()(
],
}))
} else {
console.error('新增模型失败')
console.error('新增模型失败', res.msg)
}
} catch (error) {
console.error('添加模型出错', error)
}
},
// 设置选中的模型
setSelectedModel: modelId => set({ selectedModel: modelId }),
// 删除模型
deleteModel: async (modelId: number) => {
try {
await deleteModelById(modelId)
// 删除后更新本地状态(可选)
set((state) => ({
models: state.models.filter((model) => model.id !== modelId.toString())
}))
} catch (error) {
console.error('删除模型失败', error)
}
},
// 清空
clearModels: () => set({ models: [], selectedModel: '' }),
// 切换选中模型
setSelectedModel: (modelId: string) => set({ selectedModel: modelId }),
// 清空
clearModels: () => set({ models: [], selectedModel: '', modelList: [] }),
}))
)
)

View File

@@ -1,5 +1,5 @@
import { create } from 'zustand'
import { IProvider } from '@/types'
import { IProvider, IResponse } from '@/types'
import {
addProvider,
getProviderById,
@@ -38,10 +38,9 @@ export const useProviderStore = create<ProviderStore>((set, get) => ({
// 设置整个 provider 列表
setAllProviders: providers => set({ provider: providers }),
loadProviderById: async (id: string) => {
const res = await getProviderById(id)
if (res.data.code === 0) {
const item = res.data.data
console.log('Provider ', item)
const res:IResponse<IProvider> = await getProviderById(id)
const item = res
return {
id: item.id,
name: item.name,
@@ -51,9 +50,7 @@ export const useProviderStore = create<ProviderStore>((set, get) => ({
type: item.type,
enabled: item.enabled,
}
} else {
console.log('Provider not found')
}
},
addNewProvider: async (provider: IProvider) => {
const payload = {
@@ -96,16 +93,18 @@ export const useProviderStore = create<ProviderStore>((set, get) => ({
getProviderList: () => get().provider,
fetchProviderList: async () => {
try {
const res = await getProviderList()
if (res.data.code === 0) {
const res = await getProviderList()
set({
provider: res.data.data.map(
provider: res.map(
(item: {
id: string
name: string
logo: string
api_key: string
base_url: string
type: string
enabled: number
}) => {
return {
id: item.id,
@@ -119,7 +118,6 @@ export const useProviderStore = create<ProviderStore>((set, get) => ({
}
),
})
}
} catch (error) {
console.error('Error fetching provider list:', error)
}

View File

@@ -2,6 +2,7 @@ import { create } from 'zustand'
import { persist } from 'zustand/middleware'
import { delete_task, generateNote } from '@/services/note.ts'
import { v4 as uuidv4 } from 'uuid'
import toast from 'react-hot-toast'
export type TaskStatus = 'PENDING' | 'RUNNING' | 'SUCCESS' | 'FAILD'
@@ -157,14 +158,19 @@ export const useTaskStore = create<TaskStore>()(
return get().tasks.find(task => task.id === currentTaskId) || null
},
retryTask: async (id: string, payload?: any) => {
if (!id){
toast.error('任务不存在')
return
}
const task = get().tasks.find(task => task.id === id)
console.log('retry',task)
if (!task) return
const newFormData = payload || task.formData
await generateNote({
task_id: id,
...newFormData,
task_id: id,
})
set(state => ({

View File

@@ -7,3 +7,8 @@ export interface IProvider {
baseUrl: string
enabled: number
}
export interface IResponse<T> {
code: number
data:T
msg: string
}

View File

@@ -1,27 +1,58 @@
import axios from 'axios'
const request = axios.create({
baseURL: '/api',
timeout: 10000,
})
function handleErrorResponse(response: any) {
if (!response) return '请求失败,请检查网络连接'
if (typeof response.code !== 'number') return '系统异常'
import axios, { AxiosInstance, AxiosResponse } from 'axios';
import toast from 'react-hot-toast'
// 错误码判断
switch (response.code) {
case 1001:
return response.msg || '下载失败,请检查视频链接'
case 1002:
return response.msg || '转写失败,请稍后重试'
case 1003:
return response.msg || '总结失败,可能是模型服务异常'
case 2001:
case 2002:
return Array.isArray(response.data)
? response.data.map(e => `${e.field}: ${e.error}`).join('\n')
: response.msg || '参数错误'
default:
return response.msg || '系统异常'
}
// 统一响应类型
export interface IResponse<T = any> {
code: number;
msg: string;
data: T;
}
// 模拟一个消息提示函数 (实际项目中会使用UI库的组件如 Ant Design 的 message 或 Element UI 的 ElMessage)
// This function simulates a message display (in real projects, you'd use a UI library's component)
// 创建实例
const request: AxiosInstance = axios.create({
baseURL: '/api', // 请确保你的开发服务器代理设置正确
timeout: 10000,
});
// 响应拦截器
request.interceptors.response.use(
(response: AxiosResponse<IResponse>) => {
const res = response.data;
if (res.code === 0) {
// 业务成功,可以根据需要显示成功消息,或者不显示(如果操作本身就是可见的)
// showMessage('success', res.msg || '操作成功'); // 如果需要显示成功消息
return res.data; // 返回data部分简化后续业务代码
} else {
// 业务错误,统一显示后端返回的错误消息
// Business error, uniformly display the error message returned from the backend
toast.error(res.msg || '操作失败,请稍后再试');
return Promise.reject(res); // 拒绝Promise让业务代码可以捕获并处理
}
},
(error) => {
// 网络/服务器错误
const res = error?.response?.data as IResponse | undefined;
if (res) {
// 如果后端有返回错误信息,则显示后端信息
// If the backend returns an error message, display it
toast.error(res.msg || '服务器错误,请稍后再试');
return Promise.reject(res);
} else {
// 没有响应数据(如网络中断),显示通用网络错误
// No response data (e.g., network disconnected), display generic network error
toast.error( '请求失败,请检查网络连接或稍后再试')
return Promise.reject({
code: -1,
msg: '请求失败,请检查网络连接',
data: null
} as IResponse);
}
}
);
export default request

View File

@@ -3,7 +3,7 @@
<p align="center">
<img src="./doc/icon.svg" alt="BiliNote Banner" width="50" height="50" />
</p>
<h1 align="center" > BiliNote v1.7.3</h1>
<h1 align="center" > BiliNote v1.7.4</h1>
</div>
<p align="center"><i>AI 视频笔记生成工具 让 AI 为你的视频做笔记</i></p>

View File

@@ -1,11 +1,14 @@
from fastapi import FastAPI
from .routers import note, provider, model, config
def create_app() -> FastAPI:
app = FastAPI(title="BiliNote")
app.include_router(note.router, prefix="/api")
app.include_router(provider.router, prefix="/api")
app.include_router(model.router,prefix="/api")
app.include_router(config.router, prefix="/api")
return app

View File

@@ -1,38 +0,0 @@
# app/core/exception_handlers.py
from fastapi import Request, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.utils.logger import get_logger
from app.utils.response import ResponseWrapper
from app.utils.status_code import StatusCode
logger = get_logger(__name__)
def register_exception_handlers(app):
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
errors = []
for err in exc.errors():
loc = err.get("loc", [])
field = loc[-1] if loc else "body"
msg = err.get("msg", "参数不合法")
errors.append({"field": field, "error": msg})
return JSONResponse(
status_code=400,
content=ResponseWrapper.error(msg="参数验证失败", code=StatusCode.PARAM_ERROR, data=errors)
)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
return JSONResponse(
status_code=exc.status_code,
content=ResponseWrapper.error(msg=str(exc.detail), code=StatusCode.FAIL)
)
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
logger.exception(f"服务器内部错误: {exc}")
return JSONResponse(
status_code=500,
content=ResponseWrapper.error(msg="服务器内部错误", code=StatusCode.FAIL, data=str(exc))
)

View File

@@ -136,7 +136,8 @@ def get_provider_by_name(name: str):
if row is None:
logger.info(f"Provider not found: {name}")
return None
logger.info(f"Provider found: {row}")
logger.info(f"Provider found: {row[0]}")
return row
except Exception as e:
logger.error(f"Failed to get provider by name: {e}")
@@ -155,7 +156,7 @@ def get_provider_by_id(id: int):
if row is None:
logger.info(f"Provider not found: {id}")
return None
logger.info(f"Provider found: {row}")
logger.info(f"Provider found: {row[0]}")
return row
except Exception as e:
logger.error(f"Failed to get provider by id: {e}")
@@ -173,7 +174,7 @@ def get_all_providers():
if rows is None:
logger.info("No providers found")
return None
logger.info(f"Providers found: {rows}")
logger.info(f"Providers found total {len(rows) }")
return rows
except Exception as e:
logger.error(f"Failed to get all providers: {e}")

View File

@@ -145,53 +145,59 @@ class DouyinDownloader(Downloader):
return ""
def gen_real_msToken(self) -> str:
payload = json.dumps(
{
"magic": self.ms_token_config["magic"],
"version": self.ms_token_config["version"],
"dataType": self.ms_token_config["dataType"],
"strData": self.ms_token_config["strData"],
"tspFromClient": get_timestamp(),
try:
payload = json.dumps(
{
"magic": self.ms_token_config["magic"],
"version": self.ms_token_config["version"],
"dataType": self.ms_token_config["dataType"],
"strData": self.ms_token_config["strData"],
"tspFromClient": get_timestamp(),
}
)
headers = {
"User-Agent": self.headers_config["User-Agent"],
"Content-Type": "application/json",
}
)
headers = {
"User-Agent": self.headers_config["User-Agent"],
"Content-Type": "application/json",
}
transport = httpx.HTTPTransport(retries=5)
with httpx.Client(transport=transport) as client:
try:
response = client.post(
self.ms_token_config["url"], content=payload, headers=headers
)
response.raise_for_status()
transport = httpx.HTTPTransport(retries=5)
with httpx.Client(transport=transport) as client:
try:
response = client.post(
self.ms_token_config["url"], content=payload, headers=headers
)
response.raise_for_status()
msToken = str(httpx.Cookies(response.cookies).get("msToken"))
if len(msToken) not in [120, 128]:
raise ValueError("响应内容:{0} Douyin msToken API 的响应内容不符合要求。".format(msToken))
msToken = str(httpx.Cookies(response.cookies).get("msToken"))
if len(msToken) not in [120, 128]:
raise ValueError("响应内容:{0} Douyin msToken API 的响应内容不符合要求。".format(msToken))
return msToken
except Exception as e:
raise ValueError("Douyin msToken API 请求失败:{0}".format(e))
return msToken
except Exception as e:
raise ValueError("Douyin msToken API 请求失败:{0}".format(e))
except Exception as e:
raise ValueError("Douyin msToken API{0}".format(e))
def fetch_video_info(self, video_url: str) -> json:
aweme_id = self.extract_video_id(video_url)
kwargs = self.headers_config
print("kwargs:", kwargs)
base_params = BaseRequestModel().model_dump()
base_params["msToken"] = self.gen_real_msToken()
base_params["aweme_id"] = aweme_id
bogus = ABogus()
ab_value = bogus.get_value(base_params)
a_bogus = quote(ab_value, safe='')
print(base_params)
query_str = urlencode(base_params)
full_url = f"{DOUYIN_DOMAIN}/aweme/v1/web/aweme/detail/?{query_str}&a_bogus={a_bogus}"
print("Request URL:", full_url)
try:
aweme_id = self.extract_video_id(video_url)
kwargs = self.headers_config
print("@kwargs:", kwargs)
base_params = BaseRequestModel().model_dump()
base_params["msToken"] = self.gen_real_msToken()
base_params["aweme_id"] = aweme_id
bogus = ABogus()
ab_value = bogus.get_value(base_params)
a_bogus = quote(ab_value, safe='')
print("@a_bogus:", a_bogus)
print(base_params)
query_str = urlencode(base_params)
full_url = f"{DOUYIN_DOMAIN}/aweme/v1/web/aweme/detail/?{query_str}&a_bogus={a_bogus}"
print("Request URL:", full_url)
response = requests.get(full_url, headers=kwargs)
print("Response JSON:", response.content)
@@ -208,46 +214,49 @@ class DouyinDownloader(Downloader):
quality: DownloadQuality = "fast",
need_video: Optional[bool] = False
) -> AudioDownloadResult:
print(
f"正在下载视频: {video_url},保存路径: {output_dir},质量: {quality}"
)
if output_dir is None:
output_dir = get_data_dir()
if not output_dir:
output_dir = self.cache_data
os.makedirs(output_dir, exist_ok=True)
try:
print(
f"正在下载视频: {video_url},保存路径: {output_dir},质量: {quality}"
)
if output_dir is None:
output_dir = get_data_dir()
if not output_dir:
output_dir = self.cache_data
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "%(id)s.%(ext)s")
output_path = os.path.join(output_dir, "%(id)s.%(ext)s")
video_data = self.fetch_video_info(video_url)
output_path = output_path % {
"id": video_data['aweme_detail']['aweme_id'],
"ext": "mp3",
}
url = video_data['aweme_detail']['music']['play_url']['uri']
# 下载音频
audio_data = requests.get(url)
with open(output_path, 'wb') as f:
f.write(audio_data.content)
print(url)
tags = []
for tag in video_data['aweme_detail']['video_tag']:
if tag['tag_name']:
tags.append(tag['tag_name'])
video_data = self.fetch_video_info(video_url)
output_path = output_path % {
"id": video_data['aweme_detail']['aweme_id'],
"ext": "mp3",
}
url = video_data['aweme_detail']['music']['play_url']['uri']
# 下载音频
audio_data = requests.get(url)
with open(output_path, 'wb') as f:
f.write(audio_data.content)
print(url)
tags = []
for tag in video_data['aweme_detail']['video_tag']:
if tag['tag_name']:
tags.append(tag['tag_name'])
return AudioDownloadResult(
file_path=output_path,
title=video_data['aweme_detail']['item_title'],
duration=video_data['aweme_detail']['video']['duration'],
cover_url=video_data['aweme_detail']['video']['cover_original_scale']['url_list'][0] if
video_data['aweme_detail']['video']['cover'] else video_data['video']['big_thumbs']['img_url'],
platform="douyin",
video_id=video_data['aweme_detail']['aweme_id'],
raw_info={
'tags': video_data['aweme_detail']['caption'] + ''.join(tags),
},
video_path=None # ❗音频下载不包含视频路径
)
return AudioDownloadResult(
file_path=output_path,
title=video_data['aweme_detail']['item_title'],
duration=video_data['aweme_detail']['video']['duration'],
cover_url=video_data['aweme_detail']['video']['cover_original_scale']['url_list'][0] if
video_data['aweme_detail']['video']['cover'] else video_data['video']['big_thumbs']['img_url'],
platform="douyin",
video_id=video_data['aweme_detail']['aweme_id'],
raw_info={
'tags': video_data['aweme_detail']['caption'] + ''.join(tags),
},
video_path=None # ❗音频下载不包含视频路径
)
except Exception as e:
raise e
def download_video(self, video_url: str, output_dir: Union[str, None] = None) -> str:

View File

@@ -0,0 +1,21 @@
import enum
class ProviderErrorEnum(enum.Enum):
CONNECTION_TEST_FAILED = (200101, "供应商连接测试失败")
SAVE_FAILED = (200102, "供应商保存失败")
CREATE_FAILED = (200103, "供应商创建失败")
NOT_FOUND = (200104, "供应商不存在/未保存")
WRONG_PARAMETER = (200105, "API / API 地址不正确")
UNKNOW_ERROR = (200106, "未知错误")
def __init__(self, code, message):
self.code = code
self.message = message
class NoteErrorEnum(enum.Enum):
PLATFORM_NOT_SUPPORTED = (300101 ,"选择的平台不受支持")
def __init__(self, code, message):
self.code = code
self.message = message

View File

View File

@@ -0,0 +1,6 @@
# exceptions/biz_exception.py
class BizException(Exception):
def __init__(self, code: int, message: str = "业务异常"):
self.code = code
self.message = message

View File

@@ -0,0 +1,33 @@
# middlewares/exception_handler.py
from fastapi import Request
from fastapi import FastAPI
from app.enmus.exception import NoteErrorEnum
from app.exceptions.biz_exception import BizException
from app.exceptions.note import NoteError
from app.exceptions.provider import ProviderError
from app.utils.logger import get_logger
from app.utils.response import ResponseWrapper as R
import traceback
logger = get_logger(__name__)
def register_exception_handlers(app: FastAPI):
@app.exception_handler(BizException)
async def biz_exception_handler(request: Request, exc: BizException):
logger.error(f"BizException: {exc.code} - {exc.message}")
return R.error(code=exc.code, msg=str(exc.message))
@app.exception_handler(NoteError)
async def note_exception_handler(request: Request, exc: NoteError):
logger.error(f"NoteError: {exc.code} - {exc.message}")
return R.error(code=exc.code, msg=str(exc.message))
@app.exception_handler(ProviderError)
async def provider_exception_handler(request: Request, exc: ProviderError):
logger.error(f"供应商模块错误: {exc.code} - {exc.message}")
return R.error(code=exc.code, msg=str(exc.message))
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
logger.error(f"系统异常: {str(exc)}\n{traceback.format_exc()}")
return R.error(code=500000, msg="系统异常")

View File

@@ -0,0 +1,9 @@
# exceptions.py
from app.enmus.exception import ProviderErrorEnum
class NoteError(Exception):
def __init__(self, message: str,code: ProviderErrorEnum) -> None:
super().__init__(message)
self.code=code
self.message = message

View File

@@ -1,5 +1,12 @@
# exceptions.py
class ConnectionTestError(Exception):
def __init__(self, message: str):
from app.enmus.exception import ProviderErrorEnum
class ProviderError(Exception):
def __init__(self, message: str,code: ProviderErrorEnum) -> None:
super().__init__(message)
self.message = message
self.code=code
self.message = message

View File

@@ -2,6 +2,9 @@ from typing import Optional, Union
from openai import OpenAI
from app.utils.logger import get_logger
logging= get_logger(__name__)
class OpenAICompatibleProvider:
def __init__(self, api_key: str, base_url: str, model: Union[str, None]=None):
self.client = OpenAI(api_key=api_key, base_url=base_url)
@@ -13,11 +16,16 @@ class OpenAICompatibleProvider:
@staticmethod
def test_connection(api_key: str, base_url: str) -> bool:
print(api_key)
try:
client = OpenAI(api_key=api_key, base_url=base_url)
client.models.list()
model = client.models.list()
# for segment in model:
# print(segment)
# print(model)
logging.info("连通性测试成功")
return True
except Exception as e:
print(f"Error connecting to OpenAI API: {e}")
logging.info(f"连通性测试失败:{e}")
# print(f"Error connecting to OpenAI API: {e}")
return False

View File

@@ -27,4 +27,6 @@ def get_cookie(platform: str):
@router.post("/update_downloader_cookie")
def update_cookie(data: CookieUpdateRequest):
cookie_manager.set(data.platform, data.cookie)
return {"message": "Cookie updated successfully"}
return R.success(
)

View File

@@ -31,10 +31,9 @@ def delete_model(model_id: int):
return R.error(f"删除模型失败: {e}")
@router.get("/model_list/{provider_id}")
def model_list(provider_id):
try:
return R.success(modelService.get_all_models_by_id(provider_id))
except Exception as e:
return R.error(e)
return R.success(modelService.get_all_models_by_id(provider_id))
@router.post("/models")
def create_model(data: CreateModelRequest):

View File

@@ -2,6 +2,7 @@
import json
import os
import uuid
from pathlib import Path
from typing import Optional
from urllib.parse import urlparse
@@ -10,7 +11,9 @@ from pydantic import BaseModel, validator, field_validator
from dataclasses import asdict
from app.db.video_task_dao import get_task_by_video
from app.enmus.exception import NoteErrorEnum
from app.enmus.note_enums import DownloadQuality
from app.exceptions.note import NoteError
from app.services.note import NoteGenerator, logger
from app.utils.response import ResponseWrapper as R
from app.utils.url_parser import extract_video_id
@@ -54,12 +57,13 @@ class VideoRequest(BaseModel):
if parsed.scheme in ("http", "https"):
# 是网络链接,继续用原有平台校验
if not is_supported_video_url(url):
raise ValueError("暂不支持该视频平台或链接格式无效")
raise NoteError(code=NoteErrorEnum.PLATFORM_NOT_SUPPORTED.code,
message=NoteErrorEnum.PLATFORM_NOT_SUPPORTED.message)
return v
NOTE_OUTPUT_DIR = "note_results"
NOTE_OUTPUT_DIR = os.getenv("NOTE_OUTPUT_DIR", "note_results")
UPLOAD_DIR = "uploads"
@@ -74,30 +78,32 @@ def run_note_task(task_id: str, video_url: str, platform: str, quality: Download
_format: list = None, style: str = None, extras: str = None, video_understanding: bool = False,
video_interval=0, grid_size=[]
):
try:
if not model_name or not provider_id:
raise HTTPException(status_code=400, detail="请选择模型和提供者")
note = NoteGenerator().generate(
video_url=video_url,
platform=platform,
quality=quality,
task_id=task_id,
model_name=model_name,
provider_id=provider_id,
link=link,
_format=_format,
style=style,
extras=extras,
screenshot=screenshot
, video_understanding=video_understanding,
video_interval=video_interval,
grid_size=grid_size
)
logger.info(f"Note generated: {task_id}")
save_note_to_file(task_id, note)
except Exception as e:
save_note_to_file(task_id, {"error": str(e)})
if not model_name or not provider_id:
raise HTTPException(status_code=400, detail="请选择模型和提供者")
note = NoteGenerator().generate(
video_url=video_url,
platform=platform,
quality=quality,
task_id=task_id,
model_name=model_name,
provider_id=provider_id,
link=link,
_format=_format,
style=style,
extras=extras,
screenshot=screenshot
, video_understanding=video_understanding,
video_interval=video_interval,
grid_size=grid_size
)
logger.info(f"Note generated: {task_id}")
if not note or not note.markdown:
logger.warning(f"任务 {task_id} 执行失败,跳过保存")
return
save_note_to_file(task_id, note)
@router.post('/delete_task')
@@ -135,7 +141,6 @@ def generate_note(data: VideoRequest, background_tasks: BackgroundTasks):
# msg='笔记已生成,请勿重复发起',
#
# )
if data.task_id:
# 如果传了task_id说明是重试
task_id = data.task_id

View File

@@ -2,7 +2,7 @@ from typing import Optional
from fastapi import APIRouter
from pydantic import BaseModel
from app.exceptions.provider import ConnectionTestError
from app.exceptions.provider import ProviderError
from app.models.model_config import ModelConfig
from app.services.model import ModelService
from app.utils.response import ResponseWrapper as R
@@ -88,9 +88,5 @@ def update_provider(data: ProviderUpdateRequest):
@router.post('/connect_test')
def gpt_connect_test(data: TestRequest):
try:
ModelService().connect_test(data.id)
return R.success(msg='连接成功')
except Exception as e:
print("捕获到异常类型:", type(e))
return R.error(msg=str(e))
ModelService().connect_test(data.id)
return R.success(msg='连接成功')

View File

@@ -1,12 +1,16 @@
from app.db.model_dao import insert_model, get_all_models, get_model_by_provider_and_name, delete_model
from app.db.provider_dao import get_enabled_providers
from app.exceptions.provider import ConnectionTestError
from app.enmus.exception import ProviderErrorEnum
from app.exceptions.provider import ProviderError
from app.gpt.gpt_factory import GPTFactory
from app.gpt.provider.OpenAI_compatible_provider import OpenAICompatibleProvider
from app.models.model_config import ModelConfig
from app.services.provider import ProviderService
from app.utils.logger import get_logger
logger=get_logger(__name__)
class ModelService:
@staticmethod
@@ -83,37 +87,38 @@ class ModelService:
provider = ProviderService.get_provider_by_id(provider_id)
models = ModelService.get_model_list(provider["id"], verbose=verbose)
model_list={
"models": models
print(type(models))
serializable_models = [m.dict() for m in models.data]
model_list = {
"models": serializable_models
}
logger.info(f"[{provider['name']}] 获取模型成功")
return model_list
except Exception as e:
print(f"[{provider_id}] 获取模型失败: {e}")
# print(f"[{provider_id}] 获取模型失败: {e}")
logger.error(f"[{provider_id}] 获取模型失败: {e}")
return []
@staticmethod
def connect_test(id: str) -> bool:
try:
provider = ProviderService.get_provider_by_id(id)
if provider:
if not provider.get('api_key'):
raise ConnectionTestError(f"供应商信息未找到,请先保存重试")
result = OpenAICompatibleProvider.test_connection(
api_key=provider.get('api_key'),
base_url=provider.get('base_url')
)
if result:
return True
else:
raise ConnectionTestError("请检查API Key 和 API 地址是否正确")
provider = ProviderService.get_provider_by_id(id)
if provider:
if not provider.get('api_key'):
raise ProviderError(code=ProviderErrorEnum.NOT_FOUND.code, message=ProviderErrorEnum.NOT_FOUND.message)
result = OpenAICompatibleProvider.test_connection(
api_key=provider.get('api_key'),
base_url=provider.get('base_url')
)
if result:
return True
else:
raise ProviderError(code=ProviderErrorEnum.WRONG_PARAMETER.code,message=ProviderErrorEnum.WRONG_PARAMETER.message)
raise ProviderError(code=ProviderErrorEnum.NOT_FOUND.code, message=ProviderErrorEnum.NOT_FOUND.message)
raise ConnectionTestError("供应商信息未找到,请先保存重试")
except Exception as e:
# 抛出业务异常,交由 Controller 处理
raise ConnectionTestError(f"{str(e)}") from e
@staticmethod
def delete_model_by_id( model_id: int) -> bool:

View File

@@ -1,75 +1,63 @@
import json
from dataclasses import asdict
from fastapi import HTTPException
from app.downloaders.local_downloader import LocalDownloader
from app.enmus.task_status_enums import TaskStatus
import logging
import os
from typing import Union, Optional
import re
from dataclasses import asdict
from pathlib import Path
from typing import List, Optional, Tuple, Union, Any
from pydantic import HttpUrl
from dotenv import load_dotenv
from app.db.video_task_dao import insert_video_task, delete_task_by_video
from app.downloaders.base import Downloader
from app.downloaders.bilibili_downloader import BilibiliDownloader
from app.downloaders.douyin_downloader import DouyinDownloader
from app.downloaders.youtube_downloader import YoutubeDownloader
from app.services.constant import SUPPORT_PLATFORM_MAP
from app.enmus.task_status_enums import TaskStatus
from app.enmus.exception import NoteErrorEnum, ProviderErrorEnum
from app.exceptions.note import NoteError
from app.exceptions.provider import ProviderError
from app.db.video_task_dao import delete_task_by_video, insert_video_task
from app.gpt.base import GPT
from app.gpt.deepseek_gpt import DeepSeekGPT
from app.gpt.gpt_factory import GPTFactory
from app.gpt.openai_gpt import OpenaiGPT
from app.gpt.qwen_gpt import QwenGPT
from app.models.audio_model import AudioDownloadResult
from app.models.gpt_model import GPTSource
from app.models.model_config import ModelConfig
from app.models.notes_model import NoteResult
from app.models.notes_model import AudioDownloadResult
from app.enmus.note_enums import DownloadQuality
from app.models.transcriber_model import TranscriptResult, TranscriptSegment
from app.services.constant import SUPPORT_PLATFORM_MAP
from app.services.provider import ProviderService
from app.transcriber.base import Transcriber
from app.transcriber.transcriber_provider import get_transcriber, _transcribers
from app.transcriber.whisper import WhisperTranscriber
import re
from app.utils.note_helper import replace_content_markers
from app.utils.status_code import StatusCode
from app.utils.video_helper import generate_screenshot
# from app.services.whisperer import transcribe_audio
# from app.services.gpt import summarize_text
from dotenv import load_dotenv
from app.utils.logger import get_logger
from app.utils.video_reader import VideoReader
from events import transcription_finished
from app.utils.video_helper import generate_screenshot
from app.utils.note_helper import replace_content_markers
from app.enmus.note_enums import DownloadQuality
logger = get_logger(__name__)
# 环境变量
load_dotenv()
api_path = os.getenv("API_BASE_URL", "http://localhost")
BACKEND_PORT = os.getenv("BACKEND_PORT", 8000)
NOTE_OUTPUT_DIR = Path(os.getenv("NOTE_OUTPUT_DIR", "note_results"))
NOTE_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
IMAGE_BASE_URL = os.getenv("IMAGE_BASE_URL", "/static/screenshots")
IMAGE_OUTPUT_DIR = os.getenv("OUT_DIR", "images")
BACKEND_BASE_URL = f"{api_path}:{BACKEND_PORT}"
output_dir = os.getenv('OUT_DIR')
image_base_url = os.getenv('IMAGE_BASE_URL')
logger.info("starting up")
NOTE_OUTPUT_DIR = "note_results"
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class NoteGenerator:
class States:
INIT = 'INIT'
PARSING = 'PARSING'
DOWNLOADING = 'DOWNLOADING'
TRANSCRIBING = 'TRANSCRIBING'
SUMMARIZING = 'SUMMARIZING'
SAVING = 'SAVING'
SUCCESS = 'SUCCESS'
FAILED = 'FAILED'
def __init__(self):
self.model_size: str = 'base'
self.device: Union[str, None] = None
self.transcriber_type = os.getenv('TRANSCRIBER_TYPE', 'fast-whisper')
self.transcriber = self.get_transcriber()
self.video_path = None
logger.info("初始化NoteGenerator")
import logging
logger = logging.getLogger(__name__)
self.transcriber_type = os.getenv("TRANSCRIBER_TYPE", "fast-whisper")
self.transcriber: Transcriber = self._init_transcriber()
self.video_img_urls = []
@staticmethod
def update_task_status(task_id: str, status: Union[str, TaskStatus], message: Optional[str] = None):
@@ -81,310 +69,179 @@ class NoteGenerator:
with open(path, "w", encoding="utf-8") as f:
json.dump(content, f, ensure_ascii=False, indent=2)
def get_gpt(self, model_name: str = None, provider_id: str = None) -> GPT:
def generate(
self,
video_url: Union[str, HttpUrl],
platform: str,
quality: DownloadQuality = DownloadQuality.medium,
task_id: Optional[str] = None,
model_name: Optional[str] = None,
provider_id: Optional[str] = None,
link: bool = False,
screenshot: bool = False,
_format: Optional[List[str]] = None,
style: Optional[str] = None,
extras: Optional[str] = None,
output_path: Optional[str] = None,
video_understanding: bool = False,
video_interval: int = 0,
grid_size: Optional[List[int]] = None,
) -> NoteResult | None:
self.task_id = task_id
self._change_state(self.States.INIT)
try:
self._change_state(self.States.PARSING)
downloader = self._get_downloader(platform)
gpt = self._get_gpt(model_name, provider_id)
self.audio_cache_file = NOTE_OUTPUT_DIR / f"{task_id}_audio.json"
self.transcript_cache_file = NOTE_OUTPUT_DIR / f"{task_id}_transcript.json"
self.markdown_cache_file = NOTE_OUTPUT_DIR / f"{task_id}_markdown.md"
self.audio_meta = self._download_audio_video(
downloader, video_url, quality, output_path,
screenshot, video_understanding, video_interval, grid_size or []
)
self.transcript = self._transcribe_audio()
self.markdown = self._summarize_text(
gpt, link, screenshot, _format or [], style, extras
)
self.markdown = self._post_process_markdown(
self.markdown, self.video_path, _format or [], self.audio_meta, platform
)
self._change_state(self.States.SAVING)
self._save_metadata(self.audio_meta.video_id, platform, task_id)
self._change_state(self.States.SUCCESS)
return NoteResult(markdown=self.markdown, transcript=self.transcript, audio_meta=self.audio_meta)
except Exception as e:
logger.exception(f"任务 {self.task_id} 失败: {e}")
self._change_state(self.States.FAILED, str(e))
return None
def _change_state(self, state: str, message: Optional[str] = None):
if not self.task_id:
return
NOTE_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
status_file = NOTE_OUTPUT_DIR / f"{self.task_id}.status.json"
data = {"status": state}
if message:
data["message"] = message
temp_file = status_file.with_suffix('.tmp')
with temp_file.open('w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
temp_file.replace(status_file)
def _init_transcriber(self) -> Transcriber:
if self.transcriber_type not in _transcribers:
raise Exception(f"不支持的转写器:{self.transcriber_type}")
return get_transcriber(self.transcriber_type)
def _get_gpt(self, model_name: Optional[str], provider_id: Optional[str]) -> GPT:
provider = ProviderService.get_provider_by_id(provider_id)
if not provider:
logger.error(f"[get_gpt] 未找到对应的模型供应商: provider_id={provider_id}")
raise ValueError(f"未找到对应的模型供应商: provider_id={provider_id}")
gpt = GPTFactory().from_config(
ModelConfig(
api_key=provider.get('api_key'),
base_url=provider.get('base_url'),
model_name=model_name,
provider=provider.get('type'),
name=provider.get('name')
)
raise ProviderError(code=ProviderErrorEnum.NOT_FOUND, message=ProviderErrorEnum.NOT_FOUND.message)
config = ModelConfig(
api_key=provider["api_key"], base_url=provider["base_url"],
model_name=model_name, provider=provider["type"], name=provider["name"]
)
return gpt
return GPTFactory().from_config(config)
def get_downloader(self, platform: str) -> Downloader:
downloader = SUPPORT_PLATFORM_MAP[platform]
if downloader:
logger.info(f"使用{downloader}下载器")
return downloader
else:
logger.warning("不支持的平台")
raise ValueError(f"不支持的平台:{platform}")
def _get_downloader(self, platform: str) -> Downloader:
downloader_cls = SUPPORT_PLATFORM_MAP.get(platform)
if not downloader_cls:
raise NoteError(code=NoteErrorEnum.PLATFORM_NOT_SUPPORTED.code,
message=NoteErrorEnum.PLATFORM_NOT_SUPPORTED.message)
return downloader_cls
def get_transcriber(self) -> Transcriber:
'''
def _download_audio_video(self, downloader, video_url, quality, output_path,
screenshot, video_understanding, video_interval, grid_size):
self._change_state(self.States.DOWNLOADING)
:param transcriber: 选择的转义器
:return:
'''
if self.transcriber_type in _transcribers.keys():
logger.info(f"使用{self.transcriber_type}转义器")
return get_transcriber(transcriber_type=self.transcriber_type)
else:
logger.warning("不支持的转义器")
raise ValueError(f"不支持的转义器:{self.transcriber}")
need_video = screenshot or video_understanding
if need_video:
self.video_path = Path(downloader.download_video(video_url, output_path))
if grid_size:
self.video_img_urls = VideoReader(
video_path=str(self.video_path),
grid_size=tuple(grid_size),
frame_interval=video_interval,
unit_width=1280, unit_height=720,
save_quality=90,
).run()
def save_meta(self, video_id, platform, task_id):
logger.info(f"记录已经生成的数据信息")
insert_video_task(video_id=video_id, platform=platform, task_id=task_id)
if self.audio_cache_file.exists():
with open(self.audio_cache_file, "r", encoding="utf-8") as f:
data = json.load(f)
return AudioDownloadResult(**data)
def insert_screenshots_into_markdown(self, markdown: str, video_path: str, image_base_url: str,
output_dir: str, _format: list) -> str:
"""
扫描 markdown 中的 *Screenshot-xx:xx生成截图并插入 markdown 图片
:param markdown:
:param image_base_url: 最终返回给前端的路径前缀(如 /static/screenshots
"""
matches = self.extract_screenshot_timestamps(markdown)
new_markdown = markdown
audio = downloader.download(
video_url=video_url, quality=quality, output_dir=output_path, need_video=need_video
)
with open(self.audio_cache_file, "w", encoding="utf-8") as f:
json.dump(asdict(audio), f, ensure_ascii=False, indent=2)
return audio
logger.info(f"开始为笔记生成截图")
try:
for idx, (marker, ts) in enumerate(matches):
image_path = generate_screenshot(video_path, output_dir, ts, idx)
image_relative_path = os.path.join(image_base_url, os.path.basename(image_path)).replace("\\", "/")
image_url = f"/static/screenshots/{os.path.basename(image_path)}"
replacement = f"![]({image_url})"
new_markdown = new_markdown.replace(marker, replacement, 1)
def _transcribe_audio(self):
self._change_state(self.States.TRANSCRIBING)
if self.transcript_cache_file.exists():
with open(self.transcript_cache_file, "r", encoding="utf-8") as f:
data = json.load(f)
segments = [TranscriptSegment(**seg) for seg in data.get("segments", [])]
return TranscriptResult(language=data["language"], full_text=data["full_text"], segments=segments)
return new_markdown
except Exception as e:
logger.error(f"截图生成失败:{e}")
raise HTTPException(
status_code=500,
detail={
"code": StatusCode.DOWNLOAD_ERROR,
"msg": f"截图生成失败",
"error": str(e)
}
)
transcript = self.transcriber.transcript(self.audio_meta.file_path)
with open(self.transcript_cache_file, "w", encoding="utf-8") as f:
json.dump(asdict(transcript), f, ensure_ascii=False, indent=2)
return transcript
def _summarize_text(self, gpt, link, screenshot, formats, style, extras):
self._change_state(self.States.SUMMARIZING)
source = GPTSource(
title=self.audio_meta.title,
segment=self.transcript.segments,
tags=self.audio_meta.raw_info.get("tags", []),
screenshot=screenshot,
video_img_urls=self.video_img_urls,
link=link, _format=formats, style=style, extras=extras
)
markdown = gpt.summarize(source)
with open(self.markdown_cache_file, "w", encoding="utf-8") as f:
f.write(markdown)
return markdown
@staticmethod
def delete_note(video_id: str, platform: str):
logger.info(f"删除生成的笔记记录")
return delete_task_by_video(video_id, platform)
def _post_process_markdown(self, markdown, video_path, formats, audio_meta, platform):
if "screenshot" in formats and video_path:
markdown = self._insert_screenshots(markdown, video_path)
if "link" in formats:
markdown = replace_content_markers(markdown, video_id=audio_meta.video_id, platform=platform)
return markdown
import re
def extract_screenshot_timestamps(self, markdown: str) -> list[tuple[str, int]]:
"""
从 Markdown 中提取 Screenshot 时间标记(如 *Screenshot-03:39 或 Screenshot-[03:39]
并返回匹配文本和对应时间戳(秒)
"""
logger.info(f"开始提取截图时间标记")
def _insert_screenshots(self, markdown, video_path):
pattern = r"(?:\*Screenshot-(\d{2}):(\d{2})|Screenshot-\[(\d{2}):(\d{2})\])"
matches = list(re.finditer(pattern, markdown))
results = []
for match in matches:
matches = []
for match in re.finditer(pattern, markdown):
mm = match.group(1) or match.group(3)
ss = match.group(2) or match.group(4)
total_seconds = int(mm) * 60 + int(ss)
results.append((match.group(0), total_seconds))
return results
matches.append((match.group(0), int(mm)*60+int(ss)))
for idx, (marker, ts) in enumerate(matches):
img_path = generate_screenshot(str(video_path), str(IMAGE_OUTPUT_DIR), ts, idx)
filename = Path(img_path).name
img_url = f"{IMAGE_BASE_URL.rstrip('/')}/{filename}"
markdown = markdown.replace(marker, f"![]({img_url})", 1)
return markdown
def generate(
self,
video_url: Union[str, HttpUrl],
platform: str,
quality: DownloadQuality = DownloadQuality.medium,
task_id: Union[str, None] = None,
model_name: str = None,
provider_id: str = None,
link: bool = False,
screenshot: bool = False,
_format: list = None,
style: str = None,
extras: str = None,
path: Union[str, None] = None,
video_understanding: bool = False,
video_interval=0,
grid_size=[]
) -> NoteResult:
def _save_metadata(self, video_id: str, platform: str, task_id: str):
insert_video_task(video_id=video_id, platform=platform, task_id=task_id)
try:
logger.info(f"🎯 开始解析并生成笔记task_id={task_id}")
self.update_task_status(task_id, TaskStatus.PARSING)
downloader = self.get_downloader(platform)
gpt = self.get_gpt(model_name=model_name, provider_id=provider_id)
video_img_urls = []
audio_cache_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}_audio.json")
transcript_cache_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}_transcript.json")
markdown_cache_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}_markdown.md")
# -------- 1. 下载音频 --------
try:
self.update_task_status(task_id, TaskStatus.DOWNLOADING)
# 加载音频缓存(如果存在)
audio = None
if os.path.exists(audio_cache_path):
logger.info(f"检测到已有音频缓存直接读取task_id={task_id}")
with open(audio_cache_path, "r", encoding="utf-8") as f:
audio_data = json.load(f)
audio = AudioDownloadResult(**audio_data)
# 需要视频的情况(截图 or 视频理解)
need_video = 'screenshot' in _format or video_understanding
if need_video:
try:
video_path = downloader.download_video(video_url)
self.video_path = video_path
logger.info(f"成功下载视频文件: {video_path}")
video_img_urls = VideoReader(
video_path=video_path,
grid_size=tuple(grid_size),
frame_interval=video_interval,
unit_width=1280,
unit_height=720,
save_quality=90,
).run()
except Exception as e:
logger.error(f"Error 下载视频失败task_id={task_id},错误信息:{e}")
self.update_task_status(task_id, TaskStatus.FAILED, message=f"下载音频失败:{e}")
raise HTTPException(
status_code=500,
detail={
"code": StatusCode.DOWNLOAD_ERROR,
"msg": f"下载视频失败task_id={task_id}",
"error": str(e)
}
)
# 没有音频缓存就下载音频(可能同时也带上视频)
if audio is None:
audio = downloader.download(
video_url=video_url,
quality=quality,
output_dir=path,
need_video='screenshot' in _format, # 注意这里只为了截图需要
)
with open(audio_cache_path, "w", encoding="utf-8") as f:
json.dump(asdict(audio), f, ensure_ascii=False, indent=2)
logger.info(f"音频下载并缓存成功task_id={task_id}")
except Exception as e:
logger.error(f"Error 下载音频失败task_id={task_id},错误信息:{e}")
self.update_task_status(task_id, TaskStatus.FAILED, message=f"下载音频失败:{e}")
raise HTTPException(
status_code=500,
detail={
"code": StatusCode.DOWNLOAD_ERROR,
"msg": f"下载音频失败task_id={task_id}",
"error": str(e)
}
)
# -------- 2. 转写文字 --------
try:
self.update_task_status(task_id, TaskStatus.TRANSCRIBING)
if os.path.exists(transcript_cache_path):
logger.info(f"检测到已有转写缓存直接读取task_id={task_id}")
try:
with open(transcript_cache_path, "r", encoding="utf-8") as f:
transcript_data = json.load(f)
transcript = TranscriptResult(
language=transcript_data["language"],
full_text=transcript_data["full_text"],
segments=[TranscriptSegment(**seg) for seg in transcript_data["segments"]]
)
except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"Warning 读取转录缓存失败重新转录task_id={task_id},错误信息:{e}")
transcript: TranscriptResult = self.transcriber.transcript(file_path=audio.file_path)
with open(transcript_cache_path, "w", encoding="utf-8") as f:
json.dump(asdict(transcript), f, ensure_ascii=False, indent=2)
else:
transcript: TranscriptResult = self.transcriber.transcript(file_path=audio.file_path)
with open(transcript_cache_path, "w", encoding="utf-8") as f:
json.dump(asdict(transcript), f, ensure_ascii=False, indent=2)
logger.info(f"文字转写并缓存成功task_id={task_id}")
except Exception as e:
logger.error(f"Error 转写文字失败task_id={task_id},错误信息:{e}")
self.update_task_status(task_id, TaskStatus.FAILED, message=f"转写文字失败:{e}")
raise HTTPException(
status_code=500,
detail={
"code": StatusCode.GENERATE_ERROR, # =1003
"msg": f"转写文字失败task_id={task_id}",
"error": str(e)
}
)
# -------- 3. 总结内容 --------
try:
self.update_task_status(task_id, TaskStatus.SUMMARIZING)
# if os.path.exists(markdown_cache_path):
# logger.info(f"检测到已有总结缓存直接读取task_id={task_id}")
# with open(markdown_cache_path, "r", encoding="utf-8") as f:
# markdown = f.read()
# else:
source = GPTSource(
title=audio.title,
segment=transcript.segments,
tags=audio.raw_info.get('tags'),
screenshot=screenshot,
video_img_urls=video_img_urls,
link=link,
_format=_format,
style=style,
extras=extras
)
markdown: str = gpt.summarize(source)
with open(markdown_cache_path, "w", encoding="utf-8") as f:
f.write(markdown)
logger.info(f"GPT总结并缓存成功task_id={task_id}")
except Exception as e:
logger.error(f"Error 总结内容失败task_id={task_id},错误信息:{e}")
self.update_task_status(task_id, TaskStatus.FAILED, message=f"总结内容失败:{e}")
raise HTTPException(
status_code=500,
detail={
"code": StatusCode.GENERATE_ERROR, # =1003
"msg": f"总结内容失败task_id={task_id}",
"error": str(e)
}
)
# -------- 4. 插入截图 --------
if _format and 'screenshot' in _format:
try:
markdown = self.insert_screenshots_into_markdown(markdown, self.video_path, image_base_url,
output_dir, _format)
except Exception as e:
logger.warning(f"Warning 插入截图失败跳过处理task_id={task_id},错误信息:{e}")
if _format and 'link' in _format:
try:
markdown = replace_content_markers(markdown, video_id=audio.video_id, platform=platform)
except Exception as e:
logger.warning(f"Warning 插入链接失败跳过处理task_id={task_id},错误信息:{e}")
# 注意:截图失败不终止整体流程
# -------- 5. 保存数据库记录 --------
self.update_task_status(task_id, TaskStatus.SAVING)
self.save_meta(video_id=audio.video_id, platform=platform, task_id=task_id)
# -------- 6. 完成 --------
self.update_task_status(task_id, TaskStatus.SUCCESS)
logger.info(f"succeed 笔记生成成功task_id={task_id}")
# TODO :改为前端一键清除缓存
# if platform != 'local':
# transcription_finished.send({
# "file_path": audio.file_path,
# })
return NoteResult(
markdown=markdown,
transcript=transcript,
audio_meta=audio
)
except Exception as e:
logger.error(f"Error 笔记生成流程异常终止task_id={task_id},错误信息:{e}")
self.update_task_status(task_id, TaskStatus.FAILED, message=str(e))
# 返回结构化错误信息给前端(可以用于日志 + 显示 + 错误定位)
raise HTTPException(
status_code=500,
detail={
"code": StatusCode.FAIL,
"msg": f"笔记生成流程异常终止task_id={task_id}",
"error": str(e)
}
)
@staticmethod
def delete_note(video_id: str, platform: str) -> int:
return delete_task_by_video(video_id, platform)

View File

@@ -1,18 +1,24 @@
from fastapi.responses import JSONResponse
from app.utils.status_code import StatusCode
from pydantic import BaseModel
from typing import Optional, Any
from fastapi.responses import JSONResponse
class ResponseWrapper:
@staticmethod
def success(data=None, msg="success", code=StatusCode.SUCCESS):
return {
"code": int(code),
def success(data=None, msg="success", code=0):
return JSONResponse(content={
"code": code,
"msg": msg,
"data": data
}
})
@staticmethod
def error(msg="error", code=StatusCode.FAIL, data=None):
return {
"code": int(code),
def error(msg="error", code=500, data=None):
return JSONResponse(content={
"code": code,
"msg": msg,
"data": data
}
})

View File

@@ -4,7 +4,7 @@ import uvicorn
from starlette.staticfiles import StaticFiles
from dotenv import load_dotenv
from app.core.exception_handlers import register_exception_handlers
from app.exceptions.exception_handlers import register_exception_handlers
from app.db.model_dao import init_model_table
from app.db.provider_dao import init_provider_table
from app.utils.logger import get_logger
@@ -33,11 +33,14 @@ if not os.path.exists(out_dir):
os.makedirs(out_dir)
app = create_app()
register_exception_handlers(app)
app.mount(static_path, StaticFiles(directory=static_dir), name="static")
app.mount("/uploads", StaticFiles(directory=uploads_dir), name="uploads")
@app.on_event("startup")
async def startup_event():
register_exception_handlers(app)
register_handler()
ensure_ffmpeg_or_raise()
register_handler()
@@ -46,8 +49,9 @@ async def startup_event():
init_provider_table()
init_model_table()
if __name__ == "__main__":
port = int(os.getenv("BACKEND_PORT", 8000))
host = os.getenv("BACKEND_HOST", "0.0.0.0")
logger.info(f"Starting server on {host}:{port}")
uvicorn.run("main:app", host=host, port=port, reload=False)
uvicorn.run(app, host=host, port=port, reload=False)

View File

@@ -1,6 +1,6 @@
server {
listen 80;
client_max_body_size 10G;
# 所有非 /api 请求全部代理给 frontend 容器
location / {
proxy_pass http://frontend:80;
@@ -11,8 +11,6 @@ server {
proxy_pass http://backend:8000;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
client_max_body_size 10G;
client_body_buffer_size 128k;
}
location /static/ {