From e56a72eb9fdc87f4b82f8adc5bc0c2df563a58d6 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Fri, 17 Apr 2026 18:07:50 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(redis):=20=E4=BF=AE=E6=AD=A3?= =?UTF-8?q?=20hash=20=E8=AF=A6=E6=83=85=E8=AF=BB=E5=8F=96=E4=BE=9D?= =?UTF-8?q?=E8=B5=96=20HGETALL?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 为 hash 读取增加 HGETALL 权限受限时的 HSCAN 降级路径 - RedisGetValue 与 GetHash 统一复用 fallback 并保留长度元数据 - 补充普通用户权限受限与非权限错误回归测试 Fixes #380 --- internal/redis/redis_impl.go | 73 +++++++++++++++++++++++++++++-- internal/redis/redis_impl_test.go | 57 ++++++++++++++++++++++++ 2 files changed, 127 insertions(+), 3 deletions(-) diff --git a/internal/redis/redis_impl.go b/internal/redis/redis_impl.go index 8ea032c..2313e8b 100644 --- a/internal/redis/redis_impl.go +++ b/internal/redis/redis_impl.go @@ -699,12 +699,12 @@ func (r *RedisClientImpl) GetValue(key string) (*RedisValue, error) { result.Length = int64(len(val)) case "hash": - val, err := r.client.HGetAll(ctx, physicalKey).Result() + val, length, err := r.readHashEntries(ctx, physicalKey) if err != nil { return nil, err } result.Value = val - result.Length = int64(len(val)) + result.Length = length case "list": length, err := r.client.LLen(ctx, physicalKey).Result() @@ -819,7 +819,74 @@ func (r *RedisClientImpl) GetHash(key string) (map[string]string, error) { } ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - return r.client.HGetAll(ctx, r.toPhysicalKey(key)).Result() + values, _, err := r.readHashEntries(ctx, r.toPhysicalKey(key)) + return values, err +} + +func (r *RedisClientImpl) readHashEntries(ctx context.Context, physicalKey string) (map[string]string, int64, error) { + return readRedisHashEntriesWithFallback( + func() (map[string]string, error) { + return r.client.HGetAll(ctx, physicalKey).Result() + }, + func() (int64, error) { + return r.client.HLen(ctx, physicalKey).Result() + }, + func(cursor uint64, count int64) ([]string, uint64, error) { + return r.client.HScan(ctx, physicalKey, cursor, "*", count).Result() + }, + ) +} + +func readRedisHashEntriesWithFallback( + readAll func() (map[string]string, error), + readLength func() (int64, error), + scan func(cursor uint64, count int64) ([]string, uint64, error), +) (map[string]string, int64, error) { + values, err := readAll() + if err == nil { + return values, int64(len(values)), nil + } + if !shouldFallbackRedisHashScan(err) { + return nil, 0, err + } + + entries := make(map[string]string) + var cursor uint64 + for round := 0; round < redisScanMaxRounds; round++ { + pairs, nextCursor, scanErr := scan(cursor, redisScanMinStepCount) + if scanErr != nil { + return nil, 0, scanErr + } + if len(pairs)%2 != 0 { + return nil, 0, fmt.Errorf("Redis HSCAN 返回结果格式异常") + } + for i := 0; i < len(pairs); i += 2 { + entries[pairs[i]] = pairs[i+1] + } + cursor = nextCursor + if cursor == 0 { + length, lengthErr := readLength() + if lengthErr == nil { + return entries, length, nil + } + return entries, int64(len(entries)), nil + } + } + + return nil, 0, fmt.Errorf("Redis HSCAN 超出安全轮次,无法完整读取 hash") +} + +func shouldFallbackRedisHashScan(err error) bool { + if err == nil { + return false + } + message := strings.ToLower(strings.TrimSpace(err.Error())) + if !strings.Contains(message, "hgetall") { + return false + } + return strings.Contains(message, "not support for normal user") || + strings.Contains(message, "noperm") || + strings.Contains(message, "permission") } // SetHashField sets a field in a hash diff --git a/internal/redis/redis_impl_test.go b/internal/redis/redis_impl_test.go index dafb991..914e088 100644 --- a/internal/redis/redis_impl_test.go +++ b/internal/redis/redis_impl_test.go @@ -119,3 +119,60 @@ func TestNormalizeRedisGetValueError(t *testing.T) { t.Fatal("expected nil for supported existing key") } } + +func TestReadRedisHashEntriesWithFallbackUsesHScanWhenHGetAllForbidden(t *testing.T) { + scanCalls := 0 + values, length, err := readRedisHashEntriesWithFallback( + func() (map[string]string, error) { + return nil, errors.New("ERR command 'HGETALL' not support for normal user") + }, + func() (int64, error) { + return 2, nil + }, + func(cursor uint64, count int64) ([]string, uint64, error) { + scanCalls++ + if cursor != 0 { + t.Fatalf("expected first scan cursor to be 0, got %d", cursor) + } + if count <= 0 { + t.Fatalf("expected positive scan count, got %d", count) + } + return []string{"field-a", "value-a", "field-b", "value-b"}, 0, nil + }, + ) + if err != nil { + t.Fatalf("readRedisHashEntriesWithFallback() unexpected error: %v", err) + } + if scanCalls != 1 { + t.Fatalf("expected exactly one HSCAN fallback, got %d", scanCalls) + } + if length != 2 { + t.Fatalf("expected hash length 2, got %d", length) + } + if got := values["field-a"]; got != "value-a" { + t.Fatalf("expected field-a=value-a, got %q", got) + } + if got := values["field-b"]; got != "value-b" { + t.Fatalf("expected field-b=value-b, got %q", got) + } +} + +func TestReadRedisHashEntriesWithFallbackReturnsOriginalErrorForNonPermissionFailure(t *testing.T) { + expectedErr := errors.New("ERR wrong type") + _, _, err := readRedisHashEntriesWithFallback( + func() (map[string]string, error) { + return nil, expectedErr + }, + func() (int64, error) { + t.Fatal("expected HLEN not to run for non-permission failure") + return 0, nil + }, + func(cursor uint64, count int64) ([]string, uint64, error) { + t.Fatal("expected HSCAN not to run for non-permission failure") + return nil, 0, nil + }, + ) + if !errors.Is(err, expectedErr) { + t.Fatalf("expected original error %v, got %v", expectedErr, err) + } +}