mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-05-06 20:42:52 +08:00
Merge remote-tracking branch 'origin/master' into dev
# Conflicts: # backend/app/transcriber/transcriber_provider.py
This commit is contained in:
@@ -18,7 +18,7 @@ from app.models.notes_model import AudioDownloadResult
|
||||
from app.enmus.note_enums import DownloadQuality
|
||||
from app.models.transcriber_model import TranscriptResult
|
||||
from app.transcriber.base import Transcriber
|
||||
from app.transcriber.transcriber_provider import get_transcriber
|
||||
from app.transcriber.transcriber_provider import get_transcriber,_transcribers
|
||||
from app.transcriber.whisper import WhisperTranscriber
|
||||
import re
|
||||
|
||||
@@ -43,7 +43,7 @@ class NoteGenerator:
|
||||
def __init__(self):
|
||||
self.model_size: str = 'base'
|
||||
self.device: Union[str, None] = None
|
||||
self.transcriber_type = 'fast-whisper'
|
||||
self.transcriber_type = os.getenv('TRANSCRIBER_TYPE','fast-whisper')
|
||||
self.transcriber = self.get_transcriber()
|
||||
# TODO 需要更换为可调节
|
||||
|
||||
@@ -86,9 +86,9 @@ class NoteGenerator:
|
||||
:param transcriber: 选择的转义器
|
||||
:return:
|
||||
'''
|
||||
if self.transcriber_type == 'fast-whisper':
|
||||
logger.info("使用Whisper")
|
||||
return get_transcriber()
|
||||
if self.transcriber_type in _transcribers.keys():
|
||||
logger.info(f"使用{self.transcriber_type}转义器")
|
||||
return get_transcriber(transcriber_type=self.transcriber_type)
|
||||
else:
|
||||
logger.warning("不支持的转义器")
|
||||
raise ValueError(f"不支持的转义器:{self.transcriber}")
|
||||
|
||||
88
backend/app/transcriber/mlx_whisper_transcriber.py
Normal file
88
backend/app/transcriber/mlx_whisper_transcriber.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import mlx_whisper
|
||||
from pathlib import Path
|
||||
import os
|
||||
import platform
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from app.decorators.timeit import timeit
|
||||
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
|
||||
from app.transcriber.base import Transcriber
|
||||
from app.utils.logger import get_logger
|
||||
from app.utils.path_helper import get_model_dir
|
||||
from events import transcription_finished
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class MLXWhisperTranscriber(Transcriber):
|
||||
def __init__(
|
||||
self,
|
||||
model_size: str = "base"
|
||||
):
|
||||
# 检查平台
|
||||
if platform.system() != "Darwin":
|
||||
raise RuntimeError("MLX Whisper 仅支持 Apple 平台")
|
||||
|
||||
# 检查环境变量
|
||||
if os.environ.get("TRANSCRIBER_TYPE") != "mlx-whisper":
|
||||
raise RuntimeError("必须设置环境变量 TRANSCRIBER_TYPE=mlx-whisper 才能使用 MLX Whisper")
|
||||
|
||||
self.model_size = model_size
|
||||
self.model_name = f"mlx-community/whisper-{model_size}"
|
||||
self.model_path = None
|
||||
|
||||
# 设置模型路径
|
||||
model_dir = get_model_dir("mlx-whisper")
|
||||
self.model_path = os.path.join(model_dir, self.model_name)
|
||||
# 检查并下载模型
|
||||
if not Path(self.model_path).exists():
|
||||
logger.info(f"模型 {self.model_name} 不存在,开始下载...")
|
||||
snapshot_download(
|
||||
self.model_name,
|
||||
local_dir=self.model_path,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
logger.info("模型下载完成")
|
||||
|
||||
logger.info(f"初始化 MLX Whisper 转录器,模型:{self.model_name}")
|
||||
|
||||
@timeit
|
||||
def transcript(self, file_path: str) -> TranscriptResult:
|
||||
try:
|
||||
# 使用 MLX Whisper 进行转录
|
||||
result = mlx_whisper.transcribe(
|
||||
file_path,
|
||||
path_or_hf_repo=f"{self.model_name}"
|
||||
)
|
||||
|
||||
# 转换为标准格式
|
||||
segments = []
|
||||
full_text = ""
|
||||
|
||||
for segment in result["segments"]:
|
||||
text = segment["text"].strip()
|
||||
full_text += text + " "
|
||||
segments.append(TranscriptSegment(
|
||||
start=segment["start"],
|
||||
end=segment["end"],
|
||||
text=text
|
||||
))
|
||||
|
||||
transcript_result = TranscriptResult(
|
||||
language=result.get("language", "unknown"),
|
||||
full_text=full_text.strip(),
|
||||
segments=segments,
|
||||
raw=result
|
||||
)
|
||||
|
||||
self.on_finish(file_path, transcript_result)
|
||||
return transcript_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MLX Whisper 转写失败:{e}")
|
||||
raise e
|
||||
|
||||
def on_finish(self, video_path: str, result: TranscriptResult) -> None:
|
||||
logger.info("MLX Whisper 转写完成")
|
||||
transcription_finished.send({
|
||||
"file_path": video_path,
|
||||
})
|
||||
@@ -1,19 +1,113 @@
|
||||
import os
|
||||
import platform
|
||||
|
||||
from app.transcriber.whisper import WhisperTranscriber
|
||||
from app.transcriber.bcut import BcutTranscriber
|
||||
from app.transcriber.kuaishou import KuaishouTranscriber
|
||||
from app.utils.logger import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
logger.info('实例化transcriber')
|
||||
# TODO:后面需要加入逻辑选择
|
||||
_transcriber = None
|
||||
# 只在Apple平台且设置了环境变量时才导入MLX Whisper
|
||||
if platform.system() == "Darwin" and os.environ.get("TRANSCRIBER_TYPE") == "mlx-whisper":
|
||||
try:
|
||||
from app.transcriber.mlx_whisper_transcriber import MLXWhisperTranscriber
|
||||
MLX_WHISPER_AVAILABLE = True
|
||||
logger.info("MLX Whisper 可用,已导入")
|
||||
except ImportError:
|
||||
MLX_WHISPER_AVAILABLE = False
|
||||
logger.warning("MLX Whisper 导入失败,可能未安装或平台不支持")
|
||||
else:
|
||||
MLX_WHISPER_AVAILABLE = False
|
||||
|
||||
def get_transcriber(model_size="base", device="cuda"):
|
||||
global _transcriber
|
||||
logger.info('初始化转录服务提供器')
|
||||
|
||||
if _transcriber is None:
|
||||
logger.info('不存在 transcriber ,开始实例化transcriber。')
|
||||
# 维护各种转录器的单例实例
|
||||
_transcribers = {
|
||||
'bcut': None,
|
||||
'kuaishou': None,
|
||||
'mlx-whisper': None,
|
||||
'fast-whisper':None
|
||||
}
|
||||
|
||||
def get_whisper_transcriber(model_size="base", device="cuda"):
|
||||
"""获取 Whisper 转录器实例"""
|
||||
if _transcribers['fast-whisper'] is None:
|
||||
logger.info(f'创建 Whisper 转录器实例,参数:{model_size}, {device}')
|
||||
try:
|
||||
_transcriber = WhisperTranscriber(model_size=model_size, device=device)
|
||||
logger.info(f'实例化transcriber成功。参数:{model_size}, {device} ')
|
||||
_transcribers['whisper'] = WhisperTranscriber(model_size=model_size, device=device)
|
||||
logger.info('Whisper 转录器创建成功')
|
||||
except Exception as e:
|
||||
logger.error(f"实例化transcriber失败,请检查是否安装whisper。{e}")
|
||||
return _transcriber
|
||||
logger.error(f"Whisper 转录器创建失败: {e}")
|
||||
raise
|
||||
return _transcribers['whisper']
|
||||
|
||||
def get_bcut_transcriber():
|
||||
"""获取 Bcut 转录器实例"""
|
||||
if _transcribers['bcut'] is None:
|
||||
logger.info('创建 Bcut 转录器实例')
|
||||
try:
|
||||
_transcribers['bcut'] = BcutTranscriber()
|
||||
logger.info('Bcut 转录器创建成功')
|
||||
except Exception as e:
|
||||
logger.error(f"Bcut 转录器创建失败: {e}")
|
||||
raise
|
||||
return _transcribers['bcut']
|
||||
|
||||
def get_kuaishou_transcriber():
|
||||
"""获取快手转录器实例"""
|
||||
if _transcribers['kuaishou'] is None:
|
||||
logger.info('创建快手转录器实例')
|
||||
try:
|
||||
_transcribers['kuaishou'] = KuaishouTranscriber()
|
||||
logger.info('快手转录器创建成功')
|
||||
except Exception as e:
|
||||
logger.error(f"快手转录器创建失败: {e}")
|
||||
raise
|
||||
return _transcribers['kuaishou']
|
||||
|
||||
def get_mlx_whisper_transcriber(model_size="base"):
|
||||
"""获取 MLX Whisper 转录器实例"""
|
||||
if not MLX_WHISPER_AVAILABLE:
|
||||
logger.warning("MLX Whisper 不可用,请确保在Apple平台且已安装mlx_whisper")
|
||||
raise ImportError("MLX Whisper 不可用,请确保在Apple平台且已安装mlx_whisper")
|
||||
|
||||
if _transcribers['mlx-whisper'] is None:
|
||||
logger.info(f'创建 MLX Whisper 转录器实例,参数:{model_size}')
|
||||
try:
|
||||
_transcribers['mlx-whisper'] = MLXWhisperTranscriber(model_size=model_size)
|
||||
logger.info('MLX Whisper 转录器创建成功')
|
||||
except Exception as e:
|
||||
logger.error(f"MLX Whisper 转录器创建失败: {e}")
|
||||
raise
|
||||
return _transcribers['mlx-whisper']
|
||||
|
||||
def get_transcriber(transcriber_type="fast-whisper", model_size="base", device="cuda"):
|
||||
"""
|
||||
获取指定类型的转录器实例
|
||||
|
||||
参数:
|
||||
transcriber_type: 转录器类型,支持 "fast-whisper", "bcut", "kuaishou", "mlx-whisper"(仅Apple平台)
|
||||
model_size: 模型大小,whisper 和 mlx-whisper 特有参数
|
||||
device: 设备类型,whisper 特有参数
|
||||
|
||||
返回:
|
||||
对应类型的转录器实例
|
||||
"""
|
||||
logger.info(f'获取转录器,类型: {transcriber_type}')
|
||||
if transcriber_type == "fast-whisper":
|
||||
whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE",model_size)
|
||||
return get_whisper_transcriber(whisper_model_size, device=device)
|
||||
elif transcriber_type == "mlx-whisper":
|
||||
whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE",model_size)
|
||||
if not MLX_WHISPER_AVAILABLE:
|
||||
logger.warning("MLX Whisper 不可用,回退到 fast-whisper")
|
||||
return get_whisper_transcriber(whisper_model_size, device=device)
|
||||
return get_mlx_whisper_transcriber(whisper_model_size)
|
||||
elif transcriber_type == "bcut":
|
||||
return get_bcut_transcriber()
|
||||
elif transcriber_type == "kuaishou":
|
||||
return get_kuaishou_transcriber()
|
||||
else:
|
||||
logger.warning(f'未知转录器类型 "{transcriber_type}",使用默认 whisper')
|
||||
whisper_model_size = os.environ.get("WHISPER_MODEL_SIZE",model_size)
|
||||
return get_whisper_transcriber(whisper_model_size, device)
|
||||
Reference in New Issue
Block a user