fix(transcriber): 下载失败时透传错误到前端并提示

issue #402 衍生问题:whisper 模型后台下载失败时,/transcriber_models_status
只回传 downloading/downloaded 两个布尔,failed 态被直接丢弃,于是前端表现为
「点了下载没反应、状态一直未下载、且无任何错误提示」。

后端:新增轻量模块 model_download_state 统一维护下载状态(downloading/done/
failed)与失败原因,config.py 的下载触发与状态查询共享同一份内存态;状态接口
新增 failed 字段,失败时附带 error(仓库 404、网络中断、本地路径缺 model.bin 等)。

前端:模型管理列表新增「下载失败」红色徽标 + 错误详情,按钮在失败后变为「重试」;
自定义模型项同样展示失败图标与原因;并对「本次新出现的失败」弹一次 toast 主动提示。

测试:新增 test_model_download_state 覆盖状态流转(downloading/done/failed、
失败原因透传、downloaded 覆盖 failed、重下清错、mlx key 隔离)。

已用 docker compose 启动整套栈验证:触发本地路径缺失与 HF 仓库 404 两种失败,
/transcriber_models_status 均正确回传 failed:true + error。

Refs #402

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
huangjianwu
2026-06-23 10:19:38 +08:00
parent 3841719d5a
commit 4a87c5b93b
5 changed files with 281 additions and 54 deletions

View File

@@ -11,6 +11,7 @@ from app.utils.path_helper import get_model_dir
from app.services.cookie_manager import CookieConfigManager
from app.services.transcriber_config_manager import TranscriberConfigManager
from app.transcriber import model_download_state as dl_state
from ffmpeg_helper import ensure_ffmpeg_or_raise
logger = get_logger(__name__)
@@ -148,9 +149,9 @@ def update_proxy_config(data: ProxyConfigRequest):
# ---- Whisper 模型下载状态 & 下载触发 ----
# 用于跟踪正在进行的下载任务
_downloading: dict[str, str] = {} # model_size -> status ("downloading" | "done" | "failed")
# 下载状态downloading / done / failed + 失败原因)统一交给 model_download_state 维护,
# 「触发下载」与「查询状态」共享同一份进程内内存态。失败原因会随状态接口透传给前端,
# 修复 issue #402 衍生问题:原先只回传 downloading/downloaded下载失败时前端无任何提示。
def _check_whisper_model_exists(model_size: str, subdir: str = "whisper") -> bool:
@@ -212,12 +213,7 @@ def get_transcriber_models_status():
statuses = []
for size in get_registry().visible_model_names():
downloaded = _check_whisper_model_exists(size, "whisper")
download_status = _downloading.get(size)
statuses.append({
"model_size": size,
"downloaded": downloaded,
"downloading": download_status == "downloading",
})
statuses.append(dl_state.status_row(size, downloaded))
# 也检查 mlx-whisper仅 macOS
mlx_available = platform.system() == "Darwin"
@@ -225,16 +221,12 @@ def get_transcriber_models_status():
if mlx_available:
from app.transcriber.mlx_whisper_transcriber import MLX_MODEL_MAP
for size in WHISPER_MODEL_SIZES:
mlx_key = f"mlx-{size}"
repo_id = MLX_MODEL_MAP.get(size)
# 用 config.json 判定,和 _check_mlx_whisper_model_exists / 加载逻辑保持一致
downloaded = _check_mlx_whisper_model_exists(size)
mlx_statuses.append({
"model_size": size,
"downloaded": downloaded,
"downloading": _downloading.get(mlx_key) == "downloading",
"available": repo_id is not None,
})
row = dl_state.status_row(size, downloaded, key=f"mlx-{size}")
row["available"] = repo_id is not None
mlx_statuses.append(row)
return R.success(data={
"whisper": statuses,
@@ -260,21 +252,24 @@ def _do_download_whisper(model_size: str):
from app.transcriber.whisper_models import resolve_whisper_model, is_local_target
try:
_downloading[model_size] = "downloading"
dl_state.mark_downloading(model_size)
model_dir = get_model_dir("whisper")
# 已经下好就不重复下
if _check_whisper_model_exists(model_size, "whisper"):
_downloading[model_size] = "done"
dl_state.mark_done(model_size)
return
target = resolve_whisper_model(model_size)
if is_local_target(target):
# 本地模型不下载,只校验 model.bin 是否就位
ok = (Path(target) / "model.bin").exists()
_downloading[model_size] = "done" if ok else "failed"
if not ok:
logger.warning(f"本地模型 {model_size} 路径 {target} 下没有 model.bin无法使用")
if ok:
dl_state.mark_done(model_size)
else:
msg = f"本地模型路径 {target} 下没有 model.bin无法使用"
logger.warning(f"本地模型 {model_size}{msg}")
dl_state.mark_failed(model_size, msg)
return
logger.info(f"开始下载 whisper 模型: {model_size}{target}")
@@ -292,17 +287,17 @@ def _do_download_whisper(model_size: str):
],
)
logger.info(f"whisper 模型下载完成: {model_size}")
_downloading[model_size] = "done"
dl_state.mark_done(model_size)
except Exception as e:
logger.error(f"whisper 模型下载失败: {model_size}, {e}")
_downloading[model_size] = "failed"
dl_state.mark_failed(model_size, str(e))
def _do_download_mlx_whisper(model_size: str):
"""后台下载 mlx-whisper 模型。"""
key = f"mlx-{model_size}"
try:
_downloading[key] = "downloading"
dl_state.mark_downloading(key)
from huggingface_hub import snapshot_download as hf_download
from app.transcriber.mlx_whisper_transcriber import resolve_mlx_repo_id
@@ -310,22 +305,22 @@ def _do_download_mlx_whisper(model_size: str):
repo_id = resolve_mlx_repo_id(model_size)
except ValueError as e:
logger.error(str(e))
_downloading[key] = "failed"
dl_state.mark_failed(key, str(e))
return
model_dir = get_model_dir("mlx-whisper")
model_path = os.path.join(model_dir, repo_id)
# 用 config.json 判定而非目录存在:半成品目录不能算「已下载」
if (Path(model_path) / "config.json").exists():
_downloading[key] = "done"
dl_state.mark_done(key)
return
logger.info(f"开始下载 mlx-whisper 模型: {model_size}{repo_id}")
hf_download(repo_id, local_dir=model_path, local_dir_use_symlinks=False)
logger.info(f"mlx-whisper 模型下载完成: {model_size}")
_downloading[key] = "done"
dl_state.mark_done(key)
except Exception as e:
logger.error(f"mlx-whisper 模型下载失败: {model_size}, {e}")
_downloading[key] = "failed"
dl_state.mark_failed(key, str(e))
@router.post("/transcriber_download")
@@ -338,7 +333,7 @@ def download_transcriber_model(data: ModelDownloadRequest, background_tasks: Bac
if platform.system() != "Darwin":
return R.error(msg="MLX Whisper 仅支持 macOS")
key = f"mlx-{data.model_size}"
if _downloading.get(key) == "downloading":
if dl_state.is_downloading(key):
return R.success(msg="模型正在下载中")
background_tasks.add_task(_do_download_mlx_whisper, data.model_size)
else:
@@ -346,7 +341,7 @@ def download_transcriber_model(data: ModelDownloadRequest, background_tasks: Bac
from app.transcriber.whisper_models import get_registry
if not get_registry().is_known(data.model_size):
return R.error(msg=f"不支持的模型: {data.model_size}(请先在自定义模型中登记)")
if _downloading.get(data.model_size) == "downloading":
if dl_state.is_downloading(data.model_size):
return R.success(msg="模型正在下载中")
background_tasks.add_task(_do_download_whisper, data.model_size)

View File

@@ -0,0 +1,75 @@
"""whisper / mlx 模型后台下载状态跟踪(含失败原因)。
routers.config 的「触发下载」与「查询状态」共享这份进程内内存态:
- keyfast-whisper 直接用 model_sizemlx 用 "mlx-{size}" 前缀(与历史一致)
- 状态downloading / done / failedfailed 时另存最近一次错误原因
为什么抽成独立的轻量模块(仅依赖 logger
1) 把原先散落在 config.py 多处的字符串状态赋值收敛到一处,避免拼写漂移;
2) 失败原因能透传到 /transcriber_models_status → 前端,修复「下载失败前端无任何
提示、状态一直显示未下载」issue #402 的衍生问题:原先状态接口只回传
downloading/downloaded 两个布尔failed 态被直接丢弃);
3) 不引入 faster_whisper / ctranslate2 等重依赖,可被单测隔离加载。
"""
from typing import Dict, Optional
from app.utils.logger import get_logger
logger = get_logger(__name__)
DOWNLOADING = "downloading"
DONE = "done"
FAILED = "failed"
# key -> 状态字符串key -> 最近一次失败原因(仅 failed 时有意义)
_status: Dict[str, str] = {}
_errors: Dict[str, str] = {}
def mark_downloading(key: str) -> None:
_status[key] = DOWNLOADING
_errors.pop(key, None) # 重新开始下载,清掉上一次的失败原因
def mark_done(key: str) -> None:
_status[key] = DONE
_errors.pop(key, None)
def mark_failed(key: str, error: str = "") -> None:
_status[key] = FAILED
if error:
_errors[key] = error
def get_status(key: str) -> Optional[str]:
return _status.get(key)
def is_downloading(key: str) -> bool:
return _status.get(key) == DOWNLOADING
def get_error(key: str) -> Optional[str]:
return _errors.get(key)
def status_row(name: str, downloaded: bool, key: Optional[str] = None) -> dict:
"""构造单个模型给前端的状态行downloaded / downloading / failed (+error)。
key 默认用 namemlx 传 "mlx-{size}"。已下载成功downloaded=True的模型
一律不回传 failed/error——避免「先失败后又下好」时残留旧的错误状态。
"""
k = key if key is not None else name
st = _status.get(k)
row: dict = {
"model_size": name,
"downloaded": downloaded,
"downloading": st == DOWNLOADING,
"failed": (not downloaded) and st == FAILED,
}
if row["failed"]:
err = _errors.get(k)
if err:
row["error"] = err
return row

View File

@@ -0,0 +1,113 @@
"""Unit tests for app.transcriber.model_download_state模型下载状态 + 失败原因跟踪)。
与 test_whisper_models 一样按文件路径隔离加载,并桩掉 app.utils.logger
避免触发 app/__init__.py会 import faster_whisper 等重依赖)。
"""
import importlib.util
import logging
import pathlib
import sys
import types
import unittest
ROOT = pathlib.Path(__file__).resolve().parents[1]
MODULE_PATH = ROOT / "app" / "transcriber" / "model_download_state.py"
def _load_module():
if "app" not in sys.modules:
app_pkg = types.ModuleType("app")
app_pkg.__path__ = []
sys.modules["app"] = app_pkg
if "app.utils" not in sys.modules:
utils_pkg = types.ModuleType("app.utils")
utils_pkg.__path__ = []
sys.modules["app.utils"] = utils_pkg
if "app.utils.logger" not in sys.modules:
logger_mod = types.ModuleType("app.utils.logger")
logger_mod.get_logger = lambda name=None: logging.getLogger(name or "test")
sys.modules["app.utils.logger"] = logger_mod
spec = importlib.util.spec_from_file_location("model_download_state_under_test", MODULE_PATH)
assert spec and spec.loader
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
ds = _load_module()
class TestDownloadState(unittest.TestCase):
def setUp(self):
# 模块级单例,测试间互相隔离
ds._status.clear()
ds._errors.clear()
def test_unknown_key_defaults(self):
row = ds.status_row("tiny", downloaded=False)
self.assertEqual(
row,
{"model_size": "tiny", "downloaded": False, "downloading": False, "failed": False},
)
self.assertNotIn("error", row)
self.assertFalse(ds.is_downloading("tiny"))
def test_downloading(self):
ds.mark_downloading("tiny")
self.assertTrue(ds.is_downloading("tiny"))
row = ds.status_row("tiny", downloaded=False)
self.assertTrue(row["downloading"])
self.assertFalse(row["failed"])
def test_failed_surfaces_error(self):
ds.mark_failed("tiny", "401 Repository Not Found")
row = ds.status_row("tiny", downloaded=False)
self.assertTrue(row["failed"])
self.assertFalse(row["downloading"])
self.assertEqual(row["error"], "401 Repository Not Found")
self.assertEqual(ds.get_error("tiny"), "401 Repository Not Found")
def test_failed_without_message_has_no_error_field(self):
ds.mark_failed("tiny")
row = ds.status_row("tiny", downloaded=False)
self.assertTrue(row["failed"])
self.assertNotIn("error", row)
def test_downloaded_overrides_failed(self):
# 先失败后又下好downloaded=True 时不应再回传 failed/error
ds.mark_failed("tiny", "boom")
row = ds.status_row("tiny", downloaded=True)
self.assertFalse(row["failed"])
self.assertTrue(row["downloaded"])
self.assertNotIn("error", row)
def test_mark_done_clears_error(self):
ds.mark_failed("tiny", "boom")
ds.mark_done("tiny")
self.assertIsNone(ds.get_error("tiny"))
row = ds.status_row("tiny", downloaded=True)
self.assertFalse(row["failed"])
def test_redownload_clears_previous_error(self):
ds.mark_failed("tiny", "boom")
ds.mark_downloading("tiny") # 重新开始下载
self.assertIsNone(ds.get_error("tiny"))
row = ds.status_row("tiny", downloaded=False)
self.assertTrue(row["downloading"])
self.assertFalse(row["failed"])
self.assertNotIn("error", row)
def test_mlx_key_is_independent(self):
# mlx 用 "mlx-{size}" 前缀,与 fast-whisper 的同名档位互不影响
ds.mark_failed("mlx-tiny", "mlx boom")
ds.mark_downloading("tiny")
whisper_row = ds.status_row("tiny", downloaded=False)
mlx_row = ds.status_row("tiny", downloaded=False, key="mlx-tiny")
self.assertTrue(whisper_row["downloading"])
self.assertFalse(whisper_row["failed"])
self.assertTrue(mlx_row["failed"])
self.assertEqual(mlx_row["error"], "mlx boom")
if __name__ == "__main__":
unittest.main()