feat(NoteForm): 增加文件上传状态反馈

This commit is contained in:
JefferyHcool
2025-06-19 14:54:51 +08:00
parent 2b0fb8f4ad
commit d92cc4a977
24 changed files with 777 additions and 374 deletions

View File

@@ -24,6 +24,7 @@
"@radix-ui/react-tabs": "^1.1.9",
"@radix-ui/react-tooltip": "^1.1.8",
"@tailwindcss/vite": "^4.1.3",
"@tauri-apps/plugin-shell": "~2.2.2",
"@uiw/react-markdown-preview": "^5.1.3",
"antd": "^5.24.8",
"axios": "^1.8.4",
@@ -65,6 +66,7 @@
"devDependencies": {
"@eslint/js": "^9.21.0",
"@tailwindcss/postcss": "^4.1.3",
"@tauri-apps/cli": "^2.5.0",
"@types/node": "^22.14.0",
"@types/react": "^19.0.10",
"@types/react-dom": "^19.0.4",

View File

@@ -3,7 +3,6 @@ import { createRoot } from 'react-dom/client'
import './index.css'
import App from './App.tsx'
import RootLayout from './layouts/RootLayout.tsx'
createRoot(document.getElementById('root')!).render(
<StrictMode>
<RootLayout>

View File

@@ -38,6 +38,7 @@ import { Input } from '@/components/ui/input.tsx'
import { Textarea } from '@/components/ui/textarea.tsx'
import { noteStyles, noteFormats, videoPlatforms } from '@/constant/note.ts'
import { fetchModels } from '@/services/model.ts'
import { useNavigate } from 'react-router-dom'
/* -------------------- 校验 Schema -------------------- */
const formSchema = z
@@ -119,6 +120,8 @@ const CheckboxGroup = ({
/* -------------------- 主组件 -------------------- */
const NoteForm = () => {
const navigate = useNavigate();
/* ---- 全局状态 ---- */
const { addPendingTask, currentTaskId, setCurrentTask, getCurrentTask, retryTask } =
useTaskStore()
@@ -144,6 +147,9 @@ const NoteForm = () => {
const videoUnderstandingEnabled = useWatch({ control: form.control, name: 'video_understanding' })
const editing = currentTask && currentTask.id
const goModelAdd = () => {
navigate("/settings/model");
};
/* ---- 副作用 ---- */
useEffect(() => {
loadEnabledModels()
@@ -186,8 +192,8 @@ const NoteForm = () => {
const formData = new FormData()
formData.append('file', file)
try {
const { data } = await uploadFile(formData)
if (data.code === 0) cb(data.data.url)
const data = await uploadFile(formData)
cb(data.url)
} catch (err) {
console.error('上传失败:', err)
message.error('上传失败,请重试')
@@ -348,38 +354,50 @@ const NoteForm = () => {
/>
<div className="grid grid-cols-2 gap-2">
{/* 模型选择 */}
<FormField
className="w-full"
control={form.control}
name="model_name"
render={({ field }) => (
<FormItem>
<SectionHeader title="模型选择" tip="不同模型效果不同,建议自行测试" />
<Select
onOpenChange={()=>{
loadEnabledModels()
}}
value={field.value}
onValueChange={field.onChange}
defaultValue={field.value}
>
<FormControl>
<SelectTrigger className="w-full min-w-0 truncate">
<SelectValue />
</SelectTrigger>
</FormControl>
<SelectContent>
{modelList.map(m => (
<SelectItem key={m.id} value={m.model_name}>
{m.model_name}
</SelectItem>
))}
</SelectContent>
</Select>
<FormMessage />
</FormItem>
)}
/>
{
modelList.length>0?( <FormField
className="w-full"
control={form.control}
name="model_name"
render={({ field }) => (
<FormItem>
<SectionHeader title="模型选择" tip="不同模型效果不同,建议自行测试" />
<Select
onOpenChange={()=>{
loadEnabledModels()
}}
value={field.value}
onValueChange={field.onChange}
defaultValue={field.value}
>
<FormControl>
<SelectTrigger className="w-full min-w-0 truncate">
<SelectValue />
</SelectTrigger>
</FormControl>
<SelectContent>
{modelList.map(m => (
<SelectItem key={m.id} value={m.model_name}>
{m.model_name}
</SelectItem>
))}
</SelectContent>
</Select>
<FormMessage />
</FormItem>
)}
/>): (
<FormItem>
<SectionHeader title="模型选择" tip="不同模型效果不同,建议自行测试" />
<Button type={'button'} variant={
'outline'
} onClick={()=>{goModelAdd()}}></Button>
<FormMessage />
</FormItem>
)
}
{/* 笔记风格 */}
<FormField
className="w-full"

View File

@@ -49,13 +49,9 @@ export const delete_task = async ({ video_id, platform }) => {
}
const res = await request.post('/delete_task', data)
if (res.data.code === 0) {
toast.success('任务已成功删除')
return res.data
} else {
toast.error(res.data.message || '删除失败')
throw new Error(res.data.message || '删除失败')
}
return res
} catch (e) {
toast.error('请求异常,删除任务失败')
console.error('❌ 删除任务失败:', e)

View File

@@ -4,8 +4,8 @@ from .routers import note, provider, model, config
def create_app() -> FastAPI:
app = FastAPI(title="BiliNote")
def create_app(lifespan) -> FastAPI:
app = FastAPI(title="BiliNote",lifespan=lifespan)
app.include_router(note.router, prefix="/api")
app.include_router(provider.router, prefix="/api")
app.include_router(model.router,prefix="/api")

36
backend/app/db/engine.py Normal file
View File

@@ -0,0 +1,36 @@
import os
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, declarative_base
from dotenv import load_dotenv
load_dotenv()
# 默认 SQLite如果想换 PostgreSQL 或 MySQL可以直接改 .env
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///bili_note.db")
# SQLite 需要特定连接参数,其他数据库不需要
engine_args = {}
if DATABASE_URL.startswith("sqlite"):
engine_args["connect_args"] = {"check_same_thread": False}
engine = create_engine(
DATABASE_URL,
echo=os.getenv("SQLALCHEMY_ECHO", "false").lower() == "true",
**engine_args
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
def get_engine():
return engine
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()

View File

@@ -0,0 +1,9 @@
from app.db.models.models import Model
from app.db.models.providers import Provider
from app.db.models.video_tasks import VideoTask
from app.db.engine import get_engine, Base
def init_db():
engine = get_engine()
Base.metadata.create_all(bind=engine)

View File

@@ -1,67 +1,67 @@
from app.db.sqlite_client import get_connection
def init_model_table():
conn = get_connection()
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS models (
id INTEGER PRIMARY KEY AUTOINCREMENT,
provider_id INTEGER NOT NULL,
model_name TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
conn.commit()
conn.close()
from app.db.engine import get_db
from app.db.models.models import Model
def get_model_by_provider_and_name(provider_id: int, model_name: str):
conn = get_connection()
cursor = conn.execute(
"SELECT * FROM models WHERE provider_id = ? AND model_name = ?",
(provider_id, model_name)
)
row = cursor.fetchone()
return row
# 插入模型
db = next(get_db())
try:
model = db.query(Model).filter_by(provider_id=provider_id, model_name=model_name).first()
if model:
return {
"id": model.id,
"provider_id": model.provider_id,
"model_name": model.model_name,
"created_at": model.created_at,
}
return None
finally:
db.close()
def insert_model(provider_id: int, model_name: str):
conn = get_connection()
cursor = conn.cursor()
cursor.execute("""
INSERT INTO models (provider_id, model_name)
VALUES (?, ?)
""", (provider_id, model_name))
conn.commit()
conn.close()
db = next(get_db())
try:
model = Model(provider_id=provider_id, model_name=model_name)
db.add(model)
db.commit()
db.refresh(model)
return {
"id": model.id,
"provider_id": model.provider_id,
"model_name": model.model_name,
"created_at": model.created_at,
}
finally:
db.close()
# 根据provider查模型
def get_models_by_provider(provider_id: int):
conn = get_connection()
cursor = conn.cursor()
cursor.execute("""
SELECT id, model_name FROM models
WHERE provider_id = ?
""", (provider_id,))
rows = cursor.fetchall()
conn.close()
return [{"id": row[0], "model_name": row[1]} for row in rows]
db = next(get_db())
try:
models = db.query(Model).filter_by(provider_id=provider_id).all()
return [{"id": m.id, "model_name": m.model_name} for m in models]
finally:
db.close()
# 删除某个模型
def delete_model(model_id: int):
conn = get_connection()
cursor = conn.cursor()
cursor.execute("""
DELETE FROM models WHERE id = ?
""", (model_id,))
conn.commit()
conn.close()
db = next(get_db())
try:
model = db.query(Model).filter_by(id=model_id).first()
if model:
db.delete(model)
db.commit()
finally:
db.close()
def get_all_models():
conn = get_connection()
cursor = conn.cursor()
cursor.execute("""
SELECT id, provider_id, model_name FROM models
""")
rows = cursor.fetchall()
conn.close()
return [{"id": row[0], "provider_id": row[1], "model_name": row[2]} for row in rows]
db = next(get_db())
try:
models = db.query(Model).all()
return [
{"id": m.id, "provider_id": m.provider_id, "model_name": m.model_name}
for m in models
]
finally:
db.close()

View File

View File

@@ -0,0 +1,12 @@
from sqlalchemy import Column, Integer, String, DateTime, func, ForeignKey
from app.db.engine import Base
class Model(Base):
__tablename__ = "models"
id = Column(Integer, primary_key=True, autoincrement=True)
provider_id = Column(Integer, nullable=False)
model_name = Column(String, nullable=False)
created_at = Column(DateTime, server_default=func.now())

View File

@@ -0,0 +1,17 @@
from sqlalchemy import Column, String, Integer, DateTime, func
from sqlalchemy.orm import declarative_base
from app.db.engine import Base
class Provider(Base):
__tablename__ = "providers"
id = Column(String, primary_key=True)
name = Column(String, nullable=False)
logo = Column(String, nullable=False)
type = Column(String, nullable=False)
api_key = Column(String, nullable=False)
base_url = Column(String, nullable=False)
enabled = Column(Integer, default=1)
created_at = Column(DateTime, server_default=func.now())

View File

@@ -0,0 +1,14 @@
from sqlalchemy import Column, Integer, String, DateTime, func
from sqlalchemy.orm import declarative_base
from app.db.engine import Base
class VideoTask(Base):
__tablename__ = "video_tasks"
id = Column(Integer, primary_key=True, autoincrement=True)
video_id = Column(String, nullable=False)
platform = Column(String, nullable=False)
task_id = Column(String, unique=True, nullable=False)
created_at = Column(DateTime, server_default=func.now())

View File

@@ -1,14 +1,13 @@
import json
import os
import sys
from app.db.sqlite_client import get_connection
from app.db.models.providers import Provider
from app.utils.logger import get_logger
from app.db.engine import get_engine, Base, get_db
logger = get_logger(__name__)
def get_builtin_providers_path():
if getattr(sys, 'frozen', False):
base_path = sys._MEIPASS
@@ -16,213 +15,115 @@ def get_builtin_providers_path():
base_path = os.path.dirname(__file__)
return os.path.join(base_path, 'builtin_providers.json')
def seed_default_providers():
conn = get_connection()
if conn is None:
logger.error("Failed to connect to database.")
return
cursor = conn.cursor()
# 检查已有数据
cursor.execute("SELECT COUNT(*) FROM providers")
count = cursor.fetchone()[0]
if count > 0:
logger.info("Providers already exist, skipping seed.")
conn.close()
return
json_path = get_builtin_providers_path()
db = next(get_db())
try:
with open(json_path, 'r', encoding='utf-8') as f:
providers = json.load(f)
except Exception as e:
logger.error(f"Failed to read builtin_providers.json: {e}")
conn.close()
return
if db.query(Provider).count() > 0:
logger.info("Providers already exist, skipping seed.")
return
json_path = get_builtin_providers_path()
try:
with open(json_path, 'r', encoding='utf-8') as f:
providers = json.load(f)
except Exception as e:
logger.error(f"Failed to read builtin_providers.json: {e}")
return
try:
for p in providers:
cursor.execute("""
INSERT INTO providers (id, name, api_key, base_url, logo, type, enabled)
VALUES (?, ?, ?, ?, ?, ?, ?)
""", (
p['id'],
p['name'],
p['api_key'],
p['base_url'],
p['logo'],
p['type'],
p.get('enabled', 1)
db.add(Provider(
id=p['id'],
name=p['name'],
api_key=p['api_key'],
base_url=p['base_url'],
logo=p['logo'],
type=p['type'],
enabled=p.get('enabled', 1)
))
conn.commit()
db.commit()
logger.info("Default providers seeded successfully.")
except Exception as e:
logger.error(f"Failed to seed default providers: {e}")
finally:
conn.close()
def init_provider_table():
conn = get_connection()
if conn is None:
logger.error("Failed to connect to the database.")
return
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS providers (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
logo TEXT NOT NULL,
type TEXT NOT NULL,
api_key TEXT NOT NULL,
base_url TEXT NOT NULL,
enabled INTEGER DEFAULT 1,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
db.close()
def insert_provider(id: str, name: str, api_key: str, base_url: str, logo: str, type_: str, enabled: int = 1):
db = next(get_db())
try:
conn.commit()
conn.close()
logger.info("provider table created successfully.")
seed_default_providers()
except Exception as e:
logger.error(f"Failed to create provider table: {e}")
def insert_provider(id: str, name: str, api_key: str, base_url: str, logo: str, type_: str,enabled:int=1):
conn = get_connection()
if conn is None:
logger.error("Failed to connect to the database.")
return
cursor = conn.cursor()
cursor.execute("""
INSERT INTO providers (id, name, api_key, base_url, logo, type, enabled)
VALUES (?, ?, ?, ?, ?, ?, ?)
""", (id, name, api_key, base_url, logo, type_, enabled))
try:
conn.commit()
conn.close()
provider = Provider(id=id, name=name, api_key=api_key, base_url=base_url, logo=logo, type=type_, enabled=enabled)
db.add(provider)
db.commit()
logger.info(f"Provider inserted successfully. id: {id}, name: {name}, type: {type_}")
return id
except Exception as e:
logger.error(f"Failed to insert provider: {e}")
return None
finally:
db.close()
def get_enabled_providers():
conn = get_connection()
if conn is None:
logger.error("Failed to connect to the database.")
return
cursor = conn.cursor()
cursor.execute("SELECT * FROM providers WHERE enabled = 1")
db = next(get_db())
try:
rows = cursor.fetchall()
conn.close()
if rows is None:
logger.info("No providers found")
return None
logger.info(f"Providers found: {rows}")
return rows
except Exception as e:
logger.error(f"Failed to get enabled providers: {e}")
return db.query(Provider).filter_by(enabled=1).all()
finally:
db.close()
def get_provider_by_name(name: str):
conn = get_connection()
if conn is None:
logger.error("Failed to connect to the database.")
return
cursor = conn.cursor()
cursor.execute("SELECT * FROM providers WHERE name = ?", (name,))
db = next(get_db())
try:
row = cursor.fetchone()
conn.close()
if row is None:
logger.info(f"Provider not found: {name}")
return None
logger.info(f"Provider found: {row[0]}")
return db.query(Provider).filter_by(name=name).first()
finally:
db.close()
return row
except Exception as e:
logger.error(f"Failed to get provider by name: {e}")
def get_provider_by_id(id: int):
conn = get_connection()
if conn is None:
logger.error("Failed to connect to the database.")
return
cursor = conn.cursor()
cursor.execute("SELECT * FROM providers WHERE id = ?", (id,))
def get_provider_by_id(id: str):
db = next(get_db())
try:
row = cursor.fetchone()
conn.close()
if row is None:
logger.info(f"Provider not found: {id}")
return None
logger.info(f"Provider found: {row[0]}")
return row
except Exception as e:
logger.error(f"Failed to get provider by id: {e}")
return db.query(Provider).filter_by(id=id).first()
finally:
db.close()
def get_all_providers():
conn = get_connection()
if conn is None:
logger.error("Failed to connect to the database.")
return
cursor = conn.cursor()
cursor.execute("SELECT * FROM providers")
db = next(get_db())
try:
rows = cursor.fetchall()
conn.close()
if rows is None:
logger.info("No providers found")
return None
logger.info(f"Providers found total {len(rows) }")
return rows
except Exception as e:
logger.error(f"Failed to get all providers: {e}")
return db.query(Provider).all()
finally:
db.close()
def update_provider(id: str, **kwargs):
conn = get_connection()
if conn is None:
logger.error("Failed to connect to the database.")
return
fields = []
values = []
for key, value in kwargs.items():
fields.append(f"{key} = ?")
values.append(value)
if not fields:
logger.warning("No fields provided for update.")
return
sql = f"""
UPDATE providers
SET {', '.join(fields)}
WHERE id = ?
"""
values.append(id) # id 最后加
cursor = conn.cursor()
db = next(get_db())
try:
cursor.execute(sql, values)
conn.commit()
conn.close()
logger.info(f"Provider updated successfully. id: {id}, updated_fields: {fields}")
provider = db.query(Provider).filter_by(id=id).first()
if not provider:
logger.warning(f"Provider {id} not found for update.")
return
for key, value in kwargs.items():
if hasattr(provider, key):
setattr(provider, key, value)
db.commit()
logger.info(f"Provider updated successfully. id: {id}, updated_fields: {list(kwargs.keys())}")
except Exception as e:
logger.error(f"Failed to update provider: {e}")
finally:
db.close()
def delete_provider(id: int):
conn = get_connection()
if conn is None:
logger.error("Failed to connect to the database.")
return
cursor = conn.cursor()
cursor.execute("DELETE FROM providers WHERE id = ?", (id,))
def delete_provider(id: str):
db = next(get_db())
try:
conn.commit()
conn.close()
logger.info(f"Provider deleted successfully. id: {id}")
provider = db.query(Provider).filter_by(id=id).first()
if provider:
db.delete(provider)
db.commit()
logger.info(f"Provider deleted successfully. id: {id}")
except Exception as e:
logger.error(f"Failed to delete provider: {e}")
logger.error(f"Failed to delete provider: {e}")
finally:
db.close()

View File

@@ -1,78 +1,61 @@
from .sqlite_client import get_connection
from app.db.models.video_tasks import VideoTask
from app.db.engine import get_db
from app.utils.logger import get_logger
logger = get_logger(__name__)
def init_video_task_table():
conn = get_connection()
if conn is None:
logger.error("Failed to connect to the database.")
return
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS video_tasks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
video_id TEXT NOT NULL,
platform TEXT NOT NULL,
task_id TEXT NOT NULL UNIQUE,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
try:
conn.commit()
conn.close()
logger.info("video_tasks table created successfully.")
except Exception as e:
logger.error(f"Failed to create video_tasks table: {e}")
# 插入任务
def insert_video_task(video_id: str, platform: str, task_id: str):
db = next(get_db())
try:
conn = get_connection()
cursor = conn.cursor()
cursor.execute("""
INSERT INTO video_tasks (video_id, platform, task_id)
VALUES (?, ?, ?)
""", (video_id, platform, task_id))
conn.commit()
conn.close()
logger.info(f"Video task inserted successfully."
f"video_id: {video_id}"
f"platform: {platform}"
f"task_id: {task_id}")
task = VideoTask(video_id=video_id, platform=platform, task_id=task_id)
db.add(task)
db.commit()
db.refresh(task)
logger.info(f"Video task inserted successfully. video_id: {video_id}, platform: {platform}, task_id: {task_id}")
except Exception as e:
logger.error(f"Failed to insert video task: {e}")
finally:
db.close()
# 查询任务(最新一条)
def get_task_by_video(video_id: str, platform: str):
db = next(get_db())
try:
conn = get_connection()
cursor = conn.cursor()
cursor.execute("""
SELECT task_id FROM video_tasks
WHERE video_id = ? AND platform = ?
ORDER BY created_at DESC
LIMIT 1
""", (video_id, platform))
result = cursor.fetchone()
conn.close()
if result is None:
task = (
db.query(VideoTask)
.filter_by(video_id=video_id, platform=platform)
.order_by(VideoTask.created_at.desc())
.first()
)
if task:
logger.info(f"Task found for video_id: {video_id} and platform: {platform}")
return task.task_id
else:
logger.info(f"No task found for video_id: {video_id} and platform: {platform}")
logger.info(f"Task found for video_id: {video_id} and platform: {platform}")
return result[0] if result else None
return None
except Exception as e:
logger.error(f"Failed to get task by video: {e}")
finally:
db.close()
# 删除任务
def delete_task_by_video(video_id: str, platform: str):
db = next(get_db())
try:
conn = get_connection()
cursor = conn.cursor()
cursor.execute("""
DELETE FROM video_tasks
WHERE video_id = ? AND platform = ?
""", (video_id, platform))
conn.commit()
conn.close()
logger.info(f"Task deleted for video_id: {video_id} and platform: {platform}")
tasks = (
db.query(VideoTask)
.filter_by(video_id=video_id, platform=platform)
.all()
)
for task in tasks:
db.delete(task)
db.commit()
logger.info(f"Task(s) deleted for video_id: {video_id} and platform: {platform}")
except Exception as e:
logger.error(f"Failed to delete task by video: {e}")
logger.error(f"Failed to delete task by video: {e}")
finally:
db.close()

View File

@@ -0,0 +1,25 @@
from typing import Union, Optional
import requests
from app.downloaders.base import Downloader
from app.enmus.note_enums import DownloadQuality
from app.models.audio_model import AudioDownloadResult
url='https://www.xiaoyuzhoufm.com/_next/data/5Pvt_oGntgdyBD_XgwBaB/podcast/62382c1103bea1ebfffa1c00.json?id=62382c1103bea1ebfffa1c00'
header ={
'user-agent':'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/137.0.0.0 Safari/537.36'
}
response = requests.get(url, headers=header)
print(response.json())
class Xiaoyuzhoufm_download(Downloader):
def download(
self,
video_url: str,
output_dir: Union[str, None] = None,
quality: DownloadQuality = "fast",
need_video:Optional[bool]=False
) -> AudioDownloadResult:
pass

View File

@@ -109,8 +109,8 @@ def run_note_task(task_id: str, video_url: str, platform: str, quality: Download
@router.post('/delete_task')
def delete_task(data: RecordRequest):
try:
NoteGenerator().delete_note(video_id=data.video_id, platform=data.platform)
# TODO: 待持久化完成
# NoteGenerator().delete_note(video_id=data.video_id, platform=data.platform)
return R.success(msg='删除成功')
except Exception as e:
return R.error(msg=e)

View File

@@ -436,7 +436,7 @@ class NoteGenerator:
logger.info(f"转写并缓存成功 ({transcript_cache_file})")
return transcript
except Exception as exc:
logger.error(f"音频转写失败:{e}")
logger.error(f"音频转写失败:{exc}")
self._handle_exception(task_id, exc)
raise

View File

@@ -1,8 +1,9 @@
from fastapi.encoders import jsonable_encoder
from kombu import uuid
from app.db.models.providers import Provider
from app.db.provider_dao import (
insert_provider,
init_provider_table,
get_all_providers,
get_provider_by_name,
get_provider_by_id,
@@ -16,32 +17,51 @@ from app.models.model_config import ModelConfig
class ProviderService:
@staticmethod
def serialize_provider(row: tuple) -> dict:
def serialize_provider(row: Provider) -> dict:
if not row:
return None
row = ProviderService.provider_to_dict(row)
return {
"id": row[0],
"name": row[1],
"logo": row[2],
"type": row[3],
"api_key": row[4],
"base_url": row[5],
"enabled": row[6],
"created_at": row[7],
"id": row.get("id"),
"name": row.get("name"),
"logo": row.get("logo"),
"type":row.get("type"),
"enabled": row.get("enabled"),
"base_url": row.get("base_url"),
"api_key": row.get("api_key"),
"created_at": jsonable_encoder(row.get("created_at")),
# "name": row[1],
# "logo": row[2],
# "type": row[3],
# "api_key": row[4],
# "base_url": row[5],
# "enabled": row[6],
# "created_at": row[7],
}
@staticmethod
def serialize_provider_safe(row: tuple) -> dict:
def serialize_provider_safe(row: Provider) -> dict:
if not row:
return None
row = ProviderService.provider_to_dict(row)
return {
"id": row[0],
"name": row[1],
"logo": row[2],
"type": row[3],
"api_key": ProviderService.mask_key(row[4]),
"base_url": row[5],
"enabled": row[6],
"created_at": row[7],
"id": row.get("id"),
"name": row.get("name"),
"logo": row.get("logo"),
"type":row.get("type"),
"enabled": row.get("enabled"),
"base_url": row.get("base_url"),
"api_key": ProviderService.mask_key(row.get("api_key")),
"created_at": jsonable_encoder(row.get("created_at")),
# "id": row[0],
# "name": row[1],
# "logo": row[2],
# "type": row[3],
# "api_key": ProviderService.mask_key(row[4]),
# "base_url": row[5],
# "enabled": row[6],
# "created_at": row[7],
}
@staticmethod
def mask_key(key: str) -> str:
@@ -56,15 +76,30 @@ class ProviderService:
return insert_provider(id, name, api_key, base_url, logo, type_, enabled)
except Exception as e:
print('创建模式失败',e)
@staticmethod
def provider_to_dict(p: Provider):
return {
"id": p.id,
"name": p.name,
"logo": p.logo,
"type": p.type,
"api_key": p.api_key,
"base_url": p.base_url,
"enabled": p.enabled,
"created_at": p.created_at,
}
@staticmethod
def get_all_providers():
rows = get_all_providers()
if rows is None:
return []
return [ProviderService.serialize_provider(row) for row in rows] if rows else []
@staticmethod
def get_all_providers_safe():
rows = get_all_providers()
return [ProviderService.serialize_provider(row) for row in rows] if rows else []
return [ProviderService.serialize_provider(row) for row in rows] if (rows) else []
@staticmethod
def get_provider_by_name(name: str):
row = get_provider_by_name(name)

View File

@@ -6,15 +6,31 @@ from app.models.transcriber_model import TranscriptResult, TranscriptSegment
from app.services.provider import ProviderService
from app.transcriber.base import Transcriber
from openai import OpenAI
import ffmpeg
import tempfile
from dotenv import load_dotenv
load_dotenv()
MAX_SIZE_MB = 18
MAX_SIZE_BYTES = MAX_SIZE_MB * 1024 * 1024
def compress_audio(input_path: str, target_bitrate='64k') -> str:
output_fd, output_path = tempfile.mkstemp(suffix=".mp3") # 临时输出文件
os.close(output_fd) # 关闭文件描述符ffmpeg 会用路径操作
ffmpeg.input(input_path).output(output_path, audio_bitrate=target_bitrate).run(quiet=True, overwrite_output=True)
return output_path
class GroqTranscriber(Transcriber, ABC):
@timeit
def transcript(self, file_path: str) -> TranscriptResult:
file_size = os.path.getsize(file_path)
if file_size > MAX_SIZE_BYTES:
print(f"文件超过 {MAX_SIZE_MB}MB开始压缩当前 {round(file_size / (1024 * 1024), 2)}MB...")
file_path = compress_audio(file_path)
print(f"压缩完成,临时路径:{file_path}")
provider = ProviderService.get_provider_by_id('groq')
if not provider:
raise Exception("Groq 供应商未配置,请配置以后使用。")
client = OpenAI(

285
backend/app/utils/export.py Normal file
View File

@@ -0,0 +1,285 @@
import os
import re
from urllib.parse import quote
from markdown_pdf import MarkdownPdf, Section
from dotenv import load_dotenv
load_dotenv()
# 项目根路径(无论你在哪里运行)
BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# 从 .env 获取 DATA_DIR相对于 BASE_DIR 解析
DATA_DIR_NAME = os.getenv("DATA_DIR", "data")
DATA_DIR = os.path.join(BASE_DIR, DATA_DIR_NAME)
SAVE_PATH = os.path.join(DATA_DIR, "note_output")
IMAGE_BASE_URL = os.getenv("IMAGE_BASE_URL")
STATIC_BASE = os.path.join(BASE_DIR, IMAGE_BASE_URL)
class ExportUtils:
def __init__(self, **kwargs):
# 确认SAVE_PATH存在
print(f"保存路径: {SAVE_PATH}")
print(f"静态文件路径: {STATIC_BASE}")
if not os.path.exists(SAVE_PATH):
os.makedirs(SAVE_PATH)
def _embed_image_as_base64(self, img_path: str) -> str:
"""
将图片转换为 base64 格式嵌入
"""
import base64
import mimetypes
try:
# 获取 MIME 类型
mime_type, _ = mimetypes.guess_type(img_path)
if not mime_type:
# 根据扩展名推断
ext = os.path.splitext(img_path)[1].lower()
mime_map = {
'.png': 'image/png',
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.gif': 'image/gif',
'.bmp': 'image/bmp',
'.webp': 'image/webp',
'.svg': 'image/svg+xml'
}
mime_type = mime_map.get(ext, 'image/png')
# 读取图片文件并转换为 base64
with open(img_path, 'rb') as f:
img_data = f.read()
base64_data = base64.b64encode(img_data).decode('utf-8')
return f"data:{mime_type};base64,{base64_data}"
except Exception as e:
print(f"图片 base64 编码失败 {img_path}: {str(e)}")
return None
def _get_normalized_path(self, path: str) -> str:
"""
获取规范化的绝对路径
"""
return os.path.normpath(os.path.abspath(path))
def _replace_static_paths_with_absolute(self, content: str) -> str:
"""
将 Markdown 中的图片路径替换为 base64 内嵌格式
这样可以确保图片在 PDF 中正确显示
"""
def repl(match):
# 捕获 alt 文本和路径
alt_text = match.group(1) if match.group(1) else ""
img_path = match.group(2).strip()
print(f"处理图片路径: {img_path}")
# 处理 /static/ 开头的路径
if img_path.startswith("/static/"):
# 构建绝对路径
relative_path = img_path.lstrip("/") # 移除开头的 /
abs_path = os.path.join(BASE_DIR, relative_path)
abs_path = self._get_normalized_path(abs_path)
# 检查文件是否存在并转换为 base64
if os.path.exists(abs_path):
base64_uri = self._embed_image_as_base64(abs_path)
if base64_uri:
print(f"图片转换为 base64 成功: {img_path}")
return f"![{alt_text}]({base64_uri})"
else:
print(f"图片 base64 转换失败: {abs_path}")
return f"![{alt_text}](图片转换失败: {img_path})"
else:
print(f"警告:图片文件不存在 {abs_path}")
return f"![{alt_text}](图片不存在: {img_path})"
# 处理相对路径(相对于 STATIC_BASE
elif not img_path.startswith(('http://', 'https://', 'data:')):
# 尝试多个可能的路径
possible_paths = [
os.path.join(STATIC_BASE, img_path),
os.path.abspath(img_path),
os.path.join(BASE_DIR, img_path)
]
for abs_path in possible_paths:
abs_path = self._get_normalized_path(abs_path)
if os.path.exists(abs_path):
base64_uri = self._embed_image_as_base64(abs_path)
if base64_uri:
print(f"相对路径图片转换为 base64 成功: {img_path}")
return f"![{alt_text}]({base64_uri})"
break
print(f"警告:图片文件未找到 {img_path}")
return f"![{alt_text}](图片未找到: {img_path})"
# HTTP/HTTPS 和 data: 路径保持不变
elif img_path.startswith(('http://', 'https://', 'data:')):
print(f"网络图片或 data URI 保持不变: {img_path[:50]}...")
return match.group(0)
# 其他情况保持不变
return match.group(0)
# 使用更精确的正则表达式匹配图片语法
# 匹配 ![alt text](path) 格式
pattern = r'!\[([^\]]*)\]\(([^)]+)\)'
result = re.sub(pattern, repl, content)
print("图片路径处理完成")
return result
def _to_pdf(self, content: str, title: str):
"""
将 Markdown 内容转换为 PDF
"""
try:
# 创建 PDF 对象,启用优化
pdf = MarkdownPdf(
optimize=True,
# 添加一些可能有助于图片显示的配置
# toc=False,
# paper_size='A4',
# margin=dict(top='1cm', bottom='1cm', left='1cm', right='1cm')
)
# 添加内容段落
pdf.add_section(Section(content))
# 保存 PDF
save_path = os.path.join(SAVE_PATH, f"{title}.pdf")
pdf.save(save_path)
print(f"PDF 导出成功: {save_path}")
return save_path
except Exception as e:
print(f"PDF 导出失败: {str(e)}")
print("尝试使用基本配置...")
try:
# 尝试最基本的配置
pdf = MarkdownPdf()
pdf.add_section(Section(content))
save_path = os.path.join(SAVE_PATH, f"{title}.pdf")
pdf.save(save_path)
print(f"基本配置 PDF 导出成功: {save_path}")
return save_path
except Exception as e2:
print(f"基本配置也失败: {str(e2)}")
raise e2
def export(self, output_format: str, title: str, content: str) -> str:
"""
导出内容为指定格式
支持格式pdf, html, word/docx, image/png
"""
content = content.strip()
# 处理图片路径
print("开始处理图片路径...")
content = self._replace_static_paths_with_absolute(content)
output_format = output_format.lower()
try:
if output_format == "pdf":
save_path = self._to_pdf(content, title)
elif output_format == "html":
save_path = self._to_html(content, title)
elif output_format in ["word", "docx"]:
save_path = self._to_word(content, title)
elif output_format in ["image", "png"]:
save_path = self._to_image(content, title)
else:
supported_formats = ["pdf", "html", "word/docx", "image/png"]
raise ValueError(f"不支持的导出格式: {output_format}. 支持的格式: {', '.join(supported_formats)}")
print(f"导出完成: {save_path}")
return save_path
except Exception as e:
print(f"导出失败: {str(e)}")
raise e
def get_supported_formats(self):
"""
返回支持的导出格式列表
"""
return {
"pdf": "PDF 文档",
"html": "HTML 网页",
"word": "Word 文档 (.docx)",
"docx": "Word 文档 (.docx)",
"image": "PNG 图片",
"png": "PNG 图片"
}
def debug_paths(self):
"""
调试方法:打印重要路径信息
"""
print("=== 路径调试信息 ===")
print(f"BASE_DIR: {BASE_DIR}")
print(f"DATA_DIR: {DATA_DIR}")
print(f"SAVE_PATH: {SAVE_PATH}")
print(f"STATIC_BASE: {STATIC_BASE}")
print(f"IMAGE_BASE_URL: {IMAGE_BASE_URL}")
print("==================")
if __name__ == '__main__':
ExportUtils().export("pdf",title='测试',content='''# 视频笔记Facial Recognition Forces My Coworkers to Do Their Dishes
## 简介
该视频展示了团队如何利用面部识别技术来监控和激励同事清洗餐具。通过结合硬件和软件团队开发了一个“Dish Watcher”系统旨在识别并提醒那些未清洁餐具的人。
## 背景
- 团队面临的问题是同事们不愿意清洗餐具。
- 为解决这一问题,团队决定在不告知的情况下使用技术来监控厨房区域。
## 实验设计
1\. **设备安装**
- 使用Raspberry Pi和隐藏摄像头来捕捉厨房水槽的活动。
- 摄像头只在有人在水槽附近活动时录制,以节省存储空间。
2\. **软件开发**
- 使用Cursor AI和Meta的项目来分析视频。
- 系统能识别人员特征如发型、服装并将结果发送到Discord服务器以提醒团队。
3\. **面部识别**
- 通过视频流实时分析来判断是否有人留下了脏餐具。
- 系统能识别并记录下未清洗餐具的人的详细特征。
![](/static/screenshots/screenshot_000_a61be29d-06ae-42ee-ac38-2d0b1db394f3.jpg)* 展示了堆积的脏餐具,问题的严重性可见一斑。
## 实验过程
- 系统成功捕获了少数“罪犯”并通过Discord进行了通知。
- 计划将摄像头隐藏在厨房的画作后,使其更加隐蔽。
![](/static/screenshots/screenshot_001_e9d1c7ad-509e-4c7d-a718-a09193e97724.jpg)* SAM 介绍了项目的背景。
## 结果
- 实验初期,系统有效地识别了不清洗餐具的同事。
- 由于摄像头的存在,同事们开始自觉清洗餐具,长时间未发现新的“罪犯”。
## 思考与改进
- 团队意识到仅仅通过惩罚来改变行为可能效果有限,考虑奖励来激励清洗餐具。
- 系统将改进为奖励机制,记录并表扬那些清洗餐具的人。
## 总结
这次实验展示了技术在工作场所行为管理中的应用潜力。通过实验,团队不仅解决了餐具清洗的问题,还对如何更有效地激励员工有了更深的认识。
![](/static/screenshots/screenshot_002_f1ca0c20-c657-417f-be78-7958bf0e7a4b.jpg)* 展示了系统对某位同事洗碗的实时面部识别。
## 结论
- 应用技术可以有效改善工作环境中的小问题。
- 积极的激励比惩罚更能驱动行为改变。
通过这次实验,团队不仅解决了餐具堆积的问题,还为未来更复杂的行为管理系统奠定了基础。 ''',)

View File

@@ -19,6 +19,6 @@ class ResponseWrapper:
def error(msg="error", code=500, data=None):
return JSONResponse(content={
"code": code,
"msg": msg,
"msg": str(msg),
"data": data
})

39
backend/build.sh Executable file
View File

@@ -0,0 +1,39 @@
#!/usr/bin/env bash
set -e
# uncomment this for debugging
# set -x
# 切到项目根(假设脚本放在 script/ 目录)
cd "$(dirname "$0")/.."
echo "当前工作目录:$(pwd)"
# 清理旧的构建
echo "清理旧的构建..."
rm -rf backend/dist backend/build ./BillNote_frontend/src-tauri/bin/*
echo "清理完成。"
TARGET_TRIPLE=$(rustc -Vv | grep host | cut -f2 -d' ')
echo "Detected target triple: $TARGET_TRIPLE"
# PyInstaller onedir 模式,直接输出到 Tauri 的 bin 目录
echo "开始 PyInstaller 打包..."
pyinstaller \
--name BiliNoteBackend \
--paths backend \
--distpath ./BillNote_frontend/src-tauri/bin \
--workpath backend/build \
--specpath backend \
--hidden-import uvicorn \
--hidden-import fastapi \
--hidden-import starlette \
--add-data "app/db/builtin_providers.json:."\
--add-data "../.env:." \
"$(pwd)/backend/main.py" # 确保这里没有额外的空格,并使用绝对路径
mv \
./BillNote_frontend/src-tauri/bin/BiliNoteBackend/BiliNoteBackend\
./BillNote_frontend/src-tauri/bin/BiliNoteBackend/BiliNoteBackend-$TARGET_TRIPLE
echo "PyInstaller 打包完成:"
ls -l ./BillNote_frontend/src-tauri/bin/BiliNoteBackend # 这里会列出 onedir 模式下的目录内容
echo "请检查 src-tauri/bin/BiliNoteBackend 目录,以确认打包内容。"

View File

@@ -1,15 +1,19 @@
import os
from contextlib import asynccontextmanager
import uvicorn
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware
from starlette.staticfiles import StaticFiles
from dotenv import load_dotenv
from app.db.init_db import init_db
from app.db.provider_dao import seed_default_providers
from app.exceptions.exception_handlers import register_exception_handlers
from app.db.model_dao import init_model_table
from app.db.provider_dao import init_provider_table
# from app.db.model_dao import init_model_table
# from app.db.provider_dao import init_provider_table
from app.utils.logger import get_logger
from app import create_app
from app.db.video_task_dao import init_video_task_table
from app.transcriber.transcriber_provider import get_transcriber
from events import register_handler
from ffmpeg_helper import ensure_ffmpeg_or_raise
@@ -32,21 +36,33 @@ if not os.path.exists(uploads_dir):
if not os.path.exists(out_dir):
os.makedirs(out_dir)
app = create_app()
@asynccontextmanager
async def lifespan(app: FastAPI):
register_handler()
ensure_ffmpeg_or_raise()
init_db()
get_transcriber(transcriber_type=os.getenv("TRANSCRIBER_TYPE", "fast-whisper"))
seed_default_providers()
yield
app = create_app(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["tauri://localhost"], # ✅ 加上 Tauri 的 origin
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
register_exception_handlers(app)
app.mount(static_path, StaticFiles(directory=static_dir), name="static")
app.mount("/uploads", StaticFiles(directory=uploads_dir), name="uploads")
@app.on_event("startup")
async def startup_event():
register_handler()
ensure_ffmpeg_or_raise()
get_transcriber(transcriber_type=os.getenv("TRANSCRIBER_TYPE","fast-whisper"))
init_video_task_table()
init_provider_table()
init_model_table()
if __name__ == "__main__":

Binary file not shown.