From da94c38df38ebe48e23349e6be4c6bbbe63f9f8e Mon Sep 17 00:00:00 2001 From: Awuqing <3184394176@qq.com> Date: Sun, 19 Apr 2026 16:24:10 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8A=9F=E8=83=BD:=20=E6=96=B0=E5=A2=9E=20Inst?= =?UTF-8?q?allTokenService=20=E5=90=AB=E8=BE=93=E5=85=A5=E6=A0=A1=E9=AA=8C?= =?UTF-8?q?=E3=80=81=E9=99=90=E6=B5=81=E3=80=81GC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../internal/service/install_token_service.go | 187 ++++++++++++++++++ .../service/install_token_service_test.go | 156 +++++++++++++++ 2 files changed, 343 insertions(+) create mode 100644 server/internal/service/install_token_service.go create mode 100644 server/internal/service/install_token_service_test.go diff --git a/server/internal/service/install_token_service.go b/server/internal/service/install_token_service.go new file mode 100644 index 0000000..c697807 --- /dev/null +++ b/server/internal/service/install_token_service.go @@ -0,0 +1,187 @@ +package service + +import ( + "context" + "crypto/rand" + "encoding/hex" + "fmt" + "strings" + "time" + + "backupx/server/internal/apperror" + "backupx/server/internal/model" + "backupx/server/internal/repository" +) + +// InstallTokenService 负责一次性安装令牌的创建/消费/校验。 +type InstallTokenService struct { + repo repository.AgentInstallTokenRepository + nodeRepo repository.NodeRepository +} + +func NewInstallTokenService(repo repository.AgentInstallTokenRepository, nodeRepo repository.NodeRepository) *InstallTokenService { + return &InstallTokenService{repo: repo, nodeRepo: nodeRepo} +} + +// InstallTokenInput 生成一次性安装令牌的输入。 +type InstallTokenInput struct { + NodeID uint + Mode string + Arch string + AgentVersion string + DownloadSrc string + TTLSeconds int + CreatedByID uint +} + +// InstallTokenOutput 生成结果。 +type InstallTokenOutput struct { + Token string + ExpiresAt time.Time + Node *model.Node + Record *model.AgentInstallToken +} + +// ConsumedInstallToken 消费成功后返回给 handler 的组合体。 +type ConsumedInstallToken struct { + Record *model.AgentInstallToken + Node *model.Node +} + +// 校验与限流常量。 +const ( + InstallTokenMinTTL = 300 // 5 分钟 + InstallTokenMaxTTL = 86400 // 24 小时 + InstallTokenRateWindow = 60 * time.Second + InstallTokenRatePerWin = 5 +) + +var ( + validInstallModes = map[string]bool{model.InstallModeSystemd: true, model.InstallModeDocker: true, model.InstallModeForeground: true} + validInstallArches = map[string]bool{model.InstallArchAmd64: true, model.InstallArchArm64: true, model.InstallArchAuto: true} + validInstallSources = map[string]bool{model.InstallSourceGitHub: true, model.InstallSourceGhproxy: true} +) + +// Create 生成一次性安装令牌。 +func (s *InstallTokenService) Create(ctx context.Context, in InstallTokenInput) (*InstallTokenOutput, error) { + if err := s.validate(in); err != nil { + return nil, err + } + node, err := s.nodeRepo.FindByID(ctx, in.NodeID) + if err != nil { + return nil, err + } + if node == nil { + return nil, apperror.New(404, "NODE_NOT_FOUND", "节点不存在", nil) + } + + since := time.Now().UTC().Add(-InstallTokenRateWindow) + count, err := s.repo.CountCreatedSince(ctx, in.NodeID, since) + if err != nil { + return nil, err + } + if count >= InstallTokenRatePerWin { + return nil, apperror.TooManyRequests("INSTALL_TOKEN_RATE_LIMITED", + fmt.Sprintf("每 %d 秒最多生成 %d 次", int(InstallTokenRateWindow.Seconds()), InstallTokenRatePerWin), nil) + } + + token, err := generateInstallToken() + if err != nil { + return nil, fmt.Errorf("generate token: %w", err) + } + expiresAt := time.Now().UTC().Add(time.Duration(in.TTLSeconds) * time.Second) + record := &model.AgentInstallToken{ + Token: token, + NodeID: in.NodeID, + Mode: in.Mode, + Arch: in.Arch, + AgentVer: in.AgentVersion, + DownloadSrc: in.DownloadSrc, + ExpiresAt: expiresAt, + CreatedByID: in.CreatedByID, + } + if err := s.repo.Create(ctx, record); err != nil { + return nil, err + } + return &InstallTokenOutput{Token: token, ExpiresAt: expiresAt, Node: node, Record: record}, nil +} + +// Consume 原子消费令牌。未命中/已过期/已消费均返回 (nil, nil)。 +func (s *InstallTokenService) Consume(ctx context.Context, token string) (*ConsumedInstallToken, error) { + if strings.TrimSpace(token) == "" { + return nil, nil + } + record, err := s.repo.ConsumeByToken(ctx, token) + if err != nil { + return nil, err + } + if record == nil { + return nil, nil + } + node, err := s.nodeRepo.FindByID(ctx, record.NodeID) + if err != nil { + return nil, err + } + if node == nil { + return nil, apperror.New(404, "NODE_NOT_FOUND", "节点已被删除", nil) + } + return &ConsumedInstallToken{Record: record, Node: node}, nil +} + +// Peek 只读查询(不消费),供 compose 端点预检 Mode。 +func (s *InstallTokenService) Peek(ctx context.Context, token string) (*model.AgentInstallToken, error) { + if strings.TrimSpace(token) == "" { + return nil, nil + } + return s.repo.FindByToken(ctx, token) +} + +// StartGC 启动后台 GC,按 interval 扫描并删 ExpiresAt < now-7d 的记录。 +func (s *InstallTokenService) StartGC(ctx context.Context, interval time.Duration) { + if interval <= 0 { + interval = time.Hour + } + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + _, _ = s.repo.DeleteExpiredBefore(ctx, time.Now().UTC().Add(-7*24*time.Hour)) + } + } + }() +} + +func (s *InstallTokenService) validate(in InstallTokenInput) error { + if in.NodeID == 0 { + return apperror.BadRequest("INSTALL_TOKEN_INVALID", "nodeId 必填", nil) + } + if !validInstallModes[in.Mode] { + return apperror.BadRequest("INSTALL_TOKEN_INVALID", "mode 非法", nil) + } + if !validInstallArches[in.Arch] { + return apperror.BadRequest("INSTALL_TOKEN_INVALID", "arch 非法", nil) + } + if !validInstallSources[in.DownloadSrc] { + return apperror.BadRequest("INSTALL_TOKEN_INVALID", "downloadSrc 非法", nil) + } + if strings.TrimSpace(in.AgentVersion) == "" { + return apperror.BadRequest("INSTALL_TOKEN_INVALID", "agentVersion 必填", nil) + } + if in.TTLSeconds < InstallTokenMinTTL || in.TTLSeconds > InstallTokenMaxTTL { + return apperror.BadRequest("INSTALL_TOKEN_INVALID", + fmt.Sprintf("ttlSeconds 需在 %d-%d", InstallTokenMinTTL, InstallTokenMaxTTL), nil) + } + return nil +} + +func generateInstallToken() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return hex.EncodeToString(b), nil +} diff --git a/server/internal/service/install_token_service_test.go b/server/internal/service/install_token_service_test.go new file mode 100644 index 0000000..0552202 --- /dev/null +++ b/server/internal/service/install_token_service_test.go @@ -0,0 +1,156 @@ +package service + +import ( + "context" + "path/filepath" + "testing" + "time" + + "backupx/server/internal/model" + "backupx/server/internal/repository" + "github.com/glebarez/sqlite" + "gorm.io/gorm" + gormlogger "gorm.io/gorm/logger" +) + +func openInstallTokenTestDB(t *testing.T) *gorm.DB { + t.Helper() + db, err := gorm.Open(sqlite.Open(filepath.Join(t.TempDir(), "it.db")), + &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)}) + if err != nil { + t.Fatalf("open: %v", err) + } + if err := db.AutoMigrate(&model.AgentInstallToken{}, &model.Node{}); err != nil { + t.Fatalf("migrate: %v", err) + } + return db +} + +func TestInstallTokenServiceCreateAndConsume(t *testing.T) { + db := openInstallTokenTestDB(t) + repo := repository.NewAgentInstallTokenRepository(db) + nodeRepo := repository.NewNodeRepository(db) + + node := &model.Node{Name: "n1", Token: "agent-token"} + if err := nodeRepo.Create(context.Background(), node); err != nil { + t.Fatalf("create node: %v", err) + } + + svc := NewInstallTokenService(repo, nodeRepo) + created, err := svc.Create(context.Background(), InstallTokenInput{ + NodeID: node.ID, + Mode: model.InstallModeSystemd, + Arch: model.InstallArchAuto, + AgentVersion: "v1.7.0", + DownloadSrc: model.InstallSourceGitHub, + TTLSeconds: 900, + CreatedByID: 1, + }) + if err != nil { + t.Fatalf("create: %v", err) + } + if created.Token == "" || created.ExpiresAt.Before(time.Now().UTC()) { + t.Fatalf("invalid token: %+v", created) + } + + consumed, err := svc.Consume(context.Background(), created.Token) + if err != nil { + t.Fatalf("consume: %v", err) + } + if consumed == nil || consumed.Node.ID != node.ID { + t.Fatalf("expected consumed token for node, got %+v", consumed) + } + + again, err := svc.Consume(context.Background(), created.Token) + if err != nil { + t.Fatalf("second consume err: %v", err) + } + if again != nil { + t.Fatalf("expected nil on second consume") + } +} + +func TestInstallTokenServicePeekDoesNotConsume(t *testing.T) { + db := openInstallTokenTestDB(t) + repo := repository.NewAgentInstallTokenRepository(db) + nodeRepo := repository.NewNodeRepository(db) + node := &model.Node{Name: "n2", Token: "tok2"} + _ = nodeRepo.Create(context.Background(), node) + + svc := NewInstallTokenService(repo, nodeRepo) + out, err := svc.Create(context.Background(), InstallTokenInput{ + NodeID: node.ID, Mode: "docker", Arch: "auto", + AgentVersion: "v1", DownloadSrc: "github", TTLSeconds: 300, CreatedByID: 1, + }) + if err != nil { + t.Fatalf("create: %v", err) + } + + // Peek 两次都应成功(不消费) + for i := 0; i < 2; i++ { + rec, err := svc.Peek(context.Background(), out.Token) + if err != nil { + t.Fatalf("peek %d: %v", i, err) + } + if rec == nil || rec.Mode != "docker" { + t.Fatalf("peek %d bad: %+v", i, rec) + } + } + + // 之后仍可消费 + consumed, _ := svc.Consume(context.Background(), out.Token) + if consumed == nil { + t.Fatalf("consume after peek failed") + } +} + +func TestInstallTokenServiceValidatesInput(t *testing.T) { + db := openInstallTokenTestDB(t) + nodeRepo := repository.NewNodeRepository(db) + node := &model.Node{Name: "valid", Token: "t"} + _ = nodeRepo.Create(context.Background(), node) + + svc := NewInstallTokenService(repository.NewAgentInstallTokenRepository(db), nodeRepo) + cases := []struct { + name string + in InstallTokenInput + }{ + {"bad mode", InstallTokenInput{NodeID: node.ID, Mode: "xxx", Arch: "auto", AgentVersion: "v1", DownloadSrc: "github", TTLSeconds: 300, CreatedByID: 1}}, + {"bad arch", InstallTokenInput{NodeID: node.ID, Mode: "systemd", Arch: "risc", AgentVersion: "v1", DownloadSrc: "github", TTLSeconds: 300, CreatedByID: 1}}, + {"bad source", InstallTokenInput{NodeID: node.ID, Mode: "systemd", Arch: "auto", AgentVersion: "v1", DownloadSrc: "bogus", TTLSeconds: 300, CreatedByID: 1}}, + {"bad ttl low", InstallTokenInput{NodeID: node.ID, Mode: "systemd", Arch: "auto", AgentVersion: "v1", DownloadSrc: "github", TTLSeconds: 10, CreatedByID: 1}}, + {"bad ttl high", InstallTokenInput{NodeID: node.ID, Mode: "systemd", Arch: "auto", AgentVersion: "v1", DownloadSrc: "github", TTLSeconds: 999999, CreatedByID: 1}}, + {"missing version", InstallTokenInput{NodeID: node.ID, Mode: "systemd", Arch: "auto", AgentVersion: "", DownloadSrc: "github", TTLSeconds: 300, CreatedByID: 1}}, + {"missing node id", InstallTokenInput{NodeID: 0, Mode: "systemd", Arch: "auto", AgentVersion: "v1", DownloadSrc: "github", TTLSeconds: 300, CreatedByID: 1}}, + {"node not exists", InstallTokenInput{NodeID: 999, Mode: "systemd", Arch: "auto", AgentVersion: "v1", DownloadSrc: "github", TTLSeconds: 300, CreatedByID: 1}}, + } + for _, tc := range cases { + if _, err := svc.Create(context.Background(), tc.in); err == nil { + t.Errorf("%s: expected validation error", tc.name) + } + } +} + +func TestInstallTokenServiceRateLimit(t *testing.T) { + db := openInstallTokenTestDB(t) + nodeRepo := repository.NewNodeRepository(db) + node := &model.Node{Name: "rl", Token: "rl"} + _ = nodeRepo.Create(context.Background(), node) + + svc := NewInstallTokenService(repository.NewAgentInstallTokenRepository(db), nodeRepo) + base := InstallTokenInput{ + NodeID: node.ID, Mode: "systemd", Arch: "auto", + AgentVersion: "v1", DownloadSrc: "github", TTLSeconds: 300, CreatedByID: 1, + } + // 前 5 次成功 + for i := 0; i < 5; i++ { + if _, err := svc.Create(context.Background(), base); err != nil { + t.Fatalf("iter %d: %v", i, err) + } + } + // 第 6 次应被限流 + _, err := svc.Create(context.Background(), base) + if err == nil { + t.Fatalf("expected rate limit error") + } +}