feat(transcriber): 更新转录器支持和模型下载逻辑

- 修改 NoteGenerator 类以支持动态选择转录器类型。
- 更新 MLXWhisperTranscriber 类,添加模型下载逻辑,确保模型存在时自动下载。
- 在 transcriber_provider.py 中优化 MLX Whisper 的环境变量处理,确保在不可用时回退到 fast-whisper。
This commit is contained in:
SurfRid3r
2025-04-20 22:50:07 +08:00
parent 369de19572
commit a567788448
3 changed files with 18 additions and 12 deletions

View File

@@ -18,7 +18,7 @@ from app.models.notes_model import AudioDownloadResult
from app.enmus.note_enums import DownloadQuality
from app.models.transcriber_model import TranscriptResult
from app.transcriber.base import Transcriber
from app.transcriber.transcriber_provider import get_transcriber
from app.transcriber.transcriber_provider import get_transcriber,_transcribers
from app.transcriber.whisper import WhisperTranscriber
import re
@@ -89,9 +89,9 @@ class NoteGenerator:
:param transcriber: 选择的转义器
:return:
'''
if self.transcriber_type == 'fast-whisper':
logger.info("使用Whisper")
return get_transcriber(transcriber_type='fast-whisper')
if self.transcriber_type in _transcribers.keys():
logger.info(f"使用{self.transcriber_type}转义器")
return get_transcriber(transcriber_type=self.transcriber_type)
else:
logger.warning("不支持的转义器")
raise ValueError(f"不支持的转义器:{self.transcriber}")

View File

@@ -2,6 +2,7 @@ import mlx_whisper
from pathlib import Path
import os
import platform
from huggingface_hub import snapshot_download
from app.decorators.timeit import timeit
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
@@ -15,8 +16,7 @@ logger = get_logger(__name__)
class MLXWhisperTranscriber(Transcriber):
def __init__(
self,
model_size: str = "base",
device: str = None, # MLX 会自动选择最佳设备
model_size: str = "base"
):
# 检查平台
if platform.system() != "Darwin":
@@ -27,15 +27,21 @@ class MLXWhisperTranscriber(Transcriber):
raise RuntimeError("必须设置环境变量 TRANSCRIBER_TYPE=mlx-whisper 才能使用 MLX Whisper")
self.model_size = model_size
self.model_name = f"whisper-{model_size}-mlx"
self.model_name = f"mlx-community/whisper-{model_size}"
self.model_path = None
# 设置模型路径
model_dir = get_model_dir("mlx-whisper")
self.model_path = os.path.join(model_dir, self.model_name)
# 确保模型目录存在
Path(model_dir).mkdir(parents=True, exist_ok=True)
# 检查并下载模型
if not Path(self.model_path).exists():
logger.info(f"模型 {self.model_name} 不存在,开始下载...")
snapshot_download(
self.model_name,
local_dir=self.model_path,
local_dir_use_symlinks=False,
)
logger.info("模型下载完成")
logger.info(f"初始化 MLX Whisper 转录器,模型:{self.model_name}")
@@ -45,7 +51,7 @@ class MLXWhisperTranscriber(Transcriber):
# 使用 MLX Whisper 进行转录
result = mlx_whisper.transcribe(
file_path,
path_or_hf_repo=f"mlx-community/{self.model_name}"
path_or_hf_repo=f"{self.model_name}"
)
# 转换为标准格式

View File

@@ -98,9 +98,9 @@ def get_transcriber(transcriber_type="fast-whisper", model_size="base", device="
whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE",model_size)
return get_whisper_transcriber(whisper_model_size, device=device)
elif transcriber_type == "mlx-whisper":
whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE",model_size)
if not MLX_WHISPER_AVAILABLE:
logger.warning("MLX Whisper 不可用,回退到 fast-whisper")
whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE",model_size)
return get_whisper_transcriber(whisper_model_size, device=device)
return get_mlx_whisper_transcriber(whisper_model_size)
elif transcriber_type == "bcut":