mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-06-24 01:03:42 +08:00
Merge branch 'hotfix/2.4.3'
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"$schema": "../node_modules/@tauri-apps/cli/config.schema.json",
|
||||
"productName": "BiliNote",
|
||||
"version": "2.4.2",
|
||||
"version": "2.4.3",
|
||||
"identifier": "com.jefferyhuang.bilinote",
|
||||
"build": {
|
||||
"frontendDist": "../dist",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useState, useEffect, useCallback } from 'react'
|
||||
import { useState, useEffect, useCallback, useRef } from 'react'
|
||||
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { Badge } from '@/components/ui/badge'
|
||||
@@ -40,6 +40,9 @@ export default function Transcriber() {
|
||||
const [newModelName, setNewModelName] = useState('')
|
||||
const [newModelTarget, setNewModelTarget] = useState('')
|
||||
const [addingModel, setAddingModel] = useState(false)
|
||||
// 已提示过的下载失败 key(whisper 用 model_size,mlx 用 mlx-{size})。
|
||||
// null 表示尚未首次加载——首次加载只建立基线、不对历史失败弹窗。
|
||||
const prevFailedRef = useRef<Set<string> | null>(null)
|
||||
|
||||
// 重新拉取配置(不重置用户当前的选择),用于增删自定义模型后刷新下拉与列表
|
||||
const reloadConfig = useCallback(async () => {
|
||||
@@ -56,6 +59,23 @@ export default function Transcriber() {
|
||||
setModelStatuses(data.whisper)
|
||||
setMlxModelStatuses(data.mlx_whisper)
|
||||
setMlxAvailable(data.mlx_available)
|
||||
|
||||
// 下载失败主动提示:只对「本次新出现的失败」弹一次,避免轮询期间反复弹窗
|
||||
const failedNow = new Map<string, ModelStatus>()
|
||||
data.whisper.forEach(m => m.failed && failedNow.set(m.model_size, m))
|
||||
data.mlx_whisper.forEach(m => m.failed && failedNow.set(`mlx-${m.model_size}`, m))
|
||||
if (prevFailedRef.current === null) {
|
||||
// 首次加载:建立基线,不对进入页面前就已失败的项弹窗(仍会在列表里红字展示)
|
||||
prevFailedRef.current = new Set(failedNow.keys())
|
||||
} else {
|
||||
failedNow.forEach((m, key) => {
|
||||
if (!prevFailedRef.current!.has(key)) {
|
||||
const detail = m.error ? `:${m.error.slice(0, 120)}` : ''
|
||||
toast.error(`模型 ${m.model_size} 下载失败${detail}`, { duration: 6000 })
|
||||
}
|
||||
})
|
||||
prevFailedRef.current = new Set(failedNow.keys())
|
||||
}
|
||||
} catch {
|
||||
// 静默失败,不阻塞主流程
|
||||
}
|
||||
@@ -290,32 +310,44 @@ export default function Transcriber() {
|
||||
{currentModels.map(model => (
|
||||
<div
|
||||
key={model.model_size}
|
||||
className="flex items-center justify-between rounded-md border px-4 py-3"
|
||||
className="rounded-md border px-4 py-3"
|
||||
>
|
||||
<div className="flex items-center gap-3">
|
||||
<span className="font-medium">{model.model_size}</span>
|
||||
{model.downloaded ? (
|
||||
<Badge variant="default" className="bg-green-500 hover:bg-green-600">
|
||||
已下载
|
||||
</Badge>
|
||||
) : model.downloading ? (
|
||||
<Badge variant="secondary" className="flex items-center gap-1">
|
||||
<Loader2 className="h-3 w-3 animate-spin" />
|
||||
下载中
|
||||
</Badge>
|
||||
) : (
|
||||
<Badge variant="outline">未下载</Badge>
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-3">
|
||||
<span className="font-medium">{model.model_size}</span>
|
||||
{model.downloaded ? (
|
||||
<Badge variant="default" className="bg-green-500 hover:bg-green-600">
|
||||
已下载
|
||||
</Badge>
|
||||
) : model.downloading ? (
|
||||
<Badge variant="secondary" className="flex items-center gap-1">
|
||||
<Loader2 className="h-3 w-3 animate-spin" />
|
||||
下载中
|
||||
</Badge>
|
||||
) : model.failed ? (
|
||||
<Badge variant="destructive" className="flex items-center gap-1" title={model.error}>
|
||||
<XCircle className="h-3 w-3" />
|
||||
下载失败
|
||||
</Badge>
|
||||
) : (
|
||||
<Badge variant="outline">未下载</Badge>
|
||||
)}
|
||||
</div>
|
||||
{!model.downloaded && !model.downloading && (
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
onClick={() => handleDownload(model.model_size, selectedType)}
|
||||
>
|
||||
<Download className="mr-1 h-4 w-4" />
|
||||
{model.failed ? '重试' : '下载'}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
{!model.downloaded && !model.downloading && (
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
onClick={() => handleDownload(model.model_size, selectedType)}
|
||||
>
|
||||
<Download className="mr-1 h-4 w-4" />
|
||||
下载
|
||||
</Button>
|
||||
{model.failed && model.error && (
|
||||
<p className="mt-2 break-all text-xs text-red-500" title={model.error}>
|
||||
{model.error}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
@@ -368,10 +400,18 @@ export default function Transcriber() {
|
||||
{status?.downloading && (
|
||||
<Loader2 className="h-3.5 w-3.5 animate-spin text-neutral-400" />
|
||||
)}
|
||||
{status?.failed && (
|
||||
<XCircle className="h-3.5 w-3.5 text-red-500" />
|
||||
)}
|
||||
</div>
|
||||
<div className="truncate text-xs text-neutral-400" title={target}>
|
||||
{target}
|
||||
</div>
|
||||
{status?.failed && status?.error && (
|
||||
<div className="truncate text-xs text-red-500" title={status.error}>
|
||||
{status.error}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<Button
|
||||
size="sm"
|
||||
|
||||
@@ -16,6 +16,10 @@ export interface ModelStatus {
|
||||
model_size: string
|
||||
downloaded: boolean
|
||||
downloading: boolean
|
||||
/** 后台下载失败(仓库 404、网络中断、本地路径缺 model.bin 等)。后端从此字段透传 */
|
||||
failed?: boolean
|
||||
/** 下载失败时的原因(仅 failed 时存在),用于前端提示 */
|
||||
error?: string
|
||||
}
|
||||
|
||||
export interface ModelsStatusResponse {
|
||||
|
||||
@@ -2,6 +2,13 @@
|
||||
|
||||
本项目所有重要变更记录于此。格式参考 [Keep a Changelog](https://keepachangelog.com/zh-CN/1.1.0/),遵循 [语义化版本](https://semver.org/lang/zh-CN/)。
|
||||
|
||||
## [2.4.3] - 2026-06-23
|
||||
|
||||
### Fixed
|
||||
|
||||
- **Whisper `large-v3-turbo` 模型无法下载**(#402):内置映射指向的 `Systran/faster-whisper-large-v3-turbo` 仓库已从 HuggingFace 下架(返回 401/404),点击下载会静默失败、状态一直显示「未下载」。改用社区维护的 CT2 转换版 `deepdml/faster-whisper-large-v3-turbo-ct2`(直链可达、含 `model.bin`,与 faster-whisper 的 `large-v3-turbo` 等价)。
|
||||
- **模型下载失败时前端无任何提示**(#402 衍生):`/transcriber_models_status` 此前只回传 `downloading`/`downloaded`,后台下载失败状态被丢弃。现新增 `model_download_state` 统一维护下载状态与失败原因,状态接口新增 `failed` 字段并透传 `error`;前端模型列表展示「下载失败」徽标 + 错误详情,按钮变为「重试」,并对新出现的失败弹出提示。
|
||||
|
||||
## [2.4.2] - 2026-06-17
|
||||
|
||||
### Fixed
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
<p align="center">
|
||||
<img src="./doc/icon.svg" alt="BiliNote Banner" width="50" height="50" />
|
||||
</p>
|
||||
<h1 align="center" > BiliNote v2.4.2</h1>
|
||||
<h1 align="center" > BiliNote v2.4.3</h1>
|
||||
</div>
|
||||
|
||||
<p align="center"><i>AI 视频笔记生成工具 让 AI 为你的视频做笔记</i></p>
|
||||
|
||||
@@ -11,6 +11,7 @@ from app.utils.path_helper import get_model_dir
|
||||
|
||||
from app.services.cookie_manager import CookieConfigManager
|
||||
from app.services.transcriber_config_manager import TranscriberConfigManager
|
||||
from app.transcriber import model_download_state as dl_state
|
||||
from ffmpeg_helper import ensure_ffmpeg_or_raise
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -148,9 +149,9 @@ def update_proxy_config(data: ProxyConfigRequest):
|
||||
|
||||
|
||||
# ---- Whisper 模型下载状态 & 下载触发 ----
|
||||
|
||||
# 用于跟踪正在进行的下载任务
|
||||
_downloading: dict[str, str] = {} # model_size -> status ("downloading" | "done" | "failed")
|
||||
# 下载状态(downloading / done / failed + 失败原因)统一交给 model_download_state 维护,
|
||||
# 「触发下载」与「查询状态」共享同一份进程内内存态。失败原因会随状态接口透传给前端,
|
||||
# 修复 issue #402 衍生问题:原先只回传 downloading/downloaded,下载失败时前端无任何提示。
|
||||
|
||||
|
||||
def _check_whisper_model_exists(model_size: str, subdir: str = "whisper") -> bool:
|
||||
@@ -212,12 +213,7 @@ def get_transcriber_models_status():
|
||||
statuses = []
|
||||
for size in get_registry().visible_model_names():
|
||||
downloaded = _check_whisper_model_exists(size, "whisper")
|
||||
download_status = _downloading.get(size)
|
||||
statuses.append({
|
||||
"model_size": size,
|
||||
"downloaded": downloaded,
|
||||
"downloading": download_status == "downloading",
|
||||
})
|
||||
statuses.append(dl_state.status_row(size, downloaded))
|
||||
|
||||
# 也检查 mlx-whisper(仅 macOS)
|
||||
mlx_available = platform.system() == "Darwin"
|
||||
@@ -225,16 +221,12 @@ def get_transcriber_models_status():
|
||||
if mlx_available:
|
||||
from app.transcriber.mlx_whisper_transcriber import MLX_MODEL_MAP
|
||||
for size in WHISPER_MODEL_SIZES:
|
||||
mlx_key = f"mlx-{size}"
|
||||
repo_id = MLX_MODEL_MAP.get(size)
|
||||
# 用 config.json 判定,和 _check_mlx_whisper_model_exists / 加载逻辑保持一致
|
||||
downloaded = _check_mlx_whisper_model_exists(size)
|
||||
mlx_statuses.append({
|
||||
"model_size": size,
|
||||
"downloaded": downloaded,
|
||||
"downloading": _downloading.get(mlx_key) == "downloading",
|
||||
"available": repo_id is not None,
|
||||
})
|
||||
row = dl_state.status_row(size, downloaded, key=f"mlx-{size}")
|
||||
row["available"] = repo_id is not None
|
||||
mlx_statuses.append(row)
|
||||
|
||||
return R.success(data={
|
||||
"whisper": statuses,
|
||||
@@ -260,21 +252,24 @@ def _do_download_whisper(model_size: str):
|
||||
from app.transcriber.whisper_models import resolve_whisper_model, is_local_target
|
||||
|
||||
try:
|
||||
_downloading[model_size] = "downloading"
|
||||
dl_state.mark_downloading(model_size)
|
||||
model_dir = get_model_dir("whisper")
|
||||
|
||||
# 已经下好就不重复下
|
||||
if _check_whisper_model_exists(model_size, "whisper"):
|
||||
_downloading[model_size] = "done"
|
||||
dl_state.mark_done(model_size)
|
||||
return
|
||||
|
||||
target = resolve_whisper_model(model_size)
|
||||
if is_local_target(target):
|
||||
# 本地模型不下载,只校验 model.bin 是否就位
|
||||
ok = (Path(target) / "model.bin").exists()
|
||||
_downloading[model_size] = "done" if ok else "failed"
|
||||
if not ok:
|
||||
logger.warning(f"本地模型 {model_size} 路径 {target} 下没有 model.bin,无法使用")
|
||||
if ok:
|
||||
dl_state.mark_done(model_size)
|
||||
else:
|
||||
msg = f"本地模型路径 {target} 下没有 model.bin,无法使用"
|
||||
logger.warning(f"本地模型 {model_size}:{msg}")
|
||||
dl_state.mark_failed(model_size, msg)
|
||||
return
|
||||
|
||||
logger.info(f"开始下载 whisper 模型: {model_size} ← {target}")
|
||||
@@ -292,17 +287,17 @@ def _do_download_whisper(model_size: str):
|
||||
],
|
||||
)
|
||||
logger.info(f"whisper 模型下载完成: {model_size}")
|
||||
_downloading[model_size] = "done"
|
||||
dl_state.mark_done(model_size)
|
||||
except Exception as e:
|
||||
logger.error(f"whisper 模型下载失败: {model_size}, {e}")
|
||||
_downloading[model_size] = "failed"
|
||||
dl_state.mark_failed(model_size, str(e))
|
||||
|
||||
|
||||
def _do_download_mlx_whisper(model_size: str):
|
||||
"""后台下载 mlx-whisper 模型。"""
|
||||
key = f"mlx-{model_size}"
|
||||
try:
|
||||
_downloading[key] = "downloading"
|
||||
dl_state.mark_downloading(key)
|
||||
from huggingface_hub import snapshot_download as hf_download
|
||||
from app.transcriber.mlx_whisper_transcriber import resolve_mlx_repo_id
|
||||
|
||||
@@ -310,22 +305,22 @@ def _do_download_mlx_whisper(model_size: str):
|
||||
repo_id = resolve_mlx_repo_id(model_size)
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
_downloading[key] = "failed"
|
||||
dl_state.mark_failed(key, str(e))
|
||||
return
|
||||
|
||||
model_dir = get_model_dir("mlx-whisper")
|
||||
model_path = os.path.join(model_dir, repo_id)
|
||||
# 用 config.json 判定而非目录存在:半成品目录不能算「已下载」
|
||||
if (Path(model_path) / "config.json").exists():
|
||||
_downloading[key] = "done"
|
||||
dl_state.mark_done(key)
|
||||
return
|
||||
logger.info(f"开始下载 mlx-whisper 模型: {model_size} ← {repo_id}")
|
||||
hf_download(repo_id, local_dir=model_path, local_dir_use_symlinks=False)
|
||||
logger.info(f"mlx-whisper 模型下载完成: {model_size}")
|
||||
_downloading[key] = "done"
|
||||
dl_state.mark_done(key)
|
||||
except Exception as e:
|
||||
logger.error(f"mlx-whisper 模型下载失败: {model_size}, {e}")
|
||||
_downloading[key] = "failed"
|
||||
dl_state.mark_failed(key, str(e))
|
||||
|
||||
|
||||
@router.post("/transcriber_download")
|
||||
@@ -338,7 +333,7 @@ def download_transcriber_model(data: ModelDownloadRequest, background_tasks: Bac
|
||||
if platform.system() != "Darwin":
|
||||
return R.error(msg="MLX Whisper 仅支持 macOS")
|
||||
key = f"mlx-{data.model_size}"
|
||||
if _downloading.get(key) == "downloading":
|
||||
if dl_state.is_downloading(key):
|
||||
return R.success(msg="模型正在下载中")
|
||||
background_tasks.add_task(_do_download_mlx_whisper, data.model_size)
|
||||
else:
|
||||
@@ -346,7 +341,7 @@ def download_transcriber_model(data: ModelDownloadRequest, background_tasks: Bac
|
||||
from app.transcriber.whisper_models import get_registry
|
||||
if not get_registry().is_known(data.model_size):
|
||||
return R.error(msg=f"不支持的模型: {data.model_size}(请先在自定义模型中登记)")
|
||||
if _downloading.get(data.model_size) == "downloading":
|
||||
if dl_state.is_downloading(data.model_size):
|
||||
return R.success(msg="模型正在下载中")
|
||||
background_tasks.add_task(_do_download_whisper, data.model_size)
|
||||
|
||||
|
||||
75
backend/app/transcriber/model_download_state.py
Normal file
75
backend/app/transcriber/model_download_state.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""whisper / mlx 模型后台下载状态跟踪(含失败原因)。
|
||||
|
||||
routers.config 的「触发下载」与「查询状态」共享这份进程内内存态:
|
||||
- key:fast-whisper 直接用 model_size;mlx 用 "mlx-{size}" 前缀(与历史一致)
|
||||
- 状态:downloading / done / failed;failed 时另存最近一次错误原因
|
||||
|
||||
为什么抽成独立的轻量模块(仅依赖 logger):
|
||||
1) 把原先散落在 config.py 多处的字符串状态赋值收敛到一处,避免拼写漂移;
|
||||
2) 失败原因能透传到 /transcriber_models_status → 前端,修复「下载失败前端无任何
|
||||
提示、状态一直显示未下载」(issue #402 的衍生问题:原先状态接口只回传
|
||||
downloading/downloaded 两个布尔,failed 态被直接丢弃);
|
||||
3) 不引入 faster_whisper / ctranslate2 等重依赖,可被单测隔离加载。
|
||||
"""
|
||||
from typing import Dict, Optional
|
||||
|
||||
from app.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
DOWNLOADING = "downloading"
|
||||
DONE = "done"
|
||||
FAILED = "failed"
|
||||
|
||||
# key -> 状态字符串;key -> 最近一次失败原因(仅 failed 时有意义)
|
||||
_status: Dict[str, str] = {}
|
||||
_errors: Dict[str, str] = {}
|
||||
|
||||
|
||||
def mark_downloading(key: str) -> None:
|
||||
_status[key] = DOWNLOADING
|
||||
_errors.pop(key, None) # 重新开始下载,清掉上一次的失败原因
|
||||
|
||||
|
||||
def mark_done(key: str) -> None:
|
||||
_status[key] = DONE
|
||||
_errors.pop(key, None)
|
||||
|
||||
|
||||
def mark_failed(key: str, error: str = "") -> None:
|
||||
_status[key] = FAILED
|
||||
if error:
|
||||
_errors[key] = error
|
||||
|
||||
|
||||
def get_status(key: str) -> Optional[str]:
|
||||
return _status.get(key)
|
||||
|
||||
|
||||
def is_downloading(key: str) -> bool:
|
||||
return _status.get(key) == DOWNLOADING
|
||||
|
||||
|
||||
def get_error(key: str) -> Optional[str]:
|
||||
return _errors.get(key)
|
||||
|
||||
|
||||
def status_row(name: str, downloaded: bool, key: Optional[str] = None) -> dict:
|
||||
"""构造单个模型给前端的状态行:downloaded / downloading / failed (+error)。
|
||||
|
||||
key 默认用 name;mlx 传 "mlx-{size}"。已下载成功(downloaded=True)的模型
|
||||
一律不回传 failed/error——避免「先失败后又下好」时残留旧的错误状态。
|
||||
"""
|
||||
k = key if key is not None else name
|
||||
st = _status.get(k)
|
||||
row: dict = {
|
||||
"model_size": name,
|
||||
"downloaded": downloaded,
|
||||
"downloading": st == DOWNLOADING,
|
||||
"failed": (not downloaded) and st == FAILED,
|
||||
}
|
||||
if row["failed"]:
|
||||
err = _errors.get(k)
|
||||
if err:
|
||||
row["error"] = err
|
||||
return row
|
||||
@@ -6,7 +6,8 @@
|
||||
检测三处,用户想用命名不符合该约定的模型(比如社区微调版、或自己下到本地的模型)就接不上。
|
||||
|
||||
本模块把映射**显式化 + 可配置**(对齐 mlx_whisper_transcriber.MLX_MODEL_MAP 的模式):
|
||||
- 内置:size → Systran/faster-whisper-{size}
|
||||
- 内置:size → faster-whisper 兼容的 CT2 repo_id(多数为 Systran/faster-whisper-{size};
|
||||
turbo 用社区维护版,见 BUILTIN_WHISPER_MODELS)
|
||||
- 自定义:用户在 config/whisper_models.json 登记 {名称: "<repo_id 或本地路径>"}
|
||||
(JSON 持久化;Docker 下随 config 卷持久化)
|
||||
|
||||
@@ -22,7 +23,8 @@ from app.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 内置模型:size → faster-whisper 兼容的 HF repo_id(CTranslate2 转换版,Systran 官方维护)。
|
||||
# 内置模型:size → faster-whisper 兼容的 HF repo_id(CTranslate2 转换版)。
|
||||
# 多数档位用 Systran 官方维护的转换版;turbo 例外见下。
|
||||
BUILTIN_WHISPER_MODELS: Dict[str, str] = {
|
||||
"tiny": "Systran/faster-whisper-tiny",
|
||||
"base": "Systran/faster-whisper-base",
|
||||
@@ -31,7 +33,10 @@ BUILTIN_WHISPER_MODELS: Dict[str, str] = {
|
||||
"large-v1": "Systran/faster-whisper-large-v1",
|
||||
"large-v2": "Systran/faster-whisper-large-v2",
|
||||
"large-v3": "Systran/faster-whisper-large-v3",
|
||||
"large-v3-turbo": "Systran/faster-whisper-large-v3-turbo",
|
||||
# issue #402:Systran 没有 turbo 的 CT2 转换版(Systran/faster-whisper-large-v3-turbo
|
||||
# 在 HF 上 401/404),点下载会静默失败、状态一直「未下载」。改用社区维护的 CT2 转换版
|
||||
# (deepdml,直链可达、含 model.bin,与 faster-whisper 的 large-v3-turbo 等价)。
|
||||
"large-v3-turbo": "deepdml/faster-whisper-large-v3-turbo-ct2",
|
||||
}
|
||||
|
||||
# 前端下拉默认展示的内置档位(保持与历史 WHISPER_MODEL_SIZES 一致,不把 8 个全列出来)
|
||||
|
||||
113
backend/tests/test_model_download_state.py
Normal file
113
backend/tests/test_model_download_state.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Unit tests for app.transcriber.model_download_state(模型下载状态 + 失败原因跟踪)。
|
||||
|
||||
与 test_whisper_models 一样按文件路径隔离加载,并桩掉 app.utils.logger,
|
||||
避免触发 app/__init__.py(会 import faster_whisper 等重依赖)。
|
||||
"""
|
||||
import importlib.util
|
||||
import logging
|
||||
import pathlib
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
|
||||
ROOT = pathlib.Path(__file__).resolve().parents[1]
|
||||
MODULE_PATH = ROOT / "app" / "transcriber" / "model_download_state.py"
|
||||
|
||||
|
||||
def _load_module():
|
||||
if "app" not in sys.modules:
|
||||
app_pkg = types.ModuleType("app")
|
||||
app_pkg.__path__ = []
|
||||
sys.modules["app"] = app_pkg
|
||||
if "app.utils" not in sys.modules:
|
||||
utils_pkg = types.ModuleType("app.utils")
|
||||
utils_pkg.__path__ = []
|
||||
sys.modules["app.utils"] = utils_pkg
|
||||
if "app.utils.logger" not in sys.modules:
|
||||
logger_mod = types.ModuleType("app.utils.logger")
|
||||
logger_mod.get_logger = lambda name=None: logging.getLogger(name or "test")
|
||||
sys.modules["app.utils.logger"] = logger_mod
|
||||
spec = importlib.util.spec_from_file_location("model_download_state_under_test", MODULE_PATH)
|
||||
assert spec and spec.loader
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
return mod
|
||||
|
||||
|
||||
ds = _load_module()
|
||||
|
||||
|
||||
class TestDownloadState(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# 模块级单例,测试间互相隔离
|
||||
ds._status.clear()
|
||||
ds._errors.clear()
|
||||
|
||||
def test_unknown_key_defaults(self):
|
||||
row = ds.status_row("tiny", downloaded=False)
|
||||
self.assertEqual(
|
||||
row,
|
||||
{"model_size": "tiny", "downloaded": False, "downloading": False, "failed": False},
|
||||
)
|
||||
self.assertNotIn("error", row)
|
||||
self.assertFalse(ds.is_downloading("tiny"))
|
||||
|
||||
def test_downloading(self):
|
||||
ds.mark_downloading("tiny")
|
||||
self.assertTrue(ds.is_downloading("tiny"))
|
||||
row = ds.status_row("tiny", downloaded=False)
|
||||
self.assertTrue(row["downloading"])
|
||||
self.assertFalse(row["failed"])
|
||||
|
||||
def test_failed_surfaces_error(self):
|
||||
ds.mark_failed("tiny", "401 Repository Not Found")
|
||||
row = ds.status_row("tiny", downloaded=False)
|
||||
self.assertTrue(row["failed"])
|
||||
self.assertFalse(row["downloading"])
|
||||
self.assertEqual(row["error"], "401 Repository Not Found")
|
||||
self.assertEqual(ds.get_error("tiny"), "401 Repository Not Found")
|
||||
|
||||
def test_failed_without_message_has_no_error_field(self):
|
||||
ds.mark_failed("tiny")
|
||||
row = ds.status_row("tiny", downloaded=False)
|
||||
self.assertTrue(row["failed"])
|
||||
self.assertNotIn("error", row)
|
||||
|
||||
def test_downloaded_overrides_failed(self):
|
||||
# 先失败后又下好:downloaded=True 时不应再回传 failed/error
|
||||
ds.mark_failed("tiny", "boom")
|
||||
row = ds.status_row("tiny", downloaded=True)
|
||||
self.assertFalse(row["failed"])
|
||||
self.assertTrue(row["downloaded"])
|
||||
self.assertNotIn("error", row)
|
||||
|
||||
def test_mark_done_clears_error(self):
|
||||
ds.mark_failed("tiny", "boom")
|
||||
ds.mark_done("tiny")
|
||||
self.assertIsNone(ds.get_error("tiny"))
|
||||
row = ds.status_row("tiny", downloaded=True)
|
||||
self.assertFalse(row["failed"])
|
||||
|
||||
def test_redownload_clears_previous_error(self):
|
||||
ds.mark_failed("tiny", "boom")
|
||||
ds.mark_downloading("tiny") # 重新开始下载
|
||||
self.assertIsNone(ds.get_error("tiny"))
|
||||
row = ds.status_row("tiny", downloaded=False)
|
||||
self.assertTrue(row["downloading"])
|
||||
self.assertFalse(row["failed"])
|
||||
self.assertNotIn("error", row)
|
||||
|
||||
def test_mlx_key_is_independent(self):
|
||||
# mlx 用 "mlx-{size}" 前缀,与 fast-whisper 的同名档位互不影响
|
||||
ds.mark_failed("mlx-tiny", "mlx boom")
|
||||
ds.mark_downloading("tiny")
|
||||
whisper_row = ds.status_row("tiny", downloaded=False)
|
||||
mlx_row = ds.status_row("tiny", downloaded=False, key="mlx-tiny")
|
||||
self.assertTrue(whisper_row["downloading"])
|
||||
self.assertFalse(whisper_row["failed"])
|
||||
self.assertTrue(mlx_row["failed"])
|
||||
self.assertEqual(mlx_row["error"], "mlx boom")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -50,7 +50,19 @@ class TestResolve(unittest.TestCase):
|
||||
|
||||
def test_builtin_resolves_to_systran(self):
|
||||
self.assertEqual(self.reg.resolve("tiny"), "Systran/faster-whisper-tiny")
|
||||
self.assertEqual(self.reg.resolve("large-v3-turbo"), "Systran/faster-whisper-large-v3-turbo")
|
||||
|
||||
def test_large_v3_turbo_resolves_to_live_repo(self):
|
||||
# 回归 issue #402:Systran 从未发布 turbo 的 CT2 转换版,
|
||||
# 原映射 Systran/faster-whisper-large-v3-turbo 在 HF 上 401/404,
|
||||
# 导致下载静默失败、状态一直「未下载」。改用社区维护的 CT2 转换版。
|
||||
self.assertEqual(
|
||||
self.reg.resolve("large-v3-turbo"),
|
||||
"deepdml/faster-whisper-large-v3-turbo-ct2",
|
||||
)
|
||||
self.assertNotEqual(
|
||||
self.reg.resolve("large-v3-turbo"),
|
||||
"Systran/faster-whisper-large-v3-turbo",
|
||||
)
|
||||
|
||||
def test_passthrough_repo_id(self):
|
||||
# 用户直接把 HF repo_id 当 model_size 传进来(含 "/")
|
||||
|
||||
Reference in New Issue
Block a user