mirror of
https://github.com/JefferyHcool/BiliNote.git
synced 2026-05-06 20:42:52 +08:00
3
.gitignore
vendored
3
.gitignore
vendored
@@ -320,4 +320,5 @@ cython_debug/
|
||||
/backend/uploads/*
|
||||
/backend/.idea/*
|
||||
/backend/config/*
|
||||
/BiliNote_frontend/.idea/*
|
||||
/BiliNote_frontend/.idea/*
|
||||
/BiliNote_frontend/src-tauri/bin/
|
||||
1
BillNote_frontend/.gitignore
vendored
1
BillNote_frontend/.gitignore
vendored
@@ -23,3 +23,4 @@ dist-ssr
|
||||
*.sln
|
||||
*.sw?
|
||||
/pnpm-lock.yaml
|
||||
/src-tauri/bin/
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
"@radix-ui/react-tabs": "^1.1.9",
|
||||
"@radix-ui/react-tooltip": "^1.1.8",
|
||||
"@tailwindcss/vite": "^4.1.3",
|
||||
"@tauri-apps/plugin-shell": "~2.2.2",
|
||||
"@uiw/react-markdown-preview": "^5.1.3",
|
||||
"antd": "^5.24.8",
|
||||
"axios": "^1.8.4",
|
||||
@@ -65,6 +66,7 @@
|
||||
"devDependencies": {
|
||||
"@eslint/js": "^9.21.0",
|
||||
"@tailwindcss/postcss": "^4.1.3",
|
||||
"@tauri-apps/cli": "^2.5.0",
|
||||
"@types/node": "^22.14.0",
|
||||
"@types/react": "^19.0.10",
|
||||
"@types/react-dom": "^19.0.4",
|
||||
|
||||
4
BillNote_frontend/src-tauri/.gitignore
vendored
Normal file
4
BillNote_frontend/src-tauri/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
# Generated by Cargo
|
||||
# will have compiled files and executables
|
||||
/target/
|
||||
/gen/schemas
|
||||
5027
BillNote_frontend/src-tauri/Cargo.lock
generated
Normal file
5027
BillNote_frontend/src-tauri/Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
31
BillNote_frontend/src-tauri/Cargo.toml
Normal file
31
BillNote_frontend/src-tauri/Cargo.toml
Normal file
@@ -0,0 +1,31 @@
|
||||
[package]
|
||||
name = "app"
|
||||
version = "0.1.0"
|
||||
description = "A Tauri App"
|
||||
authors = ["you"]
|
||||
license = ""
|
||||
repository = ""
|
||||
edition = "2021"
|
||||
rust-version = "1.77.2"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[lib]
|
||||
name = "app_lib"
|
||||
crate-type = ["staticlib", "cdylib", "rlib"]
|
||||
|
||||
[build-dependencies]
|
||||
tauri-build = { version = "2.2.0", features = [] }
|
||||
|
||||
[dependencies]
|
||||
serde_json = "1.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
log = "0.4"
|
||||
tauri = { version = "2.5.0", features = ["devtools"] }
|
||||
tauri-plugin-log = "2.0.0-rc"
|
||||
tauri-plugin-shell = "2"
|
||||
|
||||
[package.metadata.tauri.bundle.macOS]
|
||||
frameworks = ["bin/BiliNoteBackend/_internal/"]
|
||||
|
||||
|
||||
3
BillNote_frontend/src-tauri/build.rs
Normal file
3
BillNote_frontend/src-tauri/build.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
fn main() {
|
||||
tauri_build::build()
|
||||
}
|
||||
19
BillNote_frontend/src-tauri/capabilities/default.json
Normal file
19
BillNote_frontend/src-tauri/capabilities/default.json
Normal file
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"$schema": "../gen/schemas/desktop-schema.json",
|
||||
"identifier": "default",
|
||||
"description": "enables the default permissions",
|
||||
"windows": ["main"],
|
||||
"permissions": [
|
||||
"core:default",
|
||||
{
|
||||
"identifier": "shell:allow-execute",
|
||||
"allow": [
|
||||
{
|
||||
"name": "BiliNoteBackend",
|
||||
"sidecar": true
|
||||
}
|
||||
]
|
||||
},
|
||||
"shell:allow-open"
|
||||
]
|
||||
}
|
||||
BIN
BillNote_frontend/src-tauri/icons/icon.ico
Normal file
BIN
BillNote_frontend/src-tauri/icons/icon.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 78 KiB |
BIN
BillNote_frontend/src-tauri/icons/icon.png
Normal file
BIN
BillNote_frontend/src-tauri/icons/icon.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 15 KiB |
285
BillNote_frontend/src-tauri/src/lib.rs
Normal file
285
BillNote_frontend/src-tauri/src/lib.rs
Normal file
@@ -0,0 +1,285 @@
|
||||
use tauri::{Manager, Emitter};
|
||||
use tauri_plugin_shell::ShellExt;
|
||||
use tauri_plugin_shell::process::CommandEvent;
|
||||
use std::env;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[cfg_attr(mobile, tauri::mobile_entry_point)]
|
||||
pub fn run() {
|
||||
tauri::Builder::default()
|
||||
.plugin(tauri_plugin_shell::init())
|
||||
.setup(|app| {
|
||||
if cfg!(debug_assertions) {
|
||||
app.handle().plugin(
|
||||
tauri_plugin_log::Builder::default()
|
||||
.level(log::LevelFilter::Info)
|
||||
.build(),
|
||||
)?;
|
||||
}
|
||||
|
||||
let exe_path = env::current_exe().expect("无法获取当前可执行文件路径");
|
||||
let sidecar_dir = exe_path.parent().expect("无法获取可执行文件的父目录");
|
||||
|
||||
// 收集所有系统环境变量
|
||||
let mut all_env_vars = HashMap::new();
|
||||
for (key, value) in env::vars() {
|
||||
all_env_vars.insert(key, value);
|
||||
}
|
||||
|
||||
// 增强 PATH 环境变量,添加常见的二进制路径
|
||||
let current_path = all_env_vars.get("PATH").cloned().unwrap_or_default();
|
||||
let additional_paths = get_additional_binary_paths();
|
||||
let enhanced_path = enhance_path_variable(¤t_path, &additional_paths);
|
||||
all_env_vars.insert("PATH".to_string(), enhanced_path);
|
||||
|
||||
// 打印一些关键环境变量用于调试
|
||||
println!("Enhanced PATH: {}", all_env_vars.get("PATH").unwrap_or(&"Not found".to_string()));
|
||||
println!("Total environment variables: {}", all_env_vars.len());
|
||||
|
||||
// 检查 ffmpeg 是否在 PATH 中可用
|
||||
check_ffmpeg_availability();
|
||||
|
||||
// 启动 Python 后端侧车
|
||||
let mut sidecar_command = app.shell().sidecar("BiliNoteBackend").unwrap();
|
||||
|
||||
// 设置所有环境变量到 sidecar
|
||||
for (key, value) in &all_env_vars {
|
||||
sidecar_command = sidecar_command.env(key, value);
|
||||
}
|
||||
|
||||
let (mut rx, _child) = sidecar_command
|
||||
.current_dir(sidecar_dir)
|
||||
.spawn()
|
||||
.expect("Failed to spawn sidecar");
|
||||
|
||||
// 获取主窗口句柄用于发送事件
|
||||
let window = app.get_webview_window("main").unwrap();
|
||||
|
||||
tauri::async_runtime::spawn(async move {
|
||||
// 读取诸如 stdout 之类的事件
|
||||
while let Some(event) = rx.recv().await {
|
||||
match event {
|
||||
CommandEvent::Stdout(line) => {
|
||||
let output = String::from_utf8_lossy(&line);
|
||||
println!("Backend stdout: {}", output);
|
||||
|
||||
// 发送到前端
|
||||
window
|
||||
.emit("backend-message", Some(format!("'{}'", output)))
|
||||
.expect("failed to emit event");
|
||||
}
|
||||
CommandEvent::Stderr(line) => {
|
||||
let error = String::from_utf8_lossy(&line);
|
||||
eprintln!("Backend stderr: {}", error);
|
||||
|
||||
window
|
||||
.emit("backend-error", Some(format!("'{}'", error)))
|
||||
.expect("failed to emit event");
|
||||
}
|
||||
CommandEvent::Terminated(payload) => {
|
||||
println!("Backend terminated with code: {:?}", payload.code);
|
||||
window
|
||||
.emit("backend-terminated", Some(payload.code))
|
||||
.expect("failed to emit event");
|
||||
break;
|
||||
}
|
||||
_ => {
|
||||
println!("Backend event: {:?}", event);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.invoke_handler(tauri::generate_handler![
|
||||
get_system_env_vars,
|
||||
find_executable_path,
|
||||
run_command_with_env,
|
||||
test_ffmpeg_access
|
||||
])
|
||||
.run(tauri::generate_context!())
|
||||
.expect("error while running tauri application");
|
||||
}
|
||||
|
||||
// 获取额外的二进制路径
|
||||
fn get_additional_binary_paths() -> Vec<String> {
|
||||
if cfg!(target_os = "windows") {
|
||||
vec![
|
||||
"C:\\ffmpeg\\bin".to_string(),
|
||||
"C:\\Program Files\\ffmpeg\\bin".to_string(),
|
||||
"C:\\Program Files (x86)\\ffmpeg\\bin".to_string(),
|
||||
"C:\\tools\\ffmpeg\\bin".to_string(),
|
||||
"C:\\ProgramData\\chocolatey\\bin".to_string(),
|
||||
]
|
||||
} else if cfg!(target_os = "macos") {
|
||||
vec![
|
||||
"/usr/local/bin".to_string(),
|
||||
"/opt/homebrew/bin".to_string(),
|
||||
"/usr/bin".to_string(),
|
||||
"/bin".to_string(),
|
||||
"/opt/local/bin".to_string(), // MacPorts
|
||||
]
|
||||
} else {
|
||||
vec![
|
||||
"/usr/local/bin".to_string(),
|
||||
"/usr/bin".to_string(),
|
||||
"/bin".to_string(),
|
||||
"/snap/bin".to_string(),
|
||||
"/opt/bin".to_string(),
|
||||
"/usr/local/sbin".to_string(),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
// 增强 PATH 环境变量
|
||||
fn enhance_path_variable(current_path: &str, additional_paths: &[String]) -> String {
|
||||
let path_separator = if cfg!(target_os = "windows") { ";" } else { ":" };
|
||||
|
||||
let mut paths: Vec<String> = additional_paths.to_vec();
|
||||
|
||||
// 添加当前 PATH
|
||||
if !current_path.is_empty() {
|
||||
paths.push(current_path.to_string());
|
||||
}
|
||||
|
||||
paths.join(path_separator)
|
||||
}
|
||||
|
||||
// 检查 ffmpeg 可用性
|
||||
fn check_ffmpeg_availability() {
|
||||
use std::process::Command;
|
||||
|
||||
match Command::new("ffmpeg").arg("-version").output() {
|
||||
Ok(output) => {
|
||||
if output.status.success() {
|
||||
println!("✓ FFmpeg is available in PATH");
|
||||
let version_info = String::from_utf8_lossy(&output.stdout);
|
||||
let first_line = version_info.lines().next().unwrap_or("Unknown version");
|
||||
println!("FFmpeg version: {}", first_line);
|
||||
} else {
|
||||
println!("✗ FFmpeg found but returned error");
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
println!("✗ FFmpeg not found in PATH: {}", e);
|
||||
|
||||
// 尝试在常见路径中查找
|
||||
let common_paths = get_additional_binary_paths();
|
||||
for path in common_paths {
|
||||
let ffmpeg_path = if cfg!(target_os = "windows") {
|
||||
format!("{}\\ffmpeg.exe", path)
|
||||
} else {
|
||||
format!("{}/ffmpeg", path)
|
||||
};
|
||||
|
||||
if std::path::Path::new(&ffmpeg_path).exists() {
|
||||
println!("✓ Found FFmpeg at: {}", ffmpeg_path);
|
||||
return;
|
||||
}
|
||||
}
|
||||
println!("✗ FFmpeg not found in common installation paths");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Tauri 命令:获取系统环境变量
|
||||
#[tauri::command]
|
||||
fn get_system_env_vars() -> HashMap<String, String> {
|
||||
env::vars().collect()
|
||||
}
|
||||
|
||||
// Tauri 命令:查找可执行文件路径
|
||||
#[tauri::command]
|
||||
fn find_executable_path(executable_name: String) -> Option<String> {
|
||||
use std::process::Command;
|
||||
|
||||
// 首先尝试直接执行
|
||||
if Command::new(&executable_name).arg("--version").output().is_ok() {
|
||||
return Some(executable_name);
|
||||
}
|
||||
|
||||
// 使用 which/where 命令查找
|
||||
let which_cmd = if cfg!(target_os = "windows") { "where" } else { "which" };
|
||||
|
||||
if let Ok(output) = Command::new(which_cmd).arg(&executable_name).output() {
|
||||
if output.status.success() {
|
||||
let path = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
if !path.is_empty() {
|
||||
return Some(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 在常见路径中搜索
|
||||
let common_paths = get_additional_binary_paths();
|
||||
for base_path in common_paths {
|
||||
let executable_path = if cfg!(target_os = "windows") {
|
||||
format!("{}\\{}.exe", base_path, executable_name)
|
||||
} else {
|
||||
format!("{}/{}", base_path, executable_name)
|
||||
};
|
||||
|
||||
if std::path::Path::new(&executable_path).exists() {
|
||||
return Some(executable_path);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
// Tauri 命令:使用完整环境变量运行命令
|
||||
#[tauri::command]
|
||||
async fn run_command_with_env(
|
||||
program: String,
|
||||
args: Vec<String>
|
||||
) -> Result<String, String> {
|
||||
use std::process::Command;
|
||||
|
||||
let mut cmd = Command::new(&program);
|
||||
cmd.args(&args);
|
||||
|
||||
// 设置所有环境变量
|
||||
for (key, value) in env::vars() {
|
||||
cmd.env(key, value);
|
||||
}
|
||||
|
||||
// 增强 PATH
|
||||
let current_path = env::var("PATH").unwrap_or_default();
|
||||
let additional_paths = get_additional_binary_paths();
|
||||
let enhanced_path = enhance_path_variable(¤t_path, &additional_paths);
|
||||
cmd.env("PATH", enhanced_path);
|
||||
|
||||
match cmd.output() {
|
||||
Ok(output) => {
|
||||
if output.status.success() {
|
||||
Ok(String::from_utf8_lossy(&output.stdout).to_string())
|
||||
} else {
|
||||
Err(String::from_utf8_lossy(&output.stderr).to_string())
|
||||
}
|
||||
}
|
||||
Err(e) => Err(format!("Failed to execute {}: {}", program, e))
|
||||
}
|
||||
}
|
||||
|
||||
// Tauri 命令:测试 ffmpeg 访问
|
||||
#[tauri::command]
|
||||
async fn test_ffmpeg_access() -> Result<String, String> {
|
||||
run_command_with_env("ffmpeg".to_string(), vec!["-version".to_string()]).await
|
||||
}
|
||||
|
||||
// 可选:添加一个函数来动态更新 sidecar 的环境变量
|
||||
#[tauri::command]
|
||||
async fn update_sidecar_environment(
|
||||
app_handle: tauri::AppHandle,
|
||||
additional_env_vars: HashMap<String, String>
|
||||
) -> Result<(), String> {
|
||||
// 这个函数可以用来在运行时更新环境变量
|
||||
// 注意:这需要重启 sidecar 才能生效
|
||||
|
||||
for (key, value) in additional_env_vars {
|
||||
env::set_var(key, value);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
6
BillNote_frontend/src-tauri/src/main.rs
Normal file
6
BillNote_frontend/src-tauri/src/main.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
// Prevents additional console window on Windows in release, DO NOT REMOVE!!
|
||||
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")]
|
||||
|
||||
fn main() {
|
||||
app_lib::run();
|
||||
}
|
||||
46
BillNote_frontend/src-tauri/tauri.conf.json
Normal file
46
BillNote_frontend/src-tauri/tauri.conf.json
Normal file
@@ -0,0 +1,46 @@
|
||||
{
|
||||
"$schema": "../node_modules/@tauri-apps/cli/config.schema.json",
|
||||
"productName": "BiliNote",
|
||||
"version": "0.1.0",
|
||||
"identifier": "com.jefferyhuang.bilinote",
|
||||
"build": {
|
||||
"frontendDist": "../dist",
|
||||
"devUrl": "http://localhost:3015",
|
||||
"beforeDevCommand": "pnpm dev",
|
||||
"beforeBuildCommand": "pnpm build"
|
||||
},
|
||||
"app": {
|
||||
"windows": [
|
||||
{
|
||||
"title": "BiliNote",
|
||||
"width": 1400,
|
||||
"height": 900,
|
||||
"resizable": true,
|
||||
"fullscreen": false,
|
||||
"devtools": true
|
||||
}
|
||||
],
|
||||
"security": {
|
||||
"csp": null
|
||||
}
|
||||
},
|
||||
"bundle": {
|
||||
"externalBin": [
|
||||
"bin/BiliNoteBackend/BiliNoteBackend"
|
||||
],
|
||||
"resources": {
|
||||
"bin/BiliNoteBackend/_internal":"_internal"
|
||||
},
|
||||
"macOS":{
|
||||
"files": {
|
||||
"Frameworks": "bin/BiliNoteBackend/_internal"
|
||||
}
|
||||
},
|
||||
"active": true,
|
||||
"targets": "all",
|
||||
"icon": [
|
||||
"icons/icon.ico",
|
||||
"icons/icon.png"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,6 @@ import { createRoot } from 'react-dom/client'
|
||||
import './index.css'
|
||||
import App from './App.tsx'
|
||||
import RootLayout from './layouts/RootLayout.tsx'
|
||||
|
||||
createRoot(document.getElementById('root')!).render(
|
||||
<StrictMode>
|
||||
<RootLayout>
|
||||
|
||||
@@ -38,6 +38,7 @@ import { Input } from '@/components/ui/input.tsx'
|
||||
import { Textarea } from '@/components/ui/textarea.tsx'
|
||||
import { noteStyles, noteFormats, videoPlatforms } from '@/constant/note.ts'
|
||||
import { fetchModels } from '@/services/model.ts'
|
||||
import { useNavigate } from 'react-router-dom'
|
||||
|
||||
/* -------------------- 校验 Schema -------------------- */
|
||||
const formSchema = z
|
||||
@@ -119,7 +120,7 @@ const CheckboxGroup = ({
|
||||
|
||||
/* -------------------- 主组件 -------------------- */
|
||||
const NoteForm = () => {
|
||||
|
||||
const navigate = useNavigate();
|
||||
const [isUploading, setIsUploading] = useState(false)
|
||||
const [uploadSuccess, setUploadSuccess] = useState(false)
|
||||
/* ---- 全局状态 ---- */
|
||||
@@ -147,6 +148,9 @@ const NoteForm = () => {
|
||||
const videoUnderstandingEnabled = useWatch({ control: form.control, name: 'video_understanding' })
|
||||
const editing = currentTask && currentTask.id
|
||||
|
||||
const goModelAdd = () => {
|
||||
navigate("/settings/model");
|
||||
};
|
||||
/* ---- 副作用 ---- */
|
||||
useEffect(() => {
|
||||
loadEnabledModels()
|
||||
@@ -192,6 +196,7 @@ const NoteForm = () => {
|
||||
setUploadSuccess(false)
|
||||
|
||||
try {
|
||||
|
||||
const data = await uploadFile(formData)
|
||||
cb(data.url)
|
||||
setUploadSuccess(true)
|
||||
@@ -363,38 +368,50 @@ const NoteForm = () => {
|
||||
/>
|
||||
<div className="grid grid-cols-2 gap-2">
|
||||
{/* 模型选择 */}
|
||||
<FormField
|
||||
className="w-full"
|
||||
control={form.control}
|
||||
name="model_name"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<SectionHeader title="模型选择" tip="不同模型效果不同,建议自行测试" />
|
||||
<Select
|
||||
onOpenChange={()=>{
|
||||
loadEnabledModels()
|
||||
}}
|
||||
value={field.value}
|
||||
onValueChange={field.onChange}
|
||||
defaultValue={field.value}
|
||||
>
|
||||
<FormControl>
|
||||
<SelectTrigger className="w-full min-w-0 truncate">
|
||||
<SelectValue />
|
||||
</SelectTrigger>
|
||||
</FormControl>
|
||||
<SelectContent>
|
||||
{modelList.map(m => (
|
||||
<SelectItem key={m.id} value={m.model_name}>
|
||||
{m.model_name}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
{
|
||||
|
||||
modelList.length>0?( <FormField
|
||||
className="w-full"
|
||||
control={form.control}
|
||||
name="model_name"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<SectionHeader title="模型选择" tip="不同模型效果不同,建议自行测试" />
|
||||
<Select
|
||||
onOpenChange={()=>{
|
||||
loadEnabledModels()
|
||||
}}
|
||||
value={field.value}
|
||||
onValueChange={field.onChange}
|
||||
defaultValue={field.value}
|
||||
>
|
||||
<FormControl>
|
||||
<SelectTrigger className="w-full min-w-0 truncate">
|
||||
<SelectValue />
|
||||
</SelectTrigger>
|
||||
</FormControl>
|
||||
<SelectContent>
|
||||
{modelList.map(m => (
|
||||
<SelectItem key={m.id} value={m.model_name}>
|
||||
{m.model_name}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>): (
|
||||
<FormItem>
|
||||
<SectionHeader title="模型选择" tip="不同模型效果不同,建议自行测试" />
|
||||
<Button type={'button'} variant={
|
||||
'outline'
|
||||
} onClick={()=>{goModelAdd()}}>请先添加模型</Button>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)
|
||||
}
|
||||
|
||||
{/* 笔记风格 */}
|
||||
<FormField
|
||||
className="w-full"
|
||||
|
||||
@@ -49,13 +49,9 @@ export const delete_task = async ({ video_id, platform }) => {
|
||||
}
|
||||
const res = await request.post('/delete_task', data)
|
||||
|
||||
if (res.data.code === 0) {
|
||||
|
||||
toast.success('任务已成功删除')
|
||||
return res.data
|
||||
} else {
|
||||
toast.error(res.data.message || '删除失败')
|
||||
throw new Error(res.data.message || '删除失败')
|
||||
}
|
||||
return res
|
||||
} catch (e) {
|
||||
toast.error('请求异常,删除任务失败')
|
||||
console.error('❌ 删除任务失败:', e)
|
||||
|
||||
@@ -4,8 +4,8 @@ from .routers import note, provider, model, config
|
||||
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
app = FastAPI(title="BiliNote")
|
||||
def create_app(lifespan) -> FastAPI:
|
||||
app = FastAPI(title="BiliNote",lifespan=lifespan)
|
||||
app.include_router(note.router, prefix="/api")
|
||||
app.include_router(provider.router, prefix="/api")
|
||||
app.include_router(model.router,prefix="/api")
|
||||
|
||||
36
backend/app/db/engine.py
Normal file
36
backend/app/db/engine.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import os
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# 默认 SQLite,如果想换 PostgreSQL 或 MySQL,可以直接改 .env
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///bili_note.db")
|
||||
|
||||
# SQLite 需要特定连接参数,其他数据库不需要
|
||||
engine_args = {}
|
||||
if DATABASE_URL.startswith("sqlite"):
|
||||
engine_args["connect_args"] = {"check_same_thread": False}
|
||||
|
||||
engine = create_engine(
|
||||
DATABASE_URL,
|
||||
echo=os.getenv("SQLALCHEMY_ECHO", "false").lower() == "true",
|
||||
**engine_args
|
||||
)
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def get_engine():
|
||||
return engine
|
||||
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
9
backend/app/db/init_db.py
Normal file
9
backend/app/db/init_db.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from app.db.models.models import Model
|
||||
from app.db.models.providers import Provider
|
||||
from app.db.models.video_tasks import VideoTask
|
||||
from app.db.engine import get_engine, Base
|
||||
|
||||
def init_db():
|
||||
engine = get_engine()
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
@@ -1,67 +1,67 @@
|
||||
from app.db.sqlite_client import get_connection
|
||||
|
||||
def init_model_table():
|
||||
conn = get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS models (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
provider_id INTEGER NOT NULL,
|
||||
model_name TEXT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
from app.db.engine import get_db
|
||||
from app.db.models.models import Model
|
||||
|
||||
|
||||
def get_model_by_provider_and_name(provider_id: int, model_name: str):
|
||||
conn = get_connection()
|
||||
cursor = conn.execute(
|
||||
"SELECT * FROM models WHERE provider_id = ? AND model_name = ?",
|
||||
(provider_id, model_name)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return row
|
||||
# 插入模型
|
||||
db = next(get_db())
|
||||
try:
|
||||
model = db.query(Model).filter_by(provider_id=provider_id, model_name=model_name).first()
|
||||
if model:
|
||||
return {
|
||||
"id": model.id,
|
||||
"provider_id": model.provider_id,
|
||||
"model_name": model.model_name,
|
||||
"created_at": model.created_at,
|
||||
}
|
||||
return None
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def insert_model(provider_id: int, model_name: str):
|
||||
conn = get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
INSERT INTO models (provider_id, model_name)
|
||||
VALUES (?, ?)
|
||||
""", (provider_id, model_name))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
db = next(get_db())
|
||||
try:
|
||||
model = Model(provider_id=provider_id, model_name=model_name)
|
||||
db.add(model)
|
||||
db.commit()
|
||||
db.refresh(model)
|
||||
return {
|
||||
"id": model.id,
|
||||
"provider_id": model.provider_id,
|
||||
"model_name": model.model_name,
|
||||
"created_at": model.created_at,
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# 根据provider查模型
|
||||
def get_models_by_provider(provider_id: int):
|
||||
conn = get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT id, model_name FROM models
|
||||
WHERE provider_id = ?
|
||||
""", (provider_id,))
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
return [{"id": row[0], "model_name": row[1]} for row in rows]
|
||||
db = next(get_db())
|
||||
try:
|
||||
models = db.query(Model).filter_by(provider_id=provider_id).all()
|
||||
return [{"id": m.id, "model_name": m.model_name} for m in models]
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# 删除某个模型
|
||||
def delete_model(model_id: int):
|
||||
conn = get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
DELETE FROM models WHERE id = ?
|
||||
""", (model_id,))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
db = next(get_db())
|
||||
try:
|
||||
model = db.query(Model).filter_by(id=model_id).first()
|
||||
if model:
|
||||
db.delete(model)
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def get_all_models():
|
||||
conn = get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT id, provider_id, model_name FROM models
|
||||
""")
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
return [{"id": row[0], "provider_id": row[1], "model_name": row[2]} for row in rows]
|
||||
db = next(get_db())
|
||||
try:
|
||||
models = db.query(Model).all()
|
||||
return [
|
||||
{"id": m.id, "provider_id": m.provider_id, "model_name": m.model_name}
|
||||
for m in models
|
||||
]
|
||||
finally:
|
||||
db.close()
|
||||
0
backend/app/db/models/__init__.py
Normal file
0
backend/app/db/models/__init__.py
Normal file
12
backend/app/db/models/models.py
Normal file
12
backend/app/db/models/models.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from sqlalchemy import Column, Integer, String, DateTime, func, ForeignKey
|
||||
|
||||
from app.db.engine import Base
|
||||
|
||||
|
||||
class Model(Base):
|
||||
__tablename__ = "models"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
provider_id = Column(Integer, nullable=False)
|
||||
model_name = Column(String, nullable=False)
|
||||
created_at = Column(DateTime, server_default=func.now())
|
||||
17
backend/app/db/models/providers.py
Normal file
17
backend/app/db/models/providers.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from sqlalchemy import Column, String, Integer, DateTime, func
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
from app.db.engine import Base
|
||||
|
||||
|
||||
class Provider(Base):
|
||||
__tablename__ = "providers"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
name = Column(String, nullable=False)
|
||||
logo = Column(String, nullable=False)
|
||||
type = Column(String, nullable=False)
|
||||
api_key = Column(String, nullable=False)
|
||||
base_url = Column(String, nullable=False)
|
||||
enabled = Column(Integer, default=1)
|
||||
created_at = Column(DateTime, server_default=func.now())
|
||||
14
backend/app/db/models/video_tasks.py
Normal file
14
backend/app/db/models/video_tasks.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from sqlalchemy import Column, Integer, String, DateTime, func
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
from app.db.engine import Base
|
||||
|
||||
|
||||
class VideoTask(Base):
|
||||
__tablename__ = "video_tasks"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
video_id = Column(String, nullable=False)
|
||||
platform = Column(String, nullable=False)
|
||||
task_id = Column(String, unique=True, nullable=False)
|
||||
created_at = Column(DateTime, server_default=func.now())
|
||||
@@ -1,14 +1,13 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
from app.db.sqlite_client import get_connection
|
||||
from app.db.models.providers import Provider
|
||||
from app.utils.logger import get_logger
|
||||
from app.db.engine import get_engine, Base, get_db
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
|
||||
def get_builtin_providers_path():
|
||||
if getattr(sys, 'frozen', False):
|
||||
base_path = sys._MEIPASS
|
||||
@@ -16,213 +15,115 @@ def get_builtin_providers_path():
|
||||
base_path = os.path.dirname(__file__)
|
||||
return os.path.join(base_path, 'builtin_providers.json')
|
||||
|
||||
|
||||
def seed_default_providers():
|
||||
conn = get_connection()
|
||||
if conn is None:
|
||||
logger.error("Failed to connect to database.")
|
||||
return
|
||||
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 检查已有数据
|
||||
cursor.execute("SELECT COUNT(*) FROM providers")
|
||||
count = cursor.fetchone()[0]
|
||||
if count > 0:
|
||||
logger.info("Providers already exist, skipping seed.")
|
||||
conn.close()
|
||||
return
|
||||
|
||||
json_path = get_builtin_providers_path()
|
||||
db = next(get_db())
|
||||
try:
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
providers = json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read builtin_providers.json: {e}")
|
||||
conn.close()
|
||||
return
|
||||
if db.query(Provider).count() > 0:
|
||||
logger.info("Providers already exist, skipping seed.")
|
||||
return
|
||||
|
||||
json_path = get_builtin_providers_path()
|
||||
try:
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
providers = json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read builtin_providers.json: {e}")
|
||||
return
|
||||
|
||||
try:
|
||||
for p in providers:
|
||||
cursor.execute("""
|
||||
INSERT INTO providers (id, name, api_key, base_url, logo, type, enabled)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
p['id'],
|
||||
p['name'],
|
||||
p['api_key'],
|
||||
p['base_url'],
|
||||
p['logo'],
|
||||
p['type'],
|
||||
p.get('enabled', 1)
|
||||
db.add(Provider(
|
||||
id=p['id'],
|
||||
name=p['name'],
|
||||
api_key=p['api_key'],
|
||||
base_url=p['base_url'],
|
||||
logo=p['logo'],
|
||||
type=p['type'],
|
||||
enabled=p.get('enabled', 1)
|
||||
))
|
||||
conn.commit()
|
||||
db.commit()
|
||||
logger.info("Default providers seeded successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to seed default providers: {e}")
|
||||
finally:
|
||||
conn.close()
|
||||
def init_provider_table():
|
||||
conn = get_connection()
|
||||
if conn is None:
|
||||
logger.error("Failed to connect to the database.")
|
||||
return
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS providers (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
logo TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
api_key TEXT NOT NULL,
|
||||
base_url TEXT NOT NULL,
|
||||
enabled INTEGER DEFAULT 1,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
db.close()
|
||||
|
||||
|
||||
def insert_provider(id: str, name: str, api_key: str, base_url: str, logo: str, type_: str, enabled: int = 1):
|
||||
db = next(get_db())
|
||||
try:
|
||||
conn.commit()
|
||||
conn.close()
|
||||
logger.info("provider table created successfully.")
|
||||
seed_default_providers()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create provider table: {e}")
|
||||
def insert_provider(id: str, name: str, api_key: str, base_url: str, logo: str, type_: str,enabled:int=1):
|
||||
conn = get_connection()
|
||||
if conn is None:
|
||||
logger.error("Failed to connect to the database.")
|
||||
return
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
INSERT INTO providers (id, name, api_key, base_url, logo, type, enabled)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""", (id, name, api_key, base_url, logo, type_, enabled))
|
||||
try:
|
||||
conn.commit()
|
||||
conn.close()
|
||||
provider = Provider(id=id, name=name, api_key=api_key, base_url=base_url, logo=logo, type=type_, enabled=enabled)
|
||||
db.add(provider)
|
||||
db.commit()
|
||||
logger.info(f"Provider inserted successfully. id: {id}, name: {name}, type: {type_}")
|
||||
return id
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to insert provider: {e}")
|
||||
return None
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def get_enabled_providers():
|
||||
conn = get_connection()
|
||||
if conn is None:
|
||||
logger.error("Failed to connect to the database.")
|
||||
return
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM providers WHERE enabled = 1")
|
||||
db = next(get_db())
|
||||
try:
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
if rows is None:
|
||||
logger.info("No providers found")
|
||||
return None
|
||||
logger.info(f"Providers found: {rows}")
|
||||
return rows
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get enabled providers: {e}")
|
||||
return db.query(Provider).filter_by(enabled=1).all()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def get_provider_by_name(name: str):
|
||||
conn = get_connection()
|
||||
if conn is None:
|
||||
logger.error("Failed to connect to the database.")
|
||||
return
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM providers WHERE name = ?", (name,))
|
||||
db = next(get_db())
|
||||
try:
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
if row is None:
|
||||
logger.info(f"Provider not found: {name}")
|
||||
return None
|
||||
logger.info(f"Provider found: {row[0]}")
|
||||
return db.query(Provider).filter_by(name=name).first()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return row
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get provider by name: {e}")
|
||||
|
||||
def get_provider_by_id(id: int):
|
||||
conn = get_connection()
|
||||
if conn is None:
|
||||
logger.error("Failed to connect to the database.")
|
||||
return
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM providers WHERE id = ?", (id,))
|
||||
|
||||
def get_provider_by_id(id: str):
|
||||
db = next(get_db())
|
||||
try:
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
if row is None:
|
||||
logger.info(f"Provider not found: {id}")
|
||||
return None
|
||||
logger.info(f"Provider found: {row[0]}")
|
||||
return row
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get provider by id: {e}")
|
||||
return db.query(Provider).filter_by(id=id).first()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def get_all_providers():
|
||||
conn = get_connection()
|
||||
if conn is None:
|
||||
logger.error("Failed to connect to the database.")
|
||||
return
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM providers")
|
||||
db = next(get_db())
|
||||
try:
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
if rows is None:
|
||||
logger.info("No providers found")
|
||||
return None
|
||||
logger.info(f"Providers found total {len(rows) }")
|
||||
return rows
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get all providers: {e}")
|
||||
return db.query(Provider).all()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def update_provider(id: str, **kwargs):
|
||||
conn = get_connection()
|
||||
if conn is None:
|
||||
logger.error("Failed to connect to the database.")
|
||||
return
|
||||
|
||||
fields = []
|
||||
values = []
|
||||
|
||||
for key, value in kwargs.items():
|
||||
fields.append(f"{key} = ?")
|
||||
values.append(value)
|
||||
|
||||
if not fields:
|
||||
logger.warning("No fields provided for update.")
|
||||
return
|
||||
|
||||
sql = f"""
|
||||
UPDATE providers
|
||||
SET {', '.join(fields)}
|
||||
WHERE id = ?
|
||||
"""
|
||||
|
||||
values.append(id) # id 最后加
|
||||
cursor = conn.cursor()
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
cursor.execute(sql, values)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
logger.info(f"Provider updated successfully. id: {id}, updated_fields: {fields}")
|
||||
provider = db.query(Provider).filter_by(id=id).first()
|
||||
if not provider:
|
||||
logger.warning(f"Provider {id} not found for update.")
|
||||
return
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(provider, key):
|
||||
setattr(provider, key, value)
|
||||
|
||||
db.commit()
|
||||
logger.info(f"Provider updated successfully. id: {id}, updated_fields: {list(kwargs.keys())}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update provider: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def delete_provider(id: int):
|
||||
conn = get_connection()
|
||||
if conn is None:
|
||||
logger.error("Failed to connect to the database.")
|
||||
return
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM providers WHERE id = ?", (id,))
|
||||
|
||||
def delete_provider(id: str):
|
||||
db = next(get_db())
|
||||
try:
|
||||
conn.commit()
|
||||
conn.close()
|
||||
logger.info(f"Provider deleted successfully. id: {id}")
|
||||
provider = db.query(Provider).filter_by(id=id).first()
|
||||
if provider:
|
||||
db.delete(provider)
|
||||
db.commit()
|
||||
logger.info(f"Provider deleted successfully. id: {id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete provider: {e}")
|
||||
logger.error(f"Failed to delete provider: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
@@ -1,78 +1,61 @@
|
||||
from .sqlite_client import get_connection
|
||||
from app.db.models.video_tasks import VideoTask
|
||||
from app.db.engine import get_db
|
||||
from app.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
def init_video_task_table():
|
||||
conn = get_connection()
|
||||
if conn is None:
|
||||
logger.error("Failed to connect to the database.")
|
||||
return
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS video_tasks (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
video_id TEXT NOT NULL,
|
||||
platform TEXT NOT NULL,
|
||||
task_id TEXT NOT NULL UNIQUE,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
try:
|
||||
conn.commit()
|
||||
conn.close()
|
||||
logger.info("video_tasks table created successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create video_tasks table: {e}")
|
||||
|
||||
# 插入任务
|
||||
def insert_video_task(video_id: str, platform: str, task_id: str):
|
||||
db = next(get_db())
|
||||
try:
|
||||
conn = get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
INSERT INTO video_tasks (video_id, platform, task_id)
|
||||
VALUES (?, ?, ?)
|
||||
""", (video_id, platform, task_id))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
logger.info(f"Video task inserted successfully."
|
||||
f"video_id: {video_id}"
|
||||
f"platform: {platform}"
|
||||
f"task_id: {task_id}")
|
||||
task = VideoTask(video_id=video_id, platform=platform, task_id=task_id)
|
||||
db.add(task)
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
logger.info(f"Video task inserted successfully. video_id: {video_id}, platform: {platform}, task_id: {task_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to insert video task: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# 查询任务(最新一条)
|
||||
def get_task_by_video(video_id: str, platform: str):
|
||||
db = next(get_db())
|
||||
try:
|
||||
conn = get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT task_id FROM video_tasks
|
||||
WHERE video_id = ? AND platform = ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
""", (video_id, platform))
|
||||
result = cursor.fetchone()
|
||||
conn.close()
|
||||
if result is None:
|
||||
task = (
|
||||
db.query(VideoTask)
|
||||
.filter_by(video_id=video_id, platform=platform)
|
||||
.order_by(VideoTask.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
if task:
|
||||
logger.info(f"Task found for video_id: {video_id} and platform: {platform}")
|
||||
return task.task_id
|
||||
else:
|
||||
logger.info(f"No task found for video_id: {video_id} and platform: {platform}")
|
||||
logger.info(f"Task found for video_id: {video_id} and platform: {platform}")
|
||||
return result[0] if result else None
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get task by video: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# 删除任务
|
||||
def delete_task_by_video(video_id: str, platform: str):
|
||||
db = next(get_db())
|
||||
try:
|
||||
conn = get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
DELETE FROM video_tasks
|
||||
WHERE video_id = ? AND platform = ?
|
||||
""", (video_id, platform))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
logger.info(f"Task deleted for video_id: {video_id} and platform: {platform}")
|
||||
tasks = (
|
||||
db.query(VideoTask)
|
||||
.filter_by(video_id=video_id, platform=platform)
|
||||
.all()
|
||||
)
|
||||
for task in tasks:
|
||||
db.delete(task)
|
||||
db.commit()
|
||||
logger.info(f"Task(s) deleted for video_id: {video_id} and platform: {platform}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete task by video: {e}")
|
||||
logger.error(f"Failed to delete task by video: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
25
backend/app/downloaders/xiaoyuzhoufm_download.py
Normal file
25
backend/app/downloaders/xiaoyuzhoufm_download.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import Union, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from app.downloaders.base import Downloader
|
||||
from app.enmus.note_enums import DownloadQuality
|
||||
from app.models.audio_model import AudioDownloadResult
|
||||
|
||||
url='https://www.xiaoyuzhoufm.com/_next/data/5Pvt_oGntgdyBD_XgwBaB/podcast/62382c1103bea1ebfffa1c00.json?id=62382c1103bea1ebfffa1c00'
|
||||
header ={
|
||||
'user-agent':'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/137.0.0.0 Safari/537.36'
|
||||
}
|
||||
|
||||
response = requests.get(url, headers=header)
|
||||
print(response.json())
|
||||
|
||||
class Xiaoyuzhoufm_download(Downloader):
|
||||
def download(
|
||||
self,
|
||||
video_url: str,
|
||||
output_dir: Union[str, None] = None,
|
||||
quality: DownloadQuality = "fast",
|
||||
need_video:Optional[bool]=False
|
||||
) -> AudioDownloadResult:
|
||||
pass
|
||||
@@ -109,8 +109,8 @@ def run_note_task(task_id: str, video_url: str, platform: str, quality: Download
|
||||
@router.post('/delete_task')
|
||||
def delete_task(data: RecordRequest):
|
||||
try:
|
||||
|
||||
NoteGenerator().delete_note(video_id=data.video_id, platform=data.platform)
|
||||
# TODO: 待持久化完成
|
||||
# NoteGenerator().delete_note(video_id=data.video_id, platform=data.platform)
|
||||
return R.success(msg='删除成功')
|
||||
except Exception as e:
|
||||
return R.error(msg=e)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from kombu import uuid
|
||||
|
||||
from app.db.models.providers import Provider
|
||||
from app.db.provider_dao import (
|
||||
insert_provider,
|
||||
init_provider_table,
|
||||
get_all_providers,
|
||||
get_provider_by_name,
|
||||
get_provider_by_id,
|
||||
@@ -16,32 +17,51 @@ from app.models.model_config import ModelConfig
|
||||
class ProviderService:
|
||||
|
||||
@staticmethod
|
||||
def serialize_provider(row: tuple) -> dict:
|
||||
def serialize_provider(row: Provider) -> dict:
|
||||
if not row:
|
||||
return None
|
||||
row = ProviderService.provider_to_dict(row)
|
||||
return {
|
||||
"id": row[0],
|
||||
"name": row[1],
|
||||
"logo": row[2],
|
||||
"type": row[3],
|
||||
"api_key": row[4],
|
||||
"base_url": row[5],
|
||||
"enabled": row[6],
|
||||
"created_at": row[7],
|
||||
"id": row.get("id"),
|
||||
"name": row.get("name"),
|
||||
"logo": row.get("logo"),
|
||||
"type":row.get("type"),
|
||||
"enabled": row.get("enabled"),
|
||||
"base_url": row.get("base_url"),
|
||||
"api_key": row.get("api_key"),
|
||||
"created_at": jsonable_encoder(row.get("created_at")),
|
||||
# "name": row[1],
|
||||
# "logo": row[2],
|
||||
# "type": row[3],
|
||||
# "api_key": row[4],
|
||||
# "base_url": row[5],
|
||||
# "enabled": row[6],
|
||||
# "created_at": row[7],
|
||||
}
|
||||
@staticmethod
|
||||
def serialize_provider_safe(row: tuple) -> dict:
|
||||
def serialize_provider_safe(row: Provider) -> dict:
|
||||
if not row:
|
||||
return None
|
||||
row = ProviderService.provider_to_dict(row)
|
||||
|
||||
return {
|
||||
"id": row[0],
|
||||
"name": row[1],
|
||||
"logo": row[2],
|
||||
"type": row[3],
|
||||
"api_key": ProviderService.mask_key(row[4]),
|
||||
"base_url": row[5],
|
||||
"enabled": row[6],
|
||||
"created_at": row[7],
|
||||
"id": row.get("id"),
|
||||
"name": row.get("name"),
|
||||
"logo": row.get("logo"),
|
||||
"type":row.get("type"),
|
||||
"enabled": row.get("enabled"),
|
||||
"base_url": row.get("base_url"),
|
||||
"api_key": ProviderService.mask_key(row.get("api_key")),
|
||||
"created_at": jsonable_encoder(row.get("created_at")),
|
||||
|
||||
# "id": row[0],
|
||||
# "name": row[1],
|
||||
# "logo": row[2],
|
||||
# "type": row[3],
|
||||
# "api_key": ProviderService.mask_key(row[4]),
|
||||
# "base_url": row[5],
|
||||
# "enabled": row[6],
|
||||
# "created_at": row[7],
|
||||
}
|
||||
@staticmethod
|
||||
def mask_key(key: str) -> str:
|
||||
@@ -56,15 +76,30 @@ class ProviderService:
|
||||
return insert_provider(id, name, api_key, base_url, logo, type_, enabled)
|
||||
except Exception as e:
|
||||
print('创建模式失败',e)
|
||||
|
||||
@staticmethod
|
||||
def provider_to_dict(p: Provider):
|
||||
return {
|
||||
"id": p.id,
|
||||
"name": p.name,
|
||||
"logo": p.logo,
|
||||
"type": p.type,
|
||||
"api_key": p.api_key,
|
||||
"base_url": p.base_url,
|
||||
"enabled": p.enabled,
|
||||
"created_at": p.created_at,
|
||||
}
|
||||
@staticmethod
|
||||
def get_all_providers():
|
||||
rows = get_all_providers()
|
||||
if rows is None:
|
||||
return []
|
||||
|
||||
return [ProviderService.serialize_provider(row) for row in rows] if rows else []
|
||||
@staticmethod
|
||||
def get_all_providers_safe():
|
||||
rows = get_all_providers()
|
||||
return [ProviderService.serialize_provider(row) for row in rows] if rows else []
|
||||
|
||||
return [ProviderService.serialize_provider(row) for row in rows] if (rows) else []
|
||||
@staticmethod
|
||||
def get_provider_by_name(name: str):
|
||||
row = get_provider_by_name(name)
|
||||
|
||||
@@ -6,15 +6,31 @@ from app.models.transcriber_model import TranscriptResult, TranscriptSegment
|
||||
from app.services.provider import ProviderService
|
||||
from app.transcriber.base import Transcriber
|
||||
from openai import OpenAI
|
||||
import ffmpeg
|
||||
import tempfile
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
MAX_SIZE_MB = 18
|
||||
MAX_SIZE_BYTES = MAX_SIZE_MB * 1024 * 1024
|
||||
def compress_audio(input_path: str, target_bitrate='64k') -> str:
|
||||
output_fd, output_path = tempfile.mkstemp(suffix=".mp3") # 临时输出文件
|
||||
os.close(output_fd) # 关闭文件描述符,ffmpeg 会用路径操作
|
||||
ffmpeg.input(input_path).output(output_path, audio_bitrate=target_bitrate).run(quiet=True, overwrite_output=True)
|
||||
return output_path
|
||||
|
||||
class GroqTranscriber(Transcriber, ABC):
|
||||
|
||||
|
||||
@timeit
|
||||
def transcript(self, file_path: str) -> TranscriptResult:
|
||||
file_size = os.path.getsize(file_path)
|
||||
if file_size > MAX_SIZE_BYTES:
|
||||
print(f"文件超过 {MAX_SIZE_MB}MB,开始压缩(当前 {round(file_size / (1024 * 1024), 2)}MB)...")
|
||||
file_path = compress_audio(file_path)
|
||||
print(f"压缩完成,临时路径:{file_path}")
|
||||
provider = ProviderService.get_provider_by_id('groq')
|
||||
|
||||
|
||||
if not provider:
|
||||
raise Exception("Groq 供应商未配置,请配置以后使用。")
|
||||
client = OpenAI(
|
||||
|
||||
285
backend/app/utils/export.py
Normal file
285
backend/app/utils/export.py
Normal file
@@ -0,0 +1,285 @@
|
||||
import os
|
||||
import re
|
||||
from urllib.parse import quote
|
||||
from markdown_pdf import MarkdownPdf, Section
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# 项目根路径(无论你在哪里运行)
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# 从 .env 获取 DATA_DIR,相对于 BASE_DIR 解析
|
||||
DATA_DIR_NAME = os.getenv("DATA_DIR", "data")
|
||||
DATA_DIR = os.path.join(BASE_DIR, DATA_DIR_NAME)
|
||||
SAVE_PATH = os.path.join(DATA_DIR, "note_output")
|
||||
IMAGE_BASE_URL = os.getenv("IMAGE_BASE_URL")
|
||||
STATIC_BASE = os.path.join(BASE_DIR, IMAGE_BASE_URL)
|
||||
|
||||
|
||||
class ExportUtils:
|
||||
def __init__(self, **kwargs):
|
||||
# 确认SAVE_PATH存在
|
||||
print(f"保存路径: {SAVE_PATH}")
|
||||
print(f"静态文件路径: {STATIC_BASE}")
|
||||
if not os.path.exists(SAVE_PATH):
|
||||
os.makedirs(SAVE_PATH)
|
||||
|
||||
def _embed_image_as_base64(self, img_path: str) -> str:
|
||||
"""
|
||||
将图片转换为 base64 格式嵌入
|
||||
"""
|
||||
import base64
|
||||
import mimetypes
|
||||
|
||||
try:
|
||||
# 获取 MIME 类型
|
||||
mime_type, _ = mimetypes.guess_type(img_path)
|
||||
if not mime_type:
|
||||
# 根据扩展名推断
|
||||
ext = os.path.splitext(img_path)[1].lower()
|
||||
mime_map = {
|
||||
'.png': 'image/png',
|
||||
'.jpg': 'image/jpeg',
|
||||
'.jpeg': 'image/jpeg',
|
||||
'.gif': 'image/gif',
|
||||
'.bmp': 'image/bmp',
|
||||
'.webp': 'image/webp',
|
||||
'.svg': 'image/svg+xml'
|
||||
}
|
||||
mime_type = mime_map.get(ext, 'image/png')
|
||||
|
||||
# 读取图片文件并转换为 base64
|
||||
with open(img_path, 'rb') as f:
|
||||
img_data = f.read()
|
||||
|
||||
base64_data = base64.b64encode(img_data).decode('utf-8')
|
||||
return f"data:{mime_type};base64,{base64_data}"
|
||||
|
||||
except Exception as e:
|
||||
print(f"图片 base64 编码失败 {img_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
def _get_normalized_path(self, path: str) -> str:
|
||||
"""
|
||||
获取规范化的绝对路径
|
||||
"""
|
||||
return os.path.normpath(os.path.abspath(path))
|
||||
|
||||
def _replace_static_paths_with_absolute(self, content: str) -> str:
|
||||
"""
|
||||
将 Markdown 中的图片路径替换为 base64 内嵌格式
|
||||
这样可以确保图片在 PDF 中正确显示
|
||||
"""
|
||||
|
||||
def repl(match):
|
||||
# 捕获 alt 文本和路径
|
||||
alt_text = match.group(1) if match.group(1) else ""
|
||||
img_path = match.group(2).strip()
|
||||
|
||||
print(f"处理图片路径: {img_path}")
|
||||
|
||||
# 处理 /static/ 开头的路径
|
||||
if img_path.startswith("/static/"):
|
||||
# 构建绝对路径
|
||||
relative_path = img_path.lstrip("/") # 移除开头的 /
|
||||
abs_path = os.path.join(BASE_DIR, relative_path)
|
||||
abs_path = self._get_normalized_path(abs_path)
|
||||
|
||||
# 检查文件是否存在并转换为 base64
|
||||
if os.path.exists(abs_path):
|
||||
base64_uri = self._embed_image_as_base64(abs_path)
|
||||
if base64_uri:
|
||||
print(f"图片转换为 base64 成功: {img_path}")
|
||||
return f""
|
||||
else:
|
||||
print(f"图片 base64 转换失败: {abs_path}")
|
||||
return f""
|
||||
else:
|
||||
print(f"警告:图片文件不存在 {abs_path}")
|
||||
return f""
|
||||
|
||||
# 处理相对路径(相对于 STATIC_BASE)
|
||||
elif not img_path.startswith(('http://', 'https://', 'data:')):
|
||||
# 尝试多个可能的路径
|
||||
possible_paths = [
|
||||
os.path.join(STATIC_BASE, img_path),
|
||||
os.path.abspath(img_path),
|
||||
os.path.join(BASE_DIR, img_path)
|
||||
]
|
||||
|
||||
for abs_path in possible_paths:
|
||||
abs_path = self._get_normalized_path(abs_path)
|
||||
if os.path.exists(abs_path):
|
||||
base64_uri = self._embed_image_as_base64(abs_path)
|
||||
if base64_uri:
|
||||
print(f"相对路径图片转换为 base64 成功: {img_path}")
|
||||
return f""
|
||||
break
|
||||
|
||||
print(f"警告:图片文件未找到 {img_path}")
|
||||
return f""
|
||||
|
||||
# HTTP/HTTPS 和 data: 路径保持不变
|
||||
elif img_path.startswith(('http://', 'https://', 'data:')):
|
||||
print(f"网络图片或 data URI 保持不变: {img_path[:50]}...")
|
||||
return match.group(0)
|
||||
|
||||
# 其他情况保持不变
|
||||
return match.group(0)
|
||||
|
||||
# 使用更精确的正则表达式匹配图片语法
|
||||
# 匹配  格式
|
||||
pattern = r'!\[([^\]]*)\]\(([^)]+)\)'
|
||||
result = re.sub(pattern, repl, content)
|
||||
|
||||
print("图片路径处理完成")
|
||||
return result
|
||||
|
||||
def _to_pdf(self, content: str, title: str):
|
||||
"""
|
||||
将 Markdown 内容转换为 PDF
|
||||
"""
|
||||
try:
|
||||
# 创建 PDF 对象,启用优化
|
||||
pdf = MarkdownPdf(
|
||||
optimize=True,
|
||||
# 添加一些可能有助于图片显示的配置
|
||||
# toc=False,
|
||||
# paper_size='A4',
|
||||
# margin=dict(top='1cm', bottom='1cm', left='1cm', right='1cm')
|
||||
)
|
||||
|
||||
# 添加内容段落
|
||||
pdf.add_section(Section(content))
|
||||
|
||||
# 保存 PDF
|
||||
save_path = os.path.join(SAVE_PATH, f"{title}.pdf")
|
||||
pdf.save(save_path)
|
||||
|
||||
print(f"PDF 导出成功: {save_path}")
|
||||
return save_path
|
||||
|
||||
except Exception as e:
|
||||
print(f"PDF 导出失败: {str(e)}")
|
||||
print("尝试使用基本配置...")
|
||||
try:
|
||||
# 尝试最基本的配置
|
||||
pdf = MarkdownPdf()
|
||||
pdf.add_section(Section(content))
|
||||
save_path = os.path.join(SAVE_PATH, f"{title}.pdf")
|
||||
pdf.save(save_path)
|
||||
print(f"基本配置 PDF 导出成功: {save_path}")
|
||||
return save_path
|
||||
except Exception as e2:
|
||||
print(f"基本配置也失败: {str(e2)}")
|
||||
raise e2
|
||||
|
||||
def export(self, output_format: str, title: str, content: str) -> str:
|
||||
"""
|
||||
导出内容为指定格式
|
||||
支持格式:pdf, html, word/docx, image/png
|
||||
"""
|
||||
content = content.strip()
|
||||
|
||||
# 处理图片路径
|
||||
print("开始处理图片路径...")
|
||||
content = self._replace_static_paths_with_absolute(content)
|
||||
|
||||
output_format = output_format.lower()
|
||||
|
||||
try:
|
||||
if output_format == "pdf":
|
||||
save_path = self._to_pdf(content, title)
|
||||
elif output_format == "html":
|
||||
save_path = self._to_html(content, title)
|
||||
elif output_format in ["word", "docx"]:
|
||||
save_path = self._to_word(content, title)
|
||||
elif output_format in ["image", "png"]:
|
||||
save_path = self._to_image(content, title)
|
||||
else:
|
||||
supported_formats = ["pdf", "html", "word/docx", "image/png"]
|
||||
raise ValueError(f"不支持的导出格式: {output_format}. 支持的格式: {', '.join(supported_formats)}")
|
||||
|
||||
print(f"导出完成: {save_path}")
|
||||
return save_path
|
||||
|
||||
except Exception as e:
|
||||
print(f"导出失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def get_supported_formats(self):
|
||||
"""
|
||||
返回支持的导出格式列表
|
||||
"""
|
||||
return {
|
||||
"pdf": "PDF 文档",
|
||||
"html": "HTML 网页",
|
||||
"word": "Word 文档 (.docx)",
|
||||
"docx": "Word 文档 (.docx)",
|
||||
"image": "PNG 图片",
|
||||
"png": "PNG 图片"
|
||||
}
|
||||
def debug_paths(self):
|
||||
"""
|
||||
调试方法:打印重要路径信息
|
||||
"""
|
||||
print("=== 路径调试信息 ===")
|
||||
print(f"BASE_DIR: {BASE_DIR}")
|
||||
print(f"DATA_DIR: {DATA_DIR}")
|
||||
print(f"SAVE_PATH: {SAVE_PATH}")
|
||||
print(f"STATIC_BASE: {STATIC_BASE}")
|
||||
print(f"IMAGE_BASE_URL: {IMAGE_BASE_URL}")
|
||||
print("==================")
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
ExportUtils().export("pdf",title='测试',content='''# 视频笔记:Facial Recognition Forces My Coworkers to Do Their Dishes
|
||||
|
||||
## 简介
|
||||
该视频展示了团队如何利用面部识别技术来监控和激励同事清洗餐具。通过结合硬件和软件,团队开发了一个“Dish Watcher”系统,旨在识别并提醒那些未清洁餐具的人。
|
||||
|
||||
## 背景
|
||||
- 团队面临的问题是同事们不愿意清洗餐具。
|
||||
- 为解决这一问题,团队决定在不告知的情况下使用技术来监控厨房区域。
|
||||
|
||||
## 实验设计
|
||||
1\. **设备安装**
|
||||
- 使用Raspberry Pi和隐藏摄像头来捕捉厨房水槽的活动。
|
||||
- 摄像头只在有人在水槽附近活动时录制,以节省存储空间。
|
||||
|
||||
2\. **软件开发**
|
||||
- 使用Cursor AI和Meta的项目来分析视频。
|
||||
- 系统能识别人员特征如发型、服装,并将结果发送到Discord服务器以提醒团队。
|
||||
|
||||
3\. **面部识别**
|
||||
- 通过视频流实时分析来判断是否有人留下了脏餐具。
|
||||
- 系统能识别并记录下未清洗餐具的人的详细特征。
|
||||
|
||||
* 展示了堆积的脏餐具,问题的严重性可见一斑。
|
||||
|
||||
## 实验过程
|
||||
- 系统成功捕获了少数“罪犯”,并通过Discord进行了通知。
|
||||
- 计划将摄像头隐藏在厨房的画作后,使其更加隐蔽。
|
||||
|
||||
* SAM 介绍了项目的背景。
|
||||
|
||||
## 结果
|
||||
- 实验初期,系统有效地识别了不清洗餐具的同事。
|
||||
- 由于摄像头的存在,同事们开始自觉清洗餐具,长时间未发现新的“罪犯”。
|
||||
|
||||
## 思考与改进
|
||||
- 团队意识到仅仅通过惩罚来改变行为可能效果有限,考虑奖励来激励清洗餐具。
|
||||
- 系统将改进为奖励机制,记录并表扬那些清洗餐具的人。
|
||||
|
||||
## 总结
|
||||
这次实验展示了技术在工作场所行为管理中的应用潜力。通过实验,团队不仅解决了餐具清洗的问题,还对如何更有效地激励员工有了更深的认识。
|
||||
|
||||
* 展示了系统对某位同事洗碗的实时面部识别。
|
||||
|
||||
## 结论
|
||||
- 应用技术可以有效改善工作环境中的小问题。
|
||||
- 积极的激励比惩罚更能驱动行为改变。
|
||||
|
||||
通过这次实验,团队不仅解决了餐具堆积的问题,还为未来更复杂的行为管理系统奠定了基础。 ''',)
|
||||
|
||||
@@ -19,6 +19,6 @@ class ResponseWrapper:
|
||||
def error(msg="error", code=500, data=None):
|
||||
return JSONResponse(content={
|
||||
"code": code,
|
||||
"msg": msg,
|
||||
"msg": str(msg),
|
||||
"data": data
|
||||
})
|
||||
39
backend/build.sh
Executable file
39
backend/build.sh
Executable file
@@ -0,0 +1,39 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
# uncomment this for debugging
|
||||
# set -x
|
||||
|
||||
# 切到项目根(假设脚本放在 script/ 目录)
|
||||
cd "$(dirname "$0")/.."
|
||||
|
||||
echo "当前工作目录:$(pwd)"
|
||||
|
||||
# 清理旧的构建
|
||||
echo "清理旧的构建..."
|
||||
rm -rf backend/dist backend/build ./BillNote_frontend/src-tauri/bin/*
|
||||
echo "清理完成。"
|
||||
|
||||
TARGET_TRIPLE=$(rustc -Vv | grep host | cut -f2 -d' ')
|
||||
echo "Detected target triple: $TARGET_TRIPLE"
|
||||
|
||||
# PyInstaller onedir 模式,直接输出到 Tauri 的 bin 目录
|
||||
echo "开始 PyInstaller 打包..."
|
||||
pyinstaller \
|
||||
--name BiliNoteBackend \
|
||||
--paths backend \
|
||||
--distpath ./BillNote_frontend/src-tauri/bin \
|
||||
--workpath backend/build \
|
||||
--specpath backend \
|
||||
--hidden-import uvicorn \
|
||||
--hidden-import fastapi \
|
||||
--hidden-import starlette \
|
||||
--add-data "app/db/builtin_providers.json:."\
|
||||
--add-data "../.env.env.example:.env" \
|
||||
"$(pwd)/backend/main.py" # 确保这里没有额外的空格,并使用绝对路径
|
||||
mv \
|
||||
./BillNote_frontend/src-tauri/bin/BiliNoteBackend/BiliNoteBackend\
|
||||
./BillNote_frontend/src-tauri/bin/BiliNoteBackend/BiliNoteBackend-$TARGET_TRIPLE
|
||||
|
||||
echo "PyInstaller 打包完成:"
|
||||
ls -l ./BillNote_frontend/src-tauri/bin/BiliNoteBackend # 这里会列出 onedir 模式下的目录内容
|
||||
echo "请检查 src-tauri/bin/BiliNoteBackend 目录,以确认打包内容。"
|
||||
@@ -1,15 +1,19 @@
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.staticfiles import StaticFiles
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from app.db.init_db import init_db
|
||||
from app.db.provider_dao import seed_default_providers
|
||||
from app.exceptions.exception_handlers import register_exception_handlers
|
||||
from app.db.model_dao import init_model_table
|
||||
from app.db.provider_dao import init_provider_table
|
||||
# from app.db.model_dao import init_model_table
|
||||
# from app.db.provider_dao import init_provider_table
|
||||
from app.utils.logger import get_logger
|
||||
from app import create_app
|
||||
from app.db.video_task_dao import init_video_task_table
|
||||
from app.transcriber.transcriber_provider import get_transcriber
|
||||
from events import register_handler
|
||||
from ffmpeg_helper import ensure_ffmpeg_or_raise
|
||||
@@ -32,21 +36,33 @@ if not os.path.exists(uploads_dir):
|
||||
if not os.path.exists(out_dir):
|
||||
os.makedirs(out_dir)
|
||||
|
||||
app = create_app()
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
register_handler()
|
||||
ensure_ffmpeg_or_raise()
|
||||
init_db()
|
||||
get_transcriber(transcriber_type=os.getenv("TRANSCRIBER_TYPE", "fast-whisper"))
|
||||
seed_default_providers()
|
||||
yield
|
||||
|
||||
app = create_app(lifespan=lifespan)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["tauri://localhost"], # ✅ 加上 Tauri 的 origin
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
register_exception_handlers(app)
|
||||
app.mount(static_path, StaticFiles(directory=static_dir), name="static")
|
||||
app.mount("/uploads", StaticFiles(directory=uploads_dir), name="uploads")
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
|
||||
|
||||
|
||||
register_handler()
|
||||
ensure_ffmpeg_or_raise()
|
||||
get_transcriber(transcriber_type=os.getenv("TRANSCRIBER_TYPE","fast-whisper"))
|
||||
init_video_task_table()
|
||||
init_provider_table()
|
||||
init_model_table()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Binary file not shown.
Reference in New Issue
Block a user