From 7a6ffd4dddf8e3d7d739f2f22a8cafd6f1b122de Mon Sep 17 00:00:00 2001 From: Wu Qing <3184394176@qq.com> Date: Sat, 9 May 2026 23:03:25 +0800 Subject: [PATCH] =?UTF-8?q?feat(BackupX):=20=E4=BF=AE=E5=A4=8D=E8=B7=A8?= =?UTF-8?q?=E8=8A=82=E7=82=B9=E5=A4=87=E4=BB=BD=E6=81=A2=E5=A4=8D=E7=BB=88?= =?UTF-8?q?=E6=80=81=E5=A4=84=E7=90=86=20(#60)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(BackupX): 修复集群部署管理逻辑 * feat(BackupX): 修复节点池任务运行归属 * feat(BackupX): 修复跨节点恢复路由 * feat(BackupX): 修复跨节点备份恢复终态处理 * test(BackupX): 稳定安装流HTTP测试 --- .../internal/backup/retention/service_test.go | 3 + server/internal/http/install_flow_test.go | 148 ++++- server/internal/http/node_handler.go | 40 +- .../repository/agent_command_repository.go | 53 ++ .../agent_command_repository_test.go | 100 +++ .../agent_install_token_repository_test.go | 54 ++ .../repository/backup_record_repository.go | 15 + .../repository/backup_task_repository.go | 2 +- .../repository/backup_task_repository_test.go | 46 ++ server/internal/service/agent_service.go | 107 +++- server/internal/service/agent_service_test.go | 589 ++++++++++++++++++ .../service/backup_execution_service.go | 104 +++- .../service/backup_execution_service_test.go | 249 ++++++++ .../internal/service/install_token_service.go | 107 +++- .../service/install_token_service_test.go | 73 +++ server/internal/service/restore_service.go | 26 +- .../internal/service/restore_service_test.go | 191 +++++- web/src/pages/nodes/AgentInstallWizard.tsx | 151 ++--- web/src/pages/nodes/BatchCommandTable.test.ts | 30 + web/src/pages/nodes/BatchCommandTable.tsx | 56 +- web/src/pages/nodes/NodesPage.test.ts | 21 + web/src/pages/nodes/NodesPage.tsx | 82 ++- web/src/pages/nodes/installCommands.test.ts | 14 +- web/src/pages/nodes/installCommands.ts | 38 +- .../pages/nodes/useAgentDeployFlow.test.ts | 90 +++ web/src/pages/nodes/useAgentDeployFlow.ts | 146 +++++ .../nodes/wizard/Step3CommandPreview.tsx | 23 +- 27 files changed, 2311 insertions(+), 247 deletions(-) create mode 100644 server/internal/service/agent_service_test.go create mode 100644 web/src/pages/nodes/BatchCommandTable.test.ts create mode 100644 web/src/pages/nodes/NodesPage.test.ts create mode 100644 web/src/pages/nodes/useAgentDeployFlow.test.ts create mode 100644 web/src/pages/nodes/useAgentDeployFlow.ts diff --git a/server/internal/backup/retention/service_test.go b/server/internal/backup/retention/service_test.go index e1e2014..2c88af6 100644 --- a/server/internal/backup/retention/service_test.go +++ b/server/internal/backup/retention/service_test.go @@ -24,6 +24,9 @@ func (r *fakeRecordRepository) List(context.Context, repository.BackupRecordList func (r *fakeRecordRepository) FindByID(context.Context, uint) (*model.BackupRecord, error) { return nil, nil } +func (r *fakeRecordRepository) FindRunningByTaskAndNode(context.Context, uint, uint) (*model.BackupRecord, error) { + return nil, nil +} func (r *fakeRecordRepository) Create(context.Context, *model.BackupRecord) error { return nil } func (r *fakeRecordRepository) Update(context.Context, *model.BackupRecord) error { return nil } func (r *fakeRecordRepository) Delete(_ context.Context, id uint) error { diff --git a/server/internal/http/install_flow_test.go b/server/internal/http/install_flow_test.go index ecf39c3..005a670 100644 --- a/server/internal/http/install_flow_test.go +++ b/server/internal/http/install_flow_test.go @@ -25,10 +25,14 @@ import ( // setupInstallFlowRouter 构造一个 Node + Agent + InstallToken 全量依赖的 router, // 并返回已登录管理员 JWT。 func setupInstallFlowRouter(t *testing.T) (http.Handler, string) { + return setupInstallFlowRouterWithExternalURL(t, "") +} + +func setupInstallFlowRouterWithExternalURL(t *testing.T, externalURL string) (http.Handler, string) { t.Helper() tempDir := t.TempDir() cfg := config.Config{ - Server: config.ServerConfig{Host: "127.0.0.1", Port: 8340, Mode: "test"}, + Server: config.ServerConfig{Host: "127.0.0.1", Port: 8340, Mode: "test", ExternalURL: externalURL}, Database: config.DatabaseConfig{Path: filepath.Join(tempDir, "backupx.db")}, Security: config.SecurityConfig{JWTExpire: "24h"}, Log: config.LogConfig{Level: "error"}, @@ -68,9 +72,6 @@ func setupInstallFlowRouter(t *testing.T) (http.Handler, string) { installTokenRepo := repository.NewAgentInstallTokenRepository(db) installTokenSvc := service.NewInstallTokenService(installTokenRepo, nodeRepo) - auditLogRepo := repository.NewAuditLogRepository(db) - auditSvc := service.NewAuditService(auditLogRepo) - // 用 cancelable ctx,测试结束时停掉 handler 启动的后台 GC 协程, // 避免 goroutine 持有 map 导致 tempdir 清理失败。 ctx, cancel := context.WithCancel(context.Background()) @@ -85,7 +86,7 @@ func setupInstallFlowRouter(t *testing.T) (http.Handler, string) { SystemService: systemSvc, NodeService: nodeSvc, InstallTokenService: installTokenSvc, - AuditService: auditSvc, + MasterExternalURL: cfg.Server.ExternalURL, JWTManager: jwtMgr, UserRepository: userRepo, SystemConfigRepo: systemConfigRepo, @@ -114,6 +115,73 @@ func setupInstallFlowRouter(t *testing.T) (http.Handler, string) { return router, setupResp.Data.Token } +func TestInstallTokenUsesConfiguredExternalURL(t *testing.T) { + const externalURL = "https://public.example.com/base" + router, jwt := setupInstallFlowRouterWithExternalURL(t, externalURL) + + batchBody, _ := json.Marshal(map[string][]string{"names": {"external-url-node"}}) + batchReq := httptest.NewRequest(http.MethodPost, "/api/nodes/batch", bytes.NewBuffer(batchBody)) + batchReq.Header.Set("Content-Type", "application/json") + batchReq.Header.Set("Authorization", "Bearer "+jwt) + batchRec := httptest.NewRecorder() + router.ServeHTTP(batchRec, batchReq) + if batchRec.Code != 200 { + t.Fatalf("batch create failed: %d %s", batchRec.Code, batchRec.Body.String()) + } + var batchResp struct { + Data []struct { + ID uint `json:"id"` + } `json:"data"` + } + if err := json.Unmarshal(batchRec.Body.Bytes(), &batchResp); err != nil { + t.Fatalf("unmarshal batch: %v", err) + } + if len(batchResp.Data) != 1 { + t.Fatalf("expected 1 node, got %d", len(batchResp.Data)) + } + + genBody, _ := json.Marshal(map[string]any{ + "mode": "systemd", + "arch": "auto", + "agentVersion": "v1.7.0", + "downloadSrc": "github", + "ttlSeconds": 900, + }) + genReq := httptest.NewRequest(http.MethodPost, + "/api/nodes/"+formatUint(batchResp.Data[0].ID)+"/install-tokens", bytes.NewBuffer(genBody)) + genReq.Header.Set("Content-Type", "application/json") + genReq.Header.Set("Authorization", "Bearer "+jwt) + genRec := httptest.NewRecorder() + router.ServeHTTP(genRec, genReq) + if genRec.Code != 200 { + t.Fatalf("install-tokens failed: %d %s", genRec.Code, genRec.Body.String()) + } + var genResp struct { + Data struct { + InstallToken string `json:"installToken"` + URL string `json:"url"` + FallbackURL string `json:"fallbackUrl"` + ScriptBase64 string `json:"scriptBase64"` + } `json:"data"` + } + if err := json.Unmarshal(genRec.Body.Bytes(), &genResp); err != nil { + t.Fatalf("unmarshal gen: %v", err) + } + if genResp.Data.URL != externalURL+"/api/install/"+genResp.Data.InstallToken { + t.Fatalf("url should use external URL, got %q", genResp.Data.URL) + } + if genResp.Data.FallbackURL != externalURL+"/install/"+genResp.Data.InstallToken { + t.Fatalf("fallbackUrl should use external URL, got %q", genResp.Data.FallbackURL) + } + decodedScript, err := base64.StdEncoding.DecodeString(genResp.Data.ScriptBase64) + if err != nil { + t.Fatalf("scriptBase64 should be valid base64: %v", err) + } + if !strings.Contains(string(decodedScript), `MASTER_URL="`+externalURL+`"`) { + t.Fatalf("script should use external MASTER_URL:\n%s", string(decodedScript)) + } +} + func TestOneClickInstallFlow(t *testing.T) { router, jwt := setupInstallFlowRouter(t) @@ -428,6 +496,76 @@ func TestInstallFlowComposeModeMismatch(t *testing.T) { } } +func TestInstallFlowComposeSuccessConsumesToken(t *testing.T) { + router, jwt := setupInstallFlowRouter(t) + + batchBody, _ := json.Marshal(map[string][]string{"names": {"compose-ok"}}) + batchReq := httptest.NewRequest(http.MethodPost, "/api/nodes/batch", bytes.NewBuffer(batchBody)) + batchReq.Header.Set("Content-Type", "application/json") + batchReq.Header.Set("Authorization", "Bearer "+jwt) + batchRec := httptest.NewRecorder() + router.ServeHTTP(batchRec, batchReq) + if batchRec.Code != 200 { + t.Fatalf("batch create failed: %d %s", batchRec.Code, batchRec.Body.String()) + } + var batchResp struct { + Data []struct { + ID uint `json:"id"` + } `json:"data"` + } + if err := json.Unmarshal(batchRec.Body.Bytes(), &batchResp); err != nil { + t.Fatalf("unmarshal batch: %v", err) + } + if len(batchResp.Data) != 1 { + t.Fatalf("expected 1 node, got %d", len(batchResp.Data)) + } + + genBody, _ := json.Marshal(map[string]any{ + "mode": "docker", + "arch": "auto", + "agentVersion": "v1.7.0", + "downloadSrc": "github", + "ttlSeconds": 900, + }) + genReq := httptest.NewRequest(http.MethodPost, + "/api/nodes/"+formatUint(batchResp.Data[0].ID)+"/install-tokens", bytes.NewBuffer(genBody)) + genReq.Header.Set("Content-Type", "application/json") + genReq.Header.Set("Authorization", "Bearer "+jwt) + genRec := httptest.NewRecorder() + router.ServeHTTP(genRec, genReq) + if genRec.Code != 200 { + t.Fatalf("install-tokens failed: %d %s", genRec.Code, genRec.Body.String()) + } + var genResp struct { + Data struct { + InstallToken string `json:"installToken"` + } `json:"data"` + } + if err := json.Unmarshal(genRec.Body.Bytes(), &genResp); err != nil { + t.Fatalf("unmarshal gen: %v", err) + } + if genResp.Data.InstallToken == "" { + t.Fatalf("missing installToken") + } + + composeReq := httptest.NewRequest(http.MethodGet, "/api/install/"+genResp.Data.InstallToken+"/compose.yml", nil) + composeRec := httptest.NewRecorder() + router.ServeHTTP(composeRec, composeReq) + if composeRec.Code != 200 { + t.Fatalf("compose fetch failed: %d %s", composeRec.Code, composeRec.Body.String()) + } + if !strings.Contains(composeRec.Body.String(), "BACKUPX_AGENT_TOKEN") { + t.Fatalf("compose missing token env:\n%s", composeRec.Body.String()) + } + + scriptReq := httptest.NewRequest(http.MethodGet, "/api/install/"+genResp.Data.InstallToken, nil) + scriptRec := httptest.NewRecorder() + router.ServeHTTP(scriptRec, scriptReq) + if scriptRec.Code != http.StatusGone { + t.Fatalf("script after compose should be 410, got %d: %s", scriptRec.Code, scriptRec.Body.String()) + } +} + // formatUint 小工具:uint → 十进制字符串(无需引入 strconv)。 func formatUint(u uint) string { if u == 0 { diff --git a/server/internal/http/node_handler.go b/server/internal/http/node_handler.go index 44ba878..249694b 100644 --- a/server/internal/http/node_handler.go +++ b/server/internal/http/node_handler.go @@ -1,7 +1,6 @@ package http import ( - "encoding/base64" "fmt" stdhttp "net/http" "strconv" @@ -245,14 +244,17 @@ func (h *NodeHandler) CreateInstallToken(c *gin.Context) { input.TTLSeconds = 900 } - out, err := h.installTokenSvc.Create(c.Request.Context(), service.InstallTokenInput{ - NodeID: uint(id), - Mode: input.Mode, - Arch: input.Arch, - AgentVersion: input.AgentVersion, - DownloadSrc: input.DownloadSrc, - TTLSeconds: input.TTLSeconds, - CreatedByID: h.resolveCurrentUserID(c), + out, err := h.installTokenSvc.CreateCommand(c.Request.Context(), service.InstallCommandInput{ + InstallTokenInput: service.InstallTokenInput{ + NodeID: uint(id), + Mode: input.Mode, + Arch: input.Arch, + AgentVersion: input.AgentVersion, + DownloadSrc: input.DownloadSrc, + TTLSeconds: input.TTLSeconds, + CreatedByID: h.resolveCurrentUserID(c), + }, + MasterURL: resolveMasterURL(c, h.externalURL), }) if err != nil { response.Error(c, err) @@ -262,12 +264,6 @@ func (h *NodeHandler) CreateInstallToken(c *gin.Context) { fmt.Sprintf("%d", id), out.Node.Name, fmt.Sprintf("生成 %s/%s install token TTL=%ds", input.Mode, input.Arch, input.TTLSeconds)) - masterURL := resolveMasterURL(c, h.externalURL) - script, err := renderInstallScript(masterURL, out.Node, out.Record) - if err != nil { - response.Error(c, err) - return - } // 使用 /api/install/... 而非 /install/... —— 让反向代理的 /api/ 转发规则 // 自动接管,避免 SPA fallback 把请求当成前端路由返回 index.html(issue #46)。 // 同时返回 /install/... 备用地址,兼容会剥离 /api 前缀的外层反向代理。 @@ -276,15 +272,11 @@ func (h *NodeHandler) CreateInstallToken(c *gin.Context) { body := gin.H{ "installToken": out.Token, "expiresAt": out.ExpiresAt, - "url": masterURL + "/api/install/" + out.Token, - "fallbackUrl": masterURL + "/install/" + out.Token, - "scriptBase64": base64.StdEncoding.EncodeToString([]byte(script)), - "composeUrl": "", - "fallbackComposeUrl": "", - } - if input.Mode == "docker" { - body["composeUrl"] = masterURL + "/api/install/" + out.Token + "/compose.yml" - body["fallbackComposeUrl"] = masterURL + "/install/" + out.Token + "/compose.yml" + "url": out.URL, + "fallbackUrl": out.FallbackURL, + "scriptBase64": out.ScriptBase64, + "composeUrl": out.ComposeURL, + "fallbackComposeUrl": out.FallbackComposeURL, } response.Success(c, body) } diff --git a/server/internal/repository/agent_command_repository.go b/server/internal/repository/agent_command_repository.go index aff682b..c5fbb96 100644 --- a/server/internal/repository/agent_command_repository.go +++ b/server/internal/repository/agent_command_repository.go @@ -17,12 +17,21 @@ type AgentCommandRepository interface { // 并返回领取到的命令。无命令时返回 (nil, nil)。 ClaimPending(ctx context.Context, nodeID uint) (*model.AgentCommand, error) Update(ctx context.Context, cmd *model.AgentCommand) error + // CompleteDispatched 只在命令仍处于 dispatched 时写入终态。 + // 返回 false 表示命令已被超时监控或其它流程终结,调用方不应覆盖。 + CompleteDispatched(ctx context.Context, cmd *model.AgentCommand) (bool, error) // MarkStaleTimeout 把 dispatched 状态但超时未完成的命令标记为 timeout。 // 返回被标记的行数。不返回具体命令(供背景监控简单调用)。 MarkStaleTimeout(ctx context.Context, threshold time.Time) (int64, error) + // TimeoutActive 只在命令仍处于 pending/dispatched 时写入 timeout。 + // 返回 false 表示命令已被 Agent 回写为终态,调用方不应覆盖。 + TimeoutActive(ctx context.Context, cmd *model.AgentCommand) (bool, error) // ListStaleDispatched 列出 dispatched 但已超时、尚未被标记的命令。 // 调用方需要把它们逐一标记 timeout 并联动关联记录状态。 ListStaleDispatched(ctx context.Context, threshold time.Time) ([]model.AgentCommand, error) + // ListStaleActive 列出 pending/dispatched 但已超时、尚未完成的命令。 + // pending 使用 created_at 判定,dispatched 使用 dispatched_at 判定。 + ListStaleActive(ctx context.Context, threshold time.Time) ([]model.AgentCommand, error) // ListPendingByNode 列出某节点下的所有 pending/dispatched 命令。 // 用于删除节点或节点离线时的清理。 ListPendingByNode(ctx context.Context, nodeID uint) ([]model.AgentCommand, error) @@ -94,6 +103,21 @@ func (r *GormAgentCommandRepository) Update(ctx context.Context, cmd *model.Agen return r.db.WithContext(ctx).Save(cmd).Error } +func (r *GormAgentCommandRepository) CompleteDispatched(ctx context.Context, cmd *model.AgentCommand) (bool, error) { + result := r.db.WithContext(ctx).Model(&model.AgentCommand{}). + Where("id = ? AND node_id = ? AND status = ?", cmd.ID, cmd.NodeID, model.AgentCommandStatusDispatched). + Updates(map[string]any{ + "status": cmd.Status, + "error_message": cmd.ErrorMessage, + "result": cmd.Result, + "completed_at": cmd.CompletedAt, + }) + if result.Error != nil { + return false, result.Error + } + return result.RowsAffected > 0, nil +} + func (r *GormAgentCommandRepository) MarkStaleTimeout(ctx context.Context, threshold time.Time) (int64, error) { result := r.db.WithContext(ctx).Model(&model.AgentCommand{}). Where("status = ? AND dispatched_at < ?", model.AgentCommandStatusDispatched, threshold). @@ -107,6 +131,20 @@ func (r *GormAgentCommandRepository) MarkStaleTimeout(ctx context.Context, thres return result.RowsAffected, nil } +func (r *GormAgentCommandRepository) TimeoutActive(ctx context.Context, cmd *model.AgentCommand) (bool, error) { + result := r.db.WithContext(ctx).Model(&model.AgentCommand{}). + Where("id = ? AND status IN ?", cmd.ID, []string{model.AgentCommandStatusPending, model.AgentCommandStatusDispatched}). + Updates(map[string]any{ + "status": model.AgentCommandStatusTimeout, + "error_message": cmd.ErrorMessage, + "completed_at": cmd.CompletedAt, + }) + if result.Error != nil { + return false, result.Error + } + return result.RowsAffected > 0, nil +} + // ListStaleDispatched 列出 dispatched 但 dispatched_at 早于 threshold 的命令。 func (r *GormAgentCommandRepository) ListStaleDispatched(ctx context.Context, threshold time.Time) ([]model.AgentCommand, error) { var items []model.AgentCommand @@ -119,6 +157,21 @@ func (r *GormAgentCommandRepository) ListStaleDispatched(ctx context.Context, th return items, nil } +func (r *GormAgentCommandRepository) ListStaleActive(ctx context.Context, threshold time.Time) ([]model.AgentCommand, error) { + var items []model.AgentCommand + if err := r.db.WithContext(ctx). + Where( + "(status = ? AND created_at < ?) OR (status = ? AND dispatched_at < ?)", + model.AgentCommandStatusPending, threshold, + model.AgentCommandStatusDispatched, threshold, + ). + Order("id asc"). + Find(&items).Error; err != nil { + return nil, err + } + return items, nil +} + // ListPendingByNode 列出某节点下所有待执行(pending 或 dispatched)命令。 func (r *GormAgentCommandRepository) ListPendingByNode(ctx context.Context, nodeID uint) ([]model.AgentCommand, error) { var items []model.AgentCommand diff --git a/server/internal/repository/agent_command_repository_test.go b/server/internal/repository/agent_command_repository_test.go index 7d53e8e..6f68913 100644 --- a/server/internal/repository/agent_command_repository_test.go +++ b/server/internal/repository/agent_command_repository_test.go @@ -90,6 +90,78 @@ func TestAgentCommandRepository_Update(t *testing.T) { } } +func TestAgentCommandRepository_CompleteDispatchedOnlyUpdatesDispatchedCommand(t *testing.T) { + db := newTestDB(t) + repo := NewAgentCommandRepository(db) + ctx := context.Background() + dispatched := &model.AgentCommand{NodeID: 1, Type: "run_task", Status: model.AgentCommandStatusDispatched} + timeout := &model.AgentCommand{NodeID: 1, Type: "run_task", Status: model.AgentCommandStatusTimeout, ErrorMessage: "timeout"} + if err := repo.Create(ctx, dispatched); err != nil { + t.Fatalf("Create dispatched returned error: %v", err) + } + if err := repo.Create(ctx, timeout); err != nil { + t.Fatalf("Create timeout returned error: %v", err) + } + + now := time.Now().UTC() + dispatched.Status = model.AgentCommandStatusSucceeded + dispatched.Result = `{"ok":true}` + dispatched.CompletedAt = &now + updated, err := repo.CompleteDispatched(ctx, dispatched) + if err != nil { + t.Fatalf("CompleteDispatched returned error: %v", err) + } + if !updated { + t.Fatal("expected dispatched command to be updated") + } + + timeout.Status = model.AgentCommandStatusSucceeded + timeout.Result = `{"late":true}` + timeout.CompletedAt = &now + updated, err = repo.CompleteDispatched(ctx, timeout) + if err != nil { + t.Fatalf("CompleteDispatched terminal returned error: %v", err) + } + if updated { + t.Fatal("expected terminal command not to be updated") + } + gotTimeout, err := repo.FindByID(ctx, timeout.ID) + if err != nil { + t.Fatalf("FindByID timeout returned error: %v", err) + } + if gotTimeout.Status != model.AgentCommandStatusTimeout || gotTimeout.Result != "" { + t.Fatalf("expected timeout command unchanged, got %#v", gotTimeout) + } +} + +func TestAgentCommandRepository_TimeoutActiveDoesNotOverwriteTerminalCommand(t *testing.T) { + db := newTestDB(t) + repo := NewAgentCommandRepository(db) + ctx := context.Background() + succeeded := &model.AgentCommand{NodeID: 1, Type: "run_task", Status: model.AgentCommandStatusSucceeded, Result: `{"ok":true}`} + if err := repo.Create(ctx, succeeded); err != nil { + t.Fatalf("Create succeeded returned error: %v", err) + } + + now := time.Now().UTC() + succeeded.ErrorMessage = "timeout" + succeeded.CompletedAt = &now + updated, err := repo.TimeoutActive(ctx, succeeded) + if err != nil { + t.Fatalf("TimeoutActive returned error: %v", err) + } + if updated { + t.Fatal("expected terminal command not to be timed out") + } + got, err := repo.FindByID(ctx, succeeded.ID) + if err != nil { + t.Fatalf("FindByID returned error: %v", err) + } + if got.Status != model.AgentCommandStatusSucceeded || got.ErrorMessage != "" || got.Result != `{"ok":true}` { + t.Fatalf("expected succeeded command unchanged, got %#v", got) + } +} + func TestAgentCommandRepository_MarkStaleTimeout(t *testing.T) { db := newTestDB(t) repo := NewAgentCommandRepository(db) @@ -118,3 +190,31 @@ func TestAgentCommandRepository_MarkStaleTimeout(t *testing.T) { t.Errorf("new should stay dispatched: %+v", newGot) } } + +func TestAgentCommandRepository_ListStaleActiveIncludesPendingAndDispatched(t *testing.T) { + db := newTestDB(t) + repo := NewAgentCommandRepository(db) + ctx := context.Background() + old := time.Now().Add(-time.Hour) + recent := time.Now() + oldPending := &model.AgentCommand{NodeID: 1, Type: "run_task", Status: model.AgentCommandStatusPending, CreatedAt: old} + oldDispatched := &model.AgentCommand{NodeID: 1, Type: "restore_record", Status: model.AgentCommandStatusDispatched, DispatchedAt: &old} + recentPending := &model.AgentCommand{NodeID: 1, Type: "run_task", Status: model.AgentCommandStatusPending, CreatedAt: recent} + succeeded := &model.AgentCommand{NodeID: 1, Type: "run_task", Status: model.AgentCommandStatusSucceeded, CreatedAt: old} + for _, cmd := range []*model.AgentCommand{oldPending, oldDispatched, recentPending, succeeded} { + if err := repo.Create(ctx, cmd); err != nil { + t.Fatalf("Create returned error: %v", err) + } + } + + items, err := repo.ListStaleActive(ctx, time.Now().Add(-30*time.Minute)) + if err != nil { + t.Fatalf("ListStaleActive returned error: %v", err) + } + if len(items) != 2 { + t.Fatalf("expected 2 stale active commands, got %#v", items) + } + if items[0].ID != oldPending.ID || items[1].ID != oldDispatched.ID { + t.Fatalf("unexpected stale active order/items: %#v", items) + } +} diff --git a/server/internal/repository/agent_install_token_repository_test.go b/server/internal/repository/agent_install_token_repository_test.go index e0f9534..b7c6b26 100644 --- a/server/internal/repository/agent_install_token_repository_test.go +++ b/server/internal/repository/agent_install_token_repository_test.go @@ -3,6 +3,7 @@ package repository import ( "context" "path/filepath" + "sync" "testing" "time" @@ -83,6 +84,59 @@ func TestInstallTokenConsumeExpired(t *testing.T) { } } +func TestInstallTokenConsumeConcurrentOnlyOneWins(t *testing.T) { + db := openTestInstallTokenDB(t) + repo := NewAgentInstallTokenRepository(db) + ctx := context.Background() + + tok := &model.AgentInstallToken{ + Token: "concurrent", NodeID: 1, Mode: model.InstallModeSystemd, + Arch: model.InstallArchAuto, AgentVer: "v1.7.0", + DownloadSrc: model.InstallSourceGitHub, + ExpiresAt: time.Now().UTC().Add(15 * time.Minute), + CreatedByID: 1, + } + if err := repo.Create(ctx, tok); err != nil { + t.Fatalf("create: %v", err) + } + + const workers = 8 + var wg sync.WaitGroup + start := make(chan struct{}) + results := make(chan *model.AgentInstallToken, workers) + errs := make(chan error, workers) + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-start + got, err := repo.ConsumeByToken(ctx, "concurrent") + if err != nil { + errs <- err + return + } + results <- got + }() + } + close(start) + wg.Wait() + close(results) + close(errs) + + for err := range errs { + t.Fatalf("consume err: %v", err) + } + success := 0 + for got := range results { + if got != nil { + success++ + } + } + if success != 1 { + t.Fatalf("expected exactly one successful consume, got %d", success) + } +} + func TestInstallTokenGC(t *testing.T) { db := openTestInstallTokenDB(t) repo := NewAgentInstallTokenRepository(db) diff --git a/server/internal/repository/backup_record_repository.go b/server/internal/repository/backup_record_repository.go index 5752501..523df7a 100644 --- a/server/internal/repository/backup_record_repository.go +++ b/server/internal/repository/backup_record_repository.go @@ -33,6 +33,7 @@ type BackupStorageUsageItem struct { type BackupRecordRepository interface { List(context.Context, BackupRecordListOptions) ([]model.BackupRecord, error) FindByID(context.Context, uint) (*model.BackupRecord, error) + FindRunningByTaskAndNode(context.Context, uint, uint) (*model.BackupRecord, error) Create(context.Context, *model.BackupRecord) error Update(context.Context, *model.BackupRecord) error Delete(context.Context, uint) error @@ -93,6 +94,20 @@ func (r *GormBackupRecordRepository) FindByID(ctx context.Context, id uint) (*mo return &item, nil } +func (r *GormBackupRecordRepository) FindRunningByTaskAndNode(ctx context.Context, taskID uint, nodeID uint) (*model.BackupRecord, error) { + var item model.BackupRecord + if err := r.db.WithContext(ctx). + Where("task_id = ? AND node_id = ? AND status = ?", taskID, nodeID, model.BackupRecordStatusRunning). + Order("id desc"). + First(&item).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, err + } + return &item, nil +} + func (r *GormBackupRecordRepository) Create(ctx context.Context, item *model.BackupRecord) error { return r.db.WithContext(ctx).Create(item).Error } diff --git a/server/internal/repository/backup_task_repository.go b/server/internal/repository/backup_task_repository.go index 6111b50..3923a97 100644 --- a/server/internal/repository/backup_task_repository.go +++ b/server/internal/repository/backup_task_repository.go @@ -226,7 +226,7 @@ func (r *GormBackupTaskRepository) Create(ctx context.Context, item *model.Backu } func (r *GormBackupTaskRepository) Update(ctx context.Context, item *model.BackupTask) error { - if err := r.db.WithContext(ctx).Save(item).Error; err != nil { + if err := r.db.WithContext(ctx).Omit("StorageTarget", "StorageTargets", "Node").Save(item).Error; err != nil { return err } if len(item.StorageTargets) > 0 { diff --git a/server/internal/repository/backup_task_repository_test.go b/server/internal/repository/backup_task_repository_test.go index e29ea18..aea21db 100644 --- a/server/internal/repository/backup_task_repository_test.go +++ b/server/internal/repository/backup_task_repository_test.go @@ -92,3 +92,49 @@ func TestBackupTaskRepositoryCRUD(t *testing.T) { t.Fatalf("expected task deleted, got %#v", deleted) } } + +func TestBackupTaskRepositoryUpdateCanClearNodeIDAfterPreload(t *testing.T) { + ctx := context.Background() + repo := newBackupTaskTestRepository(t) + remoteNode := &model.Node{Name: "edge-1", Token: "edge-token", Status: model.NodeStatusOnline, IsLocal: false} + if err := repo.db.WithContext(ctx).Create(remoteNode).Error; err != nil { + t.Fatalf("create node: %v", err) + } + task := &model.BackupTask{ + Name: "pooled-source", + Type: "file", + Enabled: true, + SourcePath: "/srv/www/site", + StorageTargetID: 1, + NodeID: remoteNode.ID, + RetentionDays: 30, + Compression: "gzip", + MaxBackups: 10, + LastStatus: "idle", + } + if err := repo.Create(ctx, task); err != nil { + t.Fatalf("Create returned error: %v", err) + } + loaded, err := repo.FindByID(ctx, task.ID) + if err != nil { + t.Fatalf("FindByID returned error: %v", err) + } + if loaded == nil || loaded.Node.ID != remoteNode.ID { + t.Fatalf("expected preloaded node %d, got %#v", remoteNode.ID, loaded) + } + loaded.NodeID = 0 + loaded.NodePoolTag = "db" + if err := repo.Update(ctx, loaded); err != nil { + t.Fatalf("Update returned error: %v", err) + } + stored, err := repo.FindByID(ctx, task.ID) + if err != nil { + t.Fatalf("FindByID after update returned error: %v", err) + } + if stored.NodeID != 0 { + t.Fatalf("expected NodeID to be cleared, got %d", stored.NodeID) + } + if stored.NodePoolTag != "db" { + t.Fatalf("expected NodePoolTag db, got %q", stored.NodePoolTag) + } +} diff --git a/server/internal/service/agent_service.go b/server/internal/service/agent_service.go index a751391..b3f551f 100644 --- a/server/internal/service/agent_service.go +++ b/server/internal/service/agent_service.go @@ -118,7 +118,8 @@ func (s *AgentService) SubmitCommandResult(ctx context.Context, node *model.Node cmd.Result = string(result.Result) } cmd.CompletedAt = &now - return s.cmdRepo.Update(ctx, cmd) + _, err = s.cmdRepo.CompleteDispatched(ctx, cmd) + return err } // AgentTaskSpec 给 Agent 返回的任务规格,包含解密后的存储配置,供 Agent 直接执行。 @@ -159,8 +160,8 @@ func (s *AgentService) GetTaskSpec(ctx context.Context, node *model.Node, taskID if task == nil { return nil, apperror.New(404, "BACKUP_TASK_NOT_FOUND", "任务不存在", nil) } - if task.NodeID != node.ID { - return nil, apperror.Unauthorized("BACKUP_TASK_FORBIDDEN", "任务不属于当前节点", nil) + if err := s.ensureTaskSpecAccess(ctx, node, task); err != nil { + return nil, err } // 解密数据库密码(若有) dbPassword := "" @@ -213,6 +214,20 @@ func (s *AgentService) GetTaskSpec(ctx context.Context, node *model.Node, taskID }, nil } +func (s *AgentService) ensureTaskSpecAccess(ctx context.Context, node *model.Node, task *model.BackupTask) error { + if task.NodeID == node.ID { + return nil + } + record, err := s.recordRepo.FindRunningByTaskAndNode(ctx, task.ID, node.ID) + if err != nil { + return err + } + if record == nil { + return apperror.Unauthorized("BACKUP_TASK_FORBIDDEN", "任务不属于当前节点", nil) + } + return nil +} + // AgentRecordUpdate Agent 上报备份记录的最终状态。 type AgentRecordUpdate struct { Status string `json:"status"` // running | success | failed @@ -233,14 +248,16 @@ func (s *AgentService) UpdateRecord(ctx context.Context, node *model.Node, recor if record == nil { return apperror.New(404, "BACKUP_RECORD_NOT_FOUND", "记录不存在", nil) } - // 通过 task.NodeID 判断是否属于当前 agent task, err := s.taskRepo.FindByID(ctx, record.TaskID) if err != nil { return err } - if task == nil || task.NodeID != node.ID { + if task == nil || !recordBelongsToNode(record, task, node.ID) { return apperror.Unauthorized("BACKUP_RECORD_FORBIDDEN", "记录不属于当前节点", nil) } + if isBackupRecordTerminal(record.Status) { + return nil + } if update.Status != "" { record.Status = update.Status } @@ -282,6 +299,17 @@ func (s *AgentService) UpdateRecord(ctx context.Context, node *model.Node, recor return nil } +func recordBelongsToNode(record *model.BackupRecord, task *model.BackupTask, nodeID uint) bool { + if record.NodeID != 0 { + return record.NodeID == nodeID + } + return task.NodeID == nodeID +} + +func isBackupRecordTerminal(status string) bool { + return status == model.BackupRecordStatusSuccess || status == model.BackupRecordStatusFailed +} + // EnqueueCommand Master 端调用:给指定节点插入一条待执行命令。 // 返回命令 ID。 func (s *AgentService) EnqueueCommand(ctx context.Context, nodeID uint, cmdType string, payload any) (uint, error) { @@ -356,25 +384,84 @@ func (s *AgentService) StartCommandTimeoutMonitor(ctx context.Context, interval }() } -// processStaleCommands 扫描已超时的 dispatched 命令并联动关联记录。 -// 流程:先取超时候选 → 对每条联动 backup/restore 记录 → 把命令置为 timeout。 +// processStaleCommands 扫描已超时的 pending/dispatched 命令并联动关联记录。 +// 流程:先取超时候选 → 条件式把命令置为 timeout → 对抢到的命令联动 backup/restore 记录。 // 单条失败不影响后续处理。 func (s *AgentService) processStaleCommands(ctx context.Context, threshold time.Time) { - commands, err := s.cmdRepo.ListStaleDispatched(ctx, threshold) + commands, err := s.cmdRepo.ListStaleActive(ctx, threshold) if err != nil || len(commands) == 0 { return } for i := range commands { cmd := commands[i] - s.failLinkedRecord(ctx, &cmd) + if s.commandStillActive(ctx, &cmd, threshold) { + continue + } now := time.Now().UTC() cmd.Status = model.AgentCommandStatusTimeout cmd.ErrorMessage = "agent did not report result before timeout" cmd.CompletedAt = &now - _ = s.cmdRepo.Update(ctx, &cmd) + timedOut, err := s.cmdRepo.TimeoutActive(ctx, &cmd) + if err != nil || !timedOut { + continue + } + s.failLinkedRecord(ctx, &cmd) } } +// commandStillActive 用关联记录状态、记录更新时间和节点心跳作为长任务续租信号。 +// 仅 run_task / restore_record 允许续租,避免短 RPC 命令被在线节点长期保留。 +func (s *AgentService) commandStillActive(ctx context.Context, cmd *model.AgentCommand, threshold time.Time) bool { + if cmd.Status != model.AgentCommandStatusDispatched { + return false + } + switch cmd.Type { + case model.AgentCommandTypeRunTask: + var payload struct { + RecordID uint `json:"recordId"` + } + if err := json.Unmarshal([]byte(cmd.Payload), &payload); err != nil || payload.RecordID == 0 { + return false + } + record, err := s.recordRepo.FindByID(ctx, payload.RecordID) + if err != nil || record == nil || record.Status != model.BackupRecordStatusRunning { + return false + } + if s.nodeRecentlySeen(ctx, cmd.NodeID, threshold) { + return true + } + return record.UpdatedAt.After(threshold) + case model.AgentCommandTypeRestoreRecord: + if s.restoreRepo == nil { + return false + } + var payload struct { + RestoreRecordID uint `json:"restoreRecordId"` + } + if err := json.Unmarshal([]byte(cmd.Payload), &payload); err != nil || payload.RestoreRecordID == 0 { + return false + } + restore, err := s.restoreRepo.FindByID(ctx, payload.RestoreRecordID) + if err != nil || restore == nil || restore.Status != model.RestoreRecordStatusRunning { + return false + } + if s.nodeRecentlySeen(ctx, cmd.NodeID, threshold) { + return true + } + return restore.UpdatedAt.After(threshold) + default: + return false + } +} + +func (s *AgentService) nodeRecentlySeen(ctx context.Context, nodeID uint, threshold time.Time) bool { + node, err := s.nodeRepo.FindByID(ctx, nodeID) + if err != nil || node == nil { + return false + } + return node.Status == model.NodeStatusOnline && node.LastSeen.After(threshold) +} + // failLinkedRecord 根据命令类型把关联记录标记为 failed。 // 只对仍然处于 running 状态的记录生效,避免覆盖已完成的结果。 func (s *AgentService) failLinkedRecord(ctx context.Context, cmd *model.AgentCommand) { diff --git a/server/internal/service/agent_service_test.go b/server/internal/service/agent_service_test.go new file mode 100644 index 0000000..7f6af2b --- /dev/null +++ b/server/internal/service/agent_service_test.go @@ -0,0 +1,589 @@ +package service + +import ( + "context" + "path/filepath" + "testing" + "time" + + "backupx/server/internal/config" + "backupx/server/internal/database" + "backupx/server/internal/logger" + "backupx/server/internal/model" + "backupx/server/internal/repository" + "backupx/server/internal/storage/codec" + "gorm.io/gorm" +) + +func newAgentServicePoolTestHarness(t *testing.T) (*AgentService, *gorm.DB, repository.BackupRecordRepository, repository.AgentCommandRepository, *model.Node, *model.Node) { + t.Helper() + log, err := logger.New(config.LogConfig{Level: "error"}) + if err != nil { + t.Fatalf("logger.New returned error: %v", err) + } + db, err := database.Open(config.DatabaseConfig{Path: filepath.Join(t.TempDir(), "backupx.db")}, log) + if err != nil { + t.Fatalf("database.Open returned error: %v", err) + } + cipher := codec.NewConfigCipher("agent-service-secret") + nodeRepo := repository.NewNodeRepository(db) + taskRepo := repository.NewBackupTaskRepository(db) + recordRepo := repository.NewBackupRecordRepository(db) + storageRepo := repository.NewStorageTargetRepository(db) + cmdRepo := repository.NewAgentCommandRepository(db) + + owner := &model.Node{Name: "edge-owner", Token: "owner-token", Status: model.NodeStatusOnline, IsLocal: false, LastSeen: time.Now().UTC()} + other := &model.Node{Name: "edge-other", Token: "other-token", Status: model.NodeStatusOnline, IsLocal: false, LastSeen: time.Now().UTC()} + if err := nodeRepo.Create(context.Background(), owner); err != nil { + t.Fatalf("create owner node: %v", err) + } + if err := nodeRepo.Create(context.Background(), other); err != nil { + t.Fatalf("create other node: %v", err) + } + targetConfig, err := cipher.EncryptJSON(map[string]any{"basePath": t.TempDir()}) + if err != nil { + t.Fatalf("EncryptJSON returned error: %v", err) + } + target := &model.StorageTarget{Name: "local", Type: "local_disk", Enabled: true, ConfigCiphertext: targetConfig, ConfigVersion: 1, LastTestStatus: "unknown"} + if err := storageRepo.Create(context.Background(), target); err != nil { + t.Fatalf("create storage target: %v", err) + } + task := &model.BackupTask{ + Name: "pooled-task", + Type: "file", + Enabled: true, + SourcePath: "/srv/data", + StorageTargetID: target.ID, + NodeID: 0, + NodePoolTag: "db", + RetentionDays: 30, + Compression: "gzip", + MaxBackups: 10, + LastStatus: "running", + } + if err := taskRepo.Create(context.Background(), task); err != nil { + t.Fatalf("create task: %v", err) + } + record := &model.BackupRecord{ + TaskID: task.ID, + StorageTargetID: target.ID, + NodeID: owner.ID, + Status: model.BackupRecordStatusRunning, + StartedAt: time.Now().UTC(), + } + if err := recordRepo.Create(context.Background(), record); err != nil { + t.Fatalf("create record: %v", err) + } + return NewAgentService(nodeRepo, taskRepo, recordRepo, storageRepo, cmdRepo, cipher), db, recordRepo, cmdRepo, owner, other +} + +func TestAgentServicePooledTaskUsesRecordNodeForSpecAndRecordUpdates(t *testing.T) { + svc, _, records, _, owner, other := newAgentServicePoolTestHarness(t) + ctx := context.Background() + + spec, err := svc.GetTaskSpec(ctx, owner, 1) + if err != nil { + t.Fatalf("owner GetTaskSpec returned error: %v", err) + } + if spec.TaskID != 1 || len(spec.StorageTargets) != 1 { + t.Fatalf("unexpected spec: %#v", spec) + } + if _, err := svc.GetTaskSpec(ctx, other, 1); err == nil { + t.Fatal("expected non-owner node to be forbidden from pooled task spec") + } + + if err := svc.UpdateRecord(ctx, owner, 1, AgentRecordUpdate{ + Status: model.BackupRecordStatusSuccess, + FileName: "backup.tar.gz", + FileSize: 123, + StoragePath: "tasks/1/backup.tar.gz", + }); err != nil { + t.Fatalf("owner UpdateRecord returned error: %v", err) + } + updated, err := records.FindByID(ctx, 1) + if err != nil { + t.Fatalf("FindByID returned error: %v", err) + } + if updated.Status != model.BackupRecordStatusSuccess || updated.NodeID != owner.ID { + t.Fatalf("unexpected updated record: %#v", updated) + } + if err := svc.UpdateRecord(ctx, other, 1, AgentRecordUpdate{LogAppend: "bad"}); err == nil { + t.Fatal("expected non-owner node to be forbidden from record update") + } +} + +func TestAgentServiceProcessStaleCommandsFailsPendingRunTaskRecord(t *testing.T) { + svc, _, records, commands, owner, _ := newAgentServicePoolTestHarness(t) + ctx := context.Background() + oldCommand := &model.AgentCommand{ + NodeID: owner.ID, + Type: model.AgentCommandTypeRunTask, + Status: model.AgentCommandStatusPending, + Payload: `{"recordId":1}`, + CreatedAt: time.Now().UTC().Add(-time.Hour), + } + if err := commands.Create(ctx, oldCommand); err != nil { + t.Fatalf("Create command returned error: %v", err) + } + + svc.processStaleCommands(ctx, time.Now().UTC().Add(-30*time.Minute)) + + updatedCommand, err := commands.FindByID(ctx, oldCommand.ID) + if err != nil { + t.Fatalf("FindByID command returned error: %v", err) + } + if updatedCommand.Status != model.AgentCommandStatusTimeout { + t.Fatalf("expected command timeout, got %#v", updatedCommand) + } + updatedRecord, err := records.FindByID(ctx, 1) + if err != nil { + t.Fatalf("FindByID record returned error: %v", err) + } + if updatedRecord.Status != model.BackupRecordStatusFailed { + t.Fatalf("expected record failed, got %#v", updatedRecord) + } + if updatedRecord.CompletedAt == nil { + t.Fatal("expected failed record completedAt to be set") + } +} + +func TestAgentServiceProcessStaleCommandsFailsPendingRestoreRecord(t *testing.T) { + svc, db, _, commands, owner, _ := newAgentServicePoolTestHarness(t) + ctx := context.Background() + restoreRepo := repository.NewRestoreRecordRepository(db) + restore := &model.RestoreRecord{ + BackupRecordID: 1, + TaskID: 1, + NodeID: owner.ID, + Status: model.RestoreRecordStatusRunning, + StartedAt: time.Now().UTC().Add(-time.Hour), + } + if err := restoreRepo.Create(ctx, restore); err != nil { + t.Fatalf("Create restore returned error: %v", err) + } + svc.SetRestoreRepository(restoreRepo) + oldCommand := &model.AgentCommand{ + NodeID: owner.ID, + Type: model.AgentCommandTypeRestoreRecord, + Status: model.AgentCommandStatusPending, + Payload: `{"restoreRecordId":1}`, + CreatedAt: time.Now().UTC().Add(-time.Hour), + } + if err := commands.Create(ctx, oldCommand); err != nil { + t.Fatalf("Create command returned error: %v", err) + } + + svc.processStaleCommands(ctx, time.Now().UTC().Add(-30*time.Minute)) + + updatedCommand, err := commands.FindByID(ctx, oldCommand.ID) + if err != nil { + t.Fatalf("FindByID command returned error: %v", err) + } + if updatedCommand.Status != model.AgentCommandStatusTimeout { + t.Fatalf("expected command timeout, got %#v", updatedCommand) + } + updatedRestore, err := restoreRepo.FindByID(ctx, restore.ID) + if err != nil { + t.Fatalf("FindByID restore returned error: %v", err) + } + if updatedRestore.Status != model.RestoreRecordStatusFailed { + t.Fatalf("expected restore failed, got %#v", updatedRestore) + } + if updatedRestore.CompletedAt == nil { + t.Fatal("expected failed restore completedAt to be set") + } +} + +func TestAgentServiceProcessStaleCommandsKeepsActiveDispatchedRunTaskRecord(t *testing.T) { + svc, _, records, commands, owner, _ := newAgentServicePoolTestHarness(t) + ctx := context.Background() + dispatchedAt := time.Now().UTC().Add(-time.Hour) + oldCommand := &model.AgentCommand{ + NodeID: owner.ID, + Type: model.AgentCommandTypeRunTask, + Status: model.AgentCommandStatusDispatched, + Payload: `{"recordId":1}`, + CreatedAt: dispatchedAt, + DispatchedAt: &dispatchedAt, + } + if err := commands.Create(ctx, oldCommand); err != nil { + t.Fatalf("Create command returned error: %v", err) + } + + svc.processStaleCommands(ctx, time.Now().UTC().Add(-30*time.Minute)) + + updatedCommand, err := commands.FindByID(ctx, oldCommand.ID) + if err != nil { + t.Fatalf("FindByID command returned error: %v", err) + } + if updatedCommand.Status != model.AgentCommandStatusDispatched { + t.Fatalf("expected active command to remain dispatched, got %#v", updatedCommand) + } + updatedRecord, err := records.FindByID(ctx, 1) + if err != nil { + t.Fatalf("FindByID record returned error: %v", err) + } + if updatedRecord.Status != model.BackupRecordStatusRunning { + t.Fatalf("expected active record to remain running, got %#v", updatedRecord) + } +} + +func TestAgentServiceProcessStaleCommandsKeepsDispatchedRunTaskWhenNodeHeartbeatIsFresh(t *testing.T) { + svc, db, records, commands, owner, _ := newAgentServicePoolTestHarness(t) + ctx := context.Background() + dispatchedAt := time.Now().UTC().Add(-time.Hour) + if err := setBackupRecordUpdatedAt(db, 1, dispatchedAt); err != nil { + t.Fatalf("set backup record updated_at: %v", err) + } + if err := db.Model(&model.Node{}).Where("id = ?", owner.ID).UpdateColumn("last_seen", time.Now().UTC()).Error; err != nil { + t.Fatalf("set owner last_seen: %v", err) + } + oldCommand := &model.AgentCommand{ + NodeID: owner.ID, + Type: model.AgentCommandTypeRunTask, + Status: model.AgentCommandStatusDispatched, + Payload: `{"recordId":1}`, + CreatedAt: dispatchedAt, + DispatchedAt: &dispatchedAt, + } + if err := commands.Create(ctx, oldCommand); err != nil { + t.Fatalf("Create command returned error: %v", err) + } + + svc.processStaleCommands(ctx, time.Now().UTC().Add(-30*time.Minute)) + + updatedCommand, err := commands.FindByID(ctx, oldCommand.ID) + if err != nil { + t.Fatalf("FindByID command returned error: %v", err) + } + if updatedCommand.Status != model.AgentCommandStatusDispatched { + t.Fatalf("expected command to remain dispatched while node heartbeat is fresh, got %#v", updatedCommand) + } + updatedRecord, err := records.FindByID(ctx, 1) + if err != nil { + t.Fatalf("FindByID record returned error: %v", err) + } + if updatedRecord.Status != model.BackupRecordStatusRunning { + t.Fatalf("expected record to remain running while node heartbeat is fresh, got %#v", updatedRecord) + } +} + +func TestAgentServiceProcessStaleCommandsTimesOutShortCommandEvenWhenNodeHeartbeatIsFresh(t *testing.T) { + svc, db, _, commands, owner, _ := newAgentServicePoolTestHarness(t) + ctx := context.Background() + dispatchedAt := time.Now().UTC().Add(-time.Hour) + if err := db.Model(&model.Node{}).Where("id = ?", owner.ID).UpdateColumn("last_seen", time.Now().UTC()).Error; err != nil { + t.Fatalf("set owner last_seen: %v", err) + } + oldCommand := &model.AgentCommand{ + NodeID: owner.ID, + Type: model.AgentCommandTypeListDir, + Status: model.AgentCommandStatusDispatched, + Payload: `{"path":"/srv"}`, + CreatedAt: dispatchedAt, + DispatchedAt: &dispatchedAt, + } + if err := commands.Create(ctx, oldCommand); err != nil { + t.Fatalf("Create command returned error: %v", err) + } + + svc.processStaleCommands(ctx, time.Now().UTC().Add(-30*time.Minute)) + + updatedCommand, err := commands.FindByID(ctx, oldCommand.ID) + if err != nil { + t.Fatalf("FindByID command returned error: %v", err) + } + if updatedCommand.Status != model.AgentCommandStatusTimeout { + t.Fatalf("expected stale short command timeout, got %#v", updatedCommand) + } +} + +func TestAgentServiceProcessStaleCommandsTimesOutDispatchedRunTaskWhenRecordIsTerminalEvenWithFreshHeartbeat(t *testing.T) { + svc, db, records, commands, owner, _ := newAgentServicePoolTestHarness(t) + ctx := context.Background() + dispatchedAt := time.Now().UTC().Add(-time.Hour) + if err := db.Model(&model.Node{}).Where("id = ?", owner.ID).UpdateColumn("last_seen", time.Now().UTC()).Error; err != nil { + t.Fatalf("set owner last_seen: %v", err) + } + record, err := records.FindByID(ctx, 1) + if err != nil { + t.Fatalf("FindByID record returned error: %v", err) + } + completedAt := time.Now().UTC().Add(-time.Minute) + record.Status = model.BackupRecordStatusFailed + record.CompletedAt = &completedAt + if err := records.Update(ctx, record); err != nil { + t.Fatalf("Update terminal record returned error: %v", err) + } + oldCommand := &model.AgentCommand{ + NodeID: owner.ID, + Type: model.AgentCommandTypeRunTask, + Status: model.AgentCommandStatusDispatched, + Payload: `{"recordId":1}`, + CreatedAt: dispatchedAt, + DispatchedAt: &dispatchedAt, + } + if err := commands.Create(ctx, oldCommand); err != nil { + t.Fatalf("Create command returned error: %v", err) + } + + svc.processStaleCommands(ctx, time.Now().UTC().Add(-30*time.Minute)) + + updatedCommand, err := commands.FindByID(ctx, oldCommand.ID) + if err != nil { + t.Fatalf("FindByID command returned error: %v", err) + } + if updatedCommand.Status != model.AgentCommandStatusTimeout { + t.Fatalf("expected command timeout when linked record is terminal, got %#v", updatedCommand) + } +} + +func TestAgentServiceProcessStaleCommandsTimesOutInactiveDispatchedRunTaskRecord(t *testing.T) { + svc, db, records, commands, owner, _ := newAgentServicePoolTestHarness(t) + ctx := context.Background() + dispatchedAt := time.Now().UTC().Add(-time.Hour) + if err := setBackupRecordUpdatedAt(db, 1, dispatchedAt); err != nil { + t.Fatalf("set backup record updated_at: %v", err) + } + if err := db.Model(&model.Node{}).Where("id = ?", owner.ID).UpdateColumn("last_seen", dispatchedAt).Error; err != nil { + t.Fatalf("set owner last_seen: %v", err) + } + oldCommand := &model.AgentCommand{ + NodeID: owner.ID, + Type: model.AgentCommandTypeRunTask, + Status: model.AgentCommandStatusDispatched, + Payload: `{"recordId":1}`, + CreatedAt: dispatchedAt, + DispatchedAt: &dispatchedAt, + } + if err := commands.Create(ctx, oldCommand); err != nil { + t.Fatalf("Create command returned error: %v", err) + } + + svc.processStaleCommands(ctx, time.Now().UTC().Add(-30*time.Minute)) + + updatedCommand, err := commands.FindByID(ctx, oldCommand.ID) + if err != nil { + t.Fatalf("FindByID command returned error: %v", err) + } + if updatedCommand.Status != model.AgentCommandStatusTimeout { + t.Fatalf("expected inactive command timeout, got %#v", updatedCommand) + } + updatedRecord, err := records.FindByID(ctx, 1) + if err != nil { + t.Fatalf("FindByID record returned error: %v", err) + } + if updatedRecord.Status != model.BackupRecordStatusFailed { + t.Fatalf("expected inactive record failed, got %#v", updatedRecord) + } +} + +func TestAgentServiceProcessStaleCommandsKeepsActiveDispatchedRestoreRecord(t *testing.T) { + svc, db, _, commands, owner, _ := newAgentServicePoolTestHarness(t) + ctx := context.Background() + restoreRepo := repository.NewRestoreRecordRepository(db) + restore := createAgentServiceRestoreRecord(t, restoreRepo, owner.ID) + svc.SetRestoreRepository(restoreRepo) + dispatchedAt := time.Now().UTC().Add(-time.Hour) + oldCommand := &model.AgentCommand{ + NodeID: owner.ID, + Type: model.AgentCommandTypeRestoreRecord, + Status: model.AgentCommandStatusDispatched, + Payload: `{"restoreRecordId":1}`, + CreatedAt: dispatchedAt, + DispatchedAt: &dispatchedAt, + } + if err := commands.Create(ctx, oldCommand); err != nil { + t.Fatalf("Create command returned error: %v", err) + } + + svc.processStaleCommands(ctx, time.Now().UTC().Add(-30*time.Minute)) + + updatedCommand, err := commands.FindByID(ctx, oldCommand.ID) + if err != nil { + t.Fatalf("FindByID command returned error: %v", err) + } + if updatedCommand.Status != model.AgentCommandStatusDispatched { + t.Fatalf("expected active restore command to remain dispatched, got %#v", updatedCommand) + } + updatedRestore, err := restoreRepo.FindByID(ctx, restore.ID) + if err != nil { + t.Fatalf("FindByID restore returned error: %v", err) + } + if updatedRestore.Status != model.RestoreRecordStatusRunning { + t.Fatalf("expected active restore to remain running, got %#v", updatedRestore) + } +} + +func TestAgentServiceProcessStaleCommandsKeepsDispatchedRestoreWhenNodeHeartbeatIsFresh(t *testing.T) { + svc, db, _, commands, owner, _ := newAgentServicePoolTestHarness(t) + ctx := context.Background() + restoreRepo := repository.NewRestoreRecordRepository(db) + restore := createAgentServiceRestoreRecord(t, restoreRepo, owner.ID) + svc.SetRestoreRepository(restoreRepo) + dispatchedAt := time.Now().UTC().Add(-time.Hour) + if err := setRestoreRecordUpdatedAt(db, restore.ID, dispatchedAt); err != nil { + t.Fatalf("set restore record updated_at: %v", err) + } + if err := db.Model(&model.Node{}).Where("id = ?", owner.ID).UpdateColumn("last_seen", time.Now().UTC()).Error; err != nil { + t.Fatalf("set owner last_seen: %v", err) + } + oldCommand := &model.AgentCommand{ + NodeID: owner.ID, + Type: model.AgentCommandTypeRestoreRecord, + Status: model.AgentCommandStatusDispatched, + Payload: `{"restoreRecordId":1}`, + CreatedAt: dispatchedAt, + DispatchedAt: &dispatchedAt, + } + if err := commands.Create(ctx, oldCommand); err != nil { + t.Fatalf("Create command returned error: %v", err) + } + + svc.processStaleCommands(ctx, time.Now().UTC().Add(-30*time.Minute)) + + updatedCommand, err := commands.FindByID(ctx, oldCommand.ID) + if err != nil { + t.Fatalf("FindByID command returned error: %v", err) + } + if updatedCommand.Status != model.AgentCommandStatusDispatched { + t.Fatalf("expected restore command to remain dispatched while node heartbeat is fresh, got %#v", updatedCommand) + } +} + +func TestAgentServiceProcessStaleCommandsTimesOutInactiveDispatchedRestoreRecord(t *testing.T) { + svc, db, _, commands, owner, _ := newAgentServicePoolTestHarness(t) + ctx := context.Background() + restoreRepo := repository.NewRestoreRecordRepository(db) + restore := createAgentServiceRestoreRecord(t, restoreRepo, owner.ID) + svc.SetRestoreRepository(restoreRepo) + dispatchedAt := time.Now().UTC().Add(-time.Hour) + if err := setRestoreRecordUpdatedAt(db, restore.ID, dispatchedAt); err != nil { + t.Fatalf("set restore record updated_at: %v", err) + } + if err := db.Model(&model.Node{}).Where("id = ?", owner.ID).UpdateColumn("last_seen", dispatchedAt).Error; err != nil { + t.Fatalf("set owner last_seen: %v", err) + } + oldCommand := &model.AgentCommand{ + NodeID: owner.ID, + Type: model.AgentCommandTypeRestoreRecord, + Status: model.AgentCommandStatusDispatched, + Payload: `{"restoreRecordId":1}`, + CreatedAt: dispatchedAt, + DispatchedAt: &dispatchedAt, + } + if err := commands.Create(ctx, oldCommand); err != nil { + t.Fatalf("Create command returned error: %v", err) + } + + svc.processStaleCommands(ctx, time.Now().UTC().Add(-30*time.Minute)) + + updatedCommand, err := commands.FindByID(ctx, oldCommand.ID) + if err != nil { + t.Fatalf("FindByID command returned error: %v", err) + } + if updatedCommand.Status != model.AgentCommandStatusTimeout { + t.Fatalf("expected inactive restore command timeout, got %#v", updatedCommand) + } + updatedRestore, err := restoreRepo.FindByID(ctx, restore.ID) + if err != nil { + t.Fatalf("FindByID restore returned error: %v", err) + } + if updatedRestore.Status != model.RestoreRecordStatusFailed { + t.Fatalf("expected inactive restore failed, got %#v", updatedRestore) + } +} + +func TestAgentServiceSubmitCommandResultDoesNotOverwriteTerminalCommand(t *testing.T) { + svc, _, _, commands, owner, _ := newAgentServicePoolTestHarness(t) + ctx := context.Background() + completedAt := time.Now().UTC().Add(-time.Minute) + command := &model.AgentCommand{ + NodeID: owner.ID, + Type: model.AgentCommandTypeRunTask, + Status: model.AgentCommandStatusTimeout, + Payload: `{"recordId":1}`, + ErrorMessage: "timeout", + CompletedAt: &completedAt, + } + if err := commands.Create(ctx, command); err != nil { + t.Fatalf("Create command returned error: %v", err) + } + + if err := svc.SubmitCommandResult(ctx, owner, command.ID, AgentCommandResult{Success: true, Result: []byte(`{"ok":true}`)}); err != nil { + t.Fatalf("SubmitCommandResult returned error: %v", err) + } + + updatedCommand, err := commands.FindByID(ctx, command.ID) + if err != nil { + t.Fatalf("FindByID command returned error: %v", err) + } + if updatedCommand.Status != model.AgentCommandStatusTimeout { + t.Fatalf("expected terminal command status to remain timeout, got %#v", updatedCommand) + } + if updatedCommand.Result != "" { + t.Fatalf("expected terminal command result to remain empty, got %q", updatedCommand.Result) + } +} + +func TestAgentServiceUpdateRecordDoesNotOverwriteTerminalRecord(t *testing.T) { + svc, _, records, _, owner, _ := newAgentServicePoolTestHarness(t) + ctx := context.Background() + record, err := records.FindByID(ctx, 1) + if err != nil { + t.Fatalf("FindByID record returned error: %v", err) + } + completedAt := time.Now().UTC().Add(-time.Minute) + record.Status = model.BackupRecordStatusFailed + record.ErrorMessage = "timeout" + record.CompletedAt = &completedAt + if err := records.Update(ctx, record); err != nil { + t.Fatalf("Update record returned error: %v", err) + } + + if err := svc.UpdateRecord(ctx, owner, record.ID, AgentRecordUpdate{ + Status: model.BackupRecordStatusSuccess, + FileName: "late.tar.gz", + FileSize: 42, + Checksum: "late", + StoragePath: "late/path", + ErrorMessage: "late success", + LogAppend: "late log\n", + }); err != nil { + t.Fatalf("UpdateRecord returned error: %v", err) + } + + updatedRecord, err := records.FindByID(ctx, record.ID) + if err != nil { + t.Fatalf("FindByID updated record returned error: %v", err) + } + if updatedRecord.Status != model.BackupRecordStatusFailed { + t.Fatalf("expected terminal record status to remain failed, got %#v", updatedRecord) + } + if updatedRecord.FileName != "" || updatedRecord.StoragePath != "" || updatedRecord.ErrorMessage != "timeout" { + t.Fatalf("expected terminal record fields to remain unchanged, got %#v", updatedRecord) + } +} + +func createAgentServiceRestoreRecord(t *testing.T, repo repository.RestoreRecordRepository, nodeID uint) *model.RestoreRecord { + t.Helper() + restore := &model.RestoreRecord{ + BackupRecordID: 1, + TaskID: 1, + NodeID: nodeID, + Status: model.RestoreRecordStatusRunning, + StartedAt: time.Now().UTC().Add(-time.Hour), + } + if err := repo.Create(context.Background(), restore); err != nil { + t.Fatalf("Create restore returned error: %v", err) + } + return restore +} + +func setBackupRecordUpdatedAt(db *gorm.DB, id uint, updatedAt time.Time) error { + return db.Model(&model.BackupRecord{}).Where("id = ?", id).UpdateColumn("updated_at", updatedAt).Error +} + +func setRestoreRecordUpdatedAt(db *gorm.DB, id uint, updatedAt time.Time) error { + return db.Model(&model.RestoreRecord{}).Where("id = ?", id).UpdateColumn("updated_at", updatedAt).Error +} diff --git a/server/internal/service/backup_execution_service.go b/server/internal/service/backup_execution_service.go index 56c5a54..52cbcc8 100644 --- a/server/internal/service/backup_execution_service.go +++ b/server/internal/service/backup_execution_service.go @@ -73,28 +73,28 @@ func collectTargetIDs(task *model.BackupTask) []uint { } type BackupExecutionService struct { - tasks repository.BackupTaskRepository - records repository.BackupRecordRepository - targets repository.StorageTargetRepository - nodeRepo repository.NodeRepository - storageRegistry *storage.Registry - runnerRegistry *backup.Registry - logHub *backup.LogHub - retention *backupretention.Service - cipher *codec.ConfigCipher + tasks repository.BackupTaskRepository + records repository.BackupRecordRepository + targets repository.StorageTargetRepository + nodeRepo repository.NodeRepository + storageRegistry *storage.Registry + runnerRegistry *backup.Registry + logHub *backup.LogHub + retention *backupretention.Service + cipher *codec.ConfigCipher notifier BackupResultNotifier agentDispatcher AgentDispatcher replicationHook ReplicationTrigger dependentsResolver DependentsResolver - async func(func()) - now func() time.Time - tempDir string - semaphore chan struct{} + async func(func()) + now func() time.Time + tempDir string + semaphore chan struct{} // nodeSemaphores 节点级并发限制(按 NodeID 映射)。 // 没命中的 NodeID 走全局 semaphore,节点配置 MaxConcurrent>0 时按该节点独立排队。 nodeSemaphores sync.Map - retries int // rclone 底层重试次数 - bandwidthLimit string // rclone 带宽限制(全局默认,节点配置可覆盖) + retries int // rclone 底层重试次数 + bandwidthLimit string // rclone 带宽限制(全局默认,节点配置可覆盖) metrics *metrics.Metrics } @@ -270,11 +270,9 @@ func (s *BackupExecutionService) DeleteRecord(ctx context.Context, recordID uint if record == nil { return apperror.New(404, "BACKUP_RECORD_NOT_FOUND", "备份记录不存在", fmt.Errorf("backup record %d not found", recordID)) } - // 集群场景保护:跨节点 local_disk 文件 Master 无法远程删除,拒绝操作以避免存储泄漏的错觉 - if err := s.validateClusterAccessible(ctx, record); err != nil { + if remote, err := s.deleteRemoteLocalDiskObject(ctx, record); err != nil { return err - } - if strings.TrimSpace(record.StoragePath) != "" { + } else if !remote && strings.TrimSpace(record.StoragePath) != "" { provider, err := s.resolveProvider(ctx, record.StorageTargetID) if err != nil { return err @@ -289,6 +287,40 @@ func (s *BackupExecutionService) DeleteRecord(ctx context.Context, recordID uint return nil } +func (s *BackupExecutionService) deleteRemoteLocalDiskObject(ctx context.Context, record *model.BackupRecord) (bool, error) { + if strings.TrimSpace(record.StoragePath) == "" || s.nodeRepo == nil { + return false, nil + } + node, err := s.nodeRepo.FindByID(ctx, record.NodeID) + if err != nil || node == nil || node.IsLocal { + return false, nil + } + target, err := s.targets.FindByID(ctx, record.StorageTargetID) + if err != nil { + return false, apperror.Internal("BACKUP_STORAGE_TARGET_GET_FAILED", "无法获取存储目标详情", err) + } + if target == nil || !strings.EqualFold(target.Type, "local_disk") { + return false, nil + } + if s.agentDispatcher == nil { + return true, apperror.BadRequest("BACKUP_RECORD_CROSS_NODE_LOCAL_DISK", + fmt.Sprintf("该备份位于节点 %s 的本地磁盘(local_disk),Master 无法跨节点删除。请确保 Agent 在线后再操作。", node.Name), + nil) + } + configMap := map[string]any{} + if err := s.cipher.DecryptJSON(target.ConfigCiphertext, &configMap); err != nil { + return true, apperror.Internal("BACKUP_STORAGE_TARGET_DECRYPT_FAILED", "无法解密存储目标配置", err) + } + if _, err := s.agentDispatcher.EnqueueCommand(ctx, record.NodeID, model.AgentCommandTypeDeleteStorageObject, map[string]any{ + "targetType": target.Type, + "targetConfig": configMap, + "storagePath": record.StoragePath, + }); err != nil { + return true, apperror.Internal("AGENT_COMMAND_ENQUEUE_FAILED", "无法下发远程备份文件删除命令", err) + } + return true, nil +} + // validateClusterAccessible 在跨节点 + local_disk 场景下拒绝 Master 端直接访问。 // 场景说明:远程 Agent 把备份写到其本机磁盘(local_disk basePath)时,Master 的 // provider 指向的是 Master 本机的同名路径,访问会静默取错文件或 404。明确拒绝 @@ -356,8 +388,8 @@ func (s *BackupExecutionService) startTask(ctx context.Context, id uint, async b if err := s.records.Create(ctx, record); err != nil { return nil, apperror.Internal("BACKUP_RECORD_CREATE_FAILED", "无法创建备份记录", err) } - // 用池选出的节点 ID 复写 task 副本,使后续路由/执行沿用 - task.NodeID = resolvedNodeID + runTask := *task + runTask.NodeID = resolvedNodeID task.LastRunAt = &startedAt task.LastStatus = "running" if err := s.tasks.Update(ctx, task); err != nil { @@ -365,27 +397,27 @@ func (s *BackupExecutionService) startTask(ctx context.Context, id uint, async b } // 多节点路由:task.NodeID 指向远程节点时,把执行任务入队给 Agent; // NodeID=0 或本机节点时由 Master 直接执行。 - if remoteNode := s.resolveRemoteNode(ctx, task.NodeID); remoteNode != nil { + if remoteNode := s.resolveRemoteNode(ctx, resolvedNodeID); remoteNode != nil { // 节点离线 → 立即把刚创建的 running 记录标记 failed,返回明确错误 if remoteNode.Status != model.NodeStatusOnline { offlineMsg := fmt.Sprintf("节点 %s 当前离线,无法执行备份任务", remoteNode.Name) - _ = s.finalizeRecord(ctx, task, record.ID, startedAt, model.BackupRecordStatusFailed, - offlineMsg, "", "", 0, "", "") + _ = s.finalizeRecord(ctx, &runTask, record.ID, startedAt, model.BackupRecordStatusFailed, + offlineMsg, "", "", 0, "", "", primaryTargetID) return nil, apperror.BadRequest("NODE_OFFLINE", offlineMsg, nil) } - if _, enqueueErr := s.agentDispatcher.EnqueueCommand(ctx, task.NodeID, model.AgentCommandTypeRunTask, map[string]any{ + if _, enqueueErr := s.agentDispatcher.EnqueueCommand(ctx, resolvedNodeID, model.AgentCommandTypeRunTask, map[string]any{ "taskId": task.ID, "recordId": record.ID, }); enqueueErr != nil { // 入队失败 → 在记录中标记失败,继续返回详情 - _ = s.finalizeRecord(ctx, task, record.ID, startedAt, model.BackupRecordStatusFailed, - "无法下发任务到远程节点: "+enqueueErr.Error(), "", "", 0, "", "") + _ = s.finalizeRecord(ctx, &runTask, record.ID, startedAt, model.BackupRecordStatusFailed, + "无法下发任务到远程节点: "+enqueueErr.Error(), "", "", 0, "", "", primaryTargetID) return nil, apperror.Internal("AGENT_COMMAND_ENQUEUE_FAILED", "无法下发任务到远程节点", enqueueErr) } return s.getRecordDetail(ctx, record.ID) } run := func() { - s.executeTask(context.Background(), task, record.ID, startedAt) + s.executeTask(context.Background(), &runTask, record.ID, startedAt) } if async { s.async(run) @@ -561,9 +593,10 @@ func (s *BackupExecutionService) executeTask(ctx context.Context, task *model.Ba var fileSize int64 var checksum string var storagePath string + selectedStorageTargetID := task.StorageTargetID var uploadResults []StorageUploadResultItem completeRecord := func() { - if finalizeErr := s.finalizeRecord(ctx, task, recordID, startedAt, status, errMessage, logger.String(), fileName, fileSize, checksum, storagePath); finalizeErr != nil { + if finalizeErr := s.finalizeRecord(ctx, task, recordID, startedAt, status, errMessage, logger.String(), fileName, fileSize, checksum, storagePath, selectedStorageTargetID); finalizeErr != nil { logger.Errorf("写回备份记录失败:%v", finalizeErr) } // 采集任务执行结果到 Prometheus(耗时 + 产出字节 + 状态计数) @@ -759,6 +792,9 @@ func (s *BackupExecutionService) executeTask(ctx context.Context, task *model.Ba for _, r := range uploadResults { if r.Status == "success" { anySuccess = true + if selectedStorageTargetID == task.StorageTargetID { + selectedStorageTargetID = r.StorageTargetID + } } else if r.Error != "" { failedMessages = append(failedMessages, fmt.Sprintf("%s: %s", r.StorageTargetName, r.Error)) } @@ -791,7 +827,7 @@ func (s *BackupExecutionService) executeTask(ctx context.Context, task *model.Ba record := &model.BackupRecord{ ID: recordID, TaskID: task.ID, - StorageTargetID: task.StorageTargetID, + StorageTargetID: selectedStorageTargetID, NodeID: task.NodeID, Status: "success", FileName: fileName, @@ -816,7 +852,7 @@ func (s *BackupExecutionService) executeTask(ctx context.Context, task *model.Ba } } -func (s *BackupExecutionService) finalizeRecord(ctx context.Context, task *model.BackupTask, recordID uint, startedAt time.Time, status string, errorMessage string, logContent string, fileName string, fileSize int64, checksum string, storagePath string) error { +func (s *BackupExecutionService) finalizeRecord(ctx context.Context, task *model.BackupTask, recordID uint, startedAt time.Time, status string, errorMessage string, logContent string, fileName string, fileSize int64, checksum string, storagePath string, storageTargetID uint) error { record, err := s.records.FindByID(ctx, recordID) if err != nil { return err @@ -826,6 +862,9 @@ func (s *BackupExecutionService) finalizeRecord(ctx context.Context, task *model } completedAt := s.now() record.Status = status + if storageTargetID > 0 { + record.StorageTargetID = storageTargetID + } record.FileName = fileName record.FileSize = fileSize record.Checksum = checksum @@ -957,6 +996,9 @@ func (s *BackupExecutionService) loadRecordProvider(ctx context.Context, recordI if record == nil { return nil, nil, apperror.New(404, "BACKUP_RECORD_NOT_FOUND", "备份记录不存在", fmt.Errorf("backup record %d not found", recordID)) } + if err := s.validateClusterAccessible(ctx, record); err != nil { + return nil, nil, err + } provider, err := s.resolveProvider(ctx, record.StorageTargetID) if err != nil { return nil, nil, err diff --git a/server/internal/service/backup_execution_service_test.go b/server/internal/service/backup_execution_service_test.go index 9e5d150..ca4eb8b 100644 --- a/server/internal/service/backup_execution_service_test.go +++ b/server/internal/service/backup_execution_service_test.go @@ -2,9 +2,13 @@ package service import ( "context" + "fmt" + "io" "os" "path/filepath" + "strings" "testing" + "time" "backupx/server/internal/backup" backupretention "backupx/server/internal/backup/retention" @@ -18,6 +22,62 @@ import ( storageRclone "backupx/server/internal/storage/rclone" ) +type testStorageFactory struct { + providers map[string]*testStorageProvider +} + +func (f *testStorageFactory) Type() storage.ProviderType { + return "test_storage" +} + +func (f *testStorageFactory) New(_ context.Context, config map[string]any) (storage.StorageProvider, error) { + name, _ := config["name"].(string) + provider := f.providers[name] + if provider == nil { + return nil, fmt.Errorf("unknown provider %q", name) + } + return provider, nil +} + +type testStorageProvider struct { + name string + failUpload bool + objects map[string][]byte +} + +func (p *testStorageProvider) Type() storage.ProviderType { return "test_storage" } +func (p *testStorageProvider) TestConnection(context.Context) error { + return nil +} +func (p *testStorageProvider) Upload(_ context.Context, objectKey string, reader io.Reader, _ int64, _ map[string]string) error { + if p.failUpload { + return fmt.Errorf("upload failed for %s", p.name) + } + data, err := io.ReadAll(reader) + if err != nil { + return err + } + if p.objects == nil { + p.objects = map[string][]byte{} + } + p.objects[objectKey] = data + return nil +} +func (p *testStorageProvider) Download(_ context.Context, objectKey string) (io.ReadCloser, error) { + data, ok := p.objects[objectKey] + if !ok { + return nil, fmt.Errorf("object %s not found", objectKey) + } + return io.NopCloser(strings.NewReader(string(data))), nil +} +func (p *testStorageProvider) Delete(_ context.Context, objectKey string) error { + delete(p.objects, objectKey) + return nil +} +func (p *testStorageProvider) List(context.Context, string) ([]storage.ObjectInfo, error) { + return nil, nil +} + func newExecutionTestServices(t *testing.T) (*BackupExecutionService, *BackupRecordService, repository.BackupTaskRepository, repository.StorageTargetRepository, repository.BackupRecordRepository, string, string) { t.Helper() baseDir := t.TempDir() @@ -85,6 +145,195 @@ func TestBackupExecutionServiceRunTaskByIDSync(t *testing.T) { } } +func TestBackupExecutionServiceNodePoolSelectionDoesNotPersistTaskNodeID(t *testing.T) { + executionService, _, tasks, _, records, _, _ := newExecutionTestServices(t) + ctx := context.Background() + + nodeRepo := &nodeRepoStub{nodes: []model.Node{ + {ID: 10, Name: "edge-a", Token: "edge-a-token", Status: model.NodeStatusOnline, Labels: "prod,db"}, + {ID: 11, Name: "edge-b", Token: "edge-b-token", Status: model.NodeStatusOnline, Labels: "prod,db"}, + }} + dispatcher := &fakeDispatcher{} + executionService.SetClusterDependencies(nodeRepo, dispatcher) + + task, err := tasks.FindByID(ctx, 1) + if err != nil { + t.Fatalf("FindByID returned error: %v", err) + } + task.NodeID = 0 + task.NodePoolTag = "db" + if err := tasks.Update(ctx, task); err != nil { + t.Fatalf("Update task returned error: %v", err) + } + + detail, err := executionService.RunTaskByID(ctx, 1) + if err != nil { + t.Fatalf("RunTaskByID returned error: %v", err) + } + storedTask, err := tasks.FindByID(ctx, 1) + if err != nil { + t.Fatalf("FindByID after run returned error: %v", err) + } + if storedTask.NodeID != 0 { + t.Fatalf("expected pooled task NodeID to remain 0, got %d", storedTask.NodeID) + } + if storedTask.NodePoolTag != "db" { + t.Fatalf("expected pooled task tag to remain db, got %q", storedTask.NodePoolTag) + } + storedRecord, err := records.FindByID(ctx, detail.ID) + if err != nil { + t.Fatalf("FindByID record returned error: %v", err) + } + if storedRecord == nil || storedRecord.NodeID != 10 { + t.Fatalf("expected record to keep selected node 10, got %#v", storedRecord) + } + calls := dispatcher.snapshot() + if len(calls) != 1 || calls[0].NodeID != 10 || calls[0].CmdType != model.AgentCommandTypeRunTask { + t.Fatalf("unexpected dispatcher calls: %#v", calls) + } +} + +func TestBackupExecutionServiceDeleteRecordDispatchesRemoteLocalDiskCleanup(t *testing.T) { + executionService, _, tasks, _, records, _, _ := newExecutionTestServices(t) + ctx := context.Background() + nodeRepo := &nodeRepoStub{nodes: []model.Node{ + {ID: 10, Name: "edge-a", Token: "edge-a-token", Status: model.NodeStatusOnline}, + }} + dispatcher := &fakeDispatcher{} + executionService.SetClusterDependencies(nodeRepo, dispatcher) + + task, err := tasks.FindByID(ctx, 1) + if err != nil { + t.Fatalf("FindByID task returned error: %v", err) + } + completedAt := time.Now().UTC() + record := &model.BackupRecord{ + TaskID: task.ID, + StorageTargetID: task.StorageTargetID, + NodeID: 10, + Status: model.BackupRecordStatusSuccess, + FileName: "remote.tar.gz", + StoragePath: "file/2026/05/09/remote.tar.gz", + StartedAt: completedAt.Add(-time.Second), + CompletedAt: &completedAt, + } + if err := records.Create(ctx, record); err != nil { + t.Fatalf("Create record returned error: %v", err) + } + + if err := executionService.DeleteRecord(ctx, record.ID); err != nil { + t.Fatalf("DeleteRecord returned error: %v", err) + } + deleted, err := records.FindByID(ctx, record.ID) + if err != nil { + t.Fatalf("FindByID record returned error: %v", err) + } + if deleted != nil { + t.Fatalf("expected record deleted, got %#v", deleted) + } + calls := dispatcher.snapshot() + if len(calls) != 1 { + t.Fatalf("expected one dispatcher call, got %#v", calls) + } + if calls[0].NodeID != 10 || calls[0].CmdType != model.AgentCommandTypeDeleteStorageObject { + t.Fatalf("unexpected dispatcher call: %#v", calls[0]) + } + if calls[0].Payload["storagePath"] != record.StoragePath { + t.Fatalf("expected storagePath %q, got %#v", record.StoragePath, calls[0].Payload) + } + if calls[0].Payload["targetType"] != string(storage.ProviderTypeLocalDisk) { + t.Fatalf("expected local_disk targetType, got %#v", calls[0].Payload) + } + if _, ok := calls[0].Payload["targetConfig"].(map[string]any); !ok { + t.Fatalf("expected targetConfig map, got %#v", calls[0].Payload["targetConfig"]) + } +} + +func TestBackupExecutionServiceRestoreRecordRejectsRemoteLocalDisk(t *testing.T) { + executionService, _, tasks, _, records, _, _ := newExecutionTestServices(t) + ctx := context.Background() + executionService.SetClusterDependencies(&nodeRepoStub{nodes: []model.Node{ + {ID: 10, Name: "edge-a", Token: "edge-a-token", Status: model.NodeStatusOnline}, + }}, &fakeDispatcher{}) + task, err := tasks.FindByID(ctx, 1) + if err != nil { + t.Fatalf("FindByID task returned error: %v", err) + } + completedAt := time.Now().UTC() + record := &model.BackupRecord{ + TaskID: task.ID, + StorageTargetID: task.StorageTargetID, + NodeID: 10, + Status: model.BackupRecordStatusSuccess, + FileName: "remote.tar.gz", + StoragePath: "file/2026/05/09/remote.tar.gz", + StartedAt: completedAt.Add(-time.Second), + CompletedAt: &completedAt, + } + if err := records.Create(ctx, record); err != nil { + t.Fatalf("Create record returned error: %v", err) + } + + err = executionService.RestoreRecord(ctx, record.ID) + if err == nil { + t.Fatal("expected remote local_disk restore to be rejected") + } + if !strings.Contains(err.Error(), "Master 无法跨节点访问") { + t.Fatalf("expected cross-node local_disk error, got %v", err) + } +} + +func TestBackupExecutionServiceRecordsFirstSuccessfulStorageTarget(t *testing.T) { + executionService, _, tasks, targets, records, _, _ := newExecutionTestServices(t) + ctx := context.Background() + second := &testStorageProvider{name: "second", objects: map[string][]byte{}} + executionService.storageRegistry = storage.NewRegistry(&testStorageFactory{providers: map[string]*testStorageProvider{ + "second": second, + }}) + cipher := codec.NewConfigCipher("execution-secret") + firstConfig, err := cipher.EncryptJSON(map[string]any{"name": "missing"}) + if err != nil { + t.Fatalf("EncryptJSON first returned error: %v", err) + } + secondConfig, err := cipher.EncryptJSON(map[string]any{"name": "second"}) + if err != nil { + t.Fatalf("EncryptJSON second returned error: %v", err) + } + if err := targets.Create(ctx, &model.StorageTarget{Name: "first", Type: "test_storage", Enabled: true, ConfigCiphertext: firstConfig, ConfigVersion: 1, LastTestStatus: "unknown"}); err != nil { + t.Fatalf("Create first target returned error: %v", err) + } + if err := targets.Create(ctx, &model.StorageTarget{Name: "second", Type: "test_storage", Enabled: true, ConfigCiphertext: secondConfig, ConfigVersion: 1, LastTestStatus: "unknown"}); err != nil { + t.Fatalf("Create second target returned error: %v", err) + } + task, err := tasks.FindByID(ctx, 1) + if err != nil { + t.Fatalf("FindByID task returned error: %v", err) + } + task.StorageTargetID = 2 + task.StorageTargets = []model.StorageTarget{{ID: 2}, {ID: 3}} + if err := tasks.Update(ctx, task); err != nil { + t.Fatalf("Update task returned error: %v", err) + } + + detail, err := executionService.RunTaskByIDSync(ctx, 1) + if err != nil { + t.Fatalf("RunTaskByIDSync returned error: %v", err) + } + if detail.Status != model.BackupRecordStatusSuccess { + t.Fatalf("expected success, got %#v", detail) + } + storedRecord, err := records.FindByID(ctx, detail.ID) + if err != nil { + t.Fatalf("FindByID record returned error: %v", err) + } + if storedRecord.StorageTargetID != 3 { + t.Fatalf("expected record StorageTargetID to point at successful target 3, got %d", storedRecord.StorageTargetID) + } + if _, ok := second.objects[storedRecord.StoragePath]; !ok { + t.Fatalf("expected object in successful provider at %q", storedRecord.StoragePath) + } +} + func TestBackupRecordServiceRestore(t *testing.T) { executionService, recordService, _, _, _, sourceDir, _ := newExecutionTestServices(t) detail, err := executionService.RunTaskByIDSync(context.Background(), 1) diff --git a/server/internal/service/install_token_service.go b/server/internal/service/install_token_service.go index 21ce31d..e2e97d9 100644 --- a/server/internal/service/install_token_service.go +++ b/server/internal/service/install_token_service.go @@ -3,12 +3,14 @@ package service import ( "context" "crypto/rand" + "encoding/base64" "encoding/hex" "fmt" "strings" "time" "backupx/server/internal/apperror" + "backupx/server/internal/installscript" "backupx/server/internal/model" "backupx/server/internal/repository" ) @@ -42,6 +44,25 @@ type InstallTokenOutput struct { Record *model.AgentInstallToken } +// InstallCommandInput 生成可展示安装命令所需的完整业务输入。 +type InstallCommandInput struct { + InstallTokenInput + MasterURL string +} + +// InstallCommandOutput 是 UI 生成安装命令所需的完整业务输出。 +type InstallCommandOutput struct { + Token string + ExpiresAt time.Time + Node *model.Node + Record *model.AgentInstallToken + URL string + FallbackURL string + ComposeURL string + FallbackComposeURL string + ScriptBase64 string +} + // ConsumedInstallToken 消费成功后返回给 handler 的组合体。 type ConsumedInstallToken struct { Record *model.AgentInstallToken @@ -106,6 +127,67 @@ func (s *InstallTokenService) Create(ctx context.Context, in InstallTokenInput) return &InstallTokenOutput{Token: token, ExpiresAt: expiresAt, Node: node, Record: record}, nil } +// CreateCommand 创建 install token,并返回 UI 展示安装命令所需的 URL 与嵌入式脚本。 +func (s *InstallTokenService) CreateCommand(ctx context.Context, in InstallCommandInput) (*InstallCommandOutput, error) { + masterURL := strings.TrimRight(strings.TrimSpace(in.MasterURL), "/") + if masterURL == "" { + return nil, apperror.BadRequest("INSTALL_TOKEN_INVALID", "masterURL 必填", nil) + } + if err := s.validate(in.InstallTokenInput); err != nil { + return nil, err + } + node, err := s.nodeRepo.FindByID(ctx, in.NodeID) + if err != nil { + return nil, err + } + if node == nil { + return nil, apperror.New(404, "NODE_NOT_FOUND", "节点不存在", nil) + } + if _, err := renderInstallCommandScript(masterURL, node, &model.AgentInstallToken{ + Mode: in.Mode, + Arch: in.Arch, + AgentVer: in.AgentVersion, + DownloadSrc: in.DownloadSrc, + }); err != nil { + return nil, err + } + out, err := s.Create(ctx, in.InstallTokenInput) + if err != nil { + return nil, err + } + script, err := renderInstallCommandScript(masterURL, out.Node, out.Record) + if err != nil { + return nil, err + } + result := &InstallCommandOutput{ + Token: out.Token, + ExpiresAt: out.ExpiresAt, + Node: out.Node, + Record: out.Record, + URL: masterURL + "/api/install/" + out.Token, + FallbackURL: masterURL + "/install/" + out.Token, + ScriptBase64: base64.StdEncoding.EncodeToString([]byte(script)), + } + if out.Record.Mode == model.InstallModeDocker { + result.ComposeURL = masterURL + "/api/install/" + out.Token + "/compose.yml" + result.FallbackComposeURL = masterURL + "/install/" + out.Token + "/compose.yml" + } + return result, nil +} + +func renderInstallCommandScript(masterURL string, node *model.Node, record *model.AgentInstallToken) (string, error) { + return installscript.RenderScript(installscript.Context{ + MasterURL: masterURL, + AgentToken: node.Token, + AgentVersion: record.AgentVer, + Mode: record.Mode, + Arch: record.Arch, + DownloadBase: installscript.DownloadBaseFor(record.DownloadSrc), + InstallPrefix: "/opt/backupx-agent", + NodeID: node.ID, + }) +} + // Consume 原子消费令牌。未命中/已过期/已消费均返回 (nil, nil)。 func (s *InstallTokenService) Consume(ctx context.Context, token string) (*ConsumedInstallToken, error) { if strings.TrimSpace(token) == "" { @@ -170,8 +252,8 @@ func (s *InstallTokenService) validate(in InstallTokenInput) error { if !validInstallSources[in.DownloadSrc] { return apperror.BadRequest("INSTALL_TOKEN_INVALID", "downloadSrc 非法", nil) } - if strings.TrimSpace(in.AgentVersion) == "" { - return apperror.BadRequest("INSTALL_TOKEN_INVALID", "agentVersion 必填", nil) + if err := validateInstallAgentVersion(in.AgentVersion); err != nil { + return err } if in.TTLSeconds < InstallTokenMinTTL || in.TTLSeconds > InstallTokenMaxTTL { return apperror.BadRequest("INSTALL_TOKEN_INVALID", @@ -180,6 +262,27 @@ func (s *InstallTokenService) validate(in InstallTokenInput) error { return nil } +func validateInstallAgentVersion(v string) error { + v = strings.TrimSpace(v) + if v == "" { + return apperror.BadRequest("INSTALL_TOKEN_INVALID", "agentVersion 必填", nil) + } + if len(v) > 64 { + return apperror.BadRequest("INSTALL_TOKEN_INVALID", "agentVersion 不能超过 64 字符", nil) + } + for _, c := range v { + switch { + case c >= '0' && c <= '9': + case c >= 'a' && c <= 'z': + case c >= 'A' && c <= 'Z': + case c == '.' || c == '-' || c == '_' || c == '+': + default: + return apperror.BadRequest("INSTALL_TOKEN_INVALID", "agentVersion 包含非法字符", nil) + } + } + return nil +} + func generateInstallToken() (string, error) { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { diff --git a/server/internal/service/install_token_service_test.go b/server/internal/service/install_token_service_test.go index 0552202..5ac0789 100644 --- a/server/internal/service/install_token_service_test.go +++ b/server/internal/service/install_token_service_test.go @@ -131,6 +131,79 @@ func TestInstallTokenServiceValidatesInput(t *testing.T) { } } +func TestInstallTokenServiceRejectsInvalidAgentVersionBeforeCreate(t *testing.T) { + db := openInstallTokenTestDB(t) + nodeRepo := repository.NewNodeRepository(db) + node := &model.Node{Name: "invalid-version", Token: "feedface"} + if err := nodeRepo.Create(context.Background(), node); err != nil { + t.Fatalf("create node: %v", err) + } + tokenRepo := repository.NewAgentInstallTokenRepository(db) + svc := NewInstallTokenService(tokenRepo, nodeRepo) + + _, err := svc.Create(context.Background(), InstallTokenInput{ + NodeID: node.ID, + Mode: model.InstallModeSystemd, + Arch: model.InstallArchAuto, + AgentVersion: "v1 && rm -rf /", + DownloadSrc: model.InstallSourceGitHub, + TTLSeconds: 900, + CreatedByID: 1, + }) + if err == nil { + t.Fatalf("expected invalid version error") + } + count, err := tokenRepo.CountCreatedSince(context.Background(), node.ID, time.Now().UTC().Add(-time.Hour)) + if err != nil { + t.Fatalf("count: %v", err) + } + if count != 0 { + t.Fatalf("invalid request created %d token records", count) + } +} + +func TestInstallTokenServiceCreateCommandBuildsURLsAndScript(t *testing.T) { + db := openInstallTokenTestDB(t) + nodeRepo := repository.NewNodeRepository(db) + node := &model.Node{ + Name: "command-node", + Token: "deadbeefcafebabe0123456789abcdef0123456789abcdef0123456789abcdef", + } + if err := nodeRepo.Create(context.Background(), node); err != nil { + t.Fatalf("create node: %v", err) + } + tokenRepo := repository.NewAgentInstallTokenRepository(db) + svc := NewInstallTokenService(tokenRepo, nodeRepo) + + out, err := svc.CreateCommand(context.Background(), InstallCommandInput{ + InstallTokenInput: InstallTokenInput{ + NodeID: node.ID, + Mode: model.InstallModeDocker, + Arch: model.InstallArchAuto, + AgentVersion: "v1.7.0", + DownloadSrc: model.InstallSourceGitHub, + TTLSeconds: 900, + CreatedByID: 1, + }, + MasterURL: "https://public.example.com/base", + }) + if err != nil { + t.Fatalf("create command: %v", err) + } + if out.Token == "" || out.ScriptBase64 == "" { + t.Fatalf("missing token or script: %+v", out) + } + if out.URL != "https://public.example.com/base/api/install/"+out.Token { + t.Fatalf("bad url: %s", out.URL) + } + if out.FallbackURL != "https://public.example.com/base/install/"+out.Token { + t.Fatalf("bad fallback url: %s", out.FallbackURL) + } + if out.ComposeURL != "https://public.example.com/base/api/install/"+out.Token+"/compose.yml" { + t.Fatalf("bad compose url: %s", out.ComposeURL) + } +} + func TestInstallTokenServiceRateLimit(t *testing.T) { db := openInstallTokenTestDB(t) nodeRepo := repository.NewNodeRepository(db) diff --git a/server/internal/service/restore_service.go b/server/internal/service/restore_service.go index cc3f0c4..f679d9b 100644 --- a/server/internal/service/restore_service.go +++ b/server/internal/service/restore_service.go @@ -141,10 +141,11 @@ func (s *RestoreService) Start(ctx context.Context, backupRecordID uint, trigger } startedAt := s.now() + restoreNodeID := s.resolveRestoreNodeID(record, task) restore := &model.RestoreRecord{ BackupRecordID: backupRecordID, TaskID: record.TaskID, - NodeID: task.NodeID, + NodeID: restoreNodeID, Status: model.RestoreRecordStatusRunning, StartedAt: startedAt, TriggeredBy: strings.TrimSpace(triggeredBy), @@ -154,7 +155,7 @@ func (s *RestoreService) Start(ctx context.Context, backupRecordID uint, trigger } // 远程节点路由 - if remoteNode := s.resolveRemoteNode(ctx, task.NodeID); remoteNode != nil { + if remoteNode := s.resolveRemoteNode(ctx, restoreNodeID); remoteNode != nil { if s.dispatcher == nil { return nil, apperror.Internal("RESTORE_DISPATCH_UNAVAILABLE", "Agent 下发通道未就绪", nil) } @@ -166,14 +167,14 @@ func (s *RestoreService) Start(ctx context.Context, backupRecordID uint, trigger s.logHub.Complete(restore.ID, model.RestoreRecordStatusFailed) return nil, apperror.BadRequest("NODE_OFFLINE", offlineMsg, nil) } - if _, dispatchErr := s.dispatcher.EnqueueCommand(ctx, task.NodeID, model.AgentCommandTypeRestoreRecord, map[string]any{ + if _, dispatchErr := s.dispatcher.EnqueueCommand(ctx, restoreNodeID, model.AgentCommandTypeRestoreRecord, map[string]any{ "restoreRecordId": restore.ID, }); dispatchErr != nil { _ = s.finalize(ctx, restore.ID, model.RestoreRecordStatusFailed, "下发恢复任务到远程节点失败: "+dispatchErr.Error()) return nil, apperror.Internal("AGENT_COMMAND_ENQUEUE_FAILED", "无法下发恢复任务到远程节点", dispatchErr) } - s.logHub.Append(restore.ID, "info", fmt.Sprintf("已下发恢复任务到节点 %s(#%d),等待 Agent 执行", remoteNode.Name, task.NodeID)) + s.logHub.Append(restore.ID, "info", fmt.Sprintf("已下发恢复任务到节点 %s(#%d),等待 Agent 执行", remoteNode.Name, restoreNodeID)) return s.getDetail(ctx, restore.ID) } @@ -185,6 +186,16 @@ func (s *RestoreService) Start(ctx context.Context, backupRecordID uint, trigger return s.getDetail(ctx, restore.ID) } +func (s *RestoreService) resolveRestoreNodeID(record *model.BackupRecord, task *model.BackupTask) uint { + if record != nil && record.NodeID != 0 { + return record.NodeID + } + if task != nil { + return task.NodeID + } + return 0 +} + // isRemoteNode 判断 NodeID 是否指向有效的远程节点。 func (s *RestoreService) isRemoteNode(ctx context.Context, nodeID uint) bool { return s.resolveRemoteNode(ctx, nodeID) != nil @@ -629,6 +640,9 @@ func (s *RestoreService) UpdateAgentRestore(ctx context.Context, node *model.Nod if restore.NodeID != node.ID { return apperror.Unauthorized("RESTORE_RECORD_FORBIDDEN", "恢复记录不属于当前节点", nil) } + if isRestoreRecordTerminal(restore.Status) { + return nil + } // 追加日志到 LogHub + DB if strings.TrimSpace(update.LogAppend) != "" { for _, line := range strings.Split(update.LogAppend, "\n") { @@ -667,6 +681,10 @@ func (s *RestoreService) UpdateAgentRestore(ctx context.Context, node *model.Nod return nil } +func isRestoreRecordTerminal(status string) bool { + return status == model.RestoreRecordStatusSuccess || status == model.RestoreRecordStatusFailed +} + // --- 内部辅助 --- func (s *RestoreService) getDetail(ctx context.Context, restoreID uint) (*RestoreRecordDetail, error) { diff --git a/server/internal/service/restore_service_test.go b/server/internal/service/restore_service_test.go index 4184377..ba23551 100644 --- a/server/internal/service/restore_service_test.go +++ b/server/internal/service/restore_service_test.go @@ -51,15 +51,15 @@ func (f *fakeDispatcher) snapshot() []dispatcherCall { } type restoreTestHarness struct { - service *RestoreService - execution *BackupExecutionService - records repository.BackupRecordRepository - restores repository.RestoreRecordRepository - tasks repository.BackupTaskRepository - nodes repository.NodeRepository - dispatcher *fakeDispatcher - sourceDir string - storageDir string + service *RestoreService + execution *BackupExecutionService + records repository.BackupRecordRepository + restores repository.RestoreRecordRepository + tasks repository.BackupTaskRepository + nodes repository.NodeRepository + dispatcher *fakeDispatcher + sourceDir string + storageDir string } func newRestoreTestHarness(t *testing.T, remoteNode bool) *restoreTestHarness { @@ -228,6 +228,179 @@ func TestRestoreServiceStart_RemoteNodeEnqueuesCommand(t *testing.T) { } } +func TestRestoreServiceStart_UsesBackupRecordNodeForPooledTask(t *testing.T) { + h := newRestoreTestHarness(t, true) + ctx := context.Background() + + task, err := h.tasks.FindByID(ctx, 1) + if err != nil { + t.Fatalf("FindByID task: %v", err) + } + remoteNodeID := task.NodeID + task.NodeID = 0 + task.NodePoolTag = "db" + if err := h.tasks.Update(ctx, task); err != nil { + t.Fatalf("Update task: %v", err) + } + storedTask, err := h.tasks.FindByID(ctx, task.ID) + if err != nil { + t.Fatalf("FindByID stored task: %v", err) + } + if storedTask.NodeID != 0 { + t.Fatalf("expected stored task NodeID to be reset to 0, got %d", storedTask.NodeID) + } + + startedAt := time.Now().UTC() + completedAt := startedAt.Add(time.Second) + backupRecord := &model.BackupRecord{ + TaskID: task.ID, + StorageTargetID: task.StorageTargetID, + NodeID: remoteNodeID, + Status: model.BackupRecordStatusSuccess, + FileName: "pooled.tar.gz", + StoragePath: "file/2026/05/09/pooled.tar.gz", + StartedAt: startedAt, + CompletedAt: &completedAt, + } + if err := h.records.Create(ctx, backupRecord); err != nil { + t.Fatalf("Create backup record: %v", err) + } + + detail, err := h.service.Start(ctx, backupRecord.ID, "tester-pool") + if err != nil { + t.Fatalf("Start: %v", err) + } + if detail.NodeID != remoteNodeID { + t.Fatalf("expected restore node %d, got %d", remoteNodeID, detail.NodeID) + } + calls := h.dispatcher.snapshot() + if len(calls) != 1 { + t.Fatalf("expected exactly 1 dispatcher call, got %d", len(calls)) + } + if calls[0].NodeID != remoteNodeID { + t.Fatalf("expected dispatch to node %d, got %d", remoteNodeID, calls[0].NodeID) + } +} + +func TestRestoreServiceAgentRestoreAccessUsesRestoreRecordNode(t *testing.T) { + h := newRestoreTestHarness(t, true) + ctx := context.Background() + + task, err := h.tasks.FindByID(ctx, 1) + if err != nil { + t.Fatalf("FindByID task: %v", err) + } + owner, err := h.nodes.FindByID(ctx, task.NodeID) + if err != nil { + t.Fatalf("FindByID owner node: %v", err) + } + other := &model.Node{Name: "edge-2", Token: "other-token", Status: model.NodeStatusOnline, IsLocal: false, LastSeen: time.Now().UTC()} + if err := h.nodes.Create(ctx, other); err != nil { + t.Fatalf("Create other node: %v", err) + } + startedAt := time.Now().UTC() + completedAt := startedAt.Add(time.Second) + backupRecord := &model.BackupRecord{ + TaskID: task.ID, + StorageTargetID: task.StorageTargetID, + NodeID: owner.ID, + Status: model.BackupRecordStatusSuccess, + FileName: "remote.tar.gz", + StoragePath: "file/2026/05/09/remote.tar.gz", + StartedAt: startedAt, + CompletedAt: &completedAt, + } + if err := h.records.Create(ctx, backupRecord); err != nil { + t.Fatalf("Create backup record: %v", err) + } + restore := &model.RestoreRecord{ + BackupRecordID: backupRecord.ID, + TaskID: task.ID, + NodeID: owner.ID, + Status: model.RestoreRecordStatusRunning, + StartedAt: startedAt, + TriggeredBy: "agent-test", + } + if err := h.restores.Create(ctx, restore); err != nil { + t.Fatalf("Create restore record: %v", err) + } + + spec, err := h.service.GetAgentRestoreSpec(ctx, owner, restore.ID) + if err != nil { + t.Fatalf("owner GetAgentRestoreSpec returned error: %v", err) + } + if spec.RestoreRecordID != restore.ID || spec.StoragePath != backupRecord.StoragePath { + t.Fatalf("unexpected restore spec: %#v", spec) + } + if _, err := h.service.GetAgentRestoreSpec(ctx, other, restore.ID); err == nil { + t.Fatal("expected non-owner node to be forbidden from restore spec") + } + if err := h.service.UpdateAgentRestore(ctx, owner, restore.ID, AgentRestoreUpdate{ + Status: model.RestoreRecordStatusSuccess, + LogAppend: "done\n", + }); err != nil { + t.Fatalf("owner UpdateAgentRestore returned error: %v", err) + } + updated, err := h.restores.FindByID(ctx, restore.ID) + if err != nil { + t.Fatalf("FindByID restore returned error: %v", err) + } + if updated.Status != model.RestoreRecordStatusSuccess || updated.NodeID != owner.ID { + t.Fatalf("unexpected updated restore record: %#v", updated) + } + if err := h.service.UpdateAgentRestore(ctx, other, restore.ID, AgentRestoreUpdate{LogAppend: "bad\n"}); err == nil { + t.Fatal("expected non-owner node to be forbidden from restore update") + } +} + +func TestRestoreServiceUpdateAgentRestoreDoesNotOverwriteTerminalRecord(t *testing.T) { + h := newRestoreTestHarness(t, true) + ctx := context.Background() + + task, err := h.tasks.FindByID(ctx, 1) + if err != nil { + t.Fatalf("FindByID task: %v", err) + } + owner, err := h.nodes.FindByID(ctx, task.NodeID) + if err != nil { + t.Fatalf("FindByID owner node: %v", err) + } + startedAt := time.Now().UTC().Add(-time.Hour) + completedAt := time.Now().UTC().Add(-time.Minute) + restore := &model.RestoreRecord{ + BackupRecordID: 1, + TaskID: task.ID, + NodeID: owner.ID, + Status: model.RestoreRecordStatusFailed, + ErrorMessage: "timeout", + StartedAt: startedAt, + CompletedAt: &completedAt, + TriggeredBy: "agent-test", + } + if err := h.restores.Create(ctx, restore); err != nil { + t.Fatalf("Create restore record: %v", err) + } + + if err := h.service.UpdateAgentRestore(ctx, owner, restore.ID, AgentRestoreUpdate{ + Status: model.RestoreRecordStatusSuccess, + ErrorMessage: "late success", + LogAppend: "late log\n", + }); err != nil { + t.Fatalf("UpdateAgentRestore returned error: %v", err) + } + + updated, err := h.restores.FindByID(ctx, restore.ID) + if err != nil { + t.Fatalf("FindByID restore returned error: %v", err) + } + if updated.Status != model.RestoreRecordStatusFailed { + t.Fatalf("expected terminal restore status to remain failed, got %#v", updated) + } + if updated.ErrorMessage != "timeout" { + t.Fatalf("expected terminal restore error to remain unchanged, got %q", updated.ErrorMessage) + } +} + func TestRestoreServiceStart_FailsOnNonSuccessBackup(t *testing.T) { h := newRestoreTestHarness(t, false) ctx := context.Background() diff --git a/web/src/pages/nodes/AgentInstallWizard.tsx b/web/src/pages/nodes/AgentInstallWizard.tsx index b418945..c237200 100644 --- a/web/src/pages/nodes/AgentInstallWizard.tsx +++ b/web/src/pages/nodes/AgentInstallWizard.tsx @@ -1,12 +1,11 @@ import React, { useEffect, useRef, useState } from 'react' -import { Modal, Steps, Button, Space, Message, Spin, Progress } from '@arco-design/web-react' +import { Modal, Steps, Button, Space, Message, Spin } from '@arco-design/web-react' import { Step1NodeName, type Mode } from './wizard/Step1NodeName' import { Step2DeployOptions, type DeployOptions } from './wizard/Step2DeployOptions' import { Step3CommandPreview } from './wizard/Step3CommandPreview' import { BatchCommandTable, type BatchCommandRow } from './BatchCommandTable' -import { batchCreateNodes, createInstallToken } from '../../services/nodes' import type { InstallTokenResult } from '../../types/nodes' -import { buildAgentInstallCommand } from './installCommands' +import { useAgentDeployFlow, type AgentDeployRow } from './useAgentDeployFlow' const Step = Steps.Step @@ -25,9 +24,7 @@ export function AgentInstallWizard({ visible, onClose, onSuccess, masterVersion, const [mode, setMode] = useState('single') const [singleName, setSingleName] = useState('') const [batchText, setBatchText] = useState('') - - // 批量进度(已生成 / 总数) - const [batchProgress, setBatchProgress] = useState<{ done: number; total: number } | null>(null) + const deployFlow = useAgentDeployFlow() const [deploy, setDeploy] = useState({ mode: 'systemd', @@ -66,7 +63,6 @@ export function AgentInstallWizard({ visible, onClose, onSuccess, masterVersion, setSingleToken(null) setSingleNodeInfo(null) setBatchRows([]) - setBatchProgress(null) } const handleClose = () => { @@ -102,71 +98,21 @@ export function AgentInstallWizard({ visible, onClose, onSuccess, masterVersion, Message.warning('请填写 Agent 版本号(形如 v1.7.0)') return } - // 步骤 1 的批次内去重在前端先提示一次,再由后端最终校验 - if (mode === 'batch' && !fixedNode) { - const names = parseBatchNames() - const seen = new Set() - const dups: string[] = [] - for (const n of names) { - if (seen.has(n)) dups.push(n) - seen.add(n) - } - if (dups.length > 0) { - Message.warning(`批次内有重复节点名:${Array.from(new Set(dups)).join(', ')}`) - return - } - } setSubmitting(true) try { if (fixedNode) { - const tok = await createInstallToken(fixedNode.id, { - mode: deploy.mode, - arch: deploy.arch, - agentVersion: deploy.agentVersion, - downloadSrc: deploy.downloadSrc, - ttlSeconds: deploy.ttlSeconds, - }) - setSingleNodeInfo(fixedNode) - setSingleToken(tok) + const result = await deployFlow.submitExistingNode(fixedNode, deploy) + applySingleOrTableResult(result.rows, fixedNode) } else if (mode === 'single') { - const created = await batchCreateNodes([singleName.trim()]) - const one = created[0] - const tok = await createInstallToken(one.id, { - mode: deploy.mode, - arch: deploy.arch, - agentVersion: deploy.agentVersion, - downloadSrc: deploy.downloadSrc, - ttlSeconds: deploy.ttlSeconds, - }) - setSingleNodeInfo({ id: one.id, name: one.name }) - setSingleToken(tok) + const result = await deployFlow.submitNewNodes([singleName.trim()], deploy) + applySingleOrTableResult(result.rows) } else { const names = parseBatchNames() - const created = await batchCreateNodes(names) - setBatchProgress({ done: 0, total: created.length }) - // 并发生成 install token(Promise.all),每完成一个递增 done 计数 - let done = 0 - const tokens = await Promise.all( - created.map(async (c) => { - const tok = await createInstallToken(c.id, { - mode: deploy.mode, - arch: deploy.arch, - agentVersion: deploy.agentVersion, - downloadSrc: deploy.downloadSrc, - ttlSeconds: deploy.ttlSeconds, - }) - done += 1 - if (mountedRef.current) setBatchProgress({ done, total: created.length }) - return { c, tok } - }), - ) - const rows: BatchCommandRow[] = tokens.map(({ c, tok }) => ({ - nodeId: c.id, - nodeName: c.name, - command: buildAgentInstallCommand(tok.url, tok.fallbackUrl, tok.scriptBase64), - expiresAt: tok.expiresAt, - })) - if (mountedRef.current) setBatchRows(rows) + const result = await deployFlow.submitNewNodes(names, deploy) + if (mountedRef.current) setBatchRows(toBatchRows(result.rows)) + if (result.status === 'partialFailed') { + Message.warning('部分节点安装命令生成失败,可在结果表中查看') + } } setStep(2) onSuccess() @@ -181,14 +127,12 @@ export function AgentInstallWizard({ visible, onClose, onSuccess, masterVersion, if (!singleNodeInfo) return setSubmitting(true) try { - const tok = await createInstallToken(singleNodeInfo.id, { - mode: deploy.mode, - arch: deploy.arch, - agentVersion: deploy.agentVersion, - downloadSrc: deploy.downloadSrc, - ttlSeconds: deploy.ttlSeconds, - }) - setSingleToken(tok) + const row = await deployFlow.regenerateNode(singleNodeInfo, deploy) + if (row.status === 'ready' && row.installToken) { + setSingleToken(row.installToken) + } else { + Message.error(row.errorMessage || '重新生成失败') + } } catch (e: any) { Message.error(e?.message || '重新生成失败') } finally { @@ -196,6 +140,25 @@ export function AgentInstallWizard({ visible, onClose, onSuccess, masterVersion, } } + const retryBatchNode = async (row: BatchCommandRow) => { + setSubmitting(true) + try { + const next = await deployFlow.regenerateNode({ id: row.nodeId, name: row.nodeName }, deploy) + setBatchRows((rows) => rows.map((item) => ( + item.nodeId === row.nodeId ? toBatchRows([next])[0] : item + ))) + if (next.status === 'ready') { + Message.success(`节点「${row.nodeName}」安装命令已重新生成`) + } else { + Message.error(next.errorMessage || '重试失败') + } + } catch (e: any) { + Message.error(e?.message || '重试失败') + } finally { + setSubmitting(false) + } + } + const previewParams = { mode: deploy.mode, arch: deploy.arch, @@ -225,17 +188,6 @@ export function AgentInstallWizard({ visible, onClose, onSuccess, masterVersion, {submitting && (
- {batchProgress && ( -
-
- 正在生成安装命令 {batchProgress.done} / {batchProgress.total} -
- -
- )}
)} @@ -289,7 +241,7 @@ export function AgentInstallWizard({ visible, onClose, onSuccess, masterVersion, onRegenerate={regenerateSingle} /> )} - {batchRows.length > 0 && } + {batchRows.length > 0 && }
+ + {row.status === 'ready' && ( + + )} + {row.status === 'failed' && onRetryNode && ( + + )} + ), }, ]} @@ -100,9 +123,22 @@ export function BatchCommandTable({ rows }: Props) { />
- +
) } + +function secondsLeft(expiresAt: string) { + if (!expiresAt) { + return 0 + } + const exp = new Date(expiresAt).getTime() + return Math.max(0, Math.floor((exp - Date.now()) / 1000)) +} + +export function getExportableBatchRows(rows: BatchCommandRow[]) { + return rows.filter((row) => row.status === 'ready' && secondsLeft(row.expiresAt) > 0) +} diff --git a/web/src/pages/nodes/NodesPage.test.ts b/web/src/pages/nodes/NodesPage.test.ts new file mode 100644 index 0000000..51a3fa5 --- /dev/null +++ b/web/src/pages/nodes/NodesPage.test.ts @@ -0,0 +1,21 @@ +import { describe, expect, it } from 'vitest' +import type { UserInfo } from '../../services/auth' +import { canManageNodes } from './NodesPage' + +function user(role: string): UserInfo { + return { + id: 1, + username: role, + displayName: role, + role, + } +} + +describe('canManageNodes', () => { + it('allows only admins to manage deployment operations', () => { + expect(canManageNodes(user('admin'))).toBe(true) + expect(canManageNodes(user('operator'))).toBe(false) + expect(canManageNodes(user('viewer'))).toBe(false) + expect(canManageNodes(null)).toBe(false) + }) +}) diff --git a/web/src/pages/nodes/NodesPage.tsx b/web/src/pages/nodes/NodesPage.tsx index c4e3383..b32cd12 100644 --- a/web/src/pages/nodes/NodesPage.tsx +++ b/web/src/pages/nodes/NodesPage.tsx @@ -10,12 +10,21 @@ import type { NodeSummary } from '../../types/nodes' import { listNodes, deleteNode, updateNode, rotateNodeToken } from '../../services/nodes' import { fetchSystemInfo } from '../../services/system' import { AgentInstallWizard } from './AgentInstallWizard' +import { useAuthStore } from '../../stores/auth' +import { isAdmin } from '../../utils/permissions' +import type { UserInfo } from '../../services/auth' const { Text } = Typography +export function canManageNodes(user: UserInfo | null | undefined): boolean { + return isAdmin(user) +} + export default function NodesPage() { const [nodes, setNodes] = useState([]) const [loading, setLoading] = useState(false) + const currentUser = useAuthStore((state) => state.user) + const manageable = canManageNodes(currentUser) const [wizardVisible, setWizardVisible] = useState(false) const [wizardFixedNode, setWizardFixedNode] = useState<{ id: number; name: string } | undefined>() @@ -143,38 +152,43 @@ export default function NodesPage() { }, { title: '操作', width: 180, - render: (_: unknown, record: NodeSummary) => ( - - - } + ) : undefined} /> diff --git a/web/src/pages/nodes/installCommands.test.ts b/web/src/pages/nodes/installCommands.test.ts index 25044af..27a0895 100644 --- a/web/src/pages/nodes/installCommands.test.ts +++ b/web/src/pages/nodes/installCommands.test.ts @@ -1,5 +1,5 @@ import { describe, expect, it } from 'vitest' -import { buildAgentDownloadCommand, buildAgentInstallCommand } from './installCommands' +import { buildAgentDownloadCommand, buildAgentInstallCommand, buildEmbeddedAgentInstallCommand } from './installCommands' describe('install command builders', () => { it('adds script marker validation and fallback install path', () => { @@ -22,16 +22,24 @@ describe('install command builders', () => { expect(cmd).toContain('non-script content') }) - it('prefers embedded script content when available', () => { + it('keeps URL install command as primary even when embedded script is available', () => { const cmd = buildAgentInstallCommand( 'https://master.example.com/api/install/abc', 'https://master.example.com/install/abc', 'IyEvYmluL3NoCg==', ) + expect(cmd).toContain('https://master.example.com/api/install/abc') + expect(cmd).toContain('https://master.example.com/install/abc') + expect(cmd).not.toContain('IyEvYmluL3NoCg==') + }) + + it('builds embedded fallback command explicitly', () => { + const cmd = buildEmbeddedAgentInstallCommand('IyEvYmluL3NoCg==') + expect(cmd).toContain('base64 -d') expect(cmd).toContain('base64 -D') + expect(cmd).toContain('BACKUPX_AGENT_INSTALL_V1') expect(cmd).toContain("'IyEvYmluL3NoCg=='") - expect(cmd).not.toContain('https://master.example.com/api/install/abc') }) }) diff --git a/web/src/pages/nodes/installCommands.ts b/web/src/pages/nodes/installCommands.ts index b3a3064..5203c7c 100644 --- a/web/src/pages/nodes/installCommands.ts +++ b/web/src/pages/nodes/installCommands.ts @@ -12,19 +12,7 @@ function runScriptCommand(path: string) { return `if [ "$(id -u)" -eq 0 ]; then sh ${path}; else sudo sh ${path}; fi` } -export function buildAgentInstallCommand(url: string, fallbackUrl?: string, scriptBase64?: string) { - if (scriptBase64?.trim()) { - const marker = shellQuote(INSTALL_MAGIC_MARKER) - return [ - 'enc=$(mktemp)', - 'tmp=$(mktemp)', - `printf %s ${shellQuote(scriptBase64.trim())} > "$enc"`, - '(base64 -d < "$enc" > "$tmp" 2>/dev/null || base64 -D < "$enc" > "$tmp")', - `{ grep -q ${marker} "$tmp" || { echo 'BackupX embedded installer is invalid.' >&2; head -5 "$tmp" >&2; false; }; }`, - runScriptCommand('"$tmp"'), - ].join(' && ') + '; rc=$?; rm -f "$enc" "$tmp"; test $rc -eq 0' - } - +export function buildAgentInstallCommand(url: string, fallbackUrl?: string, _scriptBase64?: string) { const primary = url.trim() const fallback = (fallbackUrl || legacyInstallUrl(primary)).trim() const urls = fallback && fallback !== primary ? [primary, fallback] : [primary] @@ -41,17 +29,7 @@ export function buildAgentInstallCommand(url: string, fallbackUrl?: string, scri ].join(' && ') + '; rc=$?; rm -f "$tmp"; test $rc -eq 0' } -export function buildAgentDownloadCommand(url: string, fallbackUrl?: string, scriptBase64?: string) { - if (scriptBase64?.trim()) { - const marker = shellQuote(INSTALL_MAGIC_MARKER) - return [ - `printf %s ${shellQuote(scriptBase64.trim())} > /tmp/bx-agent-install.b64`, - '(base64 -d < /tmp/bx-agent-install.b64 > /tmp/bx-agent-install.sh 2>/dev/null || base64 -D < /tmp/bx-agent-install.b64 > /tmp/bx-agent-install.sh)', - `{ grep -q ${marker} /tmp/bx-agent-install.sh || { echo 'BackupX embedded installer is invalid.' >&2; head -5 /tmp/bx-agent-install.sh >&2; false; }; }`, - runScriptCommand('/tmp/bx-agent-install.sh'), - ].join(' && ') - } - +export function buildAgentDownloadCommand(url: string, fallbackUrl?: string, _scriptBase64?: string) { const primary = url.trim() const fallback = (fallbackUrl || legacyInstallUrl(primary)).trim() const marker = shellQuote(INSTALL_MAGIC_MARKER) @@ -65,3 +43,15 @@ export function buildAgentDownloadCommand(url: string, fallbackUrl?: string, scr runScriptCommand('/tmp/bx-agent-install.sh'), ].join(' && ') } + +export function buildEmbeddedAgentInstallCommand(scriptBase64: string) { + const marker = shellQuote(INSTALL_MAGIC_MARKER) + return [ + 'enc=$(mktemp)', + 'tmp=$(mktemp)', + `printf %s ${shellQuote(scriptBase64.trim())} > "$enc"`, + '(base64 -d < "$enc" > "$tmp" 2>/dev/null || base64 -D < "$enc" > "$tmp")', + `{ grep -q ${marker} "$tmp" || { echo 'BackupX embedded installer is invalid.' >&2; head -5 "$tmp" >&2; false; }; }`, + runScriptCommand('"$tmp"'), + ].join(' && ') + '; rc=$?; rm -f "$enc" "$tmp"; test $rc -eq 0' +} diff --git a/web/src/pages/nodes/useAgentDeployFlow.test.ts b/web/src/pages/nodes/useAgentDeployFlow.test.ts new file mode 100644 index 0000000..93b0c69 --- /dev/null +++ b/web/src/pages/nodes/useAgentDeployFlow.test.ts @@ -0,0 +1,90 @@ +import { describe, expect, it } from 'vitest' +import type { InstallTokenInput, InstallTokenResult } from '../../types/nodes' +import { createAgentDeployFlow } from './useAgentDeployFlow' + +function deployOptions(): InstallTokenInput { + return { + mode: 'systemd', + arch: 'auto', + agentVersion: 'v2.3.1', + downloadSrc: 'github', + ttlSeconds: 900, + } +} + +function tokenResult(overrides: Partial = {}): InstallTokenResult { + return { + installToken: 'install-token', + expiresAt: '2099-01-01T00:00:00Z', + url: 'https://master.example.com/api/install/install-token', + fallbackUrl: 'https://master.example.com/install/install-token', + scriptBase64: 'IyEvYmluL3NoCg==', + composeUrl: '', + fallbackComposeUrl: '', + ...overrides, + } +} + +describe('createAgentDeployFlow', () => { + it('creates one node then issues one install token', async () => { + const calls: string[] = [] + const flow = createAgentDeployFlow({ + batchCreateNodes: async (names) => { + calls.push(`batch:${names.join(',')}`) + return [{ id: 7, name: names[0] }] + }, + createInstallToken: async (nodeId) => { + calls.push(`token:${nodeId}`) + return tokenResult() + }, + }) + + const result = await flow.submitNewNodes(['prod-a'], deployOptions()) + + expect(calls).toEqual(['batch:prod-a', 'token:7']) + expect(result.status).toBe('ready') + expect(result.rows).toHaveLength(1) + expect(result.rows[0]).toMatchObject({ + nodeId: 7, + nodeName: 'prod-a', + status: 'ready', + }) + expect(result.rows[0].command).toContain('/api/install/install-token') + expect(result.rows[0].embeddedCommand).toContain('IyEvYmluL3NoCg==') + }) + + it('returns partialFailed when one batch token request fails', async () => { + const flow = createAgentDeployFlow({ + batchCreateNodes: async (names) => names.map((name, index) => ({ id: index + 1, name })), + createInstallToken: async (nodeId) => { + if (nodeId === 2) { + throw new Error('token service unavailable') + } + return tokenResult({ installToken: `tok-${nodeId}`, url: `https://master.example.com/api/install/tok-${nodeId}` }) + }, + }) + + const result = await flow.submitNewNodes(['prod-a', 'prod-b', 'prod-c'], deployOptions()) + + expect(result.status).toBe('partialFailed') + expect(result.rows.map((row) => row.status)).toEqual(['ready', 'failed', 'ready']) + expect(result.rows[1]).toMatchObject({ + nodeId: 2, + nodeName: 'prod-b', + status: 'failed', + errorMessage: 'token service unavailable', + }) + }) + + it('rejects duplicate names before creating nodes', async () => { + const flow = createAgentDeployFlow({ + batchCreateNodes: async () => { + throw new Error('should not call batchCreateNodes') + }, + createInstallToken: async () => tokenResult(), + }) + + await expect(flow.submitNewNodes(['prod-a', ' prod-a '], deployOptions())) + .rejects.toThrow('批次内重复节点名') + }) +}) diff --git a/web/src/pages/nodes/useAgentDeployFlow.ts b/web/src/pages/nodes/useAgentDeployFlow.ts new file mode 100644 index 0000000..8c73c3c --- /dev/null +++ b/web/src/pages/nodes/useAgentDeployFlow.ts @@ -0,0 +1,146 @@ +import { useMemo } from 'react' +import type { BatchCreateResult, InstallTokenInput, InstallTokenResult } from '../../types/nodes' +import { batchCreateNodes, createInstallToken } from '../../services/nodes' +import { + buildAgentInstallCommand, + buildEmbeddedAgentInstallCommand, +} from './installCommands' + +export type DeployRowStatus = 'ready' | 'failed' +export type DeployResultStatus = 'ready' | 'partialFailed' + +export interface AgentDeployNode { + id: number + name: string +} + +export interface AgentDeployRow { + nodeId: number + nodeName: string + status: DeployRowStatus + command: string + expiresAt: string + installToken?: InstallTokenResult + embeddedCommand?: string + errorMessage?: string +} + +export interface AgentDeployResult { + status: DeployResultStatus + rows: AgentDeployRow[] +} + +interface AgentDeployFlowDeps { + batchCreateNodes: (names: string[]) => Promise + createInstallToken: (nodeId: number, input: InstallTokenInput) => Promise +} + +const TOKEN_CONCURRENCY = 4 + +export function createAgentDeployFlow(deps: AgentDeployFlowDeps) { + const issueTokenForNode = async (node: AgentDeployNode, input: InstallTokenInput): Promise => { + try { + const token = await deps.createInstallToken(node.id, input) + return readyRow(node, token) + } catch (error) { + return { + nodeId: node.id, + nodeName: node.name, + status: 'failed', + command: '', + expiresAt: '', + errorMessage: resolveErrorMessage(error), + } + } + } + + return { + async submitNewNodes(names: string[], input: InstallTokenInput): Promise { + const cleanedNames = normalizeNodeNames(names) + const nodes = await deps.batchCreateNodes(cleanedNames) + const rows = await mapWithConcurrency(nodes, TOKEN_CONCURRENCY, (node) => issueTokenForNode(node, input)) + return resultFromRows(rows) + }, + + async submitExistingNode(node: AgentDeployNode, input: InstallTokenInput): Promise { + const row = await issueTokenForNode(node, input) + return resultFromRows([row]) + }, + + async regenerateNode(node: AgentDeployNode, input: InstallTokenInput): Promise { + return issueTokenForNode(node, input) + }, + } +} + +export function useAgentDeployFlow() { + return useMemo(() => createAgentDeployFlow({ batchCreateNodes, createInstallToken }), []) +} + +function readyRow(node: AgentDeployNode, token: InstallTokenResult): AgentDeployRow { + return { + nodeId: node.id, + nodeName: node.name, + status: 'ready', + command: buildAgentInstallCommand(token.url, token.fallbackUrl), + expiresAt: token.expiresAt, + installToken: token, + embeddedCommand: token.scriptBase64 + ? buildEmbeddedAgentInstallCommand(token.scriptBase64) + : undefined, + } +} + +function resultFromRows(rows: AgentDeployRow[]): AgentDeployResult { + return { + status: rows.some((row) => row.status === 'failed') ? 'partialFailed' : 'ready', + rows, + } +} + +function normalizeNodeNames(names: string[]) { + const cleaned = names.map((name) => name.trim()).filter(Boolean) + if (cleaned.length === 0) { + throw new Error('请至少输入一个节点名称') + } + if (cleaned.length > 50) { + throw new Error('单次最多创建 50 个节点') + } + const seen = new Set() + for (const name of cleaned) { + if (seen.has(name)) { + throw new Error(`批次内重复节点名:${name}`) + } + seen.add(name) + } + return cleaned +} + +async function mapWithConcurrency( + items: T[], + concurrency: number, + mapper: (item: T, index: number) => Promise, +): Promise { + const results = new Array(items.length) + let nextIndex = 0 + const workerCount = Math.min(concurrency, items.length) + const workers = Array.from({ length: workerCount }, async () => { + for (;;) { + const index = nextIndex + nextIndex += 1 + if (index >= items.length) { + return + } + results[index] = await mapper(items[index], index) + } + }) + await Promise.all(workers) + return results +} + +function resolveErrorMessage(error: unknown) { + if (error instanceof Error && error.message) { + return error.message + } + return '生成安装命令失败' +} diff --git a/web/src/pages/nodes/wizard/Step3CommandPreview.tsx b/web/src/pages/nodes/wizard/Step3CommandPreview.tsx index c7776e6..9000060 100644 --- a/web/src/pages/nodes/wizard/Step3CommandPreview.tsx +++ b/web/src/pages/nodes/wizard/Step3CommandPreview.tsx @@ -3,7 +3,7 @@ import { Typography, Button, Space, Collapse, Spin, Message, Tag } from '@arco-d import { IconCopy, IconRefresh } from '@arco-design/web-react/icon' import { fetchScriptPreview } from '../../../services/nodes' import type { InstallTokenResult, InstallMode } from '../../../types/nodes' -import { buildAgentDownloadCommand, buildAgentInstallCommand } from '../installCommands' +import { buildAgentDownloadCommand, buildAgentInstallCommand, buildEmbeddedAgentInstallCommand } from '../installCommands' const { Text } = Typography @@ -30,8 +30,9 @@ export function Step3CommandPreview({ nodeId, nodeName, token, mode, previewPara }, [token.expiresAt]) const expired = remaining === 0 - const command = buildAgentInstallCommand(token.url, token.fallbackUrl, token.scriptBase64) - const fallbackCommand = buildAgentDownloadCommand(token.url, token.fallbackUrl, token.scriptBase64) + const command = buildAgentInstallCommand(token.url, token.fallbackUrl) + const fallbackCommand = buildAgentDownloadCommand(token.url, token.fallbackUrl) + const embeddedCommand = token.scriptBase64 ? buildEmbeddedAgentInstallCommand(token.scriptBase64) : null const dockerComposeCmd = mode === 'docker' && token.composeUrl ? `curl -fsSL ${token.composeUrl} -o docker-compose.yml && docker-compose up -d` : null @@ -107,8 +108,22 @@ export function Step3CommandPreview({ nodeId, nodeName, token, mode, previewPara )} + {embeddedCommand && ( +
+ + 代理异常时使用嵌入式备用命令: + + + {embeddedCommand} + +
+ +
+
+ )} + - 安装命令包含节点 token,请仅在目标机执行并妥善保存;公开安装链接会在 TTL 到期或首次消费后作废。 + 主安装命令包含公开 install token,会在 TTL 到期或首次消费后作废;嵌入式备用命令包含完整节点 token,不依赖公开链接消费状态,请仅在目标机执行并妥善保存。 {