mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-06-02 22:31:33 +08:00
Merge pull request #4 from JefferyHcool/dev_jeff
feat(transcriber): 更新 whisper模型加载方式
This commit is contained in:
@@ -4,14 +4,19 @@ from app.decorators.timeit import timeit
|
|||||||
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
|
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
|
||||||
from app.transcriber.base import Transcriber
|
from app.transcriber.base import Transcriber
|
||||||
from app.utils.env_checker import is_cuda_available, is_torch_installed
|
from app.utils.env_checker import is_cuda_available, is_torch_installed
|
||||||
|
from app.utils.logger import get_logger
|
||||||
from app.utils.path_helper import get_model_dir
|
from app.utils.path_helper import get_model_dir
|
||||||
|
|
||||||
from events import transcription_finished
|
from events import transcription_finished
|
||||||
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
from tqdm import tqdm
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
'''
|
'''
|
||||||
Size of the model to use (tiny, tiny.en, base, base.en, small, small.en, distil-small.en, medium, medium.en, distil-medium.en, large-v1, large-v2, large-v3, large, distil-large-v2, distil-large-v3, large-v3-turbo, or turbo
|
Size of the model to use (tiny, tiny.en, base, base.en, small, small.en, distil-small.en, medium, medium.en, distil-medium.en, large-v1, large-v2, large-v3, large, distil-large-v2, distil-large-v3, large-v3-turbo, or turbo
|
||||||
'''
|
'''
|
||||||
|
logger=get_logger(__name__)
|
||||||
|
|
||||||
class WhisperTranscriber(Transcriber):
|
class WhisperTranscriber(Transcriber):
|
||||||
# TODO:修改为可配置
|
# TODO:修改为可配置
|
||||||
@@ -31,15 +36,25 @@ class WhisperTranscriber(Transcriber):
|
|||||||
|
|
||||||
self.compute_type = compute_type or ("float16" if self.device == "cuda" else "int8")
|
self.compute_type = compute_type or ("float16" if self.device == "cuda" else "int8")
|
||||||
|
|
||||||
model_path = get_model_dir("whisper")
|
model_dir = get_model_dir("whisper")
|
||||||
|
model_path = os.path.join(model_dir, f"whisper-{model_size}")
|
||||||
|
if not Path(model_path).exists():
|
||||||
|
logger.info(f"模型 whisper-{model_size} 不存在,开始下载...")
|
||||||
|
repo_id = f"guillaumekln/faster-whisper-{model_size}"
|
||||||
|
snapshot_download(
|
||||||
|
repo_id,
|
||||||
|
local_dir=model_path,
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
)
|
||||||
|
logger.info("模型下载完成")
|
||||||
|
|
||||||
self.model = WhisperModel(
|
self.model = WhisperModel(
|
||||||
model_size,
|
model_size,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
# compute_type="int8", # 或 "float16"
|
compute_type=self.compute_type,
|
||||||
cpu_threads=cpu_threads,
|
cpu_threads=cpu_threads,
|
||||||
download_root=model_path
|
download_root=model_dir
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_torch_installed() -> bool:
|
def is_torch_installed() -> bool:
|
||||||
try:
|
try:
|
||||||
|
|||||||
Binary file not shown.
Reference in New Issue
Block a user