From 2449184ad3c11f0a0dd651f101f11f5aa6b5928a Mon Sep 17 00:00:00 2001 From: tianqijiuyun-latiao <69459608+tianqijiuyun-latiao@users.noreply.github.com> Date: Thu, 5 Mar 2026 21:10:36 +0800 Subject: [PATCH] =?UTF-8?q?fix(kingbase-transaction):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E9=87=91=E4=BB=93=E4=BA=8B=E5=8A=A1=E6=8F=90=E4=BA=A4=E9=87=8D?= =?UTF-8?q?=E5=A4=8D=E5=BC=95=E5=8F=B7=E5=AF=BC=E8=87=B4=E8=AF=AD=E6=B3=95?= =?UTF-8?q?=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit refs #176 --- internal/db/kingbase_impl.go | 91 ++++++++++++++++++++++++------- internal/db/kingbase_impl_test.go | 74 +++++++++++++++++++++++++ 2 files changed, 144 insertions(+), 21 deletions(-) create mode 100644 internal/db/kingbase_impl_test.go diff --git a/internal/db/kingbase_impl.go b/internal/db/kingbase_impl.go index f1357a8..6dfd2e5 100644 --- a/internal/db/kingbase_impl.go +++ b/internal/db/kingbase_impl.go @@ -623,28 +623,16 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet } defer tx.Rollback() - quoteIdent := func(name string) string { - n := strings.TrimSpace(name) - n = strings.Trim(n, "\"") - n = strings.ReplaceAll(n, "\"", "\"\"") - if n == "" { - return "\"\"" - } - return `"` + n + `"` - } - - schema := "" - table := strings.TrimSpace(tableName) - if parts := strings.SplitN(table, ".", 2); len(parts) == 2 { - schema = strings.TrimSpace(parts[0]) - table = strings.TrimSpace(parts[1]) + schema, table := splitKingbaseQualifiedTable(tableName) + if table == "" { + return fmt.Errorf("table name required") } qualifiedTable := "" if schema != "" { - qualifiedTable = fmt.Sprintf("%s.%s", quoteIdent(schema), quoteIdent(table)) + qualifiedTable = fmt.Sprintf("%s.%s", quoteKingbaseIdent(schema), quoteKingbaseIdent(table)) } else { - qualifiedTable = quoteIdent(table) + qualifiedTable = quoteKingbaseIdent(table) } // 1. Deletes @@ -654,7 +642,7 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet idx := 0 for k, v := range pk { idx++ - wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteIdent(k), idx)) + wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteKingbaseIdent(k), idx)) args = append(args, v) } if len(wheres) == 0 { @@ -674,7 +662,7 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet for k, v := range update.Values { idx++ - sets = append(sets, fmt.Sprintf("%s = $%d", quoteIdent(k), idx)) + sets = append(sets, fmt.Sprintf("%s = $%d", quoteKingbaseIdent(k), idx)) args = append(args, v) } @@ -685,7 +673,7 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet var wheres []string for k, v := range update.Keys { idx++ - wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteIdent(k), idx)) + wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteKingbaseIdent(k), idx)) args = append(args, v) } @@ -708,7 +696,7 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet for k, v := range row { idx++ - cols = append(cols, quoteIdent(k)) + cols = append(cols, quoteKingbaseIdent(k)) placeholders = append(placeholders, fmt.Sprintf("$%d", idx)) args = append(args, v) } @@ -726,6 +714,67 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet return tx.Commit() } +func normalizeKingbaseIdentifier(raw string) string { + value := strings.TrimSpace(raw) + if value == "" { + return "" + } + + // 兼容 JSON/字符串转义后传入的标识符:\"schema\" -> "schema" + value = strings.ReplaceAll(value, `\"`, `"`) + value = strings.TrimSpace(value) + + // 兼容异常多重包裹引号(例如 ""schema""、""""schema"""")。 + // strings.Trim 会移除两端连续引号,迭代后可收敛到纯标识符。 + for i := 0; i < 4; i++ { + next := strings.TrimSpace(strings.Trim(value, `"`)) + if next == value { + break + } + value = next + } + + // 兼容其他方言可能残留的引用形式 + if len(value) >= 2 && strings.HasPrefix(value, "`") && strings.HasSuffix(value, "`") { + value = strings.TrimSpace(strings.Trim(value, "`")) + } + if len(value) >= 2 && strings.HasPrefix(value, "[") && strings.HasSuffix(value, "]") { + value = strings.TrimSpace(value[1 : len(value)-1]) + } + + return value +} + +func quoteKingbaseIdent(name string) string { + n := normalizeKingbaseIdentifier(name) + n = strings.ReplaceAll(n, `"`, `""`) + if n == "" { + return "\"\"" + } + return `"` + n + `"` +} + +func splitKingbaseQualifiedTable(tableName string) (schema string, table string) { + raw := strings.TrimSpace(tableName) + if raw == "" { + return "", "" + } + + if parts := strings.SplitN(raw, ".", 2); len(parts) == 2 { + schema = normalizeKingbaseIdentifier(parts[0]) + table = normalizeKingbaseIdentifier(parts[1]) + if table == "" { + return "", normalizeKingbaseIdentifier(raw) + } + if schema == "" { + return "", table + } + return schema, table + } + + return "", normalizeKingbaseIdentifier(raw) +} + func (k *KingbaseDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) { // dbName 在本项目语义里是“数据库”,schema 由 table_schema 决定;这里返回全部用户 schema 的列用于查询提示。 query := ` diff --git a/internal/db/kingbase_impl_test.go b/internal/db/kingbase_impl_test.go new file mode 100644 index 0000000..eca6eaa --- /dev/null +++ b/internal/db/kingbase_impl_test.go @@ -0,0 +1,74 @@ +//go:build gonavi_full_drivers || gonavi_kingbase_driver + +package db + +import "testing" + +func TestNormalizeKingbaseIdentifier(t *testing.T) { + tests := []struct { + name string + in string + want string + }{ + {name: "plain", in: "ldf_server", want: "ldf_server"}, + {name: "quoted", in: `"ldf_server"`, want: "ldf_server"}, + {name: "double quoted", in: `""ldf_server""`, want: "ldf_server"}, + {name: "quad quoted", in: `""""ldf_server""""`, want: "ldf_server"}, + {name: "escaped quoted", in: `\"ldf_server\"`, want: "ldf_server"}, + {name: "backtick quoted", in: "`ldf_server`", want: "ldf_server"}, + {name: "bracket quoted", in: "[ldf_server]", want: "ldf_server"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := normalizeKingbaseIdentifier(tt.in); got != tt.want { + t.Fatalf("normalizeKingbaseIdentifier(%q) = %q, want %q", tt.in, got, tt.want) + } + }) + } +} + +func TestQuoteKingbaseIdent(t *testing.T) { + tests := []struct { + name string + in string + want string + }{ + {name: "plain", in: "ldf_server", want: `"ldf_server"`}, + {name: "double quoted", in: `""ldf_server""`, want: `"ldf_server"`}, + {name: "escaped quoted", in: `\"ldf_server\"`, want: `"ldf_server"`}, + {name: "with embedded quote", in: `ab"cd`, want: `"ab""cd"`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := quoteKingbaseIdent(tt.in); got != tt.want { + t.Fatalf("quoteKingbaseIdent(%q) = %q, want %q", tt.in, got, tt.want) + } + }) + } +} + +func TestSplitKingbaseQualifiedTable(t *testing.T) { + tests := []struct { + name string + in string + wantSchema string + wantTable string + }{ + {name: "plain qualified", in: "ldf_server.t_user", wantSchema: "ldf_server", wantTable: "t_user"}, + {name: "double quoted qualified", in: `""ldf_server"".""t_user""`, wantSchema: "ldf_server", wantTable: "t_user"}, + {name: "escaped qualified", in: `\"ldf_server\".\"t_user\"`, wantSchema: "ldf_server", wantTable: "t_user"}, + {name: "bracket qualified", in: "[ldf_server].[t_user]", wantSchema: "ldf_server", wantTable: "t_user"}, + {name: "table only", in: `""t_user""`, wantSchema: "", wantTable: "t_user"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotSchema, gotTable := splitKingbaseQualifiedTable(tt.in) + if gotSchema != tt.wantSchema || gotTable != tt.wantTable { + t.Fatalf("splitKingbaseQualifiedTable(%q) = (%q, %q), want (%q, %q)", tt.in, gotSchema, gotTable, tt.wantSchema, tt.wantTable) + } + }) + } +}