Files
BackupX/server/internal/repository/agent_command_repository_test.go
2026-05-13 14:24:45 +08:00

262 lines
9.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package repository
import (
"context"
"testing"
"time"
"backupx/server/internal/model"
"github.com/glebarez/sqlite"
"gorm.io/gorm"
gormlogger "gorm.io/gorm/logger"
)
func newTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)})
if err != nil {
t.Fatalf("open: %v", err)
}
if err := db.AutoMigrate(&model.AgentCommand{}); err != nil {
t.Fatalf("migrate: %v", err)
}
return db
}
func TestAgentCommandRepository_ClaimPending(t *testing.T) {
db := newTestDB(t)
repo := NewAgentCommandRepository(db)
ctx := context.Background()
// 插入两条 pending 命令
c1 := &model.AgentCommand{NodeID: 5, Type: "run_task", Status: model.AgentCommandStatusPending, Payload: `{"taskId":1}`}
c2 := &model.AgentCommand{NodeID: 5, Type: "list_dir", Status: model.AgentCommandStatusPending, Payload: `{"path":"/"}`}
c3 := &model.AgentCommand{NodeID: 99, Type: "run_task", Status: model.AgentCommandStatusPending}
for _, c := range []*model.AgentCommand{c1, c2, c3} {
if err := repo.Create(ctx, c); err != nil {
t.Fatal(err)
}
}
// 第一次 Claim 应拿到 c1
claimed, err := repo.ClaimPending(ctx, 5)
if err != nil {
t.Fatalf("claim: %v", err)
}
if claimed == nil || claimed.ID != c1.ID || claimed.Status != model.AgentCommandStatusDispatched {
t.Fatalf("expected c1 dispatched: %+v", claimed)
}
// 第二次应拿到 c2
claimed2, err := repo.ClaimPending(ctx, 5)
if err != nil || claimed2 == nil || claimed2.ID != c2.ID {
t.Fatalf("expected c2: %+v %v", claimed2, err)
}
// 第三次无 pending返回 nil
claimed3, err := repo.ClaimPending(ctx, 5)
if err != nil || claimed3 != nil {
t.Fatalf("expected nil, got %+v", claimed3)
}
// 不同 node 的命令不应被抢到
other, err := repo.ClaimPending(ctx, 5)
if err != nil || other != nil {
t.Fatalf("expected nil: %+v", other)
}
}
func TestAgentCommandRepository_Update(t *testing.T) {
db := newTestDB(t)
repo := NewAgentCommandRepository(db)
ctx := context.Background()
cmd := &model.AgentCommand{NodeID: 1, Type: "run_task", Status: model.AgentCommandStatusPending}
_ = repo.Create(ctx, cmd)
cmd.Status = model.AgentCommandStatusSucceeded
cmd.Result = `{"ok":true}`
now := time.Now().UTC()
cmd.CompletedAt = &now
if err := repo.Update(ctx, cmd); err != nil {
t.Fatal(err)
}
got, err := repo.FindByID(ctx, cmd.ID)
if err != nil || got == nil {
t.Fatal(err)
}
if got.Status != model.AgentCommandStatusSucceeded || got.Result != `{"ok":true}` {
t.Errorf("mismatch: %+v", got)
}
}
func TestAgentCommandRepository_CompleteDispatchedOnlyUpdatesDispatchedCommand(t *testing.T) {
db := newTestDB(t)
repo := NewAgentCommandRepository(db)
ctx := context.Background()
dispatched := &model.AgentCommand{NodeID: 1, Type: "run_task", Status: model.AgentCommandStatusDispatched}
timeout := &model.AgentCommand{NodeID: 1, Type: "run_task", Status: model.AgentCommandStatusTimeout, ErrorMessage: "timeout"}
if err := repo.Create(ctx, dispatched); err != nil {
t.Fatalf("Create dispatched returned error: %v", err)
}
if err := repo.Create(ctx, timeout); err != nil {
t.Fatalf("Create timeout returned error: %v", err)
}
now := time.Now().UTC()
dispatched.Status = model.AgentCommandStatusSucceeded
dispatched.Result = `{"ok":true}`
dispatched.CompletedAt = &now
updated, err := repo.CompleteDispatched(ctx, dispatched)
if err != nil {
t.Fatalf("CompleteDispatched returned error: %v", err)
}
if !updated {
t.Fatal("expected dispatched command to be updated")
}
timeout.Status = model.AgentCommandStatusSucceeded
timeout.Result = `{"late":true}`
timeout.CompletedAt = &now
updated, err = repo.CompleteDispatched(ctx, timeout)
if err != nil {
t.Fatalf("CompleteDispatched terminal returned error: %v", err)
}
if updated {
t.Fatal("expected terminal command not to be updated")
}
gotTimeout, err := repo.FindByID(ctx, timeout.ID)
if err != nil {
t.Fatalf("FindByID timeout returned error: %v", err)
}
if gotTimeout.Status != model.AgentCommandStatusTimeout || gotTimeout.Result != "" {
t.Fatalf("expected timeout command unchanged, got %#v", gotTimeout)
}
}
func TestAgentCommandRepository_TimeoutActiveDoesNotOverwriteTerminalCommand(t *testing.T) {
db := newTestDB(t)
repo := NewAgentCommandRepository(db)
ctx := context.Background()
succeeded := &model.AgentCommand{NodeID: 1, Type: "run_task", Status: model.AgentCommandStatusSucceeded, Result: `{"ok":true}`}
if err := repo.Create(ctx, succeeded); err != nil {
t.Fatalf("Create succeeded returned error: %v", err)
}
now := time.Now().UTC()
succeeded.ErrorMessage = "timeout"
succeeded.CompletedAt = &now
updated, err := repo.TimeoutActive(ctx, succeeded)
if err != nil {
t.Fatalf("TimeoutActive returned error: %v", err)
}
if updated {
t.Fatal("expected terminal command not to be timed out")
}
got, err := repo.FindByID(ctx, succeeded.ID)
if err != nil {
t.Fatalf("FindByID returned error: %v", err)
}
if got.Status != model.AgentCommandStatusSucceeded || got.ErrorMessage != "" || got.Result != `{"ok":true}` {
t.Fatalf("expected succeeded command unchanged, got %#v", got)
}
}
func TestAgentCommandRepository_MarkStaleTimeout(t *testing.T) {
db := newTestDB(t)
repo := NewAgentCommandRepository(db)
ctx := context.Background()
old := time.Now().Add(-time.Hour)
recent := time.Now()
// 两条 dispatched一条旧、一条新
oldCmd := &model.AgentCommand{NodeID: 1, Type: "run_task", Status: model.AgentCommandStatusDispatched, DispatchedAt: &old}
newCmd := &model.AgentCommand{NodeID: 1, Type: "run_task", Status: model.AgentCommandStatusDispatched, DispatchedAt: &recent}
_ = repo.Create(ctx, oldCmd)
_ = repo.Create(ctx, newCmd)
n, err := repo.MarkStaleTimeout(ctx, time.Now().Add(-30*time.Minute))
if err != nil {
t.Fatal(err)
}
if n != 1 {
t.Errorf("expected 1 row, got %d", n)
}
oldGot, _ := repo.FindByID(ctx, oldCmd.ID)
newGot, _ := repo.FindByID(ctx, newCmd.ID)
if oldGot.Status != model.AgentCommandStatusTimeout {
t.Errorf("old should be timeout: %+v", oldGot)
}
if newGot.Status != model.AgentCommandStatusDispatched {
t.Errorf("new should stay dispatched: %+v", newGot)
}
}
func TestAgentCommandRepository_ListStaleActiveIncludesPendingAndDispatched(t *testing.T) {
db := newTestDB(t)
repo := NewAgentCommandRepository(db)
ctx := context.Background()
old := time.Now().Add(-time.Hour)
recent := time.Now()
oldPending := &model.AgentCommand{NodeID: 1, Type: "run_task", Status: model.AgentCommandStatusPending, CreatedAt: old}
oldDispatched := &model.AgentCommand{NodeID: 1, Type: "restore_record", Status: model.AgentCommandStatusDispatched, DispatchedAt: &old}
recentPending := &model.AgentCommand{NodeID: 1, Type: "run_task", Status: model.AgentCommandStatusPending, CreatedAt: recent}
succeeded := &model.AgentCommand{NodeID: 1, Type: "run_task", Status: model.AgentCommandStatusSucceeded, CreatedAt: old}
for _, cmd := range []*model.AgentCommand{oldPending, oldDispatched, recentPending, succeeded} {
if err := repo.Create(ctx, cmd); err != nil {
t.Fatalf("Create returned error: %v", err)
}
}
items, err := repo.ListStaleActive(ctx, time.Now().Add(-30*time.Minute))
if err != nil {
t.Fatalf("ListStaleActive returned error: %v", err)
}
if len(items) != 2 {
t.Fatalf("expected 2 stale active commands, got %#v", items)
}
if items[0].ID != oldPending.ID || items[1].ID != oldDispatched.ID {
t.Fatalf("unexpected stale active order/items: %#v", items)
}
}
func TestAgentCommandRepository_NodeQueueSummaries(t *testing.T) {
db := newTestDB(t)
repo := NewAgentCommandRepository(db)
ctx := context.Background()
old := time.Now().UTC().Add(-20 * time.Minute)
recent := time.Now().UTC().Add(-2 * time.Minute)
dispatchedAt := time.Now().UTC().Add(-5 * time.Minute)
completedAt := time.Now().UTC().Add(-1 * time.Minute)
commands := []*model.AgentCommand{
{NodeID: 1, Type: model.AgentCommandTypeRunTask, Status: model.AgentCommandStatusPending, CreatedAt: old},
{NodeID: 1, Type: model.AgentCommandTypeRestoreRecord, Status: model.AgentCommandStatusPending, CreatedAt: recent},
{NodeID: 1, Type: model.AgentCommandTypeRunTask, Status: model.AgentCommandStatusDispatched, DispatchedAt: &dispatchedAt},
{NodeID: 1, Type: model.AgentCommandTypeRunTask, Status: model.AgentCommandStatusFailed, ErrorMessage: "boom", CompletedAt: &completedAt},
{NodeID: 1, Type: model.AgentCommandTypeRunTask, Status: model.AgentCommandStatusTimeout, ErrorMessage: "late", CompletedAt: &recent},
{NodeID: 2, Type: model.AgentCommandTypeRunTask, Status: model.AgentCommandStatusPending, CreatedAt: old},
}
for _, cmd := range commands {
if err := repo.Create(ctx, cmd); err != nil {
t.Fatalf("Create returned error: %v", err)
}
}
summaries, err := repo.NodeQueueSummaries(ctx)
if err != nil {
t.Fatalf("NodeQueueSummaries returned error: %v", err)
}
nodeOne := summaries[1]
if nodeOne.Pending != 2 || nodeOne.Dispatched != 1 || nodeOne.Running != 1 || nodeOne.Depth != 3 {
t.Fatalf("unexpected node 1 summary: %#v", nodeOne)
}
if nodeOne.Timeouts != 1 || nodeOne.LastError != "boom" {
t.Fatalf("expected terminal timeout and latest error in summary, got %#v", nodeOne)
}
if nodeOne.OldestActiveAt == nil || !nodeOne.OldestActiveAt.Equal(old) {
t.Fatalf("expected oldest active at %s, got %#v", old, nodeOne.OldestActiveAt)
}
if nodeTwo := summaries[2]; nodeTwo.Pending != 1 || nodeTwo.Depth != 1 || nodeTwo.Timeouts != 0 || nodeTwo.LastError != "" {
t.Fatalf("unexpected node 2 summary: %#v", nodeTwo)
}
}