mirror of
https://github.com/Awuqing/BackupX.git
synced 2026-06-05 09:49:37 +08:00
first commit
This commit is contained in:
188
server/internal/app/app.go
Normal file
188
server/internal/app/app.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
stdhttp "net/http"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/backup"
|
||||
backupretention "backupx/server/internal/backup/retention"
|
||||
"backupx/server/internal/config"
|
||||
"backupx/server/internal/database"
|
||||
aphttp "backupx/server/internal/http"
|
||||
"backupx/server/internal/logger"
|
||||
"backupx/server/internal/notify"
|
||||
"backupx/server/internal/repository"
|
||||
"backupx/server/internal/scheduler"
|
||||
"backupx/server/internal/security"
|
||||
"backupx/server/internal/service"
|
||||
"backupx/server/internal/storage"
|
||||
"backupx/server/internal/storage/codec"
|
||||
"backupx/server/internal/storage/googledrive"
|
||||
"backupx/server/internal/storage/localdisk"
|
||||
storageAliyun "backupx/server/internal/storage/aliyun"
|
||||
storageTencent "backupx/server/internal/storage/tencent"
|
||||
storageQiniu "backupx/server/internal/storage/qiniu"
|
||||
storageS3 "backupx/server/internal/storage/s3"
|
||||
storageWebDAV "backupx/server/internal/storage/webdav"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Application struct {
|
||||
cfg config.Config
|
||||
version string
|
||||
logger *zap.Logger
|
||||
db *gorm.DB
|
||||
httpServer *stdhttp.Server
|
||||
scheduler *scheduler.Service
|
||||
}
|
||||
|
||||
func New(ctx context.Context, cfg config.Config, version string) (*Application, error) {
|
||||
appLogger, err := logger.New(cfg.Log)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init logger: %w", err)
|
||||
}
|
||||
|
||||
db, err := database.Open(cfg.Database, appLogger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init database: %w", err)
|
||||
}
|
||||
|
||||
userRepo := repository.NewUserRepository(db)
|
||||
systemConfigRepo := repository.NewSystemConfigRepository(db)
|
||||
storageTargetRepo := repository.NewStorageTargetRepository(db)
|
||||
backupTaskRepo := repository.NewBackupTaskRepository(db)
|
||||
backupRecordRepo := repository.NewBackupRecordRepository(db)
|
||||
notificationRepo := repository.NewNotificationRepository(db)
|
||||
oauthSessionRepo := repository.NewOAuthSessionRepository(db)
|
||||
resolvedSecurity, err := service.ResolveSecurity(ctx, cfg.Security, systemConfigRepo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve security config: %w", err)
|
||||
}
|
||||
|
||||
jwtManager := security.NewJWTManager(resolvedSecurity.JWTSecret, config.MustJWTDuration(cfg.Security))
|
||||
rateLimiter := security.NewLoginRateLimiter(5, time.Minute)
|
||||
authService := service.NewAuthService(userRepo, systemConfigRepo, jwtManager, rateLimiter)
|
||||
systemService := service.NewSystemService(cfg, version, time.Now().UTC())
|
||||
configCipher := codec.NewConfigCipher(resolvedSecurity.EncryptionKey)
|
||||
storageRegistry := storage.NewRegistry(
|
||||
localdisk.NewFactory(),
|
||||
storageS3.NewFactory(),
|
||||
storageWebDAV.NewFactory(),
|
||||
googledrive.NewFactory(),
|
||||
storageAliyun.NewFactory(),
|
||||
storageTencent.NewFactory(),
|
||||
storageQiniu.NewFactory(),
|
||||
)
|
||||
storageTargetService := service.NewStorageTargetService(storageTargetRepo, oauthSessionRepo, storageRegistry, configCipher)
|
||||
storageTargetService.SetBackupTaskRepository(backupTaskRepo)
|
||||
storageTargetService.SetBackupRecordRepository(backupRecordRepo)
|
||||
backupTaskService := service.NewBackupTaskService(backupTaskRepo, storageTargetRepo, configCipher)
|
||||
backupRunnerRegistry := backup.NewRegistry(backup.NewFileRunner(), backup.NewSQLiteRunner(), backup.NewMySQLRunner(nil), backup.NewPostgreSQLRunner(nil))
|
||||
logHub := backup.NewLogHub()
|
||||
retentionService := backupretention.NewService(backupRecordRepo)
|
||||
notifyRegistry := notify.NewRegistry(notify.NewEmailNotifier(), notify.NewWebhookNotifier(), notify.NewTelegramNotifier())
|
||||
notificationService := service.NewNotificationService(notificationRepo, notifyRegistry, configCipher)
|
||||
backupExecutionService := service.NewBackupExecutionService(backupTaskRepo, backupRecordRepo, storageTargetRepo, storageRegistry, backupRunnerRegistry, logHub, retentionService, configCipher, notificationService, cfg.Backup.TempDir, cfg.Backup.MaxConcurrent)
|
||||
schedulerService := scheduler.NewService(backupTaskRepo, backupExecutionService, appLogger)
|
||||
backupTaskService.SetScheduler(schedulerService)
|
||||
backupRecordService := service.NewBackupRecordService(backupRecordRepo, backupExecutionService, logHub)
|
||||
dashboardService := service.NewDashboardService(backupTaskRepo, backupRecordRepo, storageTargetRepo)
|
||||
settingsService := service.NewSettingsService(systemConfigRepo)
|
||||
|
||||
// Cluster: Node management
|
||||
nodeRepo := repository.NewNodeRepository(db)
|
||||
nodeService := service.NewNodeService(nodeRepo)
|
||||
if err := nodeService.EnsureLocalNode(ctx); err != nil {
|
||||
appLogger.Warn("failed to ensure local node", zap.Error(err))
|
||||
}
|
||||
|
||||
router := aphttp.NewRouter(aphttp.RouterDependencies{
|
||||
Config: cfg,
|
||||
Version: version,
|
||||
Logger: appLogger,
|
||||
AuthService: authService,
|
||||
SystemService: systemService,
|
||||
StorageTargetService: storageTargetService,
|
||||
BackupTaskService: backupTaskService,
|
||||
BackupExecutionService: backupExecutionService,
|
||||
BackupRecordService: backupRecordService,
|
||||
NotificationService: notificationService,
|
||||
DashboardService: dashboardService,
|
||||
SettingsService: settingsService,
|
||||
NodeService: nodeService,
|
||||
JWTManager: jwtManager,
|
||||
UserRepository: userRepo,
|
||||
SystemConfigRepo: systemConfigRepo,
|
||||
})
|
||||
|
||||
httpServer := &stdhttp.Server{
|
||||
Addr: cfg.Address(),
|
||||
Handler: router,
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
return &Application{
|
||||
cfg: cfg,
|
||||
version: version,
|
||||
logger: appLogger,
|
||||
db: db,
|
||||
httpServer: httpServer,
|
||||
scheduler: schedulerService,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *Application) Run(ctx context.Context) error {
|
||||
if a.scheduler != nil {
|
||||
if err := a.scheduler.Start(context.Background()); err != nil {
|
||||
return fmt.Errorf("start scheduler: %w", err)
|
||||
}
|
||||
}
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
a.logger.Info("http server listening", zap.String("addr", a.cfg.Address()), zap.String("version", a.version))
|
||||
if err := a.httpServer.ListenAndServe(); err != nil && !errors.Is(err, stdhttp.ErrServerClosed) {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
errCh <- nil
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
a.logger.Info("shutdown signal received")
|
||||
if err := a.httpServer.Shutdown(shutdownCtx); err != nil {
|
||||
return fmt.Errorf("shutdown http server: %w", err)
|
||||
}
|
||||
if a.scheduler != nil {
|
||||
if err := a.scheduler.Stop(context.Background()); err != nil {
|
||||
return fmt.Errorf("stop scheduler: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
return fmt.Errorf("serve http: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Application) Close() {
|
||||
if a.logger != nil {
|
||||
_ = a.logger.Sync()
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Application) Logger() *zap.Logger {
|
||||
return a.logger
|
||||
}
|
||||
|
||||
func ErrorField(err error) zap.Field {
|
||||
return zap.Error(err)
|
||||
}
|
||||
55
server/internal/apperror/error.go
Normal file
55
server/internal/apperror/error.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package apperror
|
||||
|
||||
import "net/http"
|
||||
|
||||
type AppError struct {
|
||||
Status int
|
||||
Code string
|
||||
Message string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *AppError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
if e.Err != nil {
|
||||
return e.Err.Error()
|
||||
}
|
||||
return e.Message
|
||||
}
|
||||
|
||||
func (e *AppError) Unwrap() error {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
return e.Err
|
||||
}
|
||||
|
||||
func New(status int, code, message string, err error) *AppError {
|
||||
return &AppError{Status: status, Code: code, Message: message, Err: err}
|
||||
}
|
||||
|
||||
func BadRequest(code, message string, err error) *AppError {
|
||||
return New(http.StatusBadRequest, code, message, err)
|
||||
}
|
||||
|
||||
func Unauthorized(code, message string, err error) *AppError {
|
||||
return New(http.StatusUnauthorized, code, message, err)
|
||||
}
|
||||
|
||||
func Forbidden(code, message string, err error) *AppError {
|
||||
return New(http.StatusForbidden, code, message, err)
|
||||
}
|
||||
|
||||
func Conflict(code, message string, err error) *AppError {
|
||||
return New(http.StatusConflict, code, message, err)
|
||||
}
|
||||
|
||||
func TooManyRequests(code, message string, err error) *AppError {
|
||||
return New(http.StatusTooManyRequests, code, message, err)
|
||||
}
|
||||
|
||||
func Internal(code, message string, err error) *AppError {
|
||||
return New(http.StatusInternalServerError, code, message, err)
|
||||
}
|
||||
189
server/internal/backup/archive.go
Normal file
189
server/internal/backup/archive.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package backup
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func CreateTarGz(ctx context.Context, sourcePath string, excludePatterns []string, destinationPath string, logger LogWriter) (int64, error) {
|
||||
sourcePath = filepath.Clean(strings.TrimSpace(sourcePath))
|
||||
if sourcePath == "" {
|
||||
return 0, fmt.Errorf("source path is required")
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(destinationPath), 0o755); err != nil {
|
||||
return 0, fmt.Errorf("create destination directory: %w", err)
|
||||
}
|
||||
file, err := os.Create(destinationPath)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("create archive file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
gzipWriter, err := gzip.NewWriterLevel(file, gzip.DefaultCompression)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("create gzip writer: %w", err)
|
||||
}
|
||||
defer gzipWriter.Close()
|
||||
tarWriter := tar.NewWriter(gzipWriter)
|
||||
defer tarWriter.Close()
|
||||
|
||||
baseParent := filepath.Dir(sourcePath)
|
||||
walkErr := filepath.Walk(sourcePath, func(path string, info os.FileInfo, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return walkErr
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
rel, err := filepath.Rel(baseParent, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rel = filepath.ToSlash(rel)
|
||||
if shouldExcludeArchive(rel, excludePatterns) {
|
||||
if info.IsDir() {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
if logger != nil {
|
||||
logger.WriteLine(fmt.Sprintf("跳过排除路径:%s", rel))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
header, err := tar.FileInfoHeader(info, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("build tar header: %w", err)
|
||||
}
|
||||
header.Name = rel
|
||||
if info.IsDir() && !strings.HasSuffix(header.Name, "/") {
|
||||
header.Name += "/"
|
||||
}
|
||||
if err := tarWriter.WriteHeader(header); err != nil {
|
||||
return fmt.Errorf("write tar header: %w", err)
|
||||
}
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
input, err := os.Open(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open source file: %w", err)
|
||||
}
|
||||
defer input.Close()
|
||||
if _, err := io.Copy(tarWriter, input); err != nil {
|
||||
return fmt.Errorf("write tar body: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if walkErr != nil {
|
||||
return 0, walkErr
|
||||
}
|
||||
if err := tarWriter.Close(); err != nil {
|
||||
return 0, fmt.Errorf("close tar writer: %w", err)
|
||||
}
|
||||
if err := gzipWriter.Close(); err != nil {
|
||||
return 0, fmt.Errorf("close gzip writer: %w", err)
|
||||
}
|
||||
if err := file.Close(); err != nil {
|
||||
return 0, fmt.Errorf("close archive file: %w", err)
|
||||
}
|
||||
info, err := os.Stat(destinationPath)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("stat archive file: %w", err)
|
||||
}
|
||||
return info.Size(), nil
|
||||
}
|
||||
|
||||
func ExtractTarGz(ctx context.Context, archivePath string, destinationDir string, logger LogWriter) error {
|
||||
archivePath = filepath.Clean(archivePath)
|
||||
destinationDir = filepath.Clean(destinationDir)
|
||||
file, err := os.Open(archivePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open archive file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
gzipReader, err := gzip.NewReader(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open gzip reader: %w", err)
|
||||
}
|
||||
defer gzipReader.Close()
|
||||
tarReader := tar.NewReader(gzipReader)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
header, err := tarReader.Next()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("read tar entry: %w", err)
|
||||
}
|
||||
targetPath, err := secureJoin(destinationDir, header.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch header.Typeflag {
|
||||
case tar.TypeDir:
|
||||
if err := os.MkdirAll(targetPath, 0o755); err != nil {
|
||||
return fmt.Errorf("create restore directory: %w", err)
|
||||
}
|
||||
case tar.TypeReg:
|
||||
if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil {
|
||||
return fmt.Errorf("create restore parent directory: %w", err)
|
||||
}
|
||||
output, err := os.Create(targetPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create restore file: %w", err)
|
||||
}
|
||||
if _, err := io.Copy(output, tarReader); err != nil {
|
||||
output.Close()
|
||||
return fmt.Errorf("write restore file: %w", err)
|
||||
}
|
||||
if err := output.Close(); err != nil {
|
||||
return fmt.Errorf("close restore file: %w", err)
|
||||
}
|
||||
if logger != nil {
|
||||
logger.WriteLine(fmt.Sprintf("已恢复文件:%s", targetPath))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func shouldExcludeArchive(rel string, patterns []string) bool {
|
||||
rel = filepath.ToSlash(strings.TrimSpace(rel))
|
||||
base := filepath.Base(rel)
|
||||
for _, pattern := range patterns {
|
||||
trimmed := strings.TrimSpace(pattern)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
if matched, _ := filepath.Match(trimmed, rel); matched {
|
||||
return true
|
||||
}
|
||||
if matched, _ := filepath.Match(trimmed, base); matched {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(rel, trimmed) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func secureJoin(root string, relative string) (string, error) {
|
||||
root = filepath.Clean(root)
|
||||
target := filepath.Clean(filepath.Join(root, filepath.FromSlash(relative)))
|
||||
rootWithSep := root + string(filepath.Separator)
|
||||
if target != root && !strings.HasPrefix(target, rootWithSep) {
|
||||
return "", fmt.Errorf("archive entry escapes destination: %s", relative)
|
||||
}
|
||||
return target, nil
|
||||
}
|
||||
41
server/internal/backup/command.go
Normal file
41
server/internal/backup/command.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package backup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
type CommandOptions struct {
|
||||
Stdin io.Reader
|
||||
Stdout io.Writer
|
||||
Stderr io.Writer
|
||||
Env []string
|
||||
}
|
||||
|
||||
type CommandExecutor interface {
|
||||
LookPath(file string) (string, error)
|
||||
Run(ctx context.Context, name string, args []string, options CommandOptions) error
|
||||
}
|
||||
|
||||
type OSCommandExecutor struct{}
|
||||
|
||||
func NewOSCommandExecutor() *OSCommandExecutor {
|
||||
return &OSCommandExecutor{}
|
||||
}
|
||||
|
||||
func (e *OSCommandExecutor) LookPath(file string) (string, error) {
|
||||
return exec.LookPath(file)
|
||||
}
|
||||
|
||||
func (e *OSCommandExecutor) Run(ctx context.Context, name string, args []string, options CommandOptions) error {
|
||||
cmd := exec.CommandContext(ctx, name, args...)
|
||||
cmd.Stdin = options.Stdin
|
||||
cmd.Stdout = options.Stdout
|
||||
cmd.Stderr = options.Stderr
|
||||
if len(options.Env) > 0 {
|
||||
cmd.Env = append(os.Environ(), options.Env...)
|
||||
}
|
||||
return cmd.Run()
|
||||
}
|
||||
37
server/internal/backup/command_executor.go
Normal file
37
server/internal/backup/command_executor.go
Normal file
@@ -0,0 +1,37 @@
|
||||
//go:build ignore
|
||||
|
||||
package backup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
type CommandExecutor interface {
|
||||
LookPath(file string) (string, error)
|
||||
Run(ctx context.Context, name string, args []string, env map[string]string, stdin io.Reader, stdout io.Writer, stderr io.Writer) error
|
||||
}
|
||||
|
||||
type OSCommandExecutor struct{}
|
||||
|
||||
func NewOSCommandExecutor() *OSCommandExecutor {
|
||||
return &OSCommandExecutor{}
|
||||
}
|
||||
|
||||
func (e *OSCommandExecutor) LookPath(file string) (string, error) {
|
||||
return exec.LookPath(file)
|
||||
}
|
||||
|
||||
func (e *OSCommandExecutor) Run(ctx context.Context, name string, args []string, env map[string]string, stdin io.Reader, stdout io.Writer, stderr io.Writer) error {
|
||||
command := exec.CommandContext(ctx, name, args...)
|
||||
command.Stdin = stdin
|
||||
command.Stdout = stdout
|
||||
command.Stderr = stderr
|
||||
command.Env = os.Environ()
|
||||
for key, value := range env {
|
||||
command.Env = append(command.Env, key+"="+value)
|
||||
}
|
||||
return command.Run()
|
||||
}
|
||||
16
server/internal/backup/database_names.go
Normal file
16
server/internal/backup/database_names.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package backup
|
||||
|
||||
import "strings"
|
||||
|
||||
func normalizeDatabaseNames(items []string) []string {
|
||||
result := make([]string, 0, len(items))
|
||||
for _, item := range items {
|
||||
for _, part := range strings.Split(item, ",") {
|
||||
trimmed := strings.TrimSpace(part)
|
||||
if trimmed != "" {
|
||||
result = append(result, trimmed)
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
106
server/internal/backup/database_runners_test.go
Normal file
106
server/internal/backup/database_runners_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package backup
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type fakeCommandExecutor struct {
|
||||
lastName string
|
||||
lastArgs []string
|
||||
env []string
|
||||
lookupErr error
|
||||
runFunc func(name string, args []string, options CommandOptions) error
|
||||
}
|
||||
|
||||
func (f *fakeCommandExecutor) LookPath(string) (string, error) {
|
||||
if f.lookupErr != nil {
|
||||
return "", f.lookupErr
|
||||
}
|
||||
return "/usr/bin/fake", nil
|
||||
}
|
||||
|
||||
func (f *fakeCommandExecutor) Run(_ context.Context, name string, args []string, options CommandOptions) error {
|
||||
f.lastName = name
|
||||
f.lastArgs = append([]string{}, args...)
|
||||
f.env = append([]string{}, options.Env...)
|
||||
if f.runFunc != nil {
|
||||
return f.runFunc(name, args, options)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestMySQLRunnerUsesExpectedCommands(t *testing.T) {
|
||||
executor := &fakeCommandExecutor{runFunc: func(name string, args []string, options CommandOptions) error {
|
||||
if options.Stdout != nil {
|
||||
_, _ = io.WriteString(options.Stdout, "mysql dump")
|
||||
}
|
||||
return nil
|
||||
}}
|
||||
runner := NewMySQLRunner(executor)
|
||||
result, err := runner.Run(context.Background(), TaskSpec{Name: "mysql", Type: "mysql", Database: DatabaseSpec{Host: "127.0.0.1", Port: 3306, User: "root", Password: "secret", Names: []string{"app, audit"}}}, NopLogWriter{})
|
||||
if err != nil {
|
||||
t.Fatalf("Run returned error: %v", err)
|
||||
}
|
||||
if executor.lastName != "mysqldump" {
|
||||
t.Fatalf("expected mysqldump, got %s", executor.lastName)
|
||||
}
|
||||
if len(executor.lastArgs) == 0 || executor.lastArgs[len(executor.lastArgs)-2] != "app" || executor.lastArgs[len(executor.lastArgs)-1] != "audit" {
|
||||
t.Fatalf("unexpected mysql args: %#v", executor.lastArgs)
|
||||
}
|
||||
if _, err := os.Stat(result.ArtifactPath); err != nil {
|
||||
t.Fatalf("artifact file missing: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgreSQLRunnerRestoreUsesPsql(t *testing.T) {
|
||||
executor := &fakeCommandExecutor{}
|
||||
runner := NewPostgreSQLRunner(executor)
|
||||
artifact := filepathJoinTempFile(t, "restore.sql", "select 1;")
|
||||
if err := runner.Restore(context.Background(), TaskSpec{Name: "postgres", Type: "postgresql", Database: DatabaseSpec{Host: "127.0.0.1", Port: 5432, User: "postgres", Password: "secret"}}, artifact, NopLogWriter{}); err != nil {
|
||||
t.Fatalf("Restore returned error: %v", err)
|
||||
}
|
||||
if executor.lastName != "psql" {
|
||||
t.Fatalf("expected psql, got %s", executor.lastName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQLRunnerReturnsLookupError(t *testing.T) {
|
||||
runner := NewMySQLRunner(&fakeCommandExecutor{lookupErr: errors.New("missing")})
|
||||
_, err := runner.Run(context.Background(), TaskSpec{Name: "mysql", Type: "mysql", Database: DatabaseSpec{Host: "127.0.0.1", Port: 3306, User: "root", Password: "secret", Names: []string{"app"}}}, NopLogWriter{})
|
||||
if err == nil {
|
||||
t.Fatal("expected error when mysqldump is missing")
|
||||
}
|
||||
}
|
||||
|
||||
func filepathJoinTempFile(t *testing.T, name string, content string) string {
|
||||
t.Helper()
|
||||
filePath := t.TempDir() + "/" + name
|
||||
if err := os.WriteFile(filePath, []byte(content), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile returned error: %v", err)
|
||||
}
|
||||
return filePath
|
||||
}
|
||||
|
||||
func TestPostgreSQLRunnerRunAppendsMultipleDatabaseDumps(t *testing.T) {
|
||||
executor := &fakeCommandExecutor{runFunc: func(name string, args []string, options CommandOptions) error {
|
||||
_, _ = io.Copy(options.Stdout, bytes.NewBufferString(args[len(args)-1]))
|
||||
return nil
|
||||
}}
|
||||
runner := NewPostgreSQLRunner(executor)
|
||||
result, err := runner.Run(context.Background(), TaskSpec{Name: "pg", Type: "postgresql", Database: DatabaseSpec{Host: "127.0.0.1", Port: 5432, User: "postgres", Password: "secret", Names: []string{"app", "audit"}}}, NopLogWriter{})
|
||||
if err != nil {
|
||||
t.Fatalf("Run returned error: %v", err)
|
||||
}
|
||||
content, err := os.ReadFile(result.ArtifactPath)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile returned error: %v", err)
|
||||
}
|
||||
if !bytes.Contains(content, []byte("app")) || !bytes.Contains(content, []byte("audit")) {
|
||||
t.Fatalf("unexpected pg dump content: %s", string(content))
|
||||
}
|
||||
}
|
||||
191
server/internal/backup/file_runner.go
Normal file
191
server/internal/backup/file_runner.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package backup
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type FileRunner struct{}
|
||||
|
||||
func NewFileRunner() *FileRunner {
|
||||
return &FileRunner{}
|
||||
}
|
||||
|
||||
func (r *FileRunner) Type() string {
|
||||
return "file"
|
||||
}
|
||||
|
||||
func (r *FileRunner) Run(_ context.Context, task TaskSpec, writer LogWriter) (*RunResult, error) {
|
||||
sourcePath := filepath.Clean(strings.TrimSpace(task.SourcePath))
|
||||
if sourcePath == "" {
|
||||
return nil, fmt.Errorf("source path is required")
|
||||
}
|
||||
info, err := os.Stat(sourcePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stat source path: %w", err)
|
||||
}
|
||||
tempDir, artifactPath, err := createTempArtifact(task.TempDir, task.Name, "tar")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
artifactFile, err := os.Create(artifactPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create tar artifact: %w", err)
|
||||
}
|
||||
defer artifactFile.Close()
|
||||
tw := tar.NewWriter(artifactFile)
|
||||
defer tw.Close()
|
||||
baseParent := filepath.Dir(sourcePath)
|
||||
excludes := normalizeExcludePatterns(task.ExcludePatterns)
|
||||
writer.WriteLine(fmt.Sprintf("开始打包文件备份:%s", sourcePath))
|
||||
fileCount := 0
|
||||
dirCount := 0
|
||||
walkErr := filepath.Walk(sourcePath, func(currentPath string, currentInfo os.FileInfo, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
writer.WriteLine(fmt.Sprintf("⚠ 无法访问 %s: %v", currentPath, walkErr))
|
||||
return nil
|
||||
}
|
||||
relPath, err := filepath.Rel(baseParent, currentPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
archiveName := filepath.ToSlash(relPath)
|
||||
if shouldExcludeEntry(archiveName, currentInfo.IsDir(), excludes) {
|
||||
if currentInfo.IsDir() {
|
||||
writer.WriteLine(fmt.Sprintf("跳过排除目录 %s", archiveName))
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if currentPath == sourcePath && currentInfo.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if currentInfo.IsDir() {
|
||||
dirCount++
|
||||
writer.WriteLine(fmt.Sprintf("📁 进入目录 %s", archiveName))
|
||||
}
|
||||
|
||||
header, err := tar.FileInfoHeader(currentInfo, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
header.Name = archiveName
|
||||
if err := tw.WriteHeader(header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if currentInfo.Mode().IsRegular() {
|
||||
file, err := os.Open(currentPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
if _, err := io.CopyN(tw, file, currentInfo.Size()); err != nil && err != io.EOF {
|
||||
return err
|
||||
}
|
||||
fileCount++
|
||||
if fileCount%100 == 0 {
|
||||
writer.WriteLine(fmt.Sprintf("已打包 %d 个文件...", fileCount))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if walkErr != nil {
|
||||
return nil, fmt.Errorf("walk source path: %w", walkErr)
|
||||
}
|
||||
if info.IsDir() {
|
||||
writer.WriteLine(fmt.Sprintf("目录打包完成(%d 个目录,%d 个文件)", dirCount, fileCount))
|
||||
} else {
|
||||
writer.WriteLine("文件打包完成")
|
||||
}
|
||||
return &RunResult{ArtifactPath: artifactPath, FileName: filepath.Base(artifactPath), TempDir: tempDir}, nil
|
||||
}
|
||||
|
||||
func (r *FileRunner) Restore(_ context.Context, task TaskSpec, artifactPath string, writer LogWriter) error {
|
||||
artifactFile, err := os.Open(artifactPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open tar artifact: %w", err)
|
||||
}
|
||||
defer artifactFile.Close()
|
||||
targetParent := filepath.Dir(filepath.Clean(strings.TrimSpace(task.SourcePath)))
|
||||
if err := os.MkdirAll(targetParent, 0o755); err != nil {
|
||||
return fmt.Errorf("create restore parent: %w", err)
|
||||
}
|
||||
tr := tar.NewReader(artifactFile)
|
||||
for {
|
||||
header, err := tr.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("read tar entry: %w", err)
|
||||
}
|
||||
cleanName := path.Clean(strings.TrimSpace(header.Name))
|
||||
if cleanName == "." || cleanName == "" {
|
||||
continue
|
||||
}
|
||||
targetPath := filepath.Clean(filepath.Join(targetParent, filepath.FromSlash(cleanName)))
|
||||
parentWithSep := filepath.Clean(targetParent) + string(filepath.Separator)
|
||||
if targetPath != filepath.Clean(targetParent) && !strings.HasPrefix(targetPath, parentWithSep) {
|
||||
return fmt.Errorf("tar entry escapes restore path")
|
||||
}
|
||||
switch header.Typeflag {
|
||||
case tar.TypeDir:
|
||||
if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil {
|
||||
return fmt.Errorf("create restore dir: %w", err)
|
||||
}
|
||||
case tar.TypeReg, tar.TypeRegA:
|
||||
if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil {
|
||||
return fmt.Errorf("create restore parent dir: %w", err)
|
||||
}
|
||||
file, err := os.OpenFile(targetPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.FileMode(header.Mode))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create restore file: %w", err)
|
||||
}
|
||||
if _, err := io.Copy(file, tr); err != nil {
|
||||
file.Close()
|
||||
return fmt.Errorf("write restore file: %w", err)
|
||||
}
|
||||
if err := file.Close(); err != nil {
|
||||
return fmt.Errorf("close restore file: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
writer.WriteLine("文件恢复完成")
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeExcludePatterns(items []string) []string {
|
||||
result := make([]string, 0, len(items))
|
||||
for _, item := range items {
|
||||
trimmed := strings.TrimSpace(item)
|
||||
if trimmed != "" {
|
||||
result = append(result, filepath.ToSlash(trimmed))
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func shouldExcludeEntry(relPath string, isDir bool, patterns []string) bool {
|
||||
relPath = filepath.ToSlash(relPath)
|
||||
base := path.Base(relPath)
|
||||
for _, pattern := range patterns {
|
||||
if matched, _ := path.Match(pattern, relPath); matched {
|
||||
return true
|
||||
}
|
||||
if matched, _ := path.Match(pattern, base); matched {
|
||||
return true
|
||||
}
|
||||
if isDir && strings.TrimSuffix(pattern, "/") == base {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
69
server/internal/backup/file_runner_test.go
Normal file
69
server/internal/backup/file_runner_test.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package backup
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type bufferWriter struct{ lines []string }
|
||||
|
||||
func (w *bufferWriter) WriteLine(message string) { w.lines = append(w.lines, message) }
|
||||
|
||||
func TestFileRunnerRunAndRestore(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
sourceDir := filepath.Join(tempDir, "site")
|
||||
if err := os.MkdirAll(filepath.Join(sourceDir, "node_modules"), 0o755); err != nil {
|
||||
t.Fatalf("MkdirAll returned error: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(sourceDir, "index.html"), []byte("ok"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile returned error: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(sourceDir, "app.log"), []byte("skip"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile returned error: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(sourceDir, "node_modules", "pkg.json"), []byte("skip-dir"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile returned error: %v", err)
|
||||
}
|
||||
runner := NewFileRunner()
|
||||
writer := &bufferWriter{}
|
||||
result, err := runner.Run(context.Background(), TaskSpec{Name: "site files", Type: "file", SourcePath: sourceDir, ExcludePatterns: []string{"*.log", "node_modules"}}, writer)
|
||||
if err != nil {
|
||||
t.Fatalf("Run returned error: %v", err)
|
||||
}
|
||||
archiveFile, err := os.Open(result.ArtifactPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Open returned error: %v", err)
|
||||
}
|
||||
defer archiveFile.Close()
|
||||
reader := tar.NewReader(archiveFile)
|
||||
entries := map[string]bool{}
|
||||
for {
|
||||
header, err := reader.Next()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
entries[header.Name] = true
|
||||
}
|
||||
if !entries["site/index.html"] {
|
||||
t.Fatalf("expected site/index.html in archive, got %#v", entries)
|
||||
}
|
||||
if entries["site/app.log"] || entries["site/node_modules/pkg.json"] {
|
||||
t.Fatalf("unexpected excluded entries: %#v", entries)
|
||||
}
|
||||
if err := os.RemoveAll(sourceDir); err != nil {
|
||||
t.Fatalf("RemoveAll returned error: %v", err)
|
||||
}
|
||||
if err := runner.Restore(context.Background(), TaskSpec{Name: "site files", Type: "file", SourcePath: sourceDir}, result.ArtifactPath, writer); err != nil {
|
||||
t.Fatalf("Restore returned error: %v", err)
|
||||
}
|
||||
content, err := os.ReadFile(filepath.Join(sourceDir, "index.html"))
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile returned error: %v", err)
|
||||
}
|
||||
if string(content) != "ok" {
|
||||
t.Fatalf("unexpected restored content: %s", string(content))
|
||||
}
|
||||
}
|
||||
41
server/internal/backup/helpers.go
Normal file
41
server/internal/backup/helpers.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package backup
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func createTempArtifact(baseDir, taskName string, extension string) (string, string, error) {
|
||||
tempDir, err := os.MkdirTemp(baseDir, "backupx-run-*")
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("create temp dir: %w", err)
|
||||
}
|
||||
base := sanitizeFileName(taskName)
|
||||
if base == "" {
|
||||
base = "backup"
|
||||
}
|
||||
fileName := fmt.Sprintf("%s_%s.%s", base, time.Now().UTC().Format("20060102T150405"), strings.TrimPrefix(extension, "."))
|
||||
return tempDir, filepath.Join(tempDir, fileName), nil
|
||||
}
|
||||
|
||||
func sanitizeFileName(value string) string {
|
||||
builder := strings.Builder{}
|
||||
for _, char := range strings.TrimSpace(value) {
|
||||
switch {
|
||||
case char >= 'a' && char <= 'z':
|
||||
builder.WriteRune(char)
|
||||
case char >= 'A' && char <= 'Z':
|
||||
builder.WriteRune(char + ('a' - 'A'))
|
||||
case char >= '0' && char <= '9':
|
||||
builder.WriteRune(char)
|
||||
case char == '-' || char == '_':
|
||||
builder.WriteRune(char)
|
||||
case char == ' ' || char == '.':
|
||||
builder.WriteRune('_')
|
||||
}
|
||||
}
|
||||
return strings.Trim(builder.String(), "_")
|
||||
}
|
||||
110
server/internal/backup/log_hub.go
Normal file
110
server/internal/backup/log_hub.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package backup
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type LogHub struct {
|
||||
mu sync.RWMutex
|
||||
streams map[uint]*logStreamState
|
||||
}
|
||||
|
||||
type logStreamState struct {
|
||||
nextSequence int64
|
||||
events []LogEvent
|
||||
subscribers map[int]chan LogEvent
|
||||
nextSubID int
|
||||
completed bool
|
||||
status string
|
||||
}
|
||||
|
||||
func NewLogHub() *LogHub {
|
||||
return &LogHub{streams: make(map[uint]*logStreamState)}
|
||||
}
|
||||
|
||||
func (h *LogHub) Append(recordID uint, level, message string) LogEvent {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
state := h.ensureState(recordID)
|
||||
state.nextSequence++
|
||||
event := LogEvent{RecordID: recordID, Sequence: state.nextSequence, Level: level, Message: message, Timestamp: time.Now().UTC(), Status: state.status}
|
||||
state.events = append(state.events, event)
|
||||
for _, subscriber := range state.subscribers {
|
||||
select {
|
||||
case subscriber <- event:
|
||||
default:
|
||||
}
|
||||
}
|
||||
return event
|
||||
}
|
||||
|
||||
func (h *LogHub) Snapshot(recordID uint) []LogEvent {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
state, ok := h.streams[recordID]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
result := make([]LogEvent, len(state.events))
|
||||
copy(result, state.events)
|
||||
return result
|
||||
}
|
||||
|
||||
func (h *LogHub) Subscribe(recordID uint, buffer int) (<-chan LogEvent, func()) {
|
||||
if buffer <= 0 {
|
||||
buffer = 32
|
||||
}
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
state := h.ensureState(recordID)
|
||||
state.nextSubID++
|
||||
id := state.nextSubID
|
||||
channel := make(chan LogEvent, buffer)
|
||||
state.subscribers[id] = channel
|
||||
for _, event := range state.events {
|
||||
channel <- event
|
||||
}
|
||||
cancel := func() {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
stream, ok := h.streams[recordID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
subscriber, ok := stream.subscribers[id]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(stream.subscribers, id)
|
||||
close(subscriber)
|
||||
}
|
||||
return channel, cancel
|
||||
}
|
||||
|
||||
func (h *LogHub) Complete(recordID uint, status string) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
state := h.ensureState(recordID)
|
||||
state.completed = true
|
||||
state.status = status
|
||||
state.nextSequence++
|
||||
event := LogEvent{RecordID: recordID, Sequence: state.nextSequence, Level: "info", Message: "stream completed", Timestamp: time.Now().UTC(), Completed: true, Status: status}
|
||||
state.events = append(state.events, event)
|
||||
for _, subscriber := range state.subscribers {
|
||||
select {
|
||||
case subscriber <- event:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *LogHub) ensureState(recordID uint) *logStreamState {
|
||||
state, ok := h.streams[recordID]
|
||||
if ok {
|
||||
return state
|
||||
}
|
||||
state = &logStreamState{subscribers: make(map[int]chan LogEvent), status: "running"}
|
||||
h.streams[recordID] = state
|
||||
return state
|
||||
}
|
||||
26
server/internal/backup/log_hub_test.go
Normal file
26
server/internal/backup/log_hub_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package backup
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestLogHubAppendSubscribeAndComplete(t *testing.T) {
|
||||
hub := NewLogHub()
|
||||
channel, cancel := hub.Subscribe(1, 4)
|
||||
defer cancel()
|
||||
first := hub.Append(1, "info", "started")
|
||||
if first.Sequence != 1 || first.Message != "started" {
|
||||
t.Fatalf("unexpected first event: %#v", first)
|
||||
}
|
||||
snapshot := hub.Snapshot(1)
|
||||
if len(snapshot) != 1 {
|
||||
t.Fatalf("expected snapshot size 1, got %d", len(snapshot))
|
||||
}
|
||||
event := <-channel
|
||||
if event.Message != "started" {
|
||||
t.Fatalf("unexpected streamed event: %#v", event)
|
||||
}
|
||||
hub.Complete(1, "success")
|
||||
completeEvent := <-channel
|
||||
if !completeEvent.Completed || completeEvent.Status != "success" {
|
||||
t.Fatalf("unexpected completion event: %#v", completeEvent)
|
||||
}
|
||||
}
|
||||
56
server/internal/backup/logger.go
Normal file
56
server/internal/backup/logger.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package backup
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type ExecutionLogger struct {
|
||||
recordID uint
|
||||
hub *LogHub
|
||||
mu sync.Mutex
|
||||
buffer strings.Builder
|
||||
}
|
||||
|
||||
func NewExecutionLogger(recordID uint, hub *LogHub) *ExecutionLogger {
|
||||
return &ExecutionLogger{recordID: recordID, hub: hub}
|
||||
}
|
||||
|
||||
func (l *ExecutionLogger) Write(level, message string) {
|
||||
trimmed := strings.TrimSpace(message)
|
||||
if trimmed == "" {
|
||||
return
|
||||
}
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
if l.buffer.Len() > 0 {
|
||||
l.buffer.WriteByte('\n')
|
||||
}
|
||||
l.buffer.WriteString(trimmed)
|
||||
if l.hub != nil {
|
||||
l.hub.Append(l.recordID, level, trimmed)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ExecutionLogger) Infof(format string, args ...any) {
|
||||
l.Write("info", fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (l *ExecutionLogger) Errorf(format string, args ...any) {
|
||||
l.Write("error", fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (l *ExecutionLogger) Warnf(format string, args ...any) {
|
||||
l.Write("warn", fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (l *ExecutionLogger) WriteLine(message string) {
|
||||
l.Infof("%s", message)
|
||||
}
|
||||
|
||||
func (l *ExecutionLogger) String() string {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
return l.buffer.String()
|
||||
}
|
||||
163
server/internal/backup/mysql_runner.go
Normal file
163
server/internal/backup/mysql_runner.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package backup
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type MySQLRunner struct {
|
||||
executor CommandExecutor
|
||||
}
|
||||
|
||||
func NewMySQLRunner(executor CommandExecutor) *MySQLRunner {
|
||||
if executor == nil {
|
||||
executor = NewOSCommandExecutor()
|
||||
}
|
||||
return &MySQLRunner{executor: executor}
|
||||
}
|
||||
|
||||
func (r *MySQLRunner) Type() string {
|
||||
return "mysql"
|
||||
}
|
||||
|
||||
func (r *MySQLRunner) Run(ctx context.Context, task TaskSpec, writer LogWriter) (*RunResult, error) {
|
||||
if _, err := r.executor.LookPath("mysqldump"); err != nil {
|
||||
return nil, fmt.Errorf("未找到 mysqldump 命令 (请确保服务器已安装 mysql-client 或 mariadb-client)")
|
||||
}
|
||||
startedAt := task.StartedAt
|
||||
if startedAt.IsZero() {
|
||||
startedAt = time.Now().UTC()
|
||||
}
|
||||
tempDir, err := CreateTaskTempDir(task.Name, startedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fileName := BuildArtifactName(task.Name, startedAt, "sql")
|
||||
artifactPath := filepath.Join(tempDir, fileName)
|
||||
file, err := os.Create(artifactPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create mysql dump file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
dbNames := normalizeDatabaseNames(task.Database.Names)
|
||||
if len(dbNames) == 0 {
|
||||
return nil, fmt.Errorf("mysql database names are required")
|
||||
}
|
||||
args := []string{
|
||||
"--host", task.Database.Host,
|
||||
"--port", strconv.Itoa(task.Database.Port),
|
||||
"--user", task.Database.User,
|
||||
"--single-transaction",
|
||||
"--quick",
|
||||
"--routines",
|
||||
"--triggers",
|
||||
"--events",
|
||||
"--no-tablespaces",
|
||||
"--net-buffer-length=32768",
|
||||
"--databases",
|
||||
}
|
||||
args = append(args, dbNames...)
|
||||
|
||||
writer.WriteLine(fmt.Sprintf("连接到 MySQL: %s:%d", task.Database.Host, task.Database.Port))
|
||||
writer.WriteLine(fmt.Sprintf("备份数据库: %s", strings.Join(dbNames, ", ")))
|
||||
|
||||
stderrWriter := newLogLineWriter(writer, "mysqldump")
|
||||
writer.WriteLine("开始执行 mysqldump")
|
||||
if err := r.executor.Run(ctx, "mysqldump", args, CommandOptions{Stdout: file, Stderr: stderrWriter, Env: mysqlEnv(task.Database.Password)}); err != nil {
|
||||
return nil, fmt.Errorf("run mysqldump: %w: %s", err, stderrWriter.collected())
|
||||
}
|
||||
info, err := file.Stat()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stat mysql dump file: %w", err)
|
||||
}
|
||||
writer.WriteLine(fmt.Sprintf("MySQL 导出完成(文件大小: %s)", formatFileSize(info.Size())))
|
||||
return &RunResult{ArtifactPath: artifactPath, FileName: fileName, TempDir: tempDir, Size: info.Size(), StorageKey: BuildStorageKey("mysql", startedAt, fileName)}, nil
|
||||
}
|
||||
|
||||
func (r *MySQLRunner) Restore(ctx context.Context, task TaskSpec, artifactPath string, writer LogWriter) error {
|
||||
if _, err := r.executor.LookPath("mysql"); err != nil {
|
||||
return fmt.Errorf("未找到 mysql 命令 (请确保服务器已安装 mysql-client 或 mariadb-client)")
|
||||
}
|
||||
input, err := os.Open(filepath.Clean(artifactPath))
|
||||
if err != nil {
|
||||
return fmt.Errorf("open mysql restore file: %w", err)
|
||||
}
|
||||
defer input.Close()
|
||||
stderr := &bytes.Buffer{}
|
||||
args := []string{"--host", task.Database.Host, "--port", strconv.Itoa(task.Database.Port), "--user", task.Database.User}
|
||||
writer.WriteLine("开始执行 mysql 恢复")
|
||||
if err := r.executor.Run(ctx, "mysql", args, CommandOptions{Stdin: input, Stderr: stderr, Env: mysqlEnv(task.Database.Password)}); err != nil {
|
||||
return fmt.Errorf("run mysql restore: %w: %s", err, strings.TrimSpace(stderr.String()))
|
||||
}
|
||||
writer.WriteLine("MySQL 恢复完成")
|
||||
return nil
|
||||
}
|
||||
|
||||
func mysqlEnv(password string) []string {
|
||||
if strings.TrimSpace(password) == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{"MYSQL_PWD=" + password}
|
||||
}
|
||||
|
||||
// logLineWriter streams each line of output to a LogWriter in real-time.
|
||||
type logLineWriter struct {
|
||||
writer LogWriter
|
||||
prefix string
|
||||
buf bytes.Buffer
|
||||
}
|
||||
|
||||
func newLogLineWriter(w LogWriter, prefix string) *logLineWriter {
|
||||
return &logLineWriter{writer: w, prefix: prefix}
|
||||
}
|
||||
|
||||
func (w *logLineWriter) Write(p []byte) (int, error) {
|
||||
n := len(p)
|
||||
w.buf.Write(p)
|
||||
scanner := bufio.NewScanner(strings.NewReader(w.buf.String()))
|
||||
var remaining string
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line != "" {
|
||||
w.writer.WriteLine(fmt.Sprintf("[%s] %s", w.prefix, line))
|
||||
}
|
||||
}
|
||||
// Keep any partial last line (no newline yet)
|
||||
lastNl := bytes.LastIndexByte(p, '\n')
|
||||
if lastNl >= 0 {
|
||||
remaining = w.buf.String()[w.buf.Len()-(len(p)-lastNl-1):]
|
||||
w.buf.Reset()
|
||||
w.buf.WriteString(remaining)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (w *logLineWriter) collected() string {
|
||||
return strings.TrimSpace(w.buf.String())
|
||||
}
|
||||
|
||||
func formatFileSize(size int64) string {
|
||||
const (
|
||||
KB = 1024
|
||||
MB = KB * 1024
|
||||
GB = MB * 1024
|
||||
)
|
||||
switch {
|
||||
case size >= GB:
|
||||
return fmt.Sprintf("%.2f GB", float64(size)/float64(GB))
|
||||
case size >= MB:
|
||||
return fmt.Sprintf("%.2f MB", float64(size)/float64(MB))
|
||||
case size >= KB:
|
||||
return fmt.Sprintf("%.2f KB", float64(size)/float64(KB))
|
||||
default:
|
||||
return fmt.Sprintf("%d B", size)
|
||||
}
|
||||
}
|
||||
|
||||
171
server/internal/backup/postgres_runner.go
Normal file
171
server/internal/backup/postgres_runner.go
Normal file
@@ -0,0 +1,171 @@
|
||||
//go:build ignore
|
||||
|
||||
package backup
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type PostgreSQLRunner struct {
|
||||
executor CommandExecutor
|
||||
}
|
||||
|
||||
func NewPostgreSQLRunner(executor CommandExecutor) *PostgreSQLRunner {
|
||||
if executor == nil {
|
||||
executor = NewOSCommandExecutor()
|
||||
}
|
||||
return &PostgreSQLRunner{executor: executor}
|
||||
}
|
||||
|
||||
func (r *PostgreSQLRunner) Type() string {
|
||||
return "postgresql"
|
||||
}
|
||||
|
||||
func (r *PostgreSQLRunner) Run(ctx context.Context, spec TaskSpec, logger LogSink) (*Result, error) {
|
||||
if _, err := r.executor.LookPath("pg_dump"); err != nil {
|
||||
return nil, fmt.Errorf("pg_dump is required: %w", err)
|
||||
}
|
||||
databases := splitDatabaseNames(spec.DBName)
|
||||
if len(databases) == 0 {
|
||||
return nil, fmt.Errorf("postgresql database name is required")
|
||||
}
|
||||
tempDir, err := CreateTaskTempDir(spec.TaskName, spec.StartedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(databases) == 1 {
|
||||
return r.dumpSingleDatabase(ctx, spec, databases[0], tempDir, logger)
|
||||
}
|
||||
multiDumpDir := filepath.Join(tempDir, "postgres-dumps")
|
||||
if err := os.MkdirAll(multiDumpDir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create postgres multi dump directory: %w", err)
|
||||
}
|
||||
for _, databaseName := range databases {
|
||||
if _, err := r.dumpDatabaseToFile(ctx, spec, databaseName, filepath.Join(multiDumpDir, sanitizeDumpName(databaseName)+".sql"), logger); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
fileName := BuildArtifactName(spec.TaskName, spec.StartedAt, "tar.gz")
|
||||
artifactPath := filepath.Join(tempDir, fileName)
|
||||
size, err := CreateTarGz(ctx, multiDumpDir, nil, artifactPath, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Result{ArtifactPath: artifactPath, FileName: fileName, Size: size, StorageKey: BuildStorageKey("postgresql", spec.StartedAt, fileName)}, nil
|
||||
}
|
||||
|
||||
func (r *PostgreSQLRunner) Restore(ctx context.Context, spec TaskSpec, artifactPath string, logger LogSink) error {
|
||||
if _, err := r.executor.LookPath("psql"); err != nil {
|
||||
return fmt.Errorf("psql is required: %w", err)
|
||||
}
|
||||
databases := splitDatabaseNames(spec.DBName)
|
||||
if len(databases) == 0 {
|
||||
return fmt.Errorf("postgresql database name is required")
|
||||
}
|
||||
if strings.HasSuffix(strings.ToLower(artifactPath), ".tar.gz") {
|
||||
restoreDir, err := CreateTaskTempDir(spec.TaskName+"-restore", spec.StartedAt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ExtractTarGz(ctx, artifactPath, restoreDir, logger); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, databaseName := range databases {
|
||||
filePath := filepath.Join(restoreDir, filepath.Base(restoreDir), sanitizeDumpName(databaseName)+".sql")
|
||||
if _, err := os.Stat(filePath); err != nil {
|
||||
fallback := filepath.Join(restoreDir, "postgres-dumps", sanitizeDumpName(databaseName)+".sql")
|
||||
filePath = fallback
|
||||
}
|
||||
if err := r.restoreDatabaseFromFile(ctx, spec, databaseName, filePath, logger); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return r.restoreDatabaseFromFile(ctx, spec, databases[0], artifactPath, logger)
|
||||
}
|
||||
|
||||
func (r *PostgreSQLRunner) dumpSingleDatabase(ctx context.Context, spec TaskSpec, databaseName string, tempDir string, logger LogSink) (*Result, error) {
|
||||
fileName := BuildArtifactName(spec.TaskName, spec.StartedAt, "sql")
|
||||
artifactPath := filepath.Join(tempDir, fileName)
|
||||
size, err := r.dumpDatabaseToFile(ctx, spec, databaseName, artifactPath, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Result{ArtifactPath: artifactPath, FileName: fileName, Size: size, StorageKey: BuildStorageKey("postgresql", spec.StartedAt, fileName)}, nil
|
||||
}
|
||||
|
||||
func (r *PostgreSQLRunner) dumpDatabaseToFile(ctx context.Context, spec TaskSpec, databaseName string, artifactPath string, logger LogSink) (int64, error) {
|
||||
output, err := os.Create(filepath.Clean(artifactPath))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("create postgres dump file: %w", err)
|
||||
}
|
||||
defer output.Close()
|
||||
stderr := &bytes.Buffer{}
|
||||
args := []string{"-h", spec.DBHost, "-p", fmt.Sprintf("%d", spec.DBPort), "-U", spec.DBUser, "-d", databaseName, "--no-owner", "--no-privileges"}
|
||||
if logger != nil {
|
||||
logger.Infof("开始执行 pg_dump:%s", databaseName)
|
||||
}
|
||||
if err := r.executor.Run(ctx, "pg_dump", args, postgresEnv(spec.DBPassword), nil, output, stderr); err != nil {
|
||||
return 0, fmt.Errorf("run pg_dump: %w: %s", err, strings.TrimSpace(stderr.String()))
|
||||
}
|
||||
info, err := output.Stat()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("stat postgres dump file: %w", err)
|
||||
}
|
||||
return info.Size(), nil
|
||||
}
|
||||
|
||||
func (r *PostgreSQLRunner) restoreDatabaseFromFile(ctx context.Context, spec TaskSpec, databaseName string, artifactPath string, logger LogSink) error {
|
||||
input, err := os.Open(filepath.Clean(artifactPath))
|
||||
if err != nil {
|
||||
return fmt.Errorf("open postgres restore file: %w", err)
|
||||
}
|
||||
defer input.Close()
|
||||
stderr := &bytes.Buffer{}
|
||||
args := []string{"-h", spec.DBHost, "-p", fmt.Sprintf("%d", spec.DBPort), "-U", spec.DBUser, "-d", databaseName}
|
||||
if logger != nil {
|
||||
logger.Infof("开始执行 psql 恢复:%s", databaseName)
|
||||
}
|
||||
if err := r.executor.Run(ctx, "psql", args, postgresEnv(spec.DBPassword), input, nil, stderr); err != nil {
|
||||
return fmt.Errorf("run psql restore: %w: %s", err, strings.TrimSpace(stderr.String()))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func postgresEnv(password string) map[string]string {
|
||||
if strings.TrimSpace(password) == "" {
|
||||
return nil
|
||||
}
|
||||
return map[string]string{"PGPASSWORD": password}
|
||||
}
|
||||
|
||||
func splitDatabaseNames(value string) []string {
|
||||
parts := strings.Split(value, ",")
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
trimmed := strings.TrimSpace(part)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
result = append(result, trimmed)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func sanitizeDumpName(value string) string {
|
||||
trimmed := strings.TrimSpace(strings.ToLower(value))
|
||||
trimmed = strings.ReplaceAll(trimmed, " ", "-")
|
||||
trimmed = strings.ReplaceAll(trimmed, "/", "-")
|
||||
trimmed = strings.ReplaceAll(trimmed, "\\", "-")
|
||||
trimmed = strings.Trim(trimmed, "-._")
|
||||
if trimmed == "" {
|
||||
return "database"
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
80
server/internal/backup/postgresql_runner.go
Normal file
80
server/internal/backup/postgresql_runner.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package backup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type PostgreSQLRunner struct {
|
||||
executor CommandExecutor
|
||||
}
|
||||
|
||||
func NewPostgreSQLRunner(executor CommandExecutor) *PostgreSQLRunner {
|
||||
if executor == nil {
|
||||
executor = NewOSCommandExecutor()
|
||||
}
|
||||
return &PostgreSQLRunner{executor: executor}
|
||||
}
|
||||
|
||||
func (r *PostgreSQLRunner) Type() string {
|
||||
return "postgresql"
|
||||
}
|
||||
|
||||
func (r *PostgreSQLRunner) Run(ctx context.Context, task TaskSpec, writer LogWriter) (*RunResult, error) {
|
||||
if _, err := r.executor.LookPath("pg_dump"); err != nil {
|
||||
return nil, fmt.Errorf("未找到 pg_dump 命令 (请确保服务器已安装 postgresql-client)")
|
||||
}
|
||||
tempDir, artifactPath, err := createTempArtifact(task.TempDir, task.Name, "sql")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
file, err := os.Create(artifactPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create postgresql dump file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
dbNames := normalizeDatabaseNames(task.Database.Names)
|
||||
if len(dbNames) == 0 {
|
||||
return nil, fmt.Errorf("postgresql database names are required")
|
||||
}
|
||||
writer.WriteLine(fmt.Sprintf("连接到 PostgreSQL: %s:%d", task.Database.Host, task.Database.Port))
|
||||
writer.WriteLine(fmt.Sprintf("备份数据库: %s", strings.Join(dbNames, ", ")))
|
||||
stderrWriter := newLogLineWriter(writer, "pg_dump")
|
||||
for index, name := range dbNames {
|
||||
args := []string{"--clean", "--if-exists", "--create", "--format=plain", "-h", task.Database.Host, "-p", strconv.Itoa(task.Database.Port), "-U", task.Database.User, "--dbname", name}
|
||||
writer.WriteLine(fmt.Sprintf("开始导出数据库 [%d/%d]: %s", index+1, len(dbNames), name))
|
||||
if err := r.executor.Run(ctx, "pg_dump", args, CommandOptions{Stdout: file, Stderr: stderrWriter, Env: append(os.Environ(), "PGPASSWORD="+task.Database.Password)}); err != nil {
|
||||
return nil, fmt.Errorf("run pg_dump for %s: %w", name, err)
|
||||
}
|
||||
writer.WriteLine(fmt.Sprintf("数据库 %s 导出完成", name))
|
||||
if index < len(dbNames)-1 {
|
||||
if _, err := file.WriteString("\n\n"); err != nil {
|
||||
return nil, fmt.Errorf("write dump separator: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
info, _ := file.Stat()
|
||||
sizeStr := "未知"
|
||||
if info != nil {
|
||||
sizeStr = formatFileSize(info.Size())
|
||||
}
|
||||
writer.WriteLine(fmt.Sprintf("PostgreSQL 导出完成(文件大小: %s)", sizeStr))
|
||||
return &RunResult{ArtifactPath: artifactPath, FileName: filepath.Base(artifactPath), TempDir: tempDir}, nil
|
||||
}
|
||||
|
||||
func (r *PostgreSQLRunner) Restore(ctx context.Context, task TaskSpec, artifactPath string, writer LogWriter) error {
|
||||
if _, err := r.executor.LookPath("psql"); err != nil {
|
||||
return fmt.Errorf("未找到 psql 命令 (请确保服务器已安装 postgresql-client)")
|
||||
}
|
||||
writer.WriteLine("开始执行 psql 恢复")
|
||||
args := []string{"-h", task.Database.Host, "-p", strconv.Itoa(task.Database.Port), "-U", task.Database.User, "-d", "postgres", "-f", artifactPath}
|
||||
if err := r.executor.Run(ctx, "psql", args, CommandOptions{Env: append(os.Environ(), "PGPASSWORD="+task.Database.Password)}); err != nil {
|
||||
return fmt.Errorf("run psql restore: %w", err)
|
||||
}
|
||||
writer.WriteLine("PostgreSQL 恢复完成")
|
||||
return nil
|
||||
}
|
||||
62
server/internal/backup/registry.go
Normal file
62
server/internal/backup/registry.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package backup
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Registry struct {
|
||||
mu sync.RWMutex
|
||||
runners map[string]BackupRunner
|
||||
}
|
||||
|
||||
func NewRegistry(runners ...BackupRunner) *Registry {
|
||||
registry := &Registry{runners: make(map[string]BackupRunner)}
|
||||
for _, runner := range runners {
|
||||
registry.Register(runner)
|
||||
}
|
||||
return registry
|
||||
}
|
||||
|
||||
func (r *Registry) Register(runner BackupRunner) {
|
||||
if runner == nil {
|
||||
return
|
||||
}
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.runners == nil {
|
||||
r.runners = make(map[string]BackupRunner)
|
||||
}
|
||||
r.runners[normalizeTaskType(runner.Type())] = runner
|
||||
}
|
||||
|
||||
func (r *Registry) Runner(taskType string) (BackupRunner, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
runner, ok := r.runners[normalizeTaskType(taskType)]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported backup task type: %s", taskType)
|
||||
}
|
||||
return runner, nil
|
||||
}
|
||||
|
||||
func (r *Registry) Types() []string {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
items := make([]string, 0, len(r.runners))
|
||||
for key := range r.runners {
|
||||
items = append(items, key)
|
||||
}
|
||||
sort.Strings(items)
|
||||
return items
|
||||
}
|
||||
|
||||
func normalizeTaskType(value string) string {
|
||||
normalized := strings.TrimSpace(strings.ToLower(value))
|
||||
if normalized == "pgsql" {
|
||||
return "postgresql"
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
23
server/internal/backup/registry_test.go
Normal file
23
server/internal/backup/registry_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package backup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type stubRunner struct{ taskType string }
|
||||
|
||||
func (r stubRunner) Type() string { return r.taskType }
|
||||
func (r stubRunner) Run(context.Context, TaskSpec, LogWriter) (*RunResult, error) { return nil, nil }
|
||||
func (r stubRunner) Restore(context.Context, TaskSpec, string, LogWriter) error { return nil }
|
||||
|
||||
func TestRegistryResolvesNormalizedType(t *testing.T) {
|
||||
registry := NewRegistry(stubRunner{taskType: "postgresql"})
|
||||
runner, err := registry.Runner("pgsql")
|
||||
if err != nil {
|
||||
t.Fatalf("Runner returned error: %v", err)
|
||||
}
|
||||
if runner.Type() != "postgresql" {
|
||||
t.Fatalf("unexpected runner type: %s", runner.Type())
|
||||
}
|
||||
}
|
||||
82
server/internal/backup/retention/service.go
Normal file
82
server/internal/backup/retention/service.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package retention
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/model"
|
||||
"backupx/server/internal/repository"
|
||||
"backupx/server/internal/storage"
|
||||
)
|
||||
|
||||
type CleanupResult struct {
|
||||
DeletedRecords int
|
||||
DeletedObjects int
|
||||
Warnings []string
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
records repository.BackupRecordRepository
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
func NewService(records repository.BackupRecordRepository) *Service {
|
||||
return &Service{records: records, now: func() time.Time { return time.Now().UTC() }}
|
||||
}
|
||||
|
||||
func (s *Service) Cleanup(ctx context.Context, task *model.BackupTask, provider storage.StorageProvider) (*CleanupResult, error) {
|
||||
if task == nil {
|
||||
return nil, fmt.Errorf("backup task is required")
|
||||
}
|
||||
records, err := s.records.ListSuccessfulByTask(ctx, task.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list successful records: %w", err)
|
||||
}
|
||||
candidates := selectRecordsToDelete(records, task.RetentionDays, task.MaxBackups, s.now())
|
||||
result := &CleanupResult{}
|
||||
for _, record := range candidates {
|
||||
if strings.TrimSpace(record.StoragePath) != "" {
|
||||
if provider == nil {
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("record %d missing storage provider for cleanup", record.ID))
|
||||
continue
|
||||
}
|
||||
if err := provider.Delete(ctx, record.StoragePath); err != nil {
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("delete storage object %s failed: %v", record.StoragePath, err))
|
||||
continue
|
||||
}
|
||||
result.DeletedObjects++
|
||||
}
|
||||
if err := s.records.Delete(ctx, record.ID); err != nil {
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("delete backup record %d failed: %v", record.ID, err))
|
||||
continue
|
||||
}
|
||||
result.DeletedRecords++
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func selectRecordsToDelete(records []model.BackupRecord, retentionDays int, maxBackups int, now time.Time) []model.BackupRecord {
|
||||
selected := make(map[uint]model.BackupRecord)
|
||||
if maxBackups > 0 && len(records) > maxBackups {
|
||||
for _, record := range records[maxBackups:] {
|
||||
selected[record.ID] = record
|
||||
}
|
||||
}
|
||||
if retentionDays > 0 {
|
||||
cutoff := now.AddDate(0, 0, -retentionDays)
|
||||
for _, record := range records {
|
||||
if record.CompletedAt != nil && record.CompletedAt.Before(cutoff) {
|
||||
selected[record.ID] = record
|
||||
}
|
||||
}
|
||||
}
|
||||
result := make([]model.BackupRecord, 0, len(selected))
|
||||
for _, record := range records {
|
||||
if selectedRecord, ok := selected[record.ID]; ok {
|
||||
result = append(result, selectedRecord)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
115
server/internal/backup/retention/service_test.go
Normal file
115
server/internal/backup/retention/service_test.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package retention
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/model"
|
||||
"backupx/server/internal/repository"
|
||||
"backupx/server/internal/storage"
|
||||
)
|
||||
|
||||
type fakeRecordRepository struct {
|
||||
records []model.BackupRecord
|
||||
deleted []uint
|
||||
deleteErrs map[uint]error
|
||||
}
|
||||
|
||||
func (r *fakeRecordRepository) List(context.Context, repository.BackupRecordListOptions) ([]model.BackupRecord, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *fakeRecordRepository) FindByID(context.Context, uint) (*model.BackupRecord, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *fakeRecordRepository) Create(context.Context, *model.BackupRecord) error { return nil }
|
||||
func (r *fakeRecordRepository) Update(context.Context, *model.BackupRecord) error { return nil }
|
||||
func (r *fakeRecordRepository) Delete(_ context.Context, id uint) error {
|
||||
if err := r.deleteErrs[id]; err != nil {
|
||||
return err
|
||||
}
|
||||
r.deleted = append(r.deleted, id)
|
||||
return nil
|
||||
}
|
||||
func (r *fakeRecordRepository) ListRecent(context.Context, int) ([]model.BackupRecord, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *fakeRecordRepository) ListSuccessfulByTask(_ context.Context, _ uint) ([]model.BackupRecord, error) {
|
||||
return r.records, nil
|
||||
}
|
||||
func (r *fakeRecordRepository) Count(context.Context) (int64, error) { return 0, nil }
|
||||
func (r *fakeRecordRepository) CountSince(context.Context, time.Time) (int64, error) { return 0, nil }
|
||||
func (r *fakeRecordRepository) CountSuccessSince(context.Context, time.Time) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (r *fakeRecordRepository) SumFileSize(context.Context) (int64, error) { return 0, nil }
|
||||
func (r *fakeRecordRepository) TimelineSince(context.Context, time.Time) ([]repository.BackupTimelinePoint, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *fakeRecordRepository) StorageUsage(context.Context) ([]repository.BackupStorageUsageItem, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type fakeProvider struct {
|
||||
deleted []string
|
||||
failKey string
|
||||
}
|
||||
|
||||
func (p *fakeProvider) Type() string { return storage.ProviderTypeLocalDisk }
|
||||
func (p *fakeProvider) TestConnection(context.Context) error { return nil }
|
||||
func (p *fakeProvider) Upload(context.Context, string, io.Reader, int64, map[string]string) error {
|
||||
return nil
|
||||
}
|
||||
func (p *fakeProvider) Download(context.Context, string) (io.ReadCloser, error) { return nil, nil }
|
||||
func (p *fakeProvider) Delete(_ context.Context, objectKey string) error {
|
||||
if objectKey == p.failKey {
|
||||
return fmt.Errorf("delete failed")
|
||||
}
|
||||
p.deleted = append(p.deleted, objectKey)
|
||||
return nil
|
||||
}
|
||||
func (p *fakeProvider) List(context.Context, string) ([]storage.ObjectInfo, error) { return nil, nil }
|
||||
|
||||
func TestSelectRecordsToDelete(t *testing.T) {
|
||||
now := time.Date(2026, 3, 7, 16, 0, 0, 0, time.UTC)
|
||||
completedNew := now.Add(-24 * time.Hour)
|
||||
completedOld := now.Add(-15 * 24 * time.Hour)
|
||||
records := []model.BackupRecord{
|
||||
{ID: 3, CompletedAt: &completedNew},
|
||||
{ID: 2, CompletedAt: &completedNew},
|
||||
{ID: 1, CompletedAt: &completedOld},
|
||||
}
|
||||
selected := selectRecordsToDelete(records, 7, 2, now)
|
||||
if len(selected) != 1 || selected[0].ID != 1 {
|
||||
t.Fatalf("unexpected selected records: %#v", selected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanupDeletesExpiredRecords(t *testing.T) {
|
||||
now := time.Date(2026, 3, 7, 16, 0, 0, 0, time.UTC)
|
||||
completedNew := now.Add(-24 * time.Hour)
|
||||
completedOld := now.Add(-15 * 24 * time.Hour)
|
||||
repo := &fakeRecordRepository{records: []model.BackupRecord{
|
||||
{ID: 3, TaskID: 1, StoragePath: "records/3", CompletedAt: &completedNew},
|
||||
{ID: 2, TaskID: 1, StoragePath: "records/2", CompletedAt: &completedNew},
|
||||
{ID: 1, TaskID: 1, StoragePath: "records/1", CompletedAt: &completedOld},
|
||||
}}
|
||||
provider := &fakeProvider{}
|
||||
service := NewService(repo)
|
||||
service.now = func() time.Time { return now }
|
||||
result, err := service.Cleanup(context.Background(), &model.BackupTask{ID: 1, RetentionDays: 7, MaxBackups: 2}, provider)
|
||||
if err != nil {
|
||||
t.Fatalf("Cleanup returned error: %v", err)
|
||||
}
|
||||
if result.DeletedRecords != 1 || result.DeletedObjects != 1 {
|
||||
t.Fatalf("unexpected cleanup result: %#v", result)
|
||||
}
|
||||
if len(repo.deleted) != 1 || repo.deleted[0] != 1 {
|
||||
t.Fatalf("unexpected deleted records: %#v", repo.deleted)
|
||||
}
|
||||
if len(provider.deleted) != 1 || provider.deleted[0] != "records/1" {
|
||||
t.Fatalf("unexpected deleted objects: %#v", provider.deleted)
|
||||
}
|
||||
}
|
||||
74
server/internal/backup/sqlite_runner.go
Normal file
74
server/internal/backup/sqlite_runner.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package backup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type SQLiteRunner struct{}
|
||||
|
||||
func NewSQLiteRunner() *SQLiteRunner {
|
||||
return &SQLiteRunner{}
|
||||
}
|
||||
|
||||
func (r *SQLiteRunner) Type() string {
|
||||
return "sqlite"
|
||||
}
|
||||
|
||||
func (r *SQLiteRunner) Run(_ context.Context, task TaskSpec, writer LogWriter) (*RunResult, error) {
|
||||
dbPath := filepath.Clean(strings.TrimSpace(task.Database.Path))
|
||||
if dbPath == "" {
|
||||
return nil, fmt.Errorf("sqlite database path is required")
|
||||
}
|
||||
if _, err := os.Stat(dbPath); err != nil {
|
||||
return nil, fmt.Errorf("stat sqlite database: %w", err)
|
||||
}
|
||||
tempDir, artifactPath, err := createTempArtifact(task.TempDir, task.Name, strings.TrimPrefix(filepath.Ext(dbPath), "."))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if filepath.Ext(artifactPath) == "." || filepath.Ext(artifactPath) == "" {
|
||||
artifactPath += ".sqlite"
|
||||
}
|
||||
if err := copyFile(dbPath, artifactPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
writer.WriteLine("SQLite 备份文件已复制")
|
||||
return &RunResult{ArtifactPath: artifactPath, FileName: filepath.Base(artifactPath), TempDir: tempDir}, nil
|
||||
}
|
||||
|
||||
func (r *SQLiteRunner) Restore(_ context.Context, task TaskSpec, artifactPath string, writer LogWriter) error {
|
||||
dbPath := filepath.Clean(strings.TrimSpace(task.Database.Path))
|
||||
if dbPath == "" {
|
||||
return fmt.Errorf("sqlite database path is required")
|
||||
}
|
||||
if err := copyFile(artifactPath, dbPath); err != nil {
|
||||
return err
|
||||
}
|
||||
writer.WriteLine("SQLite 数据库已恢复")
|
||||
return nil
|
||||
}
|
||||
|
||||
func copyFile(sourcePath string, targetPath string) error {
|
||||
source, err := os.Open(sourcePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open source file: %w", err)
|
||||
}
|
||||
defer source.Close()
|
||||
if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil {
|
||||
return fmt.Errorf("create target directory: %w", err)
|
||||
}
|
||||
target, err := os.Create(targetPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create target file: %w", err)
|
||||
}
|
||||
defer target.Close()
|
||||
if _, err := io.Copy(target, source); err != nil {
|
||||
return fmt.Errorf("copy file content: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
34
server/internal/backup/sqlite_runner_test.go
Normal file
34
server/internal/backup/sqlite_runner_test.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package backup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSQLiteRunnerRunAndRestore(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
dbPath := filepath.Join(tempDir, "data.db")
|
||||
if err := os.WriteFile(dbPath, []byte("sqlite-data"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile returned error: %v", err)
|
||||
}
|
||||
runner := NewSQLiteRunner()
|
||||
result, err := runner.Run(context.Background(), TaskSpec{Name: "sqlite backup", Type: "sqlite", Database: DatabaseSpec{Path: dbPath}}, NopLogWriter{})
|
||||
if err != nil {
|
||||
t.Fatalf("Run returned error: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(dbPath, []byte("mutated"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile returned error: %v", err)
|
||||
}
|
||||
if err := runner.Restore(context.Background(), TaskSpec{Name: "sqlite backup", Type: "sqlite", Database: DatabaseSpec{Path: dbPath}}, result.ArtifactPath, NopLogWriter{}); err != nil {
|
||||
t.Fatalf("Restore returned error: %v", err)
|
||||
}
|
||||
content, err := os.ReadFile(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile returned error: %v", err)
|
||||
}
|
||||
if string(content) != "sqlite-data" {
|
||||
t.Fatalf("unexpected restored content: %s", string(content))
|
||||
}
|
||||
}
|
||||
64
server/internal/backup/temp_files.go
Normal file
64
server/internal/backup/temp_files.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package backup
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var fileNameCleaner = regexp.MustCompile(`[^a-zA-Z0-9._-]+`)
|
||||
|
||||
func EnsureTempRoot() (string, error) {
|
||||
root := filepath.Join(os.TempDir(), "backupx")
|
||||
if err := os.MkdirAll(root, 0o755); err != nil {
|
||||
return "", fmt.Errorf("create backup temp root: %w", err)
|
||||
}
|
||||
return root, nil
|
||||
}
|
||||
|
||||
func CreateTaskTempDir(taskName string, startedAt time.Time) (string, error) {
|
||||
root, err := EnsureTempRoot()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
name := sanitizeTaskName(taskName)
|
||||
if name == "" {
|
||||
name = "backup"
|
||||
}
|
||||
path := filepath.Join(root, fmt.Sprintf("%s_%s", name, startedAt.UTC().Format("20060102_150405")))
|
||||
if err := os.MkdirAll(path, 0o755); err != nil {
|
||||
return "", fmt.Errorf("create task temp dir: %w", err)
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func BuildArtifactName(taskName string, startedAt time.Time, extension string) string {
|
||||
name := sanitizeTaskName(taskName)
|
||||
if name == "" {
|
||||
name = "backup"
|
||||
}
|
||||
ext := strings.TrimSpace(extension)
|
||||
if ext != "" && !strings.HasPrefix(ext, ".") {
|
||||
ext = "." + ext
|
||||
}
|
||||
return fmt.Sprintf("%s_%s%s", name, startedAt.UTC().Format("20060102_150405"), ext)
|
||||
}
|
||||
|
||||
func BuildStorageKey(backupType string, startedAt time.Time, fileName string) string {
|
||||
typeName := strings.TrimSpace(strings.ToLower(backupType))
|
||||
if typeName == "" {
|
||||
typeName = "file"
|
||||
}
|
||||
return filepath.ToSlash(filepath.Join("BackupX", typeName, startedAt.UTC().Format("060102"), fileName))
|
||||
}
|
||||
|
||||
func sanitizeTaskName(value string) string {
|
||||
trimmed := strings.TrimSpace(strings.ToLower(value))
|
||||
trimmed = strings.ReplaceAll(trimmed, " ", "-")
|
||||
trimmed = fileNameCleaner.ReplaceAllString(trimmed, "-")
|
||||
trimmed = strings.Trim(trimmed, "-._")
|
||||
return trimmed
|
||||
}
|
||||
73
server/internal/backup/types.go
Normal file
73
server/internal/backup/types.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package backup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DatabaseSpec struct {
|
||||
Host string
|
||||
Port int
|
||||
User string
|
||||
Password string
|
||||
Names []string
|
||||
Path string
|
||||
}
|
||||
|
||||
type TaskSpec struct {
|
||||
ID uint
|
||||
Name string
|
||||
Type string
|
||||
SourcePath string
|
||||
ExcludePatterns []string
|
||||
Database DatabaseSpec
|
||||
StorageTargetID uint
|
||||
StorageTargetType string
|
||||
Compression string
|
||||
Encrypt bool
|
||||
RetentionDays int
|
||||
MaxBackups int
|
||||
StartedAt time.Time
|
||||
TempDir string
|
||||
}
|
||||
|
||||
type RunResult struct {
|
||||
ArtifactPath string
|
||||
FileName string
|
||||
TempDir string
|
||||
Size int64
|
||||
StorageKey string
|
||||
}
|
||||
|
||||
type LogEvent struct {
|
||||
RecordID uint `json:"recordId"`
|
||||
Sequence int64 `json:"sequence"`
|
||||
Level string `json:"level"`
|
||||
Message string `json:"message"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Completed bool `json:"completed"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type LogWriter interface {
|
||||
WriteLine(message string)
|
||||
}
|
||||
|
||||
type LogSink interface {
|
||||
Infof(format string, args ...any)
|
||||
Warnf(format string, args ...any)
|
||||
Errorf(format string, args ...any)
|
||||
}
|
||||
|
||||
type NopLogWriter struct{}
|
||||
|
||||
func (NopLogWriter) WriteLine(string) {}
|
||||
func (NopLogWriter) Infof(string, ...any) {}
|
||||
func (NopLogWriter) Warnf(string, ...any) {}
|
||||
func (NopLogWriter) Errorf(string, ...any) {}
|
||||
|
||||
type BackupRunner interface {
|
||||
Type() string
|
||||
Run(ctx context.Context, task TaskSpec, writer LogWriter) (*RunResult, error)
|
||||
Restore(ctx context.Context, task TaskSpec, artifactPath string, writer LogWriter) error
|
||||
}
|
||||
143
server/internal/config/config.go
Normal file
143
server/internal/config/config.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Security SecurityConfig `mapstructure:"security"`
|
||||
Backup BackupConfig `mapstructure:"backup"`
|
||||
Log LogConfig `mapstructure:"log"`
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
Mode string `mapstructure:"mode"`
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Path string `mapstructure:"path"`
|
||||
}
|
||||
|
||||
type SecurityConfig struct {
|
||||
JWTSecret string `mapstructure:"jwt_secret"`
|
||||
JWTExpire string `mapstructure:"jwt_expire"`
|
||||
EncryptionKey string `mapstructure:"encryption_key"`
|
||||
}
|
||||
|
||||
type BackupConfig struct {
|
||||
TempDir string `mapstructure:"temp_dir"`
|
||||
MaxConcurrent int `mapstructure:"max_concurrent"`
|
||||
}
|
||||
|
||||
type LogConfig struct {
|
||||
Level string `mapstructure:"level"`
|
||||
File string `mapstructure:"file"`
|
||||
MaxSize int `mapstructure:"max_size"`
|
||||
MaxBackups int `mapstructure:"max_backups"`
|
||||
MaxAge int `mapstructure:"max_age"`
|
||||
}
|
||||
|
||||
func Load(configPath string) (Config, error) {
|
||||
v := viper.New()
|
||||
applyDefaults(v)
|
||||
v.SetConfigType("yaml")
|
||||
v.SetEnvPrefix("BACKUPX")
|
||||
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||
v.AutomaticEnv()
|
||||
|
||||
if configPath != "" {
|
||||
v.SetConfigFile(configPath)
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
return Config{}, fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
} else {
|
||||
v.SetConfigName("config")
|
||||
v.AddConfigPath(".")
|
||||
v.AddConfigPath("./server")
|
||||
v.AddConfigPath("/etc/backupx")
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
|
||||
return Config{}, fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
return Config{}, fmt.Errorf("decode config: %w", err)
|
||||
}
|
||||
|
||||
if cfg.Server.Host == "" {
|
||||
cfg.Server.Host = "0.0.0.0"
|
||||
}
|
||||
if cfg.Server.Port == 0 {
|
||||
cfg.Server.Port = 8340
|
||||
}
|
||||
if cfg.Server.Mode == "" {
|
||||
cfg.Server.Mode = "release"
|
||||
}
|
||||
if cfg.Database.Path == "" {
|
||||
cfg.Database.Path = "./data/backupx.db"
|
||||
}
|
||||
if cfg.Security.JWTExpire == "" {
|
||||
cfg.Security.JWTExpire = "24h"
|
||||
}
|
||||
if cfg.Backup.TempDir == "" {
|
||||
cfg.Backup.TempDir = "/tmp/backupx"
|
||||
}
|
||||
if cfg.Backup.MaxConcurrent <= 0 {
|
||||
cfg.Backup.MaxConcurrent = 2
|
||||
}
|
||||
if cfg.Log.Level == "" {
|
||||
cfg.Log.Level = "info"
|
||||
}
|
||||
if cfg.Log.File == "" {
|
||||
cfg.Log.File = "./data/backupx.log"
|
||||
}
|
||||
if cfg.Log.MaxSize <= 0 {
|
||||
cfg.Log.MaxSize = 100
|
||||
}
|
||||
if cfg.Log.MaxBackups <= 0 {
|
||||
cfg.Log.MaxBackups = 3
|
||||
}
|
||||
if cfg.Log.MaxAge <= 0 {
|
||||
cfg.Log.MaxAge = 30
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func MustJWTDuration(cfg SecurityConfig) time.Duration {
|
||||
duration, err := time.ParseDuration(cfg.JWTExpire)
|
||||
if err != nil {
|
||||
return 24 * time.Hour
|
||||
}
|
||||
return duration
|
||||
}
|
||||
|
||||
func (c Config) Address() string {
|
||||
return fmt.Sprintf("%s:%d", c.Server.Host, c.Server.Port)
|
||||
}
|
||||
|
||||
func applyDefaults(v *viper.Viper) {
|
||||
v.SetDefault("server.host", "0.0.0.0")
|
||||
v.SetDefault("server.port", 8340)
|
||||
v.SetDefault("server.mode", "release")
|
||||
v.SetDefault("database.path", "./data/backupx.db")
|
||||
v.SetDefault("security.jwt_expire", "24h")
|
||||
v.SetDefault("backup.temp_dir", "/tmp/backupx")
|
||||
v.SetDefault("backup.max_concurrent", 2)
|
||||
v.SetDefault("log.level", "info")
|
||||
v.SetDefault("log.file", "./data/backupx.log")
|
||||
v.SetDefault("log.max_size", 100)
|
||||
v.SetDefault("log.max_backups", 3)
|
||||
v.SetDefault("log.max_age", 30)
|
||||
}
|
||||
20
server/internal/config/config_test.go
Normal file
20
server/internal/config/config_test.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package config
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestLoadUsesDefaultsWithoutConfigFile(t *testing.T) {
|
||||
cfg, err := Load("")
|
||||
if err != nil {
|
||||
t.Fatalf("Load returned error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Server.Host != "0.0.0.0" {
|
||||
t.Fatalf("expected default host, got %s", cfg.Server.Host)
|
||||
}
|
||||
if cfg.Server.Port != 8340 {
|
||||
t.Fatalf("expected default port, got %d", cfg.Server.Port)
|
||||
}
|
||||
if cfg.Database.Path != "./data/backupx.db" {
|
||||
t.Fatalf("expected default database path, got %s", cfg.Database.Path)
|
||||
}
|
||||
}
|
||||
32
server/internal/database/database.go
Normal file
32
server/internal/database/database.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"backupx/server/internal/config"
|
||||
"backupx/server/internal/model"
|
||||
"github.com/glebarez/sqlite"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
gormlogger "gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
func Open(cfg config.DatabaseConfig, logger *zap.Logger) (*gorm.DB, error) {
|
||||
if err := os.MkdirAll(filepath.Dir(cfg.Path), 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create database dir: %w", err)
|
||||
}
|
||||
|
||||
db, err := gorm.Open(sqlite.Open(cfg.Path), &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open sqlite: %w", err)
|
||||
}
|
||||
|
||||
if err := db.AutoMigrate(&model.User{}, &model.SystemConfig{}, &model.StorageTarget{}, &model.OAuthSession{}, &model.BackupTask{}, &model.BackupRecord{}, &model.Notification{}, &model.Node{}); err != nil {
|
||||
return nil, fmt.Errorf("migrate schema: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("database initialized", zap.String("path", cfg.Path))
|
||||
return db, nil
|
||||
}
|
||||
91
server/internal/http/auth_handler.go
Normal file
91
server/internal/http/auth_handler.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"backupx/server/internal/apperror"
|
||||
"backupx/server/internal/service"
|
||||
"backupx/server/pkg/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type AuthHandler struct {
|
||||
authService *service.AuthService
|
||||
}
|
||||
|
||||
func NewAuthHandler(authService *service.AuthService) *AuthHandler {
|
||||
return &AuthHandler{authService: authService}
|
||||
}
|
||||
|
||||
func (h *AuthHandler) SetupStatus(c *gin.Context) {
|
||||
initialized, err := h.authService.SetupStatus(c.Request.Context())
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"initialized": initialized})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) Setup(c *gin.Context) {
|
||||
var input service.SetupInput
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
response.Error(c, apperror.BadRequest("AUTH_SETUP_INVALID", "初始化参数不合法", err))
|
||||
return
|
||||
}
|
||||
payload, err := h.authService.Setup(c.Request.Context(), input)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
func (h *AuthHandler) Login(c *gin.Context) {
|
||||
var input service.LoginInput
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
response.Error(c, apperror.BadRequest("AUTH_LOGIN_INVALID", "登录参数不合法", err))
|
||||
return
|
||||
}
|
||||
payload, err := h.authService.Login(c.Request.Context(), input, ClientKey(c))
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
func (h *AuthHandler) Profile(c *gin.Context) {
|
||||
subjectValue, _ := c.Get(contextUserSubjectKey)
|
||||
subject, err := service.SubjectFromContextValue(subjectValue)
|
||||
if err != nil {
|
||||
response.Error(c, apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效登录态", err))
|
||||
return
|
||||
}
|
||||
user, err := h.authService.GetCurrentUser(c.Request.Context(), subject)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, user)
|
||||
}
|
||||
|
||||
func (h *AuthHandler) ChangePassword(c *gin.Context) {
|
||||
subjectValue, _ := c.Get(contextUserSubjectKey)
|
||||
subject, err := service.SubjectFromContextValue(subjectValue)
|
||||
if err != nil {
|
||||
response.Error(c, apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效登录态", err))
|
||||
return
|
||||
}
|
||||
var input service.ChangePasswordInput
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
response.Error(c, apperror.BadRequest("AUTH_PASSWORD_INVALID", "参数不合法", err))
|
||||
return
|
||||
}
|
||||
if err := h.authService.ChangePassword(c.Request.Context(), subject, input); err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"changed": true})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) Logout(c *gin.Context) {
|
||||
response.Success(c, gin.H{"loggedOut": true})
|
||||
}
|
||||
189
server/internal/http/backup_record_handler.go
Normal file
189
server/internal/http/backup_record_handler.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/apperror"
|
||||
"backupx/server/internal/backup"
|
||||
"backupx/server/internal/service"
|
||||
"backupx/server/pkg/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type BackupRecordHandler struct {
|
||||
service *service.BackupRecordService
|
||||
}
|
||||
|
||||
func NewBackupRecordHandler(recordService *service.BackupRecordService) *BackupRecordHandler {
|
||||
return &BackupRecordHandler{service: recordService}
|
||||
}
|
||||
|
||||
func (h *BackupRecordHandler) List(c *gin.Context) {
|
||||
filter, err := buildRecordFilter(c)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
items, err := h.service.List(c.Request.Context(), filter)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, items)
|
||||
}
|
||||
|
||||
func (h *BackupRecordHandler) Get(c *gin.Context) {
|
||||
id, ok := parseUintParam(c, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
item, err := h.service.Get(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, item)
|
||||
}
|
||||
|
||||
func (h *BackupRecordHandler) StreamLogs(c *gin.Context) {
|
||||
id, ok := parseUintParam(c, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
detail, err := h.service.Get(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
events := detail.LogEvents
|
||||
completed := detail.Status != "running"
|
||||
channel, cancel, err := h.service.SubscribeLogs(c.Request.Context(), id, 64)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
flusher, ok := c.Writer.(interface{ Flush() })
|
||||
if !ok {
|
||||
response.Error(c, apperror.Internal("BACKUP_RECORD_STREAM_UNSUPPORTED", "当前连接不支持日志流", nil))
|
||||
return
|
||||
}
|
||||
for _, event := range events {
|
||||
if err := writeSSEEvent(c.Writer, event); err != nil {
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
if completed {
|
||||
return
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case event, ok := <-channel:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := writeSSEEvent(c.Writer, event); err != nil {
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
if event.Completed {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *BackupRecordHandler) Download(c *gin.Context) {
|
||||
id, ok := parseUintParam(c, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
result, err := h.service.Download(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
defer result.Reader.Close()
|
||||
c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=%q", result.FileName))
|
||||
c.Header("Content-Type", "application/octet-stream")
|
||||
_, _ = io.Copy(c.Writer, result.Reader)
|
||||
}
|
||||
|
||||
func (h *BackupRecordHandler) Restore(c *gin.Context) {
|
||||
id, ok := parseUintParam(c, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := h.service.Restore(c.Request.Context(), id); err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"restored": true})
|
||||
}
|
||||
|
||||
func (h *BackupRecordHandler) Delete(c *gin.Context) {
|
||||
id, ok := parseUintParam(c, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := h.service.Delete(c.Request.Context(), id); err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
func buildRecordFilter(c *gin.Context) (service.BackupRecordListInput, error) {
|
||||
var filter service.BackupRecordListInput
|
||||
if taskIDValue := strings.TrimSpace(c.Query("taskId")); taskIDValue != "" {
|
||||
parsed, ok := parseUintString(taskIDValue)
|
||||
if !ok {
|
||||
return filter, apperror.BadRequest("BACKUP_RECORD_FILTER_INVALID", "taskId 不合法", nil)
|
||||
}
|
||||
filter.TaskID = &parsed
|
||||
}
|
||||
filter.Status = strings.TrimSpace(c.Query("status"))
|
||||
if dateFrom := strings.TrimSpace(c.Query("dateFrom")); dateFrom != "" {
|
||||
parsed, err := time.Parse(time.RFC3339, dateFrom)
|
||||
if err != nil {
|
||||
return filter, apperror.BadRequest("BACKUP_RECORD_FILTER_INVALID", "dateFrom 必须为 RFC3339 时间格式", err)
|
||||
}
|
||||
filter.DateFrom = &parsed
|
||||
}
|
||||
if dateTo := strings.TrimSpace(c.Query("dateTo")); dateTo != "" {
|
||||
parsed, err := time.Parse(time.RFC3339, dateTo)
|
||||
if err != nil {
|
||||
return filter, apperror.BadRequest("BACKUP_RECORD_FILTER_INVALID", "dateTo 必须为 RFC3339 时间格式", err)
|
||||
}
|
||||
filter.DateTo = &parsed
|
||||
}
|
||||
return filter, nil
|
||||
}
|
||||
|
||||
func writeSSEEvent(writer io.Writer, event backup.LogEvent) error {
|
||||
payload, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = fmt.Fprintf(writer, "event: log\ndata: %s\n\n", payload)
|
||||
return err
|
||||
}
|
||||
|
||||
func parseUintString(value string) (uint, bool) {
|
||||
parsed, err := strconv.ParseUint(strings.TrimSpace(value), 10, 64)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return uint(parsed), true
|
||||
}
|
||||
28
server/internal/http/backup_run_handler.go
Normal file
28
server/internal/http/backup_run_handler.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"backupx/server/internal/service"
|
||||
"backupx/server/pkg/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type BackupRunHandler struct {
|
||||
service *service.BackupExecutionService
|
||||
}
|
||||
|
||||
func NewBackupRunHandler(executionService *service.BackupExecutionService) *BackupRunHandler {
|
||||
return &BackupRunHandler{service: executionService}
|
||||
}
|
||||
|
||||
func (h *BackupRunHandler) Run(c *gin.Context) {
|
||||
id, ok := parseUintParam(c, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
record, err := h.service.RunTaskByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, record)
|
||||
}
|
||||
109
server/internal/http/backup_task_handler.go
Normal file
109
server/internal/http/backup_task_handler.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"backupx/server/internal/apperror"
|
||||
"backupx/server/internal/service"
|
||||
"backupx/server/pkg/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type BackupTaskHandler struct {
|
||||
service *service.BackupTaskService
|
||||
}
|
||||
|
||||
func NewBackupTaskHandler(taskService *service.BackupTaskService) *BackupTaskHandler {
|
||||
return &BackupTaskHandler{service: taskService}
|
||||
}
|
||||
|
||||
func (h *BackupTaskHandler) List(c *gin.Context) {
|
||||
items, err := h.service.List(c.Request.Context())
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, items)
|
||||
}
|
||||
|
||||
func (h *BackupTaskHandler) Get(c *gin.Context) {
|
||||
id, ok := parseUintParam(c, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
item, err := h.service.Get(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, item)
|
||||
}
|
||||
|
||||
func (h *BackupTaskHandler) Create(c *gin.Context) {
|
||||
var input service.BackupTaskUpsertInput
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
response.Error(c, apperror.BadRequest("BACKUP_TASK_INVALID", "备份任务参数不合法", err))
|
||||
return
|
||||
}
|
||||
item, err := h.service.Create(c.Request.Context(), input)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, item)
|
||||
}
|
||||
|
||||
func (h *BackupTaskHandler) Update(c *gin.Context) {
|
||||
id, ok := parseUintParam(c, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var input service.BackupTaskUpsertInput
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
response.Error(c, apperror.BadRequest("BACKUP_TASK_INVALID", "备份任务参数不合法", err))
|
||||
return
|
||||
}
|
||||
item, err := h.service.Update(c.Request.Context(), id, input)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, item)
|
||||
}
|
||||
|
||||
func (h *BackupTaskHandler) Delete(c *gin.Context) {
|
||||
id, ok := parseUintParam(c, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := h.service.Delete(c.Request.Context(), id); err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
func (h *BackupTaskHandler) Toggle(c *gin.Context) {
|
||||
id, ok := parseUintParam(c, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var input service.BackupTaskToggleInput
|
||||
if err := c.ShouldBindJSON(&input); err != nil && err.Error() != "EOF" {
|
||||
response.Error(c, apperror.BadRequest("BACKUP_TASK_TOGGLE_INVALID", "备份任务启停参数不合法", err))
|
||||
return
|
||||
}
|
||||
current, err := h.service.Get(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
enabled := !current.Enabled
|
||||
if input.Enabled != nil {
|
||||
enabled = *input.Enabled
|
||||
}
|
||||
item, err := h.service.Toggle(c.Request.Context(), id, enabled)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, item)
|
||||
}
|
||||
3
server/internal/http/context.go
Normal file
3
server/internal/http/context.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package http
|
||||
|
||||
const contextUserSubjectKey = "userSubject"
|
||||
46
server/internal/http/dashboard_handler.go
Normal file
46
server/internal/http/dashboard_handler.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"backupx/server/internal/apperror"
|
||||
"backupx/server/internal/service"
|
||||
"backupx/server/pkg/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type DashboardHandler struct {
|
||||
service *service.DashboardService
|
||||
}
|
||||
|
||||
func NewDashboardHandler(dashboardService *service.DashboardService) *DashboardHandler {
|
||||
return &DashboardHandler{service: dashboardService}
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) Stats(c *gin.Context) {
|
||||
payload, err := h.service.Stats(c.Request.Context())
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) Timeline(c *gin.Context) {
|
||||
days := 30
|
||||
if value := strings.TrimSpace(c.Query("days")); value != "" {
|
||||
parsed, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
response.Error(c, apperror.BadRequest("DASHBOARD_TIMELINE_INVALID", "days 必须为整数", err))
|
||||
return
|
||||
}
|
||||
days = parsed
|
||||
}
|
||||
payload, err := h.service.Timeline(c.Request.Context(), days)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, payload)
|
||||
}
|
||||
57
server/internal/http/middleware.go
Normal file
57
server/internal/http/middleware.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
stdhttp "net/http"
|
||||
"strings"
|
||||
|
||||
"backupx/server/internal/apperror"
|
||||
"backupx/server/internal/security"
|
||||
"backupx/server/pkg/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CORSMiddleware handles Cross-Origin Resource Sharing for the API.
|
||||
func CORSMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Accept, Authorization")
|
||||
c.Header("Access-Control-Max-Age", "86400")
|
||||
|
||||
if c.Request.Method == stdhttp.MethodOptions {
|
||||
c.AbortWithStatus(stdhttp.StatusNoContent)
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func AuthMiddleware(jwtManager *security.JWTManager) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
header := strings.TrimSpace(c.GetHeader("Authorization"))
|
||||
if !strings.HasPrefix(header, "Bearer ") {
|
||||
response.Error(c, apperror.Unauthorized("AUTH_REQUIRED", "请先登录", nil))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := strings.TrimSpace(strings.TrimPrefix(header, "Bearer "))
|
||||
claims, err := jwtManager.Parse(tokenString)
|
||||
if err != nil {
|
||||
response.Error(c, apperror.Unauthorized("AUTH_INVALID_TOKEN", "登录状态已失效,请重新登录", err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Set(contextUserSubjectKey, claims.Subject)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func ClientKey(c *gin.Context) string {
|
||||
ip := strings.TrimSpace(c.ClientIP())
|
||||
if ip == "" {
|
||||
return "unknown"
|
||||
}
|
||||
return ip
|
||||
}
|
||||
101
server/internal/http/node_handler.go
Normal file
101
server/internal/http/node_handler.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
stdhttp "net/http"
|
||||
"strconv"
|
||||
|
||||
"backupx/server/internal/service"
|
||||
"backupx/server/pkg/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type NodeHandler struct {
|
||||
service *service.NodeService
|
||||
}
|
||||
|
||||
func NewNodeHandler(service *service.NodeService) *NodeHandler {
|
||||
return &NodeHandler{service: service}
|
||||
}
|
||||
|
||||
func (h *NodeHandler) List(c *gin.Context) {
|
||||
items, err := h.service.List(c.Request.Context())
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, items)
|
||||
}
|
||||
|
||||
func (h *NodeHandler) Get(c *gin.Context) {
|
||||
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
item, err := h.service.Get(c.Request.Context(), uint(id))
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, item)
|
||||
}
|
||||
|
||||
func (h *NodeHandler) Create(c *gin.Context) {
|
||||
var input service.NodeCreateInput
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
c.JSON(stdhttp.StatusBadRequest, gin.H{"code": "INVALID_INPUT", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
token, err := h.service.Create(c.Request.Context(), input)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"token": token})
|
||||
}
|
||||
|
||||
func (h *NodeHandler) Delete(c *gin.Context) {
|
||||
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
if err := h.service.Delete(c.Request.Context(), uint(id)); err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, nil)
|
||||
}
|
||||
|
||||
func (h *NodeHandler) ListDirectory(c *gin.Context) {
|
||||
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
path := c.DefaultQuery("path", "/")
|
||||
entries, err := h.service.ListDirectory(c.Request.Context(), uint(id), path)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, entries)
|
||||
}
|
||||
|
||||
func (h *NodeHandler) Heartbeat(c *gin.Context) {
|
||||
var input struct {
|
||||
Token string `json:"token" binding:"required"`
|
||||
Hostname string `json:"hostname"`
|
||||
IPAddress string `json:"ipAddress"`
|
||||
AgentVersion string `json:"agentVersion"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
c.JSON(stdhttp.StatusBadRequest, gin.H{"code": "INVALID_INPUT", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := h.service.Heartbeat(c.Request.Context(), input.Token, input.Hostname, input.IPAddress, input.AgentVersion); err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"status": "ok"})
|
||||
}
|
||||
107
server/internal/http/notification_handler.go
Normal file
107
server/internal/http/notification_handler.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"backupx/server/internal/apperror"
|
||||
"backupx/server/internal/service"
|
||||
"backupx/server/pkg/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type NotificationHandler struct {
|
||||
service *service.NotificationService
|
||||
}
|
||||
|
||||
func NewNotificationHandler(notificationService *service.NotificationService) *NotificationHandler {
|
||||
return &NotificationHandler{service: notificationService}
|
||||
}
|
||||
|
||||
func (h *NotificationHandler) List(c *gin.Context) {
|
||||
items, err := h.service.List(c.Request.Context())
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, items)
|
||||
}
|
||||
|
||||
func (h *NotificationHandler) Get(c *gin.Context) {
|
||||
id, ok := parseUintParam(c, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
item, err := h.service.Get(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, item)
|
||||
}
|
||||
|
||||
func (h *NotificationHandler) Create(c *gin.Context) {
|
||||
var input service.NotificationUpsertInput
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
response.Error(c, apperror.BadRequest("NOTIFICATION_INVALID", "通知配置参数不合法", err))
|
||||
return
|
||||
}
|
||||
item, err := h.service.Create(c.Request.Context(), input)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, item)
|
||||
}
|
||||
|
||||
func (h *NotificationHandler) Update(c *gin.Context) {
|
||||
id, ok := parseUintParam(c, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var input service.NotificationUpsertInput
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
response.Error(c, apperror.BadRequest("NOTIFICATION_INVALID", "通知配置参数不合法", err))
|
||||
return
|
||||
}
|
||||
item, err := h.service.Update(c.Request.Context(), id, input)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, item)
|
||||
}
|
||||
|
||||
func (h *NotificationHandler) Delete(c *gin.Context) {
|
||||
id, ok := parseUintParam(c, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := h.service.Delete(c.Request.Context(), id); err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
func (h *NotificationHandler) Test(c *gin.Context) {
|
||||
var input service.NotificationUpsertInput
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
response.Error(c, apperror.BadRequest("NOTIFICATION_INVALID", "通知配置参数不合法", err))
|
||||
return
|
||||
}
|
||||
if err := h.service.Test(c.Request.Context(), input); err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
func (h *NotificationHandler) TestSaved(c *gin.Context) {
|
||||
id, ok := parseUintParam(c, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := h.service.TestSaved(c.Request.Context(), id); err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
152
server/internal/http/router.go
Normal file
152
server/internal/http/router.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"errors"
|
||||
stdhttp "net/http"
|
||||
|
||||
"backupx/server/internal/apperror"
|
||||
"backupx/server/internal/config"
|
||||
"backupx/server/internal/repository"
|
||||
"backupx/server/internal/security"
|
||||
"backupx/server/internal/service"
|
||||
"backupx/server/pkg/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type RouterDependencies struct {
|
||||
Config config.Config
|
||||
Version string
|
||||
Logger *zap.Logger
|
||||
AuthService *service.AuthService
|
||||
SystemService *service.SystemService
|
||||
StorageTargetService *service.StorageTargetService
|
||||
BackupTaskService *service.BackupTaskService
|
||||
BackupExecutionService *service.BackupExecutionService
|
||||
BackupRecordService *service.BackupRecordService
|
||||
NotificationService *service.NotificationService
|
||||
DashboardService *service.DashboardService
|
||||
SettingsService *service.SettingsService
|
||||
NodeService *service.NodeService
|
||||
JWTManager *security.JWTManager
|
||||
UserRepository repository.UserRepository
|
||||
SystemConfigRepo repository.SystemConfigRepository
|
||||
}
|
||||
|
||||
func NewRouter(deps RouterDependencies) *gin.Engine {
|
||||
gin.SetMode(deps.Config.Server.Mode)
|
||||
engine := gin.New()
|
||||
engine.Use(gin.Recovery())
|
||||
engine.Use(CORSMiddleware())
|
||||
engine.Use(requestLogger(deps.Logger))
|
||||
|
||||
authHandler := NewAuthHandler(deps.AuthService)
|
||||
systemHandler := NewSystemHandler(deps.SystemService)
|
||||
storageTargetHandler := NewStorageTargetHandler(deps.StorageTargetService)
|
||||
backupTaskHandler := NewBackupTaskHandler(deps.BackupTaskService)
|
||||
backupRunHandler := NewBackupRunHandler(deps.BackupExecutionService)
|
||||
backupRecordHandler := NewBackupRecordHandler(deps.BackupRecordService)
|
||||
notificationHandler := NewNotificationHandler(deps.NotificationService)
|
||||
dashboardHandler := NewDashboardHandler(deps.DashboardService)
|
||||
settingsHandler := NewSettingsHandler(deps.SettingsService)
|
||||
|
||||
api := engine.Group("/api")
|
||||
{
|
||||
auth := api.Group("/auth")
|
||||
{
|
||||
auth.GET("/setup/status", authHandler.SetupStatus)
|
||||
auth.POST("/setup", authHandler.Setup)
|
||||
auth.POST("/login", authHandler.Login)
|
||||
auth.POST("/logout", AuthMiddleware(deps.JWTManager), authHandler.Logout)
|
||||
auth.GET("/profile", AuthMiddleware(deps.JWTManager), authHandler.Profile)
|
||||
auth.PUT("/password", AuthMiddleware(deps.JWTManager), authHandler.ChangePassword)
|
||||
}
|
||||
|
||||
system := api.Group("/system")
|
||||
system.Use(AuthMiddleware(deps.JWTManager))
|
||||
system.GET("/info", systemHandler.Info)
|
||||
|
||||
storageTargets := api.Group("/storage-targets")
|
||||
storageTargets.Use(AuthMiddleware(deps.JWTManager))
|
||||
storageTargets.GET("", storageTargetHandler.List)
|
||||
storageTargets.GET("/:id", storageTargetHandler.Get)
|
||||
storageTargets.POST("", storageTargetHandler.Create)
|
||||
storageTargets.PUT("/:id", storageTargetHandler.Update)
|
||||
storageTargets.DELETE("/:id", storageTargetHandler.Delete)
|
||||
storageTargets.POST("/test", storageTargetHandler.TestConnection)
|
||||
storageTargets.POST("/:id/test", storageTargetHandler.TestSavedConnection)
|
||||
storageTargets.GET("/:id/usage", storageTargetHandler.GetUsage)
|
||||
storageTargets.POST("/google-drive/auth-url", storageTargetHandler.StartGoogleDriveOAuth)
|
||||
storageTargets.POST("/google-drive/complete", storageTargetHandler.CompleteGoogleDriveOAuth)
|
||||
storageTargets.GET("/google-drive/callback", storageTargetHandler.HandleGoogleDriveCallback)
|
||||
storageTargets.GET("/:id/google-drive/profile", storageTargetHandler.GoogleDriveProfile)
|
||||
|
||||
backupTasks := api.Group("/backup/tasks")
|
||||
backupTasks.Use(AuthMiddleware(deps.JWTManager))
|
||||
backupTasks.GET("", backupTaskHandler.List)
|
||||
backupTasks.GET("/:id", backupTaskHandler.Get)
|
||||
backupTasks.POST("", backupTaskHandler.Create)
|
||||
backupTasks.PUT("/:id", backupTaskHandler.Update)
|
||||
backupTasks.DELETE("/:id", backupTaskHandler.Delete)
|
||||
backupTasks.PUT("/:id/toggle", backupTaskHandler.Toggle)
|
||||
backupTasks.POST("/:id/run", backupRunHandler.Run)
|
||||
|
||||
backupRecords := api.Group("/backup/records")
|
||||
backupRecords.Use(AuthMiddleware(deps.JWTManager))
|
||||
backupRecords.GET("", backupRecordHandler.List)
|
||||
backupRecords.GET("/:id", backupRecordHandler.Get)
|
||||
backupRecords.GET("/:id/logs/stream", backupRecordHandler.StreamLogs)
|
||||
backupRecords.GET("/:id/download", backupRecordHandler.Download)
|
||||
backupRecords.POST("/:id/restore", backupRecordHandler.Restore)
|
||||
backupRecords.DELETE("/:id", backupRecordHandler.Delete)
|
||||
dashboard := api.Group("/dashboard")
|
||||
dashboard.Use(AuthMiddleware(deps.JWTManager))
|
||||
dashboard.GET("/stats", dashboardHandler.Stats)
|
||||
dashboard.GET("/timeline", dashboardHandler.Timeline)
|
||||
|
||||
notifications := api.Group("/notifications")
|
||||
notifications.Use(AuthMiddleware(deps.JWTManager))
|
||||
notifications.GET("", notificationHandler.List)
|
||||
notifications.GET("/:id", notificationHandler.Get)
|
||||
notifications.POST("", notificationHandler.Create)
|
||||
notifications.PUT("/:id", notificationHandler.Update)
|
||||
notifications.DELETE("/:id", notificationHandler.Delete)
|
||||
notifications.POST("/test", notificationHandler.Test)
|
||||
notifications.POST("/:id/test", notificationHandler.TestSaved)
|
||||
|
||||
settings := api.Group("/settings")
|
||||
settings.Use(AuthMiddleware(deps.JWTManager))
|
||||
settings.GET("", settingsHandler.Get)
|
||||
settings.PUT("", settingsHandler.Update)
|
||||
|
||||
nodeHandler := NewNodeHandler(deps.NodeService)
|
||||
nodes := api.Group("/nodes")
|
||||
nodes.Use(AuthMiddleware(deps.JWTManager))
|
||||
nodes.GET("", nodeHandler.List)
|
||||
nodes.GET("/:id", nodeHandler.Get)
|
||||
nodes.POST("", nodeHandler.Create)
|
||||
nodes.DELETE("/:id", nodeHandler.Delete)
|
||||
nodes.GET("/:id/fs/list", nodeHandler.ListDirectory)
|
||||
|
||||
// Agent heartbeat (public, token-authenticated)
|
||||
api.POST("/agent/heartbeat", nodeHandler.Heartbeat)
|
||||
}
|
||||
|
||||
engine.NoRoute(func(c *gin.Context) {
|
||||
response.Error(c, apperror.New(stdhttp.StatusNotFound, "NOT_FOUND", "接口不存在", errors.New("route not found")))
|
||||
})
|
||||
|
||||
return engine
|
||||
}
|
||||
|
||||
func requestLogger(logger *zap.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
logger.Info("http request",
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.Int("status", c.Writer.Status()),
|
||||
zap.String("client_ip", c.ClientIP()),
|
||||
)
|
||||
}
|
||||
}
|
||||
94
server/internal/http/router_test.go
Normal file
94
server/internal/http/router_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/config"
|
||||
"backupx/server/internal/database"
|
||||
"backupx/server/internal/logger"
|
||||
"backupx/server/internal/repository"
|
||||
"backupx/server/internal/security"
|
||||
"backupx/server/internal/service"
|
||||
)
|
||||
|
||||
func TestSetupLoginAndProfileFlow(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
cfg := config.Config{
|
||||
Server: config.ServerConfig{Host: "127.0.0.1", Port: 8340, Mode: "test"},
|
||||
Database: config.DatabaseConfig{Path: filepath.Join(tempDir, "backupx.db")},
|
||||
Security: config.SecurityConfig{JWTExpire: "24h"},
|
||||
Log: config.LogConfig{Level: "error"},
|
||||
}
|
||||
|
||||
log, err := logger.New(cfg.Log)
|
||||
if err != nil {
|
||||
t.Fatalf("logger.New error: %v", err)
|
||||
}
|
||||
db, err := database.Open(cfg.Database, log)
|
||||
if err != nil {
|
||||
t.Fatalf("database.Open error: %v", err)
|
||||
}
|
||||
|
||||
userRepo := repository.NewUserRepository(db)
|
||||
systemConfigRepo := repository.NewSystemConfigRepository(db)
|
||||
resolved, err := service.ResolveSecurity(context.Background(), cfg.Security, systemConfigRepo)
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveSecurity error: %v", err)
|
||||
}
|
||||
jwtManager := security.NewJWTManager(resolved.JWTSecret, time.Hour)
|
||||
authService := service.NewAuthService(userRepo, systemConfigRepo, jwtManager, security.NewLoginRateLimiter(5, time.Minute))
|
||||
systemService := service.NewSystemService(cfg, "test", time.Now().UTC())
|
||||
|
||||
router := NewRouter(RouterDependencies{
|
||||
Config: cfg,
|
||||
Version: "test",
|
||||
Logger: log,
|
||||
AuthService: authService,
|
||||
SystemService: systemService,
|
||||
JWTManager: jwtManager,
|
||||
UserRepository: userRepo,
|
||||
SystemConfigRepo: systemConfigRepo,
|
||||
})
|
||||
|
||||
setupBody, _ := json.Marshal(map[string]string{
|
||||
"username": "admin",
|
||||
"password": "password-123",
|
||||
"displayName": "Admin",
|
||||
})
|
||||
setupRequest := httptest.NewRequest(http.MethodPost, "/api/auth/setup", bytes.NewBuffer(setupBody))
|
||||
setupRequest.Header.Set("Content-Type", "application/json")
|
||||
setupRecorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(setupRecorder, setupRequest)
|
||||
|
||||
if setupRecorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected setup 200, got %d", setupRecorder.Code)
|
||||
}
|
||||
|
||||
var setupResponse struct {
|
||||
Data struct {
|
||||
Token string `json:"token"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(setupRecorder.Body.Bytes(), &setupResponse); err != nil {
|
||||
t.Fatalf("unmarshal setup response: %v", err)
|
||||
}
|
||||
if setupResponse.Data.Token == "" {
|
||||
t.Fatalf("expected token in setup response")
|
||||
}
|
||||
|
||||
profileRequest := httptest.NewRequest(http.MethodGet, "/api/auth/profile", nil)
|
||||
profileRequest.Header.Set("Authorization", "Bearer "+setupResponse.Data.Token)
|
||||
profileRecorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(profileRecorder, profileRequest)
|
||||
|
||||
if profileRecorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected profile 200, got %d", profileRecorder.Code)
|
||||
}
|
||||
}
|
||||
39
server/internal/http/settings_handler.go
Normal file
39
server/internal/http/settings_handler.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"backupx/server/internal/apperror"
|
||||
"backupx/server/internal/service"
|
||||
"backupx/server/pkg/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type SettingsHandler struct {
|
||||
settingsService *service.SettingsService
|
||||
}
|
||||
|
||||
func NewSettingsHandler(settingsService *service.SettingsService) *SettingsHandler {
|
||||
return &SettingsHandler{settingsService: settingsService}
|
||||
}
|
||||
|
||||
func (h *SettingsHandler) Get(c *gin.Context) {
|
||||
settings, err := h.settingsService.GetAll(c.Request.Context())
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, settings)
|
||||
}
|
||||
|
||||
func (h *SettingsHandler) Update(c *gin.Context) {
|
||||
var input map[string]string
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
response.Error(c, apperror.BadRequest("SETTINGS_INVALID", "设置参数不合法", err))
|
||||
return
|
||||
}
|
||||
settings, err := h.settingsService.Update(c.Request.Context(), input)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, settings)
|
||||
}
|
||||
244
server/internal/http/storage_target_handler.go
Normal file
244
server/internal/http/storage_target_handler.go
Normal file
@@ -0,0 +1,244 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"backupx/server/internal/apperror"
|
||||
"backupx/server/internal/service"
|
||||
"backupx/server/pkg/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type StorageTargetHandler struct {
|
||||
service *service.StorageTargetService
|
||||
}
|
||||
|
||||
type storageTargetGoogleDriveAuthRequest struct {
|
||||
TargetID *uint `json:"targetId"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Config map[string]any `json:"config"`
|
||||
ClientID string `json:"clientId"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
FolderID string `json:"folderId"`
|
||||
}
|
||||
|
||||
func NewStorageTargetHandler(service *service.StorageTargetService) *StorageTargetHandler {
|
||||
return &StorageTargetHandler{service: service}
|
||||
}
|
||||
|
||||
func (h *StorageTargetHandler) List(c *gin.Context) {
|
||||
items, err := h.service.List(c.Request.Context())
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, items)
|
||||
}
|
||||
|
||||
func (h *StorageTargetHandler) Get(c *gin.Context) {
|
||||
id, ok := parseUintParam(c, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
item, err := h.service.Get(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, item)
|
||||
}
|
||||
|
||||
func (h *StorageTargetHandler) Create(c *gin.Context) {
|
||||
var input service.StorageTargetUpsertInput
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
response.Error(c, apperror.BadRequest("STORAGE_TARGET_INVALID", "存储目标参数不合法", err))
|
||||
return
|
||||
}
|
||||
item, err := h.service.Create(c.Request.Context(), input)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, item)
|
||||
}
|
||||
|
||||
func (h *StorageTargetHandler) Update(c *gin.Context) {
|
||||
id, ok := parseUintParam(c, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var input service.StorageTargetUpsertInput
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
response.Error(c, apperror.BadRequest("STORAGE_TARGET_INVALID", "存储目标参数不合法", err))
|
||||
return
|
||||
}
|
||||
item, err := h.service.Update(c.Request.Context(), id, input)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, item)
|
||||
}
|
||||
|
||||
func (h *StorageTargetHandler) Delete(c *gin.Context) {
|
||||
id, ok := parseUintParam(c, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := h.service.Delete(c.Request.Context(), id); err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
func (h *StorageTargetHandler) TestConnection(c *gin.Context) {
|
||||
var payload service.StorageTargetUpsertInput
|
||||
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||
response.Error(c, apperror.BadRequest("STORAGE_TARGET_TEST_INVALID", "测试连接参数不合法", err))
|
||||
return
|
||||
}
|
||||
if err := h.service.TestConnection(c.Request.Context(), service.StorageTargetTestInput{Payload: payload}); err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"success": true, "message": "连接成功"})
|
||||
}
|
||||
|
||||
func (h *StorageTargetHandler) TestSavedConnection(c *gin.Context) {
|
||||
id, ok := parseUintParam(c, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := h.service.TestConnection(c.Request.Context(), service.StorageTargetTestInput{TargetID: &id}); err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"success": true, "message": "连接成功"})
|
||||
}
|
||||
|
||||
func (h *StorageTargetHandler) StartGoogleDriveOAuth(c *gin.Context) {
|
||||
var request storageTargetGoogleDriveAuthRequest
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
response.Error(c, apperror.BadRequest("STORAGE_GOOGLE_OAUTH_INVALID", "Google Drive 授权参数不合法", err))
|
||||
return
|
||||
}
|
||||
input := service.GoogleDriveAuthStartInput{
|
||||
TargetID: request.TargetID,
|
||||
Name: strings.TrimSpace(request.Name),
|
||||
Description: strings.TrimSpace(request.Description),
|
||||
Enabled: request.Enabled,
|
||||
ClientID: firstNonEmpty(asString(request.Config["clientId"]), request.ClientID),
|
||||
ClientSecret: firstNonEmpty(asString(request.Config["clientSecret"]), request.ClientSecret),
|
||||
FolderID: firstNonEmpty(asString(request.Config["folderId"]), request.FolderID),
|
||||
}
|
||||
result, err := h.service.StartGoogleDriveOAuth(c.Request.Context(), input, requestOrigin(c))
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"authUrl": result.AuthorizationURL})
|
||||
}
|
||||
|
||||
func (h *StorageTargetHandler) CompleteGoogleDriveOAuth(c *gin.Context) {
|
||||
var input service.GoogleDriveAuthCompleteInput
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
response.Error(c, apperror.BadRequest("STORAGE_GOOGLE_OAUTH_INVALID", "Google Drive 回调参数不合法", err))
|
||||
return
|
||||
}
|
||||
item, err := h.service.CompleteGoogleDriveOAuth(c.Request.Context(), input)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, item)
|
||||
}
|
||||
|
||||
func (h *StorageTargetHandler) HandleGoogleDriveCallback(c *gin.Context) {
|
||||
if queryError := strings.TrimSpace(c.Query("error")); queryError != "" {
|
||||
response.Success(c, gin.H{"success": false, "message": queryError})
|
||||
return
|
||||
}
|
||||
input := service.GoogleDriveAuthCompleteInput{State: strings.TrimSpace(c.Query("state")), Code: strings.TrimSpace(c.Query("code"))}
|
||||
if input.State == "" || input.Code == "" {
|
||||
response.Error(c, apperror.BadRequest("STORAGE_GOOGLE_OAUTH_INVALID", "Google Drive 回调参数不合法", nil))
|
||||
return
|
||||
}
|
||||
item, err := h.service.CompleteGoogleDriveOAuth(c.Request.Context(), input)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"success": true, "message": "Google Drive 授权成功", "target": item})
|
||||
}
|
||||
|
||||
func (h *StorageTargetHandler) GoogleDriveProfile(c *gin.Context) {
|
||||
id, ok := parseUintParam(c, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
profile, err := h.service.GoogleDriveProfile(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, profile)
|
||||
}
|
||||
|
||||
func parseUintParam(c *gin.Context, key string) (uint, bool) {
|
||||
value := strings.TrimSpace(c.Param(key))
|
||||
parsed, err := strconv.ParseUint(value, 10, 64)
|
||||
if err != nil {
|
||||
response.Error(c, apperror.BadRequest("INVALID_ID", fmt.Sprintf("参数 %s 不合法", key), err))
|
||||
return 0, false
|
||||
}
|
||||
return uint(parsed), true
|
||||
}
|
||||
|
||||
func requestOrigin(c *gin.Context) string {
|
||||
origin := strings.TrimSpace(c.GetHeader("Origin"))
|
||||
if origin != "" {
|
||||
return origin
|
||||
}
|
||||
scheme := strings.TrimSpace(c.GetHeader("X-Forwarded-Proto"))
|
||||
if scheme == "" {
|
||||
if c.Request.TLS != nil {
|
||||
scheme = "https"
|
||||
} else {
|
||||
scheme = "http"
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("%s://%s", scheme, c.Request.Host)
|
||||
}
|
||||
|
||||
func asString(value any) string {
|
||||
text, _ := value.(string)
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, value := range values {
|
||||
if strings.TrimSpace(value) != "" {
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (h *StorageTargetHandler) GetUsage(c *gin.Context) {
|
||||
id, ok := parseUintParam(c, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
usage, err := h.service.GetUsage(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.Error(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, usage)
|
||||
}
|
||||
19
server/internal/http/system_handler.go
Normal file
19
server/internal/http/system_handler.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"backupx/server/internal/service"
|
||||
"backupx/server/pkg/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type SystemHandler struct {
|
||||
systemService *service.SystemService
|
||||
}
|
||||
|
||||
func NewSystemHandler(systemService *service.SystemService) *SystemHandler {
|
||||
return &SystemHandler{systemService: systemService}
|
||||
}
|
||||
|
||||
func (h *SystemHandler) Info(c *gin.Context) {
|
||||
response.Success(c, h.systemService.GetInfo(c.Request.Context()))
|
||||
}
|
||||
98
server/internal/httpapi/auth_handler.go
Normal file
98
server/internal/httpapi/auth_handler.go
Normal file
@@ -0,0 +1,98 @@
|
||||
//go:build ignore
|
||||
|
||||
package httpapi
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"backupx/server/internal/service"
|
||||
"backupx/server/pkg/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type authHandler struct {
|
||||
service *service.AuthService
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
type setupRequest struct {
|
||||
Username string `json:"username" binding:"required,min=3,max=64"`
|
||||
Password string `json:"password" binding:"required,min=8,max=128"`
|
||||
DisplayName string `json:"displayName" binding:"required,min=1,max=128"`
|
||||
}
|
||||
|
||||
type loginRequest struct {
|
||||
Username string `json:"username" binding:"required,min=3,max=64"`
|
||||
Password string `json:"password" binding:"required,min=8,max=128"`
|
||||
}
|
||||
|
||||
func newAuthHandler(service *service.AuthService, logger *zap.Logger) *authHandler {
|
||||
return &authHandler{service: service, logger: logger}
|
||||
}
|
||||
|
||||
func (h *authHandler) registerRoutes(router gin.IRouter, protected gin.IRouter) {
|
||||
router.GET("/auth/setup/status", h.getSetupStatus)
|
||||
router.POST("/auth/setup", h.setup)
|
||||
router.POST("/auth/login", h.login)
|
||||
protected.GET("/auth/profile", h.profile)
|
||||
}
|
||||
|
||||
func (h *authHandler) getSetupStatus(c *gin.Context) {
|
||||
initialized, err := h.service.GetSetupStatus(c.Request.Context())
|
||||
if err != nil {
|
||||
writeError(c, h.logger, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"initialized": initialized})
|
||||
}
|
||||
|
||||
func (h *authHandler) setup(c *gin.Context) {
|
||||
payload, err := bindJSON[setupRequest](c, h.logger)
|
||||
if err != nil {
|
||||
writeError(c, h.logger, err)
|
||||
return
|
||||
}
|
||||
result, err := h.service.Setup(c.Request.Context(), service.SetupInput{
|
||||
Username: payload.Username,
|
||||
Password: payload.Password,
|
||||
DisplayName: payload.DisplayName,
|
||||
})
|
||||
if err != nil {
|
||||
writeError(c, h.logger, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusCreated, response.Envelope{Code: "OK", Message: "success", Data: result})
|
||||
}
|
||||
|
||||
func (h *authHandler) login(c *gin.Context) {
|
||||
payload, err := bindJSON[loginRequest](c, h.logger)
|
||||
if err != nil {
|
||||
writeError(c, h.logger, err)
|
||||
return
|
||||
}
|
||||
result, err := h.service.Login(c.Request.Context(), service.LoginInput{
|
||||
Username: payload.Username,
|
||||
Password: payload.Password,
|
||||
RemoteAddr: c.ClientIP(),
|
||||
})
|
||||
if err != nil {
|
||||
writeError(c, h.logger, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
func (h *authHandler) profile(c *gin.Context) {
|
||||
userID, err := getUserID(c)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusUnauthorized, "AUTH_UNAUTHORIZED", "认证信息无效")
|
||||
return
|
||||
}
|
||||
result, err := h.service.GetCurrentUser(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
writeError(c, h.logger, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
23
server/internal/httpapi/context.go
Normal file
23
server/internal/httpapi/context.go
Normal file
@@ -0,0 +1,23 @@
|
||||
//go:build ignore
|
||||
|
||||
package httpapi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const claimsContextKey = "authClaims"
|
||||
|
||||
func getUserID(c *gin.Context) (uint, error) {
|
||||
value, ok := c.Get(claimsContextKey)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("missing auth claims")
|
||||
}
|
||||
claims, ok := value.(AuthClaims)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("invalid auth claims")
|
||||
}
|
||||
return claims.UserID, nil
|
||||
}
|
||||
92
server/internal/httpapi/middleware.go
Normal file
92
server/internal/httpapi/middleware.go
Normal file
@@ -0,0 +1,92 @@
|
||||
//go:build ignore
|
||||
|
||||
package httpapi
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"backupx/server/internal/apperror"
|
||||
"backupx/server/internal/security"
|
||||
"backupx/server/pkg/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type AuthClaims struct {
|
||||
UserID uint
|
||||
Username string
|
||||
Role string
|
||||
}
|
||||
|
||||
func Recovery(logger *zap.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
logger.Error("panic recovered", zap.Any("panic", recovered), zap.String("path", c.Request.URL.Path))
|
||||
response.Error(c, http.StatusInternalServerError, "INTERNAL_ERROR", "服务器内部错误")
|
||||
c.Abort()
|
||||
}
|
||||
}()
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func RequestLogger(logger *zap.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
logger.Info("http request",
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.Int("status", c.Writer.Status()),
|
||||
zap.String("client_ip", c.ClientIP()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func AuthMiddleware(jwtManager *security.JWTManager) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
authorization := strings.TrimSpace(c.GetHeader("Authorization"))
|
||||
if authorization == "" || !strings.HasPrefix(strings.ToLower(authorization), "bearer ") {
|
||||
response.Error(c, http.StatusUnauthorized, "AUTH_UNAUTHORIZED", "缺少有效的认证令牌")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
tokenValue := strings.TrimSpace(strings.TrimPrefix(authorization, "Bearer"))
|
||||
if tokenValue == authorization {
|
||||
tokenValue = strings.TrimSpace(strings.TrimPrefix(authorization, "bearer"))
|
||||
}
|
||||
claims, err := jwtManager.Parse(tokenValue)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusUnauthorized, "AUTH_UNAUTHORIZED", "认证令牌无效或已过期")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Set(claimsContextKey, AuthClaims{UserID: claims.UserID, Username: claims.Username, Role: claims.Role})
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func writeError(c *gin.Context, logger *zap.Logger, err error) {
|
||||
var appErr *apperror.AppError
|
||||
if errors.As(err, &appErr) {
|
||||
if appErr.Err != nil {
|
||||
logger.Warn("request failed", zap.String("code", appErr.Code), zap.Error(appErr.Err))
|
||||
}
|
||||
response.Error(c, appErr.Status, appErr.Code, appErr.Message)
|
||||
return
|
||||
}
|
||||
logger.Error("unexpected error", zap.Error(err))
|
||||
response.Error(c, http.StatusInternalServerError, "INTERNAL_ERROR", "服务器内部错误")
|
||||
}
|
||||
|
||||
func bindJSON[T any](c *gin.Context, logger *zap.Logger) (*T, error) {
|
||||
var payload T
|
||||
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||
logger.Warn("bind json failed", zap.Error(err))
|
||||
return nil, apperror.Wrap(http.StatusBadRequest, "INVALID_REQUEST", fmt.Sprintf("请求参数错误: %v", err), err)
|
||||
}
|
||||
return &payload, nil
|
||||
}
|
||||
38
server/internal/httpapi/router.go
Normal file
38
server/internal/httpapi/router.go
Normal file
@@ -0,0 +1,38 @@
|
||||
//go:build ignore
|
||||
|
||||
package httpapi
|
||||
|
||||
import (
|
||||
"backupx/server/internal/security"
|
||||
"backupx/server/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type Dependencies struct {
|
||||
Logger *zap.Logger
|
||||
AuthService *service.AuthService
|
||||
SystemService *service.SystemService
|
||||
JWTManager *security.JWTManager
|
||||
Mode string
|
||||
}
|
||||
|
||||
func NewRouter(deps Dependencies) *gin.Engine {
|
||||
gin.SetMode(deps.Mode)
|
||||
router := gin.New()
|
||||
router.Use(Recovery(deps.Logger), RequestLogger(deps.Logger))
|
||||
|
||||
api := router.Group("/api")
|
||||
authHandler := newAuthHandler(deps.AuthService, deps.Logger)
|
||||
systemHandler := newSystemHandler(deps.SystemService)
|
||||
protected := api.Group("")
|
||||
protected.Use(AuthMiddleware(deps.JWTManager))
|
||||
|
||||
authHandler.registerRoutes(api, protected)
|
||||
systemHandler.registerRoutes(protected)
|
||||
api.GET("/healthz", func(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
return router
|
||||
}
|
||||
96
server/internal/httpapi/router_test.go
Normal file
96
server/internal/httpapi/router_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
//go:build ignore
|
||||
|
||||
package httpapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/config"
|
||||
"backupx/server/internal/database"
|
||||
"backupx/server/internal/logger"
|
||||
"backupx/server/internal/repository"
|
||||
"backupx/server/internal/security"
|
||||
"backupx/server/internal/service"
|
||||
)
|
||||
|
||||
func TestSetupLoginProfileAndSystemInfo(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := config.Config{
|
||||
Server: config.ServerConfig{Mode: "test"},
|
||||
Database: config.DatabaseConfig{Path: filepath.Join(tmpDir, "backupx.db")},
|
||||
Security: config.SecurityConfig{JWTSecret: "test-jwt-secret", JWTExpire: "1h", EncryptionKey: "test-encryption-key"},
|
||||
Log: config.LogConfig{Level: "error"},
|
||||
}
|
||||
log, err := logger.New(cfg.Log)
|
||||
if err != nil {
|
||||
t.Fatalf("logger.New() error = %v", err)
|
||||
}
|
||||
db, err := database.Open(cfg.Database, log)
|
||||
if err != nil {
|
||||
t.Fatalf("database.Open() error = %v", err)
|
||||
}
|
||||
jwtManager := security.NewJWTManager(cfg.Security.JWTSecret, time.Hour)
|
||||
authService := service.NewAuthService(repository.NewUserRepository(db), jwtManager, security.NewLoginLimiter(5, time.Minute))
|
||||
systemService := service.NewSystemService(cfg, "test", time.Now().Add(-time.Minute))
|
||||
router := NewRouter(Dependencies{Logger: log, AuthService: authService, SystemService: systemService, JWTManager: jwtManager, Mode: "test"})
|
||||
|
||||
setupBody := map[string]string{"username": "admin", "password": "super-secret", "displayName": "管理员"}
|
||||
setupResp := performJSONRequest(t, router, http.MethodPost, "/api/auth/setup", setupBody, "")
|
||||
if setupResp.Code != http.StatusCreated {
|
||||
t.Fatalf("unexpected setup status: %d body=%s", setupResp.Code, setupResp.Body.String())
|
||||
}
|
||||
var setupPayload struct {
|
||||
Code string `json:"code"`
|
||||
Data struct {
|
||||
Token string `json:"token"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(setupResp.Body.Bytes(), &setupPayload); err != nil {
|
||||
t.Fatalf("decode setup response: %v", err)
|
||||
}
|
||||
if setupPayload.Data.Token == "" {
|
||||
t.Fatal("expected token in setup response")
|
||||
}
|
||||
|
||||
profileResp := performJSONRequest(t, router, http.MethodGet, "/api/auth/profile", nil, setupPayload.Data.Token)
|
||||
if profileResp.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected profile status: %d body=%s", profileResp.Code, profileResp.Body.String())
|
||||
}
|
||||
|
||||
loginBody := map[string]string{"username": "admin", "password": "super-secret"}
|
||||
loginResp := performJSONRequest(t, router, http.MethodPost, "/api/auth/login", loginBody, "")
|
||||
if loginResp.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected login status: %d body=%s", loginResp.Code, loginResp.Body.String())
|
||||
}
|
||||
|
||||
systemResp := performJSONRequest(t, router, http.MethodGet, "/api/system/info", nil, setupPayload.Data.Token)
|
||||
if systemResp.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected system info status: %d body=%s", systemResp.Code, systemResp.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func performJSONRequest(t *testing.T, handler http.Handler, method string, path string, payload any, token string) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
var body []byte
|
||||
if payload != nil {
|
||||
encoded, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal() error = %v", err)
|
||||
}
|
||||
body = encoded
|
||||
}
|
||||
request := httptest.NewRequest(method, path, bytes.NewReader(body))
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
if token != "" {
|
||||
request.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
response := httptest.NewRecorder()
|
||||
handler.ServeHTTP(response, request)
|
||||
return response
|
||||
}
|
||||
25
server/internal/httpapi/system_handler.go
Normal file
25
server/internal/httpapi/system_handler.go
Normal file
@@ -0,0 +1,25 @@
|
||||
//go:build ignore
|
||||
|
||||
package httpapi
|
||||
|
||||
import (
|
||||
"backupx/server/internal/service"
|
||||
"backupx/server/pkg/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type systemHandler struct {
|
||||
service *service.SystemService
|
||||
}
|
||||
|
||||
func newSystemHandler(service *service.SystemService) *systemHandler {
|
||||
return &systemHandler{service: service}
|
||||
}
|
||||
|
||||
func (h *systemHandler) registerRoutes(protected gin.IRouter) {
|
||||
protected.GET("/system/info", h.info)
|
||||
}
|
||||
|
||||
func (h *systemHandler) info(c *gin.Context) {
|
||||
response.Success(c, h.service.GetInfo())
|
||||
}
|
||||
53
server/internal/logger/logger.go
Normal file
53
server/internal/logger/logger.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"backupx/server/internal/config"
|
||||
"github.com/natefinch/lumberjack"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
func New(cfg config.LogConfig) (*zap.Logger, error) {
|
||||
level := parseLevel(cfg.Level)
|
||||
encoderCfg := zap.NewProductionEncoderConfig()
|
||||
encoderCfg.TimeKey = "time"
|
||||
encoderCfg.EncodeTime = zapcore.ISO8601TimeEncoder
|
||||
encoder := zapcore.NewJSONEncoder(encoderCfg)
|
||||
|
||||
writers := []zapcore.WriteSyncer{zapcore.AddSync(os.Stdout)}
|
||||
if cfg.File != "" {
|
||||
if err := os.MkdirAll(filepath.Dir(cfg.File), 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create log dir: %w", err)
|
||||
}
|
||||
rotator := &lumberjack.Logger{
|
||||
Filename: cfg.File,
|
||||
MaxSize: cfg.MaxSize,
|
||||
MaxBackups: cfg.MaxBackups,
|
||||
MaxAge: cfg.MaxAge,
|
||||
LocalTime: false,
|
||||
Compress: true,
|
||||
}
|
||||
writers = append(writers, zapcore.AddSync(rotator))
|
||||
}
|
||||
|
||||
core := zapcore.NewCore(encoder, zapcore.NewMultiWriteSyncer(writers...), level)
|
||||
return zap.New(core, zap.AddCaller(), zap.AddStacktrace(zapcore.ErrorLevel)), nil
|
||||
}
|
||||
|
||||
func parseLevel(value string) zapcore.Level {
|
||||
switch strings.ToLower(strings.TrimSpace(value)) {
|
||||
case "debug":
|
||||
return zapcore.DebugLevel
|
||||
case "warn":
|
||||
return zapcore.WarnLevel
|
||||
case "error":
|
||||
return zapcore.ErrorLevel
|
||||
default:
|
||||
return zapcore.InfoLevel
|
||||
}
|
||||
}
|
||||
32
server/internal/model/backup_record.go
Normal file
32
server/internal/model/backup_record.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
BackupRecordStatusRunning = "running"
|
||||
BackupRecordStatusSuccess = "success"
|
||||
BackupRecordStatusFailed = "failed"
|
||||
)
|
||||
|
||||
type BackupRecord struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
TaskID uint `gorm:"column:task_id;index;not null" json:"taskId"`
|
||||
Task BackupTask `json:"task,omitempty"`
|
||||
StorageTargetID uint `gorm:"column:storage_target_id;index;not null" json:"storageTargetId"`
|
||||
StorageTarget StorageTarget `json:"storageTarget,omitempty"`
|
||||
Status string `gorm:"size:20;index;not null" json:"status"`
|
||||
FileName string `gorm:"column:file_name;size:255" json:"fileName"`
|
||||
FileSize int64 `gorm:"column:file_size;not null;default:0" json:"fileSize"`
|
||||
StoragePath string `gorm:"column:storage_path;size:500" json:"storagePath"`
|
||||
DurationSeconds int `gorm:"column:duration_seconds;not null;default:0" json:"durationSeconds"`
|
||||
ErrorMessage string `gorm:"column:error_message;size:2000" json:"errorMessage"`
|
||||
LogContent string `gorm:"column:log_content;type:text" json:"logContent"`
|
||||
StartedAt time.Time `gorm:"column:started_at;index;not null" json:"startedAt"`
|
||||
CompletedAt *time.Time `gorm:"column:completed_at;index" json:"completedAt,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
func (BackupRecord) TableName() string {
|
||||
return "backup_records"
|
||||
}
|
||||
50
server/internal/model/backup_task.go
Normal file
50
server/internal/model/backup_task.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
BackupTaskTypeFile = "file"
|
||||
BackupTaskTypeMySQL = "mysql"
|
||||
BackupTaskTypeSQLite = "sqlite"
|
||||
BackupTaskTypePostgreSQL = "postgresql"
|
||||
)
|
||||
|
||||
const (
|
||||
BackupTaskStatusIdle = "idle"
|
||||
BackupTaskStatusRunning = "running"
|
||||
BackupTaskStatusSuccess = "success"
|
||||
BackupTaskStatusFailed = "failed"
|
||||
)
|
||||
|
||||
type BackupTask struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"size:100;uniqueIndex;not null" json:"name"`
|
||||
Type string `gorm:"size:20;index;not null" json:"type"`
|
||||
Enabled bool `gorm:"not null;default:true" json:"enabled"`
|
||||
CronExpr string `gorm:"column:cron_expr;size:64" json:"cronExpr"`
|
||||
SourcePath string `gorm:"column:source_path;size:500" json:"sourcePath"`
|
||||
ExcludePatterns string `gorm:"column:exclude_patterns;type:text" json:"excludePatterns"`
|
||||
DBHost string `gorm:"column:db_host;size:255" json:"dbHost"`
|
||||
DBPort int `gorm:"column:db_port" json:"dbPort"`
|
||||
DBUser string `gorm:"column:db_user;size:100" json:"dbUser"`
|
||||
DBPasswordCiphertext string `gorm:"column:db_password_ciphertext;type:text" json:"-"`
|
||||
DBName string `gorm:"column:db_name;size:255" json:"dbName"`
|
||||
DBPath string `gorm:"column:db_path;size:500" json:"dbPath"`
|
||||
StorageTargetID uint `gorm:"column:storage_target_id;index;not null" json:"storageTargetId"`
|
||||
StorageTarget StorageTarget `json:"storageTarget,omitempty"`
|
||||
NodeID uint `gorm:"column:node_id;index;default:0" json:"nodeId"`
|
||||
Node Node `json:"node,omitempty"`
|
||||
Tags string `gorm:"column:tags;size:500" json:"tags"`
|
||||
RetentionDays int `gorm:"column:retention_days;not null;default:30" json:"retentionDays"`
|
||||
Compression string `gorm:"size:10;not null;default:'gzip'" json:"compression"`
|
||||
Encrypt bool `gorm:"not null;default:false" json:"encrypt"`
|
||||
MaxBackups int `gorm:"column:max_backups;not null;default:10" json:"maxBackups"`
|
||||
LastRunAt *time.Time `gorm:"column:last_run_at" json:"lastRunAt,omitempty"`
|
||||
LastStatus string `gorm:"column:last_status;size:20;not null;default:'idle'" json:"lastStatus"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
func (BackupTask) TableName() string {
|
||||
return "backup_tasks"
|
||||
}
|
||||
30
server/internal/model/node.go
Normal file
30
server/internal/model/node.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
NodeStatusOnline = "online"
|
||||
NodeStatusOffline = "offline"
|
||||
)
|
||||
|
||||
// Node represents a managed server node in the cluster.
|
||||
// The default "local" node is auto-created for single-machine backward compatibility.
|
||||
type Node struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"size:128;uniqueIndex;not null" json:"name"`
|
||||
Hostname string `gorm:"size:255" json:"hostname"`
|
||||
IPAddress string `gorm:"column:ip_address;size:64" json:"ipAddress"`
|
||||
Token string `gorm:"size:128;uniqueIndex;not null" json:"-"`
|
||||
Status string `gorm:"size:20;not null;default:'offline'" json:"status"`
|
||||
IsLocal bool `gorm:"not null;default:false" json:"isLocal"`
|
||||
OS string `gorm:"size:64" json:"os"`
|
||||
Arch string `gorm:"size:32" json:"arch"`
|
||||
AgentVer string `gorm:"column:agent_version;size:32" json:"agentVersion"`
|
||||
LastSeen time.Time `gorm:"column:last_seen" json:"lastSeen"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
func (Node) TableName() string {
|
||||
return "nodes"
|
||||
}
|
||||
19
server/internal/model/notification.go
Normal file
19
server/internal/model/notification.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
type Notification struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
Type string `gorm:"size:20;index;not null" json:"type"`
|
||||
Name string `gorm:"size:100;uniqueIndex;not null" json:"name"`
|
||||
ConfigCiphertext string `gorm:"column:config_ciphertext;type:text;not null" json:"-"`
|
||||
Enabled bool `gorm:"not null;default:true" json:"enabled"`
|
||||
OnSuccess bool `gorm:"column:on_success;not null;default:false" json:"onSuccess"`
|
||||
OnFailure bool `gorm:"column:on_failure;not null;default:true" json:"onFailure"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
func (Notification) TableName() string {
|
||||
return "notifications"
|
||||
}
|
||||
19
server/internal/model/oauth_session.go
Normal file
19
server/internal/model/oauth_session.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
type OAuthSession struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
ProviderType string `gorm:"column:provider_type;size:32;index;not null" json:"providerType"`
|
||||
State string `gorm:"size:255;uniqueIndex;not null" json:"state"`
|
||||
PayloadCiphertext string `gorm:"column:payload_ciphertext;type:text;not null" json:"-"`
|
||||
TargetID *uint `gorm:"column:target_id" json:"targetId,omitempty"`
|
||||
ExpiresAt time.Time `gorm:"column:expires_at;index;not null" json:"expiresAt"`
|
||||
UsedAt *time.Time `gorm:"column:used_at" json:"usedAt,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
func (OAuthSession) TableName() string {
|
||||
return "oauth_sessions"
|
||||
}
|
||||
22
server/internal/model/storage_target.go
Normal file
22
server/internal/model/storage_target.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
type StorageTarget struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"size:128;uniqueIndex;not null" json:"name"`
|
||||
Type string `gorm:"size:32;index;not null" json:"type"`
|
||||
Description string `gorm:"size:255" json:"description"`
|
||||
Enabled bool `gorm:"not null;default:true" json:"enabled"`
|
||||
ConfigCiphertext string `gorm:"column:config_ciphertext;type:text;not null" json:"-"`
|
||||
ConfigVersion int `gorm:"not null;default:1" json:"configVersion"`
|
||||
LastTestedAt *time.Time `gorm:"column:last_tested_at" json:"lastTestedAt,omitempty"`
|
||||
LastTestStatus string `gorm:"column:last_test_status;size:32;not null;default:'unknown'" json:"lastTestStatus"`
|
||||
LastTestMessage string `gorm:"column:last_test_message;size:512" json:"lastTestMessage"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
func (StorageTarget) TableName() string {
|
||||
return "storage_targets"
|
||||
}
|
||||
16
server/internal/model/system_config.go
Normal file
16
server/internal/model/system_config.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
type SystemConfig struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
Key string `gorm:"size:128;uniqueIndex;not null" json:"key"`
|
||||
Value string `gorm:"type:text;not null" json:"value"`
|
||||
Encrypted bool `gorm:"not null;default:false" json:"encrypted"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
func (SystemConfig) TableName() string {
|
||||
return "system_configs"
|
||||
}
|
||||
18
server/internal/model/user.go
Normal file
18
server/internal/model/user.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
type User struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
Username string `gorm:"size:64;uniqueIndex;not null" json:"username"`
|
||||
PasswordHash string `gorm:"column:password_hash;not null" json:"-"`
|
||||
DisplayName string `gorm:"size:128;not null" json:"displayName"`
|
||||
Email string `gorm:"size:255" json:"email"`
|
||||
Role string `gorm:"size:32;not null;default:admin" json:"role"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
func (User) TableName() string {
|
||||
return "users"
|
||||
}
|
||||
88
server/internal/notify/email.go
Normal file
88
server/internal/notify/email.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/smtp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type EmailNotifier struct{}
|
||||
|
||||
func NewEmailNotifier() *EmailNotifier { return &EmailNotifier{} }
|
||||
func (n *EmailNotifier) Type() string { return "email" }
|
||||
func (n *EmailNotifier) SensitiveFields() []string { return []string{"password"} }
|
||||
|
||||
func (n *EmailNotifier) Validate(config map[string]any) error {
|
||||
host := strings.TrimSpace(asString(config["host"]))
|
||||
port := asInt(config["port"])
|
||||
from := strings.TrimSpace(asString(config["from"]))
|
||||
to := strings.TrimSpace(asString(config["to"]))
|
||||
if host == "" || port <= 0 || from == "" || to == "" {
|
||||
return fmt.Errorf("email host/port/from/to are required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *EmailNotifier) Send(_ context.Context, config map[string]any, message Message) error {
|
||||
if err := n.Validate(config); err != nil {
|
||||
return err
|
||||
}
|
||||
host := strings.TrimSpace(asString(config["host"]))
|
||||
port := asInt(config["port"])
|
||||
username := strings.TrimSpace(asString(config["username"]))
|
||||
password := strings.TrimSpace(asString(config["password"]))
|
||||
from := strings.TrimSpace(asString(config["from"]))
|
||||
toList := splitCommaValues(asString(config["to"]))
|
||||
address := host + ":" + strconv.Itoa(port)
|
||||
headers := []string{"From: " + from, "To: " + strings.Join(toList, ", "), "Subject: " + message.Title, "MIME-Version: 1.0", "Content-Type: text/plain; charset=UTF-8", "", message.Body}
|
||||
var auth smtp.Auth
|
||||
if username != "" {
|
||||
auth = smtp.PlainAuth("", username, password, host)
|
||||
}
|
||||
|
||||
rawMessage := []byte(strings.Join(headers, "\r\n"))
|
||||
|
||||
if port == 465 {
|
||||
tlsConfig := &tls.Config{ServerName: host}
|
||||
conn, err := tls.Dial("tcp", address, tlsConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial tls for smtp port 465 failed: %w", err)
|
||||
}
|
||||
client, err := smtp.NewClient(conn, host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create smtp client over tls failed: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
if auth != nil {
|
||||
if ok, _ := client.Extension("AUTH"); ok {
|
||||
if err = client.Auth(auth); err != nil {
|
||||
return fmt.Errorf("smtp auth failed: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err = client.Mail(from); err != nil {
|
||||
return fmt.Errorf("smtp mail from failed: %w", err)
|
||||
}
|
||||
for _, toAddr := range toList {
|
||||
if err = client.Rcpt(toAddr); err != nil {
|
||||
return fmt.Errorf("smtp rcpt failed for %s: %w", toAddr, err)
|
||||
}
|
||||
}
|
||||
writer, err := client.Data()
|
||||
if err != nil {
|
||||
return fmt.Errorf("smtp data failed: %w", err)
|
||||
}
|
||||
if _, err = writer.Write(rawMessage); err != nil {
|
||||
return fmt.Errorf("smtp write message failed: %w", err)
|
||||
}
|
||||
if err = writer.Close(); err != nil {
|
||||
return fmt.Errorf("smtp data close failed: %w", err)
|
||||
}
|
||||
return client.Quit()
|
||||
}
|
||||
|
||||
return smtp.SendMail(address, auth, from, toList, rawMessage)
|
||||
}
|
||||
49
server/internal/notify/helpers.go
Normal file
49
server/internal/notify/helpers.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func asString(value any) string {
|
||||
text, _ := value.(string)
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
|
||||
func asInt(value any) int {
|
||||
switch actual := value.(type) {
|
||||
case int:
|
||||
return actual
|
||||
case int64:
|
||||
return int(actual)
|
||||
case float64:
|
||||
return int(actual)
|
||||
case string:
|
||||
parsed, _ := strconv.Atoi(strings.TrimSpace(actual))
|
||||
return parsed
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func splitCommaValues(value string) []string {
|
||||
items := strings.Split(value, ",")
|
||||
result := make([]string, 0, len(items))
|
||||
for _, item := range items {
|
||||
trimmed := strings.TrimSpace(item)
|
||||
if trimmed != "" {
|
||||
result = append(result, trimmed)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func validateRequiredConfig(config map[string]any, fields ...string) error {
|
||||
for _, field := range fields {
|
||||
if strings.TrimSpace(asString(config[field])) == "" {
|
||||
return fmt.Errorf("%s is required", field)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
75
server/internal/notify/registry.go
Normal file
75
server/internal/notify/registry.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Registry struct {
|
||||
mu sync.RWMutex
|
||||
notifiers map[string]Notifier
|
||||
}
|
||||
|
||||
func NewRegistry(notifiers ...Notifier) *Registry {
|
||||
registry := &Registry{notifiers: make(map[string]Notifier)}
|
||||
for _, notifier := range notifiers {
|
||||
registry.Register(notifier)
|
||||
}
|
||||
return registry
|
||||
}
|
||||
|
||||
func (r *Registry) Register(notifier Notifier) {
|
||||
if notifier == nil {
|
||||
return
|
||||
}
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.notifiers == nil {
|
||||
r.notifiers = make(map[string]Notifier)
|
||||
}
|
||||
r.notifiers[notifier.Type()] = notifier
|
||||
}
|
||||
|
||||
func (r *Registry) Types() []string {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
items := make([]string, 0, len(r.notifiers))
|
||||
for key := range r.notifiers {
|
||||
items = append(items, key)
|
||||
}
|
||||
sort.Strings(items)
|
||||
return items
|
||||
}
|
||||
|
||||
func (r *Registry) SensitiveFields(notificationType string) []string {
|
||||
notifier, ok := r.Notifier(notificationType)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return notifier.SensitiveFields()
|
||||
}
|
||||
|
||||
func (r *Registry) Validate(notificationType string, config map[string]any) error {
|
||||
notifier, ok := r.Notifier(notificationType)
|
||||
if !ok {
|
||||
return fmt.Errorf("unsupported notification type: %s", notificationType)
|
||||
}
|
||||
return notifier.Validate(config)
|
||||
}
|
||||
|
||||
func (r *Registry) Send(ctx context.Context, notificationType string, config map[string]any, message Message) error {
|
||||
notifier, ok := r.Notifier(notificationType)
|
||||
if !ok {
|
||||
return fmt.Errorf("unsupported notification type: %s", notificationType)
|
||||
}
|
||||
return notifier.Send(ctx, config, message)
|
||||
}
|
||||
|
||||
func (r *Registry) Notifier(notificationType string) (Notifier, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
notifier, ok := r.notifiers[notificationType]
|
||||
return notifier, ok
|
||||
}
|
||||
54
server/internal/notify/telegram.go
Normal file
54
server/internal/notify/telegram.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TelegramNotifier struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func NewTelegramNotifier() *TelegramNotifier {
|
||||
return &TelegramNotifier{client: &http.Client{Timeout: 10 * time.Second}}
|
||||
}
|
||||
func (n *TelegramNotifier) Type() string { return "telegram" }
|
||||
func (n *TelegramNotifier) SensitiveFields() []string { return []string{"botToken"} }
|
||||
|
||||
func (n *TelegramNotifier) Validate(config map[string]any) error {
|
||||
if strings.TrimSpace(asString(config["botToken"])) == "" || strings.TrimSpace(asString(config["chatId"])) == "" {
|
||||
return fmt.Errorf("telegram botToken/chatId are required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *TelegramNotifier) Send(ctx context.Context, config map[string]any, message Message) error {
|
||||
if err := n.Validate(config); err != nil {
|
||||
return err
|
||||
}
|
||||
botToken := strings.TrimSpace(asString(config["botToken"]))
|
||||
chatID := strings.TrimSpace(asString(config["chatId"]))
|
||||
payload, err := json.Marshal(map[string]any{"chat_id": chatID, "text": message.Title + "\n\n" + message.Body})
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal telegram payload: %w", err)
|
||||
}
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://api.telegram.org/bot"+botToken+"/sendMessage", bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create telegram request: %w", err)
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
response, err := n.client.Do(request)
|
||||
if err != nil {
|
||||
return fmt.Errorf("send telegram request: %w", err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
if response.StatusCode >= http.StatusBadRequest {
|
||||
return fmt.Errorf("telegram response status: %s", response.Status)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
16
server/internal/notify/types.go
Normal file
16
server/internal/notify/types.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package notify
|
||||
|
||||
import "context"
|
||||
|
||||
type Message struct {
|
||||
Title string `json:"title"`
|
||||
Body string `json:"body"`
|
||||
Fields map[string]any `json:"fields,omitempty"`
|
||||
}
|
||||
|
||||
type Notifier interface {
|
||||
Type() string
|
||||
SensitiveFields() []string
|
||||
Validate(config map[string]any) error
|
||||
Send(ctx context.Context, config map[string]any, message Message) error
|
||||
}
|
||||
55
server/internal/notify/webhook.go
Normal file
55
server/internal/notify/webhook.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type WebhookNotifier struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func NewWebhookNotifier() *WebhookNotifier {
|
||||
return &WebhookNotifier{client: &http.Client{Timeout: 10 * time.Second}}
|
||||
}
|
||||
func (n *WebhookNotifier) Type() string { return "webhook" }
|
||||
func (n *WebhookNotifier) SensitiveFields() []string { return []string{"secret"} }
|
||||
|
||||
func (n *WebhookNotifier) Validate(config map[string]any) error {
|
||||
if strings.TrimSpace(asString(config["url"])) == "" {
|
||||
return fmt.Errorf("webhook url is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *WebhookNotifier) Send(ctx context.Context, config map[string]any, message Message) error {
|
||||
if err := n.Validate(config); err != nil {
|
||||
return err
|
||||
}
|
||||
body, err := json.Marshal(map[string]any{"title": message.Title, "body": message.Body, "fields": message.Fields})
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal webhook payload: %w", err)
|
||||
}
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimSpace(asString(config["url"])), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create webhook request: %w", err)
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
if secret := strings.TrimSpace(asString(config["secret"])); secret != "" {
|
||||
request.Header.Set("X-BackupX-Secret", secret)
|
||||
}
|
||||
response, err := n.client.Do(request)
|
||||
if err != nil {
|
||||
return fmt.Errorf("send webhook request: %w", err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
if response.StatusCode >= http.StatusBadRequest {
|
||||
return fmt.Errorf("webhook response status: %s", response.Status)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
183
server/internal/repository/backup_record_repository.go
Normal file
183
server/internal/repository/backup_record_repository.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/model"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type BackupRecordListOptions struct {
|
||||
TaskID *uint
|
||||
Status string
|
||||
DateFrom *time.Time
|
||||
DateTo *time.Time
|
||||
Limit int
|
||||
Offset int
|
||||
}
|
||||
|
||||
type BackupTimelinePoint struct {
|
||||
Date string `json:"date"`
|
||||
Total int64 `json:"total"`
|
||||
Success int64 `json:"success"`
|
||||
Failed int64 `json:"failed"`
|
||||
}
|
||||
|
||||
type BackupStorageUsageItem struct {
|
||||
StorageTargetID uint `json:"storageTargetId"`
|
||||
TotalSize int64 `json:"totalSize"`
|
||||
}
|
||||
|
||||
type BackupRecordRepository interface {
|
||||
List(context.Context, BackupRecordListOptions) ([]model.BackupRecord, error)
|
||||
FindByID(context.Context, uint) (*model.BackupRecord, error)
|
||||
Create(context.Context, *model.BackupRecord) error
|
||||
Update(context.Context, *model.BackupRecord) error
|
||||
Delete(context.Context, uint) error
|
||||
ListRecent(context.Context, int) ([]model.BackupRecord, error)
|
||||
ListSuccessfulByTask(context.Context, uint) ([]model.BackupRecord, error)
|
||||
Count(context.Context) (int64, error)
|
||||
CountSince(context.Context, time.Time) (int64, error)
|
||||
CountSuccessSince(context.Context, time.Time) (int64, error)
|
||||
SumFileSize(context.Context) (int64, error)
|
||||
TimelineSince(context.Context, time.Time) ([]BackupTimelinePoint, error)
|
||||
StorageUsage(context.Context) ([]BackupStorageUsageItem, error)
|
||||
}
|
||||
|
||||
type GormBackupRecordRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewBackupRecordRepository(db *gorm.DB) *GormBackupRecordRepository {
|
||||
return &GormBackupRecordRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *GormBackupRecordRepository) List(ctx context.Context, options BackupRecordListOptions) ([]model.BackupRecord, error) {
|
||||
query := r.db.WithContext(ctx).Model(&model.BackupRecord{}).Preload("Task").Preload("Task.StorageTarget").Order("started_at desc")
|
||||
if options.TaskID != nil {
|
||||
query = query.Where("task_id = ?", *options.TaskID)
|
||||
}
|
||||
if options.Status != "" {
|
||||
query = query.Where("status = ?", options.Status)
|
||||
}
|
||||
if options.DateFrom != nil {
|
||||
query = query.Where("started_at >= ?", options.DateFrom.UTC())
|
||||
}
|
||||
if options.DateTo != nil {
|
||||
query = query.Where("started_at <= ?", options.DateTo.UTC())
|
||||
}
|
||||
if options.Limit > 0 {
|
||||
query = query.Limit(options.Limit)
|
||||
}
|
||||
if options.Offset > 0 {
|
||||
query = query.Offset(options.Offset)
|
||||
}
|
||||
var items []model.BackupRecord
|
||||
if err := query.Find(&items).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (r *GormBackupRecordRepository) FindByID(ctx context.Context, id uint) (*model.BackupRecord, error) {
|
||||
var item model.BackupRecord
|
||||
if err := r.db.WithContext(ctx).Preload("Task").Preload("Task.StorageTarget").First(&item, id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &item, nil
|
||||
}
|
||||
|
||||
func (r *GormBackupRecordRepository) Create(ctx context.Context, item *model.BackupRecord) error {
|
||||
return r.db.WithContext(ctx).Create(item).Error
|
||||
}
|
||||
|
||||
func (r *GormBackupRecordRepository) Update(ctx context.Context, item *model.BackupRecord) error {
|
||||
return r.db.WithContext(ctx).Save(item).Error
|
||||
}
|
||||
|
||||
func (r *GormBackupRecordRepository) Delete(ctx context.Context, id uint) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.BackupRecord{}, id).Error
|
||||
}
|
||||
|
||||
func (r *GormBackupRecordRepository) ListRecent(ctx context.Context, limit int) ([]model.BackupRecord, error) {
|
||||
if limit <= 0 {
|
||||
limit = 10
|
||||
}
|
||||
var items []model.BackupRecord
|
||||
if err := r.db.WithContext(ctx).Preload("Task").Preload("Task.StorageTarget").Order("started_at desc").Limit(limit).Find(&items).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (r *GormBackupRecordRepository) ListSuccessfulByTask(ctx context.Context, taskID uint) ([]model.BackupRecord, error) {
|
||||
var items []model.BackupRecord
|
||||
if err := r.db.WithContext(ctx).Where("task_id = ? AND status = ?", taskID, "success").Order("completed_at desc, id desc").Find(&items).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (r *GormBackupRecordRepository) Count(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
if err := r.db.WithContext(ctx).Model(&model.BackupRecord{}).Count(&count).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *GormBackupRecordRepository) CountSince(ctx context.Context, since time.Time) (int64, error) {
|
||||
var count int64
|
||||
if err := r.db.WithContext(ctx).Model(&model.BackupRecord{}).Where("started_at >= ?", since.UTC()).Count(&count).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *GormBackupRecordRepository) CountSuccessSince(ctx context.Context, since time.Time) (int64, error) {
|
||||
var count int64
|
||||
if err := r.db.WithContext(ctx).Model(&model.BackupRecord{}).Where("started_at >= ? AND status = ?", since.UTC(), "success").Count(&count).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *GormBackupRecordRepository) SumFileSize(ctx context.Context) (int64, error) {
|
||||
var sum int64
|
||||
if err := r.db.WithContext(ctx).Model(&model.BackupRecord{}).Select("COALESCE(SUM(file_size), 0)").Scan(&sum).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return sum, nil
|
||||
}
|
||||
|
||||
func (r *GormBackupRecordRepository) TimelineSince(ctx context.Context, since time.Time) ([]BackupTimelinePoint, error) {
|
||||
var items []BackupTimelinePoint
|
||||
query := `
|
||||
SELECT
|
||||
strftime('%Y-%m-%d', started_at) AS date,
|
||||
COUNT(*) AS total,
|
||||
SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) AS success,
|
||||
SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) AS failed
|
||||
FROM backup_records
|
||||
WHERE started_at >= ?
|
||||
GROUP BY strftime('%Y-%m-%d', started_at)
|
||||
ORDER BY date ASC
|
||||
`
|
||||
if err := r.db.WithContext(ctx).Raw(query, since.UTC()).Scan(&items).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (r *GormBackupRecordRepository) StorageUsage(ctx context.Context) ([]BackupStorageUsageItem, error) {
|
||||
var items []BackupStorageUsageItem
|
||||
if err := r.db.WithContext(ctx).Model(&model.BackupRecord{}).Select("storage_target_id, COALESCE(SUM(file_size), 0) AS total_size").Group("storage_target_id").Order("storage_target_id asc").Scan(&items).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
115
server/internal/repository/backup_record_repository_test.go
Normal file
115
server/internal/repository/backup_record_repository_test.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/config"
|
||||
"backupx/server/internal/database"
|
||||
"backupx/server/internal/logger"
|
||||
"backupx/server/internal/model"
|
||||
)
|
||||
|
||||
func newBackupRecordTestRepository(t *testing.T) *GormBackupRecordRepository {
|
||||
t.Helper()
|
||||
log, err := logger.New(config.LogConfig{Level: "error"})
|
||||
if err != nil {
|
||||
t.Fatalf("logger.New returned error: %v", err)
|
||||
}
|
||||
db, err := database.Open(config.DatabaseConfig{Path: filepath.Join(t.TempDir(), "backupx.db")}, log)
|
||||
if err != nil {
|
||||
t.Fatalf("database.Open returned error: %v", err)
|
||||
}
|
||||
storageTarget := &model.StorageTarget{Name: "local", Type: "local_disk", Enabled: true, ConfigCiphertext: "{}", ConfigVersion: 1, LastTestStatus: "unknown"}
|
||||
if err := db.Create(storageTarget).Error; err != nil {
|
||||
t.Fatalf("seed storage target error: %v", err)
|
||||
}
|
||||
task := &model.BackupTask{Name: "website", Type: "file", Enabled: true, SourcePath: "/srv/www/site", StorageTargetID: storageTarget.ID, RetentionDays: 30, Compression: "gzip", MaxBackups: 10, LastStatus: "idle"}
|
||||
if err := db.Create(task).Error; err != nil {
|
||||
t.Fatalf("seed backup task error: %v", err)
|
||||
}
|
||||
return NewBackupRecordRepository(db)
|
||||
}
|
||||
|
||||
func TestBackupRecordRepositoryQueries(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := newBackupRecordTestRepository(t)
|
||||
now := time.Now().UTC()
|
||||
completedAt := now.Add(2 * time.Minute)
|
||||
record := &model.BackupRecord{
|
||||
TaskID: 1,
|
||||
StorageTargetID: 1,
|
||||
Status: "success",
|
||||
FileName: "website.tar.gz",
|
||||
FileSize: 1024,
|
||||
StoragePath: "tasks/1/website.tar.gz",
|
||||
DurationSeconds: 120,
|
||||
LogContent: "done",
|
||||
StartedAt: now,
|
||||
CompletedAt: &completedAt,
|
||||
}
|
||||
if err := repo.Create(ctx, record); err != nil {
|
||||
t.Fatalf("Create returned error: %v", err)
|
||||
}
|
||||
stored, err := repo.FindByID(ctx, record.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("FindByID returned error: %v", err)
|
||||
}
|
||||
if stored == nil || stored.FileName != "website.tar.gz" {
|
||||
t.Fatalf("unexpected stored record: %#v", stored)
|
||||
}
|
||||
listed, err := repo.List(ctx, BackupRecordListOptions{TaskID: &record.TaskID, Status: "success"})
|
||||
if err != nil {
|
||||
t.Fatalf("List returned error: %v", err)
|
||||
}
|
||||
if len(listed) != 1 {
|
||||
t.Fatalf("expected one listed record, got %d", len(listed))
|
||||
}
|
||||
recent, err := repo.ListRecent(ctx, 5)
|
||||
if err != nil {
|
||||
t.Fatalf("ListRecent returned error: %v", err)
|
||||
}
|
||||
if len(recent) != 1 {
|
||||
t.Fatalf("expected one recent record, got %d", len(recent))
|
||||
}
|
||||
total, err := repo.Count(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Count returned error: %v", err)
|
||||
}
|
||||
if total != 1 {
|
||||
t.Fatalf("expected total count 1, got %d", total)
|
||||
}
|
||||
successCount, err := repo.CountSuccessSince(ctx, now.Add(-time.Hour))
|
||||
if err != nil {
|
||||
t.Fatalf("CountSuccessSince returned error: %v", err)
|
||||
}
|
||||
if successCount != 1 {
|
||||
t.Fatalf("expected success count 1, got %d", successCount)
|
||||
}
|
||||
sum, err := repo.SumFileSize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("SumFileSize returned error: %v", err)
|
||||
}
|
||||
if sum != 1024 {
|
||||
t.Fatalf("expected file size sum 1024, got %d", sum)
|
||||
}
|
||||
timeline, err := repo.TimelineSince(ctx, now.Add(-time.Hour))
|
||||
if err != nil {
|
||||
t.Fatalf("TimelineSince returned error: %v", err)
|
||||
}
|
||||
if len(timeline) != 1 || timeline[0].Success != 1 {
|
||||
t.Fatalf("unexpected timeline: %#v", timeline)
|
||||
}
|
||||
usage, err := repo.StorageUsage(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("StorageUsage returned error: %v", err)
|
||||
}
|
||||
if len(usage) != 1 || usage[0].TotalSize != 1024 {
|
||||
t.Fatalf("unexpected usage: %#v", usage)
|
||||
}
|
||||
if err := repo.Delete(ctx, record.ID); err != nil {
|
||||
t.Fatalf("Delete returned error: %v", err)
|
||||
}
|
||||
}
|
||||
116
server/internal/repository/backup_task_repository.go
Normal file
116
server/internal/repository/backup_task_repository.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"backupx/server/internal/model"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type BackupTaskListOptions struct {
|
||||
Type string
|
||||
Enabled *bool
|
||||
}
|
||||
|
||||
type BackupTaskRepository interface {
|
||||
List(context.Context, BackupTaskListOptions) ([]model.BackupTask, error)
|
||||
FindByID(context.Context, uint) (*model.BackupTask, error)
|
||||
FindByName(context.Context, string) (*model.BackupTask, error)
|
||||
ListSchedulable(context.Context) ([]model.BackupTask, error)
|
||||
Count(context.Context) (int64, error)
|
||||
CountEnabled(context.Context) (int64, error)
|
||||
CountByStorageTargetID(context.Context, uint) (int64, error)
|
||||
Create(context.Context, *model.BackupTask) error
|
||||
Update(context.Context, *model.BackupTask) error
|
||||
Delete(context.Context, uint) error
|
||||
}
|
||||
|
||||
type GormBackupTaskRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewBackupTaskRepository(db *gorm.DB) *GormBackupTaskRepository {
|
||||
return &GormBackupTaskRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *GormBackupTaskRepository) List(ctx context.Context, options BackupTaskListOptions) ([]model.BackupTask, error) {
|
||||
query := r.db.WithContext(ctx).Model(&model.BackupTask{}).Preload("StorageTarget").Order("updated_at desc")
|
||||
if options.Type != "" {
|
||||
query = query.Where("type = ?", options.Type)
|
||||
}
|
||||
if options.Enabled != nil {
|
||||
query = query.Where("enabled = ?", *options.Enabled)
|
||||
}
|
||||
var items []model.BackupTask
|
||||
if err := query.Find(&items).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (r *GormBackupTaskRepository) FindByID(ctx context.Context, id uint) (*model.BackupTask, error) {
|
||||
var item model.BackupTask
|
||||
if err := r.db.WithContext(ctx).Preload("StorageTarget").First(&item, id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &item, nil
|
||||
}
|
||||
|
||||
func (r *GormBackupTaskRepository) FindByName(ctx context.Context, name string) (*model.BackupTask, error) {
|
||||
var item model.BackupTask
|
||||
if err := r.db.WithContext(ctx).Where("name = ?", name).First(&item).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &item, nil
|
||||
}
|
||||
|
||||
func (r *GormBackupTaskRepository) ListSchedulable(ctx context.Context) ([]model.BackupTask, error) {
|
||||
var items []model.BackupTask
|
||||
if err := r.db.WithContext(ctx).Preload("StorageTarget").Where("enabled = ? AND cron_expr <> ''", true).Order("id asc").Find(&items).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (r *GormBackupTaskRepository) Count(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
if err := r.db.WithContext(ctx).Model(&model.BackupTask{}).Count(&count).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *GormBackupTaskRepository) CountEnabled(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
if err := r.db.WithContext(ctx).Model(&model.BackupTask{}).Where("enabled = ?", true).Count(&count).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *GormBackupTaskRepository) CountByStorageTargetID(ctx context.Context, storageTargetID uint) (int64, error) {
|
||||
var count int64
|
||||
if err := r.db.WithContext(ctx).Model(&model.BackupTask{}).Where("storage_target_id = ?", storageTargetID).Count(&count).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *GormBackupTaskRepository) Create(ctx context.Context, item *model.BackupTask) error {
|
||||
return r.db.WithContext(ctx).Create(item).Error
|
||||
}
|
||||
|
||||
func (r *GormBackupTaskRepository) Update(ctx context.Context, item *model.BackupTask) error {
|
||||
return r.db.WithContext(ctx).Save(item).Error
|
||||
}
|
||||
|
||||
func (r *GormBackupTaskRepository) Delete(ctx context.Context, id uint) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.BackupTask{}, id).Error
|
||||
}
|
||||
94
server/internal/repository/backup_task_repository_test.go
Normal file
94
server/internal/repository/backup_task_repository_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"backupx/server/internal/config"
|
||||
"backupx/server/internal/database"
|
||||
"backupx/server/internal/logger"
|
||||
"backupx/server/internal/model"
|
||||
)
|
||||
|
||||
func newBackupTaskTestRepository(t *testing.T) *GormBackupTaskRepository {
|
||||
t.Helper()
|
||||
log, err := logger.New(config.LogConfig{Level: "error"})
|
||||
if err != nil {
|
||||
t.Fatalf("logger.New returned error: %v", err)
|
||||
}
|
||||
db, err := database.Open(config.DatabaseConfig{Path: filepath.Join(t.TempDir(), "backupx.db")}, log)
|
||||
if err != nil {
|
||||
t.Fatalf("database.Open returned error: %v", err)
|
||||
}
|
||||
if err := db.Create(&model.StorageTarget{Name: "local", Type: "local_disk", Enabled: true, ConfigCiphertext: "{}", ConfigVersion: 1, LastTestStatus: "unknown"}).Error; err != nil {
|
||||
t.Fatalf("seed storage target error: %v", err)
|
||||
}
|
||||
return NewBackupTaskRepository(db)
|
||||
}
|
||||
|
||||
func TestBackupTaskRepositoryCRUD(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := newBackupTaskTestRepository(t)
|
||||
task := &model.BackupTask{
|
||||
Name: "website",
|
||||
Type: "file",
|
||||
Enabled: true,
|
||||
SourcePath: "/srv/www/site",
|
||||
StorageTargetID: 1,
|
||||
RetentionDays: 30,
|
||||
Compression: "gzip",
|
||||
MaxBackups: 10,
|
||||
LastStatus: "idle",
|
||||
}
|
||||
if err := repo.Create(ctx, task); err != nil {
|
||||
t.Fatalf("Create returned error: %v", err)
|
||||
}
|
||||
stored, err := repo.FindByID(ctx, task.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("FindByID returned error: %v", err)
|
||||
}
|
||||
if stored == nil || stored.Name != "website" {
|
||||
t.Fatalf("unexpected stored task: %#v", stored)
|
||||
}
|
||||
stored.Enabled = false
|
||||
stored.CronExpr = "0 3 * * *"
|
||||
if err := repo.Update(ctx, stored); err != nil {
|
||||
t.Fatalf("Update returned error: %v", err)
|
||||
}
|
||||
schedulable, err := repo.ListSchedulable(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("ListSchedulable returned error: %v", err)
|
||||
}
|
||||
if len(schedulable) != 0 {
|
||||
t.Fatalf("expected disabled task not schedulable, got %d", len(schedulable))
|
||||
}
|
||||
stored.Enabled = true
|
||||
if err := repo.Update(ctx, stored); err != nil {
|
||||
t.Fatalf("Update returned error: %v", err)
|
||||
}
|
||||
schedulable, err = repo.ListSchedulable(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("ListSchedulable returned error: %v", err)
|
||||
}
|
||||
if len(schedulable) != 1 {
|
||||
t.Fatalf("expected one schedulable task, got %d", len(schedulable))
|
||||
}
|
||||
count, err := repo.CountByStorageTargetID(ctx, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("CountByStorageTargetID returned error: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Fatalf("expected referenced task count 1, got %d", count)
|
||||
}
|
||||
if err := repo.Delete(ctx, task.ID); err != nil {
|
||||
t.Fatalf("Delete returned error: %v", err)
|
||||
}
|
||||
deleted, err := repo.FindByID(ctx, task.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("FindByID after delete returned error: %v", err)
|
||||
}
|
||||
if deleted != nil {
|
||||
t.Fatalf("expected task deleted, got %#v", deleted)
|
||||
}
|
||||
}
|
||||
80
server/internal/repository/node_repository.go
Normal file
80
server/internal/repository/node_repository.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"backupx/server/internal/model"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type NodeRepository interface {
|
||||
List(context.Context) ([]model.Node, error)
|
||||
FindByID(context.Context, uint) (*model.Node, error)
|
||||
FindByToken(context.Context, string) (*model.Node, error)
|
||||
FindLocal(context.Context) (*model.Node, error)
|
||||
Create(context.Context, *model.Node) error
|
||||
Update(context.Context, *model.Node) error
|
||||
Delete(context.Context, uint) error
|
||||
}
|
||||
|
||||
type GormNodeRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewNodeRepository(db *gorm.DB) *GormNodeRepository {
|
||||
return &GormNodeRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *GormNodeRepository) List(ctx context.Context) ([]model.Node, error) {
|
||||
var items []model.Node
|
||||
if err := r.db.WithContext(ctx).Order("is_local desc, updated_at desc").Find(&items).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (r *GormNodeRepository) FindByID(ctx context.Context, id uint) (*model.Node, error) {
|
||||
var item model.Node
|
||||
if err := r.db.WithContext(ctx).First(&item, id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &item, nil
|
||||
}
|
||||
|
||||
func (r *GormNodeRepository) FindByToken(ctx context.Context, token string) (*model.Node, error) {
|
||||
var item model.Node
|
||||
if err := r.db.WithContext(ctx).Where("token = ?", token).First(&item).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &item, nil
|
||||
}
|
||||
|
||||
func (r *GormNodeRepository) FindLocal(ctx context.Context) (*model.Node, error) {
|
||||
var item model.Node
|
||||
if err := r.db.WithContext(ctx).Where("is_local = ?", true).First(&item).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &item, nil
|
||||
}
|
||||
|
||||
func (r *GormNodeRepository) Create(ctx context.Context, item *model.Node) error {
|
||||
return r.db.WithContext(ctx).Create(item).Error
|
||||
}
|
||||
|
||||
func (r *GormNodeRepository) Update(ctx context.Context, item *model.Node) error {
|
||||
return r.db.WithContext(ctx).Save(item).Error
|
||||
}
|
||||
|
||||
func (r *GormNodeRepository) Delete(ctx context.Context, id uint) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.Node{}, id).Error
|
||||
}
|
||||
83
server/internal/repository/notification_repository.go
Normal file
83
server/internal/repository/notification_repository.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"backupx/server/internal/model"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type NotificationRepository interface {
|
||||
List(context.Context) ([]model.Notification, error)
|
||||
ListEnabledForEvent(context.Context, bool) ([]model.Notification, error)
|
||||
FindByID(context.Context, uint) (*model.Notification, error)
|
||||
FindByName(context.Context, string) (*model.Notification, error)
|
||||
Create(context.Context, *model.Notification) error
|
||||
Update(context.Context, *model.Notification) error
|
||||
Delete(context.Context, uint) error
|
||||
}
|
||||
|
||||
type GormNotificationRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewNotificationRepository(db *gorm.DB) *GormNotificationRepository {
|
||||
return &GormNotificationRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *GormNotificationRepository) List(ctx context.Context) ([]model.Notification, error) {
|
||||
var items []model.Notification
|
||||
if err := r.db.WithContext(ctx).Order("updated_at desc").Find(&items).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (r *GormNotificationRepository) ListEnabledForEvent(ctx context.Context, success bool) ([]model.Notification, error) {
|
||||
query := r.db.WithContext(ctx).Model(&model.Notification{}).Where("enabled = ?", true)
|
||||
if success {
|
||||
query = query.Where("on_success = ?", true)
|
||||
} else {
|
||||
query = query.Where("on_failure = ?", true)
|
||||
}
|
||||
var items []model.Notification
|
||||
if err := query.Order("updated_at desc").Find(&items).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (r *GormNotificationRepository) FindByID(ctx context.Context, id uint) (*model.Notification, error) {
|
||||
var item model.Notification
|
||||
if err := r.db.WithContext(ctx).First(&item, id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &item, nil
|
||||
}
|
||||
|
||||
func (r *GormNotificationRepository) FindByName(ctx context.Context, name string) (*model.Notification, error) {
|
||||
var item model.Notification
|
||||
if err := r.db.WithContext(ctx).Where("name = ?", name).First(&item).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &item, nil
|
||||
}
|
||||
|
||||
func (r *GormNotificationRepository) Create(ctx context.Context, item *model.Notification) error {
|
||||
return r.db.WithContext(ctx).Create(item).Error
|
||||
}
|
||||
|
||||
func (r *GormNotificationRepository) Update(ctx context.Context, item *model.Notification) error {
|
||||
return r.db.WithContext(ctx).Save(item).Error
|
||||
}
|
||||
|
||||
func (r *GormNotificationRepository) Delete(ctx context.Context, id uint) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.Notification{}, id).Error
|
||||
}
|
||||
69
server/internal/repository/notification_repository_test.go
Normal file
69
server/internal/repository/notification_repository_test.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"backupx/server/internal/config"
|
||||
"backupx/server/internal/database"
|
||||
"backupx/server/internal/logger"
|
||||
"backupx/server/internal/model"
|
||||
)
|
||||
|
||||
func newNotificationTestRepository(t *testing.T) *GormNotificationRepository {
|
||||
t.Helper()
|
||||
log, err := logger.New(config.LogConfig{Level: "error"})
|
||||
if err != nil {
|
||||
t.Fatalf("logger.New returned error: %v", err)
|
||||
}
|
||||
db, err := database.Open(config.DatabaseConfig{Path: filepath.Join(t.TempDir(), "backupx.db")}, log)
|
||||
if err != nil {
|
||||
t.Fatalf("database.Open returned error: %v", err)
|
||||
}
|
||||
return NewNotificationRepository(db)
|
||||
}
|
||||
|
||||
func TestNotificationRepositoryCRUD(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := newNotificationTestRepository(t)
|
||||
item := &model.Notification{
|
||||
Type: "webhook",
|
||||
Name: "ops-webhook",
|
||||
ConfigCiphertext: "ciphertext",
|
||||
Enabled: true,
|
||||
OnSuccess: false,
|
||||
OnFailure: true,
|
||||
}
|
||||
if err := repo.Create(ctx, item); err != nil {
|
||||
t.Fatalf("Create returned error: %v", err)
|
||||
}
|
||||
stored, err := repo.FindByName(ctx, "ops-webhook")
|
||||
if err != nil {
|
||||
t.Fatalf("FindByName returned error: %v", err)
|
||||
}
|
||||
if stored == nil || stored.Name != "ops-webhook" {
|
||||
t.Fatalf("unexpected notification: %#v", stored)
|
||||
}
|
||||
enabledForFailure, err := repo.ListEnabledForEvent(ctx, false)
|
||||
if err != nil {
|
||||
t.Fatalf("ListEnabledForEvent returned error: %v", err)
|
||||
}
|
||||
if len(enabledForFailure) != 1 {
|
||||
t.Fatalf("expected one failure notification, got %d", len(enabledForFailure))
|
||||
}
|
||||
stored.OnSuccess = true
|
||||
if err := repo.Update(ctx, stored); err != nil {
|
||||
t.Fatalf("Update returned error: %v", err)
|
||||
}
|
||||
enabledForSuccess, err := repo.ListEnabledForEvent(ctx, true)
|
||||
if err != nil {
|
||||
t.Fatalf("ListEnabledForEvent returned error: %v", err)
|
||||
}
|
||||
if len(enabledForSuccess) != 1 {
|
||||
t.Fatalf("expected one success notification, got %d", len(enabledForSuccess))
|
||||
}
|
||||
if err := repo.Delete(ctx, item.ID); err != nil {
|
||||
t.Fatalf("Delete returned error: %v", err)
|
||||
}
|
||||
}
|
||||
48
server/internal/repository/oauth_session_repository.go
Normal file
48
server/internal/repository/oauth_session_repository.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/model"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type OAuthSessionRepository interface {
|
||||
Create(context.Context, *model.OAuthSession) error
|
||||
Update(context.Context, *model.OAuthSession) error
|
||||
FindByState(context.Context, string) (*model.OAuthSession, error)
|
||||
DeleteExpired(context.Context, time.Time) error
|
||||
}
|
||||
|
||||
type GormOAuthSessionRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewOAuthSessionRepository(db *gorm.DB) *GormOAuthSessionRepository {
|
||||
return &GormOAuthSessionRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *GormOAuthSessionRepository) Create(ctx context.Context, item *model.OAuthSession) error {
|
||||
return r.db.WithContext(ctx).Create(item).Error
|
||||
}
|
||||
|
||||
func (r *GormOAuthSessionRepository) Update(ctx context.Context, item *model.OAuthSession) error {
|
||||
return r.db.WithContext(ctx).Save(item).Error
|
||||
}
|
||||
|
||||
func (r *GormOAuthSessionRepository) FindByState(ctx context.Context, state string) (*model.OAuthSession, error) {
|
||||
var item model.OAuthSession
|
||||
if err := r.db.WithContext(ctx).Where("state = ?", state).First(&item).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &item, nil
|
||||
}
|
||||
|
||||
func (r *GormOAuthSessionRepository) DeleteExpired(ctx context.Context, before time.Time) error {
|
||||
return r.db.WithContext(ctx).Where("expires_at <= ?", before).Delete(&model.OAuthSession{}).Error
|
||||
}
|
||||
73
server/internal/repository/oauth_session_repository_test.go
Normal file
73
server/internal/repository/oauth_session_repository_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/config"
|
||||
"backupx/server/internal/database"
|
||||
"backupx/server/internal/logger"
|
||||
"backupx/server/internal/model"
|
||||
)
|
||||
|
||||
func newOAuthSessionTestRepository(t *testing.T) *GormOAuthSessionRepository {
|
||||
t.Helper()
|
||||
log, err := logger.New(config.LogConfig{Level: "error"})
|
||||
if err != nil {
|
||||
t.Fatalf("logger.New returned error: %v", err)
|
||||
}
|
||||
db, err := database.Open(config.DatabaseConfig{Path: filepath.Join(t.TempDir(), "backupx.db")}, log)
|
||||
if err != nil {
|
||||
t.Fatalf("database.Open returned error: %v", err)
|
||||
}
|
||||
return NewOAuthSessionRepository(db)
|
||||
}
|
||||
|
||||
func TestOAuthSessionRepositoryCRUDAndDeleteExpired(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := newOAuthSessionTestRepository(t)
|
||||
expiresAt := time.Now().UTC().Add(5 * time.Minute)
|
||||
session := &model.OAuthSession{
|
||||
ProviderType: "google_drive",
|
||||
State: "oauth-state",
|
||||
PayloadCiphertext: "ciphertext",
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
if err := repo.Create(ctx, session); err != nil {
|
||||
t.Fatalf("Create returned error: %v", err)
|
||||
}
|
||||
stored, err := repo.FindByState(ctx, "oauth-state")
|
||||
if err != nil {
|
||||
t.Fatalf("FindByState returned error: %v", err)
|
||||
}
|
||||
if stored == nil || stored.State != "oauth-state" {
|
||||
t.Fatalf("unexpected stored session: %#v", stored)
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
stored.UsedAt = &now
|
||||
if err := repo.Update(ctx, stored); err != nil {
|
||||
t.Fatalf("Update returned error: %v", err)
|
||||
}
|
||||
if err := repo.DeleteExpired(ctx, time.Now().UTC().Add(-time.Minute)); err != nil {
|
||||
t.Fatalf("DeleteExpired returned error: %v", err)
|
||||
}
|
||||
stillThere, err := repo.FindByState(ctx, "oauth-state")
|
||||
if err != nil {
|
||||
t.Fatalf("FindByState after DeleteExpired returned error: %v", err)
|
||||
}
|
||||
if stillThere == nil {
|
||||
t.Fatalf("expected unexpired session to remain")
|
||||
}
|
||||
if err := repo.DeleteExpired(ctx, time.Now().UTC().Add(10*time.Minute)); err != nil {
|
||||
t.Fatalf("DeleteExpired returned error: %v", err)
|
||||
}
|
||||
deleted, err := repo.FindByState(ctx, "oauth-state")
|
||||
if err != nil {
|
||||
t.Fatalf("FindByState after expiration delete returned error: %v", err)
|
||||
}
|
||||
if deleted != nil {
|
||||
t.Fatalf("expected session to be deleted, got %#v", deleted)
|
||||
}
|
||||
}
|
||||
68
server/internal/repository/storage_target_repository.go
Normal file
68
server/internal/repository/storage_target_repository.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"backupx/server/internal/model"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type StorageTargetRepository interface {
|
||||
List(context.Context) ([]model.StorageTarget, error)
|
||||
FindByID(context.Context, uint) (*model.StorageTarget, error)
|
||||
FindByName(context.Context, string) (*model.StorageTarget, error)
|
||||
Create(context.Context, *model.StorageTarget) error
|
||||
Update(context.Context, *model.StorageTarget) error
|
||||
Delete(context.Context, uint) error
|
||||
}
|
||||
|
||||
type GormStorageTargetRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewStorageTargetRepository(db *gorm.DB) *GormStorageTargetRepository {
|
||||
return &GormStorageTargetRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *GormStorageTargetRepository) List(ctx context.Context) ([]model.StorageTarget, error) {
|
||||
var items []model.StorageTarget
|
||||
if err := r.db.WithContext(ctx).Order("updated_at desc").Find(&items).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (r *GormStorageTargetRepository) FindByID(ctx context.Context, id uint) (*model.StorageTarget, error) {
|
||||
var item model.StorageTarget
|
||||
if err := r.db.WithContext(ctx).First(&item, id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &item, nil
|
||||
}
|
||||
|
||||
func (r *GormStorageTargetRepository) FindByName(ctx context.Context, name string) (*model.StorageTarget, error) {
|
||||
var item model.StorageTarget
|
||||
if err := r.db.WithContext(ctx).Where("name = ?", name).First(&item).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &item, nil
|
||||
}
|
||||
|
||||
func (r *GormStorageTargetRepository) Create(ctx context.Context, item *model.StorageTarget) error {
|
||||
return r.db.WithContext(ctx).Create(item).Error
|
||||
}
|
||||
|
||||
func (r *GormStorageTargetRepository) Update(ctx context.Context, item *model.StorageTarget) error {
|
||||
return r.db.WithContext(ctx).Save(item).Error
|
||||
}
|
||||
|
||||
func (r *GormStorageTargetRepository) Delete(ctx context.Context, id uint) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.StorageTarget{}, id).Error
|
||||
}
|
||||
81
server/internal/repository/storage_target_repository_test.go
Normal file
81
server/internal/repository/storage_target_repository_test.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"backupx/server/internal/config"
|
||||
"backupx/server/internal/database"
|
||||
"backupx/server/internal/logger"
|
||||
"backupx/server/internal/model"
|
||||
)
|
||||
|
||||
func openTestDB(t *testing.T) context.Context {
|
||||
t.Helper()
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
func newStorageTestRepository(t *testing.T) *GormStorageTargetRepository {
|
||||
t.Helper()
|
||||
log, err := logger.New(config.LogConfig{Level: "error"})
|
||||
if err != nil {
|
||||
t.Fatalf("logger.New returned error: %v", err)
|
||||
}
|
||||
db, err := database.Open(config.DatabaseConfig{Path: filepath.Join(t.TempDir(), "backupx.db")}, log)
|
||||
if err != nil {
|
||||
t.Fatalf("database.Open returned error: %v", err)
|
||||
}
|
||||
return NewStorageTargetRepository(db)
|
||||
}
|
||||
|
||||
func TestStorageTargetRepositoryCRUD(t *testing.T) {
|
||||
ctx := openTestDB(t)
|
||||
repo := newStorageTestRepository(t)
|
||||
item := &model.StorageTarget{
|
||||
Name: "local",
|
||||
Type: "local_disk",
|
||||
Enabled: true,
|
||||
ConfigCiphertext: "ciphertext",
|
||||
ConfigVersion: 1,
|
||||
LastTestStatus: "unknown",
|
||||
}
|
||||
if err := repo.Create(ctx, item); err != nil {
|
||||
t.Fatalf("Create returned error: %v", err)
|
||||
}
|
||||
stored, err := repo.FindByID(ctx, item.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("FindByID returned error: %v", err)
|
||||
}
|
||||
if stored == nil || stored.Name != "local" {
|
||||
t.Fatalf("unexpected stored target: %#v", stored)
|
||||
}
|
||||
byName, err := repo.FindByName(ctx, "local")
|
||||
if err != nil {
|
||||
t.Fatalf("FindByName returned error: %v", err)
|
||||
}
|
||||
if byName == nil || byName.ID != item.ID {
|
||||
t.Fatalf("expected target lookup by name to match, got %#v", byName)
|
||||
}
|
||||
stored.Description = "updated"
|
||||
if err := repo.Update(ctx, stored); err != nil {
|
||||
t.Fatalf("Update returned error: %v", err)
|
||||
}
|
||||
items, err := repo.List(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("List returned error: %v", err)
|
||||
}
|
||||
if len(items) != 1 || items[0].Description != "updated" {
|
||||
t.Fatalf("unexpected list result: %#v", items)
|
||||
}
|
||||
if err := repo.Delete(ctx, item.ID); err != nil {
|
||||
t.Fatalf("Delete returned error: %v", err)
|
||||
}
|
||||
deleted, err := repo.FindByID(ctx, item.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("FindByID after delete returned error: %v", err)
|
||||
}
|
||||
if deleted != nil {
|
||||
t.Fatalf("expected target to be deleted, got %#v", deleted)
|
||||
}
|
||||
}
|
||||
50
server/internal/repository/system_config_repository.go
Normal file
50
server/internal/repository/system_config_repository.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"backupx/server/internal/model"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type SystemConfigRepository interface {
|
||||
GetByKey(context.Context, string) (*model.SystemConfig, error)
|
||||
List(context.Context) ([]model.SystemConfig, error)
|
||||
Upsert(context.Context, *model.SystemConfig) error
|
||||
}
|
||||
|
||||
type GormSystemConfigRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewSystemConfigRepository(db *gorm.DB) *GormSystemConfigRepository {
|
||||
return &GormSystemConfigRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *GormSystemConfigRepository) GetByKey(ctx context.Context, key string) (*model.SystemConfig, error) {
|
||||
var item model.SystemConfig
|
||||
if err := r.db.WithContext(ctx).Where("key = ?", key).First(&item).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &item, nil
|
||||
}
|
||||
|
||||
func (r *GormSystemConfigRepository) List(ctx context.Context) ([]model.SystemConfig, error) {
|
||||
var items []model.SystemConfig
|
||||
if err := r.db.WithContext(ctx).Order("key ASC").Find(&items).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (r *GormSystemConfigRepository) Upsert(ctx context.Context, item *model.SystemConfig) error {
|
||||
return r.db.WithContext(ctx).Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "key"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"value", "encrypted", "updated_at"}),
|
||||
}).Create(item).Error
|
||||
}
|
||||
63
server/internal/repository/user_repository.go
Normal file
63
server/internal/repository/user_repository.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"backupx/server/internal/model"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UserRepository interface {
|
||||
Count(context.Context) (int64, error)
|
||||
Create(context.Context, *model.User) error
|
||||
Update(context.Context, *model.User) error
|
||||
FindByUsername(context.Context, string) (*model.User, error)
|
||||
FindByID(context.Context, uint) (*model.User, error)
|
||||
}
|
||||
|
||||
type GormUserRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewUserRepository(db *gorm.DB) *GormUserRepository {
|
||||
return &GormUserRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) Count(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
if err := r.db.WithContext(ctx).Model(&model.User{}).Count(&count).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) Create(ctx context.Context, user *model.User) error {
|
||||
return r.db.WithContext(ctx).Create(user).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) Update(ctx context.Context, user *model.User) error {
|
||||
return r.db.WithContext(ctx).Save(user).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) FindByUsername(ctx context.Context, username string) (*model.User, error) {
|
||||
var user model.User
|
||||
if err := r.db.WithContext(ctx).Where("username = ?", username).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) FindByID(ctx context.Context, id uint) (*model.User, error) {
|
||||
var user model.User
|
||||
if err := r.db.WithContext(ctx).First(&user, id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
109
server/internal/scheduler/service.go
Normal file
109
server/internal/scheduler/service.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/model"
|
||||
"backupx/server/internal/repository"
|
||||
servicepkg "backupx/server/internal/service"
|
||||
"github.com/robfig/cron/v3"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type TaskRunner interface {
|
||||
RunTaskByID(context.Context, uint) (*servicepkg.BackupRecordDetail, error)
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
mu sync.Mutex
|
||||
cron *cron.Cron
|
||||
tasks repository.BackupTaskRepository
|
||||
runner TaskRunner
|
||||
logger *zap.Logger
|
||||
entries map[uint]cron.EntryID
|
||||
}
|
||||
|
||||
func NewService(tasks repository.BackupTaskRepository, runner TaskRunner, logger *zap.Logger) *Service {
|
||||
parser := cron.NewParser(cron.SecondOptional | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor)
|
||||
return &Service{cron: cron.New(cron.WithParser(parser), cron.WithLocation(time.UTC)), tasks: tasks, runner: runner, logger: logger, entries: make(map[uint]cron.EntryID)}
|
||||
}
|
||||
|
||||
func (s *Service) Start(ctx context.Context) error {
|
||||
if err := s.Reload(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
s.cron.Start()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) Stop(ctx context.Context) error {
|
||||
stopCtx := s.cron.Stop()
|
||||
select {
|
||||
case <-stopCtx.Done():
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Reload(ctx context.Context) error {
|
||||
items, err := s.tasks.ListSchedulable(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for taskID, entryID := range s.entries {
|
||||
s.cron.Remove(entryID)
|
||||
delete(s.entries, taskID)
|
||||
}
|
||||
for _, item := range items {
|
||||
item := item
|
||||
if err := s.syncTaskLocked(&item); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) SyncTask(_ context.Context, task *model.BackupTask) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.syncTaskLocked(task)
|
||||
}
|
||||
|
||||
func (s *Service) RemoveTask(_ context.Context, taskID uint) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if entryID, ok := s.entries[taskID]; ok {
|
||||
s.cron.Remove(entryID)
|
||||
delete(s.entries, taskID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) syncTaskLocked(task *model.BackupTask) error {
|
||||
if task == nil {
|
||||
return fmt.Errorf("task is required")
|
||||
}
|
||||
if entryID, ok := s.entries[task.ID]; ok {
|
||||
s.cron.Remove(entryID)
|
||||
delete(s.entries, task.ID)
|
||||
}
|
||||
if !task.Enabled || task.CronExpr == "" {
|
||||
return nil
|
||||
}
|
||||
entryID, err := s.cron.AddFunc(task.CronExpr, func() {
|
||||
if _, runErr := s.runner.RunTaskByID(context.Background(), task.ID); runErr != nil && s.logger != nil {
|
||||
s.logger.Warn("scheduled backup run failed", zap.Uint("task_id", task.ID), zap.Error(runErr))
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.entries[task.ID] = entryID
|
||||
return nil
|
||||
}
|
||||
58
server/internal/scheduler/service_test.go
Normal file
58
server/internal/scheduler/service_test.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"backupx/server/internal/repository"
|
||||
servicepkg "backupx/server/internal/service"
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/model"
|
||||
)
|
||||
|
||||
type fakeTaskRepository struct {
|
||||
items []model.BackupTask
|
||||
}
|
||||
|
||||
func (r *fakeTaskRepository) List(context.Context, repository.BackupTaskListOptions) ([]model.BackupTask, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *fakeTaskRepository) FindByID(context.Context, uint) (*model.BackupTask, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *fakeTaskRepository) FindByName(context.Context, string) (*model.BackupTask, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *fakeTaskRepository) ListSchedulable(context.Context) ([]model.BackupTask, error) {
|
||||
return r.items, nil
|
||||
}
|
||||
func (r *fakeTaskRepository) Count(context.Context) (int64, error) { return 0, nil }
|
||||
func (r *fakeTaskRepository) CountEnabled(context.Context) (int64, error) { return 0, nil }
|
||||
func (r *fakeTaskRepository) CountByStorageTargetID(context.Context, uint) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (r *fakeTaskRepository) Create(context.Context, *model.BackupTask) error { return nil }
|
||||
func (r *fakeTaskRepository) Update(context.Context, *model.BackupTask) error { return nil }
|
||||
func (r *fakeTaskRepository) Delete(context.Context, uint) error { return nil }
|
||||
|
||||
type fakeRunner struct{ taskIDs []uint }
|
||||
|
||||
func (r *fakeRunner) RunTaskByID(_ context.Context, id uint) (*servicepkg.BackupRecordDetail, error) {
|
||||
r.taskIDs = append(r.taskIDs, id)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestServiceSyncTaskAndTrigger(t *testing.T) {
|
||||
repo := &fakeTaskRepository{}
|
||||
runner := &fakeRunner{}
|
||||
service := NewService(repo, runner, nil)
|
||||
if err := service.SyncTask(context.Background(), &model.BackupTask{ID: 1, Enabled: true, CronExpr: "*/1 * * * * *"}); err != nil {
|
||||
t.Fatalf("SyncTask returned error: %v", err)
|
||||
}
|
||||
service.cron.Start()
|
||||
defer service.cron.Stop()
|
||||
time.Sleep(1100 * time.Millisecond)
|
||||
if len(runner.taskIDs) == 0 {
|
||||
t.Fatalf("expected scheduled runner to be triggered")
|
||||
}
|
||||
}
|
||||
60
server/internal/security/jwt.go
Normal file
60
server/internal/security/jwt.go
Normal file
@@ -0,0 +1,60 @@
|
||||
//go:build ignore
|
||||
|
||||
package security
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/model"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
type Claims struct {
|
||||
UserID uint `json:"userId"`
|
||||
Username string `json:"username"`
|
||||
Role string `json:"role"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
type JWTManager struct {
|
||||
secret []byte
|
||||
duration time.Duration
|
||||
}
|
||||
|
||||
func NewJWTManager(secret string, duration time.Duration) *JWTManager {
|
||||
return &JWTManager{secret: []byte(secret), duration: duration}
|
||||
}
|
||||
|
||||
func (m *JWTManager) IssueToken(user *model.User) (string, error) {
|
||||
now := time.Now().UTC()
|
||||
claims := Claims{
|
||||
UserID: user.ID,
|
||||
Username: user.Username,
|
||||
Role: user.Role,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: fmt.Sprintf("%d", user.ID),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(m.duration)),
|
||||
},
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString(m.secret)
|
||||
}
|
||||
|
||||
func (m *JWTManager) Parse(tokenValue string) (*Claims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenValue, &Claims{}, func(token *jwt.Token) (any, error) {
|
||||
if token.Method != jwt.SigningMethodHS256 {
|
||||
return nil, fmt.Errorf("unexpected signing method")
|
||||
}
|
||||
return m.secret, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
claims, ok := token.Claims.(*Claims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, fmt.Errorf("invalid token")
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
25
server/internal/security/jwt_test.go
Normal file
25
server/internal/security/jwt_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
//go:build ignore
|
||||
|
||||
package security
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/model"
|
||||
)
|
||||
|
||||
func TestJWTManagerIssueAndParse(t *testing.T) {
|
||||
manager := NewJWTManager("test-secret", time.Hour)
|
||||
token, err := manager.IssueToken(&model.User{ID: 7, Username: "admin", Role: "admin"})
|
||||
if err != nil {
|
||||
t.Fatalf("IssueToken() error = %v", err)
|
||||
}
|
||||
claims, err := manager.Parse(token)
|
||||
if err != nil {
|
||||
t.Fatalf("Parse() error = %v", err)
|
||||
}
|
||||
if claims.UserID != 7 || claims.Username != "admin" {
|
||||
t.Fatalf("unexpected claims: %+v", claims)
|
||||
}
|
||||
}
|
||||
54
server/internal/security/limiter.go
Normal file
54
server/internal/security/limiter.go
Normal file
@@ -0,0 +1,54 @@
|
||||
//go:build ignore
|
||||
|
||||
package security
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type limiterEntry struct {
|
||||
Count int
|
||||
ResetAt time.Time
|
||||
}
|
||||
|
||||
type LoginLimiter struct {
|
||||
mu sync.Mutex
|
||||
window time.Duration
|
||||
max int
|
||||
records map[string]limiterEntry
|
||||
}
|
||||
|
||||
func NewLoginLimiter(max int, window time.Duration) *LoginLimiter {
|
||||
return &LoginLimiter{window: window, max: max, records: make(map[string]limiterEntry)}
|
||||
}
|
||||
|
||||
func (l *LoginLimiter) Allow(key string) bool {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
entry, ok := l.records[key]
|
||||
if !ok || time.Now().After(entry.ResetAt) {
|
||||
delete(l.records, key)
|
||||
return true
|
||||
}
|
||||
return entry.Count < l.max
|
||||
}
|
||||
|
||||
func (l *LoginLimiter) RegisterFailure(key string) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
now := time.Now()
|
||||
entry, ok := l.records[key]
|
||||
if !ok || now.After(entry.ResetAt) {
|
||||
l.records[key] = limiterEntry{Count: 1, ResetAt: now.Add(l.window)}
|
||||
return
|
||||
}
|
||||
entry.Count++
|
||||
l.records[key] = entry
|
||||
}
|
||||
|
||||
func (l *LoginLimiter) Reset(key string) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
delete(l.records, key)
|
||||
}
|
||||
17
server/internal/security/password.go
Normal file
17
server/internal/security/password.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package security
|
||||
|
||||
import "golang.org/x/crypto/bcrypt"
|
||||
|
||||
const PasswordCost = 12
|
||||
|
||||
func HashPassword(password string) (string, error) {
|
||||
hashed, err := bcrypt.GenerateFromPassword([]byte(password), PasswordCost)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(hashed), nil
|
||||
}
|
||||
|
||||
func ComparePassword(hashedPassword, plainPassword string) error {
|
||||
return bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(plainPassword))
|
||||
}
|
||||
16
server/internal/security/password_test.go
Normal file
16
server/internal/security/password_test.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package security
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestHashAndComparePassword(t *testing.T) {
|
||||
hash, err := HashPassword("super-secret-password")
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword returned error: %v", err)
|
||||
}
|
||||
if hash == "super-secret-password" {
|
||||
t.Fatalf("expected hashed password to differ from plain text")
|
||||
}
|
||||
if err := ComparePassword(hash, "super-secret-password"); err != nil {
|
||||
t.Fatalf("ComparePassword returned error: %v", err)
|
||||
}
|
||||
}
|
||||
50
server/internal/security/rate_limiter.go
Normal file
50
server/internal/security/rate_limiter.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type rateEntry struct {
|
||||
count int
|
||||
windowEnd time.Time
|
||||
}
|
||||
|
||||
type LoginRateLimiter struct {
|
||||
limit int
|
||||
window time.Duration
|
||||
mu sync.Mutex
|
||||
items map[string]rateEntry
|
||||
}
|
||||
|
||||
func NewLoginRateLimiter(limit int, window time.Duration) *LoginRateLimiter {
|
||||
return &LoginRateLimiter{
|
||||
limit: limit,
|
||||
window: window,
|
||||
items: make(map[string]rateEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *LoginRateLimiter) Allow(key string) bool {
|
||||
now := time.Now().UTC()
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
entry, ok := r.items[key]
|
||||
if !ok || now.After(entry.windowEnd) {
|
||||
r.items[key] = rateEntry{count: 0, windowEnd: now.Add(r.window)}
|
||||
entry = r.items[key]
|
||||
}
|
||||
if entry.count >= r.limit {
|
||||
return false
|
||||
}
|
||||
entry.count++
|
||||
r.items[key] = entry
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *LoginRateLimiter) Reset(key string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
delete(r.items, key)
|
||||
}
|
||||
14
server/internal/security/secret.go
Normal file
14
server/internal/security/secret.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
)
|
||||
|
||||
func GenerateSecret(bytesLength int) (string, error) {
|
||||
buffer := make([]byte, bytesLength)
|
||||
if _, err := rand.Read(buffer); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(buffer), nil
|
||||
}
|
||||
93
server/internal/security/secret_store.go
Normal file
93
server/internal/security/secret_store.go
Normal file
@@ -0,0 +1,93 @@
|
||||
//go:build ignore
|
||||
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"backupx/server/internal/config"
|
||||
)
|
||||
|
||||
type PersistedSecrets struct {
|
||||
JWTSecret string `json:"jwtSecret"`
|
||||
EncryptionKey string `json:"encryptionKey"`
|
||||
}
|
||||
|
||||
func EnsureSecrets(cfg *config.Config) error {
|
||||
if cfg.Security.JWTSecret != "" && cfg.Security.EncryptionKey != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
storePath := filepath.Join(filepath.Dir(cfg.Database.Path), "backupx.secrets.json")
|
||||
current, err := loadSecrets(storePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if current == nil {
|
||||
current = &PersistedSecrets{}
|
||||
}
|
||||
if current.JWTSecret == "" {
|
||||
current.JWTSecret, err = randomHex(32)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if current.EncryptionKey == "" {
|
||||
current.EncryptionKey, err = randomHex(32)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := saveSecrets(storePath, current); err != nil {
|
||||
return err
|
||||
}
|
||||
if cfg.Security.JWTSecret == "" {
|
||||
cfg.Security.JWTSecret = current.JWTSecret
|
||||
}
|
||||
if cfg.Security.EncryptionKey == "" {
|
||||
cfg.Security.EncryptionKey = current.EncryptionKey
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadSecrets(path string) (*PersistedSecrets, error) {
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("read secrets: %w", err)
|
||||
}
|
||||
var secrets PersistedSecrets
|
||||
if err := json.Unmarshal(content, &secrets); err != nil {
|
||||
return nil, fmt.Errorf("decode secrets: %w", err)
|
||||
}
|
||||
return &secrets, nil
|
||||
}
|
||||
|
||||
func saveSecrets(path string, secrets *PersistedSecrets) error {
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
|
||||
return fmt.Errorf("create secrets dir: %w", err)
|
||||
}
|
||||
content, err := json.MarshalIndent(secrets, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode secrets: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(path, content, 0o600); err != nil {
|
||||
return fmt.Errorf("write secrets: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func randomHex(size int) (string, error) {
|
||||
bytes := make([]byte, size)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", fmt.Errorf("generate random secret: %w", err)
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
57
server/internal/security/token.go
Normal file
57
server/internal/security/token.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/model"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
type Claims struct {
|
||||
Username string `json:"username"`
|
||||
Role string `json:"role"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
type JWTManager struct {
|
||||
secret []byte
|
||||
expiry time.Duration
|
||||
}
|
||||
|
||||
func NewJWTManager(secret string, expiry time.Duration) *JWTManager {
|
||||
return &JWTManager{secret: []byte(secret), expiry: expiry}
|
||||
}
|
||||
|
||||
func (m *JWTManager) Generate(user *model.User) (string, error) {
|
||||
now := time.Now().UTC()
|
||||
claims := Claims{
|
||||
Username: user.Username,
|
||||
Role: user.Role,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: strconv.FormatUint(uint64(user.ID), 10),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(m.expiry)),
|
||||
},
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString(m.secret)
|
||||
}
|
||||
|
||||
func (m *JWTManager) Parse(tokenString string) (*Claims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (any, error) {
|
||||
if token.Method != jwt.SigningMethodHS256 {
|
||||
return nil, fmt.Errorf("unexpected signing method: %s", token.Method.Alg())
|
||||
}
|
||||
return m.secret, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
claims, ok := token.Claims.(*Claims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, fmt.Errorf("invalid token claims")
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
30
server/internal/security/token_test.go
Normal file
30
server/internal/security/token_test.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/model"
|
||||
)
|
||||
|
||||
func TestJWTManagerGenerateAndParse(t *testing.T) {
|
||||
manager := NewJWTManager("test-secret", time.Hour)
|
||||
user := &model.User{ID: 7, Username: "admin", Role: "admin"}
|
||||
|
||||
token, err := manager.Generate(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Generate returned error: %v", err)
|
||||
}
|
||||
|
||||
claims, err := manager.Parse(token)
|
||||
if err != nil {
|
||||
t.Fatalf("Parse returned error: %v", err)
|
||||
}
|
||||
|
||||
if claims.Subject != "7" {
|
||||
t.Fatalf("expected subject 7, got %s", claims.Subject)
|
||||
}
|
||||
if claims.Username != "admin" {
|
||||
t.Fatalf("expected username admin, got %s", claims.Username)
|
||||
}
|
||||
}
|
||||
194
server/internal/service/auth_service.go
Normal file
194
server/internal/service/auth_service.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"backupx/server/internal/apperror"
|
||||
"backupx/server/internal/model"
|
||||
"backupx/server/internal/repository"
|
||||
"backupx/server/internal/security"
|
||||
)
|
||||
|
||||
type SetupInput struct {
|
||||
Username string `json:"username" binding:"required,min=3,max=64"`
|
||||
Password string `json:"password" binding:"required,min=8,max=128"`
|
||||
DisplayName string `json:"displayName" binding:"required,min=1,max=128"`
|
||||
}
|
||||
|
||||
type LoginInput struct {
|
||||
Username string `json:"username" binding:"required,min=3,max=64"`
|
||||
Password string `json:"password" binding:"required,min=8,max=128"`
|
||||
}
|
||||
|
||||
type AuthPayload struct {
|
||||
Token string `json:"token"`
|
||||
User *UserOutput `json:"user"`
|
||||
}
|
||||
|
||||
type UserOutput struct {
|
||||
ID uint `json:"id"`
|
||||
Username string `json:"username"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
type AuthService struct {
|
||||
users repository.UserRepository
|
||||
configs repository.SystemConfigRepository
|
||||
jwtManager *security.JWTManager
|
||||
rateLimiter *security.LoginRateLimiter
|
||||
}
|
||||
|
||||
func NewAuthService(
|
||||
users repository.UserRepository,
|
||||
configs repository.SystemConfigRepository,
|
||||
jwtManager *security.JWTManager,
|
||||
rateLimiter *security.LoginRateLimiter,
|
||||
) *AuthService {
|
||||
return &AuthService{users: users, configs: configs, jwtManager: jwtManager, rateLimiter: rateLimiter}
|
||||
}
|
||||
|
||||
func (s *AuthService) SetupStatus(ctx context.Context) (bool, error) {
|
||||
count, err := s.users.Count(ctx)
|
||||
if err != nil {
|
||||
return false, apperror.Internal("AUTH_STATUS_FAILED", "无法检查初始化状态", err)
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) Setup(ctx context.Context, input SetupInput) (*AuthPayload, error) {
|
||||
initialized, err := s.SetupStatus(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if initialized {
|
||||
return nil, apperror.Conflict("AUTH_SETUP_DISABLED", "系统已初始化,请直接登录", nil)
|
||||
}
|
||||
|
||||
existing, err := s.users.FindByUsername(ctx, strings.TrimSpace(input.Username))
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("AUTH_LOOKUP_FAILED", "无法检查账户状态", err)
|
||||
}
|
||||
if existing != nil {
|
||||
return nil, apperror.Conflict("AUTH_USERNAME_EXISTS", "用户名已存在", nil)
|
||||
}
|
||||
|
||||
hash, err := security.HashPassword(input.Password)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("AUTH_HASH_FAILED", "无法处理密码", err)
|
||||
}
|
||||
|
||||
user := &model.User{
|
||||
Username: strings.TrimSpace(input.Username),
|
||||
PasswordHash: hash,
|
||||
DisplayName: strings.TrimSpace(input.DisplayName),
|
||||
Role: "admin",
|
||||
}
|
||||
if err := s.users.Create(ctx, user); err != nil {
|
||||
return nil, apperror.Internal("AUTH_CREATE_USER_FAILED", "无法创建管理员账户", err)
|
||||
}
|
||||
|
||||
token, err := s.jwtManager.Generate(user)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("AUTH_TOKEN_FAILED", "无法生成访问令牌", err)
|
||||
}
|
||||
|
||||
return &AuthPayload{Token: token, User: ToUserOutput(user)}, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) Login(ctx context.Context, input LoginInput, clientKey string) (*AuthPayload, error) {
|
||||
if clientKey == "" {
|
||||
clientKey = "unknown"
|
||||
}
|
||||
if !s.rateLimiter.Allow(clientKey) {
|
||||
return nil, apperror.TooManyRequests("AUTH_RATE_LIMITED", "登录尝试过于频繁,请稍后再试", nil)
|
||||
}
|
||||
|
||||
user, err := s.users.FindByUsername(ctx, strings.TrimSpace(input.Username))
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("AUTH_LOOKUP_FAILED", "无法执行登录校验", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, apperror.Unauthorized("AUTH_INVALID_CREDENTIALS", "用户名或密码错误", nil)
|
||||
}
|
||||
if err := security.ComparePassword(user.PasswordHash, input.Password); err != nil {
|
||||
return nil, apperror.Unauthorized("AUTH_INVALID_CREDENTIALS", "用户名或密码错误", err)
|
||||
}
|
||||
|
||||
s.rateLimiter.Reset(clientKey)
|
||||
token, err := s.jwtManager.Generate(user)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("AUTH_TOKEN_FAILED", "无法生成访问令牌", err)
|
||||
}
|
||||
return &AuthPayload{Token: token, User: ToUserOutput(user)}, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) GetCurrentUser(ctx context.Context, subject string) (*UserOutput, error) {
|
||||
userID, err := strconv.ParseUint(subject, 10, 64)
|
||||
if err != nil {
|
||||
return nil, apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效用户身份", err)
|
||||
}
|
||||
user, err := s.users.FindByID(ctx, uint(userID))
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("AUTH_LOOKUP_FAILED", "无法获取当前用户", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, apperror.Unauthorized("AUTH_USER_NOT_FOUND", "当前用户不存在", errors.New("user not found"))
|
||||
}
|
||||
return ToUserOutput(user), nil
|
||||
}
|
||||
|
||||
type ChangePasswordInput struct {
|
||||
OldPassword string `json:"oldPassword" binding:"required,min=8,max=128"`
|
||||
NewPassword string `json:"newPassword" binding:"required,min=8,max=128"`
|
||||
}
|
||||
|
||||
func (s *AuthService) ChangePassword(ctx context.Context, subject string, input ChangePasswordInput) error {
|
||||
userID, err := strconv.ParseUint(subject, 10, 64)
|
||||
if err != nil {
|
||||
return apperror.Unauthorized("AUTH_INVALID_SUBJECT", "无效用户身份", err)
|
||||
}
|
||||
user, err := s.users.FindByID(ctx, uint(userID))
|
||||
if err != nil {
|
||||
return apperror.Internal("AUTH_LOOKUP_FAILED", "无法获取当前用户", err)
|
||||
}
|
||||
if user == nil {
|
||||
return apperror.Unauthorized("AUTH_USER_NOT_FOUND", "当前用户不存在", errors.New("user not found"))
|
||||
}
|
||||
if err := security.ComparePassword(user.PasswordHash, input.OldPassword); err != nil {
|
||||
return apperror.BadRequest("AUTH_WRONG_PASSWORD", "旧密码不正确", err)
|
||||
}
|
||||
hash, err := security.HashPassword(input.NewPassword)
|
||||
if err != nil {
|
||||
return apperror.Internal("AUTH_HASH_FAILED", "无法处理密码", err)
|
||||
}
|
||||
user.PasswordHash = hash
|
||||
if err := s.users.Update(ctx, user); err != nil {
|
||||
return apperror.Internal("AUTH_UPDATE_FAILED", "密码修改失败", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ToUserOutput(user *model.User) *UserOutput {
|
||||
if user == nil {
|
||||
return nil
|
||||
}
|
||||
return &UserOutput{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
DisplayName: user.DisplayName,
|
||||
Role: user.Role,
|
||||
}
|
||||
}
|
||||
|
||||
func SubjectFromContextValue(value any) (string, error) {
|
||||
subject, ok := value.(string)
|
||||
if !ok || strings.TrimSpace(subject) == "" {
|
||||
return "", fmt.Errorf("invalid subject context")
|
||||
}
|
||||
return subject, nil
|
||||
}
|
||||
162
server/internal/service/auth_service_test.go
Normal file
162
server/internal/service/auth_service_test.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/model"
|
||||
"backupx/server/internal/security"
|
||||
)
|
||||
|
||||
type fakeUserRepository struct {
|
||||
users []*model.User
|
||||
}
|
||||
|
||||
func (r *fakeUserRepository) Count(context.Context) (int64, error) {
|
||||
return int64(len(r.users)), nil
|
||||
}
|
||||
|
||||
func (r *fakeUserRepository) Create(_ context.Context, user *model.User) error {
|
||||
user.ID = uint(len(r.users) + 1)
|
||||
r.users = append(r.users, user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeUserRepository) FindByUsername(_ context.Context, username string) (*model.User, error) {
|
||||
for _, user := range r.users {
|
||||
if user.Username == username {
|
||||
return user, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *fakeUserRepository) FindByID(_ context.Context, id uint) (*model.User, error) {
|
||||
for _, user := range r.users {
|
||||
if user.ID == id {
|
||||
return user, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *fakeUserRepository) Update(_ context.Context, user *model.User) error {
|
||||
for i, u := range r.users {
|
||||
if u.ID == user.ID {
|
||||
r.users[i] = user
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type fakeSystemConfigRepository struct{}
|
||||
|
||||
func (r *fakeSystemConfigRepository) GetByKey(context.Context, string) (*model.SystemConfig, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *fakeSystemConfigRepository) List(context.Context) ([]model.SystemConfig, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *fakeSystemConfigRepository) Upsert(context.Context, *model.SystemConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAuthServiceSetupAndLogin(t *testing.T) {
|
||||
users := &fakeUserRepository{}
|
||||
service := NewAuthService(
|
||||
users,
|
||||
&fakeSystemConfigRepository{},
|
||||
security.NewJWTManager("test-secret", time.Hour),
|
||||
security.NewLoginRateLimiter(5, time.Minute),
|
||||
)
|
||||
|
||||
setupResult, err := service.Setup(context.Background(), SetupInput{
|
||||
Username: "admin",
|
||||
Password: "password-123",
|
||||
DisplayName: "Admin",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Setup returned error: %v", err)
|
||||
}
|
||||
if setupResult.User.Username != "admin" {
|
||||
t.Fatalf("expected username admin, got %s", setupResult.User.Username)
|
||||
}
|
||||
|
||||
loginResult, err := service.Login(context.Background(), LoginInput{
|
||||
Username: "admin",
|
||||
Password: "password-123",
|
||||
}, "127.0.0.1")
|
||||
if err != nil {
|
||||
t.Fatalf("Login returned error: %v", err)
|
||||
}
|
||||
if loginResult.Token == "" {
|
||||
t.Fatalf("expected non-empty token")
|
||||
}
|
||||
}
|
||||
|
||||
func newTestAuthService() (*AuthService, *fakeUserRepository) {
|
||||
users := &fakeUserRepository{}
|
||||
svc := NewAuthService(
|
||||
users,
|
||||
&fakeSystemConfigRepository{},
|
||||
security.NewJWTManager("test-secret", time.Hour),
|
||||
security.NewLoginRateLimiter(5, time.Minute),
|
||||
)
|
||||
return svc, users
|
||||
}
|
||||
|
||||
func TestChangePassword(t *testing.T) {
|
||||
svc, _ := newTestAuthService()
|
||||
_, err := svc.Setup(context.Background(), SetupInput{
|
||||
Username: "admin", Password: "password-123", DisplayName: "Admin",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Setup: %v", err)
|
||||
}
|
||||
|
||||
err = svc.ChangePassword(context.Background(), "1", ChangePasswordInput{
|
||||
OldPassword: "password-123",
|
||||
NewPassword: "new-password-456",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ChangePassword: %v", err)
|
||||
}
|
||||
|
||||
// Old password should no longer work
|
||||
_, err = svc.Login(context.Background(), LoginInput{
|
||||
Username: "admin", Password: "password-123",
|
||||
}, "127.0.0.1")
|
||||
if err == nil {
|
||||
t.Fatalf("expected login with old password to fail")
|
||||
}
|
||||
|
||||
// New password should work
|
||||
_, err = svc.Login(context.Background(), LoginInput{
|
||||
Username: "admin", Password: "new-password-456",
|
||||
}, "127.0.0.1")
|
||||
if err != nil {
|
||||
t.Fatalf("login with new password: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChangePasswordWrongOld(t *testing.T) {
|
||||
svc, _ := newTestAuthService()
|
||||
_, err := svc.Setup(context.Background(), SetupInput{
|
||||
Username: "admin", Password: "password-123", DisplayName: "Admin",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Setup: %v", err)
|
||||
}
|
||||
|
||||
err = svc.ChangePassword(context.Background(), "1", ChangePasswordInput{
|
||||
OldPassword: "wrong-password",
|
||||
NewPassword: "new-password-456",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("expected ChangePassword with wrong old password to fail")
|
||||
}
|
||||
}
|
||||
487
server/internal/service/backup_execution_service.go
Normal file
487
server/internal/service/backup_execution_service.go
Normal file
@@ -0,0 +1,487 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/apperror"
|
||||
"backupx/server/internal/backup"
|
||||
backupretention "backupx/server/internal/backup/retention"
|
||||
"backupx/server/internal/model"
|
||||
"backupx/server/internal/repository"
|
||||
"backupx/server/internal/storage"
|
||||
"backupx/server/internal/storage/codec"
|
||||
"backupx/server/pkg/compress"
|
||||
backupcrypto "backupx/server/pkg/crypto"
|
||||
)
|
||||
|
||||
type BackupExecutionNotification struct {
|
||||
Task *model.BackupTask
|
||||
Record *model.BackupRecord
|
||||
Error error
|
||||
}
|
||||
|
||||
type BackupResultNotifier interface {
|
||||
NotifyBackupResult(context.Context, BackupExecutionNotification) error
|
||||
}
|
||||
|
||||
type noopBackupNotifier struct{}
|
||||
|
||||
func (noopBackupNotifier) NotifyBackupResult(context.Context, BackupExecutionNotification) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type DownloadedArtifact struct {
|
||||
FileName string
|
||||
Reader io.ReadCloser
|
||||
}
|
||||
|
||||
type BackupExecutionService struct {
|
||||
tasks repository.BackupTaskRepository
|
||||
records repository.BackupRecordRepository
|
||||
targets repository.StorageTargetRepository
|
||||
storageRegistry *storage.Registry
|
||||
runnerRegistry *backup.Registry
|
||||
logHub *backup.LogHub
|
||||
retention *backupretention.Service
|
||||
cipher *codec.ConfigCipher
|
||||
notifier BackupResultNotifier
|
||||
async func(func())
|
||||
now func() time.Time
|
||||
tempDir string
|
||||
semaphore chan struct{}
|
||||
}
|
||||
|
||||
func NewBackupExecutionService(
|
||||
tasks repository.BackupTaskRepository,
|
||||
records repository.BackupRecordRepository,
|
||||
targets repository.StorageTargetRepository,
|
||||
storageRegistry *storage.Registry,
|
||||
runnerRegistry *backup.Registry,
|
||||
logHub *backup.LogHub,
|
||||
retention *backupretention.Service,
|
||||
cipher *codec.ConfigCipher,
|
||||
notifier BackupResultNotifier,
|
||||
tempDir string,
|
||||
maxConcurrent int,
|
||||
) *BackupExecutionService {
|
||||
if notifier == nil {
|
||||
notifier = noopBackupNotifier{}
|
||||
}
|
||||
if tempDir == "" {
|
||||
tempDir = "/tmp/backupx"
|
||||
}
|
||||
if maxConcurrent <= 0 {
|
||||
maxConcurrent = 2
|
||||
}
|
||||
return &BackupExecutionService{
|
||||
tasks: tasks,
|
||||
records: records,
|
||||
targets: targets,
|
||||
storageRegistry: storageRegistry,
|
||||
runnerRegistry: runnerRegistry,
|
||||
logHub: logHub,
|
||||
retention: retention,
|
||||
cipher: cipher,
|
||||
notifier: notifier,
|
||||
async: func(job func()) {
|
||||
go job()
|
||||
},
|
||||
now: func() time.Time { return time.Now().UTC() },
|
||||
tempDir: tempDir,
|
||||
semaphore: make(chan struct{}, maxConcurrent),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupExecutionService) RunTaskByID(ctx context.Context, id uint) (*BackupRecordDetail, error) {
|
||||
return s.startTask(ctx, id, true)
|
||||
}
|
||||
|
||||
func (s *BackupExecutionService) RunTaskByIDSync(ctx context.Context, id uint) (*BackupRecordDetail, error) {
|
||||
return s.startTask(ctx, id, false)
|
||||
}
|
||||
|
||||
func (s *BackupExecutionService) DownloadRecord(ctx context.Context, recordID uint) (*DownloadedArtifact, error) {
|
||||
record, provider, err := s.loadRecordProvider(ctx, recordID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reader, err := provider.Download(ctx, record.StoragePath)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("BACKUP_RECORD_DOWNLOAD_FAILED", "无法下载备份文件", err)
|
||||
}
|
||||
fileName := record.FileName
|
||||
if strings.TrimSpace(fileName) == "" {
|
||||
fileName = filepath.Base(record.StoragePath)
|
||||
}
|
||||
return &DownloadedArtifact{FileName: fileName, Reader: reader}, nil
|
||||
}
|
||||
|
||||
func (s *BackupExecutionService) RestoreRecord(ctx context.Context, recordID uint) error {
|
||||
record, provider, err := s.loadRecordProvider(ctx, recordID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
task, err := s.tasks.FindByID(ctx, record.TaskID)
|
||||
if err != nil {
|
||||
return apperror.Internal("BACKUP_TASK_GET_FAILED", "无法获取关联备份任务", err)
|
||||
}
|
||||
if task == nil {
|
||||
return apperror.New(404, "BACKUP_TASK_NOT_FOUND", "关联的备份任务不存在,无法执行恢复", fmt.Errorf("backup task %d not found", record.TaskID))
|
||||
}
|
||||
tempDir, err := os.MkdirTemp("", "backupx-restore-*")
|
||||
if err != nil {
|
||||
return apperror.Internal("BACKUP_RECORD_RESTORE_FAILED", "无法创建恢复目录", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
artifactPath := filepath.Join(tempDir, filepath.Base(record.FileName))
|
||||
if strings.TrimSpace(filepath.Base(record.FileName)) == "" {
|
||||
artifactPath = filepath.Join(tempDir, filepath.Base(record.StoragePath))
|
||||
}
|
||||
reader, err := provider.Download(ctx, record.StoragePath)
|
||||
if err != nil {
|
||||
return apperror.Internal("BACKUP_RECORD_RESTORE_FAILED", "无法下载备份文件", err)
|
||||
}
|
||||
if err := writeReaderToFile(artifactPath, reader); err != nil {
|
||||
return apperror.Internal("BACKUP_RECORD_RESTORE_FAILED", "无法写入恢复文件", err)
|
||||
}
|
||||
preparedPath, err := s.prepareArtifactForRestore(artifactPath)
|
||||
if err != nil {
|
||||
return apperror.Internal("BACKUP_RECORD_RESTORE_FAILED", "无法准备恢复文件", err)
|
||||
}
|
||||
spec, err := s.buildTaskSpec(task, record.StartedAt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
runner, err := s.runnerRegistry.Runner(spec.Type)
|
||||
if err != nil {
|
||||
return apperror.BadRequest("BACKUP_TASK_INVALID", "不支持的备份任务类型", err)
|
||||
}
|
||||
if err := runner.Restore(ctx, spec, preparedPath, backup.NopLogWriter{}); err != nil {
|
||||
return apperror.Internal("BACKUP_RECORD_RESTORE_FAILED", "恢复备份失败", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupExecutionService) DeleteRecord(ctx context.Context, recordID uint) error {
|
||||
record, provider, err := s.loadRecordProvider(ctx, recordID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if strings.TrimSpace(record.StoragePath) != "" {
|
||||
if err := provider.Delete(ctx, record.StoragePath); err != nil {
|
||||
return apperror.Internal("BACKUP_RECORD_DELETE_FAILED", "无法删除备份文件", err)
|
||||
}
|
||||
}
|
||||
if err := s.records.Delete(ctx, recordID); err != nil {
|
||||
return apperror.Internal("BACKUP_RECORD_DELETE_FAILED", "无法删除备份记录", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupExecutionService) startTask(ctx context.Context, id uint, async bool) (*BackupRecordDetail, error) {
|
||||
task, err := s.tasks.FindByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("BACKUP_TASK_GET_FAILED", "无法获取备份任务详情", err)
|
||||
}
|
||||
if task == nil {
|
||||
return nil, apperror.New(404, "BACKUP_TASK_NOT_FOUND", "备份任务不存在", fmt.Errorf("backup task %d not found", id))
|
||||
}
|
||||
startedAt := s.now()
|
||||
record := &model.BackupRecord{TaskID: task.ID, StorageTargetID: task.StorageTargetID, Status: "running", StartedAt: startedAt}
|
||||
if err := s.records.Create(ctx, record); err != nil {
|
||||
return nil, apperror.Internal("BACKUP_RECORD_CREATE_FAILED", "无法创建备份记录", err)
|
||||
}
|
||||
task.LastRunAt = &startedAt
|
||||
task.LastStatus = "running"
|
||||
if err := s.tasks.Update(ctx, task); err != nil {
|
||||
return nil, apperror.Internal("BACKUP_TASK_UPDATE_FAILED", "无法更新任务状态", err)
|
||||
}
|
||||
run := func() {
|
||||
s.executeTask(context.Background(), task, record.ID, startedAt)
|
||||
}
|
||||
if async {
|
||||
s.async(run)
|
||||
} else {
|
||||
run()
|
||||
}
|
||||
return s.getRecordDetail(ctx, record.ID)
|
||||
}
|
||||
|
||||
func (s *BackupExecutionService) executeTask(ctx context.Context, task *model.BackupTask, recordID uint, startedAt time.Time) {
|
||||
s.semaphore <- struct{}{}
|
||||
defer func() { <-s.semaphore }()
|
||||
|
||||
logger := backup.NewExecutionLogger(recordID, s.logHub)
|
||||
status := "failed"
|
||||
errMessage := ""
|
||||
var fileName string
|
||||
var fileSize int64
|
||||
var storagePath string
|
||||
completeRecord := func() {
|
||||
if finalizeErr := s.finalizeRecord(ctx, task, recordID, startedAt, status, errMessage, logger.String(), fileName, fileSize, storagePath); finalizeErr != nil {
|
||||
logger.Errorf("写回备份记录失败:%v", finalizeErr)
|
||||
}
|
||||
if err := s.notifier.NotifyBackupResult(ctx, BackupExecutionNotification{Task: task, Record: &model.BackupRecord{ID: recordID, TaskID: task.ID, Status: status, FileName: fileName, FileSize: fileSize, StoragePath: storagePath, ErrorMessage: errMessage, StartedAt: startedAt}, Error: buildOptionalError(errMessage)}); err != nil {
|
||||
logger.Warnf("发送备份通知失败:%v", err)
|
||||
}
|
||||
s.logHub.Complete(recordID, status)
|
||||
}
|
||||
defer completeRecord()
|
||||
|
||||
spec, err := s.buildTaskSpec(task, startedAt)
|
||||
if err != nil {
|
||||
errMessage = err.Error()
|
||||
logger.Errorf("构建任务运行时配置失败:%v", err)
|
||||
return
|
||||
}
|
||||
provider, err := s.resolveProvider(ctx, task.StorageTargetID)
|
||||
if err != nil {
|
||||
errMessage = err.Error()
|
||||
logger.Errorf("创建存储客户端失败:%v", err)
|
||||
return
|
||||
}
|
||||
runner, err := s.runnerRegistry.Runner(spec.Type)
|
||||
if err != nil {
|
||||
errMessage = err.Error()
|
||||
logger.Errorf("获取备份执行器失败:%v", err)
|
||||
return
|
||||
}
|
||||
result, err := runner.Run(ctx, spec, logger)
|
||||
if err != nil {
|
||||
errMessage = err.Error()
|
||||
logger.Errorf("执行备份失败:%v", err)
|
||||
return
|
||||
}
|
||||
defer os.RemoveAll(result.TempDir)
|
||||
finalPath := result.ArtifactPath
|
||||
if strings.EqualFold(task.Compression, "gzip") && !strings.HasSuffix(strings.ToLower(finalPath), ".gz") {
|
||||
logger.Infof("开始压缩备份文件")
|
||||
compressedPath, compressErr := compress.GzipFile(finalPath)
|
||||
if compressErr != nil {
|
||||
errMessage = compressErr.Error()
|
||||
logger.Errorf("压缩备份文件失败:%v", compressErr)
|
||||
return
|
||||
}
|
||||
finalPath = compressedPath
|
||||
}
|
||||
if task.Encrypt {
|
||||
logger.Infof("开始加密备份文件")
|
||||
encryptedPath, encryptErr := backupcrypto.EncryptFile(s.cipher.Key(), finalPath)
|
||||
if encryptErr != nil {
|
||||
errMessage = encryptErr.Error()
|
||||
logger.Errorf("加密备份文件失败:%v", encryptErr)
|
||||
return
|
||||
}
|
||||
finalPath = encryptedPath
|
||||
}
|
||||
info, err := os.Stat(finalPath)
|
||||
if err != nil {
|
||||
errMessage = err.Error()
|
||||
logger.Errorf("获取备份文件信息失败:%v", err)
|
||||
return
|
||||
}
|
||||
fileSize = info.Size()
|
||||
fileName = filepath.Base(finalPath)
|
||||
storagePath = backup.BuildStorageKey(task.Type, startedAt, fileName)
|
||||
artifact, err := os.Open(finalPath)
|
||||
if err != nil {
|
||||
errMessage = err.Error()
|
||||
logger.Errorf("打开备份文件失败:%v", err)
|
||||
return
|
||||
}
|
||||
defer artifact.Close()
|
||||
logger.Infof("开始上传备份到存储目标")
|
||||
if err := provider.Upload(ctx, storagePath, artifact, fileSize, map[string]string{"taskId": fmt.Sprintf("%d", task.ID), "recordId": fmt.Sprintf("%d", recordID)}); err != nil {
|
||||
errMessage = err.Error()
|
||||
logger.Errorf("上传备份文件失败:%v", err)
|
||||
return
|
||||
}
|
||||
if s.retention != nil {
|
||||
cleanupResult, cleanupErr := s.retention.Cleanup(ctx, task, provider)
|
||||
if cleanupErr != nil {
|
||||
logger.Warnf("执行保留策略失败:%v", cleanupErr)
|
||||
} else {
|
||||
for _, warning := range cleanupResult.Warnings {
|
||||
logger.Warnf("保留策略警告:%s", warning)
|
||||
}
|
||||
}
|
||||
}
|
||||
status = "success"
|
||||
logger.Infof("备份执行完成")
|
||||
}
|
||||
|
||||
func (s *BackupExecutionService) finalizeRecord(ctx context.Context, task *model.BackupTask, recordID uint, startedAt time.Time, status string, errorMessage string, logContent string, fileName string, fileSize int64, storagePath string) error {
|
||||
record, err := s.records.FindByID(ctx, recordID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if record == nil {
|
||||
return fmt.Errorf("backup record %d not found", recordID)
|
||||
}
|
||||
completedAt := s.now()
|
||||
record.Status = status
|
||||
record.FileName = fileName
|
||||
record.FileSize = fileSize
|
||||
record.StoragePath = storagePath
|
||||
record.DurationSeconds = int(completedAt.Sub(startedAt).Seconds())
|
||||
record.ErrorMessage = strings.TrimSpace(errorMessage)
|
||||
record.LogContent = strings.TrimSpace(logContent)
|
||||
record.CompletedAt = &completedAt
|
||||
if err := s.records.Update(ctx, record); err != nil {
|
||||
return err
|
||||
}
|
||||
task.LastRunAt = &startedAt
|
||||
task.LastStatus = status
|
||||
return s.tasks.Update(ctx, task)
|
||||
}
|
||||
|
||||
func (s *BackupExecutionService) resolveProvider(ctx context.Context, targetID uint) (storage.StorageProvider, error) {
|
||||
target, err := s.targets.FindByID(ctx, targetID)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("BACKUP_STORAGE_TARGET_GET_FAILED", "无法获取存储目标详情", err)
|
||||
}
|
||||
if target == nil {
|
||||
return nil, apperror.BadRequest("BACKUP_STORAGE_TARGET_INVALID", "关联的存储目标不存在", nil)
|
||||
}
|
||||
configMap := map[string]any{}
|
||||
if err := s.cipher.DecryptJSON(target.ConfigCiphertext, &configMap); err != nil {
|
||||
return nil, apperror.Internal("BACKUP_STORAGE_TARGET_DECRYPT_FAILED", "无法解密存储目标配置", err)
|
||||
}
|
||||
provider, err := s.storageRegistry.Create(ctx, target.Type, configMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (s *BackupExecutionService) buildTaskSpec(task *model.BackupTask, startedAt time.Time) (backup.TaskSpec, error) {
|
||||
excludePatterns := []string{}
|
||||
if strings.TrimSpace(task.ExcludePatterns) != "" {
|
||||
if err := json.Unmarshal([]byte(task.ExcludePatterns), &excludePatterns); err != nil {
|
||||
return backup.TaskSpec{}, apperror.Internal("BACKUP_TASK_DECODE_FAILED", "无法解析排除规则", err)
|
||||
}
|
||||
}
|
||||
password := ""
|
||||
if strings.TrimSpace(task.DBPasswordCiphertext) != "" {
|
||||
plain, err := s.cipher.Decrypt(task.DBPasswordCiphertext)
|
||||
if err != nil {
|
||||
return backup.TaskSpec{}, apperror.Internal("BACKUP_TASK_DECRYPT_FAILED", "无法解密数据库密码", err)
|
||||
}
|
||||
password = string(plain)
|
||||
}
|
||||
return backup.TaskSpec{
|
||||
ID: task.ID,
|
||||
Name: task.Name,
|
||||
Type: task.Type,
|
||||
SourcePath: task.SourcePath,
|
||||
ExcludePatterns: excludePatterns,
|
||||
StorageTargetID: task.StorageTargetID,
|
||||
StorageTargetType: "",
|
||||
Compression: task.Compression,
|
||||
Encrypt: task.Encrypt,
|
||||
RetentionDays: task.RetentionDays,
|
||||
MaxBackups: task.MaxBackups,
|
||||
StartedAt: startedAt,
|
||||
TempDir: s.tempDir,
|
||||
Database: backup.DatabaseSpec{
|
||||
Host: task.DBHost,
|
||||
Port: task.DBPort,
|
||||
User: task.DBUser,
|
||||
Password: password,
|
||||
Names: []string{task.DBName},
|
||||
Path: task.DBPath,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *BackupExecutionService) loadRecordProvider(ctx context.Context, recordID uint) (*model.BackupRecord, storage.StorageProvider, error) {
|
||||
record, err := s.records.FindByID(ctx, recordID)
|
||||
if err != nil {
|
||||
return nil, nil, apperror.Internal("BACKUP_RECORD_GET_FAILED", "无法获取备份记录详情", err)
|
||||
}
|
||||
if record == nil {
|
||||
return nil, nil, apperror.New(404, "BACKUP_RECORD_NOT_FOUND", "备份记录不存在", fmt.Errorf("backup record %d not found", recordID))
|
||||
}
|
||||
provider, err := s.resolveProvider(ctx, record.StorageTargetID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return record, provider, nil
|
||||
}
|
||||
|
||||
func (s *BackupExecutionService) prepareArtifactForRestore(artifactPath string) (string, error) {
|
||||
currentPath := artifactPath
|
||||
if strings.HasSuffix(strings.ToLower(currentPath), ".enc") {
|
||||
decryptedPath, err := backupcrypto.DecryptFile(s.cipher.Key(), currentPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
currentPath = decryptedPath
|
||||
}
|
||||
if strings.HasSuffix(strings.ToLower(currentPath), ".gz") {
|
||||
decompressedPath, err := compress.GunzipFile(currentPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
currentPath = decompressedPath
|
||||
}
|
||||
return currentPath, nil
|
||||
}
|
||||
|
||||
func (s *BackupExecutionService) getRecordDetail(ctx context.Context, recordID uint) (*BackupRecordDetail, error) {
|
||||
record, err := s.records.FindByID(ctx, recordID)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("BACKUP_RECORD_GET_FAILED", "无法获取备份记录详情", err)
|
||||
}
|
||||
if record == nil {
|
||||
return nil, apperror.New(404, "BACKUP_RECORD_NOT_FOUND", "备份记录不存在", fmt.Errorf("backup record %d not found", recordID))
|
||||
}
|
||||
return toBackupRecordDetail(record, s.logHub), nil
|
||||
}
|
||||
|
||||
func writeReaderToFile(targetPath string, reader io.ReadCloser) error {
|
||||
defer reader.Close()
|
||||
if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
file, err := os.Create(targetPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
_, err = io.Copy(file, reader)
|
||||
return err
|
||||
}
|
||||
|
||||
func buildOptionalError(message string) error {
|
||||
if strings.TrimSpace(message) == "" {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("%s", message)
|
||||
}
|
||||
|
||||
func buildStorageProviderFromRepos(ctx context.Context, storageTargetID uint, storageTargets repository.StorageTargetRepository, storageRegistry *storage.Registry, cipher *codec.ConfigCipher) (storage.StorageProvider, *model.StorageTarget, error) {
|
||||
target, err := storageTargets.FindByID(ctx, storageTargetID)
|
||||
if err != nil {
|
||||
return nil, nil, apperror.Internal("BACKUP_STORAGE_TARGET_LOOKUP_FAILED", "无法读取存储目标", err)
|
||||
}
|
||||
if target == nil {
|
||||
return nil, nil, apperror.BadRequest("BACKUP_STORAGE_TARGET_INVALID", "存储目标不存在", nil)
|
||||
}
|
||||
var configMap map[string]any
|
||||
if err := cipher.DecryptJSON(target.ConfigCiphertext, &configMap); err != nil {
|
||||
return nil, nil, apperror.Internal("BACKUP_STORAGE_TARGET_DECRYPT_FAILED", "无法解密存储目标配置", err)
|
||||
}
|
||||
provider, err := storageRegistry.Create(ctx, storage.ParseProviderType(target.Type), configMap)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return provider, target, nil
|
||||
}
|
||||
103
server/internal/service/backup_execution_service_test.go
Normal file
103
server/internal/service/backup_execution_service_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"backupx/server/internal/backup"
|
||||
backupretention "backupx/server/internal/backup/retention"
|
||||
"backupx/server/internal/config"
|
||||
"backupx/server/internal/database"
|
||||
"backupx/server/internal/logger"
|
||||
"backupx/server/internal/model"
|
||||
"backupx/server/internal/repository"
|
||||
"backupx/server/internal/storage"
|
||||
"backupx/server/internal/storage/codec"
|
||||
"backupx/server/internal/storage/localdisk"
|
||||
)
|
||||
|
||||
func newExecutionTestServices(t *testing.T) (*BackupExecutionService, *BackupRecordService, repository.BackupTaskRepository, repository.StorageTargetRepository, repository.BackupRecordRepository, string, string) {
|
||||
t.Helper()
|
||||
baseDir := t.TempDir()
|
||||
storageDir := filepath.Join(baseDir, "storage")
|
||||
sourceDir := filepath.Join(baseDir, "source")
|
||||
if err := os.MkdirAll(sourceDir, 0o755); err != nil {
|
||||
t.Fatalf("MkdirAll returned error: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(sourceDir, "index.html"), []byte("hello"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile returned error: %v", err)
|
||||
}
|
||||
log, err := logger.New(config.LogConfig{Level: "error"})
|
||||
if err != nil {
|
||||
t.Fatalf("logger.New returned error: %v", err)
|
||||
}
|
||||
db, err := database.Open(config.DatabaseConfig{Path: filepath.Join(baseDir, "backupx.db")}, log)
|
||||
if err != nil {
|
||||
t.Fatalf("database.Open returned error: %v", err)
|
||||
}
|
||||
cipher := codec.NewConfigCipher("execution-secret")
|
||||
tasks := repository.NewBackupTaskRepository(db)
|
||||
targets := repository.NewStorageTargetRepository(db)
|
||||
records := repository.NewBackupRecordRepository(db)
|
||||
configCiphertext, err := cipher.EncryptJSON(map[string]any{"basePath": storageDir})
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptJSON returned error: %v", err)
|
||||
}
|
||||
if err := targets.Create(context.Background(), &model.StorageTarget{Name: "local", Type: string(storage.ProviderTypeLocalDisk), Enabled: true, ConfigCiphertext: configCiphertext, ConfigVersion: 1, LastTestStatus: "unknown"}); err != nil {
|
||||
t.Fatalf("Create storage target returned error: %v", err)
|
||||
}
|
||||
if err := tasks.Create(context.Background(), &model.BackupTask{Name: "site-files", Type: "file", Enabled: true, SourcePath: sourceDir, StorageTargetID: 1, RetentionDays: 30, Compression: "gzip", MaxBackups: 10, LastStatus: "idle"}); err != nil {
|
||||
t.Fatalf("Create backup task returned error: %v", err)
|
||||
}
|
||||
logHub := backup.NewLogHub()
|
||||
runnerRegistry := backup.NewRegistry(backup.NewFileRunner(), backup.NewMySQLRunner(nil), backup.NewSQLiteRunner(), backup.NewPostgreSQLRunner(nil))
|
||||
storageRegistry := storage.NewRegistry(localdisk.NewFactory())
|
||||
retentionService := backupretention.NewService(records)
|
||||
executionService := NewBackupExecutionService(tasks, records, targets, storageRegistry, runnerRegistry, logHub, retentionService, cipher, nil, "", 2)
|
||||
recordService := NewBackupRecordService(records, executionService, logHub)
|
||||
return executionService, recordService, tasks, targets, records, sourceDir, storageDir
|
||||
}
|
||||
|
||||
func TestBackupExecutionServiceRunTaskByIDSync(t *testing.T) {
|
||||
executionService, _, _, _, records, _, storageDir := newExecutionTestServices(t)
|
||||
detail, err := executionService.RunTaskByIDSync(context.Background(), 1)
|
||||
if err != nil {
|
||||
t.Fatalf("RunTaskByIDSync returned error: %v", err)
|
||||
}
|
||||
if detail.Status != "success" || detail.StoragePath == "" {
|
||||
t.Fatalf("unexpected record detail: %#v", detail)
|
||||
}
|
||||
stored, err := records.FindByID(context.Background(), detail.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("FindByID returned error: %v", err)
|
||||
}
|
||||
if stored == nil || stored.Status != "success" {
|
||||
t.Fatalf("unexpected stored record: %#v", stored)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(storageDir, filepath.FromSlash(detail.StoragePath))); err != nil {
|
||||
t.Fatalf("expected artifact in local storage: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackupRecordServiceRestore(t *testing.T) {
|
||||
executionService, recordService, _, _, _, sourceDir, _ := newExecutionTestServices(t)
|
||||
detail, err := executionService.RunTaskByIDSync(context.Background(), 1)
|
||||
if err != nil {
|
||||
t.Fatalf("RunTaskByIDSync returned error: %v", err)
|
||||
}
|
||||
if err := os.RemoveAll(sourceDir); err != nil {
|
||||
t.Fatalf("RemoveAll returned error: %v", err)
|
||||
}
|
||||
if err := recordService.Restore(context.Background(), detail.ID); err != nil {
|
||||
t.Fatalf("Restore returned error: %v", err)
|
||||
}
|
||||
content, err := os.ReadFile(filepath.Join(sourceDir, "index.html"))
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile returned error: %v", err)
|
||||
}
|
||||
if string(content) != "hello" {
|
||||
t.Fatalf("unexpected restored content: %s", string(content))
|
||||
}
|
||||
}
|
||||
134
server/internal/service/backup_record_service.go
Normal file
134
server/internal/service/backup_record_service.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/apperror"
|
||||
"backupx/server/internal/backup"
|
||||
"backupx/server/internal/model"
|
||||
"backupx/server/internal/repository"
|
||||
)
|
||||
|
||||
type BackupRecordListInput struct {
|
||||
TaskID *uint
|
||||
Status string
|
||||
DateFrom *time.Time
|
||||
DateTo *time.Time
|
||||
Limit int
|
||||
Offset int
|
||||
}
|
||||
|
||||
type BackupRecordSummary struct {
|
||||
ID uint `json:"id"`
|
||||
TaskID uint `json:"taskId"`
|
||||
TaskName string `json:"taskName"`
|
||||
StorageTargetID uint `json:"storageTargetId"`
|
||||
StorageTargetName string `json:"storageTargetName"`
|
||||
Status string `json:"status"`
|
||||
FileName string `json:"fileName"`
|
||||
FileSize int64 `json:"fileSize"`
|
||||
StoragePath string `json:"storagePath"`
|
||||
DurationSeconds int `json:"durationSeconds"`
|
||||
ErrorMessage string `json:"errorMessage"`
|
||||
StartedAt time.Time `json:"startedAt"`
|
||||
CompletedAt *time.Time `json:"completedAt,omitempty"`
|
||||
}
|
||||
|
||||
type BackupRecordDetail struct {
|
||||
BackupRecordSummary
|
||||
LogContent string `json:"logContent"`
|
||||
LogEvents []backup.LogEvent `json:"logEvents,omitempty"`
|
||||
}
|
||||
|
||||
type BackupRecordService struct {
|
||||
records repository.BackupRecordRepository
|
||||
execution *BackupExecutionService
|
||||
logHub *backup.LogHub
|
||||
}
|
||||
|
||||
func NewBackupRecordService(records repository.BackupRecordRepository, execution *BackupExecutionService, logHub *backup.LogHub) *BackupRecordService {
|
||||
return &BackupRecordService{records: records, execution: execution, logHub: logHub}
|
||||
}
|
||||
|
||||
func (s *BackupRecordService) List(ctx context.Context, input BackupRecordListInput) ([]BackupRecordSummary, error) {
|
||||
items, err := s.records.List(ctx, repository.BackupRecordListOptions{TaskID: input.TaskID, Status: strings.TrimSpace(input.Status), DateFrom: input.DateFrom, DateTo: input.DateTo, Limit: input.Limit, Offset: input.Offset})
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("BACKUP_RECORD_LIST_FAILED", "无法获取备份记录列表", err)
|
||||
}
|
||||
result := make([]BackupRecordSummary, 0, len(items))
|
||||
for _, item := range items {
|
||||
result = append(result, toBackupRecordSummary(&item))
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *BackupRecordService) Get(ctx context.Context, id uint) (*BackupRecordDetail, error) {
|
||||
item, err := s.records.FindByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("BACKUP_RECORD_GET_FAILED", "无法获取备份记录详情", err)
|
||||
}
|
||||
if item == nil {
|
||||
return nil, apperror.New(404, "BACKUP_RECORD_NOT_FOUND", "备份记录不存在", err)
|
||||
}
|
||||
return toBackupRecordDetail(item, s.logHub), nil
|
||||
}
|
||||
|
||||
func (s *BackupRecordService) SubscribeLogs(ctx context.Context, id uint, buffer int) (<-chan backup.LogEvent, func(), error) {
|
||||
item, err := s.records.FindByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, nil, apperror.Internal("BACKUP_RECORD_GET_FAILED", "无法获取备份记录详情", err)
|
||||
}
|
||||
if item == nil {
|
||||
return nil, nil, apperror.New(404, "BACKUP_RECORD_NOT_FOUND", "备份记录不存在", err)
|
||||
}
|
||||
channel, cancel := s.logHub.Subscribe(id, buffer)
|
||||
return channel, cancel, nil
|
||||
}
|
||||
|
||||
func (s *BackupRecordService) Download(ctx context.Context, id uint) (*DownloadedArtifact, error) {
|
||||
return s.execution.DownloadRecord(ctx, id)
|
||||
}
|
||||
|
||||
func (s *BackupRecordService) Restore(ctx context.Context, id uint) error {
|
||||
return s.execution.RestoreRecord(ctx, id)
|
||||
}
|
||||
|
||||
func (s *BackupRecordService) Delete(ctx context.Context, id uint) error {
|
||||
return s.execution.DeleteRecord(ctx, id)
|
||||
}
|
||||
|
||||
func toBackupRecordSummary(item *model.BackupRecord) BackupRecordSummary {
|
||||
return BackupRecordSummary{
|
||||
ID: item.ID,
|
||||
TaskID: item.TaskID,
|
||||
TaskName: item.Task.Name,
|
||||
StorageTargetID: item.StorageTargetID,
|
||||
StorageTargetName: item.StorageTarget.Name,
|
||||
Status: item.Status,
|
||||
FileName: item.FileName,
|
||||
FileSize: item.FileSize,
|
||||
StoragePath: item.StoragePath,
|
||||
DurationSeconds: item.DurationSeconds,
|
||||
ErrorMessage: item.ErrorMessage,
|
||||
StartedAt: item.StartedAt,
|
||||
CompletedAt: item.CompletedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func toBackupRecordDetail(item *model.BackupRecord, logHub *backup.LogHub) *BackupRecordDetail {
|
||||
detail := &BackupRecordDetail{BackupRecordSummary: toBackupRecordSummary(item), LogContent: item.LogContent}
|
||||
if item.Status == "running" && logHub != nil {
|
||||
events := logHub.Snapshot(item.ID)
|
||||
detail.LogEvents = events
|
||||
if len(events) > 0 {
|
||||
lines := make([]string, 0, len(events))
|
||||
for _, event := range events {
|
||||
lines = append(lines, event.Message)
|
||||
}
|
||||
detail.LogContent = strings.Join(lines, "\n")
|
||||
}
|
||||
}
|
||||
return detail
|
||||
}
|
||||
417
server/internal/service/backup_task_service.go
Normal file
417
server/internal/service/backup_task_service.go
Normal file
@@ -0,0 +1,417 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/apperror"
|
||||
"backupx/server/internal/model"
|
||||
"backupx/server/internal/repository"
|
||||
"backupx/server/internal/storage/codec"
|
||||
)
|
||||
|
||||
const backupTaskMaskedValue = "********"
|
||||
|
||||
type BackupTaskUpsertInput struct {
|
||||
Name string `json:"name" binding:"required,min=1,max=100"`
|
||||
Type string `json:"type" binding:"required,oneof=file mysql sqlite postgresql pgsql"`
|
||||
Enabled bool `json:"enabled"`
|
||||
CronExpr string `json:"cronExpr" binding:"max=64"`
|
||||
SourcePath string `json:"sourcePath" binding:"max=500"`
|
||||
ExcludePatterns []string `json:"excludePatterns"`
|
||||
DBHost string `json:"dbHost" binding:"max=255"`
|
||||
DBPort int `json:"dbPort"`
|
||||
DBUser string `json:"dbUser" binding:"max=100"`
|
||||
DBPassword string `json:"dbPassword" binding:"max=255"`
|
||||
DBName string `json:"dbName" binding:"max=255"`
|
||||
DBPath string `json:"dbPath" binding:"max=500"`
|
||||
StorageTargetID uint `json:"storageTargetId" binding:"required"`
|
||||
RetentionDays int `json:"retentionDays"`
|
||||
Compression string `json:"compression" binding:"omitempty,oneof=gzip none"`
|
||||
Encrypt bool `json:"encrypt"`
|
||||
MaxBackups int `json:"maxBackups"`
|
||||
}
|
||||
|
||||
type BackupTaskToggleInput struct {
|
||||
Enabled *bool `json:"enabled"`
|
||||
}
|
||||
|
||||
type BackupTaskSummary struct {
|
||||
ID uint `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Enabled bool `json:"enabled"`
|
||||
CronExpr string `json:"cronExpr"`
|
||||
StorageTargetID uint `json:"storageTargetId"`
|
||||
StorageTargetName string `json:"storageTargetName"`
|
||||
RetentionDays int `json:"retentionDays"`
|
||||
Compression string `json:"compression"`
|
||||
Encrypt bool `json:"encrypt"`
|
||||
MaxBackups int `json:"maxBackups"`
|
||||
LastRunAt *time.Time `json:"lastRunAt,omitempty"`
|
||||
LastStatus string `json:"lastStatus"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
type BackupTaskDetail struct {
|
||||
BackupTaskSummary
|
||||
SourcePath string `json:"sourcePath"`
|
||||
ExcludePatterns []string `json:"excludePatterns"`
|
||||
DBHost string `json:"dbHost"`
|
||||
DBPort int `json:"dbPort"`
|
||||
DBUser string `json:"dbUser"`
|
||||
DBName string `json:"dbName"`
|
||||
DBPath string `json:"dbPath"`
|
||||
MaskedFields []string `json:"maskedFields,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
type BackupTaskScheduler interface {
|
||||
SyncTask(ctx context.Context, task *model.BackupTask) error
|
||||
RemoveTask(ctx context.Context, taskID uint) error
|
||||
}
|
||||
|
||||
type BackupTaskService struct {
|
||||
tasks repository.BackupTaskRepository
|
||||
targets repository.StorageTargetRepository
|
||||
cipher *codec.ConfigCipher
|
||||
scheduler BackupTaskScheduler
|
||||
}
|
||||
|
||||
func NewBackupTaskService(
|
||||
tasks repository.BackupTaskRepository,
|
||||
targets repository.StorageTargetRepository,
|
||||
cipher *codec.ConfigCipher,
|
||||
) *BackupTaskService {
|
||||
return &BackupTaskService{tasks: tasks, targets: targets, cipher: cipher}
|
||||
}
|
||||
|
||||
func (s *BackupTaskService) SetScheduler(scheduler BackupTaskScheduler) {
|
||||
s.scheduler = scheduler
|
||||
}
|
||||
|
||||
func (s *BackupTaskService) List(ctx context.Context) ([]BackupTaskSummary, error) {
|
||||
items, err := s.tasks.List(ctx, repository.BackupTaskListOptions{})
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("BACKUP_TASK_LIST_FAILED", "无法获取备份任务列表", err)
|
||||
}
|
||||
result := make([]BackupTaskSummary, 0, len(items))
|
||||
for _, item := range items {
|
||||
result = append(result, toBackupTaskSummary(&item))
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *BackupTaskService) Get(ctx context.Context, id uint) (*BackupTaskDetail, error) {
|
||||
item, err := s.tasks.FindByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("BACKUP_TASK_GET_FAILED", "无法获取备份任务详情", err)
|
||||
}
|
||||
if item == nil {
|
||||
return nil, apperror.New(http.StatusNotFound, "BACKUP_TASK_NOT_FOUND", "备份任务不存在", fmt.Errorf("backup task %d not found", id))
|
||||
}
|
||||
return s.toDetail(item)
|
||||
}
|
||||
|
||||
func (s *BackupTaskService) Create(ctx context.Context, input BackupTaskUpsertInput) (*BackupTaskDetail, error) {
|
||||
input.Type = normalizeBackupTaskType(input.Type)
|
||||
if err := s.validateInput(ctx, nil, input); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
existing, err := s.tasks.FindByName(ctx, strings.TrimSpace(input.Name))
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("BACKUP_TASK_LOOKUP_FAILED", "无法检查备份任务名称", err)
|
||||
}
|
||||
if existing != nil {
|
||||
return nil, apperror.Conflict("BACKUP_TASK_NAME_EXISTS", "备份任务名称已存在", nil)
|
||||
}
|
||||
item, err := s.buildTask(nil, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.tasks.Create(ctx, item); err != nil {
|
||||
return nil, apperror.Internal("BACKUP_TASK_CREATE_FAILED", "无法创建备份任务", err)
|
||||
}
|
||||
if s.scheduler != nil {
|
||||
if err := s.scheduler.SyncTask(ctx, item); err != nil {
|
||||
return nil, apperror.Internal("BACKUP_TASK_SCHEDULE_FAILED", "无法同步备份任务调度", err)
|
||||
}
|
||||
}
|
||||
return s.Get(ctx, item.ID)
|
||||
}
|
||||
|
||||
func (s *BackupTaskService) Update(ctx context.Context, id uint, input BackupTaskUpsertInput) (*BackupTaskDetail, error) {
|
||||
input.Type = normalizeBackupTaskType(input.Type)
|
||||
existing, err := s.tasks.FindByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("BACKUP_TASK_GET_FAILED", "无法获取备份任务详情", err)
|
||||
}
|
||||
if existing == nil {
|
||||
return nil, apperror.New(http.StatusNotFound, "BACKUP_TASK_NOT_FOUND", "备份任务不存在", fmt.Errorf("backup task %d not found", id))
|
||||
}
|
||||
if err := s.validateInput(ctx, existing, input); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sameName, err := s.tasks.FindByName(ctx, strings.TrimSpace(input.Name))
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("BACKUP_TASK_LOOKUP_FAILED", "无法检查备份任务名称", err)
|
||||
}
|
||||
if sameName != nil && sameName.ID != existing.ID {
|
||||
return nil, apperror.Conflict("BACKUP_TASK_NAME_EXISTS", "备份任务名称已存在", nil)
|
||||
}
|
||||
item, err := s.buildTask(existing, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
item.ID = existing.ID
|
||||
item.CreatedAt = existing.CreatedAt
|
||||
if err := s.tasks.Update(ctx, item); err != nil {
|
||||
return nil, apperror.Internal("BACKUP_TASK_UPDATE_FAILED", "无法更新备份任务", err)
|
||||
}
|
||||
if s.scheduler != nil {
|
||||
if err := s.scheduler.SyncTask(ctx, item); err != nil {
|
||||
return nil, apperror.Internal("BACKUP_TASK_SCHEDULE_FAILED", "无法同步备份任务调度", err)
|
||||
}
|
||||
}
|
||||
return s.Get(ctx, item.ID)
|
||||
}
|
||||
|
||||
func (s *BackupTaskService) Delete(ctx context.Context, id uint) error {
|
||||
existing, err := s.tasks.FindByID(ctx, id)
|
||||
if err != nil {
|
||||
return apperror.Internal("BACKUP_TASK_GET_FAILED", "无法获取备份任务详情", err)
|
||||
}
|
||||
if existing == nil {
|
||||
return apperror.New(http.StatusNotFound, "BACKUP_TASK_NOT_FOUND", "备份任务不存在", fmt.Errorf("backup task %d not found", id))
|
||||
}
|
||||
if s.scheduler != nil {
|
||||
if err := s.scheduler.RemoveTask(ctx, id); err != nil {
|
||||
return apperror.Internal("BACKUP_TASK_SCHEDULE_FAILED", "无法移除备份任务调度", err)
|
||||
}
|
||||
}
|
||||
if err := s.tasks.Delete(ctx, id); err != nil {
|
||||
return apperror.Internal("BACKUP_TASK_DELETE_FAILED", "无法删除备份任务", err)
|
||||
}
|
||||
if s.scheduler != nil {
|
||||
_ = s.scheduler.RemoveTask(ctx, id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupTaskService) Toggle(ctx context.Context, id uint, enabled bool) (*BackupTaskSummary, error) {
|
||||
item, err := s.tasks.FindByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("BACKUP_TASK_GET_FAILED", "无法获取备份任务详情", err)
|
||||
}
|
||||
if item == nil {
|
||||
return nil, apperror.New(http.StatusNotFound, "BACKUP_TASK_NOT_FOUND", "备份任务不存在", fmt.Errorf("backup task %d not found", id))
|
||||
}
|
||||
item.Enabled = enabled
|
||||
if err := s.tasks.Update(ctx, item); err != nil {
|
||||
return nil, apperror.Internal("BACKUP_TASK_UPDATE_FAILED", "无法更新备份任务状态", err)
|
||||
}
|
||||
if s.scheduler != nil {
|
||||
if err := s.scheduler.SyncTask(ctx, item); err != nil {
|
||||
return nil, apperror.Internal("BACKUP_TASK_SCHEDULE_FAILED", "无法同步备份任务调度", err)
|
||||
}
|
||||
}
|
||||
returnPtr, err := s.tasks.FindByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("BACKUP_TASK_GET_FAILED", "无法获取备份任务详情", err)
|
||||
}
|
||||
returnValue := toBackupTaskSummary(returnPtr)
|
||||
return &returnValue, nil
|
||||
}
|
||||
|
||||
func (s *BackupTaskService) validateInput(ctx context.Context, existing *model.BackupTask, input BackupTaskUpsertInput) error {
|
||||
if strings.TrimSpace(input.Name) == "" {
|
||||
return apperror.BadRequest("BACKUP_TASK_INVALID", "任务名称不能为空", nil)
|
||||
}
|
||||
if input.StorageTargetID == 0 {
|
||||
return apperror.BadRequest("BACKUP_TASK_INVALID", "请选择存储目标", nil)
|
||||
}
|
||||
target, err := s.targets.FindByID(ctx, input.StorageTargetID)
|
||||
if err != nil {
|
||||
return apperror.Internal("BACKUP_TASK_STORAGE_LOOKUP_FAILED", "无法检查存储目标", err)
|
||||
}
|
||||
if target == nil {
|
||||
return apperror.BadRequest("BACKUP_STORAGE_TARGET_INVALID", "关联的存储目标不存在", nil)
|
||||
}
|
||||
if input.RetentionDays < 0 {
|
||||
return apperror.BadRequest("BACKUP_TASK_INVALID", "保留天数不能小于 0", nil)
|
||||
}
|
||||
if input.MaxBackups < 0 {
|
||||
return apperror.BadRequest("BACKUP_TASK_INVALID", "最大保留份数不能小于 0", nil)
|
||||
}
|
||||
if input.Compression == "" {
|
||||
input.Compression = "gzip"
|
||||
}
|
||||
if strings.TrimSpace(input.CronExpr) != "" && len(strings.Fields(strings.TrimSpace(input.CronExpr))) < 5 {
|
||||
return apperror.BadRequest("BACKUP_TASK_INVALID", "Cron 表达式格式不正确", nil)
|
||||
}
|
||||
passwordRequired := existing == nil || existing.DBPasswordCiphertext == ""
|
||||
return validateTaskTypeSpecificFields(input, passwordRequired)
|
||||
}
|
||||
|
||||
func validateTaskTypeSpecificFields(input BackupTaskUpsertInput, passwordRequired bool) error {
|
||||
switch normalizeBackupTaskType(input.Type) {
|
||||
case "file":
|
||||
if strings.TrimSpace(input.SourcePath) == "" {
|
||||
return apperror.BadRequest("BACKUP_TASK_INVALID", "文件备份必须填写源路径", nil)
|
||||
}
|
||||
case "mysql", "postgresql":
|
||||
if strings.TrimSpace(input.DBHost) == "" {
|
||||
return apperror.BadRequest("BACKUP_TASK_INVALID", "数据库主机不能为空", nil)
|
||||
}
|
||||
if input.DBPort <= 0 {
|
||||
return apperror.BadRequest("BACKUP_TASK_INVALID", "数据库端口必须大于 0", nil)
|
||||
}
|
||||
if strings.TrimSpace(input.DBUser) == "" {
|
||||
return apperror.BadRequest("BACKUP_TASK_INVALID", "数据库用户名不能为空", nil)
|
||||
}
|
||||
if passwordRequired && strings.TrimSpace(input.DBPassword) == "" {
|
||||
return apperror.BadRequest("BACKUP_TASK_INVALID", "数据库密码不能为空", nil)
|
||||
}
|
||||
if strings.TrimSpace(input.DBName) == "" {
|
||||
return apperror.BadRequest("BACKUP_TASK_INVALID", "数据库名称不能为空", nil)
|
||||
}
|
||||
case "sqlite":
|
||||
if strings.TrimSpace(input.DBPath) == "" {
|
||||
return apperror.BadRequest("BACKUP_TASK_INVALID", "SQLite 备份必须填写数据库文件路径", nil)
|
||||
}
|
||||
default:
|
||||
return apperror.BadRequest("BACKUP_TASK_INVALID", "不支持的备份任务类型", nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupTaskService) buildTask(existing *model.BackupTask, input BackupTaskUpsertInput) (*model.BackupTask, error) {
|
||||
excludePatterns, err := encodeExcludePatterns(input.ExcludePatterns)
|
||||
if err != nil {
|
||||
return nil, apperror.BadRequest("BACKUP_TASK_INVALID", "排除规则格式不合法", err)
|
||||
}
|
||||
passwordCiphertext := ""
|
||||
if existing != nil {
|
||||
passwordCiphertext = existing.DBPasswordCiphertext
|
||||
}
|
||||
if text := strings.TrimSpace(input.DBPassword); text != "" && text != backupTaskMaskedValue {
|
||||
ciphertext, encryptErr := s.cipher.Encrypt([]byte(text))
|
||||
if encryptErr != nil {
|
||||
return nil, apperror.Internal("BACKUP_TASK_ENCRYPT_FAILED", "无法保存数据库密码", encryptErr)
|
||||
}
|
||||
passwordCiphertext = ciphertext
|
||||
}
|
||||
compression := strings.TrimSpace(input.Compression)
|
||||
if compression == "" {
|
||||
compression = "gzip"
|
||||
}
|
||||
maxBackups := input.MaxBackups
|
||||
if maxBackups == 0 {
|
||||
maxBackups = 10
|
||||
}
|
||||
item := &model.BackupTask{
|
||||
Name: strings.TrimSpace(input.Name),
|
||||
Type: normalizeBackupTaskType(input.Type),
|
||||
Enabled: input.Enabled,
|
||||
CronExpr: strings.TrimSpace(input.CronExpr),
|
||||
SourcePath: strings.TrimSpace(input.SourcePath),
|
||||
ExcludePatterns: excludePatterns,
|
||||
DBHost: strings.TrimSpace(input.DBHost),
|
||||
DBPort: input.DBPort,
|
||||
DBUser: strings.TrimSpace(input.DBUser),
|
||||
DBPasswordCiphertext: passwordCiphertext,
|
||||
DBName: strings.TrimSpace(input.DBName),
|
||||
DBPath: strings.TrimSpace(input.DBPath),
|
||||
StorageTargetID: input.StorageTargetID,
|
||||
RetentionDays: input.RetentionDays,
|
||||
Compression: compression,
|
||||
Encrypt: input.Encrypt,
|
||||
MaxBackups: maxBackups,
|
||||
LastStatus: "idle",
|
||||
}
|
||||
if existing != nil {
|
||||
item.LastRunAt = existing.LastRunAt
|
||||
item.LastStatus = existing.LastStatus
|
||||
item.CreatedAt = existing.CreatedAt
|
||||
}
|
||||
return item, nil
|
||||
}
|
||||
|
||||
func (s *BackupTaskService) toDetail(item *model.BackupTask) (*BackupTaskDetail, error) {
|
||||
excludePatterns, err := decodeExcludePatterns(item.ExcludePatterns)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("BACKUP_TASK_DECODE_FAILED", "无法解析备份任务配置", err)
|
||||
}
|
||||
detail := &BackupTaskDetail{
|
||||
BackupTaskSummary: toBackupTaskSummary(item),
|
||||
SourcePath: item.SourcePath,
|
||||
ExcludePatterns: excludePatterns,
|
||||
DBHost: item.DBHost,
|
||||
DBPort: item.DBPort,
|
||||
DBUser: item.DBUser,
|
||||
DBName: item.DBName,
|
||||
DBPath: item.DBPath,
|
||||
CreatedAt: item.CreatedAt,
|
||||
}
|
||||
if item.DBPasswordCiphertext != "" {
|
||||
detail.MaskedFields = []string{"dbPassword"}
|
||||
}
|
||||
return detail, nil
|
||||
}
|
||||
|
||||
func toBackupTaskSummary(item *model.BackupTask) BackupTaskSummary {
|
||||
storageTargetName := ""
|
||||
if item != nil {
|
||||
storageTargetName = item.StorageTarget.Name
|
||||
}
|
||||
return BackupTaskSummary{
|
||||
ID: item.ID,
|
||||
Name: item.Name,
|
||||
Type: normalizeBackupTaskType(item.Type),
|
||||
Enabled: item.Enabled,
|
||||
CronExpr: item.CronExpr,
|
||||
StorageTargetID: item.StorageTargetID,
|
||||
StorageTargetName: storageTargetName,
|
||||
RetentionDays: item.RetentionDays,
|
||||
Compression: item.Compression,
|
||||
Encrypt: item.Encrypt,
|
||||
MaxBackups: item.MaxBackups,
|
||||
LastRunAt: item.LastRunAt,
|
||||
LastStatus: item.LastStatus,
|
||||
UpdatedAt: item.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func encodeExcludePatterns(value []string) (string, error) {
|
||||
if len(value) == 0 {
|
||||
return "[]", nil
|
||||
}
|
||||
encoded, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(encoded), nil
|
||||
}
|
||||
|
||||
func decodeExcludePatterns(value string) ([]string, error) {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return []string{}, nil
|
||||
}
|
||||
var items []string
|
||||
if err := json.Unmarshal([]byte(value), &items); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func normalizeBackupTaskType(value string) string {
|
||||
normalized := strings.TrimSpace(strings.ToLower(value))
|
||||
if normalized == "pgsql" {
|
||||
return "postgresql"
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
119
server/internal/service/backup_task_service_test.go
Normal file
119
server/internal/service/backup_task_service_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"backupx/server/internal/config"
|
||||
"backupx/server/internal/database"
|
||||
"backupx/server/internal/logger"
|
||||
"backupx/server/internal/model"
|
||||
"backupx/server/internal/repository"
|
||||
"backupx/server/internal/storage/codec"
|
||||
)
|
||||
|
||||
func newBackupTaskServiceForTest(t *testing.T) (*BackupTaskService, repository.StorageTargetRepository, repository.BackupTaskRepository) {
|
||||
t.Helper()
|
||||
log, err := logger.New(config.LogConfig{Level: "error"})
|
||||
if err != nil {
|
||||
t.Fatalf("logger.New returned error: %v", err)
|
||||
}
|
||||
db, err := database.Open(config.DatabaseConfig{Path: filepath.Join(t.TempDir(), "backupx.db")}, log)
|
||||
if err != nil {
|
||||
t.Fatalf("database.Open returned error: %v", err)
|
||||
}
|
||||
targets := repository.NewStorageTargetRepository(db)
|
||||
tasks := repository.NewBackupTaskRepository(db)
|
||||
service := NewBackupTaskService(tasks, targets, codec.NewConfigCipher("task-service-secret"))
|
||||
return service, targets, tasks
|
||||
}
|
||||
|
||||
func TestBackupTaskServiceCreateAndGet(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
service, targets, _ := newBackupTaskServiceForTest(t)
|
||||
if err := targets.Create(ctx, &model.StorageTarget{Name: "local", Type: "local_disk", Enabled: true, ConfigCiphertext: "ciphertext", ConfigVersion: 1, LastTestStatus: "unknown"}); err != nil {
|
||||
t.Fatalf("seed storage target error: %v", err)
|
||||
}
|
||||
created, err := service.Create(ctx, BackupTaskUpsertInput{
|
||||
Name: "site-files",
|
||||
Type: "file",
|
||||
Enabled: true,
|
||||
SourcePath: "/srv/site",
|
||||
ExcludePatterns: []string{"*.log", "node_modules"},
|
||||
StorageTargetID: 1,
|
||||
RetentionDays: 30,
|
||||
Compression: "gzip",
|
||||
MaxBackups: 10,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create returned error: %v", err)
|
||||
}
|
||||
if created.Name != "site-files" || len(created.ExcludePatterns) != 2 {
|
||||
t.Fatalf("unexpected created task: %#v", created)
|
||||
}
|
||||
loaded, err := service.Get(ctx, created.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Get returned error: %v", err)
|
||||
}
|
||||
if loaded.StorageTargetName != "local" {
|
||||
t.Fatalf("expected storage target name local, got %s", loaded.StorageTargetName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackupTaskServiceKeepsMaskedPasswordOnUpdate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
service, targets, tasks := newBackupTaskServiceForTest(t)
|
||||
if err := targets.Create(ctx, &model.StorageTarget{Name: "local", Type: "local_disk", Enabled: true, ConfigCiphertext: "ciphertext", ConfigVersion: 1, LastTestStatus: "unknown"}); err != nil {
|
||||
t.Fatalf("seed storage target error: %v", err)
|
||||
}
|
||||
created, err := service.Create(ctx, BackupTaskUpsertInput{
|
||||
Name: "mysql-prod",
|
||||
Type: "mysql",
|
||||
Enabled: true,
|
||||
DBHost: "127.0.0.1",
|
||||
DBPort: 3306,
|
||||
DBUser: "root",
|
||||
DBPassword: "secret",
|
||||
DBName: "app",
|
||||
StorageTargetID: 1,
|
||||
RetentionDays: 7,
|
||||
Compression: "gzip",
|
||||
MaxBackups: 5,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create returned error: %v", err)
|
||||
}
|
||||
stored, err := tasks.FindByID(ctx, created.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("FindByID returned error: %v", err)
|
||||
}
|
||||
originalCiphertext := stored.DBPasswordCiphertext
|
||||
updated, err := service.Update(ctx, created.ID, BackupTaskUpsertInput{
|
||||
Name: created.Name,
|
||||
Type: created.Type,
|
||||
Enabled: true,
|
||||
DBHost: "127.0.0.1",
|
||||
DBPort: 3306,
|
||||
DBUser: "root",
|
||||
DBPassword: "",
|
||||
DBName: "app_updated",
|
||||
StorageTargetID: 1,
|
||||
RetentionDays: 7,
|
||||
Compression: "gzip",
|
||||
MaxBackups: 5,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Update returned error: %v", err)
|
||||
}
|
||||
if len(updated.MaskedFields) != 1 || updated.MaskedFields[0] != "dbPassword" {
|
||||
t.Fatalf("expected masked dbPassword field, got %#v", updated.MaskedFields)
|
||||
}
|
||||
reloaded, err := tasks.FindByID(ctx, created.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("FindByID returned error: %v", err)
|
||||
}
|
||||
if reloaded.DBPasswordCiphertext != originalCiphertext {
|
||||
t.Fatalf("expected ciphertext unchanged")
|
||||
}
|
||||
}
|
||||
109
server/internal/service/dashboard_notification_service_test.go
Normal file
109
server/internal/service/dashboard_notification_service_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/config"
|
||||
"backupx/server/internal/database"
|
||||
"backupx/server/internal/logger"
|
||||
"backupx/server/internal/model"
|
||||
"backupx/server/internal/notify"
|
||||
"backupx/server/internal/repository"
|
||||
"backupx/server/internal/storage/codec"
|
||||
)
|
||||
|
||||
type fakeNotifier struct {
|
||||
typeName string
|
||||
messages []notify.Message
|
||||
lastConfig map[string]any
|
||||
}
|
||||
|
||||
func (n *fakeNotifier) Type() string { return n.typeName }
|
||||
func (n *fakeNotifier) SensitiveFields() []string { return []string{"secret"} }
|
||||
func (n *fakeNotifier) Validate(config map[string]any) error {
|
||||
if config["url"] == nil {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (n *fakeNotifier) Send(_ context.Context, config map[string]any, message notify.Message) error {
|
||||
n.lastConfig = config
|
||||
n.messages = append(n.messages, message)
|
||||
return nil
|
||||
}
|
||||
|
||||
func newDashboardNotificationTestDeps(t *testing.T) (*DashboardService, *NotificationService, *fakeNotifier, repository.BackupTaskRepository, repository.BackupRecordRepository, repository.NotificationRepository) {
|
||||
t.Helper()
|
||||
log, err := logger.New(config.LogConfig{Level: "error"})
|
||||
if err != nil {
|
||||
t.Fatalf("logger.New returned error: %v", err)
|
||||
}
|
||||
db, err := database.Open(config.DatabaseConfig{Path: filepath.Join(t.TempDir(), "backupx.db")}, log)
|
||||
if err != nil {
|
||||
t.Fatalf("database.Open returned error: %v", err)
|
||||
}
|
||||
tasks := repository.NewBackupTaskRepository(db)
|
||||
records := repository.NewBackupRecordRepository(db)
|
||||
targets := repository.NewStorageTargetRepository(db)
|
||||
notifications := repository.NewNotificationRepository(db)
|
||||
if err := targets.Create(context.Background(), &model.StorageTarget{Name: "local", Type: "local_disk", Enabled: true, ConfigCiphertext: "ciphertext", ConfigVersion: 1, LastTestStatus: "unknown"}); err != nil {
|
||||
t.Fatalf("Create storage target returned error: %v", err)
|
||||
}
|
||||
fake := &fakeNotifier{typeName: "webhook"}
|
||||
registry := notify.NewRegistry(fake)
|
||||
cipher := codec.NewConfigCipher("notify-secret")
|
||||
dashboardService := NewDashboardService(tasks, records, targets)
|
||||
notificationService := NewNotificationService(notifications, registry, cipher)
|
||||
return dashboardService, notificationService, fake, tasks, records, notifications
|
||||
}
|
||||
|
||||
func TestDashboardServiceStats(t *testing.T) {
|
||||
dashboardService, _, _, tasks, records, _ := newDashboardNotificationTestDeps(t)
|
||||
ctx := context.Background()
|
||||
if err := tasks.Create(ctx, &model.BackupTask{Name: "site", Type: "file", Enabled: true, SourcePath: "/srv/site", StorageTargetID: 1, RetentionDays: 30, Compression: "gzip", MaxBackups: 10, LastStatus: "success"}); err != nil {
|
||||
t.Fatalf("Create task returned error: %v", err)
|
||||
}
|
||||
startedAt := time.Now().UTC().Add(-time.Hour)
|
||||
completedAt := time.Now().UTC()
|
||||
if err := records.Create(ctx, &model.BackupRecord{TaskID: 1, StorageTargetID: 1, Status: "success", FileName: "site.tar.gz", FileSize: 2048, StoragePath: "site/2026/03/07/site.tar.gz", DurationSeconds: 30, StartedAt: startedAt, CompletedAt: &completedAt}); err != nil {
|
||||
t.Fatalf("Create record returned error: %v", err)
|
||||
}
|
||||
stats, err := dashboardService.Stats(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Stats returned error: %v", err)
|
||||
}
|
||||
if stats.TotalTasks != 1 || stats.TotalRecords != 1 || stats.TotalBackupBytes != 2048 {
|
||||
t.Fatalf("unexpected stats: %#v", stats)
|
||||
}
|
||||
if len(stats.RecentRecords) != 1 || len(stats.StorageUsage) != 1 {
|
||||
t.Fatalf("expected recent records and storage usage, got %#v", stats)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotificationServiceCreateAndDispatch(t *testing.T) {
|
||||
_, notificationService, fake, _, _, notifications := newDashboardNotificationTestDeps(t)
|
||||
ctx := context.Background()
|
||||
created, err := notificationService.Create(ctx, NotificationUpsertInput{Name: "ops", Type: "webhook", Enabled: true, OnSuccess: true, OnFailure: true, Config: map[string]any{"url": "https://example.invalid", "secret": "top-secret"}})
|
||||
if err != nil {
|
||||
t.Fatalf("Create returned error: %v", err)
|
||||
}
|
||||
if len(created.MaskedFields) != 1 || created.MaskedFields[0] != "secret" {
|
||||
t.Fatalf("unexpected masked fields: %#v", created.MaskedFields)
|
||||
}
|
||||
item, err := notifications.FindByID(ctx, created.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("FindByID returned error: %v", err)
|
||||
}
|
||||
if item == nil || item.ConfigCiphertext == "" {
|
||||
t.Fatalf("expected encrypted notification config")
|
||||
}
|
||||
if err := notificationService.NotifyBackupResult(ctx, BackupExecutionNotification{Task: &model.BackupTask{Name: "site"}, Record: &model.BackupRecord{ID: 1, Status: "success", StartedAt: time.Now().UTC()}, Error: nil}); err != nil {
|
||||
t.Fatalf("NotifyBackupResult returned error: %v", err)
|
||||
}
|
||||
if len(fake.messages) != 1 {
|
||||
t.Fatalf("expected one notification message, got %d", len(fake.messages))
|
||||
}
|
||||
}
|
||||
109
server/internal/service/dashboard_service.go
Normal file
109
server/internal/service/dashboard_service.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/apperror"
|
||||
"backupx/server/internal/repository"
|
||||
)
|
||||
|
||||
type DashboardStorageUsageItem struct {
|
||||
StorageTargetID uint `json:"storageTargetId"`
|
||||
TargetName string `json:"targetName"`
|
||||
TotalSize int64 `json:"totalSize"`
|
||||
}
|
||||
|
||||
type DashboardStats struct {
|
||||
TotalTasks int64 `json:"totalTasks"`
|
||||
EnabledTasks int64 `json:"enabledTasks"`
|
||||
TotalRecords int64 `json:"totalRecords"`
|
||||
SuccessRate float64 `json:"successRate"`
|
||||
TotalBackupBytes int64 `json:"totalBackupBytes"`
|
||||
LastBackupAt *time.Time `json:"lastBackupAt,omitempty"`
|
||||
RecentRecords []BackupRecordSummary `json:"recentRecords"`
|
||||
StorageUsage []DashboardStorageUsageItem `json:"storageUsage"`
|
||||
}
|
||||
|
||||
type DashboardService struct {
|
||||
tasks repository.BackupTaskRepository
|
||||
records repository.BackupRecordRepository
|
||||
targets repository.StorageTargetRepository
|
||||
}
|
||||
|
||||
func NewDashboardService(tasks repository.BackupTaskRepository, records repository.BackupRecordRepository, targets repository.StorageTargetRepository) *DashboardService {
|
||||
return &DashboardService{tasks: tasks, records: records, targets: targets}
|
||||
}
|
||||
|
||||
func (s *DashboardService) Stats(ctx context.Context) (*DashboardStats, error) {
|
||||
totalTasks, err := s.tasks.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("DASHBOARD_STATS_FAILED", "无法统计备份任务数量", err)
|
||||
}
|
||||
enabledTasks, err := s.tasks.CountEnabled(ctx)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("DASHBOARD_STATS_FAILED", "无法统计启用任务数量", err)
|
||||
}
|
||||
totalRecords, err := s.records.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("DASHBOARD_STATS_FAILED", "无法统计备份记录数量", err)
|
||||
}
|
||||
since := time.Now().UTC().AddDate(0, 0, -30)
|
||||
recentRecordsCount, err := s.records.CountSince(ctx, since)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("DASHBOARD_STATS_FAILED", "无法统计最近记录数量", err)
|
||||
}
|
||||
successRecordsCount, err := s.records.CountSuccessSince(ctx, since)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("DASHBOARD_STATS_FAILED", "无法统计最近成功记录数量", err)
|
||||
}
|
||||
totalBackupBytes, err := s.records.SumFileSize(ctx)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("DASHBOARD_STATS_FAILED", "无法统计备份总量", err)
|
||||
}
|
||||
recentRecords, err := s.records.ListRecent(ctx, 10)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("DASHBOARD_STATS_FAILED", "无法获取最近备份记录", err)
|
||||
}
|
||||
targetList, err := s.targets.List(ctx)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("DASHBOARD_STATS_FAILED", "无法获取存储目标信息", err)
|
||||
}
|
||||
targetNames := make(map[uint]string, len(targetList))
|
||||
for _, item := range targetList {
|
||||
targetNames[item.ID] = item.Name
|
||||
}
|
||||
usageItems, err := s.records.StorageUsage(ctx)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("DASHBOARD_STATS_FAILED", "无法统计存储使用量", err)
|
||||
}
|
||||
storageUsage := make([]DashboardStorageUsageItem, 0, len(usageItems))
|
||||
for _, item := range usageItems {
|
||||
storageUsage = append(storageUsage, DashboardStorageUsageItem{StorageTargetID: item.StorageTargetID, TargetName: targetNames[item.StorageTargetID], TotalSize: item.TotalSize})
|
||||
}
|
||||
result := &DashboardStats{TotalTasks: totalTasks, EnabledTasks: enabledTasks, TotalRecords: totalRecords, TotalBackupBytes: totalBackupBytes, RecentRecords: make([]BackupRecordSummary, 0, len(recentRecords)), StorageUsage: storageUsage}
|
||||
if recentRecordsCount > 0 {
|
||||
result.SuccessRate = float64(successRecordsCount) / float64(recentRecordsCount)
|
||||
}
|
||||
if len(recentRecords) > 0 {
|
||||
result.LastBackupAt = &recentRecords[0].StartedAt
|
||||
}
|
||||
for _, item := range recentRecords {
|
||||
result.RecentRecords = append(result.RecentRecords, toBackupRecordSummary(&item))
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) Timeline(ctx context.Context, days int) ([]repository.BackupTimelinePoint, error) {
|
||||
if days <= 0 {
|
||||
days = 30
|
||||
}
|
||||
items, err := s.records.TimelineSince(ctx, time.Now().UTC().AddDate(0, 0, -days))
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("DASHBOARD_TIMELINE_FAILED", "无法获取备份时间线", err)
|
||||
}
|
||||
if items == nil {
|
||||
items = []repository.BackupTimelinePoint{}
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
123
server/internal/service/google_drive_oauth_service.go
Normal file
123
server/internal/service/google_drive_oauth_service.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
"google.golang.org/api/drive/v3"
|
||||
|
||||
"backupx/server/internal/model"
|
||||
"backupx/server/internal/repository"
|
||||
"backupx/server/internal/security"
|
||||
"backupx/server/internal/storage"
|
||||
"backupx/server/internal/storage/codec"
|
||||
)
|
||||
|
||||
type GoogleDriveOAuthResult struct {
|
||||
TargetID *uint
|
||||
Config storage.GoogleDriveConfig
|
||||
State string
|
||||
}
|
||||
|
||||
type GoogleDriveOAuthService struct {
|
||||
sessions repository.OAuthSessionRepository
|
||||
cipher *codec.Cipher
|
||||
now func() time.Time
|
||||
generateState func() (string, error)
|
||||
exchangeCode func(context.Context, *oauth2.Config, string) (*oauth2.Token, error)
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
type googleDriveOAuthPayload struct {
|
||||
TargetID *uint `json:"targetId,omitempty"`
|
||||
Config storage.GoogleDriveConfig `json:"config"`
|
||||
}
|
||||
|
||||
func NewGoogleDriveOAuthService(sessions repository.OAuthSessionRepository, cipher *codec.Cipher) *GoogleDriveOAuthService {
|
||||
return &GoogleDriveOAuthService{
|
||||
sessions: sessions,
|
||||
cipher: cipher,
|
||||
now: func() time.Time { return time.Now().UTC() },
|
||||
generateState: func() (string, error) {
|
||||
return security.GenerateSecret(24)
|
||||
},
|
||||
exchangeCode: func(ctx context.Context, config *oauth2.Config, code string) (*oauth2.Token, error) {
|
||||
return config.Exchange(ctx, code)
|
||||
},
|
||||
ttl: 10 * time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GoogleDriveOAuthService) Start(ctx context.Context, targetID *uint, cfg storage.GoogleDriveConfig) (string, string, error) {
|
||||
if strings.TrimSpace(cfg.ClientID) == "" || strings.TrimSpace(cfg.ClientSecret) == "" {
|
||||
return "", "", fmt.Errorf("google drive client credentials are required")
|
||||
}
|
||||
if strings.TrimSpace(cfg.RedirectURL) == "" {
|
||||
return "", "", fmt.Errorf("google drive redirect url is required")
|
||||
}
|
||||
state, err := s.generateState()
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("generate oauth state: %w", err)
|
||||
}
|
||||
payload := googleDriveOAuthPayload{TargetID: targetID, Config: cfg}
|
||||
ciphertext, err := s.cipher.EncryptValue(payload)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("encrypt oauth payload: %w", err)
|
||||
}
|
||||
now := s.now()
|
||||
session := &model.OAuthSession{ProviderType: string(storage.ProviderTypeGoogleDrive), State: state, PayloadCiphertext: ciphertext, TargetID: targetID, ExpiresAt: now.Add(s.ttl)}
|
||||
if err := s.sessions.Create(ctx, session); err != nil {
|
||||
return "", "", fmt.Errorf("create oauth session: %w", err)
|
||||
}
|
||||
oauthConfig := s.oauthConfig(cfg)
|
||||
url := oauthConfig.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.ApprovalForce)
|
||||
return url, state, nil
|
||||
}
|
||||
|
||||
func (s *GoogleDriveOAuthService) Complete(ctx context.Context, state string, code string) (*GoogleDriveOAuthResult, error) {
|
||||
session, err := s.sessions.FindByState(ctx, state)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("find oauth session: %w", err)
|
||||
}
|
||||
if session == nil {
|
||||
return nil, fmt.Errorf("google drive oauth state not found")
|
||||
}
|
||||
now := s.now()
|
||||
if session.UsedAt != nil {
|
||||
return nil, fmt.Errorf("google drive oauth state already used")
|
||||
}
|
||||
if now.After(session.ExpiresAt) {
|
||||
return nil, fmt.Errorf("google drive oauth state expired")
|
||||
}
|
||||
var payload googleDriveOAuthPayload
|
||||
if err := s.cipher.DecryptValue(session.PayloadCiphertext, &payload); err != nil {
|
||||
return nil, fmt.Errorf("decrypt oauth session payload: %w", err)
|
||||
}
|
||||
token, err := s.exchangeCode(ctx, s.oauthConfig(payload.Config), code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("exchange google drive oauth code: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(token.RefreshToken) == "" {
|
||||
return nil, fmt.Errorf("google drive oauth response missing refresh token")
|
||||
}
|
||||
payload.Config.RefreshToken = token.RefreshToken
|
||||
session.UsedAt = &now
|
||||
if err := s.sessions.Update(ctx, session); err != nil {
|
||||
return nil, fmt.Errorf("mark oauth session used: %w", err)
|
||||
}
|
||||
return &GoogleDriveOAuthResult{TargetID: payload.TargetID, Config: payload.Config, State: state}, nil
|
||||
}
|
||||
|
||||
func (s *GoogleDriveOAuthService) oauthConfig(cfg storage.GoogleDriveConfig) *oauth2.Config {
|
||||
return &oauth2.Config{
|
||||
ClientID: cfg.ClientID,
|
||||
ClientSecret: cfg.ClientSecret,
|
||||
RedirectURL: cfg.RedirectURL,
|
||||
Endpoint: google.Endpoint,
|
||||
Scopes: []string{drive.DriveScope},
|
||||
}
|
||||
}
|
||||
61
server/internal/service/google_drive_oauth_service_test.go
Normal file
61
server/internal/service/google_drive_oauth_service_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"backupx/server/internal/config"
|
||||
"backupx/server/internal/database"
|
||||
"backupx/server/internal/logger"
|
||||
"backupx/server/internal/repository"
|
||||
"backupx/server/internal/storage"
|
||||
"backupx/server/internal/storage/codec"
|
||||
)
|
||||
|
||||
func TestGoogleDriveOAuthServiceStartAndComplete(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
log, err := logger.New(config.LogConfig{Level: "error"})
|
||||
if err != nil {
|
||||
t.Fatalf("logger.New returned error: %v", err)
|
||||
}
|
||||
db, err := database.Open(config.DatabaseConfig{Path: filepath.Join(tempDir, "backupx.db")}, log)
|
||||
if err != nil {
|
||||
t.Fatalf("database.Open returned error: %v", err)
|
||||
}
|
||||
sessions := repository.NewOAuthSessionRepository(db)
|
||||
service := NewGoogleDriveOAuthService(sessions, codec.New("encryption-secret"))
|
||||
service.now = func() time.Time { return time.Date(2026, 3, 7, 0, 0, 0, 0, time.UTC) }
|
||||
service.generateState = func() (string, error) { return "oauth-state", nil }
|
||||
service.exchangeCode = func(context.Context, *oauth2.Config, string) (*oauth2.Token, error) {
|
||||
return &oauth2.Token{RefreshToken: "refresh-token"}, nil
|
||||
}
|
||||
|
||||
url, state, err := service.Start(context.Background(), nil, storage.GoogleDriveConfig{
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
RedirectURL: "http://localhost:8340/api/storage-targets/google-drive/callback",
|
||||
FolderID: "folder-id",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Start returned error: %v", err)
|
||||
}
|
||||
if state != "oauth-state" {
|
||||
t.Fatalf("expected deterministic state, got %s", state)
|
||||
}
|
||||
if !strings.Contains(url, "oauth-state") {
|
||||
t.Fatalf("expected auth url to contain state, got %s", url)
|
||||
}
|
||||
|
||||
result, err := service.Complete(context.Background(), state, "auth-code")
|
||||
if err != nil {
|
||||
t.Fatalf("Complete returned error: %v", err)
|
||||
}
|
||||
if result.Config.RefreshToken != "refresh-token" {
|
||||
t.Fatalf("expected refresh token to be persisted")
|
||||
}
|
||||
}
|
||||
234
server/internal/service/node_service.go
Normal file
234
server/internal/service/node_service.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/apperror"
|
||||
"backupx/server/internal/model"
|
||||
"backupx/server/internal/repository"
|
||||
)
|
||||
|
||||
// NodeSummary is the API response for node listings.
|
||||
type NodeSummary struct {
|
||||
ID uint `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Hostname string `json:"hostname"`
|
||||
IPAddress string `json:"ipAddress"`
|
||||
Status string `json:"status"`
|
||||
IsLocal bool `json:"isLocal"`
|
||||
OS string `json:"os"`
|
||||
Arch string `json:"arch"`
|
||||
AgentVersion string `json:"agentVersion"`
|
||||
LastSeen time.Time `json:"lastSeen"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
// NodeCreateInput is the input for creating a new remote node.
|
||||
type NodeCreateInput struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
}
|
||||
|
||||
// NodeService manages the cluster nodes.
|
||||
type NodeService struct {
|
||||
repo repository.NodeRepository
|
||||
}
|
||||
|
||||
func NewNodeService(repo repository.NodeRepository) *NodeService {
|
||||
return &NodeService{repo: repo}
|
||||
}
|
||||
|
||||
// EnsureLocalNode creates the default "local" node if it does not exist.
|
||||
func (s *NodeService) EnsureLocalNode(ctx context.Context) error {
|
||||
existing, err := s.repo.FindLocal(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if existing != nil {
|
||||
existing.Status = model.NodeStatusOnline
|
||||
existing.LastSeen = time.Now().UTC()
|
||||
hostname, _ := os.Hostname()
|
||||
existing.Hostname = hostname
|
||||
existing.OS = runtime.GOOS
|
||||
existing.Arch = runtime.GOARCH
|
||||
return s.repo.Update(ctx, existing)
|
||||
}
|
||||
hostname, _ := os.Hostname()
|
||||
token, _ := generateToken()
|
||||
node := &model.Node{
|
||||
Name: "本机 (Local)",
|
||||
Hostname: hostname,
|
||||
Token: token,
|
||||
Status: model.NodeStatusOnline,
|
||||
IsLocal: true,
|
||||
OS: runtime.GOOS,
|
||||
Arch: runtime.GOARCH,
|
||||
LastSeen: time.Now().UTC(),
|
||||
}
|
||||
return s.repo.Create(ctx, node)
|
||||
}
|
||||
|
||||
func (s *NodeService) List(ctx context.Context) ([]NodeSummary, error) {
|
||||
nodes, err := s.repo.List(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make([]NodeSummary, len(nodes))
|
||||
for i, n := range nodes {
|
||||
result[i] = NodeSummary{
|
||||
ID: n.ID,
|
||||
Name: n.Name,
|
||||
Hostname: n.Hostname,
|
||||
IPAddress: n.IPAddress,
|
||||
Status: n.Status,
|
||||
IsLocal: n.IsLocal,
|
||||
OS: n.OS,
|
||||
Arch: n.Arch,
|
||||
AgentVersion: n.AgentVer,
|
||||
LastSeen: n.LastSeen,
|
||||
CreatedAt: n.CreatedAt,
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *NodeService) Get(ctx context.Context, id uint) (*NodeSummary, error) {
|
||||
node, err := s.repo.FindByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if node == nil {
|
||||
return nil, apperror.New(http.StatusNotFound, "NODE_NOT_FOUND", "节点不存在", nil)
|
||||
}
|
||||
return &NodeSummary{
|
||||
ID: node.ID,
|
||||
Name: node.Name,
|
||||
Hostname: node.Hostname,
|
||||
IPAddress: node.IPAddress,
|
||||
Status: node.Status,
|
||||
IsLocal: node.IsLocal,
|
||||
OS: node.OS,
|
||||
Arch: node.Arch,
|
||||
AgentVersion: node.AgentVer,
|
||||
LastSeen: node.LastSeen,
|
||||
CreatedAt: node.CreatedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Create registers a new remote node and returns its authentication token.
|
||||
func (s *NodeService) Create(ctx context.Context, input NodeCreateInput) (string, error) {
|
||||
token, err := generateToken()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generate token: %w", err)
|
||||
}
|
||||
node := &model.Node{
|
||||
Name: input.Name,
|
||||
Token: token,
|
||||
Status: model.NodeStatusOffline,
|
||||
IsLocal: false,
|
||||
LastSeen: time.Now().UTC(),
|
||||
}
|
||||
if err := s.repo.Create(ctx, node); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *NodeService) Delete(ctx context.Context, id uint) error {
|
||||
node, err := s.repo.FindByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if node == nil {
|
||||
return apperror.New(http.StatusNotFound, "NODE_NOT_FOUND", "节点不存在", nil)
|
||||
}
|
||||
if node.IsLocal {
|
||||
return apperror.BadRequest("NODE_DELETE_LOCAL", "无法删除本机节点", nil)
|
||||
}
|
||||
return s.repo.Delete(ctx, id)
|
||||
}
|
||||
|
||||
// ListDirectory lists the contents of a directory on the local node.
|
||||
func (s *NodeService) ListDirectory(ctx context.Context, nodeID uint, path string) ([]DirEntry, error) {
|
||||
node, err := s.repo.FindByID(ctx, nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if node == nil {
|
||||
return nil, apperror.New(http.StatusNotFound, "NODE_NOT_FOUND", "节点不存在", nil)
|
||||
}
|
||||
if !node.IsLocal {
|
||||
return nil, apperror.BadRequest("NODE_REMOTE_FS_NOT_SUPPORTED", "远程节点的目录浏览需要 Agent 在线连接(即将支持)", nil)
|
||||
}
|
||||
|
||||
cleanPath := filepath.Clean(path)
|
||||
entries, err := os.ReadDir(cleanPath)
|
||||
if err != nil {
|
||||
return nil, apperror.BadRequest("NODE_FS_READ_ERROR", fmt.Sprintf("无法读取目录: %s", err.Error()), err)
|
||||
}
|
||||
|
||||
result := make([]DirEntry, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
info, _ := entry.Info()
|
||||
size := int64(0)
|
||||
if info != nil {
|
||||
size = info.Size()
|
||||
}
|
||||
result = append(result, DirEntry{
|
||||
Name: entry.Name(),
|
||||
Path: filepath.Join(cleanPath, entry.Name()),
|
||||
IsDir: entry.IsDir(),
|
||||
Size: size,
|
||||
})
|
||||
}
|
||||
sort.Slice(result, func(i, j int) bool {
|
||||
if result[i].IsDir != result[j].IsDir {
|
||||
return result[i].IsDir
|
||||
}
|
||||
return result[i].Name < result[j].Name
|
||||
})
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Heartbeat updates the node status when an agent reports in.
|
||||
func (s *NodeService) Heartbeat(ctx context.Context, token string, hostname string, ip string, agentVer string) error {
|
||||
node, err := s.repo.FindByToken(ctx, token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if node == nil {
|
||||
return apperror.Unauthorized("NODE_INVALID_TOKEN", "无效的节点认证令牌", nil)
|
||||
}
|
||||
node.Status = model.NodeStatusOnline
|
||||
node.Hostname = hostname
|
||||
node.IPAddress = ip
|
||||
node.AgentVer = agentVer
|
||||
node.OS = runtime.GOOS
|
||||
node.Arch = runtime.GOARCH
|
||||
node.LastSeen = time.Now().UTC()
|
||||
return s.repo.Update(ctx, node)
|
||||
}
|
||||
|
||||
// DirEntry represents a file or directory in a node's file system.
|
||||
type DirEntry struct {
|
||||
Name string `json:"name"`
|
||||
Path string `json:"path"`
|
||||
IsDir bool `json:"isDir"`
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
func generateToken() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
251
server/internal/service/notification_service.go
Normal file
251
server/internal/service/notification_service.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"backupx/server/internal/apperror"
|
||||
"backupx/server/internal/model"
|
||||
"backupx/server/internal/notify"
|
||||
"backupx/server/internal/repository"
|
||||
"backupx/server/internal/storage/codec"
|
||||
)
|
||||
|
||||
type NotificationUpsertInput struct {
|
||||
Name string `json:"name" binding:"required,min=1,max=100"`
|
||||
Type string `json:"type" binding:"required,oneof=email webhook telegram"`
|
||||
Enabled bool `json:"enabled"`
|
||||
OnSuccess bool `json:"onSuccess"`
|
||||
OnFailure bool `json:"onFailure"`
|
||||
Config map[string]any `json:"config" binding:"required"`
|
||||
}
|
||||
|
||||
type NotificationSummary struct {
|
||||
ID uint `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Enabled bool `json:"enabled"`
|
||||
OnSuccess bool `json:"onSuccess"`
|
||||
OnFailure bool `json:"onFailure"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
type NotificationDetail struct {
|
||||
NotificationSummary
|
||||
Config map[string]any `json:"config"`
|
||||
MaskedFields []string `json:"maskedFields,omitempty"`
|
||||
}
|
||||
|
||||
type NotificationService struct {
|
||||
notifications repository.NotificationRepository
|
||||
registry *notify.Registry
|
||||
cipher *codec.ConfigCipher
|
||||
}
|
||||
|
||||
func NewNotificationService(notifications repository.NotificationRepository, registry *notify.Registry, cipher *codec.ConfigCipher) *NotificationService {
|
||||
return &NotificationService{notifications: notifications, registry: registry, cipher: cipher}
|
||||
}
|
||||
|
||||
func (s *NotificationService) List(ctx context.Context) ([]NotificationSummary, error) {
|
||||
items, err := s.notifications.List(ctx)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("NOTIFICATION_LIST_FAILED", "无法获取通知配置列表", err)
|
||||
}
|
||||
result := make([]NotificationSummary, 0, len(items))
|
||||
for _, item := range items {
|
||||
result = append(result, toNotificationSummary(&item))
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *NotificationService) Get(ctx context.Context, id uint) (*NotificationDetail, error) {
|
||||
item, err := s.notifications.FindByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("NOTIFICATION_GET_FAILED", "无法获取通知配置详情", err)
|
||||
}
|
||||
if item == nil {
|
||||
return nil, apperror.New(http.StatusNotFound, "NOTIFICATION_NOT_FOUND", "通知配置不存在", fmt.Errorf("notification %d not found", id))
|
||||
}
|
||||
return s.toDetail(item)
|
||||
}
|
||||
|
||||
func (s *NotificationService) Create(ctx context.Context, input NotificationUpsertInput) (*NotificationDetail, error) {
|
||||
if err := s.validateInput(ctx, 0, input); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
item, err := s.buildNotification(nil, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.notifications.Create(ctx, item); err != nil {
|
||||
return nil, apperror.Internal("NOTIFICATION_CREATE_FAILED", "无法创建通知配置", err)
|
||||
}
|
||||
return s.Get(ctx, item.ID)
|
||||
}
|
||||
|
||||
func (s *NotificationService) Update(ctx context.Context, id uint, input NotificationUpsertInput) (*NotificationDetail, error) {
|
||||
existing, err := s.notifications.FindByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("NOTIFICATION_GET_FAILED", "无法获取通知配置详情", err)
|
||||
}
|
||||
if existing == nil {
|
||||
return nil, apperror.New(http.StatusNotFound, "NOTIFICATION_NOT_FOUND", "通知配置不存在", fmt.Errorf("notification %d not found", id))
|
||||
}
|
||||
if err := s.validateInput(ctx, existing.ID, input); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
item, err := s.buildNotification(existing, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
item.ID = existing.ID
|
||||
item.CreatedAt = existing.CreatedAt
|
||||
if err := s.notifications.Update(ctx, item); err != nil {
|
||||
return nil, apperror.Internal("NOTIFICATION_UPDATE_FAILED", "无法更新通知配置", err)
|
||||
}
|
||||
return s.Get(ctx, id)
|
||||
}
|
||||
|
||||
func (s *NotificationService) Delete(ctx context.Context, id uint) error {
|
||||
item, err := s.notifications.FindByID(ctx, id)
|
||||
if err != nil {
|
||||
return apperror.Internal("NOTIFICATION_GET_FAILED", "无法获取通知配置详情", err)
|
||||
}
|
||||
if item == nil {
|
||||
return apperror.New(http.StatusNotFound, "NOTIFICATION_NOT_FOUND", "通知配置不存在", fmt.Errorf("notification %d not found", id))
|
||||
}
|
||||
if err := s.notifications.Delete(ctx, id); err != nil {
|
||||
return apperror.Internal("NOTIFICATION_DELETE_FAILED", "无法删除通知配置", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *NotificationService) Test(ctx context.Context, input NotificationUpsertInput) error {
|
||||
if err := s.registry.Validate(strings.TrimSpace(input.Type), input.Config); err != nil {
|
||||
return apperror.BadRequest("NOTIFICATION_INVALID", "通知配置不合法", err)
|
||||
}
|
||||
message := notify.Message{Title: "BackupX 通知测试", Body: "这是一条来自 BackupX 的测试通知。", Fields: map[string]any{"type": input.Type, "timestamp": time.Now().UTC().Format(time.RFC3339)}}
|
||||
if err := s.registry.Send(ctx, input.Type, input.Config, message); err != nil {
|
||||
return apperror.BadRequest("NOTIFICATION_TEST_FAILED", "发送测试通知失败", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *NotificationService) TestSaved(ctx context.Context, id uint) error {
|
||||
item, err := s.notifications.FindByID(ctx, id)
|
||||
if err != nil {
|
||||
return apperror.Internal("NOTIFICATION_GET_FAILED", "无法获取通知配置", err)
|
||||
}
|
||||
if item == nil {
|
||||
return apperror.New(http.StatusNotFound, "NOTIFICATION_NOT_FOUND", "通知配置不存在", fmt.Errorf("notification %d not found", id))
|
||||
}
|
||||
configMap := map[string]any{}
|
||||
if err := s.cipher.DecryptJSON(item.ConfigCiphertext, &configMap); err != nil {
|
||||
return apperror.Internal("NOTIFICATION_DECRYPT_FAILED", "无法读取通知配置", err)
|
||||
}
|
||||
message := notify.Message{Title: "BackupX 通知测试", Body: "这是一条来自 BackupX 的测试通知。", Fields: map[string]any{"type": item.Type, "timestamp": time.Now().UTC().Format(time.RFC3339)}}
|
||||
if err := s.registry.Send(ctx, item.Type, configMap, message); err != nil {
|
||||
return apperror.BadRequest("NOTIFICATION_TEST_FAILED", "发送测试通知失败", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *NotificationService) NotifyBackupResult(ctx context.Context, event BackupExecutionNotification) error {
|
||||
success := event.Error == nil && event.Record != nil && event.Record.Status == "success"
|
||||
items, err := s.notifications.ListEnabledForEvent(ctx, success)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
message := buildNotificationMessage(event)
|
||||
var joined error
|
||||
for _, item := range items {
|
||||
configMap := map[string]any{}
|
||||
if err := s.cipher.DecryptJSON(item.ConfigCiphertext, &configMap); err != nil {
|
||||
joined = errors.Join(joined, fmt.Errorf("decrypt notification %d config: %w", item.ID, err))
|
||||
continue
|
||||
}
|
||||
if err := s.registry.Send(ctx, item.Type, configMap, message); err != nil {
|
||||
joined = errors.Join(joined, fmt.Errorf("send notification %s failed: %w", item.Name, err))
|
||||
}
|
||||
}
|
||||
return joined
|
||||
}
|
||||
|
||||
func (s *NotificationService) validateInput(ctx context.Context, currentID uint, input NotificationUpsertInput) error {
|
||||
existing, err := s.notifications.FindByName(ctx, strings.TrimSpace(input.Name))
|
||||
if err != nil {
|
||||
return apperror.Internal("NOTIFICATION_LOOKUP_FAILED", "无法检查通知配置名称", err)
|
||||
}
|
||||
if existing != nil && existing.ID != currentID {
|
||||
return apperror.Conflict("NOTIFICATION_NAME_EXISTS", "通知配置名称已存在", nil)
|
||||
}
|
||||
if err := s.registry.Validate(strings.TrimSpace(input.Type), input.Config); err != nil {
|
||||
return apperror.BadRequest("NOTIFICATION_INVALID", "通知配置不合法", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *NotificationService) buildNotification(existing *model.Notification, input NotificationUpsertInput) (*model.Notification, error) {
|
||||
configMap := input.Config
|
||||
if existing != nil {
|
||||
currentConfig := map[string]any{}
|
||||
if err := s.cipher.DecryptJSON(existing.ConfigCiphertext, ¤tConfig); err != nil {
|
||||
return nil, apperror.Internal("NOTIFICATION_DECRYPT_FAILED", "无法读取现有通知配置", err)
|
||||
}
|
||||
configMap = codec.MergeMaskedConfig(input.Config, currentConfig, s.registry.SensitiveFields(input.Type))
|
||||
}
|
||||
ciphertext, err := s.cipher.EncryptJSON(configMap)
|
||||
if err != nil {
|
||||
return nil, apperror.Internal("NOTIFICATION_ENCRYPT_FAILED", "无法保存通知配置", err)
|
||||
}
|
||||
item := &model.Notification{Name: strings.TrimSpace(input.Name), Type: strings.TrimSpace(input.Type), ConfigCiphertext: ciphertext, Enabled: input.Enabled, OnSuccess: input.OnSuccess, OnFailure: input.OnFailure}
|
||||
return item, nil
|
||||
}
|
||||
|
||||
func (s *NotificationService) toDetail(item *model.Notification) (*NotificationDetail, error) {
|
||||
configMap := map[string]any{}
|
||||
if err := s.cipher.DecryptJSON(item.ConfigCiphertext, &configMap); err != nil {
|
||||
return nil, apperror.Internal("NOTIFICATION_DECRYPT_FAILED", "无法读取通知配置", err)
|
||||
}
|
||||
sensitiveFields := s.registry.SensitiveFields(item.Type)
|
||||
return &NotificationDetail{NotificationSummary: toNotificationSummary(item), Config: codec.MaskConfig(configMap, sensitiveFields), MaskedFields: sensitiveFields}, nil
|
||||
}
|
||||
|
||||
func toNotificationSummary(item *model.Notification) NotificationSummary {
|
||||
return NotificationSummary{ID: item.ID, Name: item.Name, Type: item.Type, Enabled: item.Enabled, OnSuccess: item.OnSuccess, OnFailure: item.OnFailure, UpdatedAt: item.UpdatedAt}
|
||||
}
|
||||
|
||||
func buildNotificationMessage(event BackupExecutionNotification) notify.Message {
|
||||
statusText := "失败"
|
||||
if event.Error == nil && event.Record != nil && event.Record.Status == "success" {
|
||||
statusText = "成功"
|
||||
}
|
||||
taskName := "未知任务"
|
||||
if event.Task != nil {
|
||||
taskName = event.Task.Name
|
||||
}
|
||||
body := fmt.Sprintf("任务:%s\n状态:%s", taskName, statusText)
|
||||
fields := map[string]any{"taskName": taskName, "status": statusText}
|
||||
if event.Record != nil {
|
||||
body += fmt.Sprintf("\n开始时间:%s\n耗时:%d 秒", event.Record.StartedAt.Format(time.RFC3339), event.Record.DurationSeconds)
|
||||
fields["recordId"] = event.Record.ID
|
||||
fields["durationSeconds"] = event.Record.DurationSeconds
|
||||
if event.Record.FileName != "" {
|
||||
body += fmt.Sprintf("\n文件:%s", event.Record.FileName)
|
||||
fields["fileName"] = event.Record.FileName
|
||||
}
|
||||
if event.Record.FileSize > 0 {
|
||||
body += fmt.Sprintf("\n大小:%d", event.Record.FileSize)
|
||||
fields["fileSize"] = event.Record.FileSize
|
||||
}
|
||||
if event.Record.ErrorMessage != "" {
|
||||
body += fmt.Sprintf("\n错误:%s", event.Record.ErrorMessage)
|
||||
fields["error"] = event.Record.ErrorMessage
|
||||
}
|
||||
}
|
||||
return notify.Message{Title: "BackupX 备份" + statusText + "通知", Body: body, Fields: fields}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user