diff --git a/backend/app/transcriber/whisper.py b/backend/app/transcriber/whisper.py index c06c749..486c774 100644 --- a/backend/app/transcriber/whisper.py +++ b/backend/app/transcriber/whisper.py @@ -4,14 +4,19 @@ from app.decorators.timeit import timeit from app.models.transcriber_model import TranscriptSegment, TranscriptResult from app.transcriber.base import Transcriber from app.utils.env_checker import is_cuda_available, is_torch_installed +from app.utils.logger import get_logger from app.utils.path_helper import get_model_dir from events import transcription_finished +from pathlib import Path +import os +from tqdm import tqdm +from huggingface_hub import snapshot_download ''' Size of the model to use (tiny, tiny.en, base, base.en, small, small.en, distil-small.en, medium, medium.en, distil-medium.en, large-v1, large-v2, large-v3, large, distil-large-v2, distil-large-v3, large-v3-turbo, or turbo ''' - +logger=get_logger(__name__) class WhisperTranscriber(Transcriber): # TODO:修改为可配置 @@ -31,15 +36,25 @@ class WhisperTranscriber(Transcriber): self.compute_type = compute_type or ("float16" if self.device == "cuda" else "int8") - model_path = get_model_dir("whisper") + model_dir = get_model_dir("whisper") + model_path = os.path.join(model_dir, f"whisper-{model_size}") + if not Path(model_path).exists(): + logger.info(f"模型 whisper-{model_size} 不存在,开始下载...") + repo_id = f"guillaumekln/faster-whisper-{model_size}" + snapshot_download( + repo_id, + local_dir=model_path, + local_dir_use_symlinks=False, + ) + logger.info("模型下载完成") + self.model = WhisperModel( model_size, device=self.device, - # compute_type="int8", # 或 "float16" + compute_type=self.compute_type, cpu_threads=cpu_threads, - download_root=model_path + download_root=model_dir ) - @staticmethod def is_torch_installed() -> bool: try: diff --git a/backend/requirements.txt b/backend/requirements.txt index 1295fa7..166bc08 100644 Binary files a/backend/requirements.txt and b/backend/requirements.txt differ