mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-06-11 18:49:59 +08:00
Merge branch 'release/2.4.0'
This commit is contained in:
@@ -3,6 +3,7 @@ import type { Settings, TaskRecord } from '~/logic/types'
|
||||
import { DEFAULT_SETTINGS, MAX_TASKS, SETTINGS_KEY, TASKS_KEY } from '~/logic/constants'
|
||||
import { detectPlatform } from '~/logic/platform'
|
||||
import { fetchBilibiliSubtitle } from '~/logic/bilibili-subtitle'
|
||||
import { normalizeVideoTitle } from '~/logic/task-display'
|
||||
|
||||
// only on dev mode
|
||||
if (import.meta.hot) {
|
||||
@@ -58,6 +59,7 @@ async function upsertTask(record: TaskRecord) {
|
||||
|
||||
async function startTask(url: string, title?: string): Promise<{ ok: boolean, taskId?: string, error?: string }> {
|
||||
const platform = detectPlatform(url)
|
||||
const displayTitle = normalizeVideoTitle(title)
|
||||
if (!platform)
|
||||
return { ok: false, error: '当前链接不是支持的视频平台' }
|
||||
|
||||
@@ -107,7 +109,7 @@ async function startTask(url: string, title?: string): Promise<{ ok: boolean, ta
|
||||
message: '已提交',
|
||||
createdAt: Date.now(),
|
||||
updatedAt: Date.now(),
|
||||
title,
|
||||
title: displayTitle,
|
||||
})
|
||||
return { ok: true, taskId: body.data.task_id }
|
||||
}
|
||||
|
||||
@@ -1,32 +1,181 @@
|
||||
<script setup lang="ts">
|
||||
import { onMounted, ref, watch } from 'vue'
|
||||
import { nextTick, onMounted, onUnmounted, ref, watch } from 'vue'
|
||||
import { Transformer } from 'markmap-lib'
|
||||
import { Markmap } from 'markmap-view'
|
||||
import { absolutizeMarkdownImages, stripSourceLink } from '~/logic/api'
|
||||
|
||||
const props = defineProps<{ markdown: string }>()
|
||||
|
||||
const wrapRef = ref<HTMLDivElement | null>(null)
|
||||
const svgRef = ref<SVGSVGElement | null>(null)
|
||||
let mm: Markmap | null = null
|
||||
let resizeObserver: ResizeObserver | null = null
|
||||
const transformer = new Transformer()
|
||||
const MIN_EXPORT_FONT_PX = 256
|
||||
const MIN_EXPORT_WIDTH = 12800
|
||||
const MAX_EXPORT_SCALE = 24
|
||||
const MAX_CANVAS_SIDE = 32767
|
||||
|
||||
function render() {
|
||||
if (!svgRef.value)
|
||||
return
|
||||
const md = absolutizeMarkdownImages(stripSourceLink(props.markdown || ''))
|
||||
const { root } = transformer.transform(md)
|
||||
if (!mm)
|
||||
mm = Markmap.create(svgRef.value, undefined, root)
|
||||
else
|
||||
mm.setData(root).then(() => mm?.fit())
|
||||
function canvasToBlob(canvas: HTMLCanvasElement): Promise<Blob> {
|
||||
return new Promise((resolve, reject) => {
|
||||
canvas.toBlob((blob) => {
|
||||
if (blob)
|
||||
resolve(blob)
|
||||
else
|
||||
reject(new Error('导出思维导图图片失败'))
|
||||
}, 'image/png')
|
||||
})
|
||||
}
|
||||
|
||||
onMounted(render)
|
||||
function createSvgElement<K extends keyof SVGElementTagNameMap>(tag: K): SVGElementTagNameMap[K] {
|
||||
return document.createElementNS('http://www.w3.org/2000/svg', tag)
|
||||
}
|
||||
|
||||
function sanitizeSvgForCanvas(svg: SVGSVGElement): SVGSVGElement {
|
||||
const cloned = svg.cloneNode(true) as SVGSVGElement
|
||||
|
||||
cloned.querySelectorAll('image').forEach(el => el.remove())
|
||||
cloned.querySelectorAll('foreignObject').forEach((foreignObject) => {
|
||||
const textContent = foreignObject.textContent?.replace(/\s+/g, ' ').trim()
|
||||
if (!textContent) {
|
||||
foreignObject.remove()
|
||||
return
|
||||
}
|
||||
|
||||
const x = Number(foreignObject.getAttribute('x') || 0)
|
||||
const y = Number(foreignObject.getAttribute('y') || 0)
|
||||
const height = Number(foreignObject.getAttribute('height') || 20)
|
||||
const text = createSvgElement('text')
|
||||
text.setAttribute('x', String(x))
|
||||
text.setAttribute('y', String(y + height / 2))
|
||||
text.setAttribute('dominant-baseline', 'middle')
|
||||
text.setAttribute('font-size', '14')
|
||||
text.setAttribute('font-family', 'Arial, "Microsoft YaHei", sans-serif')
|
||||
text.setAttribute('fill', '#333')
|
||||
text.textContent = textContent
|
||||
foreignObject.replaceWith(text)
|
||||
})
|
||||
|
||||
return cloned
|
||||
}
|
||||
|
||||
function getExportFontSize(svg: SVGSVGElement): number {
|
||||
const text = svg.querySelector('text, foreignObject')
|
||||
if (!text)
|
||||
return 14
|
||||
|
||||
const fontSize = Number.parseFloat(getComputedStyle(text).fontSize || '')
|
||||
if (Number.isFinite(fontSize) && fontSize > 0)
|
||||
return fontSize
|
||||
|
||||
const attrSize = Number.parseFloat(text.getAttribute('font-size') || '')
|
||||
return Number.isFinite(attrSize) && attrSize > 0 ? attrSize : 14
|
||||
}
|
||||
|
||||
function stripMindmapNoise(md: string): string {
|
||||
return absolutizeMarkdownImages(stripSourceLink(md || ''))
|
||||
// 笔记里的截图/封面图片在思维导图中会被当作超大 SVG foreignObject,
|
||||
// 容易把导图挤成截图里那种“只剩半框/一条竖线”的效果。导图只保留文字层级。
|
||||
.replace(/!\[[^\]]*\]\([^)]*\)/g, '')
|
||||
.replace(/<img\b[^>]*>/gi, '')
|
||||
}
|
||||
|
||||
async function fit() {
|
||||
await nextTick()
|
||||
requestAnimationFrame(() => mm?.fit())
|
||||
}
|
||||
|
||||
async function render() {
|
||||
if (!svgRef.value)
|
||||
return
|
||||
const { root } = transformer.transform(stripMindmapNoise(props.markdown))
|
||||
if (!mm)
|
||||
mm = Markmap.create(svgRef.value, { autoFit: true }, root)
|
||||
else
|
||||
await mm.setData(root)
|
||||
await fit()
|
||||
}
|
||||
|
||||
async function toPngBlob(): Promise<Blob> {
|
||||
await fit()
|
||||
await nextTick()
|
||||
if (!svgRef.value)
|
||||
throw new Error('思维导图尚未渲染完成')
|
||||
|
||||
const svg = svgRef.value
|
||||
const bbox = svg.getBBox()
|
||||
const padding = 48
|
||||
const x = Math.floor(bbox.x - padding)
|
||||
const y = Math.floor(bbox.y - padding)
|
||||
const width = Math.max(Math.ceil(bbox.width + padding * 2), 1)
|
||||
const height = Math.max(Math.ceil(bbox.height + padding * 2), 1)
|
||||
const cloned = sanitizeSvgForCanvas(svg)
|
||||
|
||||
cloned.setAttribute('xmlns', 'http://www.w3.org/2000/svg')
|
||||
cloned.setAttribute('width', String(width))
|
||||
cloned.setAttribute('height', String(height))
|
||||
cloned.setAttribute('viewBox', `${x} ${y} ${width} ${height}`)
|
||||
cloned.insertAdjacentHTML('afterbegin', `<rect width="100%" height="100%" fill="#fff"/>`)
|
||||
|
||||
const svgText = new XMLSerializer().serializeToString(cloned)
|
||||
const url = URL.createObjectURL(new Blob([svgText], { type: 'image/svg+xml;charset=utf-8' }))
|
||||
|
||||
try {
|
||||
const img = new Image()
|
||||
img.decoding = 'async'
|
||||
img.src = url
|
||||
await img.decode()
|
||||
|
||||
// 不写死某个导出宽度:按导图内容和文字字号动态反推 PNG 倍率。
|
||||
// 目标是让导出的正文至少有 MIN_EXPORT_FONT_PX 像素高,小图自动放大,
|
||||
// 大图则按内容尺寸导出;同时限制最大边长,避免复杂导图撑爆内存。
|
||||
const fontScale = MIN_EXPORT_FONT_PX / getExportFontSize(svg)
|
||||
const widthScale = MIN_EXPORT_WIDTH / width
|
||||
const rawScale = Math.max(window.devicePixelRatio || 1, fontScale, widthScale)
|
||||
const sideLimitScale = Math.min(MAX_CANVAS_SIDE / width, MAX_CANVAS_SIDE / height)
|
||||
const scale = Math.max(1, Math.min(rawScale, MAX_EXPORT_SCALE, sideLimitScale))
|
||||
const canvas = document.createElement('canvas')
|
||||
canvas.width = Math.ceil(width * scale)
|
||||
canvas.height = Math.ceil(height * scale)
|
||||
const ctx = canvas.getContext('2d')
|
||||
if (!ctx)
|
||||
throw new Error('当前浏览器不支持 Canvas 导出')
|
||||
ctx.fillStyle = '#fff'
|
||||
ctx.fillRect(0, 0, canvas.width, canvas.height)
|
||||
ctx.scale(scale, scale)
|
||||
ctx.drawImage(img, 0, 0, width, height)
|
||||
return await canvasToBlob(canvas)
|
||||
}
|
||||
finally {
|
||||
URL.revokeObjectURL(url)
|
||||
}
|
||||
}
|
||||
|
||||
defineExpose({
|
||||
toPngBlob,
|
||||
})
|
||||
|
||||
onMounted(() => {
|
||||
render()
|
||||
if (wrapRef.value) {
|
||||
resizeObserver = new ResizeObserver(() => fit())
|
||||
resizeObserver.observe(wrapRef.value)
|
||||
}
|
||||
})
|
||||
|
||||
onUnmounted(() => {
|
||||
resizeObserver?.disconnect()
|
||||
resizeObserver = null
|
||||
mm?.destroy()
|
||||
mm = null
|
||||
})
|
||||
|
||||
watch(() => props.markdown, render)
|
||||
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div class="w-full h-full bg-white rounded border overflow-hidden">
|
||||
<svg ref="svgRef" class="w-full h-full" />
|
||||
<div ref="wrapRef" class="w-full h-full min-h-[360px] bg-white rounded border overflow-hidden">
|
||||
<svg ref="svgRef" class="w-full h-full min-h-[360px]" />
|
||||
</div>
|
||||
</template>
|
||||
|
||||
21
BillNote_extension/src/logic/task-display.ts
Normal file
21
BillNote_extension/src/logic/task-display.ts
Normal file
@@ -0,0 +1,21 @@
|
||||
import type { TaskRecord } from './types'
|
||||
|
||||
const SITE_SUFFIX_RE = /\s*[-_—–||]\s*(哔哩哔哩|bilibili|youtube|抖音|douyin|快手|kuaishou)\s*$/i
|
||||
|
||||
export function normalizeVideoTitle(title: string | undefined | null): string | undefined {
|
||||
const value = title?.trim()
|
||||
if (!value)
|
||||
return undefined
|
||||
return value
|
||||
.replace(SITE_SUFFIX_RE, '')
|
||||
.trim() || value
|
||||
}
|
||||
|
||||
export function getTaskDisplayTitle(task: TaskRecord | undefined | null, fallbackTitle?: string): string {
|
||||
if (!task)
|
||||
return normalizeVideoTitle(fallbackTitle) || ''
|
||||
return normalizeVideoTitle((task.result?.audio_meta as { title?: string } | undefined)?.title)
|
||||
|| normalizeVideoTitle(task.title)
|
||||
|| normalizeVideoTitle(fallbackTitle)
|
||||
|| task.videoUrl
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import { settings, settingsReady, tasks, tasksReady, upsertTask } from '~/logic/
|
||||
import { generateNote, getTaskStatus, resolveImageUrl } from '~/logic/api'
|
||||
import { fetchBilibiliSubtitle } from '~/logic/bilibili-subtitle'
|
||||
import { NOTE_FORMATS, NOTE_STYLES, type NoteFormat, type TaskRecord } from '~/logic/types'
|
||||
import { getTaskDisplayTitle, normalizeVideoTitle } from '~/logic/task-display'
|
||||
|
||||
const tabUrl = ref<string>('')
|
||||
const tabTitle = ref<string>('')
|
||||
@@ -43,7 +44,7 @@ async function poll(taskId: string) {
|
||||
createdAt: activeTask.value?.createdAt ?? Date.now(),
|
||||
updatedAt: Date.now(),
|
||||
result: res.result ?? activeTask.value?.result,
|
||||
title: activeTask.value?.title,
|
||||
title: activeTask.value?.title || normalizeVideoTitle(tabTitle.value),
|
||||
})
|
||||
if (res.status !== 'SUCCESS' && res.status !== 'FAILED')
|
||||
pollTimer = setTimeout(() => poll(taskId), 3000)
|
||||
@@ -95,7 +96,7 @@ async function start() {
|
||||
message: '已提交',
|
||||
createdAt: Date.now(),
|
||||
updatedAt: Date.now(),
|
||||
title: tabTitle.value || undefined,
|
||||
title: normalizeVideoTitle(tabTitle.value),
|
||||
})
|
||||
poll(task_id)
|
||||
// 提交后顺手把侧边栏拉起来,免得用户来回切窗口
|
||||
@@ -144,10 +145,7 @@ function selectTask(id: string) {
|
||||
}
|
||||
|
||||
const activeCover = computed(() => activeTask.value?.result?.audio_meta?.cover_url as string | undefined)
|
||||
const activeTitle = computed(() =>
|
||||
(activeTask.value?.result?.audio_meta?.title as string | undefined)
|
||||
|| activeTask.value?.title
|
||||
|| tabTitle.value)
|
||||
const activeTitle = computed(() => getTaskDisplayTitle(activeTask.value, tabTitle.value))
|
||||
|
||||
function fmtTime(ts?: number) {
|
||||
if (!ts)
|
||||
@@ -182,8 +180,8 @@ onUnmounted(() => {
|
||||
<button class="text-xs text-gray-500 hover:text-gray-800" @click="openOptions">设置</button>
|
||||
</header>
|
||||
|
||||
<div class="text-xs text-gray-500 truncate" :title="tabUrl">
|
||||
{{ tabUrl || '当前没有打开的标签页' }}
|
||||
<div class="text-xs text-gray-500 truncate" :title="normalizeVideoTitle(tabTitle) || tabUrl">
|
||||
{{ normalizeVideoTitle(tabTitle) || tabUrl || '当前没有打开的标签页' }}
|
||||
</div>
|
||||
|
||||
<div v-if="!supported" class="text-xs text-amber-700 bg-amber-50 p-2 rounded">
|
||||
@@ -336,8 +334,8 @@ onUnmounted(() => {
|
||||
:class="{ 'bg-blue-50': t.taskId === activeTaskId }"
|
||||
@click="selectTask(t.taskId)"
|
||||
>
|
||||
<span class="truncate flex-1" :title="t.title || t.videoUrl">
|
||||
{{ (t.result?.audio_meta as { title?: string } | undefined)?.title || t.title || t.videoUrl }}
|
||||
<span class="truncate flex-1" :title="getTaskDisplayTitle(t)">
|
||||
{{ getTaskDisplayTitle(t) }}
|
||||
</span>
|
||||
<span class="text-gray-500 shrink-0">{{ t.status }}</span>
|
||||
</li>
|
||||
|
||||
@@ -3,14 +3,17 @@ import { computed, onMounted, onUnmounted, ref } from 'vue'
|
||||
import { getTaskStatus, resolveImageUrl } from '~/logic/api'
|
||||
import { tasks, tasksReady, settingsReady, upsertTask } from '~/logic/storage'
|
||||
import type { TaskRecord } from '~/logic/types'
|
||||
import { getTaskDisplayTitle } from '~/logic/task-display'
|
||||
|
||||
type ViewMode = 'markdown' | 'mindmap' | 'chat'
|
||||
|
||||
const activeTaskId = ref<string>('')
|
||||
const activeTask = computed<TaskRecord | undefined>(() => tasks.value?.find(t => t.taskId === activeTaskId.value))
|
||||
const errorMsg = ref('')
|
||||
const successMsg = ref('')
|
||||
const viewMode = ref<ViewMode>('markdown')
|
||||
const showHistory = ref(false)
|
||||
const mindMapRef = ref<{ toPngBlob: () => Promise<Blob> } | null>(null)
|
||||
|
||||
const isDone = computed(() => activeTask.value?.status === 'SUCCESS')
|
||||
const isFailed = computed(() => activeTask.value?.status === 'FAILED')
|
||||
@@ -41,7 +44,7 @@ async function poll(taskId: string) {
|
||||
message: res.message,
|
||||
result: res.result ?? cur.result,
|
||||
updatedAt: Date.now(),
|
||||
title: cur.title,
|
||||
title: cur.title || getTaskDisplayTitle(cur),
|
||||
})
|
||||
}
|
||||
if (res.status !== 'SUCCESS' && res.status !== 'FAILED')
|
||||
@@ -75,11 +78,19 @@ async function copyMarkdown() {
|
||||
await navigator.clipboard.writeText(md)
|
||||
}
|
||||
|
||||
function safeFilename(name: string): string {
|
||||
return (name || 'bilinote')
|
||||
.replace(/[\\/:*?"<>|]/g, '_')
|
||||
.replace(/\s+/g, ' ')
|
||||
.trim()
|
||||
.slice(0, 120) || 'bilinote'
|
||||
}
|
||||
|
||||
function downloadMarkdown() {
|
||||
const md = activeTask.value?.result?.markdown
|
||||
if (!md)
|
||||
return
|
||||
const title = (activeTask.value?.result?.audio_meta as { title?: string } | undefined)?.title || 'bilinote'
|
||||
const title = safeFilename(getTaskDisplayTitle(activeTask.value))
|
||||
const blob = new Blob([md], { type: 'text/markdown;charset=utf-8' })
|
||||
const url = URL.createObjectURL(blob)
|
||||
const a = document.createElement('a')
|
||||
@@ -89,11 +100,44 @@ function downloadMarkdown() {
|
||||
URL.revokeObjectURL(url)
|
||||
}
|
||||
|
||||
const activeTitle = computed(() =>
|
||||
(activeTask.value?.result?.audio_meta as { title?: string } | undefined)?.title
|
||||
|| activeTask.value?.title
|
||||
|| activeTask.value?.videoUrl
|
||||
|| '')
|
||||
async function copyMindMapImage() {
|
||||
try {
|
||||
errorMsg.value = ''
|
||||
successMsg.value = ''
|
||||
const blob = await mindMapRef.value?.toPngBlob()
|
||||
if (!blob)
|
||||
return
|
||||
await navigator.clipboard.write([
|
||||
new ClipboardItem({ [blob.type]: blob }),
|
||||
])
|
||||
successMsg.value = '思维导图图片已复制'
|
||||
setTimeout(() => { successMsg.value = '' }, 2000)
|
||||
}
|
||||
catch (e) {
|
||||
errorMsg.value = (e as Error).message || '复制思维导图图片失败'
|
||||
}
|
||||
}
|
||||
|
||||
async function downloadMindMapImage() {
|
||||
try {
|
||||
errorMsg.value = ''
|
||||
successMsg.value = ''
|
||||
const blob = await mindMapRef.value?.toPngBlob()
|
||||
if (!blob)
|
||||
return
|
||||
const url = URL.createObjectURL(blob)
|
||||
const a = document.createElement('a')
|
||||
a.href = url
|
||||
a.download = `${safeFilename(getTaskDisplayTitle(activeTask.value))}.png`
|
||||
a.click()
|
||||
URL.revokeObjectURL(url)
|
||||
}
|
||||
catch (e) {
|
||||
errorMsg.value = (e as Error).message || '下载思维导图图片失败'
|
||||
}
|
||||
}
|
||||
|
||||
const activeTitle = computed(() => getTaskDisplayTitle(activeTask.value))
|
||||
|
||||
const activeCover = computed(() =>
|
||||
(activeTask.value?.result?.audio_meta as { cover_url?: string } | undefined)?.cover_url)
|
||||
@@ -144,8 +188,8 @@ onUnmounted(() => {
|
||||
:class="{ 'bg-white border': t.taskId === activeTaskId }"
|
||||
@click="selectTask(t.taskId)"
|
||||
>
|
||||
<span class="truncate flex-1" :title="t.title || t.videoUrl">
|
||||
{{ (t.result?.audio_meta as { title?: string } | undefined)?.title || t.title || t.videoUrl }}
|
||||
<span class="truncate flex-1" :title="getTaskDisplayTitle(t)">
|
||||
{{ getTaskDisplayTitle(t) }}
|
||||
</span>
|
||||
<span class="text-gray-400 shrink-0">{{ STAGE_LABELS[t.status] || t.status }}</span>
|
||||
</li>
|
||||
@@ -155,6 +199,9 @@ onUnmounted(() => {
|
||||
<div v-if="errorMsg" class="text-xs text-red-600 px-3 py-1 break-words bg-red-50 shrink-0">
|
||||
{{ errorMsg }}
|
||||
</div>
|
||||
<div v-if="successMsg" class="text-xs text-green-700 px-3 py-1 break-words bg-green-50 shrink-0">
|
||||
{{ successMsg }}
|
||||
</div>
|
||||
|
||||
<section v-if="!activeTask" class="flex-1 flex items-center justify-center text-gray-400 text-xs px-4 text-center">
|
||||
还没有任务。在视频页点悬浮按钮、在 popup 提交,或右键菜单选「用 BiliNote 总结」。
|
||||
@@ -228,6 +275,18 @@ onUnmounted(() => {
|
||||
title="下载 .md"
|
||||
@click="downloadMarkdown"
|
||||
>下载</button>
|
||||
<button
|
||||
v-if="viewMode === 'mindmap'"
|
||||
class="text-gray-500 hover:text-gray-800 px-1.5 py-1 rounded hover:bg-gray-100"
|
||||
title="复制思维导图图片"
|
||||
@click="copyMindMapImage"
|
||||
>复制</button>
|
||||
<button
|
||||
v-if="viewMode === 'mindmap'"
|
||||
class="text-gray-500 hover:text-gray-800 px-1.5 py-1 rounded hover:bg-gray-100"
|
||||
title="下载思维导图 PNG"
|
||||
@click="downloadMindMapImage"
|
||||
>下载</button>
|
||||
</div>
|
||||
|
||||
<!-- 内容区:占满剩余空间 -->
|
||||
@@ -240,6 +299,7 @@ onUnmounted(() => {
|
||||
/>
|
||||
<MindMap
|
||||
v-else-if="isDone && activeTask.result?.markdown && viewMode === 'mindmap'"
|
||||
ref="mindMapRef"
|
||||
:markdown="activeTask.result.markdown"
|
||||
class="h-full"
|
||||
/>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"$schema": "../node_modules/@tauri-apps/cli/config.schema.json",
|
||||
"productName": "BiliNote",
|
||||
"version": "2.3.4",
|
||||
"version": "2.4.0",
|
||||
"identifier": "com.jefferyhuang.bilinote",
|
||||
"build": {
|
||||
"frontendDist": "../dist",
|
||||
|
||||
@@ -5,6 +5,171 @@ import { Toolbar } from 'markmap-toolbar'
|
||||
import 'markmap-toolbar/dist/style.css'
|
||||
import JSZip from 'jszip'
|
||||
|
||||
const MIN_EXPORT_FONT_PX = 256
|
||||
const MIN_EXPORT_WIDTH = 12800
|
||||
const WEB_EXPORT_SCALE_FACTOR = 0.34
|
||||
const MAX_EXPORT_SCALE = 24
|
||||
const MAX_CANVAS_SIDE = 32767
|
||||
const MAX_CANVAS_PIXELS = 268000000
|
||||
|
||||
function canvasToBlob(canvas: HTMLCanvasElement): Promise<Blob> {
|
||||
return new Promise((resolve, reject) => {
|
||||
canvas.toBlob((blob) => {
|
||||
if (blob) {
|
||||
resolve(blob)
|
||||
} else {
|
||||
reject(new Error('无法创建PNG图片'))
|
||||
}
|
||||
}, 'image/png')
|
||||
})
|
||||
}
|
||||
|
||||
function createSvgElement<K extends keyof SVGElementTagNameMap>(tag: K): SVGElementTagNameMap[K] {
|
||||
return document.createElementNS('http://www.w3.org/2000/svg', tag)
|
||||
}
|
||||
|
||||
function sanitizeSvgForCanvas(svg: SVGSVGElement): SVGSVGElement {
|
||||
const cloned = svg.cloneNode(true) as SVGSVGElement
|
||||
|
||||
// markmap 会在 SVG 的顶层 <g> 上写入当前预览视口的 pan/zoom transform。
|
||||
// 导出时我们按内容 bbox 裁剪,如果保留这个视口 transform,会产生双重偏移,
|
||||
// 导致图片内容跑到角落并留下大片空白。这里只移除顶层视口 transform,
|
||||
// 保留内部节点自身的布局 transform。
|
||||
cloned.querySelector(':scope > g')?.removeAttribute('transform')
|
||||
|
||||
cloned.querySelectorAll('image').forEach(el => el.remove())
|
||||
cloned.querySelectorAll('foreignObject').forEach((foreignObject) => {
|
||||
const textContent = foreignObject.textContent?.replace(/\s+/g, ' ').trim()
|
||||
if (!textContent) {
|
||||
foreignObject.remove()
|
||||
return
|
||||
}
|
||||
|
||||
const x = Number(foreignObject.getAttribute('x') || 0)
|
||||
const y = Number(foreignObject.getAttribute('y') || 0)
|
||||
const height = Number(foreignObject.getAttribute('height') || 20)
|
||||
const text = createSvgElement('text')
|
||||
text.setAttribute('x', String(x))
|
||||
text.setAttribute('y', String(y + height / 2))
|
||||
text.setAttribute('dominant-baseline', 'middle')
|
||||
text.setAttribute('font-size', '14')
|
||||
text.setAttribute('font-family', 'Arial, "Microsoft YaHei", sans-serif')
|
||||
text.setAttribute('fill', '#333')
|
||||
text.textContent = textContent
|
||||
foreignObject.replaceWith(text)
|
||||
})
|
||||
|
||||
return cloned
|
||||
}
|
||||
|
||||
function getExportFontSize(svg: SVGSVGElement): number {
|
||||
const text = svg.querySelector('text, foreignObject')
|
||||
if (!text) return 14
|
||||
|
||||
const fontSize = Number.parseFloat(getComputedStyle(text).fontSize || '')
|
||||
if (Number.isFinite(fontSize) && fontSize > 0) return fontSize
|
||||
|
||||
const attrSize = Number.parseFloat(text.getAttribute('font-size') || '')
|
||||
return Number.isFinite(attrSize) && attrSize > 0 ? attrSize : 14
|
||||
}
|
||||
|
||||
function getMindmapBounds(svg: SVGSVGElement) {
|
||||
const target = svg.querySelector('g') || svg
|
||||
const bbox = target.getBBox()
|
||||
const padding = 50
|
||||
return {
|
||||
x: Math.floor(bbox.x - padding),
|
||||
y: Math.floor(bbox.y - padding),
|
||||
width: Math.max(Math.ceil(bbox.width + padding * 2), 1),
|
||||
height: Math.max(Math.ceil(bbox.height + padding * 2), 1),
|
||||
}
|
||||
}
|
||||
|
||||
function stripMindmapImages(markdown: string) {
|
||||
return (markdown || '')
|
||||
// 思维导图只保留文字结构,图片节点会让预览排版和 PNG 导出效果都很差。
|
||||
.replace(/!\[[^\]]*\]\([^)]*\)/g, '')
|
||||
.replace(/<img\b[^>]*>/gi, '')
|
||||
}
|
||||
|
||||
function transformMindmap(markdown: string) {
|
||||
return transformer.transform(stripMindmapImages(markdown))
|
||||
}
|
||||
|
||||
function createExportSvg(svgEl: SVGSVGElement) {
|
||||
const bounds = getMindmapBounds(svgEl)
|
||||
const clonedSvg = sanitizeSvgForCanvas(svgEl)
|
||||
|
||||
clonedSvg.setAttribute('xmlns', 'http://www.w3.org/2000/svg')
|
||||
clonedSvg.setAttribute('xmlns:xlink', 'http://www.w3.org/1999/xlink')
|
||||
clonedSvg.setAttribute('width', String(bounds.width))
|
||||
clonedSvg.setAttribute('height', String(bounds.height))
|
||||
clonedSvg.setAttribute('viewBox', `${bounds.x} ${bounds.y} ${bounds.width} ${bounds.height}`)
|
||||
clonedSvg.setAttribute('preserveAspectRatio', 'xMidYMid meet')
|
||||
|
||||
const bgRect = document.createElementNS('http://www.w3.org/2000/svg', 'rect')
|
||||
bgRect.setAttribute('x', String(bounds.x))
|
||||
bgRect.setAttribute('y', String(bounds.y))
|
||||
bgRect.setAttribute('width', String(bounds.width))
|
||||
bgRect.setAttribute('height', String(bounds.height))
|
||||
bgRect.setAttribute('fill', 'white')
|
||||
const firstG = clonedSvg.querySelector('g')
|
||||
clonedSvg.insertBefore(bgRect, firstG || clonedSvg.firstChild)
|
||||
|
||||
return { clonedSvg, ...bounds }
|
||||
}
|
||||
|
||||
async function exportSvgToPngBlob(svgEl: SVGSVGElement): Promise<Blob> {
|
||||
const { clonedSvg, width, height } = createExportSvg(svgEl)
|
||||
const svgData = new XMLSerializer().serializeToString(clonedSvg)
|
||||
const svgUrl = URL.createObjectURL(new Blob([svgData], { type: 'image/svg+xml;charset=utf-8' }))
|
||||
|
||||
try {
|
||||
const img = new Image()
|
||||
img.decoding = 'async'
|
||||
img.src = svgUrl
|
||||
await img.decode()
|
||||
|
||||
// 按导图内容尺寸和字号动态反推 PNG 倍率,而不是按预览容器或固定倍率导出。
|
||||
const fontScale = MIN_EXPORT_FONT_PX / getExportFontSize(svgEl)
|
||||
const widthScale = MIN_EXPORT_WIDTH / width
|
||||
const rawScale = Math.max(window.devicePixelRatio || 1, fontScale, widthScale)
|
||||
const sideLimitScale = Math.min(MAX_CANVAS_SIDE / width, MAX_CANVAS_SIDE / height)
|
||||
const pixelLimitScale = Math.sqrt(MAX_CANVAS_PIXELS / (width * height))
|
||||
const baseScale = Math.min(rawScale, MAX_EXPORT_SCALE, sideLimitScale, pixelLimitScale)
|
||||
const scale = Math.max(1, baseScale * WEB_EXPORT_SCALE_FACTOR)
|
||||
|
||||
let currentScale = scale
|
||||
let lastError: unknown
|
||||
while (currentScale >= 1) {
|
||||
try {
|
||||
const canvas = document.createElement('canvas')
|
||||
canvas.width = Math.ceil(width * currentScale)
|
||||
canvas.height = Math.ceil(height * currentScale)
|
||||
|
||||
const ctx = canvas.getContext('2d')
|
||||
if (!ctx) {
|
||||
throw new Error('无法获取Canvas上下文')
|
||||
}
|
||||
|
||||
ctx.fillStyle = '#FFFFFF'
|
||||
ctx.fillRect(0, 0, canvas.width, canvas.height)
|
||||
ctx.setTransform(currentScale, 0, 0, currentScale, 0, 0)
|
||||
ctx.drawImage(img, 0, 0, width, height)
|
||||
ctx.setTransform(1, 0, 0, 1, 0, 0)
|
||||
|
||||
return await canvasToBlob(canvas)
|
||||
} catch (error) {
|
||||
lastError = error
|
||||
currentScale = Math.floor(currentScale / 2)
|
||||
}
|
||||
}
|
||||
throw lastError || new Error('导出PNG失败')
|
||||
} finally {
|
||||
URL.revokeObjectURL(svgUrl)
|
||||
}
|
||||
}
|
||||
|
||||
export interface MarkmapEditorProps {
|
||||
/** 要渲染的 Markdown 文本 */
|
||||
value: string
|
||||
@@ -34,6 +199,13 @@ export default function MarkmapEditor({
|
||||
|
||||
// 用于跟踪是否处于全屏状态
|
||||
const [isFullscreen, setIsFullscreen] = useState(false)
|
||||
const [pngAction, setPngAction] = useState<'idle' | 'exporting' | 'copying'>('idle')
|
||||
const [pngMessage, setPngMessage] = useState('')
|
||||
|
||||
const showPngMessage = (message: string) => {
|
||||
setPngMessage(message)
|
||||
window.setTimeout(() => setPngMessage(''), 2500)
|
||||
}
|
||||
|
||||
// 监听全屏状态变化
|
||||
useEffect(() => {
|
||||
@@ -64,7 +236,7 @@ export default function MarkmapEditor({
|
||||
// 导出HTML思维导图
|
||||
const exportHtml = () => {
|
||||
try {
|
||||
const { root } = transformer.transform(value)
|
||||
const { root } = transformMindmap(value)
|
||||
const data = JSON.stringify(root)
|
||||
|
||||
// 创建HTML内容
|
||||
@@ -202,7 +374,7 @@ export default function MarkmapEditor({
|
||||
// 导出XMind格式思维导图
|
||||
const exportXMind = async () => {
|
||||
try {
|
||||
const { root } = transformer.transform(value);
|
||||
const { root } = transformMindmap(value);
|
||||
|
||||
// 生成唯一ID
|
||||
const generateId = () => Math.random().toString(36).substring(2, 15);
|
||||
@@ -311,100 +483,44 @@ export default function MarkmapEditor({
|
||||
try {
|
||||
if (!svgRef.current || !mmRef.current) return;
|
||||
|
||||
const svgEl = svgRef.current;
|
||||
const mm = mmRef.current;
|
||||
|
||||
// 先调用fit()确保显示完整的思维导图内容
|
||||
await mm.fit();
|
||||
// 等待渲染完成
|
||||
await new Promise(resolve => setTimeout(resolve, 100));
|
||||
|
||||
// 获取SVG实际尺寸
|
||||
const svgWidth = svgEl.width.baseVal.value || svgEl.clientWidth || 800;
|
||||
const svgHeight = svgEl.height.baseVal.value || svgEl.clientHeight || 600;
|
||||
|
||||
// 设置足够大的缩放比例以确保高清输出
|
||||
const scale = 3;
|
||||
|
||||
// 克隆SVG以避免修改原始SVG
|
||||
const clonedSvg = svgEl.cloneNode(true) as SVGSVGElement;
|
||||
|
||||
// 设置SVG的背景为白色
|
||||
const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
|
||||
style.textContent = 'svg { background-color: white; }';
|
||||
clonedSvg.insertBefore(style, clonedSvg.firstChild);
|
||||
|
||||
// 确保SVG有正确的命名空间
|
||||
clonedSvg.setAttribute('xmlns', 'http://www.w3.org/2000/svg');
|
||||
clonedSvg.setAttribute('width', svgWidth.toString());
|
||||
clonedSvg.setAttribute('height', svgHeight.toString());
|
||||
|
||||
// 将SVG转换为Data URI (避免使用Blob URL来解决跨域问题)
|
||||
const svgData = new XMLSerializer().serializeToString(clonedSvg);
|
||||
const svgBase64 = btoa(unescape(encodeURIComponent(svgData)));
|
||||
const dataUri = `data:image/svg+xml;base64,${svgBase64}`;
|
||||
|
||||
// 创建Canvas
|
||||
const canvas = document.createElement('canvas');
|
||||
canvas.width = svgWidth * scale;
|
||||
canvas.height = svgHeight * scale;
|
||||
|
||||
// 获取上下文并设置白色背景
|
||||
const ctx = canvas.getContext('2d');
|
||||
if (!ctx) {
|
||||
throw new Error('无法获取Canvas上下文');
|
||||
}
|
||||
|
||||
// 设置白色背景
|
||||
ctx.fillStyle = '#FFFFFF';
|
||||
ctx.fillRect(0, 0, canvas.width, canvas.height);
|
||||
|
||||
// 创建Image对象
|
||||
const img = new Image();
|
||||
|
||||
// 当图片加载完成后,在Canvas上绘制并导出
|
||||
img.onload = () => {
|
||||
try {
|
||||
// 应用缩放
|
||||
ctx.setTransform(scale, 0, 0, scale, 0, 0);
|
||||
|
||||
// 绘制SVG
|
||||
ctx.drawImage(img, 0, 0);
|
||||
|
||||
// 重置变换
|
||||
ctx.setTransform(1, 0, 0, 1, 0, 0);
|
||||
|
||||
// 将Canvas转换为PNG Blob
|
||||
canvas.toBlob((blob) => {
|
||||
if (blob) {
|
||||
// 创建下载链接
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement('a');
|
||||
a.href = url;
|
||||
a.download = `${title || 'mindmap'}.png`;
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
URL.revokeObjectURL(url);
|
||||
} else {
|
||||
console.error('无法创建Blob对象');
|
||||
}
|
||||
}, 'image/png');
|
||||
} catch (err) {
|
||||
console.error('Canvas处理失败:', err);
|
||||
}
|
||||
};
|
||||
|
||||
// 设置图片加载错误处理
|
||||
img.onerror = (error) => {
|
||||
console.error('导出PNG失败(图片加载错误):', error);
|
||||
};
|
||||
|
||||
// 开始加载SVG图像 (使用Data URI而不是Blob URL)
|
||||
img.src = dataUri;
|
||||
|
||||
setPngAction('exporting');
|
||||
setPngMessage('正在生成高清 PNG…');
|
||||
const blob = await exportSvgToPngBlob(svgRef.current);
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement('a');
|
||||
a.href = url;
|
||||
a.download = `${title || 'mindmap'}.png`;
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
URL.revokeObjectURL(url);
|
||||
showPngMessage('PNG 已开始下载');
|
||||
} catch (error) {
|
||||
console.error('导出PNG失败:', error);
|
||||
showPngMessage('导出 PNG 失败,请查看控制台');
|
||||
} finally {
|
||||
setPngAction('idle');
|
||||
}
|
||||
};
|
||||
|
||||
// 复制PNG思维导图
|
||||
const copyPng = async () => {
|
||||
try {
|
||||
if (!svgRef.current || !mmRef.current) return;
|
||||
|
||||
setPngAction('copying');
|
||||
setPngMessage('正在复制高清 PNG…');
|
||||
await navigator.clipboard.write([
|
||||
new ClipboardItem({
|
||||
'image/png': exportSvgToPngBlob(svgRef.current),
|
||||
}),
|
||||
]);
|
||||
showPngMessage('PNG 已复制');
|
||||
} catch (error) {
|
||||
console.error('复制PNG失败:', error);
|
||||
showPngMessage('复制 PNG 失败,请查看控制台');
|
||||
} finally {
|
||||
setPngAction('idle');
|
||||
}
|
||||
};
|
||||
|
||||
@@ -428,7 +544,7 @@ export default function MarkmapEditor({
|
||||
useEffect(() => {
|
||||
const mm = mmRef.current
|
||||
if (!mm) return
|
||||
const { root } = transformer.transform(value)
|
||||
const { root } = transformMindmap(value)
|
||||
mm.setData(root).then(() => mm.fit())
|
||||
}, [value])
|
||||
|
||||
@@ -459,8 +575,17 @@ export default function MarkmapEditor({
|
||||
onClick={exportPng}
|
||||
className="rounded p-1 hover:bg-gray-200"
|
||||
title="导出PNG图片"
|
||||
disabled={pngAction !== 'idle'}
|
||||
>
|
||||
🖼️
|
||||
{pngAction === 'exporting' ? '⏳' : '🖼️'}
|
||||
</button>
|
||||
<button
|
||||
onClick={copyPng}
|
||||
className="rounded p-1 hover:bg-gray-200"
|
||||
title="复制PNG图片"
|
||||
disabled={pngAction !== 'idle'}
|
||||
>
|
||||
{pngAction === 'copying' ? '⏳' : '📋'}
|
||||
</button>
|
||||
<button
|
||||
onClick={exportHtml}
|
||||
@@ -483,6 +608,11 @@ export default function MarkmapEditor({
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
{pngMessage && (
|
||||
<div className="absolute top-11 right-2 z-20 rounded bg-white/95 px-2 py-1 text-xs text-gray-600 shadow">
|
||||
{pngMessage}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* 如果需要编辑区,就自己加一个 <textarea> 并把 handleChange 绑上 */}
|
||||
{/* <textarea value={value} onChange={handleChange} className="mb-2 p-2 border rounded" /> */}
|
||||
|
||||
@@ -10,13 +10,16 @@ import {
|
||||
SelectValue,
|
||||
} from '@/components/ui/select'
|
||||
import { Alert, AlertDescription } from '@/components/ui/alert'
|
||||
import { AudioLines, AlertTriangle, CheckCircle2, Download, Loader2, Save, XCircle } from 'lucide-react'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { AudioLines, AlertTriangle, CheckCircle2, Download, Loader2, Save, XCircle, Plus, Trash2, Boxes } from 'lucide-react'
|
||||
import { toast } from 'react-hot-toast'
|
||||
import {
|
||||
getTranscriberConfig,
|
||||
updateTranscriberConfig,
|
||||
getModelsStatus,
|
||||
downloadModel,
|
||||
addWhisperModel,
|
||||
deleteWhisperModel,
|
||||
TranscriberConfig,
|
||||
ModelStatus,
|
||||
} from '@/services/transcriber'
|
||||
@@ -33,6 +36,19 @@ export default function Transcriber() {
|
||||
const [modelStatuses, setModelStatuses] = useState<ModelStatus[]>([])
|
||||
const [mlxModelStatuses, setMlxModelStatuses] = useState<ModelStatus[]>([])
|
||||
const [mlxAvailable, setMlxAvailable] = useState(false)
|
||||
// 自定义模型表单
|
||||
const [newModelName, setNewModelName] = useState('')
|
||||
const [newModelTarget, setNewModelTarget] = useState('')
|
||||
const [addingModel, setAddingModel] = useState(false)
|
||||
|
||||
// 重新拉取配置(不重置用户当前的选择),用于增删自定义模型后刷新下拉与列表
|
||||
const reloadConfig = useCallback(async () => {
|
||||
try {
|
||||
setConfig(await getTranscriberConfig())
|
||||
} catch {
|
||||
// 静默
|
||||
}
|
||||
}, [])
|
||||
|
||||
const fetchModelsStatus = useCallback(async () => {
|
||||
try {
|
||||
@@ -123,6 +139,41 @@ export default function Transcriber() {
|
||||
}
|
||||
}
|
||||
|
||||
const handleAddCustomModel = async () => {
|
||||
const name = newModelName.trim()
|
||||
const target = newModelTarget.trim()
|
||||
if (!name || !target) {
|
||||
toast.error('请填写模型名称和 HF repo_id / 本地路径')
|
||||
return
|
||||
}
|
||||
setAddingModel(true)
|
||||
try {
|
||||
await addWhisperModel({ name, target })
|
||||
toast.success(`已添加自定义模型 ${name}`)
|
||||
setNewModelName('')
|
||||
setNewModelTarget('')
|
||||
await reloadConfig()
|
||||
await fetchModelsStatus()
|
||||
} catch {
|
||||
// 后端的具体错误(如重名)已由请求拦截器 toast,这里不重复提示
|
||||
} finally {
|
||||
setAddingModel(false)
|
||||
}
|
||||
}
|
||||
|
||||
const handleDeleteCustomModel = async (name: string) => {
|
||||
try {
|
||||
await deleteWhisperModel(name)
|
||||
toast.success(`已删除自定义模型 ${name}`)
|
||||
// 删的正好是当前选中的,回退到 tiny,避免选中一个不存在的名称
|
||||
if (selectedModelSize === name) setSelectedModelSize('tiny')
|
||||
await reloadConfig()
|
||||
await fetchModelsStatus()
|
||||
} catch {
|
||||
// 拦截器已提示
|
||||
}
|
||||
}
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<div className="flex h-64 items-center justify-center">
|
||||
@@ -272,6 +323,97 @@ export default function Transcriber() {
|
||||
</CardContent>
|
||||
</Card>
|
||||
)}
|
||||
|
||||
{/* 自定义 Whisper 模型(仅 fast-whisper:名称不符合内置 Systran 约定的模型在此登记映射) */}
|
||||
{selectedType === 'fast-whisper' && (
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle className="flex items-center gap-2 text-lg">
|
||||
<Boxes className="h-5 w-5" />
|
||||
自定义模型
|
||||
<span className="text-sm font-normal text-neutral-400">
|
||||
登记名称不符合内置约定的模型
|
||||
</span>
|
||||
</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-4">
|
||||
<Alert className="text-sm">
|
||||
<AlertDescription>
|
||||
填 <strong>HF repo_id</strong>(如{' '}
|
||||
<code className="rounded bg-neutral-100 px-1">Systran/faster-whisper-large-v3</code>
|
||||
,会自动下载)或<strong>本地模型目录</strong>(如{' '}
|
||||
<code className="rounded bg-neutral-100 px-1">/app/backend/models/my-whisper</code>
|
||||
,目录内需含 <code className="rounded bg-neutral-100 px-1">model.bin</code>,下载会跳过)。
|
||||
添加后即可在上方「模型大小」下拉中选用。Docker 部署请把模型目录挂载进容器(见 README 的{' '}
|
||||
<code className="rounded bg-neutral-100 px-1">models</code> 卷)。
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
|
||||
{config.whisper_custom_models &&
|
||||
Object.keys(config.whisper_custom_models).length > 0 ? (
|
||||
<div className="space-y-2">
|
||||
{Object.entries(config.whisper_custom_models).map(([name, target]) => {
|
||||
const status = modelStatuses.find(m => m.model_size === name)
|
||||
return (
|
||||
<div
|
||||
key={name}
|
||||
className="flex items-center justify-between gap-3 rounded-md border px-4 py-2.5"
|
||||
>
|
||||
<div className="min-w-0">
|
||||
<div className="flex items-center gap-2 font-medium">
|
||||
{name}
|
||||
{status?.downloaded && (
|
||||
<CheckCircle2 className="h-3.5 w-3.5 text-green-500" />
|
||||
)}
|
||||
{status?.downloading && (
|
||||
<Loader2 className="h-3.5 w-3.5 animate-spin text-neutral-400" />
|
||||
)}
|
||||
</div>
|
||||
<div className="truncate text-xs text-neutral-400" title={target}>
|
||||
{target}
|
||||
</div>
|
||||
</div>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
className="text-red-500 hover:text-red-600"
|
||||
onClick={() => handleDeleteCustomModel(name)}
|
||||
>
|
||||
<Trash2 className="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
) : (
|
||||
<p className="text-sm text-neutral-400">还没有自定义模型</p>
|
||||
)}
|
||||
|
||||
<div className="flex flex-col gap-2 sm:flex-row sm:items-center">
|
||||
<Input
|
||||
placeholder="模型名称(自定义,如 my-large-v3)"
|
||||
value={newModelName}
|
||||
onChange={e => setNewModelName(e.target.value)}
|
||||
className="sm:max-w-[220px]"
|
||||
/>
|
||||
<Input
|
||||
placeholder="HF repo_id 或本地路径"
|
||||
value={newModelTarget}
|
||||
onChange={e => setNewModelTarget(e.target.value)}
|
||||
className="flex-1"
|
||||
/>
|
||||
<Button onClick={handleAddCustomModel} disabled={addingModel}>
|
||||
{addingModel ? (
|
||||
<Loader2 className="mr-1 h-4 w-4 animate-spin" />
|
||||
) : (
|
||||
<Plus className="mr-1 h-4 w-4" />
|
||||
)}
|
||||
添加
|
||||
</Button>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,10 @@ export interface TranscriberConfig {
|
||||
whisper_model_size: string
|
||||
available_types: { value: string; label: string }[]
|
||||
whisper_model_sizes: string[]
|
||||
/** 内置模型映射:size → HF repo_id */
|
||||
whisper_builtin_models?: Record<string, string>
|
||||
/** 用户自定义模型映射:名称 → HF repo_id 或本地路径 */
|
||||
whisper_custom_models?: Record<string, string>
|
||||
mlx_whisper_available: boolean
|
||||
}
|
||||
|
||||
@@ -41,3 +45,23 @@ export const downloadModel = async (data: {
|
||||
}) => {
|
||||
return await request.post('/transcriber_download', data)
|
||||
}
|
||||
|
||||
export interface WhisperModelsResponse {
|
||||
builtin: Record<string, string>
|
||||
custom: Record<string, string>
|
||||
}
|
||||
|
||||
/** 列出内置 + 自定义 whisper 模型映射 */
|
||||
export const listWhisperModels = async (): Promise<WhisperModelsResponse> => {
|
||||
return await request.get('/whisper_models')
|
||||
}
|
||||
|
||||
/** 新增自定义模型映射(名称 → HF repo_id 或本地路径) */
|
||||
export const addWhisperModel = async (data: { name: string; target: string }) => {
|
||||
return await request.post('/whisper_models', data)
|
||||
}
|
||||
|
||||
/** 删除自定义模型映射(不会删除已下载的模型文件) */
|
||||
export const deleteWhisperModel = async (name: string) => {
|
||||
return await request.delete(`/whisper_models/${encodeURIComponent(name)}`)
|
||||
}
|
||||
|
||||
21
CHANGELOG.md
21
CHANGELOG.md
@@ -2,6 +2,27 @@
|
||||
|
||||
本项目所有重要变更记录于此。格式参考 [Keep a Changelog](https://keepachangelog.com/zh-CN/1.1.0/),遵循 [语义化版本](https://semver.org/lang/zh-CN/)。
|
||||
|
||||
## [2.4.0] - 2026-06-07
|
||||
|
||||
### Added
|
||||
|
||||
- **可配置 Whisper 模型**:转写设置支持自定义 Whisper 模型与名称映射,可指定自定义 HuggingFace repo 或本地路径(新增 `backend/app/transcriber/whisper_models.py` + 测试),转写设置页可选择 / 配置模型。
|
||||
- **关注公众号获取交流群**:关于页群二维码改为公众号二维码,关注公众号后回复「交流群」即可获取最新群二维码,避免群码过期失效。
|
||||
|
||||
### Fixed
|
||||
|
||||
- **浏览器扩展**:修复标题显示异常,优化脑图(Markmap)导出。
|
||||
|
||||
### Docs
|
||||
|
||||
- **GPU/CUDA 部署**:README 补全 GPU/CUDA 部署说明。
|
||||
|
||||
## [2.3.4] - 2026-05-27
|
||||
|
||||
### Added
|
||||
|
||||
- **一对一搭建服务二维码**:新增「BiliNote AI 笔记系统一对一搭建服务」二维码(README + 关于页),扫码加微信、备注「搭建服务」即可咨询。
|
||||
|
||||
## [2.3.3] - 2026-05-22
|
||||
|
||||
### Fixed
|
||||
|
||||
44
README.md
44
README.md
@@ -3,7 +3,7 @@
|
||||
<p align="center">
|
||||
<img src="./doc/icon.svg" alt="BiliNote Banner" width="50" height="50" />
|
||||
</p>
|
||||
<h1 align="center" > BiliNote v2.3.4</h1>
|
||||
<h1 align="center" > BiliNote v2.4.0</h1>
|
||||
</div>
|
||||
|
||||
<p align="center"><i>AI 视频笔记生成工具 让 AI 为你的视频做笔记</i></p>
|
||||
@@ -303,10 +303,42 @@ sudo apt install ffmpeg
|
||||
>
|
||||
> Docker 部署已内置 FFmpeg,无需额外安装。
|
||||
|
||||
### 🚀 CUDA 加速(可选)
|
||||
若你希望更快地执行音频转写任务,可使用具备 NVIDIA GPU 的机器,并启用 fast-whisper + CUDA 加速版本:
|
||||
### 🚀 CUDA / GPU 加速(可选)
|
||||
|
||||
具体 `fast-whisper` 配置方法,请参考:[fast-whisper 项目地址](http://github.com/SYSTRAN/faster-whisper#requirements)
|
||||
本地 **Faster Whisper** 转写可用 NVIDIA GPU 加速(在线引擎 Groq / 必剪 / 快手 与 GPU 无关)。仓库已自带 GPU 镜像与编排,**无需改代码、无需手动配置 device**——后端会自动检测 CUDA,可用就走 GPU,否则回退 CPU。
|
||||
|
||||
**1. 宿主机前提**
|
||||
|
||||
- NVIDIA 显卡 + 较新驱动(CUDA ≥ 12.4),宿主机 `nvidia-smi` 能正常输出;
|
||||
- 安装 [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)(最易漏的一步,没它 Docker 进不去 GPU)。装完验证:
|
||||
```bash
|
||||
docker run --rm --gpus all nvidia/cuda:12.4.1-base-ubuntu22.04 nvidia-smi
|
||||
```
|
||||
能列出显卡即 OK。
|
||||
|
||||
**2. 切换到 GPU 编排**(在源码目录里)
|
||||
|
||||
CPU 与 GPU 两套 compose 用了相同的容器名,先停掉当前栈再起 GPU 栈:
|
||||
|
||||
```bash
|
||||
docker-compose down # 停掉当前(CPU)栈
|
||||
docker-compose -f docker-compose.gpu.yml up --build -d # 用 GPU 栈重建
|
||||
```
|
||||
|
||||
- GPU 栈用 `backend/Dockerfile.gpu`(CUDA 12.4.1 + cuDNN 基础镜像,并额外装 torch 用于 CUDA 检测),compose 已声明 `deploy...devices: nvidia` 自动透传 GPU。
|
||||
- **数据不丢**:两套 compose 都把 `./backend` 整目录绑挂进容器,数据库 / 配置 / 已下载模型都保留。
|
||||
- 首次构建较大较慢(CUDA 基础镜像数 GB + torch),耐心等。
|
||||
|
||||
**3. 启用并确认**
|
||||
|
||||
- 「设置 → 音频转写配置」转写引擎选 **Faster Whisper(本地)**,GPU 下可放心选大模型(如 `large-v3`)。
|
||||
- 确认真的走了 GPU:`docker logs bilinote-backend | grep -i cuda` 看到 `CUDA 可用,使用 GPU`;或转写时宿主机 `nvidia-smi` 能看到 python 进程占显存。
|
||||
|
||||
**国内镜像**:GPU compose 支持 `BASE_REGISTRY` / `APT_MIRROR` / `PIP_INDEX` 这几个 build-arg(注意 `BASE_REGISTRY` 选的源必须支持 `nvidia/cuda` 命名空间,否则拉不到 CUDA 基础镜像)。
|
||||
|
||||
**起来了但没走 GPU?** 依次排查:① 宿主机 `nvidia-smi` 是否正常 → ② NVIDIA Container Toolkit 是否装好(上面 `--gpus all` 测试是否通过)→ ③ `docker logs bilinote-backend` 是否有 CUDA / cuDNN 报错(驱动 CUDA 版本需 ≥ 12.4)。
|
||||
|
||||
`fast-whisper` 本身的 GPU 依赖说明可参考:[faster-whisper 项目](https://github.com/SYSTRAN/faster-whisper#requirements)
|
||||
|
||||
### 🐳 使用 Docker 一键部署
|
||||
|
||||
@@ -338,8 +370,8 @@ docker run -d -p 80:80 \
|
||||
# 标准部署
|
||||
docker-compose up -d
|
||||
|
||||
# GPU 加速部署(需要 NVIDIA GPU)
|
||||
docker-compose -f docker-compose.gpu.yml up -d
|
||||
# GPU 加速部署(需要 NVIDIA GPU + NVIDIA Container Toolkit,详见上方「CUDA / GPU 加速」)
|
||||
docker-compose -f docker-compose.gpu.yml up --build -d
|
||||
```
|
||||
|
||||
## 🧠 TODO
|
||||
|
||||
@@ -61,16 +61,53 @@ WHISPER_MODEL_SIZES = ["tiny", "base", "small", "medium", "large-v3", "large-v3-
|
||||
@router.get("/transcriber_config")
|
||||
def get_transcriber_config():
|
||||
from app.transcriber.transcriber_provider import MLX_WHISPER_AVAILABLE
|
||||
from app.transcriber.whisper_models import get_registry, BUILTIN_WHISPER_MODELS
|
||||
|
||||
registry = get_registry()
|
||||
config = transcriber_config_manager.get_config()
|
||||
return R.success(data={
|
||||
**config,
|
||||
"available_types": AVAILABLE_TRANSCRIBER_TYPES,
|
||||
"whisper_model_sizes": WHISPER_MODEL_SIZES,
|
||||
# 内置可见档位 + 用户自定义模型,供前端下拉
|
||||
"whisper_model_sizes": registry.visible_model_names(),
|
||||
"whisper_builtin_models": BUILTIN_WHISPER_MODELS,
|
||||
"whisper_custom_models": registry.get_custom_models(),
|
||||
"mlx_whisper_available": MLX_WHISPER_AVAILABLE,
|
||||
})
|
||||
|
||||
|
||||
class WhisperCustomModelRequest(BaseModel):
|
||||
name: str
|
||||
target: str # HF repo_id(如 Systran/faster-whisper-large-v3)或本地模型目录路径
|
||||
|
||||
|
||||
@router.get("/whisper_models")
|
||||
def list_whisper_models():
|
||||
"""列出内置 + 用户自定义的 whisper 模型映射。"""
|
||||
from app.transcriber.whisper_models import get_registry, BUILTIN_WHISPER_MODELS
|
||||
reg = get_registry()
|
||||
return R.success(data={"builtin": BUILTIN_WHISPER_MODELS, "custom": reg.get_custom_models()})
|
||||
|
||||
|
||||
@router.post("/whisper_models")
|
||||
def add_whisper_model(data: WhisperCustomModelRequest):
|
||||
"""新增自定义 whisper 模型映射(名称 → HF repo_id 或本地路径)。"""
|
||||
from app.transcriber.whisper_models import get_registry
|
||||
try:
|
||||
custom = get_registry().add_custom_model(data.name, data.target)
|
||||
except ValueError as e:
|
||||
return R.error(msg=str(e))
|
||||
return R.success(data={"custom": custom}, msg="已添加自定义模型")
|
||||
|
||||
|
||||
@router.delete("/whisper_models/{name}")
|
||||
def delete_whisper_model(name: str):
|
||||
"""删除自定义 whisper 模型映射(不会删除已下载的模型文件)。"""
|
||||
from app.transcriber.whisper_models import get_registry
|
||||
custom = get_registry().remove_custom_model(name)
|
||||
return R.success(data={"custom": custom}, msg="已删除自定义模型")
|
||||
|
||||
|
||||
@router.post("/transcriber_config")
|
||||
def update_transcriber_config(data: TranscriberConfigRequest):
|
||||
config = transcriber_config_manager.update_config(
|
||||
@@ -119,14 +156,27 @@ _downloading: dict[str, str] = {} # model_size -> status ("downloading" | "done
|
||||
def _check_whisper_model_exists(model_size: str, subdir: str = "whisper") -> bool:
|
||||
"""检查指定 whisper 模型是否已下载完整到本地。
|
||||
|
||||
faster-whisper 把模型缓存在 HF cache 布局下:
|
||||
<model_dir>/models--Systran--faster-whisper-{size}/snapshots/<hash>/model.bin
|
||||
必须能在某个 snapshot 目录里找到 model.bin 才算完成。
|
||||
(历史 modelscope 布局 <model_dir>/whisper-{size}/model.bin 也兼容识别。)
|
||||
先把模型名 resolve 成可加载标识,再按类型判定:
|
||||
- 本地路径模型 → 直接看该目录下有没有 model.bin
|
||||
- HF repo_id → 看 HF cache 布局
|
||||
<model_dir>/models--{org}--{name}/snapshots/<hash>/model.bin
|
||||
(历史 modelscope 布局 <model_dir>/whisper-{size}/model.bin 也兼容识别)
|
||||
"""
|
||||
from app.transcriber.whisper_models import (
|
||||
resolve_whisper_model,
|
||||
is_local_target,
|
||||
hf_cache_dirname,
|
||||
)
|
||||
try:
|
||||
target = resolve_whisper_model(model_size)
|
||||
except Exception:
|
||||
return False
|
||||
if is_local_target(target):
|
||||
return (Path(target) / "model.bin").exists()
|
||||
|
||||
model_dir = Path(get_model_dir(subdir))
|
||||
# HF cache 布局
|
||||
hf_repo_dir = model_dir / f"models--Systran--faster-whisper-{model_size}" / "snapshots"
|
||||
# HF cache 布局(适配任意 org/repo,不再写死 Systran)
|
||||
hf_repo_dir = model_dir / hf_cache_dirname(target) / "snapshots"
|
||||
if hf_repo_dir.exists():
|
||||
for snapshot in hf_repo_dir.iterdir():
|
||||
if (snapshot / "model.bin").exists():
|
||||
@@ -157,9 +207,10 @@ def _check_mlx_whisper_model_exists(model_size: str) -> bool:
|
||||
|
||||
@router.get("/transcriber_models_status")
|
||||
def get_transcriber_models_status():
|
||||
"""返回所有 whisper 模型的下载状态。"""
|
||||
"""返回所有 whisper 模型的下载状态(含用户自定义模型)。"""
|
||||
from app.transcriber.whisper_models import get_registry
|
||||
statuses = []
|
||||
for size in WHISPER_MODEL_SIZES:
|
||||
for size in get_registry().visible_model_names():
|
||||
downloaded = _check_whisper_model_exists(size, "whisper")
|
||||
download_status = _downloading.get(size)
|
||||
statuses.append({
|
||||
@@ -198,13 +249,15 @@ class ModelDownloadRequest(BaseModel):
|
||||
|
||||
|
||||
def _do_download_whisper(model_size: str):
|
||||
"""后台下载 faster-whisper 模型。
|
||||
"""后台下载 faster-whisper 模型(支持内置 size / 自定义 repo_id / 本地路径)。
|
||||
|
||||
直接走 huggingface_hub.snapshot_download,把模型放到 HF cache 布局里——
|
||||
这样 faster-whisper 加载时(WhisperModel(model_size_or_path=size_name,
|
||||
download_root=model_dir))能直接命中缓存,跟加载路径完全对齐。
|
||||
模型名先 resolve:
|
||||
- 本地路径模型:无需下载,目录里有 model.bin 即 done,否则 failed;
|
||||
- HF repo_id:snapshot_download 到 HF cache 布局(cache_dir=model_dir),
|
||||
与加载逻辑 WhisperModel(download_root=model_dir) 完全对齐。
|
||||
"""
|
||||
from huggingface_hub import snapshot_download
|
||||
from app.transcriber.whisper_models import resolve_whisper_model, is_local_target
|
||||
|
||||
try:
|
||||
_downloading[model_size] = "downloading"
|
||||
@@ -214,12 +267,21 @@ def _do_download_whisper(model_size: str):
|
||||
if _check_whisper_model_exists(model_size, "whisper"):
|
||||
_downloading[model_size] = "done"
|
||||
return
|
||||
repo_id = f"Systran/faster-whisper-{model_size}"
|
||||
logger.info(f"开始下载 whisper 模型: {repo_id}")
|
||||
|
||||
target = resolve_whisper_model(model_size)
|
||||
if is_local_target(target):
|
||||
# 本地模型不下载,只校验 model.bin 是否就位
|
||||
ok = (Path(target) / "model.bin").exists()
|
||||
_downloading[model_size] = "done" if ok else "failed"
|
||||
if not ok:
|
||||
logger.warning(f"本地模型 {model_size} 路径 {target} 下没有 model.bin,无法使用")
|
||||
return
|
||||
|
||||
logger.info(f"开始下载 whisper 模型: {model_size} ← {target}")
|
||||
# 跟 faster-whisper utils.py 用同样的 allow_patterns,避免多下无关文件;
|
||||
# 不传 local_dir 让它走 HF 默认 cache 布局(与加载逻辑对齐)
|
||||
snapshot_download(
|
||||
repo_id,
|
||||
target,
|
||||
cache_dir=model_dir,
|
||||
allow_patterns=[
|
||||
"config.json",
|
||||
@@ -268,11 +330,11 @@ def _do_download_mlx_whisper(model_size: str):
|
||||
|
||||
@router.post("/transcriber_download")
|
||||
def download_transcriber_model(data: ModelDownloadRequest, background_tasks: BackgroundTasks):
|
||||
"""触发后台下载指定的 whisper 模型。"""
|
||||
if data.model_size not in WHISPER_MODEL_SIZES:
|
||||
return R.error(msg=f"不支持的模型大小: {data.model_size}")
|
||||
|
||||
"""触发后台下载指定的 whisper 模型(fast-whisper 支持内置档位 + 自定义模型)。"""
|
||||
if data.transcriber_type == "mlx-whisper":
|
||||
# mlx 只认内置档位(mlx-community 的固定映射)
|
||||
if data.model_size not in WHISPER_MODEL_SIZES:
|
||||
return R.error(msg=f"MLX 不支持的模型大小: {data.model_size}")
|
||||
if platform.system() != "Darwin":
|
||||
return R.error(msg="MLX Whisper 仅支持 macOS")
|
||||
key = f"mlx-{data.model_size}"
|
||||
@@ -280,6 +342,10 @@ def download_transcriber_model(data: ModelDownloadRequest, background_tasks: Bac
|
||||
return R.success(msg="模型正在下载中")
|
||||
background_tasks.add_task(_do_download_mlx_whisper, data.model_size)
|
||||
else:
|
||||
# fast-whisper:内置档位 / 自定义 repo_id / 本地路径都允许
|
||||
from app.transcriber.whisper_models import get_registry
|
||||
if not get_registry().is_known(data.model_size):
|
||||
return R.error(msg=f"不支持的模型: {data.model_size}(请先在自定义模型中登记)")
|
||||
if _downloading.get(data.model_size) == "downloading":
|
||||
return R.success(msg="模型正在下载中")
|
||||
background_tasks.add_task(_do_download_whisper, data.model_size)
|
||||
|
||||
@@ -3,6 +3,11 @@ from faster_whisper import WhisperModel
|
||||
from app.decorators.timeit import timeit
|
||||
from app.models.transcriber_model import TranscriptSegment, TranscriptResult
|
||||
from app.transcriber.base import Transcriber
|
||||
from app.transcriber.whisper_models import (
|
||||
resolve_whisper_model,
|
||||
is_local_target,
|
||||
hf_cache_dirname,
|
||||
)
|
||||
from app.utils.env_checker import is_cuda_available, is_torch_installed
|
||||
from app.utils.logger import get_logger
|
||||
from app.utils.path_helper import get_model_dir
|
||||
@@ -55,8 +60,12 @@ class WhisperTranscriber(Transcriber):
|
||||
self.model = self._build_model(model_size, model_dir)
|
||||
|
||||
def _build_model(self, model_size: str, model_dir: str) -> WhisperModel:
|
||||
# resolve 把模型名映射成可加载标识:内置 size→Systran repo_id、自定义映射、
|
||||
# 直通的 repo_id 或本地路径。faster-whisper 对本地目录走 os.path.isdir 分支,
|
||||
# 对 repo_id 走 download_model(cache_dir=download_root),两者都吃 model_size_or_path。
|
||||
target = resolve_whisper_model(model_size)
|
||||
return WhisperModel(
|
||||
model_size_or_path=model_size, # 传 size name,让 faster-whisper 自己映射到 Systran/faster-whisper-*
|
||||
model_size_or_path=target,
|
||||
device=self.device,
|
||||
compute_type=self.compute_type,
|
||||
download_root=model_dir,
|
||||
@@ -64,14 +73,23 @@ class WhisperTranscriber(Transcriber):
|
||||
|
||||
@staticmethod
|
||||
def _purge_cache(model_dir: str, model_size: str) -> None:
|
||||
"""删掉 HF cache 里这个 size 对应的 snapshot 目录,强制下次重新下载。
|
||||
"""加载失败时清掉对应 HF cache 的 snapshot 目录,强制下次重下。
|
||||
|
||||
HF cache 布局:<model_dir>/models--Systran--faster-whisper-{size}/
|
||||
没找到也不报错——可能用户改了 endpoint 或者 cache 布局变了。
|
||||
关键:本地路径模型**绝不删**——那是用户自己的文件,删了就是数据丢失;
|
||||
只清 HF cache 布局 <model_dir>/models--{org}--{name}/(含历史 modelscope 目录)。
|
||||
"""
|
||||
try:
|
||||
target = resolve_whisper_model(model_size)
|
||||
except Exception:
|
||||
target = model_size
|
||||
if is_local_target(target):
|
||||
logger.warning(
|
||||
f"模型 {model_size} 指向本地路径 {target},加载失败不清理用户文件,请检查该目录是否完整"
|
||||
)
|
||||
return
|
||||
candidates = [
|
||||
Path(model_dir) / f"models--Systran--faster-whisper-{model_size}",
|
||||
Path(model_dir) / f"whisper-{model_size}", # 历史 modelscope 目录,顺手清掉
|
||||
Path(model_dir) / hf_cache_dirname(target), # HF cache: models--org--name
|
||||
Path(model_dir) / f"whisper-{model_size}", # 历史 modelscope 目录,顺手清掉
|
||||
]
|
||||
for path in candidates:
|
||||
if path.exists():
|
||||
|
||||
156
backend/app/transcriber/whisper_models.py
Normal file
156
backend/app/transcriber/whisper_models.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""fast-whisper 模型名 → 可加载标识(HF repo_id 或本地路径)的映射注册表。
|
||||
|
||||
背景:faster-whisper 加载时 `WhisperModel(model_size_or_path=...)` 接受三种入参——
|
||||
内置 size 名、HF repo_id(含 "/")、或本地模型目录(`os.path.isdir` 命中则直接用)。
|
||||
此前后端把「size → Systran/faster-whisper-{size}」这层约定**隐式**散落在加载/下载/
|
||||
检测三处,用户想用命名不符合该约定的模型(比如社区微调版、或自己下到本地的模型)就接不上。
|
||||
|
||||
本模块把映射**显式化 + 可配置**(对齐 mlx_whisper_transcriber.MLX_MODEL_MAP 的模式):
|
||||
- 内置:size → Systran/faster-whisper-{size}
|
||||
- 自定义:用户在 config/whisper_models.json 登记 {名称: "<repo_id 或本地路径>"}
|
||||
(JSON 持久化;Docker 下随 config 卷持久化)
|
||||
|
||||
解析优先级(resolve):自定义 > 内置 > 直通(含 "/" 当 repo_id;已存在目录当本地路径)。
|
||||
加载 / 下载 / 完整性检测三处统一调用 resolve,路径不再各写各的。
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from app.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 内置模型:size → faster-whisper 兼容的 HF repo_id(CTranslate2 转换版,Systran 官方维护)。
|
||||
BUILTIN_WHISPER_MODELS: Dict[str, str] = {
|
||||
"tiny": "Systran/faster-whisper-tiny",
|
||||
"base": "Systran/faster-whisper-base",
|
||||
"small": "Systran/faster-whisper-small",
|
||||
"medium": "Systran/faster-whisper-medium",
|
||||
"large-v1": "Systran/faster-whisper-large-v1",
|
||||
"large-v2": "Systran/faster-whisper-large-v2",
|
||||
"large-v3": "Systran/faster-whisper-large-v3",
|
||||
"large-v3-turbo": "Systran/faster-whisper-large-v3-turbo",
|
||||
}
|
||||
|
||||
# 前端下拉默认展示的内置档位(保持与历史 WHISPER_MODEL_SIZES 一致,不把 8 个全列出来)
|
||||
DEFAULT_VISIBLE_BUILTINS: List[str] = ["tiny", "base", "small", "medium", "large-v3", "large-v3-turbo"]
|
||||
|
||||
|
||||
def is_local_target(target: str) -> bool:
|
||||
"""判断解析出的 target 是本地路径而非 HF repo_id。
|
||||
|
||||
HF repo_id 形如 'Org/Name'(恰一个斜杠、无前导斜杠、非已存在目录)。
|
||||
本地路径:绝对路径 / 以 . 或 ~ 开头 / 已存在的目录。
|
||||
"""
|
||||
if not target:
|
||||
return False
|
||||
if os.path.isabs(target) or target.startswith(".") or target.startswith("~"):
|
||||
return True
|
||||
return os.path.isdir(target)
|
||||
|
||||
|
||||
def hf_cache_dirname(repo_id: str) -> str:
|
||||
"""huggingface_hub snapshot 的本地缓存目录名:Org/Name → models--Org--Name。"""
|
||||
return "models--" + repo_id.replace("/", "--")
|
||||
|
||||
|
||||
class WhisperModelRegistry:
|
||||
"""内置 + 用户自定义的 whisper 模型映射,自定义部分持久化到 JSON。"""
|
||||
|
||||
def __init__(self, filepath: str = "config/whisper_models.json"):
|
||||
self.path = Path(filepath)
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ---- 持久化 ----
|
||||
def _read_custom(self) -> Dict[str, str]:
|
||||
if not self.path.exists():
|
||||
return {}
|
||||
try:
|
||||
with self.path.open("r", encoding="utf-8") as f:
|
||||
data = json.load(f) or {}
|
||||
except Exception as e:
|
||||
logger.warning(f"读取自定义 whisper 模型配置失败,按空处理: {e}")
|
||||
return {}
|
||||
out: Dict[str, str] = {}
|
||||
for name, val in data.items():
|
||||
if isinstance(val, str) and val.strip():
|
||||
out[name] = val.strip()
|
||||
elif isinstance(val, dict) and isinstance(val.get("target"), str):
|
||||
out[name] = val["target"].strip()
|
||||
return out
|
||||
|
||||
def _write_custom(self, data: Dict[str, str]) -> None:
|
||||
with self.path.open("w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# ---- 查询 ----
|
||||
def get_custom_models(self) -> Dict[str, str]:
|
||||
return self._read_custom()
|
||||
|
||||
def visible_model_names(self) -> List[str]:
|
||||
"""给前端下拉 / 下载状态用:默认可见内置档位 + 全部自定义名称。"""
|
||||
names = list(DEFAULT_VISIBLE_BUILTINS)
|
||||
for name in self._read_custom():
|
||||
if name not in names:
|
||||
names.append(name)
|
||||
return names
|
||||
|
||||
def is_known(self, name: str) -> bool:
|
||||
try:
|
||||
self.resolve(name)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def resolve(self, name: str) -> str:
|
||||
"""模型名 → 可加载标识(HF repo_id 或本地路径)。
|
||||
|
||||
优先级:自定义映射 > 内置映射 > 直通(含 "/" 的 repo_id / 已存在的本地目录)。
|
||||
无法识别时抛 ValueError。
|
||||
"""
|
||||
name = (name or "").strip()
|
||||
custom = self._read_custom()
|
||||
if name in custom:
|
||||
return custom[name]
|
||||
if name in BUILTIN_WHISPER_MODELS:
|
||||
return BUILTIN_WHISPER_MODELS[name]
|
||||
# 直通:用户直接把 repo_id(含 "/")或本地已存在目录当 model_size 传进来
|
||||
if "/" in name or os.path.isdir(name):
|
||||
return name
|
||||
raise ValueError(
|
||||
f"未知 whisper 模型 '{name}'。内置可选: {', '.join(BUILTIN_WHISPER_MODELS)};"
|
||||
"或在「音频转写配置」添加自定义模型(HF repo_id 或本地路径)。"
|
||||
)
|
||||
|
||||
# ---- 增删 ----
|
||||
def add_custom_model(self, name: str, target: str) -> Dict[str, str]:
|
||||
name = (name or "").strip()
|
||||
target = (target or "").strip()
|
||||
if not name or not target:
|
||||
raise ValueError("模型名称与目标(HF repo_id 或本地路径)都不能为空")
|
||||
if name in BUILTIN_WHISPER_MODELS:
|
||||
raise ValueError(f"'{name}' 与内置模型重名,请换一个名称")
|
||||
data = self._read_custom()
|
||||
data[name] = target
|
||||
self._write_custom(data)
|
||||
return data
|
||||
|
||||
def remove_custom_model(self, name: str) -> Dict[str, str]:
|
||||
data = self._read_custom()
|
||||
data.pop((name or "").strip(), None)
|
||||
self._write_custom(data)
|
||||
return data
|
||||
|
||||
|
||||
# 模块级单例
|
||||
_registry = WhisperModelRegistry()
|
||||
|
||||
|
||||
def get_registry() -> WhisperModelRegistry:
|
||||
return _registry
|
||||
|
||||
|
||||
def resolve_whisper_model(name: str) -> str:
|
||||
return _registry.resolve(name)
|
||||
132
backend/tests/test_whisper_models.py
Normal file
132
backend/tests/test_whisper_models.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Unit tests for app.transcriber.whisper_models(whisper 模型名→标识 的映射注册表)。
|
||||
|
||||
直接按文件路径加载被测模块,并桩掉 app.utils.logger,避免触发 app/__init__.py
|
||||
(会 import faster_whisper / ctranslate2 等重依赖),使本测试无需安装转写依赖即可运行。
|
||||
"""
|
||||
import importlib.util
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
import tempfile
|
||||
import types
|
||||
import unittest
|
||||
|
||||
ROOT = pathlib.Path(__file__).resolve().parents[1]
|
||||
MODULE_PATH = ROOT / "app" / "transcriber" / "whisper_models.py"
|
||||
|
||||
|
||||
def _load_module():
|
||||
if "app" not in sys.modules:
|
||||
app_pkg = types.ModuleType("app")
|
||||
app_pkg.__path__ = [] # 标记为 package
|
||||
sys.modules["app"] = app_pkg
|
||||
if "app.utils" not in sys.modules:
|
||||
utils_pkg = types.ModuleType("app.utils")
|
||||
utils_pkg.__path__ = []
|
||||
sys.modules["app.utils"] = utils_pkg
|
||||
if "app.utils.logger" not in sys.modules:
|
||||
logger_mod = types.ModuleType("app.utils.logger")
|
||||
logger_mod.get_logger = lambda name=None: logging.getLogger(name or "test")
|
||||
sys.modules["app.utils.logger"] = logger_mod
|
||||
spec = importlib.util.spec_from_file_location("whisper_models_under_test", MODULE_PATH)
|
||||
assert spec and spec.loader
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
return mod
|
||||
|
||||
|
||||
wm = _load_module()
|
||||
|
||||
|
||||
class TestResolve(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmp = tempfile.TemporaryDirectory()
|
||||
self.cfg = os.path.join(self.tmp.name, "whisper_models.json")
|
||||
self.reg = wm.WhisperModelRegistry(filepath=self.cfg)
|
||||
|
||||
def tearDown(self):
|
||||
self.tmp.cleanup()
|
||||
|
||||
def test_builtin_resolves_to_systran(self):
|
||||
self.assertEqual(self.reg.resolve("tiny"), "Systran/faster-whisper-tiny")
|
||||
self.assertEqual(self.reg.resolve("large-v3-turbo"), "Systran/faster-whisper-large-v3-turbo")
|
||||
|
||||
def test_passthrough_repo_id(self):
|
||||
# 用户直接把 HF repo_id 当 model_size 传进来(含 "/")
|
||||
self.assertEqual(self.reg.resolve("SomeOrg/my-whisper-ct2"), "SomeOrg/my-whisper-ct2")
|
||||
|
||||
def test_unknown_raises(self):
|
||||
with self.assertRaises(ValueError):
|
||||
self.reg.resolve("definitely-not-a-model")
|
||||
|
||||
def test_custom_overrides_and_persists(self):
|
||||
self.reg.add_custom_model("myhf", "someorg/whisper-ct2")
|
||||
self.assertEqual(self.reg.resolve("myhf"), "someorg/whisper-ct2")
|
||||
# 新实例读同一文件 → 确认持久化(Docker 下随 config 卷保留)
|
||||
reg2 = wm.WhisperModelRegistry(filepath=self.cfg)
|
||||
self.assertEqual(reg2.resolve("myhf"), "someorg/whisper-ct2")
|
||||
|
||||
def test_custom_can_override_builtin_key_resolution(self):
|
||||
# 自定义优先级高于内置:把 "tiny" 强行指到别的 repo(resolve 层允许;add 层禁止重名)
|
||||
self.reg._write_custom({"tiny": "Other/tiny-ct2"})
|
||||
self.assertEqual(self.reg.resolve("tiny"), "Other/tiny-ct2")
|
||||
|
||||
def test_local_path_resolution_and_detection(self):
|
||||
model_dir = os.path.join(self.tmp.name, "mymodel")
|
||||
os.makedirs(model_dir)
|
||||
self.reg.add_custom_model("local1", model_dir)
|
||||
self.assertEqual(self.reg.resolve("local1"), model_dir)
|
||||
self.assertTrue(wm.is_local_target(self.reg.resolve("local1")))
|
||||
|
||||
def test_bare_existing_dir_passthrough(self):
|
||||
# 没登记、但直接传一个已存在目录 → 直通为本地路径
|
||||
model_dir = os.path.join(self.tmp.name, "bare")
|
||||
os.makedirs(model_dir)
|
||||
self.assertEqual(self.reg.resolve(model_dir), model_dir)
|
||||
|
||||
def test_add_rejects_builtin_collision_and_empty(self):
|
||||
with self.assertRaises(ValueError):
|
||||
self.reg.add_custom_model("tiny", "x/y") # 与内置重名
|
||||
with self.assertRaises(ValueError):
|
||||
self.reg.add_custom_model("", "x/y")
|
||||
with self.assertRaises(ValueError):
|
||||
self.reg.add_custom_model("ok", "")
|
||||
|
||||
def test_remove(self):
|
||||
self.reg.add_custom_model("tmpm", "a/b")
|
||||
self.assertIn("tmpm", self.reg.get_custom_models())
|
||||
self.reg.remove_custom_model("tmpm")
|
||||
self.assertNotIn("tmpm", self.reg.get_custom_models())
|
||||
|
||||
def test_visible_includes_builtin_and_custom(self):
|
||||
self.reg.add_custom_model("zzz", "a/b")
|
||||
names = self.reg.visible_model_names()
|
||||
self.assertIn("tiny", names)
|
||||
self.assertIn("large-v3", names)
|
||||
self.assertIn("zzz", names)
|
||||
|
||||
def test_is_known(self):
|
||||
self.assertTrue(self.reg.is_known("base"))
|
||||
self.assertTrue(self.reg.is_known("Org/Name"))
|
||||
self.assertFalse(self.reg.is_known("nope-not-real"))
|
||||
|
||||
|
||||
class TestHelpers(unittest.TestCase):
|
||||
def test_hf_cache_dirname(self):
|
||||
self.assertEqual(
|
||||
wm.hf_cache_dirname("Systran/faster-whisper-tiny"),
|
||||
"models--Systran--faster-whisper-tiny",
|
||||
)
|
||||
self.assertEqual(wm.hf_cache_dirname("Org/Name"), "models--Org--Name")
|
||||
|
||||
def test_is_local_target(self):
|
||||
self.assertTrue(wm.is_local_target("/abs/path"))
|
||||
self.assertTrue(wm.is_local_target("./rel"))
|
||||
self.assertTrue(wm.is_local_target("~/home/model"))
|
||||
self.assertFalse(wm.is_local_target("Org/Name")) # repo_id 不是本地路径
|
||||
self.assertFalse(wm.is_local_target(""))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user