Files
BiliNote/backend/app/transcriber/whisper.py
黄建武 51fb59e3e1 feat(transcriber): 使用 ModelScope 替代 Hugging Face 下载模型
- 在 requirements.txt 中添加 modelscope 依赖
- 修改 whisper.py 中的模型下载逻辑,使用 ModelScope 的 snapshot_download 函数- 更新 MODEL_MAP 字典,映射不同大小的模型到对应的 ModelScope 仓库
- 调整模型路径,直接使用 ModelScope 下载的路径
2025-05-08 14:42:43 +08:00

130 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from faster_whisper import WhisperModel
from app.decorators.timeit import timeit
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
from app.transcriber.base import Transcriber
from app.utils.env_checker import is_cuda_available, is_torch_installed
from app.utils.logger import get_logger
from app.utils.path_helper import get_model_dir
from events import transcription_finished
from pathlib import Path
import os
from tqdm import tqdm
from modelscope import snapshot_download
'''
Size of the model to use (tiny, tiny.en, base, base.en, small, small.en, distil-small.en, medium, medium.en, distil-medium.en, large-v1, large-v2, large-v3, large, distil-large-v2, distil-large-v3, large-v3-turbo, or turbo
'''
logger=get_logger(__name__)
MODEL_MAP={
"tiny": "pengzhendong/faster-whisper-tiny",
'base':'pengzhendong/faster-whisper-base',
'small':'pengzhendong/faster-whisper-small',
'medium':'pengzhendong/faster-whisper-medium',
'large-v1':'pengzhendong/faster-whisper-large-v1',
'large-v2':'pengzhendong/faster-whisper-large-v2',
'large-v3':'pengzhendong/faster-whisper-large-v3',
'large-v3-turbo':'pengzhendong/faster-whisper-large-v3-turbo',
}
class WhisperTranscriber(Transcriber):
# TODO:修改为可配置
def __init__(
self,
model_size: str = "base",
device: str = 'cpu',
compute_type: str = None,
cpu_threads: int = 1,
):
if device == 'cpu' or device is None:
self.device = 'cpu'
else:
self.device = "cuda" if self.is_cuda() else "cpu"
if device == 'cuda' and self.device == 'cpu':
print('没有 cuda 使用 cpu进行计算')
self.compute_type = compute_type or ("float16" if self.device == "cuda" else "int8")
model_dir = get_model_dir("whisper")
model_path = os.path.join(model_dir, f"whisper-{model_size}")
if not Path(model_path).exists():
logger.info(f"模型 whisper-{model_size} 不存在,开始下载...")
repo_id = MODEL_MAP[model_size]
model_path = snapshot_download(
repo_id,
local_dir=model_path,
)
logger.info("模型下载完成")
self.model = WhisperModel(
model_size_or_path=model_path,
device=self.device,
compute_type=self.compute_type,
cpu_threads=cpu_threads,
download_root=model_dir
)
@staticmethod
def is_torch_installed() -> bool:
try:
import torch
return True
except ImportError:
return False
@staticmethod
def is_cuda() -> bool:
try:
if is_cuda_available():
print("✅ CUDA 可用,使用 GPU")
return True
elif is_torch_installed():
print("⚠️ 只装了 torch但没有 CUDA用 CPU")
return False
else:
print("❌ 还没有安装 torch请先安装")
return False
except ImportError:
return False
@timeit
def transcript(self, file_path: str) -> TranscriptResult:
try:
segments_raw, info = self.model.transcribe(file_path)
segments = []
full_text = ""
for seg in segments_raw:
text = seg.text.strip()
full_text += text + " "
segments.append(TranscriptSegment(
start=seg.start,
end=seg.end,
text=text
))
result= TranscriptResult(
language=info.language,
full_text=full_text.strip(),
segments=segments,
raw=info
)
# self.on_finish(file_path, result)
return result
except Exception as e:
print(f"转写失败:{e}")
def on_finish(self,video_path:str,result: TranscriptResult)->None:
print("转写完成")
transcription_finished.send({
"file_path": video_path,
})