From 58d992f28faf0bdeeea07f0f9fdbeee21eab472a Mon Sep 17 00:00:00 2001 From: huangjianwu Date: Fri, 22 May 2026 15:09:06 +0800 Subject: [PATCH 1/4] =?UTF-8?q?feat(transcriber):=20=E5=8F=AF=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=20whisper=20=E6=A8=A1=E5=9E=8B=20+=20=E5=90=8D?= =?UTF-8?q?=E7=A7=B0=E6=98=A0=E5=B0=84=EF=BC=88=E8=87=AA=E5=AE=9A=E4=B9=89?= =?UTF-8?q?=20HF=20repo=20/=20=E6=9C=AC=E5=9C=B0=E8=B7=AF=E5=BE=84?= =?UTF-8?q?=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 此前 fast-whisper 把「size → Systran/faster-whisper-{size}」的约定隐式散落在 加载/下载/检测三处,用户想用命名不符该约定的模型(社区微调版、或自己下到本地 的模型)接不上。本功能把映射显式化 + 可配置(对齐已有的 MLX_MODEL_MAP 模式)。 后端: - 新增 app/transcriber/whisper_models.py 注册表:内置映射 + 用户自定义 (config/whisper_models.json 持久化,Docker 下随 config 卷保留);resolve 优先级 自定义 > 内置 > 直通(含 / 的 repo_id / 已存在本地目录)。 - whisper.py / config.py 的加载、下载、完整性检测统一走 resolve;HF cache 目录从 任意 repo_id 推导(models--{org}--{name})不再写死 Systran;本地路径跳过下载, _purge_cache 绝不删用户本地模型。 - 新增 /whisper_models 增删查 API;/transcriber_config 返回内置+自定义列表; 下载校验放开到「已登记/可解析」的模型。 前端:transcriber.tsx 新增「自定义模型」卡片(增删 + 下载状态),模型下拉自动含自定义。 Docker:自定义 HF 模型下到 /app/backend/models(v2.3.3 models 卷已持久化);本地模型 走挂载目录 + 配置路径,UI 已提示挂载。 测试:tests/test_whisper_models.py 13 个单测全过;并在 v2.3.3 镜像真实后端环境做了 import 链 + resolve + 真实模型检测的集成冒烟,均通过。 Co-Authored-By: Claude Opus 4.7 (1M context) --- .../src/pages/SettingPage/transcriber.tsx | 144 +++++++++++++++- BillNote_frontend/src/services/transcriber.ts | 24 +++ backend/app/routers/config.py | 106 +++++++++--- backend/app/transcriber/whisper.py | 30 +++- backend/app/transcriber/whisper_models.py | 156 ++++++++++++++++++ backend/tests/test_whisper_models.py | 132 +++++++++++++++ 6 files changed, 565 insertions(+), 27 deletions(-) create mode 100644 backend/app/transcriber/whisper_models.py create mode 100644 backend/tests/test_whisper_models.py diff --git a/BillNote_frontend/src/pages/SettingPage/transcriber.tsx b/BillNote_frontend/src/pages/SettingPage/transcriber.tsx index 17f298c..74a5669 100644 --- a/BillNote_frontend/src/pages/SettingPage/transcriber.tsx +++ b/BillNote_frontend/src/pages/SettingPage/transcriber.tsx @@ -10,13 +10,16 @@ import { SelectValue, } from '@/components/ui/select' import { Alert, AlertDescription } from '@/components/ui/alert' -import { AudioLines, AlertTriangle, CheckCircle2, Download, Loader2, Save, XCircle } from 'lucide-react' +import { Input } from '@/components/ui/input' +import { AudioLines, AlertTriangle, CheckCircle2, Download, Loader2, Save, XCircle, Plus, Trash2, Boxes } from 'lucide-react' import { toast } from 'react-hot-toast' import { getTranscriberConfig, updateTranscriberConfig, getModelsStatus, downloadModel, + addWhisperModel, + deleteWhisperModel, TranscriberConfig, ModelStatus, } from '@/services/transcriber' @@ -33,6 +36,19 @@ export default function Transcriber() { const [modelStatuses, setModelStatuses] = useState([]) const [mlxModelStatuses, setMlxModelStatuses] = useState([]) const [mlxAvailable, setMlxAvailable] = useState(false) + // 自定义模型表单 + const [newModelName, setNewModelName] = useState('') + const [newModelTarget, setNewModelTarget] = useState('') + const [addingModel, setAddingModel] = useState(false) + + // 重新拉取配置(不重置用户当前的选择),用于增删自定义模型后刷新下拉与列表 + const reloadConfig = useCallback(async () => { + try { + setConfig(await getTranscriberConfig()) + } catch { + // 静默 + } + }, []) const fetchModelsStatus = useCallback(async () => { try { @@ -123,6 +139,41 @@ export default function Transcriber() { } } + const handleAddCustomModel = async () => { + const name = newModelName.trim() + const target = newModelTarget.trim() + if (!name || !target) { + toast.error('请填写模型名称和 HF repo_id / 本地路径') + return + } + setAddingModel(true) + try { + await addWhisperModel({ name, target }) + toast.success(`已添加自定义模型 ${name}`) + setNewModelName('') + setNewModelTarget('') + await reloadConfig() + await fetchModelsStatus() + } catch { + // 后端的具体错误(如重名)已由请求拦截器 toast,这里不重复提示 + } finally { + setAddingModel(false) + } + } + + const handleDeleteCustomModel = async (name: string) => { + try { + await deleteWhisperModel(name) + toast.success(`已删除自定义模型 ${name}`) + // 删的正好是当前选中的,回退到 tiny,避免选中一个不存在的名称 + if (selectedModelSize === name) setSelectedModelSize('tiny') + await reloadConfig() + await fetchModelsStatus() + } catch { + // 拦截器已提示 + } + } + if (loading) { return (
@@ -272,6 +323,97 @@ export default function Transcriber() { )} + + {/* 自定义 Whisper 模型(仅 fast-whisper:名称不符合内置 Systran 约定的模型在此登记映射) */} + {selectedType === 'fast-whisper' && ( + + + + + 自定义模型 + + 登记名称不符合内置约定的模型 + + + + + + + 填 HF repo_id(如{' '} + Systran/faster-whisper-large-v3 + ,会自动下载)或本地模型目录(如{' '} + /app/backend/models/my-whisper + ,目录内需含 model.bin,下载会跳过)。 + 添加后即可在上方「模型大小」下拉中选用。Docker 部署请把模型目录挂载进容器(见 README 的{' '} + models 卷)。 + + + + {config.whisper_custom_models && + Object.keys(config.whisper_custom_models).length > 0 ? ( +
+ {Object.entries(config.whisper_custom_models).map(([name, target]) => { + const status = modelStatuses.find(m => m.model_size === name) + return ( +
+
+
+ {name} + {status?.downloaded && ( + + )} + {status?.downloading && ( + + )} +
+
+ {target} +
+
+ +
+ ) + })} +
+ ) : ( +

还没有自定义模型

+ )} + +
+ setNewModelName(e.target.value)} + className="sm:max-w-[220px]" + /> + setNewModelTarget(e.target.value)} + className="flex-1" + /> + +
+
+
+ )}
) } diff --git a/BillNote_frontend/src/services/transcriber.ts b/BillNote_frontend/src/services/transcriber.ts index 8407927..70f3f72 100644 --- a/BillNote_frontend/src/services/transcriber.ts +++ b/BillNote_frontend/src/services/transcriber.ts @@ -5,6 +5,10 @@ export interface TranscriberConfig { whisper_model_size: string available_types: { value: string; label: string }[] whisper_model_sizes: string[] + /** 内置模型映射:size → HF repo_id */ + whisper_builtin_models?: Record + /** 用户自定义模型映射:名称 → HF repo_id 或本地路径 */ + whisper_custom_models?: Record mlx_whisper_available: boolean } @@ -41,3 +45,23 @@ export const downloadModel = async (data: { }) => { return await request.post('/transcriber_download', data) } + +export interface WhisperModelsResponse { + builtin: Record + custom: Record +} + +/** 列出内置 + 自定义 whisper 模型映射 */ +export const listWhisperModels = async (): Promise => { + return await request.get('/whisper_models') +} + +/** 新增自定义模型映射(名称 → HF repo_id 或本地路径) */ +export const addWhisperModel = async (data: { name: string; target: string }) => { + return await request.post('/whisper_models', data) +} + +/** 删除自定义模型映射(不会删除已下载的模型文件) */ +export const deleteWhisperModel = async (name: string) => { + return await request.delete(`/whisper_models/${encodeURIComponent(name)}`) +} diff --git a/backend/app/routers/config.py b/backend/app/routers/config.py index c816fc5..e0be457 100644 --- a/backend/app/routers/config.py +++ b/backend/app/routers/config.py @@ -61,16 +61,53 @@ WHISPER_MODEL_SIZES = ["tiny", "base", "small", "medium", "large-v3", "large-v3- @router.get("/transcriber_config") def get_transcriber_config(): from app.transcriber.transcriber_provider import MLX_WHISPER_AVAILABLE + from app.transcriber.whisper_models import get_registry, BUILTIN_WHISPER_MODELS + registry = get_registry() config = transcriber_config_manager.get_config() return R.success(data={ **config, "available_types": AVAILABLE_TRANSCRIBER_TYPES, - "whisper_model_sizes": WHISPER_MODEL_SIZES, + # 内置可见档位 + 用户自定义模型,供前端下拉 + "whisper_model_sizes": registry.visible_model_names(), + "whisper_builtin_models": BUILTIN_WHISPER_MODELS, + "whisper_custom_models": registry.get_custom_models(), "mlx_whisper_available": MLX_WHISPER_AVAILABLE, }) +class WhisperCustomModelRequest(BaseModel): + name: str + target: str # HF repo_id(如 Systran/faster-whisper-large-v3)或本地模型目录路径 + + +@router.get("/whisper_models") +def list_whisper_models(): + """列出内置 + 用户自定义的 whisper 模型映射。""" + from app.transcriber.whisper_models import get_registry, BUILTIN_WHISPER_MODELS + reg = get_registry() + return R.success(data={"builtin": BUILTIN_WHISPER_MODELS, "custom": reg.get_custom_models()}) + + +@router.post("/whisper_models") +def add_whisper_model(data: WhisperCustomModelRequest): + """新增自定义 whisper 模型映射(名称 → HF repo_id 或本地路径)。""" + from app.transcriber.whisper_models import get_registry + try: + custom = get_registry().add_custom_model(data.name, data.target) + except ValueError as e: + return R.error(msg=str(e)) + return R.success(data={"custom": custom}, msg="已添加自定义模型") + + +@router.delete("/whisper_models/{name}") +def delete_whisper_model(name: str): + """删除自定义 whisper 模型映射(不会删除已下载的模型文件)。""" + from app.transcriber.whisper_models import get_registry + custom = get_registry().remove_custom_model(name) + return R.success(data={"custom": custom}, msg="已删除自定义模型") + + @router.post("/transcriber_config") def update_transcriber_config(data: TranscriberConfigRequest): config = transcriber_config_manager.update_config( @@ -119,14 +156,27 @@ _downloading: dict[str, str] = {} # model_size -> status ("downloading" | "done def _check_whisper_model_exists(model_size: str, subdir: str = "whisper") -> bool: """检查指定 whisper 模型是否已下载完整到本地。 - faster-whisper 把模型缓存在 HF cache 布局下: - /models--Systran--faster-whisper-{size}/snapshots//model.bin - 必须能在某个 snapshot 目录里找到 model.bin 才算完成。 - (历史 modelscope 布局 /whisper-{size}/model.bin 也兼容识别。) + 先把模型名 resolve 成可加载标识,再按类型判定: + - 本地路径模型 → 直接看该目录下有没有 model.bin + - HF repo_id → 看 HF cache 布局 + /models--{org}--{name}/snapshots//model.bin + (历史 modelscope 布局 /whisper-{size}/model.bin 也兼容识别) """ + from app.transcriber.whisper_models import ( + resolve_whisper_model, + is_local_target, + hf_cache_dirname, + ) + try: + target = resolve_whisper_model(model_size) + except Exception: + return False + if is_local_target(target): + return (Path(target) / "model.bin").exists() + model_dir = Path(get_model_dir(subdir)) - # HF cache 布局 - hf_repo_dir = model_dir / f"models--Systran--faster-whisper-{model_size}" / "snapshots" + # HF cache 布局(适配任意 org/repo,不再写死 Systran) + hf_repo_dir = model_dir / hf_cache_dirname(target) / "snapshots" if hf_repo_dir.exists(): for snapshot in hf_repo_dir.iterdir(): if (snapshot / "model.bin").exists(): @@ -157,9 +207,10 @@ def _check_mlx_whisper_model_exists(model_size: str) -> bool: @router.get("/transcriber_models_status") def get_transcriber_models_status(): - """返回所有 whisper 模型的下载状态。""" + """返回所有 whisper 模型的下载状态(含用户自定义模型)。""" + from app.transcriber.whisper_models import get_registry statuses = [] - for size in WHISPER_MODEL_SIZES: + for size in get_registry().visible_model_names(): downloaded = _check_whisper_model_exists(size, "whisper") download_status = _downloading.get(size) statuses.append({ @@ -198,13 +249,15 @@ class ModelDownloadRequest(BaseModel): def _do_download_whisper(model_size: str): - """后台下载 faster-whisper 模型。 + """后台下载 faster-whisper 模型(支持内置 size / 自定义 repo_id / 本地路径)。 - 直接走 huggingface_hub.snapshot_download,把模型放到 HF cache 布局里—— - 这样 faster-whisper 加载时(WhisperModel(model_size_or_path=size_name, - download_root=model_dir))能直接命中缓存,跟加载路径完全对齐。 + 模型名先 resolve: + - 本地路径模型:无需下载,目录里有 model.bin 即 done,否则 failed; + - HF repo_id:snapshot_download 到 HF cache 布局(cache_dir=model_dir), + 与加载逻辑 WhisperModel(download_root=model_dir) 完全对齐。 """ from huggingface_hub import snapshot_download + from app.transcriber.whisper_models import resolve_whisper_model, is_local_target try: _downloading[model_size] = "downloading" @@ -214,12 +267,21 @@ def _do_download_whisper(model_size: str): if _check_whisper_model_exists(model_size, "whisper"): _downloading[model_size] = "done" return - repo_id = f"Systran/faster-whisper-{model_size}" - logger.info(f"开始下载 whisper 模型: {repo_id}") + + target = resolve_whisper_model(model_size) + if is_local_target(target): + # 本地模型不下载,只校验 model.bin 是否就位 + ok = (Path(target) / "model.bin").exists() + _downloading[model_size] = "done" if ok else "failed" + if not ok: + logger.warning(f"本地模型 {model_size} 路径 {target} 下没有 model.bin,无法使用") + return + + logger.info(f"开始下载 whisper 模型: {model_size} ← {target}") # 跟 faster-whisper utils.py 用同样的 allow_patterns,避免多下无关文件; # 不传 local_dir 让它走 HF 默认 cache 布局(与加载逻辑对齐) snapshot_download( - repo_id, + target, cache_dir=model_dir, allow_patterns=[ "config.json", @@ -268,11 +330,11 @@ def _do_download_mlx_whisper(model_size: str): @router.post("/transcriber_download") def download_transcriber_model(data: ModelDownloadRequest, background_tasks: BackgroundTasks): - """触发后台下载指定的 whisper 模型。""" - if data.model_size not in WHISPER_MODEL_SIZES: - return R.error(msg=f"不支持的模型大小: {data.model_size}") - + """触发后台下载指定的 whisper 模型(fast-whisper 支持内置档位 + 自定义模型)。""" if data.transcriber_type == "mlx-whisper": + # mlx 只认内置档位(mlx-community 的固定映射) + if data.model_size not in WHISPER_MODEL_SIZES: + return R.error(msg=f"MLX 不支持的模型大小: {data.model_size}") if platform.system() != "Darwin": return R.error(msg="MLX Whisper 仅支持 macOS") key = f"mlx-{data.model_size}" @@ -280,6 +342,10 @@ def download_transcriber_model(data: ModelDownloadRequest, background_tasks: Bac return R.success(msg="模型正在下载中") background_tasks.add_task(_do_download_mlx_whisper, data.model_size) else: + # fast-whisper:内置档位 / 自定义 repo_id / 本地路径都允许 + from app.transcriber.whisper_models import get_registry + if not get_registry().is_known(data.model_size): + return R.error(msg=f"不支持的模型: {data.model_size}(请先在自定义模型中登记)") if _downloading.get(data.model_size) == "downloading": return R.success(msg="模型正在下载中") background_tasks.add_task(_do_download_whisper, data.model_size) diff --git a/backend/app/transcriber/whisper.py b/backend/app/transcriber/whisper.py index 5308579..3d254a2 100644 --- a/backend/app/transcriber/whisper.py +++ b/backend/app/transcriber/whisper.py @@ -3,6 +3,11 @@ from faster_whisper import WhisperModel from app.decorators.timeit import timeit from app.models.transcriber_model import TranscriptSegment, TranscriptResult from app.transcriber.base import Transcriber +from app.transcriber.whisper_models import ( + resolve_whisper_model, + is_local_target, + hf_cache_dirname, +) from app.utils.env_checker import is_cuda_available, is_torch_installed from app.utils.logger import get_logger from app.utils.path_helper import get_model_dir @@ -55,8 +60,12 @@ class WhisperTranscriber(Transcriber): self.model = self._build_model(model_size, model_dir) def _build_model(self, model_size: str, model_dir: str) -> WhisperModel: + # resolve 把模型名映射成可加载标识:内置 size→Systran repo_id、自定义映射、 + # 直通的 repo_id 或本地路径。faster-whisper 对本地目录走 os.path.isdir 分支, + # 对 repo_id 走 download_model(cache_dir=download_root),两者都吃 model_size_or_path。 + target = resolve_whisper_model(model_size) return WhisperModel( - model_size_or_path=model_size, # 传 size name,让 faster-whisper 自己映射到 Systran/faster-whisper-* + model_size_or_path=target, device=self.device, compute_type=self.compute_type, download_root=model_dir, @@ -64,14 +73,23 @@ class WhisperTranscriber(Transcriber): @staticmethod def _purge_cache(model_dir: str, model_size: str) -> None: - """删掉 HF cache 里这个 size 对应的 snapshot 目录,强制下次重新下载。 + """加载失败时清掉对应 HF cache 的 snapshot 目录,强制下次重下。 - HF cache 布局:/models--Systran--faster-whisper-{size}/ - 没找到也不报错——可能用户改了 endpoint 或者 cache 布局变了。 + 关键:本地路径模型**绝不删**——那是用户自己的文件,删了就是数据丢失; + 只清 HF cache 布局 /models--{org}--{name}/(含历史 modelscope 目录)。 """ + try: + target = resolve_whisper_model(model_size) + except Exception: + target = model_size + if is_local_target(target): + logger.warning( + f"模型 {model_size} 指向本地路径 {target},加载失败不清理用户文件,请检查该目录是否完整" + ) + return candidates = [ - Path(model_dir) / f"models--Systran--faster-whisper-{model_size}", - Path(model_dir) / f"whisper-{model_size}", # 历史 modelscope 目录,顺手清掉 + Path(model_dir) / hf_cache_dirname(target), # HF cache: models--org--name + Path(model_dir) / f"whisper-{model_size}", # 历史 modelscope 目录,顺手清掉 ] for path in candidates: if path.exists(): diff --git a/backend/app/transcriber/whisper_models.py b/backend/app/transcriber/whisper_models.py new file mode 100644 index 0000000..71dde1c --- /dev/null +++ b/backend/app/transcriber/whisper_models.py @@ -0,0 +1,156 @@ +"""fast-whisper 模型名 → 可加载标识(HF repo_id 或本地路径)的映射注册表。 + +背景:faster-whisper 加载时 `WhisperModel(model_size_or_path=...)` 接受三种入参—— +内置 size 名、HF repo_id(含 "/")、或本地模型目录(`os.path.isdir` 命中则直接用)。 +此前后端把「size → Systran/faster-whisper-{size}」这层约定**隐式**散落在加载/下载/ +检测三处,用户想用命名不符合该约定的模型(比如社区微调版、或自己下到本地的模型)就接不上。 + +本模块把映射**显式化 + 可配置**(对齐 mlx_whisper_transcriber.MLX_MODEL_MAP 的模式): + - 内置:size → Systran/faster-whisper-{size} + - 自定义:用户在 config/whisper_models.json 登记 {名称: ""} + (JSON 持久化;Docker 下随 config 卷持久化) + +解析优先级(resolve):自定义 > 内置 > 直通(含 "/" 当 repo_id;已存在目录当本地路径)。 +加载 / 下载 / 完整性检测三处统一调用 resolve,路径不再各写各的。 +""" +import json +import os +from pathlib import Path +from typing import Dict, List + +from app.utils.logger import get_logger + +logger = get_logger(__name__) + +# 内置模型:size → faster-whisper 兼容的 HF repo_id(CTranslate2 转换版,Systran 官方维护)。 +BUILTIN_WHISPER_MODELS: Dict[str, str] = { + "tiny": "Systran/faster-whisper-tiny", + "base": "Systran/faster-whisper-base", + "small": "Systran/faster-whisper-small", + "medium": "Systran/faster-whisper-medium", + "large-v1": "Systran/faster-whisper-large-v1", + "large-v2": "Systran/faster-whisper-large-v2", + "large-v3": "Systran/faster-whisper-large-v3", + "large-v3-turbo": "Systran/faster-whisper-large-v3-turbo", +} + +# 前端下拉默认展示的内置档位(保持与历史 WHISPER_MODEL_SIZES 一致,不把 8 个全列出来) +DEFAULT_VISIBLE_BUILTINS: List[str] = ["tiny", "base", "small", "medium", "large-v3", "large-v3-turbo"] + + +def is_local_target(target: str) -> bool: + """判断解析出的 target 是本地路径而非 HF repo_id。 + + HF repo_id 形如 'Org/Name'(恰一个斜杠、无前导斜杠、非已存在目录)。 + 本地路径:绝对路径 / 以 . 或 ~ 开头 / 已存在的目录。 + """ + if not target: + return False + if os.path.isabs(target) or target.startswith(".") or target.startswith("~"): + return True + return os.path.isdir(target) + + +def hf_cache_dirname(repo_id: str) -> str: + """huggingface_hub snapshot 的本地缓存目录名:Org/Name → models--Org--Name。""" + return "models--" + repo_id.replace("/", "--") + + +class WhisperModelRegistry: + """内置 + 用户自定义的 whisper 模型映射,自定义部分持久化到 JSON。""" + + def __init__(self, filepath: str = "config/whisper_models.json"): + self.path = Path(filepath) + self.path.parent.mkdir(parents=True, exist_ok=True) + + # ---- 持久化 ---- + def _read_custom(self) -> Dict[str, str]: + if not self.path.exists(): + return {} + try: + with self.path.open("r", encoding="utf-8") as f: + data = json.load(f) or {} + except Exception as e: + logger.warning(f"读取自定义 whisper 模型配置失败,按空处理: {e}") + return {} + out: Dict[str, str] = {} + for name, val in data.items(): + if isinstance(val, str) and val.strip(): + out[name] = val.strip() + elif isinstance(val, dict) and isinstance(val.get("target"), str): + out[name] = val["target"].strip() + return out + + def _write_custom(self, data: Dict[str, str]) -> None: + with self.path.open("w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + # ---- 查询 ---- + def get_custom_models(self) -> Dict[str, str]: + return self._read_custom() + + def visible_model_names(self) -> List[str]: + """给前端下拉 / 下载状态用:默认可见内置档位 + 全部自定义名称。""" + names = list(DEFAULT_VISIBLE_BUILTINS) + for name in self._read_custom(): + if name not in names: + names.append(name) + return names + + def is_known(self, name: str) -> bool: + try: + self.resolve(name) + return True + except ValueError: + return False + + def resolve(self, name: str) -> str: + """模型名 → 可加载标识(HF repo_id 或本地路径)。 + + 优先级:自定义映射 > 内置映射 > 直通(含 "/" 的 repo_id / 已存在的本地目录)。 + 无法识别时抛 ValueError。 + """ + name = (name or "").strip() + custom = self._read_custom() + if name in custom: + return custom[name] + if name in BUILTIN_WHISPER_MODELS: + return BUILTIN_WHISPER_MODELS[name] + # 直通:用户直接把 repo_id(含 "/")或本地已存在目录当 model_size 传进来 + if "/" in name or os.path.isdir(name): + return name + raise ValueError( + f"未知 whisper 模型 '{name}'。内置可选: {', '.join(BUILTIN_WHISPER_MODELS)};" + "或在「音频转写配置」添加自定义模型(HF repo_id 或本地路径)。" + ) + + # ---- 增删 ---- + def add_custom_model(self, name: str, target: str) -> Dict[str, str]: + name = (name or "").strip() + target = (target or "").strip() + if not name or not target: + raise ValueError("模型名称与目标(HF repo_id 或本地路径)都不能为空") + if name in BUILTIN_WHISPER_MODELS: + raise ValueError(f"'{name}' 与内置模型重名,请换一个名称") + data = self._read_custom() + data[name] = target + self._write_custom(data) + return data + + def remove_custom_model(self, name: str) -> Dict[str, str]: + data = self._read_custom() + data.pop((name or "").strip(), None) + self._write_custom(data) + return data + + +# 模块级单例 +_registry = WhisperModelRegistry() + + +def get_registry() -> WhisperModelRegistry: + return _registry + + +def resolve_whisper_model(name: str) -> str: + return _registry.resolve(name) diff --git a/backend/tests/test_whisper_models.py b/backend/tests/test_whisper_models.py new file mode 100644 index 0000000..0f323c1 --- /dev/null +++ b/backend/tests/test_whisper_models.py @@ -0,0 +1,132 @@ +"""Unit tests for app.transcriber.whisper_models(whisper 模型名→标识 的映射注册表)。 + +直接按文件路径加载被测模块,并桩掉 app.utils.logger,避免触发 app/__init__.py +(会 import faster_whisper / ctranslate2 等重依赖),使本测试无需安装转写依赖即可运行。 +""" +import importlib.util +import logging +import os +import pathlib +import sys +import tempfile +import types +import unittest + +ROOT = pathlib.Path(__file__).resolve().parents[1] +MODULE_PATH = ROOT / "app" / "transcriber" / "whisper_models.py" + + +def _load_module(): + if "app" not in sys.modules: + app_pkg = types.ModuleType("app") + app_pkg.__path__ = [] # 标记为 package + sys.modules["app"] = app_pkg + if "app.utils" not in sys.modules: + utils_pkg = types.ModuleType("app.utils") + utils_pkg.__path__ = [] + sys.modules["app.utils"] = utils_pkg + if "app.utils.logger" not in sys.modules: + logger_mod = types.ModuleType("app.utils.logger") + logger_mod.get_logger = lambda name=None: logging.getLogger(name or "test") + sys.modules["app.utils.logger"] = logger_mod + spec = importlib.util.spec_from_file_location("whisper_models_under_test", MODULE_PATH) + assert spec and spec.loader + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +wm = _load_module() + + +class TestResolve(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.TemporaryDirectory() + self.cfg = os.path.join(self.tmp.name, "whisper_models.json") + self.reg = wm.WhisperModelRegistry(filepath=self.cfg) + + def tearDown(self): + self.tmp.cleanup() + + def test_builtin_resolves_to_systran(self): + self.assertEqual(self.reg.resolve("tiny"), "Systran/faster-whisper-tiny") + self.assertEqual(self.reg.resolve("large-v3-turbo"), "Systran/faster-whisper-large-v3-turbo") + + def test_passthrough_repo_id(self): + # 用户直接把 HF repo_id 当 model_size 传进来(含 "/") + self.assertEqual(self.reg.resolve("SomeOrg/my-whisper-ct2"), "SomeOrg/my-whisper-ct2") + + def test_unknown_raises(self): + with self.assertRaises(ValueError): + self.reg.resolve("definitely-not-a-model") + + def test_custom_overrides_and_persists(self): + self.reg.add_custom_model("myhf", "someorg/whisper-ct2") + self.assertEqual(self.reg.resolve("myhf"), "someorg/whisper-ct2") + # 新实例读同一文件 → 确认持久化(Docker 下随 config 卷保留) + reg2 = wm.WhisperModelRegistry(filepath=self.cfg) + self.assertEqual(reg2.resolve("myhf"), "someorg/whisper-ct2") + + def test_custom_can_override_builtin_key_resolution(self): + # 自定义优先级高于内置:把 "tiny" 强行指到别的 repo(resolve 层允许;add 层禁止重名) + self.reg._write_custom({"tiny": "Other/tiny-ct2"}) + self.assertEqual(self.reg.resolve("tiny"), "Other/tiny-ct2") + + def test_local_path_resolution_and_detection(self): + model_dir = os.path.join(self.tmp.name, "mymodel") + os.makedirs(model_dir) + self.reg.add_custom_model("local1", model_dir) + self.assertEqual(self.reg.resolve("local1"), model_dir) + self.assertTrue(wm.is_local_target(self.reg.resolve("local1"))) + + def test_bare_existing_dir_passthrough(self): + # 没登记、但直接传一个已存在目录 → 直通为本地路径 + model_dir = os.path.join(self.tmp.name, "bare") + os.makedirs(model_dir) + self.assertEqual(self.reg.resolve(model_dir), model_dir) + + def test_add_rejects_builtin_collision_and_empty(self): + with self.assertRaises(ValueError): + self.reg.add_custom_model("tiny", "x/y") # 与内置重名 + with self.assertRaises(ValueError): + self.reg.add_custom_model("", "x/y") + with self.assertRaises(ValueError): + self.reg.add_custom_model("ok", "") + + def test_remove(self): + self.reg.add_custom_model("tmpm", "a/b") + self.assertIn("tmpm", self.reg.get_custom_models()) + self.reg.remove_custom_model("tmpm") + self.assertNotIn("tmpm", self.reg.get_custom_models()) + + def test_visible_includes_builtin_and_custom(self): + self.reg.add_custom_model("zzz", "a/b") + names = self.reg.visible_model_names() + self.assertIn("tiny", names) + self.assertIn("large-v3", names) + self.assertIn("zzz", names) + + def test_is_known(self): + self.assertTrue(self.reg.is_known("base")) + self.assertTrue(self.reg.is_known("Org/Name")) + self.assertFalse(self.reg.is_known("nope-not-real")) + + +class TestHelpers(unittest.TestCase): + def test_hf_cache_dirname(self): + self.assertEqual( + wm.hf_cache_dirname("Systran/faster-whisper-tiny"), + "models--Systran--faster-whisper-tiny", + ) + self.assertEqual(wm.hf_cache_dirname("Org/Name"), "models--Org--Name") + + def test_is_local_target(self): + self.assertTrue(wm.is_local_target("/abs/path")) + self.assertTrue(wm.is_local_target("./rel")) + self.assertTrue(wm.is_local_target("~/home/model")) + self.assertFalse(wm.is_local_target("Org/Name")) # repo_id 不是本地路径 + self.assertFalse(wm.is_local_target("")) + + +if __name__ == "__main__": + unittest.main() From 64a0400792a8b2543b9f61ea4117dc21a242ca7c Mon Sep 17 00:00:00 2001 From: huangjianwu Date: Fri, 22 May 2026 21:51:47 +0800 Subject: [PATCH 2/4] =?UTF-8?q?docs(readme):=20=E8=A1=A5=E5=85=A8=20GPU/CU?= =?UTF-8?q?DA=20=E9=83=A8=E7=BD=B2=E8=AF=B4=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 原 CUDA 段落只有一行链接。补上实操步骤与常见坑: - 宿主机前提:NVIDIA 驱动 + NVIDIA Container Toolkit + --gpus all 验证命令 - 切换:先 docker-compose down(两套 compose 容器名相同)再 -f docker-compose.gpu.yml up --build -d - 数据不丢(两套 compose 都绑挂 ./backend);首次构建大而慢 - 只有本地 Faster Whisper 吃 GPU(在线引擎无关);device 自动检测无需配置 - 确认走 GPU 的方法 + 没走 GPU 的排查清单 + 国内镜像 build-arg 提示 另:把 compose 块里 GPU 那行 up -d 改成 up --build -d(首次需构建)。 Co-Authored-By: Claude Opus 4.7 (1M context) --- README.md | 42 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 37 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 21c300e..47357c9 100644 --- a/README.md +++ b/README.md @@ -293,10 +293,42 @@ sudo apt install ffmpeg > > Docker 部署已内置 FFmpeg,无需额外安装。 -### 🚀 CUDA 加速(可选) -若你希望更快地执行音频转写任务,可使用具备 NVIDIA GPU 的机器,并启用 fast-whisper + CUDA 加速版本: +### 🚀 CUDA / GPU 加速(可选) -具体 `fast-whisper` 配置方法,请参考:[fast-whisper 项目地址](http://github.com/SYSTRAN/faster-whisper#requirements) +本地 **Faster Whisper** 转写可用 NVIDIA GPU 加速(在线引擎 Groq / 必剪 / 快手 与 GPU 无关)。仓库已自带 GPU 镜像与编排,**无需改代码、无需手动配置 device**——后端会自动检测 CUDA,可用就走 GPU,否则回退 CPU。 + +**1. 宿主机前提** + +- NVIDIA 显卡 + 较新驱动(CUDA ≥ 12.4),宿主机 `nvidia-smi` 能正常输出; +- 安装 [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)(最易漏的一步,没它 Docker 进不去 GPU)。装完验证: + ```bash + docker run --rm --gpus all nvidia/cuda:12.4.1-base-ubuntu22.04 nvidia-smi + ``` + 能列出显卡即 OK。 + +**2. 切换到 GPU 编排**(在源码目录里) + +CPU 与 GPU 两套 compose 用了相同的容器名,先停掉当前栈再起 GPU 栈: + +```bash +docker-compose down # 停掉当前(CPU)栈 +docker-compose -f docker-compose.gpu.yml up --build -d # 用 GPU 栈重建 +``` + +- GPU 栈用 `backend/Dockerfile.gpu`(CUDA 12.4.1 + cuDNN 基础镜像,并额外装 torch 用于 CUDA 检测),compose 已声明 `deploy...devices: nvidia` 自动透传 GPU。 +- **数据不丢**:两套 compose 都把 `./backend` 整目录绑挂进容器,数据库 / 配置 / 已下载模型都保留。 +- 首次构建较大较慢(CUDA 基础镜像数 GB + torch),耐心等。 + +**3. 启用并确认** + +- 「设置 → 音频转写配置」转写引擎选 **Faster Whisper(本地)**,GPU 下可放心选大模型(如 `large-v3`)。 +- 确认真的走了 GPU:`docker logs bilinote-backend | grep -i cuda` 看到 `CUDA 可用,使用 GPU`;或转写时宿主机 `nvidia-smi` 能看到 python 进程占显存。 + +**国内镜像**:GPU compose 支持 `BASE_REGISTRY` / `APT_MIRROR` / `PIP_INDEX` 这几个 build-arg(注意 `BASE_REGISTRY` 选的源必须支持 `nvidia/cuda` 命名空间,否则拉不到 CUDA 基础镜像)。 + +**起来了但没走 GPU?** 依次排查:① 宿主机 `nvidia-smi` 是否正常 → ② NVIDIA Container Toolkit 是否装好(上面 `--gpus all` 测试是否通过)→ ③ `docker logs bilinote-backend` 是否有 CUDA / cuDNN 报错(驱动 CUDA 版本需 ≥ 12.4)。 + +`fast-whisper` 本身的 GPU 依赖说明可参考:[faster-whisper 项目](https://github.com/SYSTRAN/faster-whisper#requirements) ### 🐳 使用 Docker 一键部署 @@ -328,8 +360,8 @@ docker run -d -p 80:80 \ # 标准部署 docker-compose up -d -# GPU 加速部署(需要 NVIDIA GPU) -docker-compose -f docker-compose.gpu.yml up -d +# GPU 加速部署(需要 NVIDIA GPU + NVIDIA Container Toolkit,详见上方「CUDA / GPU 加速」) +docker-compose -f docker-compose.gpu.yml up --build -d ``` ## 🧠 TODO From e78b687096c7946225b5a814cf47bbb0c9dab4c3 Mon Sep 17 00:00:00 2001 From: techotaku39 Date: Tue, 26 May 2026 21:15:39 +0800 Subject: [PATCH 3/4] fix(extension): improve title display and mindmap export --- BillNote_extension/src/background/main.ts | 4 +- BillNote_extension/src/components/MindMap.vue | 175 +++++++++- BillNote_extension/src/logic/task-display.ts | 21 ++ BillNote_extension/src/popup/Popup.vue | 18 +- .../src/sidepanel/Sidepanel.vue | 78 ++++- .../HomePage/components/MarkmapComponent.tsx | 322 ++++++++++++------ 6 files changed, 489 insertions(+), 129 deletions(-) create mode 100644 BillNote_extension/src/logic/task-display.ts diff --git a/BillNote_extension/src/background/main.ts b/BillNote_extension/src/background/main.ts index 7883e31..6b4d134 100644 --- a/BillNote_extension/src/background/main.ts +++ b/BillNote_extension/src/background/main.ts @@ -3,6 +3,7 @@ import type { Settings, TaskRecord } from '~/logic/types' import { DEFAULT_SETTINGS, MAX_TASKS, SETTINGS_KEY, TASKS_KEY } from '~/logic/constants' import { detectPlatform } from '~/logic/platform' import { fetchBilibiliSubtitle } from '~/logic/bilibili-subtitle' +import { normalizeVideoTitle } from '~/logic/task-display' // only on dev mode if (import.meta.hot) { @@ -58,6 +59,7 @@ async function upsertTask(record: TaskRecord) { async function startTask(url: string, title?: string): Promise<{ ok: boolean, taskId?: string, error?: string }> { const platform = detectPlatform(url) + const displayTitle = normalizeVideoTitle(title) if (!platform) return { ok: false, error: '当前链接不是支持的视频平台' } @@ -107,7 +109,7 @@ async function startTask(url: string, title?: string): Promise<{ ok: boolean, ta message: '已提交', createdAt: Date.now(), updatedAt: Date.now(), - title, + title: displayTitle, }) return { ok: true, taskId: body.data.task_id } } diff --git a/BillNote_extension/src/components/MindMap.vue b/BillNote_extension/src/components/MindMap.vue index 0da1b21..2155c88 100644 --- a/BillNote_extension/src/components/MindMap.vue +++ b/BillNote_extension/src/components/MindMap.vue @@ -1,32 +1,181 @@ diff --git a/BillNote_extension/src/logic/task-display.ts b/BillNote_extension/src/logic/task-display.ts new file mode 100644 index 0000000..cc4c55d --- /dev/null +++ b/BillNote_extension/src/logic/task-display.ts @@ -0,0 +1,21 @@ +import type { TaskRecord } from './types' + +const SITE_SUFFIX_RE = /\s*[-_—–||]\s*(哔哩哔哩|bilibili|youtube|抖音|douyin|快手|kuaishou)\s*$/i + +export function normalizeVideoTitle(title: string | undefined | null): string | undefined { + const value = title?.trim() + if (!value) + return undefined + return value + .replace(SITE_SUFFIX_RE, '') + .trim() || value +} + +export function getTaskDisplayTitle(task: TaskRecord | undefined | null, fallbackTitle?: string): string { + if (!task) + return normalizeVideoTitle(fallbackTitle) || '' + return normalizeVideoTitle((task.result?.audio_meta as { title?: string } | undefined)?.title) + || normalizeVideoTitle(task.title) + || normalizeVideoTitle(fallbackTitle) + || task.videoUrl +} diff --git a/BillNote_extension/src/popup/Popup.vue b/BillNote_extension/src/popup/Popup.vue index 83494b2..43de2fd 100644 --- a/BillNote_extension/src/popup/Popup.vue +++ b/BillNote_extension/src/popup/Popup.vue @@ -5,6 +5,7 @@ import { settings, settingsReady, tasks, tasksReady, upsertTask } from '~/logic/ import { generateNote, getTaskStatus, resolveImageUrl } from '~/logic/api' import { fetchBilibiliSubtitle } from '~/logic/bilibili-subtitle' import { NOTE_FORMATS, NOTE_STYLES, type NoteFormat, type TaskRecord } from '~/logic/types' +import { getTaskDisplayTitle, normalizeVideoTitle } from '~/logic/task-display' const tabUrl = ref('') const tabTitle = ref('') @@ -43,7 +44,7 @@ async function poll(taskId: string) { createdAt: activeTask.value?.createdAt ?? Date.now(), updatedAt: Date.now(), result: res.result ?? activeTask.value?.result, - title: activeTask.value?.title, + title: activeTask.value?.title || normalizeVideoTitle(tabTitle.value), }) if (res.status !== 'SUCCESS' && res.status !== 'FAILED') pollTimer = setTimeout(() => poll(taskId), 3000) @@ -95,7 +96,7 @@ async function start() { message: '已提交', createdAt: Date.now(), updatedAt: Date.now(), - title: tabTitle.value || undefined, + title: normalizeVideoTitle(tabTitle.value), }) poll(task_id) // 提交后顺手把侧边栏拉起来,免得用户来回切窗口 @@ -144,10 +145,7 @@ function selectTask(id: string) { } const activeCover = computed(() => activeTask.value?.result?.audio_meta?.cover_url as string | undefined) -const activeTitle = computed(() => - (activeTask.value?.result?.audio_meta?.title as string | undefined) - || activeTask.value?.title - || tabTitle.value) +const activeTitle = computed(() => getTaskDisplayTitle(activeTask.value, tabTitle.value)) function fmtTime(ts?: number) { if (!ts) @@ -182,8 +180,8 @@ onUnmounted(() => { -
- {{ tabUrl || '当前没有打开的标签页' }} +
+ {{ normalizeVideoTitle(tabTitle) || tabUrl || '当前没有打开的标签页' }}
@@ -336,8 +334,8 @@ onUnmounted(() => { :class="{ 'bg-blue-50': t.taskId === activeTaskId }" @click="selectTask(t.taskId)" > - - {{ (t.result?.audio_meta as { title?: string } | undefined)?.title || t.title || t.videoUrl }} + + {{ getTaskDisplayTitle(t) }} {{ t.status }} diff --git a/BillNote_extension/src/sidepanel/Sidepanel.vue b/BillNote_extension/src/sidepanel/Sidepanel.vue index 155625b..04f514e 100644 --- a/BillNote_extension/src/sidepanel/Sidepanel.vue +++ b/BillNote_extension/src/sidepanel/Sidepanel.vue @@ -3,14 +3,17 @@ import { computed, onMounted, onUnmounted, ref } from 'vue' import { getTaskStatus, resolveImageUrl } from '~/logic/api' import { tasks, tasksReady, settingsReady, upsertTask } from '~/logic/storage' import type { TaskRecord } from '~/logic/types' +import { getTaskDisplayTitle } from '~/logic/task-display' type ViewMode = 'markdown' | 'mindmap' | 'chat' const activeTaskId = ref('') const activeTask = computed(() => tasks.value?.find(t => t.taskId === activeTaskId.value)) const errorMsg = ref('') +const successMsg = ref('') const viewMode = ref('markdown') const showHistory = ref(false) +const mindMapRef = ref<{ toPngBlob: () => Promise } | null>(null) const isDone = computed(() => activeTask.value?.status === 'SUCCESS') const isFailed = computed(() => activeTask.value?.status === 'FAILED') @@ -41,7 +44,7 @@ async function poll(taskId: string) { message: res.message, result: res.result ?? cur.result, updatedAt: Date.now(), - title: cur.title, + title: cur.title || getTaskDisplayTitle(cur), }) } if (res.status !== 'SUCCESS' && res.status !== 'FAILED') @@ -75,11 +78,19 @@ async function copyMarkdown() { await navigator.clipboard.writeText(md) } +function safeFilename(name: string): string { + return (name || 'bilinote') + .replace(/[\\/:*?"<>|]/g, '_') + .replace(/\s+/g, ' ') + .trim() + .slice(0, 120) || 'bilinote' +} + function downloadMarkdown() { const md = activeTask.value?.result?.markdown if (!md) return - const title = (activeTask.value?.result?.audio_meta as { title?: string } | undefined)?.title || 'bilinote' + const title = safeFilename(getTaskDisplayTitle(activeTask.value)) const blob = new Blob([md], { type: 'text/markdown;charset=utf-8' }) const url = URL.createObjectURL(blob) const a = document.createElement('a') @@ -89,11 +100,44 @@ function downloadMarkdown() { URL.revokeObjectURL(url) } -const activeTitle = computed(() => - (activeTask.value?.result?.audio_meta as { title?: string } | undefined)?.title - || activeTask.value?.title - || activeTask.value?.videoUrl - || '') +async function copyMindMapImage() { + try { + errorMsg.value = '' + successMsg.value = '' + const blob = await mindMapRef.value?.toPngBlob() + if (!blob) + return + await navigator.clipboard.write([ + new ClipboardItem({ [blob.type]: blob }), + ]) + successMsg.value = '思维导图图片已复制' + setTimeout(() => { successMsg.value = '' }, 2000) + } + catch (e) { + errorMsg.value = (e as Error).message || '复制思维导图图片失败' + } +} + +async function downloadMindMapImage() { + try { + errorMsg.value = '' + successMsg.value = '' + const blob = await mindMapRef.value?.toPngBlob() + if (!blob) + return + const url = URL.createObjectURL(blob) + const a = document.createElement('a') + a.href = url + a.download = `${safeFilename(getTaskDisplayTitle(activeTask.value))}.png` + a.click() + URL.revokeObjectURL(url) + } + catch (e) { + errorMsg.value = (e as Error).message || '下载思维导图图片失败' + } +} + +const activeTitle = computed(() => getTaskDisplayTitle(activeTask.value)) const activeCover = computed(() => (activeTask.value?.result?.audio_meta as { cover_url?: string } | undefined)?.cover_url) @@ -144,8 +188,8 @@ onUnmounted(() => { :class="{ 'bg-white border': t.taskId === activeTaskId }" @click="selectTask(t.taskId)" > - - {{ (t.result?.audio_meta as { title?: string } | undefined)?.title || t.title || t.videoUrl }} + + {{ getTaskDisplayTitle(t) }} {{ STAGE_LABELS[t.status] || t.status }} @@ -155,6 +199,9 @@ onUnmounted(() => {
{{ errorMsg }}
+
+ {{ successMsg }} +
还没有任务。在视频页点悬浮按钮、在 popup 提交,或右键菜单选「用 BiliNote 总结」。 @@ -228,6 +275,18 @@ onUnmounted(() => { title="下载 .md" @click="downloadMarkdown" >下载 + +
@@ -240,6 +299,7 @@ onUnmounted(() => { /> diff --git a/BillNote_frontend/src/pages/HomePage/components/MarkmapComponent.tsx b/BillNote_frontend/src/pages/HomePage/components/MarkmapComponent.tsx index 3e77add..21a7e8d 100644 --- a/BillNote_frontend/src/pages/HomePage/components/MarkmapComponent.tsx +++ b/BillNote_frontend/src/pages/HomePage/components/MarkmapComponent.tsx @@ -5,6 +5,171 @@ import { Toolbar } from 'markmap-toolbar' import 'markmap-toolbar/dist/style.css' import JSZip from 'jszip' +const MIN_EXPORT_FONT_PX = 256 +const MIN_EXPORT_WIDTH = 12800 +const WEB_EXPORT_SCALE_FACTOR = 0.34 +const MAX_EXPORT_SCALE = 24 +const MAX_CANVAS_SIDE = 32767 +const MAX_CANVAS_PIXELS = 268000000 + +function canvasToBlob(canvas: HTMLCanvasElement): Promise { + return new Promise((resolve, reject) => { + canvas.toBlob((blob) => { + if (blob) { + resolve(blob) + } else { + reject(new Error('无法创建PNG图片')) + } + }, 'image/png') + }) +} + +function createSvgElement(tag: K): SVGElementTagNameMap[K] { + return document.createElementNS('http://www.w3.org/2000/svg', tag) +} + +function sanitizeSvgForCanvas(svg: SVGSVGElement): SVGSVGElement { + const cloned = svg.cloneNode(true) as SVGSVGElement + + // markmap 会在 SVG 的顶层 上写入当前预览视口的 pan/zoom transform。 + // 导出时我们按内容 bbox 裁剪,如果保留这个视口 transform,会产生双重偏移, + // 导致图片内容跑到角落并留下大片空白。这里只移除顶层视口 transform, + // 保留内部节点自身的布局 transform。 + cloned.querySelector(':scope > g')?.removeAttribute('transform') + + cloned.querySelectorAll('image').forEach(el => el.remove()) + cloned.querySelectorAll('foreignObject').forEach((foreignObject) => { + const textContent = foreignObject.textContent?.replace(/\s+/g, ' ').trim() + if (!textContent) { + foreignObject.remove() + return + } + + const x = Number(foreignObject.getAttribute('x') || 0) + const y = Number(foreignObject.getAttribute('y') || 0) + const height = Number(foreignObject.getAttribute('height') || 20) + const text = createSvgElement('text') + text.setAttribute('x', String(x)) + text.setAttribute('y', String(y + height / 2)) + text.setAttribute('dominant-baseline', 'middle') + text.setAttribute('font-size', '14') + text.setAttribute('font-family', 'Arial, "Microsoft YaHei", sans-serif') + text.setAttribute('fill', '#333') + text.textContent = textContent + foreignObject.replaceWith(text) + }) + + return cloned +} + +function getExportFontSize(svg: SVGSVGElement): number { + const text = svg.querySelector('text, foreignObject') + if (!text) return 14 + + const fontSize = Number.parseFloat(getComputedStyle(text).fontSize || '') + if (Number.isFinite(fontSize) && fontSize > 0) return fontSize + + const attrSize = Number.parseFloat(text.getAttribute('font-size') || '') + return Number.isFinite(attrSize) && attrSize > 0 ? attrSize : 14 +} + +function getMindmapBounds(svg: SVGSVGElement) { + const target = svg.querySelector('g') || svg + const bbox = target.getBBox() + const padding = 50 + return { + x: Math.floor(bbox.x - padding), + y: Math.floor(bbox.y - padding), + width: Math.max(Math.ceil(bbox.width + padding * 2), 1), + height: Math.max(Math.ceil(bbox.height + padding * 2), 1), + } +} + +function stripMindmapImages(markdown: string) { + return (markdown || '') + // 思维导图只保留文字结构,图片节点会让预览排版和 PNG 导出效果都很差。 + .replace(/!\[[^\]]*\]\([^)]*\)/g, '') + .replace(/]*>/gi, '') +} + +function transformMindmap(markdown: string) { + return transformer.transform(stripMindmapImages(markdown)) +} + +function createExportSvg(svgEl: SVGSVGElement) { + const bounds = getMindmapBounds(svgEl) + const clonedSvg = sanitizeSvgForCanvas(svgEl) + + clonedSvg.setAttribute('xmlns', 'http://www.w3.org/2000/svg') + clonedSvg.setAttribute('xmlns:xlink', 'http://www.w3.org/1999/xlink') + clonedSvg.setAttribute('width', String(bounds.width)) + clonedSvg.setAttribute('height', String(bounds.height)) + clonedSvg.setAttribute('viewBox', `${bounds.x} ${bounds.y} ${bounds.width} ${bounds.height}`) + clonedSvg.setAttribute('preserveAspectRatio', 'xMidYMid meet') + + const bgRect = document.createElementNS('http://www.w3.org/2000/svg', 'rect') + bgRect.setAttribute('x', String(bounds.x)) + bgRect.setAttribute('y', String(bounds.y)) + bgRect.setAttribute('width', String(bounds.width)) + bgRect.setAttribute('height', String(bounds.height)) + bgRect.setAttribute('fill', 'white') + const firstG = clonedSvg.querySelector('g') + clonedSvg.insertBefore(bgRect, firstG || clonedSvg.firstChild) + + return { clonedSvg, ...bounds } +} + +async function exportSvgToPngBlob(svgEl: SVGSVGElement): Promise { + const { clonedSvg, width, height } = createExportSvg(svgEl) + const svgData = new XMLSerializer().serializeToString(clonedSvg) + const svgUrl = URL.createObjectURL(new Blob([svgData], { type: 'image/svg+xml;charset=utf-8' })) + + try { + const img = new Image() + img.decoding = 'async' + img.src = svgUrl + await img.decode() + + // 按导图内容尺寸和字号动态反推 PNG 倍率,而不是按预览容器或固定倍率导出。 + const fontScale = MIN_EXPORT_FONT_PX / getExportFontSize(svgEl) + const widthScale = MIN_EXPORT_WIDTH / width + const rawScale = Math.max(window.devicePixelRatio || 1, fontScale, widthScale) + const sideLimitScale = Math.min(MAX_CANVAS_SIDE / width, MAX_CANVAS_SIDE / height) + const pixelLimitScale = Math.sqrt(MAX_CANVAS_PIXELS / (width * height)) + const baseScale = Math.min(rawScale, MAX_EXPORT_SCALE, sideLimitScale, pixelLimitScale) + const scale = Math.max(1, baseScale * WEB_EXPORT_SCALE_FACTOR) + + let currentScale = scale + let lastError: unknown + while (currentScale >= 1) { + try { + const canvas = document.createElement('canvas') + canvas.width = Math.ceil(width * currentScale) + canvas.height = Math.ceil(height * currentScale) + + const ctx = canvas.getContext('2d') + if (!ctx) { + throw new Error('无法获取Canvas上下文') + } + + ctx.fillStyle = '#FFFFFF' + ctx.fillRect(0, 0, canvas.width, canvas.height) + ctx.setTransform(currentScale, 0, 0, currentScale, 0, 0) + ctx.drawImage(img, 0, 0, width, height) + ctx.setTransform(1, 0, 0, 1, 0, 0) + + return await canvasToBlob(canvas) + } catch (error) { + lastError = error + currentScale = Math.floor(currentScale / 2) + } + } + throw lastError || new Error('导出PNG失败') + } finally { + URL.revokeObjectURL(svgUrl) + } +} + export interface MarkmapEditorProps { /** 要渲染的 Markdown 文本 */ value: string @@ -34,6 +199,13 @@ export default function MarkmapEditor({ // 用于跟踪是否处于全屏状态 const [isFullscreen, setIsFullscreen] = useState(false) + const [pngAction, setPngAction] = useState<'idle' | 'exporting' | 'copying'>('idle') + const [pngMessage, setPngMessage] = useState('') + + const showPngMessage = (message: string) => { + setPngMessage(message) + window.setTimeout(() => setPngMessage(''), 2500) + } // 监听全屏状态变化 useEffect(() => { @@ -64,7 +236,7 @@ export default function MarkmapEditor({ // 导出HTML思维导图 const exportHtml = () => { try { - const { root } = transformer.transform(value) + const { root } = transformMindmap(value) const data = JSON.stringify(root) // 创建HTML内容 @@ -202,7 +374,7 @@ export default function MarkmapEditor({ // 导出XMind格式思维导图 const exportXMind = async () => { try { - const { root } = transformer.transform(value); + const { root } = transformMindmap(value); // 生成唯一ID const generateId = () => Math.random().toString(36).substring(2, 15); @@ -311,100 +483,44 @@ export default function MarkmapEditor({ try { if (!svgRef.current || !mmRef.current) return; - const svgEl = svgRef.current; - const mm = mmRef.current; - - // 先调用fit()确保显示完整的思维导图内容 - await mm.fit(); - // 等待渲染完成 - await new Promise(resolve => setTimeout(resolve, 100)); - - // 获取SVG实际尺寸 - const svgWidth = svgEl.width.baseVal.value || svgEl.clientWidth || 800; - const svgHeight = svgEl.height.baseVal.value || svgEl.clientHeight || 600; - - // 设置足够大的缩放比例以确保高清输出 - const scale = 3; - - // 克隆SVG以避免修改原始SVG - const clonedSvg = svgEl.cloneNode(true) as SVGSVGElement; - - // 设置SVG的背景为白色 - const style = document.createElementNS('http://www.w3.org/2000/svg', 'style'); - style.textContent = 'svg { background-color: white; }'; - clonedSvg.insertBefore(style, clonedSvg.firstChild); - - // 确保SVG有正确的命名空间 - clonedSvg.setAttribute('xmlns', 'http://www.w3.org/2000/svg'); - clonedSvg.setAttribute('width', svgWidth.toString()); - clonedSvg.setAttribute('height', svgHeight.toString()); - - // 将SVG转换为Data URI (避免使用Blob URL来解决跨域问题) - const svgData = new XMLSerializer().serializeToString(clonedSvg); - const svgBase64 = btoa(unescape(encodeURIComponent(svgData))); - const dataUri = `data:image/svg+xml;base64,${svgBase64}`; - - // 创建Canvas - const canvas = document.createElement('canvas'); - canvas.width = svgWidth * scale; - canvas.height = svgHeight * scale; - - // 获取上下文并设置白色背景 - const ctx = canvas.getContext('2d'); - if (!ctx) { - throw new Error('无法获取Canvas上下文'); - } - - // 设置白色背景 - ctx.fillStyle = '#FFFFFF'; - ctx.fillRect(0, 0, canvas.width, canvas.height); - - // 创建Image对象 - const img = new Image(); - - // 当图片加载完成后,在Canvas上绘制并导出 - img.onload = () => { - try { - // 应用缩放 - ctx.setTransform(scale, 0, 0, scale, 0, 0); - - // 绘制SVG - ctx.drawImage(img, 0, 0); - - // 重置变换 - ctx.setTransform(1, 0, 0, 1, 0, 0); - - // 将Canvas转换为PNG Blob - canvas.toBlob((blob) => { - if (blob) { - // 创建下载链接 - const url = URL.createObjectURL(blob); - const a = document.createElement('a'); - a.href = url; - a.download = `${title || 'mindmap'}.png`; - document.body.appendChild(a); - a.click(); - document.body.removeChild(a); - URL.revokeObjectURL(url); - } else { - console.error('无法创建Blob对象'); - } - }, 'image/png'); - } catch (err) { - console.error('Canvas处理失败:', err); - } - }; - - // 设置图片加载错误处理 - img.onerror = (error) => { - console.error('导出PNG失败(图片加载错误):', error); - }; - - // 开始加载SVG图像 (使用Data URI而不是Blob URL) - img.src = dataUri; - + setPngAction('exporting'); + setPngMessage('正在生成高清 PNG…'); + const blob = await exportSvgToPngBlob(svgRef.current); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = `${title || 'mindmap'}.png`; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + URL.revokeObjectURL(url); + showPngMessage('PNG 已开始下载'); } catch (error) { console.error('导出PNG失败:', error); + showPngMessage('导出 PNG 失败,请查看控制台'); + } finally { + setPngAction('idle'); + } + }; + + // 复制PNG思维导图 + const copyPng = async () => { + try { + if (!svgRef.current || !mmRef.current) return; + + setPngAction('copying'); + setPngMessage('正在复制高清 PNG…'); + await navigator.clipboard.write([ + new ClipboardItem({ + 'image/png': exportSvgToPngBlob(svgRef.current), + }), + ]); + showPngMessage('PNG 已复制'); + } catch (error) { + console.error('复制PNG失败:', error); + showPngMessage('复制 PNG 失败,请查看控制台'); + } finally { + setPngAction('idle'); } }; @@ -428,7 +544,7 @@ export default function MarkmapEditor({ useEffect(() => { const mm = mmRef.current if (!mm) return - const { root } = transformer.transform(value) + const { root } = transformMindmap(value) mm.setData(root).then(() => mm.fit()) }, [value]) @@ -459,8 +575,17 @@ export default function MarkmapEditor({ onClick={exportPng} className="rounded p-1 hover:bg-gray-200" title="导出PNG图片" + disabled={pngAction !== 'idle'} > - 🖼️ + {pngAction === 'exporting' ? '⏳' : '🖼️'} + +
+ {pngMessage && ( +
+ {pngMessage} +
+ )} {/* 如果需要编辑区,就自己加一个