mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-05-06 20:42:52 +08:00
feat(transcriber): 添加 MLX Whisper 转录器支持
- 新增 MLXWhisperTranscriber 类,支持在 Apple 平台上进行转录。 - 更新 transcriber_provider.py,动态导入 MLX Whisper 转录器并添加相应的环境变量检查。 - 修改 .env.example 文件,更新 TRANSCRIBER_TYPE 配置说明以包含 mlx-whisper 选项。
This commit is contained in:
@@ -28,5 +28,5 @@ MODEl_PROVIDER= #如果不是openai 请修改 deepseek/qwen
|
||||
FFMPEG_BIN_PATH=
|
||||
|
||||
# transcriber 相关配置
|
||||
TRANSCRIBER_TYPE=fast-whisper # fast-whisper/bcut/kuaishou
|
||||
TRANSCRIBER_TYPE=fast-whisper # fast-whisper/bcut/kuaishou/mlx-whisper(仅Apple平台)
|
||||
WHISPER_MODEL_SIZE=base
|
||||
82
backend/app/transcriber/mlx_whisper_transcriber.py
Normal file
82
backend/app/transcriber/mlx_whisper_transcriber.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import mlx_whisper
|
||||
from pathlib import Path
|
||||
import os
|
||||
import platform
|
||||
|
||||
from app.decorators.timeit import timeit
|
||||
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
|
||||
from app.transcriber.base import Transcriber
|
||||
from app.utils.logger import get_logger
|
||||
from app.utils.path_helper import get_model_dir
|
||||
from events import transcription_finished
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class MLXWhisperTranscriber(Transcriber):
|
||||
def __init__(
|
||||
self,
|
||||
model_size: str = "base",
|
||||
device: str = None, # MLX 会自动选择最佳设备
|
||||
):
|
||||
# 检查平台
|
||||
if platform.system() != "Darwin":
|
||||
raise RuntimeError("MLX Whisper 仅支持 Apple 平台")
|
||||
|
||||
# 检查环境变量
|
||||
if os.environ.get("TRANSCRIBER_TYPE") != "mlx-whisper":
|
||||
raise RuntimeError("必须设置环境变量 TRANSCRIBER_TYPE=mlx-whisper 才能使用 MLX Whisper")
|
||||
|
||||
self.model_size = model_size
|
||||
self.model_name = f"whisper-{model_size}-mlx"
|
||||
self.model_path = None
|
||||
|
||||
# 设置模型路径
|
||||
model_dir = get_model_dir("mlx-whisper")
|
||||
self.model_path = os.path.join(model_dir, self.model_name)
|
||||
|
||||
# 确保模型目录存在
|
||||
Path(model_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f"初始化 MLX Whisper 转录器,模型:{self.model_name}")
|
||||
|
||||
@timeit
|
||||
def transcript(self, file_path: str) -> TranscriptResult:
|
||||
try:
|
||||
# 使用 MLX Whisper 进行转录
|
||||
result = mlx_whisper.transcribe(
|
||||
file_path,
|
||||
path_or_hf_repo=f"mlx-community/{self.model_name}"
|
||||
)
|
||||
|
||||
# 转换为标准格式
|
||||
segments = []
|
||||
full_text = ""
|
||||
|
||||
for segment in result["segments"]:
|
||||
text = segment["text"].strip()
|
||||
full_text += text + " "
|
||||
segments.append(TranscriptSegment(
|
||||
start=segment["start"],
|
||||
end=segment["end"],
|
||||
text=text
|
||||
))
|
||||
|
||||
transcript_result = TranscriptResult(
|
||||
language=result.get("language", "unknown"),
|
||||
full_text=full_text.strip(),
|
||||
segments=segments,
|
||||
raw=result
|
||||
)
|
||||
|
||||
self.on_finish(file_path, transcript_result)
|
||||
return transcript_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MLX Whisper 转写失败:{e}")
|
||||
raise e
|
||||
|
||||
def on_finish(self, video_path: str, result: TranscriptResult) -> None:
|
||||
logger.info("MLX Whisper 转写完成")
|
||||
transcription_finished.send({
|
||||
"file_path": video_path,
|
||||
})
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import platform
|
||||
|
||||
from app.transcriber.whisper import WhisperTranscriber
|
||||
from app.transcriber.bcut import BcutTranscriber
|
||||
@@ -6,13 +7,26 @@ from app.transcriber.kuaishou import KuaishouTranscriber
|
||||
from app.utils.logger import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 只在Apple平台且设置了环境变量时才导入MLX Whisper
|
||||
if platform.system() == "Darwin" and os.environ.get("TRANSCRIBER_TYPE") == "mlx-whisper":
|
||||
try:
|
||||
from app.transcriber.mlx_whisper_transcriber import MLXWhisperTranscriber
|
||||
MLX_WHISPER_AVAILABLE = True
|
||||
logger.info("MLX Whisper 可用,已导入")
|
||||
except ImportError:
|
||||
MLX_WHISPER_AVAILABLE = False
|
||||
logger.warning("MLX Whisper 导入失败,可能未安装或平台不支持")
|
||||
else:
|
||||
MLX_WHISPER_AVAILABLE = False
|
||||
|
||||
logger.info('初始化转录服务提供器')
|
||||
|
||||
# 维护各种转录器的单例实例
|
||||
_transcribers = {
|
||||
'whisper': None,
|
||||
'bcut': None,
|
||||
'kuaishou': None
|
||||
'kuaishou': None,
|
||||
'mlx-whisper': None
|
||||
}
|
||||
|
||||
def get_whisper_transcriber(model_size="base", device="cuda"):
|
||||
@@ -51,13 +65,29 @@ def get_kuaishou_transcriber():
|
||||
raise
|
||||
return _transcribers['kuaishou']
|
||||
|
||||
def get_mlx_whisper_transcriber(model_size="base"):
|
||||
"""获取 MLX Whisper 转录器实例"""
|
||||
if not MLX_WHISPER_AVAILABLE:
|
||||
logger.warning("MLX Whisper 不可用,请确保在Apple平台且已安装mlx_whisper")
|
||||
raise ImportError("MLX Whisper 不可用,请确保在Apple平台且已安装mlx_whisper")
|
||||
|
||||
if _transcribers['mlx-whisper'] is None:
|
||||
logger.info(f'创建 MLX Whisper 转录器实例,参数:{model_size}')
|
||||
try:
|
||||
_transcribers['mlx-whisper'] = MLXWhisperTranscriber(model_size=model_size)
|
||||
logger.info('MLX Whisper 转录器创建成功')
|
||||
except Exception as e:
|
||||
logger.error(f"MLX Whisper 转录器创建失败: {e}")
|
||||
raise
|
||||
return _transcribers['mlx-whisper']
|
||||
|
||||
def get_transcriber(transcriber_type="fast-whisper", model_size="base", device="cuda"):
|
||||
"""
|
||||
获取指定类型的转录器实例
|
||||
|
||||
参数:
|
||||
transcriber_type: 转录器类型,支持 "fast-whisper", "bcut", "kuaishou"
|
||||
model_size: 模型大小,whisper 特有参数
|
||||
transcriber_type: 转录器类型,支持 "fast-whisper", "bcut", "kuaishou", "mlx-whisper"(仅Apple平台)
|
||||
model_size: 模型大小,whisper 和 mlx-whisper 特有参数
|
||||
device: 设备类型,whisper 特有参数
|
||||
|
||||
返回:
|
||||
@@ -67,6 +97,12 @@ def get_transcriber(transcriber_type="fast-whisper", model_size="base", device="
|
||||
if transcriber_type == "fast-whisper":
|
||||
whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE",model_size)
|
||||
return get_whisper_transcriber(whisper_model_size, device=device)
|
||||
elif transcriber_type == "mlx-whisper":
|
||||
if not MLX_WHISPER_AVAILABLE:
|
||||
logger.warning("MLX Whisper 不可用,回退到 fast-whisper")
|
||||
whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE",model_size)
|
||||
return get_whisper_transcriber(whisper_model_size, device=device)
|
||||
return get_mlx_whisper_transcriber(whisper_model_size)
|
||||
elif transcriber_type == "bcut":
|
||||
return get_bcut_transcriber()
|
||||
elif transcriber_type == "kuaishou":
|
||||
|
||||
Reference in New Issue
Block a user