From e3515b9eb26bf1583a70150e5232e3266a37c0f7 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Mon, 18 May 2026 10:28:18 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(windows):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E9=97=AA=E9=80=80=E4=B8=8E=E9=A9=B1=E5=8A=A8=E4=BB=A3?= =?UTF-8?q?=E7=90=86=E5=AE=89=E8=A3=85=E5=A4=B1=E8=B4=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修复 WebView2 zoom factor 跨线程调用风险,切回窗口线程执行并增加 recover 与超时保护 - 完善 Redis 命令结果 JSON-safe 兜底,避免复杂返回值格式化触发程序崩溃 - 调整 Windows driver-agent 校验逻辑,仅读取 PE Machine 字段判断架构兼容性 - 避免 COFF string table EOF 被误判为无效 Windows 可执行文件,修复驱动在线安装和本地导入失败 - 补充窗口缩放、Redis 返回值和驱动代理 PE 校验回归测试 --- internal/app/window_zoom_windows.go | 120 +++++++++++++++--- internal/app/window_zoom_windows_test.go | 33 ++++- internal/db/driver_agent_binary_check.go | 87 ++++++++++++- internal/db/driver_agent_binary_check_test.go | 90 +++++++++++++ internal/redis/redis_impl.go | 65 +++++++++- internal/redis/redis_impl_test.go | 63 +++++++++ 6 files changed, 425 insertions(+), 33 deletions(-) create mode 100644 internal/db/driver_agent_binary_check_test.go diff --git a/internal/app/window_zoom_windows.go b/internal/app/window_zoom_windows.go index 0a3b88c..21973ef 100644 --- a/internal/app/window_zoom_windows.go +++ b/internal/app/window_zoom_windows.go @@ -6,9 +6,12 @@ import ( "context" "fmt" "reflect" + "time" "unsafe" ) +const resetWebViewZoomInvokeTimeout = 2 * time.Second + // resetWebViewZoomFactor 通过 WebView2 ICoreWebView2Controller::put_ZoomFactor 把 WebView2 // 内部 zoom factor 重置为 1.0。这是 Windows 任务栏恢复后字体度量异常变大的根因解: // 字体度量缓存在 WebView2 D2D/DirectWrite 层,Chromium layout invalidation(CSS zoom hack) @@ -17,9 +20,10 @@ import ( // 实现路径: // 1. Wails 在 ctx 里以 key "frontend" 注入了 *desktop/windows.Frontend // 2. Frontend.chromium 是 unexported 字段 *edge.Chromium -// 3. Chromium.PutZoomFactor(float64) 是 exported 方法(封装了 controller.put_ZoomFactor) +// 3. Frontend.mainWindow 是 unexported 字段 *windows.Window,可用 Invoke 切回窗口线程 +// 4. Chromium.PutZoomFactor(float64) 是 exported 方法(封装了 controller.put_ZoomFactor) // -// 用反射 + unsafe.Pointer 解锁 unexported 字段后 MethodByName("PutZoomFactor").Call。 +// 用反射 + unsafe.Pointer 解锁 unexported 字段后,通过 mainWindow.Invoke 调 PutZoomFactor。 // 不需要 import wails 内部包,也不需要 fork wails。 // // 失败时返回错误(不 panic),让调用方决定是否回退到 toggle 路径。 @@ -35,38 +39,112 @@ func resetWebViewZoomFactor(ctx context.Context, factor float64) (err error) { if ctx == nil { return fmt.Errorf("ctx is nil") } + frontendValue, err := resolveWailsFrontendValue(ctx) + if err != nil { + return err + } + chromiumValue, err := accessibleWailsFrontendField(frontendValue, "chromium") + if err != nil { + return err + } + mainWindowValue, err := accessibleWailsFrontendField(frontendValue, "mainWindow") + if err != nil { + return err + } + + putZoomFactor := chromiumValue.MethodByName("PutZoomFactor") + if !putZoomFactor.IsValid() { + return fmt.Errorf("PutZoomFactor method not found on chromium (go-webview2 version may have changed)") + } + if putZoomFactor.Type().NumIn() != 1 || putZoomFactor.Type().In(0).Kind() != reflect.Float64 || putZoomFactor.Type().NumOut() != 0 { + return fmt.Errorf("PutZoomFactor signature changed: expected func(float64), got %v", putZoomFactor.Type()) + } + + invoke := mainWindowValue.MethodByName("Invoke") + if !invoke.IsValid() { + return fmt.Errorf("mainWindow.Invoke method not found (wails version may have changed)") + } + if invoke.Type().NumIn() != 1 || invoke.Type().In(0).Kind() != reflect.Func || invoke.Type().In(0).NumIn() != 0 || invoke.Type().In(0).NumOut() != 0 || invoke.Type().NumOut() != 0 { + return fmt.Errorf("mainWindow.Invoke signature changed: expected func(func()), got %v", invoke.Type()) + } + + done := make(chan error, 1) + if err := safeCallInvoke(invoke, func() { + done <- safeCallPutZoomFactor(putZoomFactor, factor) + }); err != nil { + return err + } + + select { + case err := <-done: + return err + case <-time.After(resetWebViewZoomInvokeTimeout): + return fmt.Errorf("timed out waiting for mainWindow.Invoke to reset WebView2 zoom factor") + } +} + +func resolveWailsFrontendValue(ctx context.Context) (reflect.Value, error) { frontendIface := ctx.Value("frontend") if frontendIface == nil { - return fmt.Errorf("wails frontend not found in ctx (key=\"frontend\")") + return reflect.Value{}, fmt.Errorf("wails frontend not found in ctx (key=\"frontend\")") } frontendValue := reflect.ValueOf(frontendIface) if frontendValue.Kind() == reflect.Ptr { + if frontendValue.IsNil() { + return reflect.Value{}, fmt.Errorf("wails frontend is nil") + } frontendValue = frontendValue.Elem() } if !frontendValue.IsValid() || frontendValue.Kind() != reflect.Struct { - return fmt.Errorf("wails frontend has unexpected kind %v", frontendValue.Kind()) + return reflect.Value{}, fmt.Errorf("wails frontend has unexpected kind %v", frontendValue.Kind()) + } + if !frontendValue.CanAddr() { + return reflect.Value{}, fmt.Errorf("wails frontend is not addressable") + } + return frontendValue, nil +} + +func accessibleWailsFrontendField(frontendValue reflect.Value, fieldName string) (reflect.Value, error) { + field := frontendValue.FieldByName(fieldName) + if !field.IsValid() { + return reflect.Value{}, fmt.Errorf("wails Frontend.%s field not found (wails version may have changed)", fieldName) + } + if !field.CanAddr() { + return reflect.Value{}, fmt.Errorf("wails Frontend.%s field is not addressable", fieldName) + } + if isNilReflectValue(field) { + return reflect.Value{}, fmt.Errorf("wails Frontend.%s is nil (WebView2 not yet initialised)", fieldName) } - chromiumField := frontendValue.FieldByName("chromium") - if !chromiumField.IsValid() { - return fmt.Errorf("wails Frontend.chromium field not found (wails version may have changed)") - } - if chromiumField.IsNil() { - return fmt.Errorf("wails Frontend.chromium is nil (WebView2 not yet initialised)") - } + return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem(), nil +} - // 用 NewAt + unsafe.Pointer 解锁 unexported 字段访问限制 - accessible := reflect.NewAt(chromiumField.Type(), unsafe.Pointer(chromiumField.UnsafeAddr())).Elem() - method := accessible.MethodByName("PutZoomFactor") - if !method.IsValid() { - return fmt.Errorf("PutZoomFactor method not found on chromium (go-webview2 version may have changed)") - } - if method.Type().NumIn() != 1 || method.Type().In(0).Kind() != reflect.Float64 { - return fmt.Errorf("PutZoomFactor signature changed: expected func(float64), got %v", method.Type()) +func isNilReflectValue(value reflect.Value) bool { + switch value.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: + return value.IsNil() + default: + return false } +} - // PutZoomFactor 内部已经 swallow error 并通过 errorCallback 报告——这里不会 panic - method.Call([]reflect.Value{reflect.ValueOf(factor)}) +func safeCallInvoke(invoke reflect.Value, fn func()) (err error) { + defer func() { + if value := recover(); value != nil { + err = fmt.Errorf("mainWindow.Invoke panicked while resetting WebView2 zoom factor: %v", value) + } + }() + invoke.Call([]reflect.Value{reflect.ValueOf(fn)}) + return nil +} + +func safeCallPutZoomFactor(putZoomFactor reflect.Value, factor float64) (err error) { + defer func() { + if value := recover(); value != nil { + err = fmt.Errorf("PutZoomFactor panicked while resetting WebView2 zoom factor: %v", value) + } + }() + putZoomFactor.Call([]reflect.Value{reflect.ValueOf(factor)}) return nil } diff --git a/internal/app/window_zoom_windows_test.go b/internal/app/window_zoom_windows_test.go index 2cfb928..83cb4b3 100644 --- a/internal/app/window_zoom_windows_test.go +++ b/internal/app/window_zoom_windows_test.go @@ -21,11 +21,21 @@ func (f *fakeChromium) PutZoomFactor(factor float64) { f.last.Store(factor) } +type fakeWindow struct { + invoked atomic.Int32 +} + +func (f *fakeWindow) Invoke(fn func()) { + f.invoked.Add(1) + fn() +} + // fakeFrontend 模仿 wails 的 internal/frontend/desktop/windows.Frontend: -// unexported 字段 chromium 是 *fakeChromium 类型(exported method PutZoomFactor)。 +// unexported 字段 chromium/mainWindow 分别模仿 *edge.Chromium 和 *windows.Window。 // 反射代码不依赖具体类型名,只检查 method signature。 type fakeFrontend struct { - chromium *fakeChromium + chromium *fakeChromium + mainWindow *fakeWindow } type panicChromium struct{} @@ -48,11 +58,15 @@ func stringContextKey(key string) any { func TestResetWebViewZoomFactorCallsPutZoomFactor(t *testing.T) { chromium := &fakeChromium{} - ctx := context.WithValue(context.Background(), stringContextKey("frontend"), &fakeFrontend{chromium: chromium}) + window := &fakeWindow{} + ctx := context.WithValue(context.Background(), stringContextKey("frontend"), &fakeFrontend{chromium: chromium, mainWindow: window}) if err := resetWebViewZoomFactor(ctx, 1.0); err != nil { t.Fatalf("expected reset to succeed against fake frontend, got %v", err) } + if got := window.invoked.Load(); got != 1 { + t.Fatalf("expected reset to run through mainWindow.Invoke exactly once, got %d", got) + } if got := chromium.called.Load(); got != 1 { t.Fatalf("expected PutZoomFactor called exactly once, got %d", got) } @@ -76,7 +90,7 @@ func TestResetWebViewZoomFactorErrorsWhenChromiumFieldMissing(t *testing.T) { } func TestResetWebViewZoomFactorErrorsWhenChromiumNil(t *testing.T) { - ctx := context.WithValue(context.Background(), stringContextKey("frontend"), &fakeFrontend{chromium: nil}) + ctx := context.WithValue(context.Background(), stringContextKey("frontend"), &fakeFrontend{chromium: nil, mainWindow: &fakeWindow{}}) err := resetWebViewZoomFactor(ctx, 1.0) if err == nil { t.Fatal("expected error when chromium is nil, got nil") @@ -86,6 +100,17 @@ func TestResetWebViewZoomFactorErrorsWhenChromiumNil(t *testing.T) { } } +func TestResetWebViewZoomFactorErrorsWhenMainWindowNil(t *testing.T) { + ctx := context.WithValue(context.Background(), stringContextKey("frontend"), &fakeFrontend{chromium: &fakeChromium{}, mainWindow: nil}) + err := resetWebViewZoomFactor(ctx, 1.0) + if err == nil { + t.Fatal("expected error when mainWindow is nil, got nil") + } + if !strings.Contains(err.Error(), "mainWindow") { + t.Fatalf("expected error to mention mainWindow, got %v", err) + } +} + func TestResetWebViewZoomFactorErrorsWhenFrontendMissing(t *testing.T) { err := resetWebViewZoomFactor(context.Background(), 1.0) if err == nil { diff --git a/internal/db/driver_agent_binary_check.go b/internal/db/driver_agent_binary_check.go index 762c720..2184bb7 100644 --- a/internal/db/driver_agent_binary_check.go +++ b/internal/db/driver_agent_binary_check.go @@ -1,8 +1,10 @@ package db import ( - "debug/pe" + "encoding/binary" "fmt" + "io" + "os" "runtime" "strings" ) @@ -11,6 +13,11 @@ const ( peMachineI386 uint16 = 0x014c peMachineAmd64 uint16 = 0x8664 peMachineArm64 uint16 = 0xaa64 + + peDOSHeaderMinSize = 0x40 + peHeaderOffsetAddr = 0x3c + peSignatureSize = 4 + peCOFFHeaderSize = 20 ) func windowsMachineLabel(machine uint16) string { @@ -40,23 +47,89 @@ func expectedWindowsMachineForGoArch(goarch string) (uint16, string, bool) { } func validateWindowsExecutableMachine(pathText string) error { - file, err := pe.Open(pathText) + return validateWindowsExecutableMachineForArch(pathText, runtime.GOARCH) +} + +func validateWindowsExecutableMachineForArch(pathText string, goarch string) error { + machine, err := readWindowsExecutableMachine(pathText) if err != nil { return fmt.Errorf("无法识别为有效的 Windows 可执行文件:%w", err) } - defer file.Close() - expectedMachine, expectedLabel, ok := expectedWindowsMachineForGoArch(runtime.GOARCH) + expectedMachine, expectedLabel, ok := expectedWindowsMachineForGoArch(goarch) if !ok { return nil } - actualMachine := file.FileHeader.Machine - if actualMachine != expectedMachine { - return fmt.Errorf("可执行文件架构不兼容(文件=%s,当前进程=%s)", windowsMachineLabel(actualMachine), expectedLabel) + if machine != expectedMachine { + return fmt.Errorf("可执行文件架构不兼容(文件=%s,当前进程=%s)", windowsMachineLabel(machine), expectedLabel) } return nil } +func readWindowsExecutableMachine(pathText string) (uint16, error) { + file, err := os.Open(pathText) + if err != nil { + return 0, err + } + defer file.Close() + + info, statErr := file.Stat() + if statErr != nil { + return 0, statErr + } + if info.IsDir() { + return 0, fmt.Errorf("路径是目录") + } + if info.Size() < peDOSHeaderMinSize { + return 0, fmt.Errorf("文件头不完整") + } + + var dosMagic [2]byte + if err := readWindowsPEBytes(file, 0, dosMagic[:]); err != nil { + return 0, fmt.Errorf("读取 DOS 头失败:%w", err) + } + if dosMagic[0] != 'M' || dosMagic[1] != 'Z' { + return 0, fmt.Errorf("缺少 MZ 头") + } + + var offsetBytes [4]byte + if err := readWindowsPEBytes(file, peHeaderOffsetAddr, offsetBytes[:]); err != nil { + return 0, fmt.Errorf("读取 PE 头偏移失败:%w", err) + } + peOffset := int64(binary.LittleEndian.Uint32(offsetBytes[:])) + if peOffset < peDOSHeaderMinSize { + return 0, fmt.Errorf("PE 头偏移异常") + } + if peOffset+peSignatureSize+peCOFFHeaderSize > info.Size() { + return 0, fmt.Errorf("PE 头不完整") + } + + var signature [4]byte + if err := readWindowsPEBytes(file, peOffset, signature[:]); err != nil { + return 0, fmt.Errorf("读取 PE 签名失败:%w", err) + } + if signature[0] != 'P' || signature[1] != 'E' || signature[2] != 0 || signature[3] != 0 { + return 0, fmt.Errorf("缺少 PE 签名") + } + + var machineBytes [2]byte + if err := readWindowsPEBytes(file, peOffset+peSignatureSize, machineBytes[:]); err != nil { + return 0, fmt.Errorf("读取 PE 架构失败:%w", err) + } + return binary.LittleEndian.Uint16(machineBytes[:]), nil +} + +func readWindowsPEBytes(reader io.ReaderAt, offset int64, target []byte) error { + if len(target) == 0 { + return nil + } + _, err := reader.ReadAt(target, offset) + if err == io.EOF { + return io.ErrUnexpectedEOF + } + return err +} + // ValidateOptionalDriverAgentExecutable 校验可选驱动代理二进制是否可在当前进程中执行。 // 当前主要用于 Windows 下的 PE 架构兼容性校验,避免升级后复用到错误架构的旧代理。 func ValidateOptionalDriverAgentExecutable(driverType string, executablePath string) error { diff --git a/internal/db/driver_agent_binary_check_test.go b/internal/db/driver_agent_binary_check_test.go new file mode 100644 index 0000000..90c8b4e --- /dev/null +++ b/internal/db/driver_agent_binary_check_test.go @@ -0,0 +1,90 @@ +package db + +import ( + "debug/pe" + "encoding/binary" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestValidateWindowsExecutableMachineIgnoresCOFFStringTableEOF(t *testing.T) { + path := filepath.Join(t.TempDir(), "oceanbase-driver-agent-windows-amd64.exe") + writeMinimalWindowsPEWithBrokenStringTable(t, path, peMachineAmd64) + + if file, err := pe.Open(path); err == nil { + _ = file.Close() + t.Fatal("fixture should reproduce debug/pe string table failure") + } else if !strings.Contains(err.Error(), "string table") { + t.Fatalf("fixture should fail in debug/pe string table parsing, got %v", err) + } + + if err := validateWindowsExecutableMachineForArch(path, "amd64"); err != nil { + t.Fatalf("valid machine header should pass without reading optional string table: %v", err) + } +} + +func TestValidateWindowsExecutableMachineRejectsMachineMismatch(t *testing.T) { + path := filepath.Join(t.TempDir(), "sqlserver-driver-agent-windows-arm64.exe") + writeMinimalWindowsPE(t, path, peMachineArm64) + + err := validateWindowsExecutableMachineForArch(path, "amd64") + if err == nil { + t.Fatal("expected machine mismatch to be rejected") + } + if !strings.Contains(err.Error(), "windows-arm64") || !strings.Contains(err.Error(), "windows-amd64") { + t.Fatalf("expected architecture labels in error, got %v", err) + } +} + +func TestValidateWindowsExecutableMachineRejectsNonPEFile(t *testing.T) { + path := filepath.Join(t.TempDir(), "oceanbase-driver-agent-windows-amd64.exe") + if err := os.WriteFile(path, []byte("not a windows executable"), 0o644); err != nil { + t.Fatalf("write fixture failed: %v", err) + } + + err := validateWindowsExecutableMachineForArch(path, "amd64") + if err == nil { + t.Fatal("expected non-PE file to be rejected") + } + if !strings.Contains(err.Error(), "无法识别为有效的 Windows 可执行文件") { + t.Fatalf("expected executable validation error, got %v", err) + } +} + +func writeMinimalWindowsPE(t *testing.T, path string, machine uint16) { + t.Helper() + + const peOffset = 0x80 + content := make([]byte, peOffset+4+20) + content[0] = 'M' + content[1] = 'Z' + binary.LittleEndian.PutUint32(content[peHeaderOffsetAddr:], peOffset) + copy(content[peOffset:], []byte{'P', 'E', 0, 0}) + binary.LittleEndian.PutUint16(content[peOffset+4:], machine) + + if err := os.WriteFile(path, content, 0o644); err != nil { + t.Fatalf("write PE fixture failed: %v", err) + } +} + +func writeMinimalWindowsPEWithBrokenStringTable(t *testing.T, path string, machine uint16) { + t.Helper() + + const peOffset = 0x80 + const symbolTableOffset = peOffset + 4 + 20 + content := make([]byte, symbolTableOffset+18) + content[0] = 'M' + content[1] = 'Z' + binary.LittleEndian.PutUint32(content[peHeaderOffsetAddr:], peOffset) + copy(content[peOffset:], []byte{'P', 'E', 0, 0}) + coffHeader := content[peOffset+4:] + binary.LittleEndian.PutUint16(coffHeader[0:], machine) + binary.LittleEndian.PutUint32(coffHeader[8:], symbolTableOffset) + binary.LittleEndian.PutUint32(coffHeader[12:], 1) + + if err := os.WriteFile(path, content, 0o644); err != nil { + t.Fatalf("write PE fixture failed: %v", err) + } +} diff --git a/internal/redis/redis_impl.go b/internal/redis/redis_impl.go index e45cac9..d48e6c4 100644 --- a/internal/redis/redis_impl.go +++ b/internal/redis/redis_impl.go @@ -5,8 +5,11 @@ import ( "crypto/tls" "errors" "fmt" + "math" + "math/big" "net" "net/url" + "reflect" "strconv" "strings" "sync" @@ -1310,6 +1313,9 @@ func (r *RedisClientImpl) ExecuteCommand(args []string) (interface{}, error) { // 如果让原值穿透到 Wails RPC,json.Marshal 会失败,Wails runtime 在 Windows 上会直接 panic // 让进程退出——用户感知为 GoNavi 闪退(issue: HGETALL 闪退)。 // 平展成 [k1, v1, k2, v2, ...] 交错形式与 RESP2 array 输出一致,前端按 array 渲染。 +// +// 这里同时把 RESP3 的 NaN/Inf 浮点、大整数、error 以及其他 map/slice 形态统一收敛为 +// JSON-safe 结构,避免 Redis 命令面板再把不可序列化的值透传给 Wails。 func formatCommandResult(result interface{}) interface{} { switch v := result.(type) { case []interface{}: @@ -1325,10 +1331,67 @@ func formatCommandResult(result interface{}) interface{} { flattened = append(flattened, formatCommandResult(val)) } return flattened + case map[string]interface{}: + formatted := make(map[string]interface{}, len(v)) + for key, val := range v { + formatted[key] = formatCommandResult(val) + } + return formatted case []byte: return string(v) - default: + case error: + return v.Error() + case *big.Int: + if v == nil { + return nil + } + return v.String() + case float64: + if math.IsNaN(v) || math.IsInf(v, 0) { + return fmt.Sprint(v) + } return v + case float32: + f := float64(v) + if math.IsNaN(f) || math.IsInf(f, 0) { + return fmt.Sprint(v) + } + return v + default: + return formatCommandResultByReflection(v) + } +} + +func formatCommandResultByReflection(result interface{}) interface{} { + value := reflect.ValueOf(result) + if !value.IsValid() { + return nil + } + switch value.Kind() { + case reflect.Map: + if value.Type().Key().Kind() == reflect.String { + formatted := make(map[string]interface{}, value.Len()) + iter := value.MapRange() + for iter.Next() { + formatted[iter.Key().String()] = formatCommandResult(iter.Value().Interface()) + } + return formatted + } + flattened := make([]interface{}, 0, value.Len()*2) + iter := value.MapRange() + for iter.Next() { + flattened = append(flattened, formatCommandResult(iter.Key().Interface())) + flattened = append(flattened, formatCommandResult(iter.Value().Interface())) + } + return flattened + case reflect.Slice, reflect.Array: + formatted := make([]interface{}, value.Len()) + for i := 0; i < value.Len(); i++ { + formatted[i] = formatCommandResult(value.Index(i).Interface()) + } + return formatted + default: + return result } } diff --git a/internal/redis/redis_impl_test.go b/internal/redis/redis_impl_test.go index 3b48889..2821be3 100644 --- a/internal/redis/redis_impl_test.go +++ b/internal/redis/redis_impl_test.go @@ -3,6 +3,8 @@ package redis import ( "encoding/json" "errors" + "math" + "math/big" "sort" "testing" ) @@ -83,6 +85,67 @@ func TestFormatCommandResultPreservesScalarAndByteSlice(t *testing.T) { } } +func TestFormatCommandResultRecursivelyFormatsStringKeyMapValues(t *testing.T) { + input := map[string]interface{}{ + "nestedMap": map[interface{}]interface{}{"k": "v"}, + "bytes": []byte("ok"), + } + + got := formatCommandResult(input) + formatted, ok := got.(map[string]interface{}) + if !ok { + t.Fatalf("expected map[string]interface{}, got %T (%#v)", got, got) + } + if formatted["bytes"] != "ok" { + t.Fatalf("expected []byte value converted to string, got %#v", formatted["bytes"]) + } + if _, ok := formatted["nestedMap"].([]interface{}); !ok { + t.Fatalf("expected nested RESP3 map to be flattened, got %T", formatted["nestedMap"]) + } + if _, err := json.Marshal(formatted); err != nil { + t.Fatalf("formatted string-key map must be JSON-marshalable, got error: %v", err) + } +} + +func TestFormatCommandResultFormatsJSONUnsupportedScalars(t *testing.T) { + input := []interface{}{ + math.Inf(1), + math.Inf(-1), + math.NaN(), + big.NewInt(1234567890123456789), + errors.New("redis nested error"), + } + + got := formatCommandResult(input) + arr, ok := got.([]interface{}) + if !ok || len(arr) != len(input) { + t.Fatalf("expected formatted array of length %d, got %#v", len(input), got) + } + for i, item := range arr { + if _, ok := item.(string); !ok { + t.Fatalf("expected item %d to be string after formatting, got %T (%#v)", i, item, item) + } + } + if _, err := json.Marshal(arr); err != nil { + t.Fatalf("formatted unsupported scalars must be JSON-marshalable, got error: %v", err) + } +} + +func TestFormatCommandResultFormatsGenericMapsAndSlices(t *testing.T) { + input := map[int][]byte{ + 1: []byte("one"), + } + + got := formatCommandResult(input) + arr, ok := got.([]interface{}) + if !ok || len(arr) != 2 { + t.Fatalf("expected generic non-string map to be flattened into 2 elements, got %#v", got) + } + if _, err := json.Marshal(arr); err != nil { + t.Fatalf("formatted generic map must be JSON-marshalable, got error: %v", err) + } +} + func TestSanitizeRedisPassword(t *testing.T) { tests := []struct { name string