mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-05-11 17:19:45 +08:00
🐛 fix(redis): 修正 hash 详情读取依赖 HGETALL
- 为 hash 读取增加 HGETALL 权限受限时的 HSCAN 降级路径 - RedisGetValue 与 GetHash 统一复用 fallback 并保留长度元数据 - 补充普通用户权限受限与非权限错误回归测试 Fixes #380
This commit is contained in:
@@ -699,12 +699,12 @@ func (r *RedisClientImpl) GetValue(key string) (*RedisValue, error) {
|
||||
result.Length = int64(len(val))
|
||||
|
||||
case "hash":
|
||||
val, err := r.client.HGetAll(ctx, physicalKey).Result()
|
||||
val, length, err := r.readHashEntries(ctx, physicalKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result.Value = val
|
||||
result.Length = int64(len(val))
|
||||
result.Length = length
|
||||
|
||||
case "list":
|
||||
length, err := r.client.LLen(ctx, physicalKey).Result()
|
||||
@@ -819,7 +819,74 @@ func (r *RedisClientImpl) GetHash(key string) (map[string]string, error) {
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
return r.client.HGetAll(ctx, r.toPhysicalKey(key)).Result()
|
||||
values, _, err := r.readHashEntries(ctx, r.toPhysicalKey(key))
|
||||
return values, err
|
||||
}
|
||||
|
||||
func (r *RedisClientImpl) readHashEntries(ctx context.Context, physicalKey string) (map[string]string, int64, error) {
|
||||
return readRedisHashEntriesWithFallback(
|
||||
func() (map[string]string, error) {
|
||||
return r.client.HGetAll(ctx, physicalKey).Result()
|
||||
},
|
||||
func() (int64, error) {
|
||||
return r.client.HLen(ctx, physicalKey).Result()
|
||||
},
|
||||
func(cursor uint64, count int64) ([]string, uint64, error) {
|
||||
return r.client.HScan(ctx, physicalKey, cursor, "*", count).Result()
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func readRedisHashEntriesWithFallback(
|
||||
readAll func() (map[string]string, error),
|
||||
readLength func() (int64, error),
|
||||
scan func(cursor uint64, count int64) ([]string, uint64, error),
|
||||
) (map[string]string, int64, error) {
|
||||
values, err := readAll()
|
||||
if err == nil {
|
||||
return values, int64(len(values)), nil
|
||||
}
|
||||
if !shouldFallbackRedisHashScan(err) {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
entries := make(map[string]string)
|
||||
var cursor uint64
|
||||
for round := 0; round < redisScanMaxRounds; round++ {
|
||||
pairs, nextCursor, scanErr := scan(cursor, redisScanMinStepCount)
|
||||
if scanErr != nil {
|
||||
return nil, 0, scanErr
|
||||
}
|
||||
if len(pairs)%2 != 0 {
|
||||
return nil, 0, fmt.Errorf("Redis HSCAN 返回结果格式异常")
|
||||
}
|
||||
for i := 0; i < len(pairs); i += 2 {
|
||||
entries[pairs[i]] = pairs[i+1]
|
||||
}
|
||||
cursor = nextCursor
|
||||
if cursor == 0 {
|
||||
length, lengthErr := readLength()
|
||||
if lengthErr == nil {
|
||||
return entries, length, nil
|
||||
}
|
||||
return entries, int64(len(entries)), nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, 0, fmt.Errorf("Redis HSCAN 超出安全轮次,无法完整读取 hash")
|
||||
}
|
||||
|
||||
func shouldFallbackRedisHashScan(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
message := strings.ToLower(strings.TrimSpace(err.Error()))
|
||||
if !strings.Contains(message, "hgetall") {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(message, "not support for normal user") ||
|
||||
strings.Contains(message, "noperm") ||
|
||||
strings.Contains(message, "permission")
|
||||
}
|
||||
|
||||
// SetHashField sets a field in a hash
|
||||
|
||||
@@ -119,3 +119,60 @@ func TestNormalizeRedisGetValueError(t *testing.T) {
|
||||
t.Fatal("expected nil for supported existing key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadRedisHashEntriesWithFallbackUsesHScanWhenHGetAllForbidden(t *testing.T) {
|
||||
scanCalls := 0
|
||||
values, length, err := readRedisHashEntriesWithFallback(
|
||||
func() (map[string]string, error) {
|
||||
return nil, errors.New("ERR command 'HGETALL' not support for normal user")
|
||||
},
|
||||
func() (int64, error) {
|
||||
return 2, nil
|
||||
},
|
||||
func(cursor uint64, count int64) ([]string, uint64, error) {
|
||||
scanCalls++
|
||||
if cursor != 0 {
|
||||
t.Fatalf("expected first scan cursor to be 0, got %d", cursor)
|
||||
}
|
||||
if count <= 0 {
|
||||
t.Fatalf("expected positive scan count, got %d", count)
|
||||
}
|
||||
return []string{"field-a", "value-a", "field-b", "value-b"}, 0, nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("readRedisHashEntriesWithFallback() unexpected error: %v", err)
|
||||
}
|
||||
if scanCalls != 1 {
|
||||
t.Fatalf("expected exactly one HSCAN fallback, got %d", scanCalls)
|
||||
}
|
||||
if length != 2 {
|
||||
t.Fatalf("expected hash length 2, got %d", length)
|
||||
}
|
||||
if got := values["field-a"]; got != "value-a" {
|
||||
t.Fatalf("expected field-a=value-a, got %q", got)
|
||||
}
|
||||
if got := values["field-b"]; got != "value-b" {
|
||||
t.Fatalf("expected field-b=value-b, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadRedisHashEntriesWithFallbackReturnsOriginalErrorForNonPermissionFailure(t *testing.T) {
|
||||
expectedErr := errors.New("ERR wrong type")
|
||||
_, _, err := readRedisHashEntriesWithFallback(
|
||||
func() (map[string]string, error) {
|
||||
return nil, expectedErr
|
||||
},
|
||||
func() (int64, error) {
|
||||
t.Fatal("expected HLEN not to run for non-permission failure")
|
||||
return 0, nil
|
||||
},
|
||||
func(cursor uint64, count int64) ([]string, uint64, error) {
|
||||
t.Fatal("expected HSCAN not to run for non-permission failure")
|
||||
return nil, 0, nil
|
||||
},
|
||||
)
|
||||
if !errors.Is(err, expectedErr) {
|
||||
t.Fatalf("expected original error %v, got %v", expectedErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user