mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-05-31 05:11:10 +08:00
feat(backend): 添加 Groq供应商支持并优化笔记生成流程- 在 builtin_providers.json 中添加 Groq 供应商信息
- 实现 GroqTranscriber 类以支持 Groq 语音转录服务 - 新增异常处理中间件以提高系统稳定性 - 优化笔记生成流程,增加错误处理和日志记录 - 添加思维导图功能和相关组件 -重构 Markdown 查看器以支持切换视图模式
This commit is contained in:
0
backend/app/core/__init__.py
Normal file
0
backend/app/core/__init__.py
Normal file
38
backend/app/core/exception_handlers.py
Normal file
38
backend/app/core/exception_handlers.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# app/core/exception_handlers.py
|
||||
from fastapi import Request, HTTPException
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.utils.logger import get_logger
|
||||
from app.utils.response import ResponseWrapper
|
||||
from app.utils.status_code import StatusCode
|
||||
logger = get_logger(__name__)
|
||||
|
||||
def register_exception_handlers(app):
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
errors = []
|
||||
for err in exc.errors():
|
||||
loc = err.get("loc", [])
|
||||
field = loc[-1] if loc else "body"
|
||||
msg = err.get("msg", "参数不合法")
|
||||
errors.append({"field": field, "error": msg})
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=ResponseWrapper.error(msg="参数验证失败", code=StatusCode.PARAM_ERROR, data=errors)
|
||||
)
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request: Request, exc: HTTPException):
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=ResponseWrapper.error(msg=str(exc.detail), code=StatusCode.FAIL)
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
logger.exception(f"服务器内部错误: {exc}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=ResponseWrapper.error(msg="服务器内部错误", code=StatusCode.FAIL, data=str(exc))
|
||||
)
|
||||
@@ -38,5 +38,13 @@
|
||||
"logo": "Gemini",
|
||||
"api_key": "",
|
||||
"base_url": "https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||
},
|
||||
{
|
||||
"id": "groq",
|
||||
"name": "Groq",
|
||||
"type": "built-in",
|
||||
"logo": "Groq",
|
||||
"api_key": "",
|
||||
"base_url": "https://api.groq.com/openai/v1"
|
||||
}
|
||||
]
|
||||
|
||||
@@ -13,8 +13,10 @@ class OpenAICompatibleProvider:
|
||||
|
||||
@staticmethod
|
||||
def test_connection(api_key: str, base_url: str) -> bool:
|
||||
print(api_key)
|
||||
try:
|
||||
client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
client.models.list()
|
||||
return True
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import json
|
||||
from dataclasses import asdict
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.downloaders.local_downloader import LocalDownloader
|
||||
from app.enmus.task_status_enums import TaskStatus
|
||||
import os
|
||||
@@ -33,6 +35,7 @@ from app.transcriber.whisper import WhisperTranscriber
|
||||
import re
|
||||
|
||||
from app.utils.note_helper import replace_content_markers
|
||||
from app.utils.status_code import StatusCode
|
||||
from app.utils.video_helper import generate_screenshot
|
||||
|
||||
# from app.services.whisperer import transcribe_audio
|
||||
@@ -143,7 +146,15 @@ class NoteGenerator:
|
||||
return new_markdown
|
||||
except Exception as e:
|
||||
logger.error(f"截图生成失败:{e}")
|
||||
raise e
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"code": StatusCode.DOWNLOAD_ERROR,
|
||||
"msg": f"截图生成失败",
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def delete_note(video_id: str, platform: str):
|
||||
@@ -226,8 +237,16 @@ class NoteGenerator:
|
||||
save_quality=90,
|
||||
).run()
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 下载视频失败,task_id={task_id},错误信息:{e}")
|
||||
logger.error(f"Error 下载视频失败,task_id={task_id},错误信息:{e}")
|
||||
self.update_task_status(task_id, TaskStatus.FAILED, message=f"下载音频失败:{e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"code": StatusCode.DOWNLOAD_ERROR,
|
||||
"msg": f"下载视频失败,task_id={task_id}",
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
# 没有音频缓存就下载音频(可能同时也带上视频)
|
||||
if audio is None:
|
||||
@@ -241,9 +260,17 @@ class NoteGenerator:
|
||||
json.dump(asdict(audio), f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"音频下载并缓存成功,task_id={task_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 下载音频失败,task_id={task_id},错误信息:{e}")
|
||||
logger.error(f"Error 下载音频失败,task_id={task_id},错误信息:{e}")
|
||||
self.update_task_status(task_id, TaskStatus.FAILED, message=f"下载音频失败:{e}")
|
||||
raise e
|
||||
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"code": StatusCode.DOWNLOAD_ERROR,
|
||||
"msg": f"下载音频失败,task_id={task_id}",
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
# -------- 2. 转写文字 --------
|
||||
try:
|
||||
@@ -259,7 +286,7 @@ class NoteGenerator:
|
||||
segments=[TranscriptSegment(**seg) for seg in transcript_data["segments"]]
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.warning(f"⚠️ 读取转录缓存失败,重新转录,task_id={task_id},错误信息:{e}")
|
||||
logger.warning(f"Warning 读取转录缓存失败,重新转录,task_id={task_id},错误信息:{e}")
|
||||
transcript: TranscriptResult = self.transcriber.transcript(file_path=audio.file_path)
|
||||
with open(transcript_cache_path, "w", encoding="utf-8") as f:
|
||||
json.dump(asdict(transcript), f, ensure_ascii=False, indent=2)
|
||||
@@ -269,9 +296,16 @@ class NoteGenerator:
|
||||
json.dump(asdict(transcript), f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"文字转写并缓存成功,task_id={task_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 转写文字失败,task_id={task_id},错误信息:{e}")
|
||||
logger.error(f"Error 转写文字失败,task_id={task_id},错误信息:{e}")
|
||||
self.update_task_status(task_id, TaskStatus.FAILED, message=f"转写文字失败:{e}")
|
||||
raise e
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"code": StatusCode.GENERATE_ERROR, # =1003
|
||||
"msg": f"转写文字失败,task_id={task_id}",
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
# -------- 3. 总结内容 --------
|
||||
try:
|
||||
@@ -298,9 +332,16 @@ class NoteGenerator:
|
||||
f.write(markdown)
|
||||
logger.info(f"GPT总结并缓存成功,task_id={task_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 总结内容失败,task_id={task_id},错误信息:{e}")
|
||||
logger.error(f"Error 总结内容失败,task_id={task_id},错误信息:{e}")
|
||||
self.update_task_status(task_id, TaskStatus.FAILED, message=f"总结内容失败:{e}")
|
||||
raise e
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"code": StatusCode.GENERATE_ERROR, # =1003
|
||||
"msg": f"总结内容失败,task_id={task_id}",
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
# -------- 4. 插入截图 --------
|
||||
if _format and 'screenshot' in _format:
|
||||
@@ -308,12 +349,12 @@ class NoteGenerator:
|
||||
markdown = self.insert_screenshots_into_markdown(markdown, self.video_path, image_base_url,
|
||||
output_dir, _format)
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 插入截图失败,跳过处理,task_id={task_id},错误信息:{e}")
|
||||
logger.warning(f"Warning 插入截图失败,跳过处理,task_id={task_id},错误信息:{e}")
|
||||
if _format and 'link' in _format:
|
||||
try:
|
||||
markdown = replace_content_markers(markdown, video_id=audio.video_id, platform=platform)
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 插入链接失败,跳过处理,task_id={task_id},错误信息:{e}")
|
||||
logger.warning(f"Warning 插入链接失败,跳过处理,task_id={task_id},错误信息:{e}")
|
||||
# 注意:截图失败不终止整体流程
|
||||
|
||||
# -------- 5. 保存数据库记录 --------
|
||||
@@ -322,7 +363,7 @@ class NoteGenerator:
|
||||
|
||||
# -------- 6. 完成 --------
|
||||
self.update_task_status(task_id, TaskStatus.SUCCESS)
|
||||
logger.info(f"✅ 笔记生成成功,task_id={task_id}")
|
||||
logger.info(f"succeed 笔记生成成功,task_id={task_id}")
|
||||
# TODO :改为前端一键清除缓存
|
||||
# if platform != 'local':
|
||||
# transcription_finished.send({
|
||||
@@ -335,6 +376,15 @@ class NoteGenerator:
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 笔记生成流程异常终止,task_id={task_id},错误信息:{e}")
|
||||
logger.error(f"Error 笔记生成流程异常终止,task_id={task_id},错误信息:{e}")
|
||||
self.update_task_status(task_id, TaskStatus.FAILED, message=str(e))
|
||||
raise f'❌ 笔记生成流程异常终止,task_id={task_id},错误信息:{e}'
|
||||
|
||||
# 返回结构化错误信息给前端(可以用于日志 + 显示 + 错误定位)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"code": StatusCode.FAIL,
|
||||
"msg": f"笔记生成流程异常终止,task_id={task_id}",
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
52
backend/app/transcriber/groq.py
Normal file
52
backend/app/transcriber/groq.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from abc import ABC
|
||||
import os
|
||||
|
||||
from app.decorators.timeit import timeit
|
||||
from app.models.transcriber_model import TranscriptResult, TranscriptSegment
|
||||
from app.services.provider import ProviderService
|
||||
from app.transcriber.base import Transcriber
|
||||
from openai import OpenAI
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
class GroqTranscriber(Transcriber, ABC):
|
||||
|
||||
|
||||
@timeit
|
||||
def transcript(self, file_path: str) -> TranscriptResult:
|
||||
provider = ProviderService.get_provider_by_id('groq')
|
||||
if not provider:
|
||||
raise Exception("Groq 供应商未配置,请配置以后使用。")
|
||||
client = OpenAI(
|
||||
api_key=provider.get('api_key'),
|
||||
base_url=provider.get('base_url')
|
||||
)
|
||||
filename = file_path
|
||||
|
||||
with open(filename, "rb") as file:
|
||||
transcription = client.audio.transcriptions.create(
|
||||
file=(filename, file.read()),
|
||||
model=os.getenv('GROQ_TRANSCRIBER_MODEL'),
|
||||
response_format="verbose_json",
|
||||
)
|
||||
print(transcription.text)
|
||||
print(transcription)
|
||||
segments = []
|
||||
full_text = ""
|
||||
|
||||
for seg in transcription.segments:
|
||||
text = seg.text.strip()
|
||||
full_text += text + " "
|
||||
segments.append(TranscriptSegment(
|
||||
start=seg.start,
|
||||
end=seg.end,
|
||||
text=text
|
||||
))
|
||||
|
||||
result = TranscriptResult(
|
||||
language=transcription.language,
|
||||
full_text=full_text.strip(),
|
||||
segments=segments,
|
||||
raw=transcription.to_dict()
|
||||
)
|
||||
return result
|
||||
@@ -1,113 +1,115 @@
|
||||
import os
|
||||
import platform
|
||||
from enum import Enum
|
||||
|
||||
from app.transcriber.groq import GroqTranscriber
|
||||
from app.transcriber.whisper import WhisperTranscriber
|
||||
from app.transcriber.bcut import BcutTranscriber
|
||||
from app.transcriber.kuaishou import KuaishouTranscriber
|
||||
from app.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 只在Apple平台且设置了环境变量时才导入MLX Whisper
|
||||
class TranscriberType(str, Enum):
|
||||
FAST_WHISPER = "fast-whisper"
|
||||
MLX_WHISPER = "mlx-whisper"
|
||||
BCUT = "bcut"
|
||||
KUAISHOU = "kuaishou"
|
||||
GROQ = "groq"
|
||||
|
||||
# 仅在 Apple 平台启用 MLX Whisper
|
||||
MLX_WHISPER_AVAILABLE = False
|
||||
if platform.system() == "Darwin" and os.environ.get("TRANSCRIBER_TYPE") == "mlx-whisper":
|
||||
try:
|
||||
from app.transcriber.mlx_whisper_transcriber import MLXWhisperTranscriber
|
||||
MLX_WHISPER_AVAILABLE = True
|
||||
logger.info("MLX Whisper 可用,已导入")
|
||||
except ImportError:
|
||||
MLX_WHISPER_AVAILABLE = False
|
||||
logger.warning("MLX Whisper 导入失败,可能未安装或平台不支持")
|
||||
else:
|
||||
MLX_WHISPER_AVAILABLE = False
|
||||
|
||||
logger.info('初始化转录服务提供器')
|
||||
|
||||
# 维护各种转录器的单例实例
|
||||
# 转录器单例缓存
|
||||
_transcribers = {
|
||||
'bcut': None,
|
||||
'kuaishou': None,
|
||||
'mlx-whisper': None,
|
||||
'fast-whisper':None
|
||||
TranscriberType.FAST_WHISPER: None,
|
||||
TranscriberType.MLX_WHISPER: None,
|
||||
TranscriberType.BCUT: None,
|
||||
TranscriberType.KUAISHOU: None,
|
||||
TranscriberType.GROQ: None,
|
||||
}
|
||||
|
||||
def get_whisper_transcriber(model_size="base", device="cuda"):
|
||||
"""获取 Whisper 转录器实例"""
|
||||
if _transcribers['fast-whisper'] is None:
|
||||
logger.info(f'创建 Whisper 转录器实例,参数:{model_size}, {device}')
|
||||
# 公共实例初始化函数
|
||||
def _init_transcriber(key: TranscriberType, cls, *args, **kwargs):
|
||||
if _transcribers[key] is None:
|
||||
logger.info(f'创建 {cls.__name__} 实例: {key}')
|
||||
try:
|
||||
_transcribers['whisper'] = WhisperTranscriber(model_size=model_size, device=device)
|
||||
logger.info('Whisper 转录器创建成功')
|
||||
_transcribers[key] = cls(*args, **kwargs)
|
||||
logger.info(f'{cls.__name__} 创建成功')
|
||||
except Exception as e:
|
||||
logger.error(f"Whisper 转录器创建失败: {e}")
|
||||
logger.error(f"{cls.__name__} 创建失败: {e}")
|
||||
raise
|
||||
return _transcribers['whisper']
|
||||
return _transcribers[key]
|
||||
|
||||
# 各类型获取方法
|
||||
def get_groq_transcriber():
|
||||
return _init_transcriber(TranscriberType.GROQ, GroqTranscriber)
|
||||
|
||||
def get_whisper_transcriber(model_size="base", device="cuda"):
|
||||
return _init_transcriber(TranscriberType.FAST_WHISPER, WhisperTranscriber, model_size=model_size, device=device)
|
||||
|
||||
def get_bcut_transcriber():
|
||||
"""获取 Bcut 转录器实例"""
|
||||
if _transcribers['bcut'] is None:
|
||||
logger.info('创建 Bcut 转录器实例')
|
||||
try:
|
||||
_transcribers['bcut'] = BcutTranscriber()
|
||||
logger.info('Bcut 转录器创建成功')
|
||||
except Exception as e:
|
||||
logger.error(f"Bcut 转录器创建失败: {e}")
|
||||
raise
|
||||
return _transcribers['bcut']
|
||||
return _init_transcriber(TranscriberType.BCUT, BcutTranscriber)
|
||||
|
||||
def get_kuaishou_transcriber():
|
||||
"""获取快手转录器实例"""
|
||||
if _transcribers['kuaishou'] is None:
|
||||
logger.info('创建快手转录器实例')
|
||||
try:
|
||||
_transcribers['kuaishou'] = KuaishouTranscriber()
|
||||
logger.info('快手转录器创建成功')
|
||||
except Exception as e:
|
||||
logger.error(f"快手转录器创建失败: {e}")
|
||||
raise
|
||||
return _transcribers['kuaishou']
|
||||
return _init_transcriber(TranscriberType.KUAISHOU, KuaishouTranscriber)
|
||||
|
||||
def get_mlx_whisper_transcriber(model_size="base"):
|
||||
"""获取 MLX Whisper 转录器实例"""
|
||||
if not MLX_WHISPER_AVAILABLE:
|
||||
logger.warning("MLX Whisper 不可用,请确保在Apple平台且已安装mlx_whisper")
|
||||
raise ImportError("MLX Whisper 不可用,请确保在Apple平台且已安装mlx_whisper")
|
||||
|
||||
if _transcribers['mlx-whisper'] is None:
|
||||
logger.info(f'创建 MLX Whisper 转录器实例,参数:{model_size}')
|
||||
try:
|
||||
_transcribers['mlx-whisper'] = MLXWhisperTranscriber(model_size=model_size)
|
||||
logger.info('MLX Whisper 转录器创建成功')
|
||||
except Exception as e:
|
||||
logger.error(f"MLX Whisper 转录器创建失败: {e}")
|
||||
raise
|
||||
return _transcribers['mlx-whisper']
|
||||
logger.warning("MLX Whisper 不可用,请确保在 Apple 平台且已安装 mlx_whisper")
|
||||
raise ImportError("MLX Whisper 不可用")
|
||||
return _init_transcriber(TranscriberType.MLX_WHISPER, MLXWhisperTranscriber, model_size=model_size)
|
||||
|
||||
# 通用入口
|
||||
def get_transcriber(transcriber_type="fast-whisper", model_size="base", device="cuda"):
|
||||
"""
|
||||
获取指定类型的转录器实例
|
||||
|
||||
|
||||
参数:
|
||||
transcriber_type: 转录器类型,支持 "fast-whisper", "bcut", "kuaishou", "mlx-whisper"(仅Apple平台)
|
||||
model_size: 模型大小,whisper 和 mlx-whisper 特有参数
|
||||
device: 设备类型,whisper 特有参数
|
||||
|
||||
transcriber_type: 支持 "fast-whisper", "mlx-whisper", "bcut", "kuaishou", "groq"
|
||||
model_size: 模型大小,适用于 whisper 类
|
||||
device: 设备类型(如 cuda / cpu),仅 whisper 使用
|
||||
|
||||
返回:
|
||||
对应类型的转录器实例
|
||||
"""
|
||||
logger.info(f'获取转录器,类型: {transcriber_type}')
|
||||
if transcriber_type == "fast-whisper":
|
||||
whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE",model_size)
|
||||
logger.info(f'请求转录器类型: {transcriber_type}')
|
||||
|
||||
try:
|
||||
transcriber_enum = TranscriberType(transcriber_type)
|
||||
except ValueError:
|
||||
logger.warning(f'未知转录器类型 "{transcriber_type}",默认使用 fast-whisper')
|
||||
transcriber_enum = TranscriberType.FAST_WHISPER
|
||||
|
||||
whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE", model_size)
|
||||
|
||||
if transcriber_enum == TranscriberType.FAST_WHISPER:
|
||||
return get_whisper_transcriber(whisper_model_size, device=device)
|
||||
elif transcriber_type == "mlx-whisper":
|
||||
whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE",model_size)
|
||||
|
||||
elif transcriber_enum == TranscriberType.MLX_WHISPER:
|
||||
if not MLX_WHISPER_AVAILABLE:
|
||||
logger.warning("MLX Whisper 不可用,回退到 fast-whisper")
|
||||
return get_whisper_transcriber(whisper_model_size, device=device)
|
||||
return get_mlx_whisper_transcriber(whisper_model_size)
|
||||
elif transcriber_type == "bcut":
|
||||
|
||||
elif transcriber_enum == TranscriberType.BCUT:
|
||||
return get_bcut_transcriber()
|
||||
elif transcriber_type == "kuaishou":
|
||||
|
||||
elif transcriber_enum == TranscriberType.KUAISHOU:
|
||||
return get_kuaishou_transcriber()
|
||||
else:
|
||||
logger.warning(f'未知转录器类型 "{transcriber_type}",使用默认 whisper')
|
||||
whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE",model_size)
|
||||
return get_whisper_transcriber(whisper_model_size, device)
|
||||
|
||||
elif transcriber_enum == TranscriberType.GROQ:
|
||||
return get_groq_transcriber()
|
||||
|
||||
# fallback
|
||||
logger.warning(f'未识别转录器类型 "{transcriber_type}",使用 fast-whisper 作为默认')
|
||||
return get_whisper_transcriber(whisper_model_size, device=device)
|
||||
|
||||
@@ -4,6 +4,7 @@ import uvicorn
|
||||
from starlette.staticfiles import StaticFiles
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from app.core.exception_handlers import register_exception_handlers
|
||||
from app.db.model_dao import init_model_table
|
||||
from app.db.provider_dao import init_provider_table
|
||||
from app.utils.logger import get_logger
|
||||
@@ -34,12 +35,12 @@ if not os.path.exists(out_dir):
|
||||
app = create_app()
|
||||
app.mount(static_path, StaticFiles(directory=static_dir), name="static")
|
||||
app.mount("/uploads", StaticFiles(directory=uploads_dir), name="uploads")
|
||||
async def startup_event():
|
||||
register_handler()
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
register_exception_handlers(app)
|
||||
register_handler()
|
||||
ensure_ffmpeg_or_raise()
|
||||
register_handler()
|
||||
get_transcriber(transcriber_type=os.getenv("TRANSCRIBER_TYPE","fast-whisper"))
|
||||
init_video_task_table()
|
||||
init_provider_table()
|
||||
|
||||
Binary file not shown.
Reference in New Issue
Block a user