diff --git a/.env.example b/.env.example index fe158f5..5824a49 100644 --- a/.env.example +++ b/.env.example @@ -28,5 +28,5 @@ MODEl_PROVIDER= #如果不是openai 请修改 deepseek/qwen FFMPEG_BIN_PATH= # transcriber 相关配置 -TRANSCRIBER_TYPE=fast-whisper # fast-whisper/bcut/kuaishou +TRANSCRIBER_TYPE=fast-whisper # fast-whisper/bcut/kuaishou/mlx-whisper(仅Apple平台) WHISPER_MODEL_SIZE=base \ No newline at end of file diff --git a/backend/app/services/note.py b/backend/app/services/note.py index eb85fd9..e2dc05d 100644 --- a/backend/app/services/note.py +++ b/backend/app/services/note.py @@ -18,7 +18,7 @@ from app.models.notes_model import AudioDownloadResult from app.enmus.note_enums import DownloadQuality from app.models.transcriber_model import TranscriptResult from app.transcriber.base import Transcriber -from app.transcriber.transcriber_provider import get_transcriber +from app.transcriber.transcriber_provider import get_transcriber,_transcribers from app.transcriber.whisper import WhisperTranscriber import re @@ -89,9 +89,9 @@ class NoteGenerator: :param transcriber: 选择的转义器 :return: ''' - if self.transcriber_type == 'fast-whisper': - logger.info("使用Whisper") - return get_transcriber(transcriber_type='fast-whisper') + if self.transcriber_type in _transcribers.keys(): + logger.info(f"使用{self.transcriber_type}转义器") + return get_transcriber(transcriber_type=self.transcriber_type) else: logger.warning("不支持的转义器") raise ValueError(f"不支持的转义器:{self.transcriber}") diff --git a/backend/app/transcriber/mlx_whisper_transcriber.py b/backend/app/transcriber/mlx_whisper_transcriber.py new file mode 100644 index 0000000..b253acc --- /dev/null +++ b/backend/app/transcriber/mlx_whisper_transcriber.py @@ -0,0 +1,88 @@ +import mlx_whisper +from pathlib import Path +import os +import platform +from huggingface_hub import snapshot_download + +from app.decorators.timeit import timeit +from app.models.transcriber_model import TranscriptSegment, TranscriptResult +from app.transcriber.base import Transcriber +from app.utils.logger import get_logger +from app.utils.path_helper import get_model_dir +from events import transcription_finished + +logger = get_logger(__name__) + +class MLXWhisperTranscriber(Transcriber): + def __init__( + self, + model_size: str = "base" + ): + # 检查平台 + if platform.system() != "Darwin": + raise RuntimeError("MLX Whisper 仅支持 Apple 平台") + + # 检查环境变量 + if os.environ.get("TRANSCRIBER_TYPE") != "mlx-whisper": + raise RuntimeError("必须设置环境变量 TRANSCRIBER_TYPE=mlx-whisper 才能使用 MLX Whisper") + + self.model_size = model_size + self.model_name = f"mlx-community/whisper-{model_size}" + self.model_path = None + + # 设置模型路径 + model_dir = get_model_dir("mlx-whisper") + self.model_path = os.path.join(model_dir, self.model_name) + # 检查并下载模型 + if not Path(self.model_path).exists(): + logger.info(f"模型 {self.model_name} 不存在,开始下载...") + snapshot_download( + self.model_name, + local_dir=self.model_path, + local_dir_use_symlinks=False, + ) + logger.info("模型下载完成") + + logger.info(f"初始化 MLX Whisper 转录器,模型:{self.model_name}") + + @timeit + def transcript(self, file_path: str) -> TranscriptResult: + try: + # 使用 MLX Whisper 进行转录 + result = mlx_whisper.transcribe( + file_path, + path_or_hf_repo=f"{self.model_name}" + ) + + # 转换为标准格式 + segments = [] + full_text = "" + + for segment in result["segments"]: + text = segment["text"].strip() + full_text += text + " " + segments.append(TranscriptSegment( + start=segment["start"], + end=segment["end"], + text=text + )) + + transcript_result = TranscriptResult( + language=result.get("language", "unknown"), + full_text=full_text.strip(), + segments=segments, + raw=result + ) + + self.on_finish(file_path, transcript_result) + return transcript_result + + except Exception as e: + logger.error(f"MLX Whisper 转写失败:{e}") + raise e + + def on_finish(self, video_path: str, result: TranscriptResult) -> None: + logger.info("MLX Whisper 转写完成") + transcription_finished.send({ + "file_path": video_path, + }) \ No newline at end of file diff --git a/backend/app/transcriber/transcriber_provider.py b/backend/app/transcriber/transcriber_provider.py index a59af19..0edcd40 100644 --- a/backend/app/transcriber/transcriber_provider.py +++ b/backend/app/transcriber/transcriber_provider.py @@ -1,4 +1,5 @@ import os +import platform from app.transcriber.whisper import WhisperTranscriber from app.transcriber.bcut import BcutTranscriber @@ -6,13 +7,26 @@ from app.transcriber.kuaishou import KuaishouTranscriber from app.utils.logger import get_logger logger = get_logger(__name__) +# 只在Apple平台且设置了环境变量时才导入MLX Whisper +if platform.system() == "Darwin" and os.environ.get("TRANSCRIBER_TYPE") == "mlx-whisper": + try: + from app.transcriber.mlx_whisper_transcriber import MLXWhisperTranscriber + MLX_WHISPER_AVAILABLE = True + logger.info("MLX Whisper 可用,已导入") + except ImportError: + MLX_WHISPER_AVAILABLE = False + logger.warning("MLX Whisper 导入失败,可能未安装或平台不支持") +else: + MLX_WHISPER_AVAILABLE = False + logger.info('初始化转录服务提供器') # 维护各种转录器的单例实例 _transcribers = { 'whisper': None, 'bcut': None, - 'kuaishou': None + 'kuaishou': None, + 'mlx-whisper': None } def get_whisper_transcriber(model_size="base", device="cuda"): @@ -51,13 +65,29 @@ def get_kuaishou_transcriber(): raise return _transcribers['kuaishou'] +def get_mlx_whisper_transcriber(model_size="base"): + """获取 MLX Whisper 转录器实例""" + if not MLX_WHISPER_AVAILABLE: + logger.warning("MLX Whisper 不可用,请确保在Apple平台且已安装mlx_whisper") + raise ImportError("MLX Whisper 不可用,请确保在Apple平台且已安装mlx_whisper") + + if _transcribers['mlx-whisper'] is None: + logger.info(f'创建 MLX Whisper 转录器实例,参数:{model_size}') + try: + _transcribers['mlx-whisper'] = MLXWhisperTranscriber(model_size=model_size) + logger.info('MLX Whisper 转录器创建成功') + except Exception as e: + logger.error(f"MLX Whisper 转录器创建失败: {e}") + raise + return _transcribers['mlx-whisper'] + def get_transcriber(transcriber_type="fast-whisper", model_size="base", device="cuda"): """ 获取指定类型的转录器实例 参数: - transcriber_type: 转录器类型,支持 "fast-whisper", "bcut", "kuaishou" - model_size: 模型大小,whisper 特有参数 + transcriber_type: 转录器类型,支持 "fast-whisper", "bcut", "kuaishou", "mlx-whisper"(仅Apple平台) + model_size: 模型大小,whisper 和 mlx-whisper 特有参数 device: 设备类型,whisper 特有参数 返回: @@ -67,6 +97,12 @@ def get_transcriber(transcriber_type="fast-whisper", model_size="base", device=" if transcriber_type == "fast-whisper": whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE",model_size) return get_whisper_transcriber(whisper_model_size, device=device) + elif transcriber_type == "mlx-whisper": + whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE",model_size) + if not MLX_WHISPER_AVAILABLE: + logger.warning("MLX Whisper 不可用,回退到 fast-whisper") + return get_whisper_transcriber(whisper_model_size, device=device) + return get_mlx_whisper_transcriber(whisper_model_size) elif transcriber_type == "bcut": return get_bcut_transcriber() elif transcriber_type == "kuaishou":