diff --git a/server/internal/repository/agent_install_token_repository.go b/server/internal/repository/agent_install_token_repository.go new file mode 100644 index 0000000..58b49f3 --- /dev/null +++ b/server/internal/repository/agent_install_token_repository.go @@ -0,0 +1,87 @@ +package repository + +import ( + "context" + "errors" + "time" + + "backupx/server/internal/model" + "gorm.io/gorm" +) + +// AgentInstallTokenRepository 一次性安装令牌仓储。 +type AgentInstallTokenRepository interface { + Create(ctx context.Context, t *model.AgentInstallToken) error + FindByToken(ctx context.Context, token string) (*model.AgentInstallToken, error) + // ConsumeByToken 原子消费:仅当 token 存在、未过期、未消费时成功,返回消费后的记录。 + // 其它情况(不存在/已过期/已消费)一律返回 (nil, nil)。 + ConsumeByToken(ctx context.Context, token string) (*model.AgentInstallToken, error) + // DeleteExpiredBefore 硬删除 ExpiresAt < threshold 的记录。 + DeleteExpiredBefore(ctx context.Context, threshold time.Time) (int64, error) + // CountCreatedSince 统计 node 在 since 之后创建的数量(用于节点级限流)。 + CountCreatedSince(ctx context.Context, nodeID uint, since time.Time) (int64, error) +} + +type GormAgentInstallTokenRepository struct { + db *gorm.DB +} + +func NewAgentInstallTokenRepository(db *gorm.DB) *GormAgentInstallTokenRepository { + return &GormAgentInstallTokenRepository{db: db} +} + +func (r *GormAgentInstallTokenRepository) Create(ctx context.Context, t *model.AgentInstallToken) error { + return r.db.WithContext(ctx).Create(t).Error +} + +func (r *GormAgentInstallTokenRepository) FindByToken(ctx context.Context, token string) (*model.AgentInstallToken, error) { + var item model.AgentInstallToken + if err := r.db.WithContext(ctx).Where("token = ?", token).First(&item).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, err + } + return &item, nil +} + +// ConsumeByToken 使用条件 UPDATE + RowsAffected 实现原子消费。 +// SQLite 不支持 SELECT FOR UPDATE,但 UPDATE 本身在 SQLite 中是原子的。 +func (r *GormAgentInstallTokenRepository) ConsumeByToken(ctx context.Context, token string) (*model.AgentInstallToken, error) { + var consumed *model.AgentInstallToken + err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + now := time.Now().UTC() + result := tx.Model(&model.AgentInstallToken{}). + Where("token = ? AND consumed_at IS NULL AND expires_at > ?", token, now). + Update("consumed_at", &now) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return nil + } + var item model.AgentInstallToken + if err := tx.Where("token = ?", token).First(&item).Error; err != nil { + return err + } + consumed = &item + return nil + }) + if err != nil { + return nil, err + } + return consumed, nil +} + +func (r *GormAgentInstallTokenRepository) DeleteExpiredBefore(ctx context.Context, threshold time.Time) (int64, error) { + result := r.db.WithContext(ctx).Where("expires_at < ?", threshold).Delete(&model.AgentInstallToken{}) + return result.RowsAffected, result.Error +} + +func (r *GormAgentInstallTokenRepository) CountCreatedSince(ctx context.Context, nodeID uint, since time.Time) (int64, error) { + var n int64 + err := r.db.WithContext(ctx).Model(&model.AgentInstallToken{}). + Where("node_id = ? AND created_at >= ?", nodeID, since). + Count(&n).Error + return n, err +} diff --git a/server/internal/repository/agent_install_token_repository_test.go b/server/internal/repository/agent_install_token_repository_test.go new file mode 100644 index 0000000..e0f9534 --- /dev/null +++ b/server/internal/repository/agent_install_token_repository_test.go @@ -0,0 +1,151 @@ +package repository + +import ( + "context" + "path/filepath" + "testing" + "time" + + "backupx/server/internal/model" + "github.com/glebarez/sqlite" + "gorm.io/gorm" + gormlogger "gorm.io/gorm/logger" +) + +func openTestInstallTokenDB(t *testing.T) *gorm.DB { + t.Helper() + path := filepath.Join(t.TempDir(), "install.db") + db, err := gorm.Open(sqlite.Open(path), &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)}) + if err != nil { + t.Fatalf("open: %v", err) + } + if err := db.AutoMigrate(&model.AgentInstallToken{}); err != nil { + t.Fatalf("migrate: %v", err) + } + return db +} + +func TestInstallTokenConsumeOnce(t *testing.T) { + db := openTestInstallTokenDB(t) + repo := NewAgentInstallTokenRepository(db) + ctx := context.Background() + + tok := &model.AgentInstallToken{ + Token: "abc", NodeID: 1, Mode: model.InstallModeSystemd, + Arch: model.InstallArchAuto, AgentVer: "v1.7.0", + DownloadSrc: model.InstallSourceGitHub, + ExpiresAt: time.Now().UTC().Add(15 * time.Minute), + CreatedByID: 1, + } + if err := repo.Create(ctx, tok); err != nil { + t.Fatalf("create: %v", err) + } + + got, err := repo.ConsumeByToken(ctx, "abc") + if err != nil { + t.Fatalf("consume err: %v", err) + } + if got == nil || got.ConsumedAt == nil { + t.Fatalf("expected consumed token, got %+v", got) + } + + got, err = repo.ConsumeByToken(ctx, "abc") + if err != nil { + t.Fatalf("second consume err: %v", err) + } + if got != nil { + t.Fatalf("expected nil on second consume, got %+v", got) + } +} + +func TestInstallTokenConsumeExpired(t *testing.T) { + db := openTestInstallTokenDB(t) + repo := NewAgentInstallTokenRepository(db) + ctx := context.Background() + + tok := &model.AgentInstallToken{ + Token: "stale", NodeID: 1, Mode: model.InstallModeSystemd, + Arch: model.InstallArchAuto, AgentVer: "v1.7.0", + DownloadSrc: model.InstallSourceGitHub, + ExpiresAt: time.Now().UTC().Add(-time.Minute), + CreatedByID: 1, + } + if err := repo.Create(ctx, tok); err != nil { + t.Fatalf("create: %v", err) + } + + got, err := repo.ConsumeByToken(ctx, "stale") + if err != nil { + t.Fatalf("consume err: %v", err) + } + if got != nil { + t.Fatalf("expected nil on expired, got %+v", got) + } +} + +func TestInstallTokenGC(t *testing.T) { + db := openTestInstallTokenDB(t) + repo := NewAgentInstallTokenRepository(db) + ctx := context.Background() + + old := &model.AgentInstallToken{ + Token: "old", NodeID: 1, Mode: model.InstallModeSystemd, + Arch: model.InstallArchAuto, AgentVer: "v1.7.0", + DownloadSrc: model.InstallSourceGitHub, + ExpiresAt: time.Now().UTC().Add(-8 * 24 * time.Hour), + CreatedByID: 1, + } + if err := repo.Create(ctx, old); err != nil { + t.Fatalf("create old: %v", err) + } + + fresh := &model.AgentInstallToken{ + Token: "fresh", NodeID: 1, Mode: model.InstallModeSystemd, + Arch: model.InstallArchAuto, AgentVer: "v1.7.0", + DownloadSrc: model.InstallSourceGitHub, + ExpiresAt: time.Now().UTC().Add(-1 * time.Hour), + CreatedByID: 1, + } + if err := repo.Create(ctx, fresh); err != nil { + t.Fatalf("create fresh: %v", err) + } + + n, err := repo.DeleteExpiredBefore(ctx, time.Now().UTC().Add(-7*24*time.Hour)) + if err != nil { + t.Fatalf("gc err: %v", err) + } + if n != 1 { + t.Fatalf("expected 1 deleted, got %d", n) + } +} + +func TestInstallTokenCountCreatedSince(t *testing.T) { + db := openTestInstallTokenDB(t) + repo := NewAgentInstallTokenRepository(db) + ctx := context.Background() + + // 同一节点 3 条 + for i := 0; i < 3; i++ { + _ = repo.Create(ctx, &model.AgentInstallToken{ + Token: "t" + string(rune('a'+i)), NodeID: 1, Mode: "systemd", Arch: "auto", + AgentVer: "v1", DownloadSrc: "github", + ExpiresAt: time.Now().UTC().Add(time.Minute), CreatedByID: 1, + }) + } + // 另一节点 2 条(不计入) + for i := 0; i < 2; i++ { + _ = repo.Create(ctx, &model.AgentInstallToken{ + Token: "n2_" + string(rune('a'+i)), NodeID: 2, Mode: "systemd", Arch: "auto", + AgentVer: "v1", DownloadSrc: "github", + ExpiresAt: time.Now().UTC().Add(time.Minute), CreatedByID: 1, + }) + } + + n, err := repo.CountCreatedSince(ctx, 1, time.Now().UTC().Add(-time.Minute)) + if err != nil { + t.Fatalf("count err: %v", err) + } + if n != 3 { + t.Fatalf("expected 3, got %d", n) + } +}