feat(backend): 添加 Groq供应商支持并优化笔记生成流程- 在 builtin_providers.json 中添加 Groq 供应商信息

- 实现 GroqTranscriber 类以支持 Groq 语音转录服务
- 新增异常处理中间件以提高系统稳定性
- 优化笔记生成流程,增加错误处理和日志记录
- 添加思维导图功能和相关组件
-重构 Markdown 查看器以支持切换视图模式
This commit is contained in:
黄建武
2025-05-12 14:59:06 +08:00
parent b2034c0865
commit 6ff8b4d90f
16 changed files with 743 additions and 352 deletions

View File

View File

@@ -0,0 +1,38 @@
# app/core/exception_handlers.py
from fastapi import Request, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.utils.logger import get_logger
from app.utils.response import ResponseWrapper
from app.utils.status_code import StatusCode
logger = get_logger(__name__)
def register_exception_handlers(app):
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
errors = []
for err in exc.errors():
loc = err.get("loc", [])
field = loc[-1] if loc else "body"
msg = err.get("msg", "参数不合法")
errors.append({"field": field, "error": msg})
return JSONResponse(
status_code=400,
content=ResponseWrapper.error(msg="参数验证失败", code=StatusCode.PARAM_ERROR, data=errors)
)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
return JSONResponse(
status_code=exc.status_code,
content=ResponseWrapper.error(msg=str(exc.detail), code=StatusCode.FAIL)
)
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
logger.exception(f"服务器内部错误: {exc}")
return JSONResponse(
status_code=500,
content=ResponseWrapper.error(msg="服务器内部错误", code=StatusCode.FAIL, data=str(exc))
)

View File

@@ -38,5 +38,13 @@
"logo": "Gemini",
"api_key": "",
"base_url": "https://generativelanguage.googleapis.com/v1beta/openai/"
},
{
"id": "groq",
"name": "Groq",
"type": "built-in",
"logo": "Groq",
"api_key": "",
"base_url": "https://api.groq.com/openai/v1"
}
]

View File

@@ -13,8 +13,10 @@ class OpenAICompatibleProvider:
@staticmethod
def test_connection(api_key: str, base_url: str) -> bool:
print(api_key)
try:
client = OpenAI(api_key=api_key, base_url=base_url)
client.models.list()
return True
except Exception as e:

View File

@@ -1,6 +1,8 @@
import json
from dataclasses import asdict
from fastapi import HTTPException
from app.downloaders.local_downloader import LocalDownloader
from app.enmus.task_status_enums import TaskStatus
import os
@@ -33,6 +35,7 @@ from app.transcriber.whisper import WhisperTranscriber
import re
from app.utils.note_helper import replace_content_markers
from app.utils.status_code import StatusCode
from app.utils.video_helper import generate_screenshot
# from app.services.whisperer import transcribe_audio
@@ -143,7 +146,15 @@ class NoteGenerator:
return new_markdown
except Exception as e:
logger.error(f"截图生成失败:{e}")
raise e
raise HTTPException(
status_code=500,
detail={
"code": StatusCode.DOWNLOAD_ERROR,
"msg": f"截图生成失败",
"error": str(e)
}
)
@staticmethod
def delete_note(video_id: str, platform: str):
@@ -226,8 +237,16 @@ class NoteGenerator:
save_quality=90,
).run()
except Exception as e:
logger.error(f" 下载视频失败task_id={task_id},错误信息:{e}")
logger.error(f"Error 下载视频失败task_id={task_id},错误信息:{e}")
self.update_task_status(task_id, TaskStatus.FAILED, message=f"下载音频失败:{e}")
raise HTTPException(
status_code=500,
detail={
"code": StatusCode.DOWNLOAD_ERROR,
"msg": f"下载视频失败task_id={task_id}",
"error": str(e)
}
)
# 没有音频缓存就下载音频(可能同时也带上视频)
if audio is None:
@@ -241,9 +260,17 @@ class NoteGenerator:
json.dump(asdict(audio), f, ensure_ascii=False, indent=2)
logger.info(f"音频下载并缓存成功task_id={task_id}")
except Exception as e:
logger.error(f" 下载音频失败task_id={task_id},错误信息:{e}")
logger.error(f"Error 下载音频失败task_id={task_id},错误信息:{e}")
self.update_task_status(task_id, TaskStatus.FAILED, message=f"下载音频失败:{e}")
raise e
raise HTTPException(
status_code=500,
detail={
"code": StatusCode.DOWNLOAD_ERROR,
"msg": f"下载音频失败task_id={task_id}",
"error": str(e)
}
)
# -------- 2. 转写文字 --------
try:
@@ -259,7 +286,7 @@ class NoteGenerator:
segments=[TranscriptSegment(**seg) for seg in transcript_data["segments"]]
)
except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"⚠️ 读取转录缓存失败重新转录task_id={task_id},错误信息:{e}")
logger.warning(f"Warning 读取转录缓存失败重新转录task_id={task_id},错误信息:{e}")
transcript: TranscriptResult = self.transcriber.transcript(file_path=audio.file_path)
with open(transcript_cache_path, "w", encoding="utf-8") as f:
json.dump(asdict(transcript), f, ensure_ascii=False, indent=2)
@@ -269,9 +296,16 @@ class NoteGenerator:
json.dump(asdict(transcript), f, ensure_ascii=False, indent=2)
logger.info(f"文字转写并缓存成功task_id={task_id}")
except Exception as e:
logger.error(f" 转写文字失败task_id={task_id},错误信息:{e}")
logger.error(f"Error 转写文字失败task_id={task_id},错误信息:{e}")
self.update_task_status(task_id, TaskStatus.FAILED, message=f"转写文字失败:{e}")
raise e
raise HTTPException(
status_code=500,
detail={
"code": StatusCode.GENERATE_ERROR, # =1003
"msg": f"转写文字失败task_id={task_id}",
"error": str(e)
}
)
# -------- 3. 总结内容 --------
try:
@@ -298,9 +332,16 @@ class NoteGenerator:
f.write(markdown)
logger.info(f"GPT总结并缓存成功task_id={task_id}")
except Exception as e:
logger.error(f" 总结内容失败task_id={task_id},错误信息:{e}")
logger.error(f"Error 总结内容失败task_id={task_id},错误信息:{e}")
self.update_task_status(task_id, TaskStatus.FAILED, message=f"总结内容失败:{e}")
raise e
raise HTTPException(
status_code=500,
detail={
"code": StatusCode.GENERATE_ERROR, # =1003
"msg": f"总结内容失败task_id={task_id}",
"error": str(e)
}
)
# -------- 4. 插入截图 --------
if _format and 'screenshot' in _format:
@@ -308,12 +349,12 @@ class NoteGenerator:
markdown = self.insert_screenshots_into_markdown(markdown, self.video_path, image_base_url,
output_dir, _format)
except Exception as e:
logger.warning(f"⚠️ 插入截图失败跳过处理task_id={task_id},错误信息:{e}")
logger.warning(f"Warning 插入截图失败跳过处理task_id={task_id},错误信息:{e}")
if _format and 'link' in _format:
try:
markdown = replace_content_markers(markdown, video_id=audio.video_id, platform=platform)
except Exception as e:
logger.warning(f"⚠️ 插入链接失败跳过处理task_id={task_id},错误信息:{e}")
logger.warning(f"Warning 插入链接失败跳过处理task_id={task_id},错误信息:{e}")
# 注意:截图失败不终止整体流程
# -------- 5. 保存数据库记录 --------
@@ -322,7 +363,7 @@ class NoteGenerator:
# -------- 6. 完成 --------
self.update_task_status(task_id, TaskStatus.SUCCESS)
logger.info(f" 笔记生成成功task_id={task_id}")
logger.info(f"succeed 笔记生成成功task_id={task_id}")
# TODO :改为前端一键清除缓存
# if platform != 'local':
# transcription_finished.send({
@@ -335,6 +376,15 @@ class NoteGenerator:
)
except Exception as e:
logger.error(f" 笔记生成流程异常终止task_id={task_id},错误信息:{e}")
logger.error(f"Error 笔记生成流程异常终止task_id={task_id},错误信息:{e}")
self.update_task_status(task_id, TaskStatus.FAILED, message=str(e))
raise f'❌ 笔记生成流程异常终止task_id={task_id},错误信息:{e}'
# 返回结构化错误信息给前端(可以用于日志 + 显示 + 错误定位)
raise HTTPException(
status_code=500,
detail={
"code": StatusCode.FAIL,
"msg": f"笔记生成流程异常终止task_id={task_id}",
"error": str(e)
}
)

View File

@@ -0,0 +1,52 @@
from abc import ABC
import os
from app.decorators.timeit import timeit
from app.models.transcriber_model import TranscriptResult, TranscriptSegment
from app.services.provider import ProviderService
from app.transcriber.base import Transcriber
from openai import OpenAI
from dotenv import load_dotenv
load_dotenv()
class GroqTranscriber(Transcriber, ABC):
@timeit
def transcript(self, file_path: str) -> TranscriptResult:
provider = ProviderService.get_provider_by_id('groq')
if not provider:
raise Exception("Groq 供应商未配置,请配置以后使用。")
client = OpenAI(
api_key=provider.get('api_key'),
base_url=provider.get('base_url')
)
filename = file_path
with open(filename, "rb") as file:
transcription = client.audio.transcriptions.create(
file=(filename, file.read()),
model=os.getenv('GROQ_TRANSCRIBER_MODEL'),
response_format="verbose_json",
)
print(transcription.text)
print(transcription)
segments = []
full_text = ""
for seg in transcription.segments:
text = seg.text.strip()
full_text += text + " "
segments.append(TranscriptSegment(
start=seg.start,
end=seg.end,
text=text
))
result = TranscriptResult(
language=transcription.language,
full_text=full_text.strip(),
segments=segments,
raw=transcription.to_dict()
)
return result

View File

@@ -1,113 +1,115 @@
import os
import platform
from enum import Enum
from app.transcriber.groq import GroqTranscriber
from app.transcriber.whisper import WhisperTranscriber
from app.transcriber.bcut import BcutTranscriber
from app.transcriber.kuaishou import KuaishouTranscriber
from app.utils.logger import get_logger
logger = get_logger(__name__)
# 只在Apple平台且设置了环境变量时才导入MLX Whisper
class TranscriberType(str, Enum):
FAST_WHISPER = "fast-whisper"
MLX_WHISPER = "mlx-whisper"
BCUT = "bcut"
KUAISHOU = "kuaishou"
GROQ = "groq"
# 仅在 Apple 平台启用 MLX Whisper
MLX_WHISPER_AVAILABLE = False
if platform.system() == "Darwin" and os.environ.get("TRANSCRIBER_TYPE") == "mlx-whisper":
try:
from app.transcriber.mlx_whisper_transcriber import MLXWhisperTranscriber
MLX_WHISPER_AVAILABLE = True
logger.info("MLX Whisper 可用,已导入")
except ImportError:
MLX_WHISPER_AVAILABLE = False
logger.warning("MLX Whisper 导入失败,可能未安装或平台不支持")
else:
MLX_WHISPER_AVAILABLE = False
logger.info('初始化转录服务提供器')
# 维护各种转录器单例实例
# 转录器单例缓存
_transcribers = {
'bcut': None,
'kuaishou': None,
'mlx-whisper': None,
'fast-whisper':None
TranscriberType.FAST_WHISPER: None,
TranscriberType.MLX_WHISPER: None,
TranscriberType.BCUT: None,
TranscriberType.KUAISHOU: None,
TranscriberType.GROQ: None,
}
def get_whisper_transcriber(model_size="base", device="cuda"):
"""获取 Whisper 转录器实例"""
if _transcribers['fast-whisper'] is None:
logger.info(f'创建 Whisper 转录器实例,参数:{model_size}, {device}')
# 公共实例初始化函数
def _init_transcriber(key: TranscriberType, cls, *args, **kwargs):
if _transcribers[key] is None:
logger.info(f'创建 {cls.__name__} 实例: {key}')
try:
_transcribers['whisper'] = WhisperTranscriber(model_size=model_size, device=device)
logger.info('Whisper 转录器创建成功')
_transcribers[key] = cls(*args, **kwargs)
logger.info(f'{cls.__name__} 创建成功')
except Exception as e:
logger.error(f"Whisper 转录器创建失败: {e}")
logger.error(f"{cls.__name__} 创建失败: {e}")
raise
return _transcribers['whisper']
return _transcribers[key]
# 各类型获取方法
def get_groq_transcriber():
return _init_transcriber(TranscriberType.GROQ, GroqTranscriber)
def get_whisper_transcriber(model_size="base", device="cuda"):
return _init_transcriber(TranscriberType.FAST_WHISPER, WhisperTranscriber, model_size=model_size, device=device)
def get_bcut_transcriber():
"""获取 Bcut 转录器实例"""
if _transcribers['bcut'] is None:
logger.info('创建 Bcut 转录器实例')
try:
_transcribers['bcut'] = BcutTranscriber()
logger.info('Bcut 转录器创建成功')
except Exception as e:
logger.error(f"Bcut 转录器创建失败: {e}")
raise
return _transcribers['bcut']
return _init_transcriber(TranscriberType.BCUT, BcutTranscriber)
def get_kuaishou_transcriber():
"""获取快手转录器实例"""
if _transcribers['kuaishou'] is None:
logger.info('创建快手转录器实例')
try:
_transcribers['kuaishou'] = KuaishouTranscriber()
logger.info('快手转录器创建成功')
except Exception as e:
logger.error(f"快手转录器创建失败: {e}")
raise
return _transcribers['kuaishou']
return _init_transcriber(TranscriberType.KUAISHOU, KuaishouTranscriber)
def get_mlx_whisper_transcriber(model_size="base"):
"""获取 MLX Whisper 转录器实例"""
if not MLX_WHISPER_AVAILABLE:
logger.warning("MLX Whisper 不可用请确保在Apple平台且已安装mlx_whisper")
raise ImportError("MLX Whisper 不可用请确保在Apple平台且已安装mlx_whisper")
if _transcribers['mlx-whisper'] is None:
logger.info(f'创建 MLX Whisper 转录器实例,参数:{model_size}')
try:
_transcribers['mlx-whisper'] = MLXWhisperTranscriber(model_size=model_size)
logger.info('MLX Whisper 转录器创建成功')
except Exception as e:
logger.error(f"MLX Whisper 转录器创建失败: {e}")
raise
return _transcribers['mlx-whisper']
logger.warning("MLX Whisper 不可用,请确保在 Apple 平台且已安装 mlx_whisper")
raise ImportError("MLX Whisper 不可用")
return _init_transcriber(TranscriberType.MLX_WHISPER, MLXWhisperTranscriber, model_size=model_size)
# 通用入口
def get_transcriber(transcriber_type="fast-whisper", model_size="base", device="cuda"):
"""
获取指定类型的转录器实例
参数:
transcriber_type: 转录器类型,支持 "fast-whisper", "bcut", "kuaishou", "mlx-whisper"(仅Apple平台)
model_size: 模型大小,whisper 和 mlx-whisper 特有参数
device: 设备类型whisper 特有参数
transcriber_type: 支持 "fast-whisper", "mlx-whisper", "bcut", "kuaishou", "groq"
model_size: 模型大小,适用于 whisper
device: 设备类型(如 cuda / cpuwhisper 使用
返回:
对应类型的转录器实例
"""
logger.info(f'获取转录器类型: {transcriber_type}')
if transcriber_type == "fast-whisper":
whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE",model_size)
logger.info(f'请求转录器类型: {transcriber_type}')
try:
transcriber_enum = TranscriberType(transcriber_type)
except ValueError:
logger.warning(f'未知转录器类型 "{transcriber_type}",默认使用 fast-whisper')
transcriber_enum = TranscriberType.FAST_WHISPER
whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE", model_size)
if transcriber_enum == TranscriberType.FAST_WHISPER:
return get_whisper_transcriber(whisper_model_size, device=device)
elif transcriber_type == "mlx-whisper":
whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE",model_size)
elif transcriber_enum == TranscriberType.MLX_WHISPER:
if not MLX_WHISPER_AVAILABLE:
logger.warning("MLX Whisper 不可用,回退到 fast-whisper")
return get_whisper_transcriber(whisper_model_size, device=device)
return get_mlx_whisper_transcriber(whisper_model_size)
elif transcriber_type == "bcut":
elif transcriber_enum == TranscriberType.BCUT:
return get_bcut_transcriber()
elif transcriber_type == "kuaishou":
elif transcriber_enum == TranscriberType.KUAISHOU:
return get_kuaishou_transcriber()
else:
logger.warning(f'未知转录器类型 "{transcriber_type}",使用默认 whisper')
whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE",model_size)
return get_whisper_transcriber(whisper_model_size, device)
elif transcriber_enum == TranscriberType.GROQ:
return get_groq_transcriber()
# fallback
logger.warning(f'未识别转录器类型 "{transcriber_type}",使用 fast-whisper 作为默认')
return get_whisper_transcriber(whisper_model_size, device=device)

View File

@@ -4,6 +4,7 @@ import uvicorn
from starlette.staticfiles import StaticFiles
from dotenv import load_dotenv
from app.core.exception_handlers import register_exception_handlers
from app.db.model_dao import init_model_table
from app.db.provider_dao import init_provider_table
from app.utils.logger import get_logger
@@ -34,12 +35,12 @@ if not os.path.exists(out_dir):
app = create_app()
app.mount(static_path, StaticFiles(directory=static_dir), name="static")
app.mount("/uploads", StaticFiles(directory=uploads_dir), name="uploads")
async def startup_event():
register_handler()
@app.on_event("startup")
async def startup_event():
register_exception_handlers(app)
register_handler()
ensure_ffmpeg_or_raise()
register_handler()
get_transcriber(transcriber_type=os.getenv("TRANSCRIBER_TYPE","fast-whisper"))
init_video_task_table()
init_provider_table()

Binary file not shown.