refactor(backend): 重构后端异常处理和模型管理

- 新增自定义异常类 BizException、NoteError 和 ProviderError
- 优化了模型管理相关的逻辑,包括加载、删除和测试连接等功能
- 改进了 Douyin 下载器的错误处理
- 调整了任务重试逻辑和笔记生成的异常处理- 更新了相关组件和页面以适应新的异常处理机制
This commit is contained in:
JefferyHcool
2025-06-06 21:30:23 +08:00
parent df5c0f771a
commit 8b1bc54f2d
34 changed files with 661 additions and 660 deletions

View File

@@ -1,11 +1,14 @@
from fastapi import FastAPI
from .routers import note, provider, model, config
def create_app() -> FastAPI:
app = FastAPI(title="BiliNote")
app.include_router(note.router, prefix="/api")
app.include_router(provider.router, prefix="/api")
app.include_router(model.router,prefix="/api")
app.include_router(config.router, prefix="/api")
return app

View File

@@ -1,38 +0,0 @@
# app/core/exception_handlers.py
from fastapi import Request, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.utils.logger import get_logger
from app.utils.response import ResponseWrapper
from app.utils.status_code import StatusCode
logger = get_logger(__name__)
def register_exception_handlers(app):
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
errors = []
for err in exc.errors():
loc = err.get("loc", [])
field = loc[-1] if loc else "body"
msg = err.get("msg", "参数不合法")
errors.append({"field": field, "error": msg})
return JSONResponse(
status_code=400,
content=ResponseWrapper.error(msg="参数验证失败", code=StatusCode.PARAM_ERROR, data=errors)
)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
return JSONResponse(
status_code=exc.status_code,
content=ResponseWrapper.error(msg=str(exc.detail), code=StatusCode.FAIL)
)
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
logger.exception(f"服务器内部错误: {exc}")
return JSONResponse(
status_code=500,
content=ResponseWrapper.error(msg="服务器内部错误", code=StatusCode.FAIL, data=str(exc))
)

View File

@@ -136,7 +136,8 @@ def get_provider_by_name(name: str):
if row is None:
logger.info(f"Provider not found: {name}")
return None
logger.info(f"Provider found: {row}")
logger.info(f"Provider found: {row[0]}")
return row
except Exception as e:
logger.error(f"Failed to get provider by name: {e}")
@@ -155,7 +156,7 @@ def get_provider_by_id(id: int):
if row is None:
logger.info(f"Provider not found: {id}")
return None
logger.info(f"Provider found: {row}")
logger.info(f"Provider found: {row[0]}")
return row
except Exception as e:
logger.error(f"Failed to get provider by id: {e}")
@@ -173,7 +174,7 @@ def get_all_providers():
if rows is None:
logger.info("No providers found")
return None
logger.info(f"Providers found: {rows}")
logger.info(f"Providers found total {len(rows) }")
return rows
except Exception as e:
logger.error(f"Failed to get all providers: {e}")

View File

@@ -145,53 +145,59 @@ class DouyinDownloader(Downloader):
return ""
def gen_real_msToken(self) -> str:
payload = json.dumps(
{
"magic": self.ms_token_config["magic"],
"version": self.ms_token_config["version"],
"dataType": self.ms_token_config["dataType"],
"strData": self.ms_token_config["strData"],
"tspFromClient": get_timestamp(),
try:
payload = json.dumps(
{
"magic": self.ms_token_config["magic"],
"version": self.ms_token_config["version"],
"dataType": self.ms_token_config["dataType"],
"strData": self.ms_token_config["strData"],
"tspFromClient": get_timestamp(),
}
)
headers = {
"User-Agent": self.headers_config["User-Agent"],
"Content-Type": "application/json",
}
)
headers = {
"User-Agent": self.headers_config["User-Agent"],
"Content-Type": "application/json",
}
transport = httpx.HTTPTransport(retries=5)
with httpx.Client(transport=transport) as client:
try:
response = client.post(
self.ms_token_config["url"], content=payload, headers=headers
)
response.raise_for_status()
transport = httpx.HTTPTransport(retries=5)
with httpx.Client(transport=transport) as client:
try:
response = client.post(
self.ms_token_config["url"], content=payload, headers=headers
)
response.raise_for_status()
msToken = str(httpx.Cookies(response.cookies).get("msToken"))
if len(msToken) not in [120, 128]:
raise ValueError("响应内容:{0} Douyin msToken API 的响应内容不符合要求。".format(msToken))
msToken = str(httpx.Cookies(response.cookies).get("msToken"))
if len(msToken) not in [120, 128]:
raise ValueError("响应内容:{0} Douyin msToken API 的响应内容不符合要求。".format(msToken))
return msToken
except Exception as e:
raise ValueError("Douyin msToken API 请求失败:{0}".format(e))
return msToken
except Exception as e:
raise ValueError("Douyin msToken API 请求失败:{0}".format(e))
except Exception as e:
raise ValueError("Douyin msToken API{0}".format(e))
def fetch_video_info(self, video_url: str) -> json:
aweme_id = self.extract_video_id(video_url)
kwargs = self.headers_config
print("kwargs:", kwargs)
base_params = BaseRequestModel().model_dump()
base_params["msToken"] = self.gen_real_msToken()
base_params["aweme_id"] = aweme_id
bogus = ABogus()
ab_value = bogus.get_value(base_params)
a_bogus = quote(ab_value, safe='')
print(base_params)
query_str = urlencode(base_params)
full_url = f"{DOUYIN_DOMAIN}/aweme/v1/web/aweme/detail/?{query_str}&a_bogus={a_bogus}"
print("Request URL:", full_url)
try:
aweme_id = self.extract_video_id(video_url)
kwargs = self.headers_config
print("@kwargs:", kwargs)
base_params = BaseRequestModel().model_dump()
base_params["msToken"] = self.gen_real_msToken()
base_params["aweme_id"] = aweme_id
bogus = ABogus()
ab_value = bogus.get_value(base_params)
a_bogus = quote(ab_value, safe='')
print("@a_bogus:", a_bogus)
print(base_params)
query_str = urlencode(base_params)
full_url = f"{DOUYIN_DOMAIN}/aweme/v1/web/aweme/detail/?{query_str}&a_bogus={a_bogus}"
print("Request URL:", full_url)
response = requests.get(full_url, headers=kwargs)
print("Response JSON:", response.content)
@@ -208,46 +214,49 @@ class DouyinDownloader(Downloader):
quality: DownloadQuality = "fast",
need_video: Optional[bool] = False
) -> AudioDownloadResult:
print(
f"正在下载视频: {video_url},保存路径: {output_dir},质量: {quality}"
)
if output_dir is None:
output_dir = get_data_dir()
if not output_dir:
output_dir = self.cache_data
os.makedirs(output_dir, exist_ok=True)
try:
print(
f"正在下载视频: {video_url},保存路径: {output_dir},质量: {quality}"
)
if output_dir is None:
output_dir = get_data_dir()
if not output_dir:
output_dir = self.cache_data
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "%(id)s.%(ext)s")
output_path = os.path.join(output_dir, "%(id)s.%(ext)s")
video_data = self.fetch_video_info(video_url)
output_path = output_path % {
"id": video_data['aweme_detail']['aweme_id'],
"ext": "mp3",
}
url = video_data['aweme_detail']['music']['play_url']['uri']
# 下载音频
audio_data = requests.get(url)
with open(output_path, 'wb') as f:
f.write(audio_data.content)
print(url)
tags = []
for tag in video_data['aweme_detail']['video_tag']:
if tag['tag_name']:
tags.append(tag['tag_name'])
video_data = self.fetch_video_info(video_url)
output_path = output_path % {
"id": video_data['aweme_detail']['aweme_id'],
"ext": "mp3",
}
url = video_data['aweme_detail']['music']['play_url']['uri']
# 下载音频
audio_data = requests.get(url)
with open(output_path, 'wb') as f:
f.write(audio_data.content)
print(url)
tags = []
for tag in video_data['aweme_detail']['video_tag']:
if tag['tag_name']:
tags.append(tag['tag_name'])
return AudioDownloadResult(
file_path=output_path,
title=video_data['aweme_detail']['item_title'],
duration=video_data['aweme_detail']['video']['duration'],
cover_url=video_data['aweme_detail']['video']['cover_original_scale']['url_list'][0] if
video_data['aweme_detail']['video']['cover'] else video_data['video']['big_thumbs']['img_url'],
platform="douyin",
video_id=video_data['aweme_detail']['aweme_id'],
raw_info={
'tags': video_data['aweme_detail']['caption'] + ''.join(tags),
},
video_path=None # ❗音频下载不包含视频路径
)
return AudioDownloadResult(
file_path=output_path,
title=video_data['aweme_detail']['item_title'],
duration=video_data['aweme_detail']['video']['duration'],
cover_url=video_data['aweme_detail']['video']['cover_original_scale']['url_list'][0] if
video_data['aweme_detail']['video']['cover'] else video_data['video']['big_thumbs']['img_url'],
platform="douyin",
video_id=video_data['aweme_detail']['aweme_id'],
raw_info={
'tags': video_data['aweme_detail']['caption'] + ''.join(tags),
},
video_path=None # ❗音频下载不包含视频路径
)
except Exception as e:
raise e
def download_video(self, video_url: str, output_dir: Union[str, None] = None) -> str:

View File

@@ -0,0 +1,21 @@
import enum
class ProviderErrorEnum(enum.Enum):
CONNECTION_TEST_FAILED = (200101, "供应商连接测试失败")
SAVE_FAILED = (200102, "供应商保存失败")
CREATE_FAILED = (200103, "供应商创建失败")
NOT_FOUND = (200104, "供应商不存在/未保存")
WRONG_PARAMETER = (200105, "API / API 地址不正确")
UNKNOW_ERROR = (200106, "未知错误")
def __init__(self, code, message):
self.code = code
self.message = message
class NoteErrorEnum(enum.Enum):
PLATFORM_NOT_SUPPORTED = (300101 ,"选择的平台不受支持")
def __init__(self, code, message):
self.code = code
self.message = message

View File

View File

@@ -0,0 +1,6 @@
# exceptions/biz_exception.py
class BizException(Exception):
def __init__(self, code: int, message: str = "业务异常"):
self.code = code
self.message = message

View File

@@ -0,0 +1,33 @@
# middlewares/exception_handler.py
from fastapi import Request
from fastapi import FastAPI
from app.enmus.exception import NoteErrorEnum
from app.exceptions.biz_exception import BizException
from app.exceptions.note import NoteError
from app.exceptions.provider import ProviderError
from app.utils.logger import get_logger
from app.utils.response import ResponseWrapper as R
import traceback
logger = get_logger(__name__)
def register_exception_handlers(app: FastAPI):
@app.exception_handler(BizException)
async def biz_exception_handler(request: Request, exc: BizException):
logger.error(f"BizException: {exc.code} - {exc.message}")
return R.error(code=exc.code, msg=str(exc.message))
@app.exception_handler(NoteError)
async def note_exception_handler(request: Request, exc: NoteError):
logger.error(f"NoteError: {exc.code} - {exc.message}")
return R.error(code=exc.code, msg=str(exc.message))
@app.exception_handler(ProviderError)
async def provider_exception_handler(request: Request, exc: ProviderError):
logger.error(f"供应商模块错误: {exc.code} - {exc.message}")
return R.error(code=exc.code, msg=str(exc.message))
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
logger.error(f"系统异常: {str(exc)}\n{traceback.format_exc()}")
return R.error(code=500000, msg="系统异常")

View File

@@ -0,0 +1,9 @@
# exceptions.py
from app.enmus.exception import ProviderErrorEnum
class NoteError(Exception):
def __init__(self, message: str,code: ProviderErrorEnum) -> None:
super().__init__(message)
self.code=code
self.message = message

View File

@@ -1,5 +1,12 @@
# exceptions.py
class ConnectionTestError(Exception):
def __init__(self, message: str):
from app.enmus.exception import ProviderErrorEnum
class ProviderError(Exception):
def __init__(self, message: str,code: ProviderErrorEnum) -> None:
super().__init__(message)
self.message = message
self.code=code
self.message = message

View File

@@ -2,6 +2,9 @@ from typing import Optional, Union
from openai import OpenAI
from app.utils.logger import get_logger
logging= get_logger(__name__)
class OpenAICompatibleProvider:
def __init__(self, api_key: str, base_url: str, model: Union[str, None]=None):
self.client = OpenAI(api_key=api_key, base_url=base_url)
@@ -13,11 +16,16 @@ class OpenAICompatibleProvider:
@staticmethod
def test_connection(api_key: str, base_url: str) -> bool:
print(api_key)
try:
client = OpenAI(api_key=api_key, base_url=base_url)
client.models.list()
model = client.models.list()
# for segment in model:
# print(segment)
# print(model)
logging.info("连通性测试成功")
return True
except Exception as e:
print(f"Error connecting to OpenAI API: {e}")
logging.info(f"连通性测试失败:{e}")
# print(f"Error connecting to OpenAI API: {e}")
return False

View File

@@ -27,4 +27,6 @@ def get_cookie(platform: str):
@router.post("/update_downloader_cookie")
def update_cookie(data: CookieUpdateRequest):
cookie_manager.set(data.platform, data.cookie)
return {"message": "Cookie updated successfully"}
return R.success(
)

View File

@@ -31,10 +31,9 @@ def delete_model(model_id: int):
return R.error(f"删除模型失败: {e}")
@router.get("/model_list/{provider_id}")
def model_list(provider_id):
try:
return R.success(modelService.get_all_models_by_id(provider_id))
except Exception as e:
return R.error(e)
return R.success(modelService.get_all_models_by_id(provider_id))
@router.post("/models")
def create_model(data: CreateModelRequest):

View File

@@ -2,6 +2,7 @@
import json
import os
import uuid
from pathlib import Path
from typing import Optional
from urllib.parse import urlparse
@@ -10,7 +11,9 @@ from pydantic import BaseModel, validator, field_validator
from dataclasses import asdict
from app.db.video_task_dao import get_task_by_video
from app.enmus.exception import NoteErrorEnum
from app.enmus.note_enums import DownloadQuality
from app.exceptions.note import NoteError
from app.services.note import NoteGenerator, logger
from app.utils.response import ResponseWrapper as R
from app.utils.url_parser import extract_video_id
@@ -54,12 +57,13 @@ class VideoRequest(BaseModel):
if parsed.scheme in ("http", "https"):
# 是网络链接,继续用原有平台校验
if not is_supported_video_url(url):
raise ValueError("暂不支持该视频平台或链接格式无效")
raise NoteError(code=NoteErrorEnum.PLATFORM_NOT_SUPPORTED.code,
message=NoteErrorEnum.PLATFORM_NOT_SUPPORTED.message)
return v
NOTE_OUTPUT_DIR = "note_results"
NOTE_OUTPUT_DIR = os.getenv("NOTE_OUTPUT_DIR", "note_results")
UPLOAD_DIR = "uploads"
@@ -74,30 +78,32 @@ def run_note_task(task_id: str, video_url: str, platform: str, quality: Download
_format: list = None, style: str = None, extras: str = None, video_understanding: bool = False,
video_interval=0, grid_size=[]
):
try:
if not model_name or not provider_id:
raise HTTPException(status_code=400, detail="请选择模型和提供者")
note = NoteGenerator().generate(
video_url=video_url,
platform=platform,
quality=quality,
task_id=task_id,
model_name=model_name,
provider_id=provider_id,
link=link,
_format=_format,
style=style,
extras=extras,
screenshot=screenshot
, video_understanding=video_understanding,
video_interval=video_interval,
grid_size=grid_size
)
logger.info(f"Note generated: {task_id}")
save_note_to_file(task_id, note)
except Exception as e:
save_note_to_file(task_id, {"error": str(e)})
if not model_name or not provider_id:
raise HTTPException(status_code=400, detail="请选择模型和提供者")
note = NoteGenerator().generate(
video_url=video_url,
platform=platform,
quality=quality,
task_id=task_id,
model_name=model_name,
provider_id=provider_id,
link=link,
_format=_format,
style=style,
extras=extras,
screenshot=screenshot
, video_understanding=video_understanding,
video_interval=video_interval,
grid_size=grid_size
)
logger.info(f"Note generated: {task_id}")
if not note or not note.markdown:
logger.warning(f"任务 {task_id} 执行失败,跳过保存")
return
save_note_to_file(task_id, note)
@router.post('/delete_task')
@@ -135,7 +141,6 @@ def generate_note(data: VideoRequest, background_tasks: BackgroundTasks):
# msg='笔记已生成,请勿重复发起',
#
# )
if data.task_id:
# 如果传了task_id说明是重试
task_id = data.task_id

View File

@@ -2,7 +2,7 @@ from typing import Optional
from fastapi import APIRouter
from pydantic import BaseModel
from app.exceptions.provider import ConnectionTestError
from app.exceptions.provider import ProviderError
from app.models.model_config import ModelConfig
from app.services.model import ModelService
from app.utils.response import ResponseWrapper as R
@@ -88,9 +88,5 @@ def update_provider(data: ProviderUpdateRequest):
@router.post('/connect_test')
def gpt_connect_test(data: TestRequest):
try:
ModelService().connect_test(data.id)
return R.success(msg='连接成功')
except Exception as e:
print("捕获到异常类型:", type(e))
return R.error(msg=str(e))
ModelService().connect_test(data.id)
return R.success(msg='连接成功')

View File

@@ -1,12 +1,16 @@
from app.db.model_dao import insert_model, get_all_models, get_model_by_provider_and_name, delete_model
from app.db.provider_dao import get_enabled_providers
from app.exceptions.provider import ConnectionTestError
from app.enmus.exception import ProviderErrorEnum
from app.exceptions.provider import ProviderError
from app.gpt.gpt_factory import GPTFactory
from app.gpt.provider.OpenAI_compatible_provider import OpenAICompatibleProvider
from app.models.model_config import ModelConfig
from app.services.provider import ProviderService
from app.utils.logger import get_logger
logger=get_logger(__name__)
class ModelService:
@staticmethod
@@ -83,37 +87,38 @@ class ModelService:
provider = ProviderService.get_provider_by_id(provider_id)
models = ModelService.get_model_list(provider["id"], verbose=verbose)
model_list={
"models": models
print(type(models))
serializable_models = [m.dict() for m in models.data]
model_list = {
"models": serializable_models
}
logger.info(f"[{provider['name']}] 获取模型成功")
return model_list
except Exception as e:
print(f"[{provider_id}] 获取模型失败: {e}")
# print(f"[{provider_id}] 获取模型失败: {e}")
logger.error(f"[{provider_id}] 获取模型失败: {e}")
return []
@staticmethod
def connect_test(id: str) -> bool:
try:
provider = ProviderService.get_provider_by_id(id)
if provider:
if not provider.get('api_key'):
raise ConnectionTestError(f"供应商信息未找到,请先保存重试")
result = OpenAICompatibleProvider.test_connection(
api_key=provider.get('api_key'),
base_url=provider.get('base_url')
)
if result:
return True
else:
raise ConnectionTestError("请检查API Key 和 API 地址是否正确")
provider = ProviderService.get_provider_by_id(id)
if provider:
if not provider.get('api_key'):
raise ProviderError(code=ProviderErrorEnum.NOT_FOUND.code, message=ProviderErrorEnum.NOT_FOUND.message)
result = OpenAICompatibleProvider.test_connection(
api_key=provider.get('api_key'),
base_url=provider.get('base_url')
)
if result:
return True
else:
raise ProviderError(code=ProviderErrorEnum.WRONG_PARAMETER.code,message=ProviderErrorEnum.WRONG_PARAMETER.message)
raise ProviderError(code=ProviderErrorEnum.NOT_FOUND.code, message=ProviderErrorEnum.NOT_FOUND.message)
raise ConnectionTestError("供应商信息未找到,请先保存重试")
except Exception as e:
# 抛出业务异常,交由 Controller 处理
raise ConnectionTestError(f"{str(e)}") from e
@staticmethod
def delete_model_by_id( model_id: int) -> bool:

View File

@@ -1,75 +1,63 @@
import json
from dataclasses import asdict
from fastapi import HTTPException
from app.downloaders.local_downloader import LocalDownloader
from app.enmus.task_status_enums import TaskStatus
import logging
import os
from typing import Union, Optional
import re
from dataclasses import asdict
from pathlib import Path
from typing import List, Optional, Tuple, Union, Any
from pydantic import HttpUrl
from dotenv import load_dotenv
from app.db.video_task_dao import insert_video_task, delete_task_by_video
from app.downloaders.base import Downloader
from app.downloaders.bilibili_downloader import BilibiliDownloader
from app.downloaders.douyin_downloader import DouyinDownloader
from app.downloaders.youtube_downloader import YoutubeDownloader
from app.services.constant import SUPPORT_PLATFORM_MAP
from app.enmus.task_status_enums import TaskStatus
from app.enmus.exception import NoteErrorEnum, ProviderErrorEnum
from app.exceptions.note import NoteError
from app.exceptions.provider import ProviderError
from app.db.video_task_dao import delete_task_by_video, insert_video_task
from app.gpt.base import GPT
from app.gpt.deepseek_gpt import DeepSeekGPT
from app.gpt.gpt_factory import GPTFactory
from app.gpt.openai_gpt import OpenaiGPT
from app.gpt.qwen_gpt import QwenGPT
from app.models.audio_model import AudioDownloadResult
from app.models.gpt_model import GPTSource
from app.models.model_config import ModelConfig
from app.models.notes_model import NoteResult
from app.models.notes_model import AudioDownloadResult
from app.enmus.note_enums import DownloadQuality
from app.models.transcriber_model import TranscriptResult, TranscriptSegment
from app.services.constant import SUPPORT_PLATFORM_MAP
from app.services.provider import ProviderService
from app.transcriber.base import Transcriber
from app.transcriber.transcriber_provider import get_transcriber, _transcribers
from app.transcriber.whisper import WhisperTranscriber
import re
from app.utils.note_helper import replace_content_markers
from app.utils.status_code import StatusCode
from app.utils.video_helper import generate_screenshot
# from app.services.whisperer import transcribe_audio
# from app.services.gpt import summarize_text
from dotenv import load_dotenv
from app.utils.logger import get_logger
from app.utils.video_reader import VideoReader
from events import transcription_finished
from app.utils.video_helper import generate_screenshot
from app.utils.note_helper import replace_content_markers
from app.enmus.note_enums import DownloadQuality
logger = get_logger(__name__)
# 环境变量
load_dotenv()
api_path = os.getenv("API_BASE_URL", "http://localhost")
BACKEND_PORT = os.getenv("BACKEND_PORT", 8000)
NOTE_OUTPUT_DIR = Path(os.getenv("NOTE_OUTPUT_DIR", "note_results"))
NOTE_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
IMAGE_BASE_URL = os.getenv("IMAGE_BASE_URL", "/static/screenshots")
IMAGE_OUTPUT_DIR = os.getenv("OUT_DIR", "images")
BACKEND_BASE_URL = f"{api_path}:{BACKEND_PORT}"
output_dir = os.getenv('OUT_DIR')
image_base_url = os.getenv('IMAGE_BASE_URL')
logger.info("starting up")
NOTE_OUTPUT_DIR = "note_results"
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class NoteGenerator:
class States:
INIT = 'INIT'
PARSING = 'PARSING'
DOWNLOADING = 'DOWNLOADING'
TRANSCRIBING = 'TRANSCRIBING'
SUMMARIZING = 'SUMMARIZING'
SAVING = 'SAVING'
SUCCESS = 'SUCCESS'
FAILED = 'FAILED'
def __init__(self):
self.model_size: str = 'base'
self.device: Union[str, None] = None
self.transcriber_type = os.getenv('TRANSCRIBER_TYPE', 'fast-whisper')
self.transcriber = self.get_transcriber()
self.video_path = None
logger.info("初始化NoteGenerator")
import logging
logger = logging.getLogger(__name__)
self.transcriber_type = os.getenv("TRANSCRIBER_TYPE", "fast-whisper")
self.transcriber: Transcriber = self._init_transcriber()
self.video_img_urls = []
@staticmethod
def update_task_status(task_id: str, status: Union[str, TaskStatus], message: Optional[str] = None):
@@ -81,310 +69,179 @@ class NoteGenerator:
with open(path, "w", encoding="utf-8") as f:
json.dump(content, f, ensure_ascii=False, indent=2)
def get_gpt(self, model_name: str = None, provider_id: str = None) -> GPT:
def generate(
self,
video_url: Union[str, HttpUrl],
platform: str,
quality: DownloadQuality = DownloadQuality.medium,
task_id: Optional[str] = None,
model_name: Optional[str] = None,
provider_id: Optional[str] = None,
link: bool = False,
screenshot: bool = False,
_format: Optional[List[str]] = None,
style: Optional[str] = None,
extras: Optional[str] = None,
output_path: Optional[str] = None,
video_understanding: bool = False,
video_interval: int = 0,
grid_size: Optional[List[int]] = None,
) -> NoteResult | None:
self.task_id = task_id
self._change_state(self.States.INIT)
try:
self._change_state(self.States.PARSING)
downloader = self._get_downloader(platform)
gpt = self._get_gpt(model_name, provider_id)
self.audio_cache_file = NOTE_OUTPUT_DIR / f"{task_id}_audio.json"
self.transcript_cache_file = NOTE_OUTPUT_DIR / f"{task_id}_transcript.json"
self.markdown_cache_file = NOTE_OUTPUT_DIR / f"{task_id}_markdown.md"
self.audio_meta = self._download_audio_video(
downloader, video_url, quality, output_path,
screenshot, video_understanding, video_interval, grid_size or []
)
self.transcript = self._transcribe_audio()
self.markdown = self._summarize_text(
gpt, link, screenshot, _format or [], style, extras
)
self.markdown = self._post_process_markdown(
self.markdown, self.video_path, _format or [], self.audio_meta, platform
)
self._change_state(self.States.SAVING)
self._save_metadata(self.audio_meta.video_id, platform, task_id)
self._change_state(self.States.SUCCESS)
return NoteResult(markdown=self.markdown, transcript=self.transcript, audio_meta=self.audio_meta)
except Exception as e:
logger.exception(f"任务 {self.task_id} 失败: {e}")
self._change_state(self.States.FAILED, str(e))
return None
def _change_state(self, state: str, message: Optional[str] = None):
if not self.task_id:
return
NOTE_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
status_file = NOTE_OUTPUT_DIR / f"{self.task_id}.status.json"
data = {"status": state}
if message:
data["message"] = message
temp_file = status_file.with_suffix('.tmp')
with temp_file.open('w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
temp_file.replace(status_file)
def _init_transcriber(self) -> Transcriber:
if self.transcriber_type not in _transcribers:
raise Exception(f"不支持的转写器:{self.transcriber_type}")
return get_transcriber(self.transcriber_type)
def _get_gpt(self, model_name: Optional[str], provider_id: Optional[str]) -> GPT:
provider = ProviderService.get_provider_by_id(provider_id)
if not provider:
logger.error(f"[get_gpt] 未找到对应的模型供应商: provider_id={provider_id}")
raise ValueError(f"未找到对应的模型供应商: provider_id={provider_id}")
gpt = GPTFactory().from_config(
ModelConfig(
api_key=provider.get('api_key'),
base_url=provider.get('base_url'),
model_name=model_name,
provider=provider.get('type'),
name=provider.get('name')
)
raise ProviderError(code=ProviderErrorEnum.NOT_FOUND, message=ProviderErrorEnum.NOT_FOUND.message)
config = ModelConfig(
api_key=provider["api_key"], base_url=provider["base_url"],
model_name=model_name, provider=provider["type"], name=provider["name"]
)
return gpt
return GPTFactory().from_config(config)
def get_downloader(self, platform: str) -> Downloader:
downloader = SUPPORT_PLATFORM_MAP[platform]
if downloader:
logger.info(f"使用{downloader}下载器")
return downloader
else:
logger.warning("不支持的平台")
raise ValueError(f"不支持的平台:{platform}")
def _get_downloader(self, platform: str) -> Downloader:
downloader_cls = SUPPORT_PLATFORM_MAP.get(platform)
if not downloader_cls:
raise NoteError(code=NoteErrorEnum.PLATFORM_NOT_SUPPORTED.code,
message=NoteErrorEnum.PLATFORM_NOT_SUPPORTED.message)
return downloader_cls
def get_transcriber(self) -> Transcriber:
'''
def _download_audio_video(self, downloader, video_url, quality, output_path,
screenshot, video_understanding, video_interval, grid_size):
self._change_state(self.States.DOWNLOADING)
:param transcriber: 选择的转义器
:return:
'''
if self.transcriber_type in _transcribers.keys():
logger.info(f"使用{self.transcriber_type}转义器")
return get_transcriber(transcriber_type=self.transcriber_type)
else:
logger.warning("不支持的转义器")
raise ValueError(f"不支持的转义器:{self.transcriber}")
need_video = screenshot or video_understanding
if need_video:
self.video_path = Path(downloader.download_video(video_url, output_path))
if grid_size:
self.video_img_urls = VideoReader(
video_path=str(self.video_path),
grid_size=tuple(grid_size),
frame_interval=video_interval,
unit_width=1280, unit_height=720,
save_quality=90,
).run()
def save_meta(self, video_id, platform, task_id):
logger.info(f"记录已经生成的数据信息")
insert_video_task(video_id=video_id, platform=platform, task_id=task_id)
if self.audio_cache_file.exists():
with open(self.audio_cache_file, "r", encoding="utf-8") as f:
data = json.load(f)
return AudioDownloadResult(**data)
def insert_screenshots_into_markdown(self, markdown: str, video_path: str, image_base_url: str,
output_dir: str, _format: list) -> str:
"""
扫描 markdown 中的 *Screenshot-xx:xx生成截图并插入 markdown 图片
:param markdown:
:param image_base_url: 最终返回给前端的路径前缀(如 /static/screenshots
"""
matches = self.extract_screenshot_timestamps(markdown)
new_markdown = markdown
audio = downloader.download(
video_url=video_url, quality=quality, output_dir=output_path, need_video=need_video
)
with open(self.audio_cache_file, "w", encoding="utf-8") as f:
json.dump(asdict(audio), f, ensure_ascii=False, indent=2)
return audio
logger.info(f"开始为笔记生成截图")
try:
for idx, (marker, ts) in enumerate(matches):
image_path = generate_screenshot(video_path, output_dir, ts, idx)
image_relative_path = os.path.join(image_base_url, os.path.basename(image_path)).replace("\\", "/")
image_url = f"/static/screenshots/{os.path.basename(image_path)}"
replacement = f"![]({image_url})"
new_markdown = new_markdown.replace(marker, replacement, 1)
def _transcribe_audio(self):
self._change_state(self.States.TRANSCRIBING)
if self.transcript_cache_file.exists():
with open(self.transcript_cache_file, "r", encoding="utf-8") as f:
data = json.load(f)
segments = [TranscriptSegment(**seg) for seg in data.get("segments", [])]
return TranscriptResult(language=data["language"], full_text=data["full_text"], segments=segments)
return new_markdown
except Exception as e:
logger.error(f"截图生成失败:{e}")
raise HTTPException(
status_code=500,
detail={
"code": StatusCode.DOWNLOAD_ERROR,
"msg": f"截图生成失败",
"error": str(e)
}
)
transcript = self.transcriber.transcript(self.audio_meta.file_path)
with open(self.transcript_cache_file, "w", encoding="utf-8") as f:
json.dump(asdict(transcript), f, ensure_ascii=False, indent=2)
return transcript
def _summarize_text(self, gpt, link, screenshot, formats, style, extras):
self._change_state(self.States.SUMMARIZING)
source = GPTSource(
title=self.audio_meta.title,
segment=self.transcript.segments,
tags=self.audio_meta.raw_info.get("tags", []),
screenshot=screenshot,
video_img_urls=self.video_img_urls,
link=link, _format=formats, style=style, extras=extras
)
markdown = gpt.summarize(source)
with open(self.markdown_cache_file, "w", encoding="utf-8") as f:
f.write(markdown)
return markdown
@staticmethod
def delete_note(video_id: str, platform: str):
logger.info(f"删除生成的笔记记录")
return delete_task_by_video(video_id, platform)
def _post_process_markdown(self, markdown, video_path, formats, audio_meta, platform):
if "screenshot" in formats and video_path:
markdown = self._insert_screenshots(markdown, video_path)
if "link" in formats:
markdown = replace_content_markers(markdown, video_id=audio_meta.video_id, platform=platform)
return markdown
import re
def extract_screenshot_timestamps(self, markdown: str) -> list[tuple[str, int]]:
"""
从 Markdown 中提取 Screenshot 时间标记(如 *Screenshot-03:39 或 Screenshot-[03:39]
并返回匹配文本和对应时间戳(秒)
"""
logger.info(f"开始提取截图时间标记")
def _insert_screenshots(self, markdown, video_path):
pattern = r"(?:\*Screenshot-(\d{2}):(\d{2})|Screenshot-\[(\d{2}):(\d{2})\])"
matches = list(re.finditer(pattern, markdown))
results = []
for match in matches:
matches = []
for match in re.finditer(pattern, markdown):
mm = match.group(1) or match.group(3)
ss = match.group(2) or match.group(4)
total_seconds = int(mm) * 60 + int(ss)
results.append((match.group(0), total_seconds))
return results
matches.append((match.group(0), int(mm)*60+int(ss)))
for idx, (marker, ts) in enumerate(matches):
img_path = generate_screenshot(str(video_path), str(IMAGE_OUTPUT_DIR), ts, idx)
filename = Path(img_path).name
img_url = f"{IMAGE_BASE_URL.rstrip('/')}/{filename}"
markdown = markdown.replace(marker, f"![]({img_url})", 1)
return markdown
def generate(
self,
video_url: Union[str, HttpUrl],
platform: str,
quality: DownloadQuality = DownloadQuality.medium,
task_id: Union[str, None] = None,
model_name: str = None,
provider_id: str = None,
link: bool = False,
screenshot: bool = False,
_format: list = None,
style: str = None,
extras: str = None,
path: Union[str, None] = None,
video_understanding: bool = False,
video_interval=0,
grid_size=[]
) -> NoteResult:
def _save_metadata(self, video_id: str, platform: str, task_id: str):
insert_video_task(video_id=video_id, platform=platform, task_id=task_id)
try:
logger.info(f"🎯 开始解析并生成笔记task_id={task_id}")
self.update_task_status(task_id, TaskStatus.PARSING)
downloader = self.get_downloader(platform)
gpt = self.get_gpt(model_name=model_name, provider_id=provider_id)
video_img_urls = []
audio_cache_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}_audio.json")
transcript_cache_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}_transcript.json")
markdown_cache_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}_markdown.md")
# -------- 1. 下载音频 --------
try:
self.update_task_status(task_id, TaskStatus.DOWNLOADING)
# 加载音频缓存(如果存在)
audio = None
if os.path.exists(audio_cache_path):
logger.info(f"检测到已有音频缓存直接读取task_id={task_id}")
with open(audio_cache_path, "r", encoding="utf-8") as f:
audio_data = json.load(f)
audio = AudioDownloadResult(**audio_data)
# 需要视频的情况(截图 or 视频理解)
need_video = 'screenshot' in _format or video_understanding
if need_video:
try:
video_path = downloader.download_video(video_url)
self.video_path = video_path
logger.info(f"成功下载视频文件: {video_path}")
video_img_urls = VideoReader(
video_path=video_path,
grid_size=tuple(grid_size),
frame_interval=video_interval,
unit_width=1280,
unit_height=720,
save_quality=90,
).run()
except Exception as e:
logger.error(f"Error 下载视频失败task_id={task_id},错误信息:{e}")
self.update_task_status(task_id, TaskStatus.FAILED, message=f"下载音频失败:{e}")
raise HTTPException(
status_code=500,
detail={
"code": StatusCode.DOWNLOAD_ERROR,
"msg": f"下载视频失败task_id={task_id}",
"error": str(e)
}
)
# 没有音频缓存就下载音频(可能同时也带上视频)
if audio is None:
audio = downloader.download(
video_url=video_url,
quality=quality,
output_dir=path,
need_video='screenshot' in _format, # 注意这里只为了截图需要
)
with open(audio_cache_path, "w", encoding="utf-8") as f:
json.dump(asdict(audio), f, ensure_ascii=False, indent=2)
logger.info(f"音频下载并缓存成功task_id={task_id}")
except Exception as e:
logger.error(f"Error 下载音频失败task_id={task_id},错误信息:{e}")
self.update_task_status(task_id, TaskStatus.FAILED, message=f"下载音频失败:{e}")
raise HTTPException(
status_code=500,
detail={
"code": StatusCode.DOWNLOAD_ERROR,
"msg": f"下载音频失败task_id={task_id}",
"error": str(e)
}
)
# -------- 2. 转写文字 --------
try:
self.update_task_status(task_id, TaskStatus.TRANSCRIBING)
if os.path.exists(transcript_cache_path):
logger.info(f"检测到已有转写缓存直接读取task_id={task_id}")
try:
with open(transcript_cache_path, "r", encoding="utf-8") as f:
transcript_data = json.load(f)
transcript = TranscriptResult(
language=transcript_data["language"],
full_text=transcript_data["full_text"],
segments=[TranscriptSegment(**seg) for seg in transcript_data["segments"]]
)
except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"Warning 读取转录缓存失败重新转录task_id={task_id},错误信息:{e}")
transcript: TranscriptResult = self.transcriber.transcript(file_path=audio.file_path)
with open(transcript_cache_path, "w", encoding="utf-8") as f:
json.dump(asdict(transcript), f, ensure_ascii=False, indent=2)
else:
transcript: TranscriptResult = self.transcriber.transcript(file_path=audio.file_path)
with open(transcript_cache_path, "w", encoding="utf-8") as f:
json.dump(asdict(transcript), f, ensure_ascii=False, indent=2)
logger.info(f"文字转写并缓存成功task_id={task_id}")
except Exception as e:
logger.error(f"Error 转写文字失败task_id={task_id},错误信息:{e}")
self.update_task_status(task_id, TaskStatus.FAILED, message=f"转写文字失败:{e}")
raise HTTPException(
status_code=500,
detail={
"code": StatusCode.GENERATE_ERROR, # =1003
"msg": f"转写文字失败task_id={task_id}",
"error": str(e)
}
)
# -------- 3. 总结内容 --------
try:
self.update_task_status(task_id, TaskStatus.SUMMARIZING)
# if os.path.exists(markdown_cache_path):
# logger.info(f"检测到已有总结缓存直接读取task_id={task_id}")
# with open(markdown_cache_path, "r", encoding="utf-8") as f:
# markdown = f.read()
# else:
source = GPTSource(
title=audio.title,
segment=transcript.segments,
tags=audio.raw_info.get('tags'),
screenshot=screenshot,
video_img_urls=video_img_urls,
link=link,
_format=_format,
style=style,
extras=extras
)
markdown: str = gpt.summarize(source)
with open(markdown_cache_path, "w", encoding="utf-8") as f:
f.write(markdown)
logger.info(f"GPT总结并缓存成功task_id={task_id}")
except Exception as e:
logger.error(f"Error 总结内容失败task_id={task_id},错误信息:{e}")
self.update_task_status(task_id, TaskStatus.FAILED, message=f"总结内容失败:{e}")
raise HTTPException(
status_code=500,
detail={
"code": StatusCode.GENERATE_ERROR, # =1003
"msg": f"总结内容失败task_id={task_id}",
"error": str(e)
}
)
# -------- 4. 插入截图 --------
if _format and 'screenshot' in _format:
try:
markdown = self.insert_screenshots_into_markdown(markdown, self.video_path, image_base_url,
output_dir, _format)
except Exception as e:
logger.warning(f"Warning 插入截图失败跳过处理task_id={task_id},错误信息:{e}")
if _format and 'link' in _format:
try:
markdown = replace_content_markers(markdown, video_id=audio.video_id, platform=platform)
except Exception as e:
logger.warning(f"Warning 插入链接失败跳过处理task_id={task_id},错误信息:{e}")
# 注意:截图失败不终止整体流程
# -------- 5. 保存数据库记录 --------
self.update_task_status(task_id, TaskStatus.SAVING)
self.save_meta(video_id=audio.video_id, platform=platform, task_id=task_id)
# -------- 6. 完成 --------
self.update_task_status(task_id, TaskStatus.SUCCESS)
logger.info(f"succeed 笔记生成成功task_id={task_id}")
# TODO :改为前端一键清除缓存
# if platform != 'local':
# transcription_finished.send({
# "file_path": audio.file_path,
# })
return NoteResult(
markdown=markdown,
transcript=transcript,
audio_meta=audio
)
except Exception as e:
logger.error(f"Error 笔记生成流程异常终止task_id={task_id},错误信息:{e}")
self.update_task_status(task_id, TaskStatus.FAILED, message=str(e))
# 返回结构化错误信息给前端(可以用于日志 + 显示 + 错误定位)
raise HTTPException(
status_code=500,
detail={
"code": StatusCode.FAIL,
"msg": f"笔记生成流程异常终止task_id={task_id}",
"error": str(e)
}
)
@staticmethod
def delete_note(video_id: str, platform: str) -> int:
return delete_task_by_video(video_id, platform)

View File

@@ -1,18 +1,24 @@
from fastapi.responses import JSONResponse
from app.utils.status_code import StatusCode
from pydantic import BaseModel
from typing import Optional, Any
from fastapi.responses import JSONResponse
class ResponseWrapper:
@staticmethod
def success(data=None, msg="success", code=StatusCode.SUCCESS):
return {
"code": int(code),
def success(data=None, msg="success", code=0):
return JSONResponse(content={
"code": code,
"msg": msg,
"data": data
}
})
@staticmethod
def error(msg="error", code=StatusCode.FAIL, data=None):
return {
"code": int(code),
def error(msg="error", code=500, data=None):
return JSONResponse(content={
"code": code,
"msg": msg,
"data": data
}
})

View File

@@ -4,7 +4,7 @@ import uvicorn
from starlette.staticfiles import StaticFiles
from dotenv import load_dotenv
from app.core.exception_handlers import register_exception_handlers
from app.exceptions.exception_handlers import register_exception_handlers
from app.db.model_dao import init_model_table
from app.db.provider_dao import init_provider_table
from app.utils.logger import get_logger
@@ -33,11 +33,14 @@ if not os.path.exists(out_dir):
os.makedirs(out_dir)
app = create_app()
register_exception_handlers(app)
app.mount(static_path, StaticFiles(directory=static_dir), name="static")
app.mount("/uploads", StaticFiles(directory=uploads_dir), name="uploads")
@app.on_event("startup")
async def startup_event():
register_exception_handlers(app)
register_handler()
ensure_ffmpeg_or_raise()
register_handler()
@@ -46,8 +49,9 @@ async def startup_event():
init_provider_table()
init_model_table()
if __name__ == "__main__":
port = int(os.getenv("BACKEND_PORT", 8000))
host = os.getenv("BACKEND_HOST", "0.0.0.0")
logger.info(f"Starting server on {host}:{port}")
uvicorn.run("main:app", host=host, port=port, reload=False)
uvicorn.run(app, host=host, port=port, reload=False)