Files
BiliNote/backend/app/transcriber/mlx_whisper_transcriber.py
huangjianwu be5e1637fa fix(mlx-whisper): 修正 huggingface 仓库 ID 命名
mlx-community 上 Whisper 仓库的命名实际是 'whisper-{size}-mlx'(large-v3-turbo 例外,无 -mlx 后缀)。
之前 hardcode 拼成 'mlx-community/whisper-{size}' 在 HF 上不存在,下载会 404:

  Repository Not Found for url:
    https://huggingface.co/api/models/mlx-community/whisper-small/revision/main.

修复:
- 在 mlx_whisper_transcriber.py 加 MLX_MODEL_MAP(已用 huggingface API 核对过命名)+ resolve_mlx_repo_id() 帮助函数
- routers/config.py 的 _do_download_mlx_whisper 与 _check ... 路径生成都改用同一份映射表
- 给 transcriber_models_status 的每条 mlx 状态加 available 字段,避免后续若有不支持的 size 时静默失败

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-07 11:59:02 +08:00

113 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import mlx_whisper
from pathlib import Path
import os
import platform
from huggingface_hub import snapshot_download
from app.decorators.timeit import timeit
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
from app.transcriber.base import Transcriber
from app.utils.logger import get_logger
from app.utils.path_helper import get_model_dir
from events import transcription_finished
logger = get_logger(__name__)
# mlx-community 上的 Whisper 仓库命名不统一:常规版本是 'whisper-{size}-mlx'
# turbo 例外没有 -mlx 后缀。直接拼 'mlx-community/whisper-{size}' 会 404。
# 已用 https://huggingface.co/api/models?author=mlx-community&search=whisper 核对过。
MLX_MODEL_MAP = {
"tiny": "mlx-community/whisper-tiny-mlx",
"base": "mlx-community/whisper-base-mlx",
"small": "mlx-community/whisper-small-mlx",
"medium": "mlx-community/whisper-medium-mlx",
"large-v1": "mlx-community/whisper-large-v1-mlx",
"large-v2": "mlx-community/whisper-large-v2-mlx",
"large-v3": "mlx-community/whisper-large-v3-mlx",
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
}
def resolve_mlx_repo_id(model_size: str) -> str:
if model_size not in MLX_MODEL_MAP:
raise ValueError(
f"不支持的 MLX Whisper 模型大小: {model_size}"
f"可选: {', '.join(MLX_MODEL_MAP.keys())}"
)
return MLX_MODEL_MAP[model_size]
class MLXWhisperTranscriber(Transcriber):
def __init__(
self,
model_size: str = "base"
):
# 检查平台
if platform.system() != "Darwin":
raise RuntimeError("MLX Whisper 仅支持 Apple 平台")
# 检查环境变量
if os.environ.get("TRANSCRIBER_TYPE") != "mlx-whisper":
raise RuntimeError("必须设置环境变量 TRANSCRIBER_TYPE=mlx-whisper 才能使用 MLX Whisper")
self.model_size = model_size
self.model_name = resolve_mlx_repo_id(model_size)
self.model_path = None
# 设置模型路径
model_dir = get_model_dir("mlx-whisper")
self.model_path = os.path.join(model_dir, self.model_name)
# 检查并下载模型
if not Path(self.model_path).exists():
logger.info(f"模型 {self.model_name} 不存在,开始下载...")
snapshot_download(
self.model_name,
local_dir=self.model_path,
local_dir_use_symlinks=False,
)
logger.info("模型下载完成")
logger.info(f"初始化 MLX Whisper 转录器,模型:{self.model_name}")
@timeit
def transcript(self, file_path: str) -> TranscriptResult:
try:
# 使用 MLX Whisper 进行转录
result = mlx_whisper.transcribe(
file_path,
path_or_hf_repo=f"{self.model_name}"
)
# 转换为标准格式
segments = []
full_text = ""
for segment in result["segments"]:
text = segment["text"].strip()
full_text += text + " "
segments.append(TranscriptSegment(
start=segment["start"],
end=segment["end"],
text=text
))
transcript_result = TranscriptResult(
language=result.get("language", "unknown"),
full_text=full_text.strip(),
segments=segments,
raw=result
)
# self.on_finish(file_path, transcript_result)
return transcript_result
except Exception as e:
logger.error(f"MLX Whisper 转写失败:{e}")
raise e
def on_finish(self, video_path: str, result: TranscriptResult) -> None:
logger.info("MLX Whisper 转写完成")
transcription_finished.send({
"file_path": video_path,
})