🐛 fix(export): 修复 PostgreSQL 布尔字段备份类型错误

- 导出修复:PostgreSQL 系列 bool 字段 INSERT 输出 true/false
- 兼容处理:支持 bool、boolean、pg_catalog.bool 类型识别
- 回归覆盖:补充备份 SQL 布尔字段导出测试
Refs #444
This commit is contained in:
Syngnat
2026-05-15 22:23:41 +08:00
parent b707c74203
commit 71fca7fb86
2 changed files with 186 additions and 1 deletions

View File

@@ -960,11 +960,129 @@ func normalizeImportTemporalValue(dbType, columnType, raw string) string {
return parsed.Format("2006-01-02 15:04:05")
}
func isPgLikeBooleanDBType(dbType string) bool {
switch strings.ToLower(strings.TrimSpace(dbType)) {
case "postgres", "postgresql", "pg", "pq", "pgx", "kingbase", "kingbase8", "kingbasees", "kingbasev8", "highgo", "vastbase", "opengauss", "open_gauss", "open-gauss":
return true
default:
return false
}
}
func isBooleanColumnType(columnType string) bool {
typ := strings.ToLower(strings.TrimSpace(columnType))
if typ == "" {
return false
}
typ = strings.ReplaceAll(typ, `"`, "")
if idx := strings.IndexAny(typ, " ("); idx >= 0 {
typ = typ[:idx]
}
typ = strings.TrimPrefix(typ, "pg_catalog.")
return typ == "bool" || typ == "boolean"
}
func booleanSQLLiteral(v bool) string {
if v {
return "true"
}
return "false"
}
func formatSignedBooleanSQLValue(v int64) (string, bool) {
switch v {
case 0:
return "false", true
case 1:
return "true", true
default:
return "", false
}
}
func formatUnsignedBooleanSQLValue(v uint64) (string, bool) {
switch v {
case 0:
return "false", true
case 1:
return "true", true
default:
return "", false
}
}
func formatFloatBooleanSQLValue(v float64) (string, bool) {
if v == 0 {
return "false", true
}
if v == 1 {
return "true", true
}
return "", false
}
func formatBooleanStringSQLValue(raw string) (string, bool) {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "true", "t", "1", "yes", "y", "on":
return "true", true
case "false", "f", "0", "no", "n", "off":
return "false", true
default:
return "", false
}
}
func formatPostgresBooleanSQLValue(value interface{}) (string, bool) {
switch val := value.(type) {
case bool:
return booleanSQLLiteral(val), true
case int:
return formatSignedBooleanSQLValue(int64(val))
case int8:
return formatSignedBooleanSQLValue(int64(val))
case int16:
return formatSignedBooleanSQLValue(int64(val))
case int32:
return formatSignedBooleanSQLValue(int64(val))
case int64:
return formatSignedBooleanSQLValue(val)
case uint:
return formatUnsignedBooleanSQLValue(uint64(val))
case uint8:
return formatUnsignedBooleanSQLValue(uint64(val))
case uint16:
return formatUnsignedBooleanSQLValue(uint64(val))
case uint32:
return formatUnsignedBooleanSQLValue(uint64(val))
case uint64:
return formatUnsignedBooleanSQLValue(val)
case float32:
return formatFloatBooleanSQLValue(float64(val))
case float64:
return formatFloatBooleanSQLValue(val)
case []byte:
if len(val) == 1 && (val[0] == 0 || val[0] == 1) {
return booleanSQLLiteral(val[0] == 1), true
}
return formatBooleanStringSQLValue(string(val))
case string:
return formatBooleanStringSQLValue(val)
default:
return "", false
}
}
func formatImportSQLValue(dbType, columnType string, value interface{}) string {
if value == nil {
return "NULL"
}
if isPgLikeBooleanDBType(dbType) && isBooleanColumnType(columnType) {
if literal, ok := formatPostgresBooleanSQLValue(value); ok {
return literal
}
}
if isTemporalColumnType(dbType, columnType) {
normalized := normalizeImportTemporalValue(dbType, columnType, fmt.Sprintf("%v", value))
escaped := strings.ReplaceAll(normalized, "'", "''")

View File

@@ -1,6 +1,7 @@
package app
import (
"bufio"
"bytes"
"context"
"encoding/json"
@@ -16,6 +17,7 @@ type fakeExportQueryDB struct {
data []map[string]interface{}
cols []string
err error
defs []connection.ColumnDefinition
lastQuery string
lastContextTimeout time.Duration
@@ -46,7 +48,7 @@ func (f *fakeExportQueryDB) GetCreateStatement(dbName, tableName string) (string
return "", nil
}
func (f *fakeExportQueryDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
return nil, nil
return f.defs, nil
}
func (f *fakeExportQueryDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
return nil, nil
@@ -364,3 +366,68 @@ func TestFormatImportSQLValue_LeavesTextLiteralUntouched(t *testing.T) {
t.Fatalf("文本字段不应被归一化want=%q got=%q", "'2026-01-21T18:32:26+08:00'", got)
}
}
func TestFormatImportSQLValue_PostgresBooleanColumnUsesBooleanLiteral(t *testing.T) {
cases := []struct {
name string
dbType string
columnType string
value interface{}
want string
}{
{name: "postgres bool true", dbType: "postgres", columnType: "boolean", value: true, want: "true"},
{name: "postgres bool false", dbType: "postgres", columnType: "bool", value: false, want: "false"},
{name: "pg catalog bool string", dbType: "postgres", columnType: "pg_catalog.bool", value: "t", want: "true"},
{name: "highgo boolean bytes", dbType: "highgo", columnType: "boolean", value: []byte("0"), want: "false"},
{name: "mysql keeps numeric bool", dbType: "mysql", columnType: "tinyint(1)", value: true, want: "1"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := formatImportSQLValue(tc.dbType, tc.columnType, tc.value)
if got != tc.want {
t.Fatalf("布尔字面量异常want=%q got=%q", tc.want, got)
}
})
}
}
func TestDumpTableSQL_PostgresBooleanBackupUsesBooleanLiterals(t *testing.T) {
fake := &fakeExportQueryDB{
data: []map[string]interface{}{
{"active": true, "archived": false},
},
cols: []string{"active", "archived"},
defs: []connection.ColumnDefinition{
{Name: "active", Type: "boolean"},
{Name: "archived", Type: "bool"},
},
}
var buf bytes.Buffer
writer := bufio.NewWriter(&buf)
err := dumpTableSQL(
writer,
fake,
connection.ConnectionConfig{Type: "postgres"},
"public",
"orders",
false,
true,
map[string]string{},
)
if err != nil {
t.Fatalf("dumpTableSQL 返回错误: %v", err)
}
if err := writer.Flush(); err != nil {
t.Fatalf("flush 导出 SQL 失败: %v", err)
}
content := buf.String()
if !strings.Contains(content, `INSERT INTO "public"."orders" ("active", "archived") VALUES (true, false);`) {
t.Fatalf("PostgreSQL bool 备份应使用 true/false 字面量content=%s", content)
}
if strings.Contains(content, "VALUES (1, 0)") {
t.Fatalf("PostgreSQL bool 备份不应输出数字布尔值content=%s", content)
}
}