From 369de195728adfc3070ddf60f7c7186aaa424282 Mon Sep 17 00:00:00 2001 From: SurfRid3r Date: Sun, 20 Apr 2025 00:37:48 +0800 Subject: [PATCH] =?UTF-8?q?feat(transcriber):=20=E6=B7=BB=E5=8A=A0=20MLX?= =?UTF-8?q?=20Whisper=20=E8=BD=AC=E5=BD=95=E5=99=A8=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 MLXWhisperTranscriber 类,支持在 Apple 平台上进行转录。 - 更新 transcriber_provider.py,动态导入 MLX Whisper 转录器并添加相应的环境变量检查。 - 修改 .env.example 文件,更新 TRANSCRIBER_TYPE 配置说明以包含 mlx-whisper 选项。 --- .env.example | 2 +- .../transcriber/mlx_whisper_transcriber.py | 82 +++++++++++++++++++ .../app/transcriber/transcriber_provider.py | 42 +++++++++- 3 files changed, 122 insertions(+), 4 deletions(-) create mode 100644 backend/app/transcriber/mlx_whisper_transcriber.py diff --git a/.env.example b/.env.example index fe158f5..5824a49 100644 --- a/.env.example +++ b/.env.example @@ -28,5 +28,5 @@ MODEl_PROVIDER= #如果不是openai 请修改 deepseek/qwen FFMPEG_BIN_PATH= # transcriber 相关配置 -TRANSCRIBER_TYPE=fast-whisper # fast-whisper/bcut/kuaishou +TRANSCRIBER_TYPE=fast-whisper # fast-whisper/bcut/kuaishou/mlx-whisper(仅Apple平台) WHISPER_MODEL_SIZE=base \ No newline at end of file diff --git a/backend/app/transcriber/mlx_whisper_transcriber.py b/backend/app/transcriber/mlx_whisper_transcriber.py new file mode 100644 index 0000000..6764748 --- /dev/null +++ b/backend/app/transcriber/mlx_whisper_transcriber.py @@ -0,0 +1,82 @@ +import mlx_whisper +from pathlib import Path +import os +import platform + +from app.decorators.timeit import timeit +from app.models.transcriber_model import TranscriptSegment, TranscriptResult +from app.transcriber.base import Transcriber +from app.utils.logger import get_logger +from app.utils.path_helper import get_model_dir +from events import transcription_finished + +logger = get_logger(__name__) + +class MLXWhisperTranscriber(Transcriber): + def __init__( + self, + model_size: str = "base", + device: str = None, # MLX 会自动选择最佳设备 + ): + # 检查平台 + if platform.system() != "Darwin": + raise RuntimeError("MLX Whisper 仅支持 Apple 平台") + + # 检查环境变量 + if os.environ.get("TRANSCRIBER_TYPE") != "mlx-whisper": + raise RuntimeError("必须设置环境变量 TRANSCRIBER_TYPE=mlx-whisper 才能使用 MLX Whisper") + + self.model_size = model_size + self.model_name = f"whisper-{model_size}-mlx" + self.model_path = None + + # 设置模型路径 + model_dir = get_model_dir("mlx-whisper") + self.model_path = os.path.join(model_dir, self.model_name) + + # 确保模型目录存在 + Path(model_dir).mkdir(parents=True, exist_ok=True) + + logger.info(f"初始化 MLX Whisper 转录器,模型:{self.model_name}") + + @timeit + def transcript(self, file_path: str) -> TranscriptResult: + try: + # 使用 MLX Whisper 进行转录 + result = mlx_whisper.transcribe( + file_path, + path_or_hf_repo=f"mlx-community/{self.model_name}" + ) + + # 转换为标准格式 + segments = [] + full_text = "" + + for segment in result["segments"]: + text = segment["text"].strip() + full_text += text + " " + segments.append(TranscriptSegment( + start=segment["start"], + end=segment["end"], + text=text + )) + + transcript_result = TranscriptResult( + language=result.get("language", "unknown"), + full_text=full_text.strip(), + segments=segments, + raw=result + ) + + self.on_finish(file_path, transcript_result) + return transcript_result + + except Exception as e: + logger.error(f"MLX Whisper 转写失败:{e}") + raise e + + def on_finish(self, video_path: str, result: TranscriptResult) -> None: + logger.info("MLX Whisper 转写完成") + transcription_finished.send({ + "file_path": video_path, + }) \ No newline at end of file diff --git a/backend/app/transcriber/transcriber_provider.py b/backend/app/transcriber/transcriber_provider.py index a59af19..347672b 100644 --- a/backend/app/transcriber/transcriber_provider.py +++ b/backend/app/transcriber/transcriber_provider.py @@ -1,4 +1,5 @@ import os +import platform from app.transcriber.whisper import WhisperTranscriber from app.transcriber.bcut import BcutTranscriber @@ -6,13 +7,26 @@ from app.transcriber.kuaishou import KuaishouTranscriber from app.utils.logger import get_logger logger = get_logger(__name__) +# 只在Apple平台且设置了环境变量时才导入MLX Whisper +if platform.system() == "Darwin" and os.environ.get("TRANSCRIBER_TYPE") == "mlx-whisper": + try: + from app.transcriber.mlx_whisper_transcriber import MLXWhisperTranscriber + MLX_WHISPER_AVAILABLE = True + logger.info("MLX Whisper 可用,已导入") + except ImportError: + MLX_WHISPER_AVAILABLE = False + logger.warning("MLX Whisper 导入失败,可能未安装或平台不支持") +else: + MLX_WHISPER_AVAILABLE = False + logger.info('初始化转录服务提供器') # 维护各种转录器的单例实例 _transcribers = { 'whisper': None, 'bcut': None, - 'kuaishou': None + 'kuaishou': None, + 'mlx-whisper': None } def get_whisper_transcriber(model_size="base", device="cuda"): @@ -51,13 +65,29 @@ def get_kuaishou_transcriber(): raise return _transcribers['kuaishou'] +def get_mlx_whisper_transcriber(model_size="base"): + """获取 MLX Whisper 转录器实例""" + if not MLX_WHISPER_AVAILABLE: + logger.warning("MLX Whisper 不可用,请确保在Apple平台且已安装mlx_whisper") + raise ImportError("MLX Whisper 不可用,请确保在Apple平台且已安装mlx_whisper") + + if _transcribers['mlx-whisper'] is None: + logger.info(f'创建 MLX Whisper 转录器实例,参数:{model_size}') + try: + _transcribers['mlx-whisper'] = MLXWhisperTranscriber(model_size=model_size) + logger.info('MLX Whisper 转录器创建成功') + except Exception as e: + logger.error(f"MLX Whisper 转录器创建失败: {e}") + raise + return _transcribers['mlx-whisper'] + def get_transcriber(transcriber_type="fast-whisper", model_size="base", device="cuda"): """ 获取指定类型的转录器实例 参数: - transcriber_type: 转录器类型,支持 "fast-whisper", "bcut", "kuaishou" - model_size: 模型大小,whisper 特有参数 + transcriber_type: 转录器类型,支持 "fast-whisper", "bcut", "kuaishou", "mlx-whisper"(仅Apple平台) + model_size: 模型大小,whisper 和 mlx-whisper 特有参数 device: 设备类型,whisper 特有参数 返回: @@ -67,6 +97,12 @@ def get_transcriber(transcriber_type="fast-whisper", model_size="base", device=" if transcriber_type == "fast-whisper": whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE",model_size) return get_whisper_transcriber(whisper_model_size, device=device) + elif transcriber_type == "mlx-whisper": + if not MLX_WHISPER_AVAILABLE: + logger.warning("MLX Whisper 不可用,回退到 fast-whisper") + whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE",model_size) + return get_whisper_transcriber(whisper_model_size, device=device) + return get_mlx_whisper_transcriber(whisper_model_size) elif transcriber_type == "bcut": return get_bcut_transcriber() elif transcriber_type == "kuaishou":