refactor(backend): 重构后端异常处理和模型管理

- 新增自定义异常类 BizException、NoteError 和 ProviderError
- 优化了模型管理相关的逻辑,包括加载、删除和测试连接等功能
- 改进了 Douyin 下载器的错误处理
- 调整了任务重试逻辑和笔记生成的异常处理- 更新了相关组件和页面以适应新的异常处理机制
This commit is contained in:
JefferyHcool
2025-06-06 21:30:23 +08:00
parent df5c0f771a
commit 8b1bc54f2d
34 changed files with 661 additions and 660 deletions

View File

@@ -1,75 +1,63 @@
import json
from dataclasses import asdict
from fastapi import HTTPException
from app.downloaders.local_downloader import LocalDownloader
from app.enmus.task_status_enums import TaskStatus
import logging
import os
from typing import Union, Optional
import re
from dataclasses import asdict
from pathlib import Path
from typing import List, Optional, Tuple, Union, Any
from pydantic import HttpUrl
from dotenv import load_dotenv
from app.db.video_task_dao import insert_video_task, delete_task_by_video
from app.downloaders.base import Downloader
from app.downloaders.bilibili_downloader import BilibiliDownloader
from app.downloaders.douyin_downloader import DouyinDownloader
from app.downloaders.youtube_downloader import YoutubeDownloader
from app.services.constant import SUPPORT_PLATFORM_MAP
from app.enmus.task_status_enums import TaskStatus
from app.enmus.exception import NoteErrorEnum, ProviderErrorEnum
from app.exceptions.note import NoteError
from app.exceptions.provider import ProviderError
from app.db.video_task_dao import delete_task_by_video, insert_video_task
from app.gpt.base import GPT
from app.gpt.deepseek_gpt import DeepSeekGPT
from app.gpt.gpt_factory import GPTFactory
from app.gpt.openai_gpt import OpenaiGPT
from app.gpt.qwen_gpt import QwenGPT
from app.models.audio_model import AudioDownloadResult
from app.models.gpt_model import GPTSource
from app.models.model_config import ModelConfig
from app.models.notes_model import NoteResult
from app.models.notes_model import AudioDownloadResult
from app.enmus.note_enums import DownloadQuality
from app.models.transcriber_model import TranscriptResult, TranscriptSegment
from app.services.constant import SUPPORT_PLATFORM_MAP
from app.services.provider import ProviderService
from app.transcriber.base import Transcriber
from app.transcriber.transcriber_provider import get_transcriber, _transcribers
from app.transcriber.whisper import WhisperTranscriber
import re
from app.utils.note_helper import replace_content_markers
from app.utils.status_code import StatusCode
from app.utils.video_helper import generate_screenshot
# from app.services.whisperer import transcribe_audio
# from app.services.gpt import summarize_text
from dotenv import load_dotenv
from app.utils.logger import get_logger
from app.utils.video_reader import VideoReader
from events import transcription_finished
from app.utils.video_helper import generate_screenshot
from app.utils.note_helper import replace_content_markers
from app.enmus.note_enums import DownloadQuality
logger = get_logger(__name__)
# 环境变量
load_dotenv()
api_path = os.getenv("API_BASE_URL", "http://localhost")
BACKEND_PORT = os.getenv("BACKEND_PORT", 8000)
NOTE_OUTPUT_DIR = Path(os.getenv("NOTE_OUTPUT_DIR", "note_results"))
NOTE_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
IMAGE_BASE_URL = os.getenv("IMAGE_BASE_URL", "/static/screenshots")
IMAGE_OUTPUT_DIR = os.getenv("OUT_DIR", "images")
BACKEND_BASE_URL = f"{api_path}:{BACKEND_PORT}"
output_dir = os.getenv('OUT_DIR')
image_base_url = os.getenv('IMAGE_BASE_URL')
logger.info("starting up")
NOTE_OUTPUT_DIR = "note_results"
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class NoteGenerator:
class States:
INIT = 'INIT'
PARSING = 'PARSING'
DOWNLOADING = 'DOWNLOADING'
TRANSCRIBING = 'TRANSCRIBING'
SUMMARIZING = 'SUMMARIZING'
SAVING = 'SAVING'
SUCCESS = 'SUCCESS'
FAILED = 'FAILED'
def __init__(self):
self.model_size: str = 'base'
self.device: Union[str, None] = None
self.transcriber_type = os.getenv('TRANSCRIBER_TYPE', 'fast-whisper')
self.transcriber = self.get_transcriber()
self.video_path = None
logger.info("初始化NoteGenerator")
import logging
logger = logging.getLogger(__name__)
self.transcriber_type = os.getenv("TRANSCRIBER_TYPE", "fast-whisper")
self.transcriber: Transcriber = self._init_transcriber()
self.video_img_urls = []
@staticmethod
def update_task_status(task_id: str, status: Union[str, TaskStatus], message: Optional[str] = None):
@@ -81,310 +69,179 @@ class NoteGenerator:
with open(path, "w", encoding="utf-8") as f:
json.dump(content, f, ensure_ascii=False, indent=2)
def get_gpt(self, model_name: str = None, provider_id: str = None) -> GPT:
def generate(
self,
video_url: Union[str, HttpUrl],
platform: str,
quality: DownloadQuality = DownloadQuality.medium,
task_id: Optional[str] = None,
model_name: Optional[str] = None,
provider_id: Optional[str] = None,
link: bool = False,
screenshot: bool = False,
_format: Optional[List[str]] = None,
style: Optional[str] = None,
extras: Optional[str] = None,
output_path: Optional[str] = None,
video_understanding: bool = False,
video_interval: int = 0,
grid_size: Optional[List[int]] = None,
) -> NoteResult | None:
self.task_id = task_id
self._change_state(self.States.INIT)
try:
self._change_state(self.States.PARSING)
downloader = self._get_downloader(platform)
gpt = self._get_gpt(model_name, provider_id)
self.audio_cache_file = NOTE_OUTPUT_DIR / f"{task_id}_audio.json"
self.transcript_cache_file = NOTE_OUTPUT_DIR / f"{task_id}_transcript.json"
self.markdown_cache_file = NOTE_OUTPUT_DIR / f"{task_id}_markdown.md"
self.audio_meta = self._download_audio_video(
downloader, video_url, quality, output_path,
screenshot, video_understanding, video_interval, grid_size or []
)
self.transcript = self._transcribe_audio()
self.markdown = self._summarize_text(
gpt, link, screenshot, _format or [], style, extras
)
self.markdown = self._post_process_markdown(
self.markdown, self.video_path, _format or [], self.audio_meta, platform
)
self._change_state(self.States.SAVING)
self._save_metadata(self.audio_meta.video_id, platform, task_id)
self._change_state(self.States.SUCCESS)
return NoteResult(markdown=self.markdown, transcript=self.transcript, audio_meta=self.audio_meta)
except Exception as e:
logger.exception(f"任务 {self.task_id} 失败: {e}")
self._change_state(self.States.FAILED, str(e))
return None
def _change_state(self, state: str, message: Optional[str] = None):
if not self.task_id:
return
NOTE_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
status_file = NOTE_OUTPUT_DIR / f"{self.task_id}.status.json"
data = {"status": state}
if message:
data["message"] = message
temp_file = status_file.with_suffix('.tmp')
with temp_file.open('w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
temp_file.replace(status_file)
def _init_transcriber(self) -> Transcriber:
if self.transcriber_type not in _transcribers:
raise Exception(f"不支持的转写器:{self.transcriber_type}")
return get_transcriber(self.transcriber_type)
def _get_gpt(self, model_name: Optional[str], provider_id: Optional[str]) -> GPT:
provider = ProviderService.get_provider_by_id(provider_id)
if not provider:
logger.error(f"[get_gpt] 未找到对应的模型供应商: provider_id={provider_id}")
raise ValueError(f"未找到对应的模型供应商: provider_id={provider_id}")
gpt = GPTFactory().from_config(
ModelConfig(
api_key=provider.get('api_key'),
base_url=provider.get('base_url'),
model_name=model_name,
provider=provider.get('type'),
name=provider.get('name')
)
raise ProviderError(code=ProviderErrorEnum.NOT_FOUND, message=ProviderErrorEnum.NOT_FOUND.message)
config = ModelConfig(
api_key=provider["api_key"], base_url=provider["base_url"],
model_name=model_name, provider=provider["type"], name=provider["name"]
)
return gpt
return GPTFactory().from_config(config)
def get_downloader(self, platform: str) -> Downloader:
downloader = SUPPORT_PLATFORM_MAP[platform]
if downloader:
logger.info(f"使用{downloader}下载器")
return downloader
else:
logger.warning("不支持的平台")
raise ValueError(f"不支持的平台:{platform}")
def _get_downloader(self, platform: str) -> Downloader:
downloader_cls = SUPPORT_PLATFORM_MAP.get(platform)
if not downloader_cls:
raise NoteError(code=NoteErrorEnum.PLATFORM_NOT_SUPPORTED.code,
message=NoteErrorEnum.PLATFORM_NOT_SUPPORTED.message)
return downloader_cls
def get_transcriber(self) -> Transcriber:
'''
def _download_audio_video(self, downloader, video_url, quality, output_path,
screenshot, video_understanding, video_interval, grid_size):
self._change_state(self.States.DOWNLOADING)
:param transcriber: 选择的转义器
:return:
'''
if self.transcriber_type in _transcribers.keys():
logger.info(f"使用{self.transcriber_type}转义器")
return get_transcriber(transcriber_type=self.transcriber_type)
else:
logger.warning("不支持的转义器")
raise ValueError(f"不支持的转义器:{self.transcriber}")
need_video = screenshot or video_understanding
if need_video:
self.video_path = Path(downloader.download_video(video_url, output_path))
if grid_size:
self.video_img_urls = VideoReader(
video_path=str(self.video_path),
grid_size=tuple(grid_size),
frame_interval=video_interval,
unit_width=1280, unit_height=720,
save_quality=90,
).run()
def save_meta(self, video_id, platform, task_id):
logger.info(f"记录已经生成的数据信息")
insert_video_task(video_id=video_id, platform=platform, task_id=task_id)
if self.audio_cache_file.exists():
with open(self.audio_cache_file, "r", encoding="utf-8") as f:
data = json.load(f)
return AudioDownloadResult(**data)
def insert_screenshots_into_markdown(self, markdown: str, video_path: str, image_base_url: str,
output_dir: str, _format: list) -> str:
"""
扫描 markdown 中的 *Screenshot-xx:xx生成截图并插入 markdown 图片
:param markdown:
:param image_base_url: 最终返回给前端的路径前缀(如 /static/screenshots
"""
matches = self.extract_screenshot_timestamps(markdown)
new_markdown = markdown
audio = downloader.download(
video_url=video_url, quality=quality, output_dir=output_path, need_video=need_video
)
with open(self.audio_cache_file, "w", encoding="utf-8") as f:
json.dump(asdict(audio), f, ensure_ascii=False, indent=2)
return audio
logger.info(f"开始为笔记生成截图")
try:
for idx, (marker, ts) in enumerate(matches):
image_path = generate_screenshot(video_path, output_dir, ts, idx)
image_relative_path = os.path.join(image_base_url, os.path.basename(image_path)).replace("\\", "/")
image_url = f"/static/screenshots/{os.path.basename(image_path)}"
replacement = f"![]({image_url})"
new_markdown = new_markdown.replace(marker, replacement, 1)
def _transcribe_audio(self):
self._change_state(self.States.TRANSCRIBING)
if self.transcript_cache_file.exists():
with open(self.transcript_cache_file, "r", encoding="utf-8") as f:
data = json.load(f)
segments = [TranscriptSegment(**seg) for seg in data.get("segments", [])]
return TranscriptResult(language=data["language"], full_text=data["full_text"], segments=segments)
return new_markdown
except Exception as e:
logger.error(f"截图生成失败:{e}")
raise HTTPException(
status_code=500,
detail={
"code": StatusCode.DOWNLOAD_ERROR,
"msg": f"截图生成失败",
"error": str(e)
}
)
transcript = self.transcriber.transcript(self.audio_meta.file_path)
with open(self.transcript_cache_file, "w", encoding="utf-8") as f:
json.dump(asdict(transcript), f, ensure_ascii=False, indent=2)
return transcript
def _summarize_text(self, gpt, link, screenshot, formats, style, extras):
self._change_state(self.States.SUMMARIZING)
source = GPTSource(
title=self.audio_meta.title,
segment=self.transcript.segments,
tags=self.audio_meta.raw_info.get("tags", []),
screenshot=screenshot,
video_img_urls=self.video_img_urls,
link=link, _format=formats, style=style, extras=extras
)
markdown = gpt.summarize(source)
with open(self.markdown_cache_file, "w", encoding="utf-8") as f:
f.write(markdown)
return markdown
@staticmethod
def delete_note(video_id: str, platform: str):
logger.info(f"删除生成的笔记记录")
return delete_task_by_video(video_id, platform)
def _post_process_markdown(self, markdown, video_path, formats, audio_meta, platform):
if "screenshot" in formats and video_path:
markdown = self._insert_screenshots(markdown, video_path)
if "link" in formats:
markdown = replace_content_markers(markdown, video_id=audio_meta.video_id, platform=platform)
return markdown
import re
def extract_screenshot_timestamps(self, markdown: str) -> list[tuple[str, int]]:
"""
从 Markdown 中提取 Screenshot 时间标记(如 *Screenshot-03:39 或 Screenshot-[03:39]
并返回匹配文本和对应时间戳(秒)
"""
logger.info(f"开始提取截图时间标记")
def _insert_screenshots(self, markdown, video_path):
pattern = r"(?:\*Screenshot-(\d{2}):(\d{2})|Screenshot-\[(\d{2}):(\d{2})\])"
matches = list(re.finditer(pattern, markdown))
results = []
for match in matches:
matches = []
for match in re.finditer(pattern, markdown):
mm = match.group(1) or match.group(3)
ss = match.group(2) or match.group(4)
total_seconds = int(mm) * 60 + int(ss)
results.append((match.group(0), total_seconds))
return results
matches.append((match.group(0), int(mm)*60+int(ss)))
for idx, (marker, ts) in enumerate(matches):
img_path = generate_screenshot(str(video_path), str(IMAGE_OUTPUT_DIR), ts, idx)
filename = Path(img_path).name
img_url = f"{IMAGE_BASE_URL.rstrip('/')}/{filename}"
markdown = markdown.replace(marker, f"![]({img_url})", 1)
return markdown
def generate(
self,
video_url: Union[str, HttpUrl],
platform: str,
quality: DownloadQuality = DownloadQuality.medium,
task_id: Union[str, None] = None,
model_name: str = None,
provider_id: str = None,
link: bool = False,
screenshot: bool = False,
_format: list = None,
style: str = None,
extras: str = None,
path: Union[str, None] = None,
video_understanding: bool = False,
video_interval=0,
grid_size=[]
) -> NoteResult:
def _save_metadata(self, video_id: str, platform: str, task_id: str):
insert_video_task(video_id=video_id, platform=platform, task_id=task_id)
try:
logger.info(f"🎯 开始解析并生成笔记task_id={task_id}")
self.update_task_status(task_id, TaskStatus.PARSING)
downloader = self.get_downloader(platform)
gpt = self.get_gpt(model_name=model_name, provider_id=provider_id)
video_img_urls = []
audio_cache_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}_audio.json")
transcript_cache_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}_transcript.json")
markdown_cache_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}_markdown.md")
# -------- 1. 下载音频 --------
try:
self.update_task_status(task_id, TaskStatus.DOWNLOADING)
# 加载音频缓存(如果存在)
audio = None
if os.path.exists(audio_cache_path):
logger.info(f"检测到已有音频缓存直接读取task_id={task_id}")
with open(audio_cache_path, "r", encoding="utf-8") as f:
audio_data = json.load(f)
audio = AudioDownloadResult(**audio_data)
# 需要视频的情况(截图 or 视频理解)
need_video = 'screenshot' in _format or video_understanding
if need_video:
try:
video_path = downloader.download_video(video_url)
self.video_path = video_path
logger.info(f"成功下载视频文件: {video_path}")
video_img_urls = VideoReader(
video_path=video_path,
grid_size=tuple(grid_size),
frame_interval=video_interval,
unit_width=1280,
unit_height=720,
save_quality=90,
).run()
except Exception as e:
logger.error(f"Error 下载视频失败task_id={task_id},错误信息:{e}")
self.update_task_status(task_id, TaskStatus.FAILED, message=f"下载音频失败:{e}")
raise HTTPException(
status_code=500,
detail={
"code": StatusCode.DOWNLOAD_ERROR,
"msg": f"下载视频失败task_id={task_id}",
"error": str(e)
}
)
# 没有音频缓存就下载音频(可能同时也带上视频)
if audio is None:
audio = downloader.download(
video_url=video_url,
quality=quality,
output_dir=path,
need_video='screenshot' in _format, # 注意这里只为了截图需要
)
with open(audio_cache_path, "w", encoding="utf-8") as f:
json.dump(asdict(audio), f, ensure_ascii=False, indent=2)
logger.info(f"音频下载并缓存成功task_id={task_id}")
except Exception as e:
logger.error(f"Error 下载音频失败task_id={task_id},错误信息:{e}")
self.update_task_status(task_id, TaskStatus.FAILED, message=f"下载音频失败:{e}")
raise HTTPException(
status_code=500,
detail={
"code": StatusCode.DOWNLOAD_ERROR,
"msg": f"下载音频失败task_id={task_id}",
"error": str(e)
}
)
# -------- 2. 转写文字 --------
try:
self.update_task_status(task_id, TaskStatus.TRANSCRIBING)
if os.path.exists(transcript_cache_path):
logger.info(f"检测到已有转写缓存直接读取task_id={task_id}")
try:
with open(transcript_cache_path, "r", encoding="utf-8") as f:
transcript_data = json.load(f)
transcript = TranscriptResult(
language=transcript_data["language"],
full_text=transcript_data["full_text"],
segments=[TranscriptSegment(**seg) for seg in transcript_data["segments"]]
)
except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"Warning 读取转录缓存失败重新转录task_id={task_id},错误信息:{e}")
transcript: TranscriptResult = self.transcriber.transcript(file_path=audio.file_path)
with open(transcript_cache_path, "w", encoding="utf-8") as f:
json.dump(asdict(transcript), f, ensure_ascii=False, indent=2)
else:
transcript: TranscriptResult = self.transcriber.transcript(file_path=audio.file_path)
with open(transcript_cache_path, "w", encoding="utf-8") as f:
json.dump(asdict(transcript), f, ensure_ascii=False, indent=2)
logger.info(f"文字转写并缓存成功task_id={task_id}")
except Exception as e:
logger.error(f"Error 转写文字失败task_id={task_id},错误信息:{e}")
self.update_task_status(task_id, TaskStatus.FAILED, message=f"转写文字失败:{e}")
raise HTTPException(
status_code=500,
detail={
"code": StatusCode.GENERATE_ERROR, # =1003
"msg": f"转写文字失败task_id={task_id}",
"error": str(e)
}
)
# -------- 3. 总结内容 --------
try:
self.update_task_status(task_id, TaskStatus.SUMMARIZING)
# if os.path.exists(markdown_cache_path):
# logger.info(f"检测到已有总结缓存直接读取task_id={task_id}")
# with open(markdown_cache_path, "r", encoding="utf-8") as f:
# markdown = f.read()
# else:
source = GPTSource(
title=audio.title,
segment=transcript.segments,
tags=audio.raw_info.get('tags'),
screenshot=screenshot,
video_img_urls=video_img_urls,
link=link,
_format=_format,
style=style,
extras=extras
)
markdown: str = gpt.summarize(source)
with open(markdown_cache_path, "w", encoding="utf-8") as f:
f.write(markdown)
logger.info(f"GPT总结并缓存成功task_id={task_id}")
except Exception as e:
logger.error(f"Error 总结内容失败task_id={task_id},错误信息:{e}")
self.update_task_status(task_id, TaskStatus.FAILED, message=f"总结内容失败:{e}")
raise HTTPException(
status_code=500,
detail={
"code": StatusCode.GENERATE_ERROR, # =1003
"msg": f"总结内容失败task_id={task_id}",
"error": str(e)
}
)
# -------- 4. 插入截图 --------
if _format and 'screenshot' in _format:
try:
markdown = self.insert_screenshots_into_markdown(markdown, self.video_path, image_base_url,
output_dir, _format)
except Exception as e:
logger.warning(f"Warning 插入截图失败跳过处理task_id={task_id},错误信息:{e}")
if _format and 'link' in _format:
try:
markdown = replace_content_markers(markdown, video_id=audio.video_id, platform=platform)
except Exception as e:
logger.warning(f"Warning 插入链接失败跳过处理task_id={task_id},错误信息:{e}")
# 注意:截图失败不终止整体流程
# -------- 5. 保存数据库记录 --------
self.update_task_status(task_id, TaskStatus.SAVING)
self.save_meta(video_id=audio.video_id, platform=platform, task_id=task_id)
# -------- 6. 完成 --------
self.update_task_status(task_id, TaskStatus.SUCCESS)
logger.info(f"succeed 笔记生成成功task_id={task_id}")
# TODO :改为前端一键清除缓存
# if platform != 'local':
# transcription_finished.send({
# "file_path": audio.file_path,
# })
return NoteResult(
markdown=markdown,
transcript=transcript,
audio_meta=audio
)
except Exception as e:
logger.error(f"Error 笔记生成流程异常终止task_id={task_id},错误信息:{e}")
self.update_task_status(task_id, TaskStatus.FAILED, message=str(e))
# 返回结构化错误信息给前端(可以用于日志 + 显示 + 错误定位)
raise HTTPException(
status_code=500,
detail={
"code": StatusCode.FAIL,
"msg": f"笔记生成流程异常终止task_id={task_id}",
"error": str(e)
}
)
@staticmethod
def delete_note(video_id: str, platform: str) -> int:
return delete_task_by_video(video_id, platform)