mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-06-11 18:49:59 +08:00
Merge branch 'feat/configurable-whisper-models' into develop
This commit is contained in:
@@ -10,13 +10,16 @@ import {
|
|||||||
SelectValue,
|
SelectValue,
|
||||||
} from '@/components/ui/select'
|
} from '@/components/ui/select'
|
||||||
import { Alert, AlertDescription } from '@/components/ui/alert'
|
import { Alert, AlertDescription } from '@/components/ui/alert'
|
||||||
import { AudioLines, AlertTriangle, CheckCircle2, Download, Loader2, Save, XCircle } from 'lucide-react'
|
import { Input } from '@/components/ui/input'
|
||||||
|
import { AudioLines, AlertTriangle, CheckCircle2, Download, Loader2, Save, XCircle, Plus, Trash2, Boxes } from 'lucide-react'
|
||||||
import { toast } from 'react-hot-toast'
|
import { toast } from 'react-hot-toast'
|
||||||
import {
|
import {
|
||||||
getTranscriberConfig,
|
getTranscriberConfig,
|
||||||
updateTranscriberConfig,
|
updateTranscriberConfig,
|
||||||
getModelsStatus,
|
getModelsStatus,
|
||||||
downloadModel,
|
downloadModel,
|
||||||
|
addWhisperModel,
|
||||||
|
deleteWhisperModel,
|
||||||
TranscriberConfig,
|
TranscriberConfig,
|
||||||
ModelStatus,
|
ModelStatus,
|
||||||
} from '@/services/transcriber'
|
} from '@/services/transcriber'
|
||||||
@@ -33,6 +36,19 @@ export default function Transcriber() {
|
|||||||
const [modelStatuses, setModelStatuses] = useState<ModelStatus[]>([])
|
const [modelStatuses, setModelStatuses] = useState<ModelStatus[]>([])
|
||||||
const [mlxModelStatuses, setMlxModelStatuses] = useState<ModelStatus[]>([])
|
const [mlxModelStatuses, setMlxModelStatuses] = useState<ModelStatus[]>([])
|
||||||
const [mlxAvailable, setMlxAvailable] = useState(false)
|
const [mlxAvailable, setMlxAvailable] = useState(false)
|
||||||
|
// 自定义模型表单
|
||||||
|
const [newModelName, setNewModelName] = useState('')
|
||||||
|
const [newModelTarget, setNewModelTarget] = useState('')
|
||||||
|
const [addingModel, setAddingModel] = useState(false)
|
||||||
|
|
||||||
|
// 重新拉取配置(不重置用户当前的选择),用于增删自定义模型后刷新下拉与列表
|
||||||
|
const reloadConfig = useCallback(async () => {
|
||||||
|
try {
|
||||||
|
setConfig(await getTranscriberConfig())
|
||||||
|
} catch {
|
||||||
|
// 静默
|
||||||
|
}
|
||||||
|
}, [])
|
||||||
|
|
||||||
const fetchModelsStatus = useCallback(async () => {
|
const fetchModelsStatus = useCallback(async () => {
|
||||||
try {
|
try {
|
||||||
@@ -123,6 +139,41 @@ export default function Transcriber() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const handleAddCustomModel = async () => {
|
||||||
|
const name = newModelName.trim()
|
||||||
|
const target = newModelTarget.trim()
|
||||||
|
if (!name || !target) {
|
||||||
|
toast.error('请填写模型名称和 HF repo_id / 本地路径')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
setAddingModel(true)
|
||||||
|
try {
|
||||||
|
await addWhisperModel({ name, target })
|
||||||
|
toast.success(`已添加自定义模型 ${name}`)
|
||||||
|
setNewModelName('')
|
||||||
|
setNewModelTarget('')
|
||||||
|
await reloadConfig()
|
||||||
|
await fetchModelsStatus()
|
||||||
|
} catch {
|
||||||
|
// 后端的具体错误(如重名)已由请求拦截器 toast,这里不重复提示
|
||||||
|
} finally {
|
||||||
|
setAddingModel(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleDeleteCustomModel = async (name: string) => {
|
||||||
|
try {
|
||||||
|
await deleteWhisperModel(name)
|
||||||
|
toast.success(`已删除自定义模型 ${name}`)
|
||||||
|
// 删的正好是当前选中的,回退到 tiny,避免选中一个不存在的名称
|
||||||
|
if (selectedModelSize === name) setSelectedModelSize('tiny')
|
||||||
|
await reloadConfig()
|
||||||
|
await fetchModelsStatus()
|
||||||
|
} catch {
|
||||||
|
// 拦截器已提示
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (loading) {
|
if (loading) {
|
||||||
return (
|
return (
|
||||||
<div className="flex h-64 items-center justify-center">
|
<div className="flex h-64 items-center justify-center">
|
||||||
@@ -272,6 +323,97 @@ export default function Transcriber() {
|
|||||||
</CardContent>
|
</CardContent>
|
||||||
</Card>
|
</Card>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
{/* 自定义 Whisper 模型(仅 fast-whisper:名称不符合内置 Systran 约定的模型在此登记映射) */}
|
||||||
|
{selectedType === 'fast-whisper' && (
|
||||||
|
<Card>
|
||||||
|
<CardHeader>
|
||||||
|
<CardTitle className="flex items-center gap-2 text-lg">
|
||||||
|
<Boxes className="h-5 w-5" />
|
||||||
|
自定义模型
|
||||||
|
<span className="text-sm font-normal text-neutral-400">
|
||||||
|
登记名称不符合内置约定的模型
|
||||||
|
</span>
|
||||||
|
</CardTitle>
|
||||||
|
</CardHeader>
|
||||||
|
<CardContent className="space-y-4">
|
||||||
|
<Alert className="text-sm">
|
||||||
|
<AlertDescription>
|
||||||
|
填 <strong>HF repo_id</strong>(如{' '}
|
||||||
|
<code className="rounded bg-neutral-100 px-1">Systran/faster-whisper-large-v3</code>
|
||||||
|
,会自动下载)或<strong>本地模型目录</strong>(如{' '}
|
||||||
|
<code className="rounded bg-neutral-100 px-1">/app/backend/models/my-whisper</code>
|
||||||
|
,目录内需含 <code className="rounded bg-neutral-100 px-1">model.bin</code>,下载会跳过)。
|
||||||
|
添加后即可在上方「模型大小」下拉中选用。Docker 部署请把模型目录挂载进容器(见 README 的{' '}
|
||||||
|
<code className="rounded bg-neutral-100 px-1">models</code> 卷)。
|
||||||
|
</AlertDescription>
|
||||||
|
</Alert>
|
||||||
|
|
||||||
|
{config.whisper_custom_models &&
|
||||||
|
Object.keys(config.whisper_custom_models).length > 0 ? (
|
||||||
|
<div className="space-y-2">
|
||||||
|
{Object.entries(config.whisper_custom_models).map(([name, target]) => {
|
||||||
|
const status = modelStatuses.find(m => m.model_size === name)
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
key={name}
|
||||||
|
className="flex items-center justify-between gap-3 rounded-md border px-4 py-2.5"
|
||||||
|
>
|
||||||
|
<div className="min-w-0">
|
||||||
|
<div className="flex items-center gap-2 font-medium">
|
||||||
|
{name}
|
||||||
|
{status?.downloaded && (
|
||||||
|
<CheckCircle2 className="h-3.5 w-3.5 text-green-500" />
|
||||||
|
)}
|
||||||
|
{status?.downloading && (
|
||||||
|
<Loader2 className="h-3.5 w-3.5 animate-spin text-neutral-400" />
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<div className="truncate text-xs text-neutral-400" title={target}>
|
||||||
|
{target}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<Button
|
||||||
|
size="sm"
|
||||||
|
variant="ghost"
|
||||||
|
className="text-red-500 hover:text-red-600"
|
||||||
|
onClick={() => handleDeleteCustomModel(name)}
|
||||||
|
>
|
||||||
|
<Trash2 className="h-4 w-4" />
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
})}
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<p className="text-sm text-neutral-400">还没有自定义模型</p>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<div className="flex flex-col gap-2 sm:flex-row sm:items-center">
|
||||||
|
<Input
|
||||||
|
placeholder="模型名称(自定义,如 my-large-v3)"
|
||||||
|
value={newModelName}
|
||||||
|
onChange={e => setNewModelName(e.target.value)}
|
||||||
|
className="sm:max-w-[220px]"
|
||||||
|
/>
|
||||||
|
<Input
|
||||||
|
placeholder="HF repo_id 或本地路径"
|
||||||
|
value={newModelTarget}
|
||||||
|
onChange={e => setNewModelTarget(e.target.value)}
|
||||||
|
className="flex-1"
|
||||||
|
/>
|
||||||
|
<Button onClick={handleAddCustomModel} disabled={addingModel}>
|
||||||
|
{addingModel ? (
|
||||||
|
<Loader2 className="mr-1 h-4 w-4 animate-spin" />
|
||||||
|
) : (
|
||||||
|
<Plus className="mr-1 h-4 w-4" />
|
||||||
|
)}
|
||||||
|
添加
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,10 @@ export interface TranscriberConfig {
|
|||||||
whisper_model_size: string
|
whisper_model_size: string
|
||||||
available_types: { value: string; label: string }[]
|
available_types: { value: string; label: string }[]
|
||||||
whisper_model_sizes: string[]
|
whisper_model_sizes: string[]
|
||||||
|
/** 内置模型映射:size → HF repo_id */
|
||||||
|
whisper_builtin_models?: Record<string, string>
|
||||||
|
/** 用户自定义模型映射:名称 → HF repo_id 或本地路径 */
|
||||||
|
whisper_custom_models?: Record<string, string>
|
||||||
mlx_whisper_available: boolean
|
mlx_whisper_available: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -41,3 +45,23 @@ export const downloadModel = async (data: {
|
|||||||
}) => {
|
}) => {
|
||||||
return await request.post('/transcriber_download', data)
|
return await request.post('/transcriber_download', data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface WhisperModelsResponse {
|
||||||
|
builtin: Record<string, string>
|
||||||
|
custom: Record<string, string>
|
||||||
|
}
|
||||||
|
|
||||||
|
/** 列出内置 + 自定义 whisper 模型映射 */
|
||||||
|
export const listWhisperModels = async (): Promise<WhisperModelsResponse> => {
|
||||||
|
return await request.get('/whisper_models')
|
||||||
|
}
|
||||||
|
|
||||||
|
/** 新增自定义模型映射(名称 → HF repo_id 或本地路径) */
|
||||||
|
export const addWhisperModel = async (data: { name: string; target: string }) => {
|
||||||
|
return await request.post('/whisper_models', data)
|
||||||
|
}
|
||||||
|
|
||||||
|
/** 删除自定义模型映射(不会删除已下载的模型文件) */
|
||||||
|
export const deleteWhisperModel = async (name: string) => {
|
||||||
|
return await request.delete(`/whisper_models/${encodeURIComponent(name)}`)
|
||||||
|
}
|
||||||
|
|||||||
@@ -61,16 +61,53 @@ WHISPER_MODEL_SIZES = ["tiny", "base", "small", "medium", "large-v3", "large-v3-
|
|||||||
@router.get("/transcriber_config")
|
@router.get("/transcriber_config")
|
||||||
def get_transcriber_config():
|
def get_transcriber_config():
|
||||||
from app.transcriber.transcriber_provider import MLX_WHISPER_AVAILABLE
|
from app.transcriber.transcriber_provider import MLX_WHISPER_AVAILABLE
|
||||||
|
from app.transcriber.whisper_models import get_registry, BUILTIN_WHISPER_MODELS
|
||||||
|
|
||||||
|
registry = get_registry()
|
||||||
config = transcriber_config_manager.get_config()
|
config = transcriber_config_manager.get_config()
|
||||||
return R.success(data={
|
return R.success(data={
|
||||||
**config,
|
**config,
|
||||||
"available_types": AVAILABLE_TRANSCRIBER_TYPES,
|
"available_types": AVAILABLE_TRANSCRIBER_TYPES,
|
||||||
"whisper_model_sizes": WHISPER_MODEL_SIZES,
|
# 内置可见档位 + 用户自定义模型,供前端下拉
|
||||||
|
"whisper_model_sizes": registry.visible_model_names(),
|
||||||
|
"whisper_builtin_models": BUILTIN_WHISPER_MODELS,
|
||||||
|
"whisper_custom_models": registry.get_custom_models(),
|
||||||
"mlx_whisper_available": MLX_WHISPER_AVAILABLE,
|
"mlx_whisper_available": MLX_WHISPER_AVAILABLE,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperCustomModelRequest(BaseModel):
|
||||||
|
name: str
|
||||||
|
target: str # HF repo_id(如 Systran/faster-whisper-large-v3)或本地模型目录路径
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/whisper_models")
|
||||||
|
def list_whisper_models():
|
||||||
|
"""列出内置 + 用户自定义的 whisper 模型映射。"""
|
||||||
|
from app.transcriber.whisper_models import get_registry, BUILTIN_WHISPER_MODELS
|
||||||
|
reg = get_registry()
|
||||||
|
return R.success(data={"builtin": BUILTIN_WHISPER_MODELS, "custom": reg.get_custom_models()})
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/whisper_models")
|
||||||
|
def add_whisper_model(data: WhisperCustomModelRequest):
|
||||||
|
"""新增自定义 whisper 模型映射(名称 → HF repo_id 或本地路径)。"""
|
||||||
|
from app.transcriber.whisper_models import get_registry
|
||||||
|
try:
|
||||||
|
custom = get_registry().add_custom_model(data.name, data.target)
|
||||||
|
except ValueError as e:
|
||||||
|
return R.error(msg=str(e))
|
||||||
|
return R.success(data={"custom": custom}, msg="已添加自定义模型")
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/whisper_models/{name}")
|
||||||
|
def delete_whisper_model(name: str):
|
||||||
|
"""删除自定义 whisper 模型映射(不会删除已下载的模型文件)。"""
|
||||||
|
from app.transcriber.whisper_models import get_registry
|
||||||
|
custom = get_registry().remove_custom_model(name)
|
||||||
|
return R.success(data={"custom": custom}, msg="已删除自定义模型")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/transcriber_config")
|
@router.post("/transcriber_config")
|
||||||
def update_transcriber_config(data: TranscriberConfigRequest):
|
def update_transcriber_config(data: TranscriberConfigRequest):
|
||||||
config = transcriber_config_manager.update_config(
|
config = transcriber_config_manager.update_config(
|
||||||
@@ -119,14 +156,27 @@ _downloading: dict[str, str] = {} # model_size -> status ("downloading" | "done
|
|||||||
def _check_whisper_model_exists(model_size: str, subdir: str = "whisper") -> bool:
|
def _check_whisper_model_exists(model_size: str, subdir: str = "whisper") -> bool:
|
||||||
"""检查指定 whisper 模型是否已下载完整到本地。
|
"""检查指定 whisper 模型是否已下载完整到本地。
|
||||||
|
|
||||||
faster-whisper 把模型缓存在 HF cache 布局下:
|
先把模型名 resolve 成可加载标识,再按类型判定:
|
||||||
<model_dir>/models--Systran--faster-whisper-{size}/snapshots/<hash>/model.bin
|
- 本地路径模型 → 直接看该目录下有没有 model.bin
|
||||||
必须能在某个 snapshot 目录里找到 model.bin 才算完成。
|
- HF repo_id → 看 HF cache 布局
|
||||||
(历史 modelscope 布局 <model_dir>/whisper-{size}/model.bin 也兼容识别。)
|
<model_dir>/models--{org}--{name}/snapshots/<hash>/model.bin
|
||||||
|
(历史 modelscope 布局 <model_dir>/whisper-{size}/model.bin 也兼容识别)
|
||||||
"""
|
"""
|
||||||
|
from app.transcriber.whisper_models import (
|
||||||
|
resolve_whisper_model,
|
||||||
|
is_local_target,
|
||||||
|
hf_cache_dirname,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
target = resolve_whisper_model(model_size)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
if is_local_target(target):
|
||||||
|
return (Path(target) / "model.bin").exists()
|
||||||
|
|
||||||
model_dir = Path(get_model_dir(subdir))
|
model_dir = Path(get_model_dir(subdir))
|
||||||
# HF cache 布局
|
# HF cache 布局(适配任意 org/repo,不再写死 Systran)
|
||||||
hf_repo_dir = model_dir / f"models--Systran--faster-whisper-{model_size}" / "snapshots"
|
hf_repo_dir = model_dir / hf_cache_dirname(target) / "snapshots"
|
||||||
if hf_repo_dir.exists():
|
if hf_repo_dir.exists():
|
||||||
for snapshot in hf_repo_dir.iterdir():
|
for snapshot in hf_repo_dir.iterdir():
|
||||||
if (snapshot / "model.bin").exists():
|
if (snapshot / "model.bin").exists():
|
||||||
@@ -157,9 +207,10 @@ def _check_mlx_whisper_model_exists(model_size: str) -> bool:
|
|||||||
|
|
||||||
@router.get("/transcriber_models_status")
|
@router.get("/transcriber_models_status")
|
||||||
def get_transcriber_models_status():
|
def get_transcriber_models_status():
|
||||||
"""返回所有 whisper 模型的下载状态。"""
|
"""返回所有 whisper 模型的下载状态(含用户自定义模型)。"""
|
||||||
|
from app.transcriber.whisper_models import get_registry
|
||||||
statuses = []
|
statuses = []
|
||||||
for size in WHISPER_MODEL_SIZES:
|
for size in get_registry().visible_model_names():
|
||||||
downloaded = _check_whisper_model_exists(size, "whisper")
|
downloaded = _check_whisper_model_exists(size, "whisper")
|
||||||
download_status = _downloading.get(size)
|
download_status = _downloading.get(size)
|
||||||
statuses.append({
|
statuses.append({
|
||||||
@@ -198,13 +249,15 @@ class ModelDownloadRequest(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
def _do_download_whisper(model_size: str):
|
def _do_download_whisper(model_size: str):
|
||||||
"""后台下载 faster-whisper 模型。
|
"""后台下载 faster-whisper 模型(支持内置 size / 自定义 repo_id / 本地路径)。
|
||||||
|
|
||||||
直接走 huggingface_hub.snapshot_download,把模型放到 HF cache 布局里——
|
模型名先 resolve:
|
||||||
这样 faster-whisper 加载时(WhisperModel(model_size_or_path=size_name,
|
- 本地路径模型:无需下载,目录里有 model.bin 即 done,否则 failed;
|
||||||
download_root=model_dir))能直接命中缓存,跟加载路径完全对齐。
|
- HF repo_id:snapshot_download 到 HF cache 布局(cache_dir=model_dir),
|
||||||
|
与加载逻辑 WhisperModel(download_root=model_dir) 完全对齐。
|
||||||
"""
|
"""
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
from app.transcriber.whisper_models import resolve_whisper_model, is_local_target
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_downloading[model_size] = "downloading"
|
_downloading[model_size] = "downloading"
|
||||||
@@ -214,12 +267,21 @@ def _do_download_whisper(model_size: str):
|
|||||||
if _check_whisper_model_exists(model_size, "whisper"):
|
if _check_whisper_model_exists(model_size, "whisper"):
|
||||||
_downloading[model_size] = "done"
|
_downloading[model_size] = "done"
|
||||||
return
|
return
|
||||||
repo_id = f"Systran/faster-whisper-{model_size}"
|
|
||||||
logger.info(f"开始下载 whisper 模型: {repo_id}")
|
target = resolve_whisper_model(model_size)
|
||||||
|
if is_local_target(target):
|
||||||
|
# 本地模型不下载,只校验 model.bin 是否就位
|
||||||
|
ok = (Path(target) / "model.bin").exists()
|
||||||
|
_downloading[model_size] = "done" if ok else "failed"
|
||||||
|
if not ok:
|
||||||
|
logger.warning(f"本地模型 {model_size} 路径 {target} 下没有 model.bin,无法使用")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"开始下载 whisper 模型: {model_size} ← {target}")
|
||||||
# 跟 faster-whisper utils.py 用同样的 allow_patterns,避免多下无关文件;
|
# 跟 faster-whisper utils.py 用同样的 allow_patterns,避免多下无关文件;
|
||||||
# 不传 local_dir 让它走 HF 默认 cache 布局(与加载逻辑对齐)
|
# 不传 local_dir 让它走 HF 默认 cache 布局(与加载逻辑对齐)
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
repo_id,
|
target,
|
||||||
cache_dir=model_dir,
|
cache_dir=model_dir,
|
||||||
allow_patterns=[
|
allow_patterns=[
|
||||||
"config.json",
|
"config.json",
|
||||||
@@ -268,11 +330,11 @@ def _do_download_mlx_whisper(model_size: str):
|
|||||||
|
|
||||||
@router.post("/transcriber_download")
|
@router.post("/transcriber_download")
|
||||||
def download_transcriber_model(data: ModelDownloadRequest, background_tasks: BackgroundTasks):
|
def download_transcriber_model(data: ModelDownloadRequest, background_tasks: BackgroundTasks):
|
||||||
"""触发后台下载指定的 whisper 模型。"""
|
"""触发后台下载指定的 whisper 模型(fast-whisper 支持内置档位 + 自定义模型)。"""
|
||||||
if data.model_size not in WHISPER_MODEL_SIZES:
|
|
||||||
return R.error(msg=f"不支持的模型大小: {data.model_size}")
|
|
||||||
|
|
||||||
if data.transcriber_type == "mlx-whisper":
|
if data.transcriber_type == "mlx-whisper":
|
||||||
|
# mlx 只认内置档位(mlx-community 的固定映射)
|
||||||
|
if data.model_size not in WHISPER_MODEL_SIZES:
|
||||||
|
return R.error(msg=f"MLX 不支持的模型大小: {data.model_size}")
|
||||||
if platform.system() != "Darwin":
|
if platform.system() != "Darwin":
|
||||||
return R.error(msg="MLX Whisper 仅支持 macOS")
|
return R.error(msg="MLX Whisper 仅支持 macOS")
|
||||||
key = f"mlx-{data.model_size}"
|
key = f"mlx-{data.model_size}"
|
||||||
@@ -280,6 +342,10 @@ def download_transcriber_model(data: ModelDownloadRequest, background_tasks: Bac
|
|||||||
return R.success(msg="模型正在下载中")
|
return R.success(msg="模型正在下载中")
|
||||||
background_tasks.add_task(_do_download_mlx_whisper, data.model_size)
|
background_tasks.add_task(_do_download_mlx_whisper, data.model_size)
|
||||||
else:
|
else:
|
||||||
|
# fast-whisper:内置档位 / 自定义 repo_id / 本地路径都允许
|
||||||
|
from app.transcriber.whisper_models import get_registry
|
||||||
|
if not get_registry().is_known(data.model_size):
|
||||||
|
return R.error(msg=f"不支持的模型: {data.model_size}(请先在自定义模型中登记)")
|
||||||
if _downloading.get(data.model_size) == "downloading":
|
if _downloading.get(data.model_size) == "downloading":
|
||||||
return R.success(msg="模型正在下载中")
|
return R.success(msg="模型正在下载中")
|
||||||
background_tasks.add_task(_do_download_whisper, data.model_size)
|
background_tasks.add_task(_do_download_whisper, data.model_size)
|
||||||
|
|||||||
@@ -3,6 +3,11 @@ from faster_whisper import WhisperModel
|
|||||||
from app.decorators.timeit import timeit
|
from app.decorators.timeit import timeit
|
||||||
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
|
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
|
||||||
from app.transcriber.base import Transcriber
|
from app.transcriber.base import Transcriber
|
||||||
|
from app.transcriber.whisper_models import (
|
||||||
|
resolve_whisper_model,
|
||||||
|
is_local_target,
|
||||||
|
hf_cache_dirname,
|
||||||
|
)
|
||||||
from app.utils.env_checker import is_cuda_available, is_torch_installed
|
from app.utils.env_checker import is_cuda_available, is_torch_installed
|
||||||
from app.utils.logger import get_logger
|
from app.utils.logger import get_logger
|
||||||
from app.utils.path_helper import get_model_dir
|
from app.utils.path_helper import get_model_dir
|
||||||
@@ -55,8 +60,12 @@ class WhisperTranscriber(Transcriber):
|
|||||||
self.model = self._build_model(model_size, model_dir)
|
self.model = self._build_model(model_size, model_dir)
|
||||||
|
|
||||||
def _build_model(self, model_size: str, model_dir: str) -> WhisperModel:
|
def _build_model(self, model_size: str, model_dir: str) -> WhisperModel:
|
||||||
|
# resolve 把模型名映射成可加载标识:内置 size→Systran repo_id、自定义映射、
|
||||||
|
# 直通的 repo_id 或本地路径。faster-whisper 对本地目录走 os.path.isdir 分支,
|
||||||
|
# 对 repo_id 走 download_model(cache_dir=download_root),两者都吃 model_size_or_path。
|
||||||
|
target = resolve_whisper_model(model_size)
|
||||||
return WhisperModel(
|
return WhisperModel(
|
||||||
model_size_or_path=model_size, # 传 size name,让 faster-whisper 自己映射到 Systran/faster-whisper-*
|
model_size_or_path=target,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
compute_type=self.compute_type,
|
compute_type=self.compute_type,
|
||||||
download_root=model_dir,
|
download_root=model_dir,
|
||||||
@@ -64,13 +73,22 @@ class WhisperTranscriber(Transcriber):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _purge_cache(model_dir: str, model_size: str) -> None:
|
def _purge_cache(model_dir: str, model_size: str) -> None:
|
||||||
"""删掉 HF cache 里这个 size 对应的 snapshot 目录,强制下次重新下载。
|
"""加载失败时清掉对应 HF cache 的 snapshot 目录,强制下次重下。
|
||||||
|
|
||||||
HF cache 布局:<model_dir>/models--Systran--faster-whisper-{size}/
|
关键:本地路径模型**绝不删**——那是用户自己的文件,删了就是数据丢失;
|
||||||
没找到也不报错——可能用户改了 endpoint 或者 cache 布局变了。
|
只清 HF cache 布局 <model_dir>/models--{org}--{name}/(含历史 modelscope 目录)。
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
target = resolve_whisper_model(model_size)
|
||||||
|
except Exception:
|
||||||
|
target = model_size
|
||||||
|
if is_local_target(target):
|
||||||
|
logger.warning(
|
||||||
|
f"模型 {model_size} 指向本地路径 {target},加载失败不清理用户文件,请检查该目录是否完整"
|
||||||
|
)
|
||||||
|
return
|
||||||
candidates = [
|
candidates = [
|
||||||
Path(model_dir) / f"models--Systran--faster-whisper-{model_size}",
|
Path(model_dir) / hf_cache_dirname(target), # HF cache: models--org--name
|
||||||
Path(model_dir) / f"whisper-{model_size}", # 历史 modelscope 目录,顺手清掉
|
Path(model_dir) / f"whisper-{model_size}", # 历史 modelscope 目录,顺手清掉
|
||||||
]
|
]
|
||||||
for path in candidates:
|
for path in candidates:
|
||||||
|
|||||||
156
backend/app/transcriber/whisper_models.py
Normal file
156
backend/app/transcriber/whisper_models.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
"""fast-whisper 模型名 → 可加载标识(HF repo_id 或本地路径)的映射注册表。
|
||||||
|
|
||||||
|
背景:faster-whisper 加载时 `WhisperModel(model_size_or_path=...)` 接受三种入参——
|
||||||
|
内置 size 名、HF repo_id(含 "/")、或本地模型目录(`os.path.isdir` 命中则直接用)。
|
||||||
|
此前后端把「size → Systran/faster-whisper-{size}」这层约定**隐式**散落在加载/下载/
|
||||||
|
检测三处,用户想用命名不符合该约定的模型(比如社区微调版、或自己下到本地的模型)就接不上。
|
||||||
|
|
||||||
|
本模块把映射**显式化 + 可配置**(对齐 mlx_whisper_transcriber.MLX_MODEL_MAP 的模式):
|
||||||
|
- 内置:size → Systran/faster-whisper-{size}
|
||||||
|
- 自定义:用户在 config/whisper_models.json 登记 {名称: "<repo_id 或本地路径>"}
|
||||||
|
(JSON 持久化;Docker 下随 config 卷持久化)
|
||||||
|
|
||||||
|
解析优先级(resolve):自定义 > 内置 > 直通(含 "/" 当 repo_id;已存在目录当本地路径)。
|
||||||
|
加载 / 下载 / 完整性检测三处统一调用 resolve,路径不再各写各的。
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from app.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# 内置模型:size → faster-whisper 兼容的 HF repo_id(CTranslate2 转换版,Systran 官方维护)。
|
||||||
|
BUILTIN_WHISPER_MODELS: Dict[str, str] = {
|
||||||
|
"tiny": "Systran/faster-whisper-tiny",
|
||||||
|
"base": "Systran/faster-whisper-base",
|
||||||
|
"small": "Systran/faster-whisper-small",
|
||||||
|
"medium": "Systran/faster-whisper-medium",
|
||||||
|
"large-v1": "Systran/faster-whisper-large-v1",
|
||||||
|
"large-v2": "Systran/faster-whisper-large-v2",
|
||||||
|
"large-v3": "Systran/faster-whisper-large-v3",
|
||||||
|
"large-v3-turbo": "Systran/faster-whisper-large-v3-turbo",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 前端下拉默认展示的内置档位(保持与历史 WHISPER_MODEL_SIZES 一致,不把 8 个全列出来)
|
||||||
|
DEFAULT_VISIBLE_BUILTINS: List[str] = ["tiny", "base", "small", "medium", "large-v3", "large-v3-turbo"]
|
||||||
|
|
||||||
|
|
||||||
|
def is_local_target(target: str) -> bool:
|
||||||
|
"""判断解析出的 target 是本地路径而非 HF repo_id。
|
||||||
|
|
||||||
|
HF repo_id 形如 'Org/Name'(恰一个斜杠、无前导斜杠、非已存在目录)。
|
||||||
|
本地路径:绝对路径 / 以 . 或 ~ 开头 / 已存在的目录。
|
||||||
|
"""
|
||||||
|
if not target:
|
||||||
|
return False
|
||||||
|
if os.path.isabs(target) or target.startswith(".") or target.startswith("~"):
|
||||||
|
return True
|
||||||
|
return os.path.isdir(target)
|
||||||
|
|
||||||
|
|
||||||
|
def hf_cache_dirname(repo_id: str) -> str:
|
||||||
|
"""huggingface_hub snapshot 的本地缓存目录名:Org/Name → models--Org--Name。"""
|
||||||
|
return "models--" + repo_id.replace("/", "--")
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperModelRegistry:
|
||||||
|
"""内置 + 用户自定义的 whisper 模型映射,自定义部分持久化到 JSON。"""
|
||||||
|
|
||||||
|
def __init__(self, filepath: str = "config/whisper_models.json"):
|
||||||
|
self.path = Path(filepath)
|
||||||
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# ---- 持久化 ----
|
||||||
|
def _read_custom(self) -> Dict[str, str]:
|
||||||
|
if not self.path.exists():
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
with self.path.open("r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f) or {}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"读取自定义 whisper 模型配置失败,按空处理: {e}")
|
||||||
|
return {}
|
||||||
|
out: Dict[str, str] = {}
|
||||||
|
for name, val in data.items():
|
||||||
|
if isinstance(val, str) and val.strip():
|
||||||
|
out[name] = val.strip()
|
||||||
|
elif isinstance(val, dict) and isinstance(val.get("target"), str):
|
||||||
|
out[name] = val["target"].strip()
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _write_custom(self, data: Dict[str, str]) -> None:
|
||||||
|
with self.path.open("w", encoding="utf-8") as f:
|
||||||
|
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
# ---- 查询 ----
|
||||||
|
def get_custom_models(self) -> Dict[str, str]:
|
||||||
|
return self._read_custom()
|
||||||
|
|
||||||
|
def visible_model_names(self) -> List[str]:
|
||||||
|
"""给前端下拉 / 下载状态用:默认可见内置档位 + 全部自定义名称。"""
|
||||||
|
names = list(DEFAULT_VISIBLE_BUILTINS)
|
||||||
|
for name in self._read_custom():
|
||||||
|
if name not in names:
|
||||||
|
names.append(name)
|
||||||
|
return names
|
||||||
|
|
||||||
|
def is_known(self, name: str) -> bool:
|
||||||
|
try:
|
||||||
|
self.resolve(name)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def resolve(self, name: str) -> str:
|
||||||
|
"""模型名 → 可加载标识(HF repo_id 或本地路径)。
|
||||||
|
|
||||||
|
优先级:自定义映射 > 内置映射 > 直通(含 "/" 的 repo_id / 已存在的本地目录)。
|
||||||
|
无法识别时抛 ValueError。
|
||||||
|
"""
|
||||||
|
name = (name or "").strip()
|
||||||
|
custom = self._read_custom()
|
||||||
|
if name in custom:
|
||||||
|
return custom[name]
|
||||||
|
if name in BUILTIN_WHISPER_MODELS:
|
||||||
|
return BUILTIN_WHISPER_MODELS[name]
|
||||||
|
# 直通:用户直接把 repo_id(含 "/")或本地已存在目录当 model_size 传进来
|
||||||
|
if "/" in name or os.path.isdir(name):
|
||||||
|
return name
|
||||||
|
raise ValueError(
|
||||||
|
f"未知 whisper 模型 '{name}'。内置可选: {', '.join(BUILTIN_WHISPER_MODELS)};"
|
||||||
|
"或在「音频转写配置」添加自定义模型(HF repo_id 或本地路径)。"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- 增删 ----
|
||||||
|
def add_custom_model(self, name: str, target: str) -> Dict[str, str]:
|
||||||
|
name = (name or "").strip()
|
||||||
|
target = (target or "").strip()
|
||||||
|
if not name or not target:
|
||||||
|
raise ValueError("模型名称与目标(HF repo_id 或本地路径)都不能为空")
|
||||||
|
if name in BUILTIN_WHISPER_MODELS:
|
||||||
|
raise ValueError(f"'{name}' 与内置模型重名,请换一个名称")
|
||||||
|
data = self._read_custom()
|
||||||
|
data[name] = target
|
||||||
|
self._write_custom(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def remove_custom_model(self, name: str) -> Dict[str, str]:
|
||||||
|
data = self._read_custom()
|
||||||
|
data.pop((name or "").strip(), None)
|
||||||
|
self._write_custom(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
# 模块级单例
|
||||||
|
_registry = WhisperModelRegistry()
|
||||||
|
|
||||||
|
|
||||||
|
def get_registry() -> WhisperModelRegistry:
|
||||||
|
return _registry
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_whisper_model(name: str) -> str:
|
||||||
|
return _registry.resolve(name)
|
||||||
132
backend/tests/test_whisper_models.py
Normal file
132
backend/tests/test_whisper_models.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
"""Unit tests for app.transcriber.whisper_models(whisper 模型名→标识 的映射注册表)。
|
||||||
|
|
||||||
|
直接按文件路径加载被测模块,并桩掉 app.utils.logger,避免触发 app/__init__.py
|
||||||
|
(会 import faster_whisper / ctranslate2 等重依赖),使本测试无需安装转写依赖即可运行。
|
||||||
|
"""
|
||||||
|
import importlib.util
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import types
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
ROOT = pathlib.Path(__file__).resolve().parents[1]
|
||||||
|
MODULE_PATH = ROOT / "app" / "transcriber" / "whisper_models.py"
|
||||||
|
|
||||||
|
|
||||||
|
def _load_module():
|
||||||
|
if "app" not in sys.modules:
|
||||||
|
app_pkg = types.ModuleType("app")
|
||||||
|
app_pkg.__path__ = [] # 标记为 package
|
||||||
|
sys.modules["app"] = app_pkg
|
||||||
|
if "app.utils" not in sys.modules:
|
||||||
|
utils_pkg = types.ModuleType("app.utils")
|
||||||
|
utils_pkg.__path__ = []
|
||||||
|
sys.modules["app.utils"] = utils_pkg
|
||||||
|
if "app.utils.logger" not in sys.modules:
|
||||||
|
logger_mod = types.ModuleType("app.utils.logger")
|
||||||
|
logger_mod.get_logger = lambda name=None: logging.getLogger(name or "test")
|
||||||
|
sys.modules["app.utils.logger"] = logger_mod
|
||||||
|
spec = importlib.util.spec_from_file_location("whisper_models_under_test", MODULE_PATH)
|
||||||
|
assert spec and spec.loader
|
||||||
|
mod = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(mod)
|
||||||
|
return mod
|
||||||
|
|
||||||
|
|
||||||
|
wm = _load_module()
|
||||||
|
|
||||||
|
|
||||||
|
class TestResolve(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.tmp = tempfile.TemporaryDirectory()
|
||||||
|
self.cfg = os.path.join(self.tmp.name, "whisper_models.json")
|
||||||
|
self.reg = wm.WhisperModelRegistry(filepath=self.cfg)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.tmp.cleanup()
|
||||||
|
|
||||||
|
def test_builtin_resolves_to_systran(self):
|
||||||
|
self.assertEqual(self.reg.resolve("tiny"), "Systran/faster-whisper-tiny")
|
||||||
|
self.assertEqual(self.reg.resolve("large-v3-turbo"), "Systran/faster-whisper-large-v3-turbo")
|
||||||
|
|
||||||
|
def test_passthrough_repo_id(self):
|
||||||
|
# 用户直接把 HF repo_id 当 model_size 传进来(含 "/")
|
||||||
|
self.assertEqual(self.reg.resolve("SomeOrg/my-whisper-ct2"), "SomeOrg/my-whisper-ct2")
|
||||||
|
|
||||||
|
def test_unknown_raises(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.reg.resolve("definitely-not-a-model")
|
||||||
|
|
||||||
|
def test_custom_overrides_and_persists(self):
|
||||||
|
self.reg.add_custom_model("myhf", "someorg/whisper-ct2")
|
||||||
|
self.assertEqual(self.reg.resolve("myhf"), "someorg/whisper-ct2")
|
||||||
|
# 新实例读同一文件 → 确认持久化(Docker 下随 config 卷保留)
|
||||||
|
reg2 = wm.WhisperModelRegistry(filepath=self.cfg)
|
||||||
|
self.assertEqual(reg2.resolve("myhf"), "someorg/whisper-ct2")
|
||||||
|
|
||||||
|
def test_custom_can_override_builtin_key_resolution(self):
|
||||||
|
# 自定义优先级高于内置:把 "tiny" 强行指到别的 repo(resolve 层允许;add 层禁止重名)
|
||||||
|
self.reg._write_custom({"tiny": "Other/tiny-ct2"})
|
||||||
|
self.assertEqual(self.reg.resolve("tiny"), "Other/tiny-ct2")
|
||||||
|
|
||||||
|
def test_local_path_resolution_and_detection(self):
|
||||||
|
model_dir = os.path.join(self.tmp.name, "mymodel")
|
||||||
|
os.makedirs(model_dir)
|
||||||
|
self.reg.add_custom_model("local1", model_dir)
|
||||||
|
self.assertEqual(self.reg.resolve("local1"), model_dir)
|
||||||
|
self.assertTrue(wm.is_local_target(self.reg.resolve("local1")))
|
||||||
|
|
||||||
|
def test_bare_existing_dir_passthrough(self):
|
||||||
|
# 没登记、但直接传一个已存在目录 → 直通为本地路径
|
||||||
|
model_dir = os.path.join(self.tmp.name, "bare")
|
||||||
|
os.makedirs(model_dir)
|
||||||
|
self.assertEqual(self.reg.resolve(model_dir), model_dir)
|
||||||
|
|
||||||
|
def test_add_rejects_builtin_collision_and_empty(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.reg.add_custom_model("tiny", "x/y") # 与内置重名
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.reg.add_custom_model("", "x/y")
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.reg.add_custom_model("ok", "")
|
||||||
|
|
||||||
|
def test_remove(self):
|
||||||
|
self.reg.add_custom_model("tmpm", "a/b")
|
||||||
|
self.assertIn("tmpm", self.reg.get_custom_models())
|
||||||
|
self.reg.remove_custom_model("tmpm")
|
||||||
|
self.assertNotIn("tmpm", self.reg.get_custom_models())
|
||||||
|
|
||||||
|
def test_visible_includes_builtin_and_custom(self):
|
||||||
|
self.reg.add_custom_model("zzz", "a/b")
|
||||||
|
names = self.reg.visible_model_names()
|
||||||
|
self.assertIn("tiny", names)
|
||||||
|
self.assertIn("large-v3", names)
|
||||||
|
self.assertIn("zzz", names)
|
||||||
|
|
||||||
|
def test_is_known(self):
|
||||||
|
self.assertTrue(self.reg.is_known("base"))
|
||||||
|
self.assertTrue(self.reg.is_known("Org/Name"))
|
||||||
|
self.assertFalse(self.reg.is_known("nope-not-real"))
|
||||||
|
|
||||||
|
|
||||||
|
class TestHelpers(unittest.TestCase):
|
||||||
|
def test_hf_cache_dirname(self):
|
||||||
|
self.assertEqual(
|
||||||
|
wm.hf_cache_dirname("Systran/faster-whisper-tiny"),
|
||||||
|
"models--Systran--faster-whisper-tiny",
|
||||||
|
)
|
||||||
|
self.assertEqual(wm.hf_cache_dirname("Org/Name"), "models--Org--Name")
|
||||||
|
|
||||||
|
def test_is_local_target(self):
|
||||||
|
self.assertTrue(wm.is_local_target("/abs/path"))
|
||||||
|
self.assertTrue(wm.is_local_target("./rel"))
|
||||||
|
self.assertTrue(wm.is_local_target("~/home/model"))
|
||||||
|
self.assertFalse(wm.is_local_target("Org/Name")) # repo_id 不是本地路径
|
||||||
|
self.assertFalse(wm.is_local_target(""))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user