mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-05-22 08:50:17 +08:00
🐛 fix(export): 修复 PostgreSQL 布尔字段备份类型错误
- 导出修复:PostgreSQL 系列 bool 字段 INSERT 输出 true/false - 兼容处理:支持 bool、boolean、pg_catalog.bool 类型识别 - 回归覆盖:补充备份 SQL 布尔字段导出测试 Refs #444
This commit is contained in:
@@ -960,11 +960,129 @@ func normalizeImportTemporalValue(dbType, columnType, raw string) string {
|
||||
return parsed.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
func isPgLikeBooleanDBType(dbType string) bool {
|
||||
switch strings.ToLower(strings.TrimSpace(dbType)) {
|
||||
case "postgres", "postgresql", "pg", "pq", "pgx", "kingbase", "kingbase8", "kingbasees", "kingbasev8", "highgo", "vastbase", "opengauss", "open_gauss", "open-gauss":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isBooleanColumnType(columnType string) bool {
|
||||
typ := strings.ToLower(strings.TrimSpace(columnType))
|
||||
if typ == "" {
|
||||
return false
|
||||
}
|
||||
typ = strings.ReplaceAll(typ, `"`, "")
|
||||
if idx := strings.IndexAny(typ, " ("); idx >= 0 {
|
||||
typ = typ[:idx]
|
||||
}
|
||||
typ = strings.TrimPrefix(typ, "pg_catalog.")
|
||||
return typ == "bool" || typ == "boolean"
|
||||
}
|
||||
|
||||
func booleanSQLLiteral(v bool) string {
|
||||
if v {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
}
|
||||
|
||||
func formatSignedBooleanSQLValue(v int64) (string, bool) {
|
||||
switch v {
|
||||
case 0:
|
||||
return "false", true
|
||||
case 1:
|
||||
return "true", true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func formatUnsignedBooleanSQLValue(v uint64) (string, bool) {
|
||||
switch v {
|
||||
case 0:
|
||||
return "false", true
|
||||
case 1:
|
||||
return "true", true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func formatFloatBooleanSQLValue(v float64) (string, bool) {
|
||||
if v == 0 {
|
||||
return "false", true
|
||||
}
|
||||
if v == 1 {
|
||||
return "true", true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func formatBooleanStringSQLValue(raw string) (string, bool) {
|
||||
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||||
case "true", "t", "1", "yes", "y", "on":
|
||||
return "true", true
|
||||
case "false", "f", "0", "no", "n", "off":
|
||||
return "false", true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func formatPostgresBooleanSQLValue(value interface{}) (string, bool) {
|
||||
switch val := value.(type) {
|
||||
case bool:
|
||||
return booleanSQLLiteral(val), true
|
||||
case int:
|
||||
return formatSignedBooleanSQLValue(int64(val))
|
||||
case int8:
|
||||
return formatSignedBooleanSQLValue(int64(val))
|
||||
case int16:
|
||||
return formatSignedBooleanSQLValue(int64(val))
|
||||
case int32:
|
||||
return formatSignedBooleanSQLValue(int64(val))
|
||||
case int64:
|
||||
return formatSignedBooleanSQLValue(val)
|
||||
case uint:
|
||||
return formatUnsignedBooleanSQLValue(uint64(val))
|
||||
case uint8:
|
||||
return formatUnsignedBooleanSQLValue(uint64(val))
|
||||
case uint16:
|
||||
return formatUnsignedBooleanSQLValue(uint64(val))
|
||||
case uint32:
|
||||
return formatUnsignedBooleanSQLValue(uint64(val))
|
||||
case uint64:
|
||||
return formatUnsignedBooleanSQLValue(val)
|
||||
case float32:
|
||||
return formatFloatBooleanSQLValue(float64(val))
|
||||
case float64:
|
||||
return formatFloatBooleanSQLValue(val)
|
||||
case []byte:
|
||||
if len(val) == 1 && (val[0] == 0 || val[0] == 1) {
|
||||
return booleanSQLLiteral(val[0] == 1), true
|
||||
}
|
||||
return formatBooleanStringSQLValue(string(val))
|
||||
case string:
|
||||
return formatBooleanStringSQLValue(val)
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func formatImportSQLValue(dbType, columnType string, value interface{}) string {
|
||||
if value == nil {
|
||||
return "NULL"
|
||||
}
|
||||
|
||||
if isPgLikeBooleanDBType(dbType) && isBooleanColumnType(columnType) {
|
||||
if literal, ok := formatPostgresBooleanSQLValue(value); ok {
|
||||
return literal
|
||||
}
|
||||
}
|
||||
|
||||
if isTemporalColumnType(dbType, columnType) {
|
||||
normalized := normalizeImportTemporalValue(dbType, columnType, fmt.Sprintf("%v", value))
|
||||
escaped := strings.ReplaceAll(normalized, "'", "''")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
@@ -16,6 +17,7 @@ type fakeExportQueryDB struct {
|
||||
data []map[string]interface{}
|
||||
cols []string
|
||||
err error
|
||||
defs []connection.ColumnDefinition
|
||||
|
||||
lastQuery string
|
||||
lastContextTimeout time.Duration
|
||||
@@ -46,7 +48,7 @@ func (f *fakeExportQueryDB) GetCreateStatement(dbName, tableName string) (string
|
||||
return "", nil
|
||||
}
|
||||
func (f *fakeExportQueryDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
return nil, nil
|
||||
return f.defs, nil
|
||||
}
|
||||
func (f *fakeExportQueryDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
|
||||
return nil, nil
|
||||
@@ -364,3 +366,68 @@ func TestFormatImportSQLValue_LeavesTextLiteralUntouched(t *testing.T) {
|
||||
t.Fatalf("文本字段不应被归一化,want=%q got=%q", "'2026-01-21T18:32:26+08:00'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatImportSQLValue_PostgresBooleanColumnUsesBooleanLiteral(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
dbType string
|
||||
columnType string
|
||||
value interface{}
|
||||
want string
|
||||
}{
|
||||
{name: "postgres bool true", dbType: "postgres", columnType: "boolean", value: true, want: "true"},
|
||||
{name: "postgres bool false", dbType: "postgres", columnType: "bool", value: false, want: "false"},
|
||||
{name: "pg catalog bool string", dbType: "postgres", columnType: "pg_catalog.bool", value: "t", want: "true"},
|
||||
{name: "highgo boolean bytes", dbType: "highgo", columnType: "boolean", value: []byte("0"), want: "false"},
|
||||
{name: "mysql keeps numeric bool", dbType: "mysql", columnType: "tinyint(1)", value: true, want: "1"},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := formatImportSQLValue(tc.dbType, tc.columnType, tc.value)
|
||||
if got != tc.want {
|
||||
t.Fatalf("布尔字面量异常,want=%q got=%q", tc.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDumpTableSQL_PostgresBooleanBackupUsesBooleanLiterals(t *testing.T) {
|
||||
fake := &fakeExportQueryDB{
|
||||
data: []map[string]interface{}{
|
||||
{"active": true, "archived": false},
|
||||
},
|
||||
cols: []string{"active", "archived"},
|
||||
defs: []connection.ColumnDefinition{
|
||||
{Name: "active", Type: "boolean"},
|
||||
{Name: "archived", Type: "bool"},
|
||||
},
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
writer := bufio.NewWriter(&buf)
|
||||
|
||||
err := dumpTableSQL(
|
||||
writer,
|
||||
fake,
|
||||
connection.ConnectionConfig{Type: "postgres"},
|
||||
"public",
|
||||
"orders",
|
||||
false,
|
||||
true,
|
||||
map[string]string{},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("dumpTableSQL 返回错误: %v", err)
|
||||
}
|
||||
if err := writer.Flush(); err != nil {
|
||||
t.Fatalf("flush 导出 SQL 失败: %v", err)
|
||||
}
|
||||
|
||||
content := buf.String()
|
||||
if !strings.Contains(content, `INSERT INTO "public"."orders" ("active", "archived") VALUES (true, false);`) {
|
||||
t.Fatalf("PostgreSQL bool 备份应使用 true/false 字面量,content=%s", content)
|
||||
}
|
||||
if strings.Contains(content, "VALUES (1, 0)") {
|
||||
t.Fatalf("PostgreSQL bool 备份不应输出数字布尔值,content=%s", content)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user