mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-05-07 05:12:45 +08:00
Merge pull request #4 from JefferyHcool/dev_jeff
feat(transcriber): 更新 whisper模型加载方式
This commit is contained in:
@@ -4,14 +4,19 @@ from app.decorators.timeit import timeit
|
||||
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
|
||||
from app.transcriber.base import Transcriber
|
||||
from app.utils.env_checker import is_cuda_available, is_torch_installed
|
||||
from app.utils.logger import get_logger
|
||||
from app.utils.path_helper import get_model_dir
|
||||
|
||||
from events import transcription_finished
|
||||
from pathlib import Path
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
'''
|
||||
Size of the model to use (tiny, tiny.en, base, base.en, small, small.en, distil-small.en, medium, medium.en, distil-medium.en, large-v1, large-v2, large-v3, large, distil-large-v2, distil-large-v3, large-v3-turbo, or turbo
|
||||
'''
|
||||
|
||||
logger=get_logger(__name__)
|
||||
|
||||
class WhisperTranscriber(Transcriber):
|
||||
# TODO:修改为可配置
|
||||
@@ -31,15 +36,25 @@ class WhisperTranscriber(Transcriber):
|
||||
|
||||
self.compute_type = compute_type or ("float16" if self.device == "cuda" else "int8")
|
||||
|
||||
model_path = get_model_dir("whisper")
|
||||
model_dir = get_model_dir("whisper")
|
||||
model_path = os.path.join(model_dir, f"whisper-{model_size}")
|
||||
if not Path(model_path).exists():
|
||||
logger.info(f"模型 whisper-{model_size} 不存在,开始下载...")
|
||||
repo_id = f"guillaumekln/faster-whisper-{model_size}"
|
||||
snapshot_download(
|
||||
repo_id,
|
||||
local_dir=model_path,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
logger.info("模型下载完成")
|
||||
|
||||
self.model = WhisperModel(
|
||||
model_size,
|
||||
device=self.device,
|
||||
# compute_type="int8", # 或 "float16"
|
||||
compute_type=self.compute_type,
|
||||
cpu_threads=cpu_threads,
|
||||
download_root=model_path
|
||||
download_root=model_dir
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_torch_installed() -> bool:
|
||||
try:
|
||||
|
||||
Binary file not shown.
Reference in New Issue
Block a user