mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-06-12 19:20:00 +08:00
feat(transcriber): 可配置 whisper 模型 + 名称映射(自定义 HF repo / 本地路径)
此前 fast-whisper 把「size → Systran/faster-whisper-{size}」的约定隐式散落在
加载/下载/检测三处,用户想用命名不符该约定的模型(社区微调版、或自己下到本地
的模型)接不上。本功能把映射显式化 + 可配置(对齐已有的 MLX_MODEL_MAP 模式)。
后端:
- 新增 app/transcriber/whisper_models.py 注册表:内置映射 + 用户自定义
(config/whisper_models.json 持久化,Docker 下随 config 卷保留);resolve
优先级 自定义 > 内置 > 直通(含 / 的 repo_id / 已存在本地目录)。
- whisper.py / config.py 的加载、下载、完整性检测统一走 resolve;HF cache 目录从
任意 repo_id 推导(models--{org}--{name})不再写死 Systran;本地路径跳过下载,
_purge_cache 绝不删用户本地模型。
- 新增 /whisper_models 增删查 API;/transcriber_config 返回内置+自定义列表;
下载校验放开到「已登记/可解析」的模型。
前端:transcriber.tsx 新增「自定义模型」卡片(增删 + 下载状态),模型下拉自动含自定义。
Docker:自定义 HF 模型下到 /app/backend/models(v2.3.3 models 卷已持久化);本地模型
走挂载目录 + 配置路径,UI 已提示挂载。
测试:tests/test_whisper_models.py 13 个单测全过;并在 v2.3.3 镜像真实后端环境做了
import 链 + resolve + 真实模型检测的集成冒烟,均通过。
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -3,6 +3,11 @@ from faster_whisper import WhisperModel
|
||||
from app.decorators.timeit import timeit
|
||||
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
|
||||
from app.transcriber.base import Transcriber
|
||||
from app.transcriber.whisper_models import (
|
||||
resolve_whisper_model,
|
||||
is_local_target,
|
||||
hf_cache_dirname,
|
||||
)
|
||||
from app.utils.env_checker import is_cuda_available, is_torch_installed
|
||||
from app.utils.logger import get_logger
|
||||
from app.utils.path_helper import get_model_dir
|
||||
@@ -55,8 +60,12 @@ class WhisperTranscriber(Transcriber):
|
||||
self.model = self._build_model(model_size, model_dir)
|
||||
|
||||
def _build_model(self, model_size: str, model_dir: str) -> WhisperModel:
|
||||
# resolve 把模型名映射成可加载标识:内置 size→Systran repo_id、自定义映射、
|
||||
# 直通的 repo_id 或本地路径。faster-whisper 对本地目录走 os.path.isdir 分支,
|
||||
# 对 repo_id 走 download_model(cache_dir=download_root),两者都吃 model_size_or_path。
|
||||
target = resolve_whisper_model(model_size)
|
||||
return WhisperModel(
|
||||
model_size_or_path=model_size, # 传 size name,让 faster-whisper 自己映射到 Systran/faster-whisper-*
|
||||
model_size_or_path=target,
|
||||
device=self.device,
|
||||
compute_type=self.compute_type,
|
||||
download_root=model_dir,
|
||||
@@ -64,14 +73,23 @@ class WhisperTranscriber(Transcriber):
|
||||
|
||||
@staticmethod
|
||||
def _purge_cache(model_dir: str, model_size: str) -> None:
|
||||
"""删掉 HF cache 里这个 size 对应的 snapshot 目录,强制下次重新下载。
|
||||
"""加载失败时清掉对应 HF cache 的 snapshot 目录,强制下次重下。
|
||||
|
||||
HF cache 布局:<model_dir>/models--Systran--faster-whisper-{size}/
|
||||
没找到也不报错——可能用户改了 endpoint 或者 cache 布局变了。
|
||||
关键:本地路径模型**绝不删**——那是用户自己的文件,删了就是数据丢失;
|
||||
只清 HF cache 布局 <model_dir>/models--{org}--{name}/(含历史 modelscope 目录)。
|
||||
"""
|
||||
try:
|
||||
target = resolve_whisper_model(model_size)
|
||||
except Exception:
|
||||
target = model_size
|
||||
if is_local_target(target):
|
||||
logger.warning(
|
||||
f"模型 {model_size} 指向本地路径 {target},加载失败不清理用户文件,请检查该目录是否完整"
|
||||
)
|
||||
return
|
||||
candidates = [
|
||||
Path(model_dir) / f"models--Systran--faster-whisper-{model_size}",
|
||||
Path(model_dir) / f"whisper-{model_size}", # 历史 modelscope 目录,顺手清掉
|
||||
Path(model_dir) / hf_cache_dirname(target), # HF cache: models--org--name
|
||||
Path(model_dir) / f"whisper-{model_size}", # 历史 modelscope 目录,顺手清掉
|
||||
]
|
||||
for path in candidates:
|
||||
if path.exists():
|
||||
|
||||
156
backend/app/transcriber/whisper_models.py
Normal file
156
backend/app/transcriber/whisper_models.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""fast-whisper 模型名 → 可加载标识(HF repo_id 或本地路径)的映射注册表。
|
||||
|
||||
背景:faster-whisper 加载时 `WhisperModel(model_size_or_path=...)` 接受三种入参——
|
||||
内置 size 名、HF repo_id(含 "/")、或本地模型目录(`os.path.isdir` 命中则直接用)。
|
||||
此前后端把「size → Systran/faster-whisper-{size}」这层约定**隐式**散落在加载/下载/
|
||||
检测三处,用户想用命名不符合该约定的模型(比如社区微调版、或自己下到本地的模型)就接不上。
|
||||
|
||||
本模块把映射**显式化 + 可配置**(对齐 mlx_whisper_transcriber.MLX_MODEL_MAP 的模式):
|
||||
- 内置:size → Systran/faster-whisper-{size}
|
||||
- 自定义:用户在 config/whisper_models.json 登记 {名称: "<repo_id 或本地路径>"}
|
||||
(JSON 持久化;Docker 下随 config 卷持久化)
|
||||
|
||||
解析优先级(resolve):自定义 > 内置 > 直通(含 "/" 当 repo_id;已存在目录当本地路径)。
|
||||
加载 / 下载 / 完整性检测三处统一调用 resolve,路径不再各写各的。
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from app.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 内置模型:size → faster-whisper 兼容的 HF repo_id(CTranslate2 转换版,Systran 官方维护)。
|
||||
BUILTIN_WHISPER_MODELS: Dict[str, str] = {
|
||||
"tiny": "Systran/faster-whisper-tiny",
|
||||
"base": "Systran/faster-whisper-base",
|
||||
"small": "Systran/faster-whisper-small",
|
||||
"medium": "Systran/faster-whisper-medium",
|
||||
"large-v1": "Systran/faster-whisper-large-v1",
|
||||
"large-v2": "Systran/faster-whisper-large-v2",
|
||||
"large-v3": "Systran/faster-whisper-large-v3",
|
||||
"large-v3-turbo": "Systran/faster-whisper-large-v3-turbo",
|
||||
}
|
||||
|
||||
# 前端下拉默认展示的内置档位(保持与历史 WHISPER_MODEL_SIZES 一致,不把 8 个全列出来)
|
||||
DEFAULT_VISIBLE_BUILTINS: List[str] = ["tiny", "base", "small", "medium", "large-v3", "large-v3-turbo"]
|
||||
|
||||
|
||||
def is_local_target(target: str) -> bool:
|
||||
"""判断解析出的 target 是本地路径而非 HF repo_id。
|
||||
|
||||
HF repo_id 形如 'Org/Name'(恰一个斜杠、无前导斜杠、非已存在目录)。
|
||||
本地路径:绝对路径 / 以 . 或 ~ 开头 / 已存在的目录。
|
||||
"""
|
||||
if not target:
|
||||
return False
|
||||
if os.path.isabs(target) or target.startswith(".") or target.startswith("~"):
|
||||
return True
|
||||
return os.path.isdir(target)
|
||||
|
||||
|
||||
def hf_cache_dirname(repo_id: str) -> str:
|
||||
"""huggingface_hub snapshot 的本地缓存目录名:Org/Name → models--Org--Name。"""
|
||||
return "models--" + repo_id.replace("/", "--")
|
||||
|
||||
|
||||
class WhisperModelRegistry:
|
||||
"""内置 + 用户自定义的 whisper 模型映射,自定义部分持久化到 JSON。"""
|
||||
|
||||
def __init__(self, filepath: str = "config/whisper_models.json"):
|
||||
self.path = Path(filepath)
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ---- 持久化 ----
|
||||
def _read_custom(self) -> Dict[str, str]:
|
||||
if not self.path.exists():
|
||||
return {}
|
||||
try:
|
||||
with self.path.open("r", encoding="utf-8") as f:
|
||||
data = json.load(f) or {}
|
||||
except Exception as e:
|
||||
logger.warning(f"读取自定义 whisper 模型配置失败,按空处理: {e}")
|
||||
return {}
|
||||
out: Dict[str, str] = {}
|
||||
for name, val in data.items():
|
||||
if isinstance(val, str) and val.strip():
|
||||
out[name] = val.strip()
|
||||
elif isinstance(val, dict) and isinstance(val.get("target"), str):
|
||||
out[name] = val["target"].strip()
|
||||
return out
|
||||
|
||||
def _write_custom(self, data: Dict[str, str]) -> None:
|
||||
with self.path.open("w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# ---- 查询 ----
|
||||
def get_custom_models(self) -> Dict[str, str]:
|
||||
return self._read_custom()
|
||||
|
||||
def visible_model_names(self) -> List[str]:
|
||||
"""给前端下拉 / 下载状态用:默认可见内置档位 + 全部自定义名称。"""
|
||||
names = list(DEFAULT_VISIBLE_BUILTINS)
|
||||
for name in self._read_custom():
|
||||
if name not in names:
|
||||
names.append(name)
|
||||
return names
|
||||
|
||||
def is_known(self, name: str) -> bool:
|
||||
try:
|
||||
self.resolve(name)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def resolve(self, name: str) -> str:
|
||||
"""模型名 → 可加载标识(HF repo_id 或本地路径)。
|
||||
|
||||
优先级:自定义映射 > 内置映射 > 直通(含 "/" 的 repo_id / 已存在的本地目录)。
|
||||
无法识别时抛 ValueError。
|
||||
"""
|
||||
name = (name or "").strip()
|
||||
custom = self._read_custom()
|
||||
if name in custom:
|
||||
return custom[name]
|
||||
if name in BUILTIN_WHISPER_MODELS:
|
||||
return BUILTIN_WHISPER_MODELS[name]
|
||||
# 直通:用户直接把 repo_id(含 "/")或本地已存在目录当 model_size 传进来
|
||||
if "/" in name or os.path.isdir(name):
|
||||
return name
|
||||
raise ValueError(
|
||||
f"未知 whisper 模型 '{name}'。内置可选: {', '.join(BUILTIN_WHISPER_MODELS)};"
|
||||
"或在「音频转写配置」添加自定义模型(HF repo_id 或本地路径)。"
|
||||
)
|
||||
|
||||
# ---- 增删 ----
|
||||
def add_custom_model(self, name: str, target: str) -> Dict[str, str]:
|
||||
name = (name or "").strip()
|
||||
target = (target or "").strip()
|
||||
if not name or not target:
|
||||
raise ValueError("模型名称与目标(HF repo_id 或本地路径)都不能为空")
|
||||
if name in BUILTIN_WHISPER_MODELS:
|
||||
raise ValueError(f"'{name}' 与内置模型重名,请换一个名称")
|
||||
data = self._read_custom()
|
||||
data[name] = target
|
||||
self._write_custom(data)
|
||||
return data
|
||||
|
||||
def remove_custom_model(self, name: str) -> Dict[str, str]:
|
||||
data = self._read_custom()
|
||||
data.pop((name or "").strip(), None)
|
||||
self._write_custom(data)
|
||||
return data
|
||||
|
||||
|
||||
# 模块级单例
|
||||
_registry = WhisperModelRegistry()
|
||||
|
||||
|
||||
def get_registry() -> WhisperModelRegistry:
|
||||
return _registry
|
||||
|
||||
|
||||
def resolve_whisper_model(name: str) -> str:
|
||||
return _registry.resolve(name)
|
||||
Reference in New Issue
Block a user