Merge remote-tracking branch 'origin/master' into dev

# Conflicts:
#	backend/app/transcriber/transcriber_provider.py
This commit is contained in:
思诺特
2025-04-26 11:19:34 +08:00
7 changed files with 210 additions and 20 deletions

View File

@@ -0,0 +1,88 @@
import mlx_whisper
from pathlib import Path
import os
import platform
from huggingface_hub import snapshot_download
from app.decorators.timeit import timeit
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
from app.transcriber.base import Transcriber
from app.utils.logger import get_logger
from app.utils.path_helper import get_model_dir
from events import transcription_finished
logger = get_logger(__name__)
class MLXWhisperTranscriber(Transcriber):
def __init__(
self,
model_size: str = "base"
):
# 检查平台
if platform.system() != "Darwin":
raise RuntimeError("MLX Whisper 仅支持 Apple 平台")
# 检查环境变量
if os.environ.get("TRANSCRIBER_TYPE") != "mlx-whisper":
raise RuntimeError("必须设置环境变量 TRANSCRIBER_TYPE=mlx-whisper 才能使用 MLX Whisper")
self.model_size = model_size
self.model_name = f"mlx-community/whisper-{model_size}"
self.model_path = None
# 设置模型路径
model_dir = get_model_dir("mlx-whisper")
self.model_path = os.path.join(model_dir, self.model_name)
# 检查并下载模型
if not Path(self.model_path).exists():
logger.info(f"模型 {self.model_name} 不存在,开始下载...")
snapshot_download(
self.model_name,
local_dir=self.model_path,
local_dir_use_symlinks=False,
)
logger.info("模型下载完成")
logger.info(f"初始化 MLX Whisper 转录器,模型:{self.model_name}")
@timeit
def transcript(self, file_path: str) -> TranscriptResult:
try:
# 使用 MLX Whisper 进行转录
result = mlx_whisper.transcribe(
file_path,
path_or_hf_repo=f"{self.model_name}"
)
# 转换为标准格式
segments = []
full_text = ""
for segment in result["segments"]:
text = segment["text"].strip()
full_text += text + " "
segments.append(TranscriptSegment(
start=segment["start"],
end=segment["end"],
text=text
))
transcript_result = TranscriptResult(
language=result.get("language", "unknown"),
full_text=full_text.strip(),
segments=segments,
raw=result
)
self.on_finish(file_path, transcript_result)
return transcript_result
except Exception as e:
logger.error(f"MLX Whisper 转写失败:{e}")
raise e
def on_finish(self, video_path: str, result: TranscriptResult) -> None:
logger.info("MLX Whisper 转写完成")
transcription_finished.send({
"file_path": video_path,
})

View File

@@ -1,19 +1,113 @@
import os
import platform
from app.transcriber.whisper import WhisperTranscriber
from app.transcriber.bcut import BcutTranscriber
from app.transcriber.kuaishou import KuaishouTranscriber
from app.utils.logger import get_logger
logger = get_logger(__name__)
logger.info('实例化transcriber')
# TODO:后面需要加入逻辑选择
_transcriber = None
# 只在Apple平台且设置了环境变量时才导入MLX Whisper
if platform.system() == "Darwin" and os.environ.get("TRANSCRIBER_TYPE") == "mlx-whisper":
try:
from app.transcriber.mlx_whisper_transcriber import MLXWhisperTranscriber
MLX_WHISPER_AVAILABLE = True
logger.info("MLX Whisper 可用,已导入")
except ImportError:
MLX_WHISPER_AVAILABLE = False
logger.warning("MLX Whisper 导入失败,可能未安装或平台不支持")
else:
MLX_WHISPER_AVAILABLE = False
def get_transcriber(model_size="base", device="cuda"):
global _transcriber
logger.info('初始化转录服务提供器')
if _transcriber is None:
logger.info('不存在 transcriber 开始实例化transcriber。')
# 维护各种转录器的单例实例
_transcribers = {
'bcut': None,
'kuaishou': None,
'mlx-whisper': None,
'fast-whisper':None
}
def get_whisper_transcriber(model_size="base", device="cuda"):
"""获取 Whisper 转录器实例"""
if _transcribers['fast-whisper'] is None:
logger.info(f'创建 Whisper 转录器实例,参数:{model_size}, {device}')
try:
_transcriber = WhisperTranscriber(model_size=model_size, device=device)
logger.info(f'实例化transcriber成功。参数{model_size}, {device} ')
_transcribers['whisper'] = WhisperTranscriber(model_size=model_size, device=device)
logger.info('Whisper 转录器创建成功')
except Exception as e:
logger.error(f"实例化transcriber失败请检查是否安装whisper。{e}")
return _transcriber
logger.error(f"Whisper 转录器创建失败: {e}")
raise
return _transcribers['whisper']
def get_bcut_transcriber():
"""获取 Bcut 转录器实例"""
if _transcribers['bcut'] is None:
logger.info('创建 Bcut 转录器实例')
try:
_transcribers['bcut'] = BcutTranscriber()
logger.info('Bcut 转录器创建成功')
except Exception as e:
logger.error(f"Bcut 转录器创建失败: {e}")
raise
return _transcribers['bcut']
def get_kuaishou_transcriber():
"""获取快手转录器实例"""
if _transcribers['kuaishou'] is None:
logger.info('创建快手转录器实例')
try:
_transcribers['kuaishou'] = KuaishouTranscriber()
logger.info('快手转录器创建成功')
except Exception as e:
logger.error(f"快手转录器创建失败: {e}")
raise
return _transcribers['kuaishou']
def get_mlx_whisper_transcriber(model_size="base"):
"""获取 MLX Whisper 转录器实例"""
if not MLX_WHISPER_AVAILABLE:
logger.warning("MLX Whisper 不可用请确保在Apple平台且已安装mlx_whisper")
raise ImportError("MLX Whisper 不可用请确保在Apple平台且已安装mlx_whisper")
if _transcribers['mlx-whisper'] is None:
logger.info(f'创建 MLX Whisper 转录器实例,参数:{model_size}')
try:
_transcribers['mlx-whisper'] = MLXWhisperTranscriber(model_size=model_size)
logger.info('MLX Whisper 转录器创建成功')
except Exception as e:
logger.error(f"MLX Whisper 转录器创建失败: {e}")
raise
return _transcribers['mlx-whisper']
def get_transcriber(transcriber_type="fast-whisper", model_size="base", device="cuda"):
"""
获取指定类型的转录器实例
参数:
transcriber_type: 转录器类型,支持 "fast-whisper", "bcut", "kuaishou", "mlx-whisper"(仅Apple平台)
model_size: 模型大小whisper 和 mlx-whisper 特有参数
device: 设备类型whisper 特有参数
返回:
对应类型的转录器实例
"""
logger.info(f'获取转录器,类型: {transcriber_type}')
if transcriber_type == "fast-whisper":
whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE",model_size)
return get_whisper_transcriber(whisper_model_size, device=device)
elif transcriber_type == "mlx-whisper":
whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE",model_size)
if not MLX_WHISPER_AVAILABLE:
logger.warning("MLX Whisper 不可用,回退到 fast-whisper")
return get_whisper_transcriber(whisper_model_size, device=device)
return get_mlx_whisper_transcriber(whisper_model_size)
elif transcriber_type == "bcut":
return get_bcut_transcriber()
elif transcriber_type == "kuaishou":
return get_kuaishou_transcriber()
else:
logger.warning(f'未知转录器类型 "{transcriber_type}",使用默认 whisper')
whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE",model_size)
return get_whisper_transcriber(whisper_model_size, device)