Merge pull request #2 from SurfRid3r/dev_surfrid3r_mlxwhisper

feat: 引入 MLX‑Whisper 转录器,提升 macOS 平台whisper转录性能
This commit is contained in:
SurfRid3r
2025-04-20 23:03:50 +08:00
committed by GitHub
4 changed files with 132 additions and 8 deletions

View File

@@ -28,5 +28,5 @@ MODEl_PROVIDER= #如果不是openai 请修改 deepseek/qwen
FFMPEG_BIN_PATH=
# transcriber 相关配置
TRANSCRIBER_TYPE=fast-whisper # fast-whisper/bcut/kuaishou
TRANSCRIBER_TYPE=fast-whisper # fast-whisper/bcut/kuaishou/mlx-whisper(仅Apple平台)
WHISPER_MODEL_SIZE=base

View File

@@ -18,7 +18,7 @@ from app.models.notes_model import AudioDownloadResult
from app.enmus.note_enums import DownloadQuality
from app.models.transcriber_model import TranscriptResult
from app.transcriber.base import Transcriber
from app.transcriber.transcriber_provider import get_transcriber
from app.transcriber.transcriber_provider import get_transcriber,_transcribers
from app.transcriber.whisper import WhisperTranscriber
import re
@@ -89,9 +89,9 @@ class NoteGenerator:
:param transcriber: 选择的转义器
:return:
'''
if self.transcriber_type == 'fast-whisper':
logger.info("使用Whisper")
return get_transcriber(transcriber_type='fast-whisper')
if self.transcriber_type in _transcribers.keys():
logger.info(f"使用{self.transcriber_type}转义器")
return get_transcriber(transcriber_type=self.transcriber_type)
else:
logger.warning("不支持的转义器")
raise ValueError(f"不支持的转义器:{self.transcriber}")

View File

@@ -0,0 +1,88 @@
import mlx_whisper
from pathlib import Path
import os
import platform
from huggingface_hub import snapshot_download
from app.decorators.timeit import timeit
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
from app.transcriber.base import Transcriber
from app.utils.logger import get_logger
from app.utils.path_helper import get_model_dir
from events import transcription_finished
logger = get_logger(__name__)
class MLXWhisperTranscriber(Transcriber):
def __init__(
self,
model_size: str = "base"
):
# 检查平台
if platform.system() != "Darwin":
raise RuntimeError("MLX Whisper 仅支持 Apple 平台")
# 检查环境变量
if os.environ.get("TRANSCRIBER_TYPE") != "mlx-whisper":
raise RuntimeError("必须设置环境变量 TRANSCRIBER_TYPE=mlx-whisper 才能使用 MLX Whisper")
self.model_size = model_size
self.model_name = f"mlx-community/whisper-{model_size}"
self.model_path = None
# 设置模型路径
model_dir = get_model_dir("mlx-whisper")
self.model_path = os.path.join(model_dir, self.model_name)
# 检查并下载模型
if not Path(self.model_path).exists():
logger.info(f"模型 {self.model_name} 不存在,开始下载...")
snapshot_download(
self.model_name,
local_dir=self.model_path,
local_dir_use_symlinks=False,
)
logger.info("模型下载完成")
logger.info(f"初始化 MLX Whisper 转录器,模型:{self.model_name}")
@timeit
def transcript(self, file_path: str) -> TranscriptResult:
try:
# 使用 MLX Whisper 进行转录
result = mlx_whisper.transcribe(
file_path,
path_or_hf_repo=f"{self.model_name}"
)
# 转换为标准格式
segments = []
full_text = ""
for segment in result["segments"]:
text = segment["text"].strip()
full_text += text + " "
segments.append(TranscriptSegment(
start=segment["start"],
end=segment["end"],
text=text
))
transcript_result = TranscriptResult(
language=result.get("language", "unknown"),
full_text=full_text.strip(),
segments=segments,
raw=result
)
self.on_finish(file_path, transcript_result)
return transcript_result
except Exception as e:
logger.error(f"MLX Whisper 转写失败:{e}")
raise e
def on_finish(self, video_path: str, result: TranscriptResult) -> None:
logger.info("MLX Whisper 转写完成")
transcription_finished.send({
"file_path": video_path,
})

View File

@@ -1,4 +1,5 @@
import os
import platform
from app.transcriber.whisper import WhisperTranscriber
from app.transcriber.bcut import BcutTranscriber
@@ -6,13 +7,26 @@ from app.transcriber.kuaishou import KuaishouTranscriber
from app.utils.logger import get_logger
logger = get_logger(__name__)
# 只在Apple平台且设置了环境变量时才导入MLX Whisper
if platform.system() == "Darwin" and os.environ.get("TRANSCRIBER_TYPE") == "mlx-whisper":
try:
from app.transcriber.mlx_whisper_transcriber import MLXWhisperTranscriber
MLX_WHISPER_AVAILABLE = True
logger.info("MLX Whisper 可用,已导入")
except ImportError:
MLX_WHISPER_AVAILABLE = False
logger.warning("MLX Whisper 导入失败,可能未安装或平台不支持")
else:
MLX_WHISPER_AVAILABLE = False
logger.info('初始化转录服务提供器')
# 维护各种转录器的单例实例
_transcribers = {
'whisper': None,
'bcut': None,
'kuaishou': None
'kuaishou': None,
'mlx-whisper': None
}
def get_whisper_transcriber(model_size="base", device="cuda"):
@@ -51,13 +65,29 @@ def get_kuaishou_transcriber():
raise
return _transcribers['kuaishou']
def get_mlx_whisper_transcriber(model_size="base"):
"""获取 MLX Whisper 转录器实例"""
if not MLX_WHISPER_AVAILABLE:
logger.warning("MLX Whisper 不可用请确保在Apple平台且已安装mlx_whisper")
raise ImportError("MLX Whisper 不可用请确保在Apple平台且已安装mlx_whisper")
if _transcribers['mlx-whisper'] is None:
logger.info(f'创建 MLX Whisper 转录器实例,参数:{model_size}')
try:
_transcribers['mlx-whisper'] = MLXWhisperTranscriber(model_size=model_size)
logger.info('MLX Whisper 转录器创建成功')
except Exception as e:
logger.error(f"MLX Whisper 转录器创建失败: {e}")
raise
return _transcribers['mlx-whisper']
def get_transcriber(transcriber_type="fast-whisper", model_size="base", device="cuda"):
"""
获取指定类型的转录器实例
参数:
transcriber_type: 转录器类型,支持 "fast-whisper", "bcut", "kuaishou"
model_size: 模型大小whisper 特有参数
transcriber_type: 转录器类型,支持 "fast-whisper", "bcut", "kuaishou", "mlx-whisper"(仅Apple平台)
model_size: 模型大小whisper 和 mlx-whisper 特有参数
device: 设备类型whisper 特有参数
返回:
@@ -67,6 +97,12 @@ def get_transcriber(transcriber_type="fast-whisper", model_size="base", device="
if transcriber_type == "fast-whisper":
whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE",model_size)
return get_whisper_transcriber(whisper_model_size, device=device)
elif transcriber_type == "mlx-whisper":
whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE",model_size)
if not MLX_WHISPER_AVAILABLE:
logger.warning("MLX Whisper 不可用,回退到 fast-whisper")
return get_whisper_transcriber(whisper_model_size, device=device)
return get_mlx_whisper_transcriber(whisper_model_size)
elif transcriber_type == "bcut":
return get_bcut_transcriber()
elif transcriber_type == "kuaishou":