Merge pull request #4 from JefferyHcool/dev_jeff

feat(transcriber): 更新 whisper模型加载方式
This commit is contained in:
Jianwu Huang
2025-04-15 10:28:31 +08:00
committed by GitHub
2 changed files with 20 additions and 5 deletions

View File

@@ -4,14 +4,19 @@ from app.decorators.timeit import timeit
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
from app.transcriber.base import Transcriber
from app.utils.env_checker import is_cuda_available, is_torch_installed
from app.utils.logger import get_logger
from app.utils.path_helper import get_model_dir
from events import transcription_finished
from pathlib import Path
import os
from tqdm import tqdm
from huggingface_hub import snapshot_download
'''
Size of the model to use (tiny, tiny.en, base, base.en, small, small.en, distil-small.en, medium, medium.en, distil-medium.en, large-v1, large-v2, large-v3, large, distil-large-v2, distil-large-v3, large-v3-turbo, or turbo
'''
logger=get_logger(__name__)
class WhisperTranscriber(Transcriber):
# TODO:修改为可配置
@@ -31,15 +36,25 @@ class WhisperTranscriber(Transcriber):
self.compute_type = compute_type or ("float16" if self.device == "cuda" else "int8")
model_path = get_model_dir("whisper")
model_dir = get_model_dir("whisper")
model_path = os.path.join(model_dir, f"whisper-{model_size}")
if not Path(model_path).exists():
logger.info(f"模型 whisper-{model_size} 不存在,开始下载...")
repo_id = f"guillaumekln/faster-whisper-{model_size}"
snapshot_download(
repo_id,
local_dir=model_path,
local_dir_use_symlinks=False,
)
logger.info("模型下载完成")
self.model = WhisperModel(
model_size,
device=self.device,
# compute_type="int8", # 或 "float16"
compute_type=self.compute_type,
cpu_threads=cpu_threads,
download_root=model_path
download_root=model_dir
)
@staticmethod
def is_torch_installed() -> bool:
try:

Binary file not shown.