From 71fca7fb86d77f6a56b13f1fd757c314400086b0 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Fri, 15 May 2026 22:23:41 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(export):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=20PostgreSQL=20=E5=B8=83=E5=B0=94=E5=AD=97=E6=AE=B5?= =?UTF-8?q?=E5=A4=87=E4=BB=BD=E7=B1=BB=E5=9E=8B=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 导出修复:PostgreSQL 系列 bool 字段 INSERT 输出 true/false - 兼容处理:支持 bool、boolean、pg_catalog.bool 类型识别 - 回归覆盖:补充备份 SQL 布尔字段导出测试 Refs #444 --- internal/app/methods_file.go | 118 +++++++++++++++++++++++ internal/app/methods_file_export_test.go | 69 ++++++++++++- 2 files changed, 186 insertions(+), 1 deletion(-) diff --git a/internal/app/methods_file.go b/internal/app/methods_file.go index 9ba929e..71c29f4 100644 --- a/internal/app/methods_file.go +++ b/internal/app/methods_file.go @@ -960,11 +960,129 @@ func normalizeImportTemporalValue(dbType, columnType, raw string) string { return parsed.Format("2006-01-02 15:04:05") } +func isPgLikeBooleanDBType(dbType string) bool { + switch strings.ToLower(strings.TrimSpace(dbType)) { + case "postgres", "postgresql", "pg", "pq", "pgx", "kingbase", "kingbase8", "kingbasees", "kingbasev8", "highgo", "vastbase", "opengauss", "open_gauss", "open-gauss": + return true + default: + return false + } +} + +func isBooleanColumnType(columnType string) bool { + typ := strings.ToLower(strings.TrimSpace(columnType)) + if typ == "" { + return false + } + typ = strings.ReplaceAll(typ, `"`, "") + if idx := strings.IndexAny(typ, " ("); idx >= 0 { + typ = typ[:idx] + } + typ = strings.TrimPrefix(typ, "pg_catalog.") + return typ == "bool" || typ == "boolean" +} + +func booleanSQLLiteral(v bool) string { + if v { + return "true" + } + return "false" +} + +func formatSignedBooleanSQLValue(v int64) (string, bool) { + switch v { + case 0: + return "false", true + case 1: + return "true", true + default: + return "", false + } +} + +func formatUnsignedBooleanSQLValue(v uint64) (string, bool) { + switch v { + case 0: + return "false", true + case 1: + return "true", true + default: + return "", false + } +} + +func formatFloatBooleanSQLValue(v float64) (string, bool) { + if v == 0 { + return "false", true + } + if v == 1 { + return "true", true + } + return "", false +} + +func formatBooleanStringSQLValue(raw string) (string, bool) { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "true", "t", "1", "yes", "y", "on": + return "true", true + case "false", "f", "0", "no", "n", "off": + return "false", true + default: + return "", false + } +} + +func formatPostgresBooleanSQLValue(value interface{}) (string, bool) { + switch val := value.(type) { + case bool: + return booleanSQLLiteral(val), true + case int: + return formatSignedBooleanSQLValue(int64(val)) + case int8: + return formatSignedBooleanSQLValue(int64(val)) + case int16: + return formatSignedBooleanSQLValue(int64(val)) + case int32: + return formatSignedBooleanSQLValue(int64(val)) + case int64: + return formatSignedBooleanSQLValue(val) + case uint: + return formatUnsignedBooleanSQLValue(uint64(val)) + case uint8: + return formatUnsignedBooleanSQLValue(uint64(val)) + case uint16: + return formatUnsignedBooleanSQLValue(uint64(val)) + case uint32: + return formatUnsignedBooleanSQLValue(uint64(val)) + case uint64: + return formatUnsignedBooleanSQLValue(val) + case float32: + return formatFloatBooleanSQLValue(float64(val)) + case float64: + return formatFloatBooleanSQLValue(val) + case []byte: + if len(val) == 1 && (val[0] == 0 || val[0] == 1) { + return booleanSQLLiteral(val[0] == 1), true + } + return formatBooleanStringSQLValue(string(val)) + case string: + return formatBooleanStringSQLValue(val) + default: + return "", false + } +} + func formatImportSQLValue(dbType, columnType string, value interface{}) string { if value == nil { return "NULL" } + if isPgLikeBooleanDBType(dbType) && isBooleanColumnType(columnType) { + if literal, ok := formatPostgresBooleanSQLValue(value); ok { + return literal + } + } + if isTemporalColumnType(dbType, columnType) { normalized := normalizeImportTemporalValue(dbType, columnType, fmt.Sprintf("%v", value)) escaped := strings.ReplaceAll(normalized, "'", "''") diff --git a/internal/app/methods_file_export_test.go b/internal/app/methods_file_export_test.go index 7bd4c73..0d9e709 100644 --- a/internal/app/methods_file_export_test.go +++ b/internal/app/methods_file_export_test.go @@ -1,6 +1,7 @@ package app import ( + "bufio" "bytes" "context" "encoding/json" @@ -16,6 +17,7 @@ type fakeExportQueryDB struct { data []map[string]interface{} cols []string err error + defs []connection.ColumnDefinition lastQuery string lastContextTimeout time.Duration @@ -46,7 +48,7 @@ func (f *fakeExportQueryDB) GetCreateStatement(dbName, tableName string) (string return "", nil } func (f *fakeExportQueryDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) { - return nil, nil + return f.defs, nil } func (f *fakeExportQueryDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) { return nil, nil @@ -364,3 +366,68 @@ func TestFormatImportSQLValue_LeavesTextLiteralUntouched(t *testing.T) { t.Fatalf("文本字段不应被归一化,want=%q got=%q", "'2026-01-21T18:32:26+08:00'", got) } } + +func TestFormatImportSQLValue_PostgresBooleanColumnUsesBooleanLiteral(t *testing.T) { + cases := []struct { + name string + dbType string + columnType string + value interface{} + want string + }{ + {name: "postgres bool true", dbType: "postgres", columnType: "boolean", value: true, want: "true"}, + {name: "postgres bool false", dbType: "postgres", columnType: "bool", value: false, want: "false"}, + {name: "pg catalog bool string", dbType: "postgres", columnType: "pg_catalog.bool", value: "t", want: "true"}, + {name: "highgo boolean bytes", dbType: "highgo", columnType: "boolean", value: []byte("0"), want: "false"}, + {name: "mysql keeps numeric bool", dbType: "mysql", columnType: "tinyint(1)", value: true, want: "1"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := formatImportSQLValue(tc.dbType, tc.columnType, tc.value) + if got != tc.want { + t.Fatalf("布尔字面量异常,want=%q got=%q", tc.want, got) + } + }) + } +} + +func TestDumpTableSQL_PostgresBooleanBackupUsesBooleanLiterals(t *testing.T) { + fake := &fakeExportQueryDB{ + data: []map[string]interface{}{ + {"active": true, "archived": false}, + }, + cols: []string{"active", "archived"}, + defs: []connection.ColumnDefinition{ + {Name: "active", Type: "boolean"}, + {Name: "archived", Type: "bool"}, + }, + } + var buf bytes.Buffer + writer := bufio.NewWriter(&buf) + + err := dumpTableSQL( + writer, + fake, + connection.ConnectionConfig{Type: "postgres"}, + "public", + "orders", + false, + true, + map[string]string{}, + ) + if err != nil { + t.Fatalf("dumpTableSQL 返回错误: %v", err) + } + if err := writer.Flush(); err != nil { + t.Fatalf("flush 导出 SQL 失败: %v", err) + } + + content := buf.String() + if !strings.Contains(content, `INSERT INTO "public"."orders" ("active", "archived") VALUES (true, false);`) { + t.Fatalf("PostgreSQL bool 备份应使用 true/false 字面量,content=%s", content) + } + if strings.Contains(content, "VALUES (1, 0)") { + t.Fatalf("PostgreSQL bool 备份不应输出数字布尔值,content=%s", content) + } +}