diff --git a/server/internal/repository/node_repository.go b/server/internal/repository/node_repository.go index b497b26..c98c481 100644 --- a/server/internal/repository/node_repository.go +++ b/server/internal/repository/node_repository.go @@ -49,7 +49,20 @@ func (r *GormNodeRepository) FindByID(ctx context.Context, id uint) (*model.Node func (r *GormNodeRepository) FindByToken(ctx context.Context, token string) (*model.Node, error) { var item model.Node - if err := r.db.WithContext(ctx).Where("token = ?", token).First(&item).Error; err != nil { + // 主 token 查询 + err := r.db.WithContext(ctx).Where("token = ?", token).First(&item).Error + if err == nil { + return &item, nil + } + if !errors.Is(err, gorm.ErrRecordNotFound) { + return nil, err + } + // 回退:prev_token 且未过期 + now := time.Now().UTC() + err = r.db.WithContext(ctx). + Where("prev_token = ? AND prev_token_expires IS NOT NULL AND prev_token_expires > ?", token, now). + First(&item).Error + if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } diff --git a/server/internal/repository/node_repository_test.go b/server/internal/repository/node_repository_test.go new file mode 100644 index 0000000..9c7ed2a --- /dev/null +++ b/server/internal/repository/node_repository_test.go @@ -0,0 +1,76 @@ +package repository + +import ( + "context" + "path/filepath" + "testing" + "time" + + "backupx/server/internal/model" + "github.com/glebarez/sqlite" + "gorm.io/gorm" + gormlogger "gorm.io/gorm/logger" +) + +func openTestNodeDB(t *testing.T) *gorm.DB { + t.Helper() + path := filepath.Join(t.TempDir(), "nodes.db") + db, err := gorm.Open(sqlite.Open(path), &gorm.Config{Logger: gormlogger.Default.LogMode(gormlogger.Silent)}) + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + if err := db.AutoMigrate(&model.Node{}); err != nil { + t.Fatalf("migrate: %v", err) + } + return db +} + +func TestFindByTokenFallsBackToPrevToken(t *testing.T) { + db := openTestNodeDB(t) + repo := NewNodeRepository(db) + ctx := context.Background() + + future := time.Now().UTC().Add(24 * time.Hour) + node := &model.Node{ + Name: "test", Token: "new-token", + PrevToken: "old-token", PrevTokenExpires: &future, + } + if err := repo.Create(ctx, node); err != nil { + t.Fatalf("create: %v", err) + } + + // 新 token 能查到 + got, err := repo.FindByToken(ctx, "new-token") + if err != nil || got == nil || got.ID != node.ID { + t.Fatalf("new token lookup failed: err=%v got=%v", err, got) + } + + // 旧 token 也能查到(未过期) + got, err = repo.FindByToken(ctx, "old-token") + if err != nil || got == nil || got.ID != node.ID { + t.Fatalf("prev_token lookup failed: err=%v got=%v", err, got) + } +} + +func TestFindByTokenRejectsExpiredPrevToken(t *testing.T) { + db := openTestNodeDB(t) + repo := NewNodeRepository(db) + ctx := context.Background() + + past := time.Now().UTC().Add(-1 * time.Hour) + node := &model.Node{ + Name: "test", Token: "new-token", + PrevToken: "stale", PrevTokenExpires: &past, + } + if err := repo.Create(ctx, node); err != nil { + t.Fatalf("create: %v", err) + } + + got, err := repo.FindByToken(ctx, "stale") + if err != nil { + t.Fatalf("err=%v", err) + } + if got != nil { + t.Fatalf("expected stale prev_token rejected, got %v", got) + } +}