From c441d8776f009493705a00ffb92841a32d109d97 Mon Sep 17 00:00:00 2001 From: shiyu Date: Sun, 18 Jan 2026 21:01:59 +0800 Subject: [PATCH] feat: enhance backup functionality with section selection and import mode options --- domain/backup/api.py | 16 +- domain/backup/service.py | 320 +++++++++++++----- domain/backup/types.py | 1 + web/src/api/backup.ts | 12 +- web/src/i18n/locales/en.json | 14 + web/src/i18n/locales/zh.json | 14 + .../pages/SystemSettingsPage/BackupPage.tsx | 63 +++- 7 files changed, 333 insertions(+), 107 deletions(-) diff --git a/domain/backup/api.py b/domain/backup/api.py index 0b0b31f..842e0c7 100644 --- a/domain/backup/api.py +++ b/domain/backup/api.py @@ -1,6 +1,6 @@ import datetime -from fastapi import APIRouter, Depends, File, Request, UploadFile +from fastapi import APIRouter, Depends, File, Form, Query, Request, UploadFile from fastapi.responses import JSONResponse from domain.audit import AuditAction, audit @@ -16,8 +16,10 @@ router = APIRouter( @router.get("/export", summary="导出全站数据") @audit(action=AuditAction.DOWNLOAD, description="导出备份") -async def export_backup(request: Request): - data = await BackupService.export_data() +async def export_backup( + request: Request, sections: list[str] | None = Query(default=None) +): + data = await BackupService.export_data(sections=sections) timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") headers = {"Content-Disposition": f"attachment; filename=foxel_backup_{timestamp}.json"} return JSONResponse(content=data.model_dump(), headers=headers) @@ -25,6 +27,10 @@ async def export_backup(request: Request): @router.post("/import", summary="导入数据") @audit(action=AuditAction.UPLOAD, description="导入备份") -async def import_backup(request: Request, file: UploadFile = File(...)): - await BackupService.import_from_bytes(file.filename, await file.read()) +async def import_backup( + request: Request, + file: UploadFile = File(...), + mode: str = Form("replace"), +): + await BackupService.import_from_bytes(file.filename, await file.read(), mode=mode) return {"message": "数据导入成功。"} diff --git a/domain/backup/service.py b/domain/backup/service.py index a953c63..f09d632 100644 --- a/domain/backup/service.py +++ b/domain/backup/service.py @@ -20,18 +20,64 @@ from models.database import ( class BackupService: + ALL_SECTIONS = ( + "storage_adapters", + "user_accounts", + "automation_tasks", + "share_links", + "configurations", + "ai_providers", + "ai_models", + "ai_default_models", + "plugins", + ) + @classmethod - async def export_data(cls) -> BackupData: + async def export_data(cls, sections: list[str] | None = None) -> BackupData: + sections = cls._normalize_sections(sections) + section_set = set(sections) async with in_transaction(): - adapters = await StorageAdapter.all().values() - users = await UserAccount.all().values() - tasks = await AutomationTask.all().values() - shares = await ShareLink.all().values() - configs = await Configuration.all().values() - providers = await AIProvider.all().values() - models = await AIModel.all().values() - default_models = await AIDefaultModel.all().values() - plugins = await Plugin.all().values() + adapters = ( + await StorageAdapter.all().values() + if "storage_adapters" in section_set + else [] + ) + users = ( + await UserAccount.all().values() + if "user_accounts" in section_set + else [] + ) + tasks = ( + await AutomationTask.all().values() + if "automation_tasks" in section_set + else [] + ) + shares = ( + await ShareLink.all().values() + if "share_links" in section_set + else [] + ) + configs = ( + await Configuration.all().values() + if "configurations" in section_set + else [] + ) + providers = ( + await AIProvider.all().values() + if "ai_providers" in section_set + else [] + ) + models = ( + await AIModel.all().values() if "ai_models" in section_set else [] + ) + default_models = ( + await AIDefaultModel.all().values() + if "ai_default_models" in section_set + else [] + ) + plugins = ( + await Plugin.all().values() if "plugins" in section_set else [] + ) share_links = cls._serialize_datetime_fields( shares, ["created_at", "expires_at"] @@ -51,6 +97,7 @@ class BackupService: return BackupData( version=VERSION, + sections=sections, storage_adapters=list(adapters), user_accounts=list(users), automation_tasks=list(tasks), @@ -63,106 +110,195 @@ class BackupService: ) @classmethod - async def import_from_bytes(cls, filename: str, content: bytes) -> None: + async def import_from_bytes( + cls, filename: str, content: bytes, mode: str = "replace" + ) -> None: if not filename.endswith(".json"): raise HTTPException(status_code=400, detail="无效的文件类型, 请上传 .json 文件") try: raw_data = json.loads(content) except Exception: raise HTTPException(status_code=400, detail="无法解析JSON文件") - await cls.import_data(BackupData(**raw_data)) + await cls.import_data(BackupData(**raw_data), mode=mode) @classmethod - async def import_data(cls, payload: BackupData) -> None: + async def import_data(cls, payload: BackupData, mode: str = "replace") -> None: + sections = cls._normalize_sections(payload.sections) + if mode not in {"replace", "merge"}: + raise HTTPException(status_code=400, detail="无效的导入模式") + + share_links = ( + cls._parse_datetime_fields(payload.share_links, ["created_at", "expires_at"]) + if payload.share_links + else [] + ) + ai_providers = ( + cls._parse_datetime_fields(payload.ai_providers, ["created_at", "updated_at"]) + if payload.ai_providers + else [] + ) + ai_models = ( + cls._parse_datetime_fields(payload.ai_models, ["created_at", "updated_at"]) + if payload.ai_models + else [] + ) + ai_default_models = ( + cls._parse_datetime_fields( + payload.ai_default_models, ["created_at", "updated_at"] + ) + if payload.ai_default_models + else [] + ) + plugins = ( + cls._parse_datetime_fields(payload.plugins, ["created_at", "updated_at"]) + if payload.plugins + else [] + ) + async with in_transaction() as conn: - await ShareLink.all().using_db(conn).delete() - await AutomationTask.all().using_db(conn).delete() - await StorageAdapter.all().using_db(conn).delete() - await UserAccount.all().using_db(conn).delete() - await Configuration.all().using_db(conn).delete() - await AIDefaultModel.all().using_db(conn).delete() - await AIModel.all().using_db(conn).delete() - await AIProvider.all().using_db(conn).delete() - await Plugin.all().using_db(conn).delete() + if mode == "replace": + if "share_links" in sections: + await ShareLink.all().using_db(conn).delete() + if "automation_tasks" in sections: + await AutomationTask.all().using_db(conn).delete() + if "storage_adapters" in sections: + await StorageAdapter.all().using_db(conn).delete() + if "user_accounts" in sections: + await UserAccount.all().using_db(conn).delete() + if "configurations" in sections: + await Configuration.all().using_db(conn).delete() + if "ai_default_models" in sections: + await AIDefaultModel.all().using_db(conn).delete() + if "ai_models" in sections: + await AIModel.all().using_db(conn).delete() + if "ai_providers" in sections: + await AIProvider.all().using_db(conn).delete() + if "plugins" in sections: + await Plugin.all().using_db(conn).delete() - if payload.configurations: - await Configuration.bulk_create( - [Configuration(**config) for config in payload.configurations], - using_db=conn, - ) + if "configurations" in sections and payload.configurations: + if mode == "merge": + await cls._merge_records( + Configuration, payload.configurations, conn + ) + else: + await Configuration.bulk_create( + [Configuration(**config) for config in payload.configurations], + using_db=conn, + ) - if payload.user_accounts: - await UserAccount.bulk_create( - [UserAccount(**user) for user in payload.user_accounts], - using_db=conn, - ) + if "user_accounts" in sections and payload.user_accounts: + if mode == "merge": + await cls._merge_records(UserAccount, payload.user_accounts, conn) + else: + await UserAccount.bulk_create( + [UserAccount(**user) for user in payload.user_accounts], + using_db=conn, + ) - if payload.storage_adapters: - await StorageAdapter.bulk_create( - [StorageAdapter(**adapter) for adapter in payload.storage_adapters], - using_db=conn, - ) + if "storage_adapters" in sections and payload.storage_adapters: + if mode == "merge": + await cls._merge_records( + StorageAdapter, payload.storage_adapters, conn + ) + else: + await StorageAdapter.bulk_create( + [StorageAdapter(**adapter) for adapter in payload.storage_adapters], + using_db=conn, + ) - if payload.automation_tasks: - await AutomationTask.bulk_create( - [AutomationTask(**task) for task in payload.automation_tasks], - using_db=conn, - ) + if "automation_tasks" in sections and payload.automation_tasks: + if mode == "merge": + await cls._merge_records( + AutomationTask, payload.automation_tasks, conn + ) + else: + await AutomationTask.bulk_create( + [AutomationTask(**task) for task in payload.automation_tasks], + using_db=conn, + ) - if payload.share_links: - await ShareLink.bulk_create( - [ - ShareLink(**share) - for share in cls._parse_datetime_fields( - payload.share_links, ["created_at", "expires_at"] - ) - ], - using_db=conn, - ) + if "share_links" in sections and share_links: + if mode == "merge": + await cls._merge_records(ShareLink, share_links, conn) + else: + await ShareLink.bulk_create( + [ShareLink(**share) for share in share_links], + using_db=conn, + ) - if payload.ai_providers: - await AIProvider.bulk_create( - [ - AIProvider(**item) - for item in cls._parse_datetime_fields( - payload.ai_providers, ["created_at", "updated_at"] - ) - ], - using_db=conn, - ) + if "ai_providers" in sections and ai_providers: + if mode == "merge": + await cls._merge_records(AIProvider, ai_providers, conn) + else: + await AIProvider.bulk_create( + [AIProvider(**item) for item in ai_providers], + using_db=conn, + ) - if payload.ai_models: - await AIModel.bulk_create( - [ - AIModel(**item) - for item in cls._parse_datetime_fields( - payload.ai_models, ["created_at", "updated_at"] - ) - ], - using_db=conn, - ) + if "ai_models" in sections and ai_models: + if mode == "merge": + await cls._merge_records(AIModel, ai_models, conn) + else: + await AIModel.bulk_create( + [AIModel(**item) for item in ai_models], + using_db=conn, + ) - if payload.ai_default_models: - await AIDefaultModel.bulk_create( - [ - AIDefaultModel(**item) - for item in cls._parse_datetime_fields( - payload.ai_default_models, ["created_at", "updated_at"] - ) - ], - using_db=conn, - ) + if "ai_default_models" in sections and ai_default_models: + if mode == "merge": + await cls._merge_records( + AIDefaultModel, ai_default_models, conn + ) + else: + await AIDefaultModel.bulk_create( + [AIDefaultModel(**item) for item in ai_default_models], + using_db=conn, + ) - if payload.plugins: - await Plugin.bulk_create( - [ - Plugin(**item) - for item in cls._parse_datetime_fields( - payload.plugins, ["created_at", "updated_at"] - ) - ], - using_db=conn, - ) + if "plugins" in sections and plugins: + if mode == "merge": + await cls._merge_records(Plugin, plugins, conn) + else: + await Plugin.bulk_create( + [Plugin(**item) for item in plugins], + using_db=conn, + ) + + @classmethod + def _normalize_sections(cls, sections: list[str] | None) -> list[str]: + if not sections: + return list(cls.ALL_SECTIONS) + normalized = [item for item in sections if item] + invalid = [item for item in normalized if item not in cls.ALL_SECTIONS] + if invalid: + raise HTTPException( + status_code=400, detail=f"无效的备份分区: {', '.join(invalid)}" + ) + result: list[str] = [] + seen = set() + for item in normalized: + if item in seen: + continue + seen.add(item) + result.append(item) + return result + + @staticmethod + async def _merge_records(model, records: list[dict], using_db) -> None: + for record in records: + data = dict(record) + record_id = data.pop("id", None) + if record_id is None: + await model.create(using_db=using_db, **data) + continue + updated = ( + await model.filter(id=record_id) + .using_db(using_db) + .update(**data) + ) + if updated == 0: + await model.create(using_db=using_db, id=record_id, **data) @staticmethod def _serialize_datetime_fields( diff --git a/domain/backup/types.py b/domain/backup/types.py index 272cfde..a7c1aa2 100644 --- a/domain/backup/types.py +++ b/domain/backup/types.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, Field class BackupData(BaseModel): version: str | None = None + sections: list[str] = Field(default_factory=list) storage_adapters: list[dict[str, Any]] = Field(default_factory=list) user_accounts: list[dict[str, Any]] = Field(default_factory=list) automation_tasks: list[dict[str, Any]] = Field(default_factory=list) diff --git a/web/src/api/backup.ts b/web/src/api/backup.ts index 7f7669b..517e13e 100644 --- a/web/src/api/backup.ts +++ b/web/src/api/backup.ts @@ -1,8 +1,11 @@ import request from './client'; export const backupApi = { - export: async () => { - const response = await request('/backup/export', { + export: async (sections?: string[]) => { + const params = new URLSearchParams(); + (sections || []).forEach((section) => params.append('sections', section)); + const query = params.toString(); + const response = await request(`/backup/export${query ? `?${query}` : ''}`, { method: 'GET', rawResponse: true, }) as Response; @@ -27,12 +30,13 @@ export const backupApi = { window.URL.revokeObjectURL(url); }, - import: async (file: File) => { + import: async (file: File, mode: 'replace' | 'merge' = 'replace') => { const formData = new FormData(); formData.append('file', file); + formData.append('mode', mode); return request('/backup/import', { method: 'POST', body: formData, }); }, -}; \ No newline at end of file +}; diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index ff87d64..bbfbbf5 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -556,10 +556,24 @@ "Export": "Export", "Import": "Import", "Export all data (adapters, users, tasks, shares) into a JSON file.": "Export all data (adapters, users, tasks, shares) into a JSON file.", + "Export selected data into a JSON file.": "Export selected data into a JSON file.", "Keep your backup file safe.": "Keep your backup file safe.", + "Select backup sections": "Select backup sections", + "User Accounts": "User Accounts", + "Share Links": "Share Links", + "Configurations": "Configurations", + "AI Providers": "AI Providers", + "AI Models": "AI Models", + "AI Default Models": "AI Default Models", + "Plugin Data": "Plugins", "Export Backup": "Export Backup", "Restore data from a previously exported JSON file.": "Restore data from a previously exported JSON file.", "Warning: This will clear and overwrite existing data.": "Warning: This will clear and overwrite existing data.", + "Import mode": "Import mode", + "Merge (upsert by ID)": "Merge (upsert by ID)", + "Replace (clear before import)": "Replace (clear before import)", + "Warning: This will clear data in the backup sections before importing.": "Warning: This will clear data in the backup sections before importing.", + "Warning: This will merge data in the backup sections and overwrite existing records with the same ID.": "Warning: This will merge data in the backup sections and overwrite existing records with the same ID.", "Choose File and Restore": "Choose File and Restore", "No files yet here": "No files yet here", "This folder is empty": "This folder is empty", diff --git a/web/src/i18n/locales/zh.json b/web/src/i18n/locales/zh.json index b08f8a0..0b80cd7 100644 --- a/web/src/i18n/locales/zh.json +++ b/web/src/i18n/locales/zh.json @@ -547,10 +547,24 @@ "Export": "导出", "Import": "恢复", "Export all data (adapters, users, tasks, shares) into a JSON file.": "点击按钮将所有数据(包括存储、用户、自动化任务和分享)导出为一个 JSON 文件。", + "Export selected data into a JSON file.": "导出选中的数据为一个 JSON 文件。", "Keep your backup file safe.": "请妥善保管您的备份文件。", + "Select backup sections": "选择备份内容", + "User Accounts": "账号", + "Share Links": "分享列表", + "Configurations": "配置", + "AI Providers": "AI 服务商", + "AI Models": "AI 模型", + "AI Default Models": "AI 默认模型", + "Plugin Data": "插件", "Export Backup": "导出备份", "Restore data from a previously exported JSON file.": "从之前导出的JSON文件恢复数据。", "Warning: This will clear and overwrite existing data.": "警告:此操作将清除并覆盖现有数据。", + "Import mode": "导入方式", + "Merge (upsert by ID)": "增量+覆盖(按 ID)", + "Replace (clear before import)": "清空后导入", + "Warning: This will clear data in the backup sections before importing.": "警告:此操作会先清空备份中包含的分区数据,再导入。", + "Warning: This will merge data in the backup sections and overwrite existing records with the same ID.": "警告:此操作会合并备份中包含的分区数据,并按 ID 覆盖已存在记录。", "Choose File and Restore": "选择文件并恢复", "No files yet here": "这里还没有任何文件", "This folder is empty": "此目录为空", diff --git a/web/src/pages/SystemSettingsPage/BackupPage.tsx b/web/src/pages/SystemSettingsPage/BackupPage.tsx index 93eacbe..e52d30d 100644 --- a/web/src/pages/SystemSettingsPage/BackupPage.tsx +++ b/web/src/pages/SystemSettingsPage/BackupPage.tsx @@ -1,5 +1,5 @@ import { memo, useState } from 'react'; -import { Button, Typography, Upload, message, Modal, Card } from 'antd'; +import { Button, Typography, Upload, message, Modal, Card, Checkbox, Space, Radio } from 'antd'; import PageCard from '../../components/PageCard'; import { UploadOutlined, DownloadOutlined } from '@ant-design/icons'; import { backupApi } from '../../api/backup'; @@ -7,14 +7,40 @@ import { useI18n } from '../../i18n'; const { Paragraph, Text } = Typography; +const BACKUP_SECTIONS = [ + { key: 'user_accounts', labelKey: 'User Accounts' }, + { key: 'storage_adapters', labelKey: 'Storage Adapters' }, + { key: 'automation_tasks', labelKey: 'Automation Tasks' }, + { key: 'share_links', labelKey: 'Share Links' }, + { key: 'configurations', labelKey: 'Configurations' }, + { key: 'ai_providers', labelKey: 'AI Providers' }, + { key: 'ai_models', labelKey: 'AI Models' }, + { key: 'ai_default_models', labelKey: 'AI Default Models' }, + { key: 'plugins', labelKey: 'Plugin Data' }, +] as const; + +type BackupSection = typeof BACKUP_SECTIONS[number]['key']; +const ALL_SECTION_KEYS = BACKUP_SECTIONS.map((section) => section.key) as BackupSection[]; + const BackupPage = memo(function BackupPage() { const [loading, setLoading] = useState(false); + const [selectedSections, setSelectedSections] = useState(ALL_SECTION_KEYS); + const [importMode, setImportMode] = useState<'replace' | 'merge'>('replace'); const { t } = useI18n(); + const importWarning = importMode === 'replace' + ? t('Warning: This will clear data in the backup sections before importing.') + : t('Warning: This will merge data in the backup sections and overwrite existing records with the same ID.'); + const importWarningType = importMode === 'replace' ? 'danger' : 'warning'; + const exportOptions = BACKUP_SECTIONS.map((section) => ({ + label: t(section.labelKey), + value: section.key, + })); + const canExport = selectedSections.length > 0; const handleExport = async () => { setLoading(true); try { - await backupApi.export(); + await backupApi.export(selectedSections); message.success(t('Export started, check your downloads.')); } catch (e: any) { message.error(e.message || t('Export failed')); @@ -29,7 +55,9 @@ const BackupPage = memo(function BackupPage() { content: ( {t('Are you sure to import from this file?')} - {t('Warning: This will overwrite all data including users (with passwords), settings, storages and tasks. Irreversible!')} + + {importWarning} + ), okText: t('Confirm Import'), @@ -38,7 +66,7 @@ const BackupPage = memo(function BackupPage() { onOk: async () => { setLoading(true); try { - const response = await backupApi.import(file); + const response = await backupApi.import(file, importMode); message.success(response.message || t('Import succeeded! The page will refresh.')); setTimeout(() => window.location.reload(), 2000); } catch (e: any) { @@ -57,13 +85,22 @@ const BackupPage = memo(function BackupPage() {
- {t('Export all data (adapters, users, tasks, shares) into a JSON file.')} + {t('Export selected data into a JSON file.')} {t('Keep your backup file safe.')} + + {t('Select backup sections')} + setSelectedSections(values as BackupSection[])} + /> + @@ -71,8 +108,22 @@ const BackupPage = memo(function BackupPage() { {t('Restore data from a previously exported JSON file.')} - {t('Warning: This will clear and overwrite existing data.')} + + {t('Import mode')} + setImportMode(event.target.value)} + > + {t('Merge (upsert by ID)')} + {t('Replace (clear before import)')} + + + {importWarning} + +