From 012d54a946d6fdc255edb85899675fd976f11382 Mon Sep 17 00:00:00 2001 From: Awuqing <3184394176@qq.com> Date: Sun, 19 Apr 2026 16:26:41 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8A=9F=E8=83=BD:=20NodeService=20=E6=96=B0?= =?UTF-8?q?=E5=A2=9E=20BatchCreate=20=E4=B8=8E=20RotateToken?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/internal/service/node_service.go | 109 +++++++++++++ server/internal/service/node_service_test.go | 159 +++++++++++++++++++ 2 files changed, 268 insertions(+) create mode 100644 server/internal/service/node_service_test.go diff --git a/server/internal/service/node_service.go b/server/internal/service/node_service.go index 6dee41a..34b0ec4 100644 --- a/server/internal/service/node_service.go +++ b/server/internal/service/node_service.go @@ -373,6 +373,115 @@ func detectLocalIP() string { return "" } +// NodeCreateResult 批量创建结果。注意:不暴露 agent token,token 获取走 install-token 流程。 +type NodeCreateResult struct { + ID uint `json:"id"` + Name string `json:"name"` +} + +// BatchCreate 批量创建远程节点。 +// 校验:1-50 项、每项 1-128 字符、批次内去重、与已有节点名去重。 +// 返回 NodeCreateResult 列表(不含 token,调用方应再调 install-tokens 接口)。 +func (s *NodeService) BatchCreate(ctx context.Context, names []string) ([]NodeCreateResult, error) { + cleaned, err := validateBatchNames(names) + if err != nil { + return nil, err + } + existing, err := s.repo.List(ctx) + if err != nil { + return nil, err + } + existingSet := make(map[string]bool, len(existing)) + for _, n := range existing { + existingSet[n.Name] = true + } + for _, name := range cleaned { + if existingSet[name] { + return nil, apperror.BadRequest("NODE_DUPLICATE_NAME", + fmt.Sprintf("节点名「%s」已存在", name), nil) + } + } + + results := make([]NodeCreateResult, 0, len(cleaned)) + for _, name := range cleaned { + tok, err := generateToken() + if err != nil { + return nil, fmt.Errorf("generate token: %w", err) + } + node := &model.Node{ + Name: name, + Token: tok, + Status: model.NodeStatusOffline, + IsLocal: false, + LastSeen: time.Now().UTC(), + } + if err := s.repo.Create(ctx, node); err != nil { + return nil, err + } + results = append(results, NodeCreateResult{ID: node.ID, Name: node.Name}) + } + return results, nil +} + +// RotateToken 轮换指定节点的 agent token。 +// 旧 token 复制到 prev_token,24h 内新旧 token 均可认证。 +func (s *NodeService) RotateToken(ctx context.Context, id uint) (string, error) { + node, err := s.repo.FindByID(ctx, id) + if err != nil { + return "", err + } + if node == nil { + return "", apperror.New(http.StatusNotFound, "NODE_NOT_FOUND", "节点不存在", nil) + } + if node.IsLocal { + return "", apperror.BadRequest("NODE_ROTATE_LOCAL", "本机节点无需轮换 Token", nil) + } + newTok, err := generateToken() + if err != nil { + return "", fmt.Errorf("generate: %w", err) + } + expires := time.Now().UTC().Add(24 * time.Hour) + node.PrevToken = node.Token + node.PrevTokenExpires = &expires + node.Token = newTok + if err := s.repo.Update(ctx, node); err != nil { + return "", err + } + return newTok, nil +} + +// validateBatchNames 校验并去重批次内名称(空白行忽略)。 +func validateBatchNames(names []string) ([]string, error) { + if len(names) == 0 { + return nil, apperror.BadRequest("NODE_BATCH_EMPTY", "节点名列表不能为空", nil) + } + if len(names) > 50 { + return nil, apperror.BadRequest("NODE_BATCH_TOO_MANY", "单次最多创建 50 个节点", nil) + } + seen := make(map[string]bool, len(names)) + out := make([]string, 0, len(names)) + for _, raw := range names { + name := strings.TrimSpace(raw) + if name == "" { + continue + } + if len(name) > 128 { + return nil, apperror.BadRequest("NODE_NAME_TOO_LONG", + fmt.Sprintf("节点名「%s」超过 128 字符", name), nil) + } + if seen[name] { + return nil, apperror.BadRequest("NODE_DUPLICATE_NAME", + fmt.Sprintf("批次内重复节点名「%s」", name), nil) + } + seen[name] = true + out = append(out, name) + } + if len(out) == 0 { + return nil, apperror.BadRequest("NODE_BATCH_EMPTY", "去除空白后列表为空", nil) + } + return out, nil +} + func generateToken() (string, error) { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { diff --git a/server/internal/service/node_service_test.go b/server/internal/service/node_service_test.go new file mode 100644 index 0000000..51cca16 --- /dev/null +++ b/server/internal/service/node_service_test.go @@ -0,0 +1,159 @@ +package service + +import ( + "context" + "path/filepath" + "testing" + "time" + + "backupx/server/internal/model" + "backupx/server/internal/repository" + "github.com/glebarez/sqlite" + "gorm.io/gorm" + gormlogger "gorm.io/gorm/logger" +) + +func openNodeServiceDB(t *testing.T) *gorm.DB { + t.Helper() + db, err := gorm.Open(sqlite.Open(filepath.Join(t.TempDir(), "ns.db")), + &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)}) + if err != nil { + t.Fatalf("open: %v", err) + } + if err := db.AutoMigrate(&model.Node{}); err != nil { + t.Fatalf("migrate: %v", err) + } + return db +} + +func TestBatchCreateNodes(t *testing.T) { + db := openNodeServiceDB(t) + svc := NewNodeService(repository.NewNodeRepository(db), "test") + ctx := context.Background() + + items, err := svc.BatchCreate(ctx, []string{"a", "b", "c"}) + if err != nil { + t.Fatalf("batch: %v", err) + } + if len(items) != 3 { + t.Fatalf("expected 3, got %d", len(items)) + } + for _, it := range items { + if it.ID == 0 || it.Name == "" { + t.Errorf("invalid item %+v", it) + } + } +} + +func TestBatchCreateRejectsDuplicatesAgainstDB(t *testing.T) { + db := openNodeServiceDB(t) + svc := NewNodeService(repository.NewNodeRepository(db), "test") + ctx := context.Background() + + if _, err := svc.Create(ctx, NodeCreateInput{Name: "a"}); err != nil { + t.Fatalf("create: %v", err) + } + _, err := svc.BatchCreate(ctx, []string{"a", "b"}) + if err == nil { + t.Fatalf("expected error on duplicate with existing") + } +} + +func TestBatchCreateRejectsIntraBatchDuplicates(t *testing.T) { + db := openNodeServiceDB(t) + svc := NewNodeService(repository.NewNodeRepository(db), "test") + _, err := svc.BatchCreate(context.Background(), []string{"x", "x"}) + if err == nil { + t.Fatalf("expected error on intra-batch duplicate") + } +} + +func TestBatchCreateLimitEnforced(t *testing.T) { + db := openNodeServiceDB(t) + svc := NewNodeService(repository.NewNodeRepository(db), "test") + names := make([]string, 51) + for i := range names { + names[i] = "n" + string(rune('A'+i)) + } + _, err := svc.BatchCreate(context.Background(), names) + if err == nil { + t.Fatalf("expected error on >50 batch") + } +} + +func TestBatchCreateSkipsEmptyLines(t *testing.T) { + db := openNodeServiceDB(t) + svc := NewNodeService(repository.NewNodeRepository(db), "test") + items, err := svc.BatchCreate(context.Background(), []string{"a", " ", "", "b"}) + if err != nil { + t.Fatalf("batch: %v", err) + } + if len(items) != 2 { + t.Fatalf("expected 2 (a,b), got %d", len(items)) + } +} + +func TestRotateToken(t *testing.T) { + db := openNodeServiceDB(t) + repo := repository.NewNodeRepository(db) + svc := NewNodeService(repo, "test") + ctx := context.Background() + + _, err := svc.Create(ctx, NodeCreateInput{Name: "rot"}) + if err != nil { + t.Fatalf("create: %v", err) + } + var node model.Node + db.First(&node, "name = ?", "rot") + oldTok := node.Token + + newTok, err := svc.RotateToken(ctx, node.ID) + if err != nil { + t.Fatalf("rotate: %v", err) + } + if newTok == oldTok || len(newTok) != 64 { + t.Fatalf("invalid new token: %s", newTok) + } + + // 旧 token 仍可查(24h 内) + found, _ := repo.FindByToken(ctx, oldTok) + if found == nil || found.ID != node.ID { + t.Fatalf("old token should still work via prev_token fallback") + } + found2, _ := repo.FindByToken(ctx, newTok) + if found2 == nil || found2.ID != node.ID { + t.Fatalf("new token should work") + } + + db.First(&node, node.ID) + if node.PrevTokenExpires == nil { + t.Fatalf("prev_token_expires not set") + } + diff := node.PrevTokenExpires.Sub(time.Now().UTC()) + if diff < 23*time.Hour || diff > 25*time.Hour { + t.Fatalf("prev_token_expires out of range: %v", diff) + } +} + +func TestRotateTokenRejectsLocal(t *testing.T) { + db := openNodeServiceDB(t) + repo := repository.NewNodeRepository(db) + svc := NewNodeService(repo, "test") + ctx := context.Background() + + if err := svc.EnsureLocalNode(ctx); err != nil { + t.Fatalf("ensure local: %v", err) + } + local, _ := repo.FindLocal(ctx) + if _, err := svc.RotateToken(ctx, local.ID); err == nil { + t.Fatalf("expected error rotating local node") + } +} + +func TestRotateTokenNotFound(t *testing.T) { + db := openNodeServiceDB(t) + svc := NewNodeService(repository.NewNodeRepository(db), "test") + if _, err := svc.RotateToken(context.Background(), 9999); err == nil { + t.Fatalf("expected not found error") + } +}