mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-05-06 20:42:52 +08:00
feat(transcriber): 更新转录器支持和模型下载逻辑
- 修改 NoteGenerator 类以支持动态选择转录器类型。 - 更新 MLXWhisperTranscriber 类,添加模型下载逻辑,确保模型存在时自动下载。 - 在 transcriber_provider.py 中优化 MLX Whisper 的环境变量处理,确保在不可用时回退到 fast-whisper。
This commit is contained in:
@@ -18,7 +18,7 @@ from app.models.notes_model import AudioDownloadResult
|
||||
from app.enmus.note_enums import DownloadQuality
|
||||
from app.models.transcriber_model import TranscriptResult
|
||||
from app.transcriber.base import Transcriber
|
||||
from app.transcriber.transcriber_provider import get_transcriber
|
||||
from app.transcriber.transcriber_provider import get_transcriber,_transcribers
|
||||
from app.transcriber.whisper import WhisperTranscriber
|
||||
import re
|
||||
|
||||
@@ -89,9 +89,9 @@ class NoteGenerator:
|
||||
:param transcriber: 选择的转义器
|
||||
:return:
|
||||
'''
|
||||
if self.transcriber_type == 'fast-whisper':
|
||||
logger.info("使用Whisper")
|
||||
return get_transcriber(transcriber_type='fast-whisper')
|
||||
if self.transcriber_type in _transcribers.keys():
|
||||
logger.info(f"使用{self.transcriber_type}转义器")
|
||||
return get_transcriber(transcriber_type=self.transcriber_type)
|
||||
else:
|
||||
logger.warning("不支持的转义器")
|
||||
raise ValueError(f"不支持的转义器:{self.transcriber}")
|
||||
|
||||
@@ -2,6 +2,7 @@ import mlx_whisper
|
||||
from pathlib import Path
|
||||
import os
|
||||
import platform
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from app.decorators.timeit import timeit
|
||||
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
|
||||
@@ -15,8 +16,7 @@ logger = get_logger(__name__)
|
||||
class MLXWhisperTranscriber(Transcriber):
|
||||
def __init__(
|
||||
self,
|
||||
model_size: str = "base",
|
||||
device: str = None, # MLX 会自动选择最佳设备
|
||||
model_size: str = "base"
|
||||
):
|
||||
# 检查平台
|
||||
if platform.system() != "Darwin":
|
||||
@@ -27,15 +27,21 @@ class MLXWhisperTranscriber(Transcriber):
|
||||
raise RuntimeError("必须设置环境变量 TRANSCRIBER_TYPE=mlx-whisper 才能使用 MLX Whisper")
|
||||
|
||||
self.model_size = model_size
|
||||
self.model_name = f"whisper-{model_size}-mlx"
|
||||
self.model_name = f"mlx-community/whisper-{model_size}"
|
||||
self.model_path = None
|
||||
|
||||
# 设置模型路径
|
||||
model_dir = get_model_dir("mlx-whisper")
|
||||
self.model_path = os.path.join(model_dir, self.model_name)
|
||||
|
||||
# 确保模型目录存在
|
||||
Path(model_dir).mkdir(parents=True, exist_ok=True)
|
||||
# 检查并下载模型
|
||||
if not Path(self.model_path).exists():
|
||||
logger.info(f"模型 {self.model_name} 不存在,开始下载...")
|
||||
snapshot_download(
|
||||
self.model_name,
|
||||
local_dir=self.model_path,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
logger.info("模型下载完成")
|
||||
|
||||
logger.info(f"初始化 MLX Whisper 转录器,模型:{self.model_name}")
|
||||
|
||||
@@ -45,7 +51,7 @@ class MLXWhisperTranscriber(Transcriber):
|
||||
# 使用 MLX Whisper 进行转录
|
||||
result = mlx_whisper.transcribe(
|
||||
file_path,
|
||||
path_or_hf_repo=f"mlx-community/{self.model_name}"
|
||||
path_or_hf_repo=f"{self.model_name}"
|
||||
)
|
||||
|
||||
# 转换为标准格式
|
||||
|
||||
@@ -98,9 +98,9 @@ def get_transcriber(transcriber_type="fast-whisper", model_size="base", device="
|
||||
whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE",model_size)
|
||||
return get_whisper_transcriber(whisper_model_size, device=device)
|
||||
elif transcriber_type == "mlx-whisper":
|
||||
whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE",model_size)
|
||||
if not MLX_WHISPER_AVAILABLE:
|
||||
logger.warning("MLX Whisper 不可用,回退到 fast-whisper")
|
||||
whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE",model_size)
|
||||
return get_whisper_transcriber(whisper_model_size, device=device)
|
||||
return get_mlx_whisper_transcriber(whisper_model_size)
|
||||
elif transcriber_type == "bcut":
|
||||
|
||||
Reference in New Issue
Block a user