mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-06-12 03:00:09 +08:00
feat(transcriber): 可配置 whisper 模型 + 名称映射(自定义 HF repo / 本地路径)
此前 fast-whisper 把「size → Systran/faster-whisper-{size}」的约定隐式散落在
加载/下载/检测三处,用户想用命名不符该约定的模型(社区微调版、或自己下到本地
的模型)接不上。本功能把映射显式化 + 可配置(对齐已有的 MLX_MODEL_MAP 模式)。
后端:
- 新增 app/transcriber/whisper_models.py 注册表:内置映射 + 用户自定义
(config/whisper_models.json 持久化,Docker 下随 config 卷保留);resolve
优先级 自定义 > 内置 > 直通(含 / 的 repo_id / 已存在本地目录)。
- whisper.py / config.py 的加载、下载、完整性检测统一走 resolve;HF cache 目录从
任意 repo_id 推导(models--{org}--{name})不再写死 Systran;本地路径跳过下载,
_purge_cache 绝不删用户本地模型。
- 新增 /whisper_models 增删查 API;/transcriber_config 返回内置+自定义列表;
下载校验放开到「已登记/可解析」的模型。
前端:transcriber.tsx 新增「自定义模型」卡片(增删 + 下载状态),模型下拉自动含自定义。
Docker:自定义 HF 模型下到 /app/backend/models(v2.3.3 models 卷已持久化);本地模型
走挂载目录 + 配置路径,UI 已提示挂载。
测试:tests/test_whisper_models.py 13 个单测全过;并在 v2.3.3 镜像真实后端环境做了
import 链 + resolve + 真实模型检测的集成冒烟,均通过。
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -61,16 +61,53 @@ WHISPER_MODEL_SIZES = ["tiny", "base", "small", "medium", "large-v3", "large-v3-
|
||||
@router.get("/transcriber_config")
|
||||
def get_transcriber_config():
|
||||
from app.transcriber.transcriber_provider import MLX_WHISPER_AVAILABLE
|
||||
from app.transcriber.whisper_models import get_registry, BUILTIN_WHISPER_MODELS
|
||||
|
||||
registry = get_registry()
|
||||
config = transcriber_config_manager.get_config()
|
||||
return R.success(data={
|
||||
**config,
|
||||
"available_types": AVAILABLE_TRANSCRIBER_TYPES,
|
||||
"whisper_model_sizes": WHISPER_MODEL_SIZES,
|
||||
# 内置可见档位 + 用户自定义模型,供前端下拉
|
||||
"whisper_model_sizes": registry.visible_model_names(),
|
||||
"whisper_builtin_models": BUILTIN_WHISPER_MODELS,
|
||||
"whisper_custom_models": registry.get_custom_models(),
|
||||
"mlx_whisper_available": MLX_WHISPER_AVAILABLE,
|
||||
})
|
||||
|
||||
|
||||
class WhisperCustomModelRequest(BaseModel):
|
||||
name: str
|
||||
target: str # HF repo_id(如 Systran/faster-whisper-large-v3)或本地模型目录路径
|
||||
|
||||
|
||||
@router.get("/whisper_models")
|
||||
def list_whisper_models():
|
||||
"""列出内置 + 用户自定义的 whisper 模型映射。"""
|
||||
from app.transcriber.whisper_models import get_registry, BUILTIN_WHISPER_MODELS
|
||||
reg = get_registry()
|
||||
return R.success(data={"builtin": BUILTIN_WHISPER_MODELS, "custom": reg.get_custom_models()})
|
||||
|
||||
|
||||
@router.post("/whisper_models")
|
||||
def add_whisper_model(data: WhisperCustomModelRequest):
|
||||
"""新增自定义 whisper 模型映射(名称 → HF repo_id 或本地路径)。"""
|
||||
from app.transcriber.whisper_models import get_registry
|
||||
try:
|
||||
custom = get_registry().add_custom_model(data.name, data.target)
|
||||
except ValueError as e:
|
||||
return R.error(msg=str(e))
|
||||
return R.success(data={"custom": custom}, msg="已添加自定义模型")
|
||||
|
||||
|
||||
@router.delete("/whisper_models/{name}")
|
||||
def delete_whisper_model(name: str):
|
||||
"""删除自定义 whisper 模型映射(不会删除已下载的模型文件)。"""
|
||||
from app.transcriber.whisper_models import get_registry
|
||||
custom = get_registry().remove_custom_model(name)
|
||||
return R.success(data={"custom": custom}, msg="已删除自定义模型")
|
||||
|
||||
|
||||
@router.post("/transcriber_config")
|
||||
def update_transcriber_config(data: TranscriberConfigRequest):
|
||||
config = transcriber_config_manager.update_config(
|
||||
@@ -119,14 +156,27 @@ _downloading: dict[str, str] = {} # model_size -> status ("downloading" | "done
|
||||
def _check_whisper_model_exists(model_size: str, subdir: str = "whisper") -> bool:
|
||||
"""检查指定 whisper 模型是否已下载完整到本地。
|
||||
|
||||
faster-whisper 把模型缓存在 HF cache 布局下:
|
||||
<model_dir>/models--Systran--faster-whisper-{size}/snapshots/<hash>/model.bin
|
||||
必须能在某个 snapshot 目录里找到 model.bin 才算完成。
|
||||
(历史 modelscope 布局 <model_dir>/whisper-{size}/model.bin 也兼容识别。)
|
||||
先把模型名 resolve 成可加载标识,再按类型判定:
|
||||
- 本地路径模型 → 直接看该目录下有没有 model.bin
|
||||
- HF repo_id → 看 HF cache 布局
|
||||
<model_dir>/models--{org}--{name}/snapshots/<hash>/model.bin
|
||||
(历史 modelscope 布局 <model_dir>/whisper-{size}/model.bin 也兼容识别)
|
||||
"""
|
||||
from app.transcriber.whisper_models import (
|
||||
resolve_whisper_model,
|
||||
is_local_target,
|
||||
hf_cache_dirname,
|
||||
)
|
||||
try:
|
||||
target = resolve_whisper_model(model_size)
|
||||
except Exception:
|
||||
return False
|
||||
if is_local_target(target):
|
||||
return (Path(target) / "model.bin").exists()
|
||||
|
||||
model_dir = Path(get_model_dir(subdir))
|
||||
# HF cache 布局
|
||||
hf_repo_dir = model_dir / f"models--Systran--faster-whisper-{model_size}" / "snapshots"
|
||||
# HF cache 布局(适配任意 org/repo,不再写死 Systran)
|
||||
hf_repo_dir = model_dir / hf_cache_dirname(target) / "snapshots"
|
||||
if hf_repo_dir.exists():
|
||||
for snapshot in hf_repo_dir.iterdir():
|
||||
if (snapshot / "model.bin").exists():
|
||||
@@ -157,9 +207,10 @@ def _check_mlx_whisper_model_exists(model_size: str) -> bool:
|
||||
|
||||
@router.get("/transcriber_models_status")
|
||||
def get_transcriber_models_status():
|
||||
"""返回所有 whisper 模型的下载状态。"""
|
||||
"""返回所有 whisper 模型的下载状态(含用户自定义模型)。"""
|
||||
from app.transcriber.whisper_models import get_registry
|
||||
statuses = []
|
||||
for size in WHISPER_MODEL_SIZES:
|
||||
for size in get_registry().visible_model_names():
|
||||
downloaded = _check_whisper_model_exists(size, "whisper")
|
||||
download_status = _downloading.get(size)
|
||||
statuses.append({
|
||||
@@ -198,13 +249,15 @@ class ModelDownloadRequest(BaseModel):
|
||||
|
||||
|
||||
def _do_download_whisper(model_size: str):
|
||||
"""后台下载 faster-whisper 模型。
|
||||
"""后台下载 faster-whisper 模型(支持内置 size / 自定义 repo_id / 本地路径)。
|
||||
|
||||
直接走 huggingface_hub.snapshot_download,把模型放到 HF cache 布局里——
|
||||
这样 faster-whisper 加载时(WhisperModel(model_size_or_path=size_name,
|
||||
download_root=model_dir))能直接命中缓存,跟加载路径完全对齐。
|
||||
模型名先 resolve:
|
||||
- 本地路径模型:无需下载,目录里有 model.bin 即 done,否则 failed;
|
||||
- HF repo_id:snapshot_download 到 HF cache 布局(cache_dir=model_dir),
|
||||
与加载逻辑 WhisperModel(download_root=model_dir) 完全对齐。
|
||||
"""
|
||||
from huggingface_hub import snapshot_download
|
||||
from app.transcriber.whisper_models import resolve_whisper_model, is_local_target
|
||||
|
||||
try:
|
||||
_downloading[model_size] = "downloading"
|
||||
@@ -214,12 +267,21 @@ def _do_download_whisper(model_size: str):
|
||||
if _check_whisper_model_exists(model_size, "whisper"):
|
||||
_downloading[model_size] = "done"
|
||||
return
|
||||
repo_id = f"Systran/faster-whisper-{model_size}"
|
||||
logger.info(f"开始下载 whisper 模型: {repo_id}")
|
||||
|
||||
target = resolve_whisper_model(model_size)
|
||||
if is_local_target(target):
|
||||
# 本地模型不下载,只校验 model.bin 是否就位
|
||||
ok = (Path(target) / "model.bin").exists()
|
||||
_downloading[model_size] = "done" if ok else "failed"
|
||||
if not ok:
|
||||
logger.warning(f"本地模型 {model_size} 路径 {target} 下没有 model.bin,无法使用")
|
||||
return
|
||||
|
||||
logger.info(f"开始下载 whisper 模型: {model_size} ← {target}")
|
||||
# 跟 faster-whisper utils.py 用同样的 allow_patterns,避免多下无关文件;
|
||||
# 不传 local_dir 让它走 HF 默认 cache 布局(与加载逻辑对齐)
|
||||
snapshot_download(
|
||||
repo_id,
|
||||
target,
|
||||
cache_dir=model_dir,
|
||||
allow_patterns=[
|
||||
"config.json",
|
||||
@@ -268,11 +330,11 @@ def _do_download_mlx_whisper(model_size: str):
|
||||
|
||||
@router.post("/transcriber_download")
|
||||
def download_transcriber_model(data: ModelDownloadRequest, background_tasks: BackgroundTasks):
|
||||
"""触发后台下载指定的 whisper 模型。"""
|
||||
if data.model_size not in WHISPER_MODEL_SIZES:
|
||||
return R.error(msg=f"不支持的模型大小: {data.model_size}")
|
||||
|
||||
"""触发后台下载指定的 whisper 模型(fast-whisper 支持内置档位 + 自定义模型)。"""
|
||||
if data.transcriber_type == "mlx-whisper":
|
||||
# mlx 只认内置档位(mlx-community 的固定映射)
|
||||
if data.model_size not in WHISPER_MODEL_SIZES:
|
||||
return R.error(msg=f"MLX 不支持的模型大小: {data.model_size}")
|
||||
if platform.system() != "Darwin":
|
||||
return R.error(msg="MLX Whisper 仅支持 macOS")
|
||||
key = f"mlx-{data.model_size}"
|
||||
@@ -280,6 +342,10 @@ def download_transcriber_model(data: ModelDownloadRequest, background_tasks: Bac
|
||||
return R.success(msg="模型正在下载中")
|
||||
background_tasks.add_task(_do_download_mlx_whisper, data.model_size)
|
||||
else:
|
||||
# fast-whisper:内置档位 / 自定义 repo_id / 本地路径都允许
|
||||
from app.transcriber.whisper_models import get_registry
|
||||
if not get_registry().is_known(data.model_size):
|
||||
return R.error(msg=f"不支持的模型: {data.model_size}(请先在自定义模型中登记)")
|
||||
if _downloading.get(data.model_size) == "downloading":
|
||||
return R.success(msg="模型正在下载中")
|
||||
background_tasks.add_task(_do_download_whisper, data.model_size)
|
||||
|
||||
Reference in New Issue
Block a user