mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-05-31 12:39:41 +08:00
@@ -623,28 +623,16 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
quoteIdent := func(name string) string {
|
||||
n := strings.TrimSpace(name)
|
||||
n = strings.Trim(n, "\"")
|
||||
n = strings.ReplaceAll(n, "\"", "\"\"")
|
||||
if n == "" {
|
||||
return "\"\""
|
||||
}
|
||||
return `"` + n + `"`
|
||||
}
|
||||
|
||||
schema := ""
|
||||
table := strings.TrimSpace(tableName)
|
||||
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
|
||||
schema = strings.TrimSpace(parts[0])
|
||||
table = strings.TrimSpace(parts[1])
|
||||
schema, table := splitKingbaseQualifiedTable(tableName)
|
||||
if table == "" {
|
||||
return fmt.Errorf("table name required")
|
||||
}
|
||||
|
||||
qualifiedTable := ""
|
||||
if schema != "" {
|
||||
qualifiedTable = fmt.Sprintf("%s.%s", quoteIdent(schema), quoteIdent(table))
|
||||
qualifiedTable = fmt.Sprintf("%s.%s", quoteKingbaseIdent(schema), quoteKingbaseIdent(table))
|
||||
} else {
|
||||
qualifiedTable = quoteIdent(table)
|
||||
qualifiedTable = quoteKingbaseIdent(table)
|
||||
}
|
||||
|
||||
// 1. Deletes
|
||||
@@ -654,7 +642,7 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet
|
||||
idx := 0
|
||||
for k, v := range pk {
|
||||
idx++
|
||||
wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteIdent(k), idx))
|
||||
wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteKingbaseIdent(k), idx))
|
||||
args = append(args, v)
|
||||
}
|
||||
if len(wheres) == 0 {
|
||||
@@ -674,7 +662,7 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet
|
||||
|
||||
for k, v := range update.Values {
|
||||
idx++
|
||||
sets = append(sets, fmt.Sprintf("%s = $%d", quoteIdent(k), idx))
|
||||
sets = append(sets, fmt.Sprintf("%s = $%d", quoteKingbaseIdent(k), idx))
|
||||
args = append(args, v)
|
||||
}
|
||||
|
||||
@@ -685,7 +673,7 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet
|
||||
var wheres []string
|
||||
for k, v := range update.Keys {
|
||||
idx++
|
||||
wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteIdent(k), idx))
|
||||
wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteKingbaseIdent(k), idx))
|
||||
args = append(args, v)
|
||||
}
|
||||
|
||||
@@ -708,7 +696,7 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet
|
||||
|
||||
for k, v := range row {
|
||||
idx++
|
||||
cols = append(cols, quoteIdent(k))
|
||||
cols = append(cols, quoteKingbaseIdent(k))
|
||||
placeholders = append(placeholders, fmt.Sprintf("$%d", idx))
|
||||
args = append(args, v)
|
||||
}
|
||||
@@ -726,6 +714,67 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func normalizeKingbaseIdentifier(raw string) string {
|
||||
value := strings.TrimSpace(raw)
|
||||
if value == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 兼容 JSON/字符串转义后传入的标识符:\"schema\" -> "schema"
|
||||
value = strings.ReplaceAll(value, `\"`, `"`)
|
||||
value = strings.TrimSpace(value)
|
||||
|
||||
// 兼容异常多重包裹引号(例如 ""schema""、""""schema"""")。
|
||||
// strings.Trim 会移除两端连续引号,迭代后可收敛到纯标识符。
|
||||
for i := 0; i < 4; i++ {
|
||||
next := strings.TrimSpace(strings.Trim(value, `"`))
|
||||
if next == value {
|
||||
break
|
||||
}
|
||||
value = next
|
||||
}
|
||||
|
||||
// 兼容其他方言可能残留的引用形式
|
||||
if len(value) >= 2 && strings.HasPrefix(value, "`") && strings.HasSuffix(value, "`") {
|
||||
value = strings.TrimSpace(strings.Trim(value, "`"))
|
||||
}
|
||||
if len(value) >= 2 && strings.HasPrefix(value, "[") && strings.HasSuffix(value, "]") {
|
||||
value = strings.TrimSpace(value[1 : len(value)-1])
|
||||
}
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
func quoteKingbaseIdent(name string) string {
|
||||
n := normalizeKingbaseIdentifier(name)
|
||||
n = strings.ReplaceAll(n, `"`, `""`)
|
||||
if n == "" {
|
||||
return "\"\""
|
||||
}
|
||||
return `"` + n + `"`
|
||||
}
|
||||
|
||||
func splitKingbaseQualifiedTable(tableName string) (schema string, table string) {
|
||||
raw := strings.TrimSpace(tableName)
|
||||
if raw == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
if parts := strings.SplitN(raw, ".", 2); len(parts) == 2 {
|
||||
schema = normalizeKingbaseIdentifier(parts[0])
|
||||
table = normalizeKingbaseIdentifier(parts[1])
|
||||
if table == "" {
|
||||
return "", normalizeKingbaseIdentifier(raw)
|
||||
}
|
||||
if schema == "" {
|
||||
return "", table
|
||||
}
|
||||
return schema, table
|
||||
}
|
||||
|
||||
return "", normalizeKingbaseIdentifier(raw)
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
|
||||
// dbName 在本项目语义里是“数据库”,schema 由 table_schema 决定;这里返回全部用户 schema 的列用于查询提示。
|
||||
query := `
|
||||
|
||||
74
internal/db/kingbase_impl_test.go
Normal file
74
internal/db/kingbase_impl_test.go
Normal file
@@ -0,0 +1,74 @@
|
||||
//go:build gonavi_full_drivers || gonavi_kingbase_driver
|
||||
|
||||
package db
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeKingbaseIdentifier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{name: "plain", in: "ldf_server", want: "ldf_server"},
|
||||
{name: "quoted", in: `"ldf_server"`, want: "ldf_server"},
|
||||
{name: "double quoted", in: `""ldf_server""`, want: "ldf_server"},
|
||||
{name: "quad quoted", in: `""""ldf_server""""`, want: "ldf_server"},
|
||||
{name: "escaped quoted", in: `\"ldf_server\"`, want: "ldf_server"},
|
||||
{name: "backtick quoted", in: "`ldf_server`", want: "ldf_server"},
|
||||
{name: "bracket quoted", in: "[ldf_server]", want: "ldf_server"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := normalizeKingbaseIdentifier(tt.in); got != tt.want {
|
||||
t.Fatalf("normalizeKingbaseIdentifier(%q) = %q, want %q", tt.in, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuoteKingbaseIdent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{name: "plain", in: "ldf_server", want: `"ldf_server"`},
|
||||
{name: "double quoted", in: `""ldf_server""`, want: `"ldf_server"`},
|
||||
{name: "escaped quoted", in: `\"ldf_server\"`, want: `"ldf_server"`},
|
||||
{name: "with embedded quote", in: `ab"cd`, want: `"ab""cd"`},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := quoteKingbaseIdent(tt.in); got != tt.want {
|
||||
t.Fatalf("quoteKingbaseIdent(%q) = %q, want %q", tt.in, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitKingbaseQualifiedTable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
wantSchema string
|
||||
wantTable string
|
||||
}{
|
||||
{name: "plain qualified", in: "ldf_server.t_user", wantSchema: "ldf_server", wantTable: "t_user"},
|
||||
{name: "double quoted qualified", in: `""ldf_server"".""t_user""`, wantSchema: "ldf_server", wantTable: "t_user"},
|
||||
{name: "escaped qualified", in: `\"ldf_server\".\"t_user\"`, wantSchema: "ldf_server", wantTable: "t_user"},
|
||||
{name: "bracket qualified", in: "[ldf_server].[t_user]", wantSchema: "ldf_server", wantTable: "t_user"},
|
||||
{name: "table only", in: `""t_user""`, wantSchema: "", wantTable: "t_user"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotSchema, gotTable := splitKingbaseQualifiedTable(tt.in)
|
||||
if gotSchema != tt.wantSchema || gotTable != tt.wantTable {
|
||||
t.Fatalf("splitKingbaseQualifiedTable(%q) = (%q, %q), want (%q, %q)", tt.in, gotSchema, gotTable, tt.wantSchema, tt.wantTable)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user