//go:build gonavi_full_drivers || gonavi_kingbase_driver package db import ( "context" "database/sql" "fmt" "net" "net/url" "regexp" "sort" "strconv" "strings" "time" "GoNavi-Wails/internal/connection" "GoNavi-Wails/internal/logger" "GoNavi-Wails/internal/ssh" "GoNavi-Wails/internal/utils" _ "gitea.com/kingbase/gokb" // Registers "kingbase" driver ) type KingbaseDB struct { conn *sql.DB pingTimeout time.Duration forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder } func quoteConnValue(v string) string { if v == "" { return "''" } needsQuote := false for _, r := range v { switch r { case ' ', '\t', '\n', '\r', '\v', '\f', '\'', '\\': needsQuote = true } if needsQuote { break } } if !needsQuote { return v } var b strings.Builder b.Grow(len(v) + 2) b.WriteByte('\'') for _, r := range v { if r == '\\' || r == '\'' { b.WriteByte('\\') } b.WriteRune(r) } b.WriteByte('\'') return b.String() } func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string { // Kingbase DSN usually similar to Postgres: // host=localhost port=54321 user=system password=... dbname=TEST sslmode=disable params := url.Values{} params.Set("host", config.Host) params.Set("port", strconv.Itoa(config.Port)) params.Set("user", config.User) params.Set("password", config.Password) params.Set("dbname", config.Database) params.Set("sslmode", resolvePostgresSSLMode(config)) params.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config))) mergeConnectionParamsFromConfig(params, config, "kingbase") preferred := []string{"host", "port", "user", "password", "dbname", "sslmode", "connect_timeout"} seen := make(map[string]struct{}, len(params)) parts := make([]string, 0, len(params)) for _, key := range preferred { if values, ok := params[key]; ok && len(values) > 0 { parts = append(parts, fmt.Sprintf("%s=%s", key, quoteConnValue(values[len(values)-1]))) seen[key] = struct{}{} } } extraKeys := make([]string, 0, len(params)) for key := range params { if _, ok := seen[key]; ok || !isSafeConnectionParamKey(key) { continue } extraKeys = append(extraKeys, key) } sort.Strings(extraKeys) for _, key := range extraKeys { values := params[key] if len(values) == 0 { continue } parts = append(parts, fmt.Sprintf("%s=%s", key, quoteConnValue(values[len(values)-1]))) } return strings.Join(parts, " ") } func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error { runConfig := config if config.UseSSH { // Create SSH tunnel with local port forwarding logger.Infof("人大金仓使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User) forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port) if err != nil { return fmt.Errorf("创建 SSH 隧道失败:%w", err) } k.forwarder = forwarder // Parse local address host, portStr, err := net.SplitHostPort(forwarder.LocalAddr) if err != nil { return fmt.Errorf("解析本地转发地址失败:%w", err) } port, err := strconv.Atoi(portStr) if err != nil { return fmt.Errorf("解析本地端口失败:%w", err) } // Create a modified config pointing to local forwarder localConfig := config localConfig.Host = host localConfig.Port = port localConfig.UseSSH = false runConfig = localConfig logger.Infof("人大金仓通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port) } attempts := []connection.ConnectionConfig{runConfig} if shouldTrySSLPreferredFallback(runConfig) { attempts = append(attempts, withSSLDisabled(runConfig)) } var failures []string for idx, attempt := range attempts { dsn := k.getDSN(attempt) db, err := sql.Open("kingbase", dsn) if err != nil { failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err)) continue } k.conn = db k.pingTimeout = getConnectTimeout(attempt) if err := k.Ping(); err != nil { _ = db.Close() k.conn = nil failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err)) continue } if idx > 0 { logger.Warnf("人大金仓 SSL 优先连接失败,已回退至明文连接") } // 获取 schema 列表以重构带有 search_path 的连接池 searchPathStr := k.getSearchPathStr() if searchPathStr != "" { // 将 search_path 参数拼入 DSN finalDSN := dsn + " search_path=" + quoteConnValue(searchPathStr) if finalDB, err := sql.Open("kingbase", finalDSN); err == nil { k.pingTimeout = getConnectTimeout(attempt) finalDB.SetConnMaxLifetime(5 * time.Minute) // 临时将 k.conn 指向 finalDB 来做 ping 测试 oldConn := k.conn k.conn = finalDB if err := k.Ping(); err == nil { // 成功使用带 search_path 的连接池 _ = oldConn.Close() logger.Infof("人大金仓已配置连接级 search_path:%s", searchPathStr) } else { _ = finalDB.Close() k.conn = oldConn } } } if searchPathStr != "" { timeout := k.pingTimeout if timeout <= 0 { timeout = 5 * time.Second } ctx, cancel := utils.ContextWithTimeout(timeout) defer cancel() if _, err := k.conn.ExecContext(ctx, fmt.Sprintf("SET search_path TO %s", searchPathStr)); err != nil { logger.Warnf("人大金仓显式设置 search_path 失败:%v", err) } else { logger.Infof("人大金仓已设置默认 search_path:%s", searchPathStr) } } return nil } return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ";")) } // getSearchPathStr 查询当前数据库中所有用户 schema,配置 DSN 的 search_path。 // KingBase 默认 search_path 为 "$user", public,对于自定义 schema 下的表不可见。 func (k *KingbaseDB) getSearchPathStr() string { if k.conn == nil { return "" } query := `SELECT nspname FROM pg_namespace WHERE nspname NOT IN ('pg_catalog', 'information_schema') AND nspname NOT LIKE 'pg|_%' ESCAPE '|' ORDER BY nspname` rows, err := k.conn.Query(query) if err != nil { logger.Warnf("人大金仓查询用户 schema 失败,跳过 search_path 设置:%v", err) return "" } defer rows.Close() var rawSchemas []string for rows.Next() { var name string if err := rows.Scan(&name); err != nil { continue } name = strings.TrimSpace(name) if name != "" { rawSchemas = append(rawSchemas, name) } } searchPath, _ := buildKingbaseSearchPathCommon(rawSchemas) return searchPath } func (k *KingbaseDB) Close() error { // Close SSH forwarder first if exists if k.forwarder != nil { if err := k.forwarder.Close(); err != nil { logger.Warnf("关闭人大金仓 SSH 端口转发失败:%v", err) } k.forwarder = nil } // Then close database connection if k.conn != nil { return k.conn.Close() } return nil } func (k *KingbaseDB) Ping() error { if k.conn == nil { return fmt.Errorf("连接未打开") } timeout := k.pingTimeout if timeout <= 0 { timeout = 5 * time.Second } ctx, cancel := utils.ContextWithTimeout(timeout) defer cancel() return k.conn.PingContext(ctx) } func (k *KingbaseDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { if k.conn == nil { return nil, nil, fmt.Errorf("连接未打开") } rows, err := k.conn.QueryContext(ctx, query) if err != nil { return nil, nil, err } defer rows.Close() return scanRows(rows) } func (k *KingbaseDB) Query(query string) ([]map[string]interface{}, []string, error) { if k.conn == nil { return nil, nil, fmt.Errorf("连接未打开") } rows, err := k.conn.Query(query) if err != nil { return nil, nil, err } defer rows.Close() return scanRows(rows) } func (k *KingbaseDB) ExecContext(ctx context.Context, query string) (int64, error) { if k.conn == nil { return 0, fmt.Errorf("连接未打开") } res, err := k.conn.ExecContext(ctx, query) if err != nil { return 0, err } return res.RowsAffected() } func (k *KingbaseDB) ExecBatchContext(ctx context.Context, query string) (int64, error) { if k.conn == nil { return 0, fmt.Errorf("连接未打开") } res, err := k.conn.ExecContext(ctx, query) if err != nil { return 0, err } return res.RowsAffected() } func (k *KingbaseDB) Exec(query string) (int64, error) { if k.conn == nil { return 0, fmt.Errorf("连接未打开") } res, err := k.conn.Exec(query) if err != nil { return 0, err } return res.RowsAffected() } func (k *KingbaseDB) GetDatabases() ([]string, error) { data, _, err := k.Query("SELECT datname FROM pg_database WHERE datistemplate = false") if err == nil { dbs := collectKingbaseNames(data, "datname", "database") if len(dbs) > 0 { return dbs, nil } } fallbackData, _, fallbackErr := k.Query("SELECT current_database() AS datname") if fallbackErr != nil { if err != nil { return nil, err } return nil, fallbackErr } dbs := collectKingbaseNames(fallbackData, "datname", "database", "current_database", "currentDatabase") if len(dbs) > 0 { return dbs, nil } if err != nil { return nil, err } return nil, fmt.Errorf("未获取到可见数据库列表") } func collectKingbaseNames(rows []map[string]interface{}, keys ...string) []string { result := make([]string, 0, len(rows)) seen := make(map[string]struct{}, len(rows)) for _, row := range rows { name := strings.TrimSpace(getKingbaseNameFromRow(row, keys...)) if name == "" { continue } if _, exists := seen[name]; exists { continue } seen[name] = struct{}{} result = append(result, name) } return result } func getKingbaseNameFromRow(row map[string]interface{}, keys ...string) string { if len(row) == 0 { return "" } for _, key := range keys { if value, ok := row[key]; ok { return fmt.Sprintf("%v", value) } } for existingKey, value := range row { for _, key := range keys { if strings.EqualFold(existingKey, key) { return fmt.Sprintf("%v", value) } } } for _, value := range row { return fmt.Sprintf("%v", value) } return "" } func (k *KingbaseDB) GetTables(dbName string) ([]string, error) { // Kingbase: tables are scoped by the current DB connection; include schema to avoid search_path issues. query := ` SELECT table_schema AS schemaname, table_name AS tablename FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema') AND table_schema NOT LIKE 'pg|_%' ESCAPE '|' ORDER BY table_schema, table_name` data, _, err := k.Query(query) if err != nil { return nil, err } var tables []string for _, row := range data { schema, okSchema := row["schemaname"] name, okName := row["tablename"] if okSchema && okName { tables = append(tables, fmt.Sprintf("%v.%v", schema, name)) continue } if val, ok := row["table_name"]; ok { tables = append(tables, fmt.Sprintf("%v", val)) } } return tables, nil } func (k *KingbaseDB) GetCreateStatement(dbName, tableName string) (string, error) { // Kingbase doesn't have "SHOW CREATE TABLE". // We can try pg_dump logic or use a query to reconstruction. // A simple approach is just returning basic info or "Not Supported". // Or we can query information_schema to build it. return "SHOW CREATE TABLE not directly supported in Kingbase/Postgres via SQL", nil } func (k *KingbaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) { // 解析 schema.table 格式 schema := strings.TrimSpace(dbName) table := strings.TrimSpace(tableName) // 如果 tableName 包含 schema (格式: schema.table) if parts := strings.SplitN(table, ".", 2); len(parts) == 2 { parsedSchema := strings.TrimSpace(parts[0]) parsedTable := strings.TrimSpace(parts[1]) if parsedSchema != "" && parsedTable != "" { schema = parsedSchema table = parsedTable } } // 如果仍然没有 schema,使用 current_schema() // 这样可以自动匹配当前连接的 search_path if schema == "" { return k.getColumnsWithCurrentSchema(table) } if table == "" { return nil, fmt.Errorf("表名不能为空") } // 转义函数:处理单引号,移除双引号 esc := func(s string) string { // 移除前后的双引号(如果存在) s = strings.Trim(s, "\"") // 转义单引号 return strings.ReplaceAll(s, "'", "''") } query := fmt.Sprintf(` SELECT a.attname AS column_name, pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type, CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable, pg_get_expr(ad.adbin, ad.adrelid) AS column_default, col_description(a.attrelid, a.attnum) AS comment, CASE WHEN pk.attname IS NOT NULL THEN 'PRI' ELSE '' END AS column_key FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace JOIN pg_attribute a ON a.attrelid = c.oid LEFT JOIN pg_attrdef ad ON ad.adrelid = c.oid AND ad.adnum = a.attnum LEFT JOIN ( SELECT i.indrelid, a3.attname FROM pg_index i JOIN pg_attribute a3 ON a3.attrelid = i.indrelid AND a3.attnum = ANY(i.indkey) WHERE i.indisprimary ) pk ON pk.indrelid = c.oid AND pk.attname = a.attname WHERE c.relkind IN ('r', 'p') AND n.nspname = '%s' AND c.relname = '%s' AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum`, esc(schema), esc(table)) data, _, err := k.Query(query) if err != nil { return nil, err } var columns []connection.ColumnDefinition for _, row := range data { col := connection.ColumnDefinition{ Name: fmt.Sprintf("%v", row["column_name"]), Type: fmt.Sprintf("%v", row["data_type"]), Nullable: fmt.Sprintf("%v", row["is_nullable"]), Key: fmt.Sprintf("%v", row["column_key"]), Extra: "", Comment: "", } if row["column_default"] != nil { def := fmt.Sprintf("%v", row["column_default"]) col.Default = &def if strings.HasPrefix(strings.ToLower(strings.TrimSpace(def)), "nextval(") { col.Extra = "auto_increment" } } if v, ok := row["comment"]; ok && v != nil { col.Comment = fmt.Sprintf("%v", v) } columns = append(columns, col) } return columns, nil } // getColumnsWithCurrentSchema 使用 current_schema() 查询当前schema的表 func (k *KingbaseDB) getColumnsWithCurrentSchema(tableName string) ([]connection.ColumnDefinition, error) { table := strings.TrimSpace(tableName) if table == "" { return nil, fmt.Errorf("表名不能为空") } // 转义函数 esc := func(s string) string { s = strings.Trim(s, "\"") return strings.ReplaceAll(s, "'", "''") } // 使用 current_schema() 获取当前schema query := fmt.Sprintf(` SELECT a.attname AS column_name, pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type, CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable, pg_get_expr(ad.adbin, ad.adrelid) AS column_default, col_description(a.attrelid, a.attnum) AS comment, CASE WHEN pk.attname IS NOT NULL THEN 'PRI' ELSE '' END AS column_key FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace JOIN pg_attribute a ON a.attrelid = c.oid LEFT JOIN pg_attrdef ad ON ad.adrelid = c.oid AND ad.adnum = a.attnum LEFT JOIN ( SELECT i.indrelid, a3.attname FROM pg_index i JOIN pg_attribute a3 ON a3.attrelid = i.indrelid AND a3.attnum = ANY(i.indkey) WHERE i.indisprimary ) pk ON pk.indrelid = c.oid AND pk.attname = a.attname WHERE c.relkind IN ('r', 'p') AND n.nspname = current_schema() AND c.relname = '%s' AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum`, esc(table)) data, _, err := k.Query(query) if err != nil { return nil, err } var columns []connection.ColumnDefinition for _, row := range data { col := connection.ColumnDefinition{ Name: fmt.Sprintf("%v", row["column_name"]), Type: fmt.Sprintf("%v", row["data_type"]), Nullable: fmt.Sprintf("%v", row["is_nullable"]), Key: fmt.Sprintf("%v", row["column_key"]), Extra: "", Comment: "", } if row["column_default"] != nil { def := fmt.Sprintf("%v", row["column_default"]) col.Default = &def if strings.HasPrefix(strings.ToLower(strings.TrimSpace(def)), "nextval(") { col.Extra = "auto_increment" } } if v, ok := row["comment"]; ok && v != nil { col.Comment = fmt.Sprintf("%v", v) } columns = append(columns, col) } return columns, nil } func (k *KingbaseDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) { // 解析 schema.table 格式 schema := strings.TrimSpace(dbName) table := strings.TrimSpace(tableName) // 如果 tableName 包含 schema (格式: schema.table) if parts := strings.SplitN(table, ".", 2); len(parts) == 2 { parsedSchema := strings.TrimSpace(parts[0]) parsedTable := strings.TrimSpace(parts[1]) if parsedSchema != "" && parsedTable != "" { schema = parsedSchema table = parsedTable } } if table == "" { return nil, fmt.Errorf("表名不能为空") } // 转义函数:处理单引号,移除双引号 esc := func(s string) string { s = strings.Trim(s, "\"") return strings.ReplaceAll(s, "'", "''") } // 构建查询:如果没有指定schema,使用current_schema() var query string if schema != "" { query = fmt.Sprintf(` SELECT i.relname as index_name, a.attname as column_name, ix.indisunique as is_unique FROM pg_class t, pg_class i, pg_index ix, pg_attribute a, pg_namespace n WHERE t.oid = ix.indrelid AND i.oid = ix.indexrelid AND a.attrelid = t.oid AND a.attnum = ANY(ix.indkey) AND t.relkind = 'r' AND t.relname = '%s' AND n.oid = t.relnamespace AND n.nspname = '%s' `, esc(table), esc(schema)) } else { query = fmt.Sprintf(` SELECT i.relname as index_name, a.attname as column_name, ix.indisunique as is_unique FROM pg_class t, pg_class i, pg_index ix, pg_attribute a, pg_namespace n WHERE t.oid = ix.indrelid AND i.oid = ix.indexrelid AND a.attrelid = t.oid AND a.attnum = ANY(ix.indkey) AND t.relkind = 'r' AND t.relname = '%s' AND n.oid = t.relnamespace AND n.nspname = current_schema() `, esc(table)) } data, _, err := k.Query(query) if err != nil { return nil, err } var indexes []connection.IndexDefinition for _, row := range data { nonUnique := 1 if val, ok := row["is_unique"]; ok { if b, ok := val.(bool); ok && b { nonUnique = 0 } } idx := connection.IndexDefinition{ Name: fmt.Sprintf("%v", row["index_name"]), ColumnName: fmt.Sprintf("%v", row["column_name"]), NonUnique: nonUnique, IndexType: "BTREE", // Default } indexes = append(indexes, idx) } return indexes, nil } func (k *KingbaseDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) { // 解析 schema.table 格式 schema := strings.TrimSpace(dbName) table := strings.TrimSpace(tableName) // 如果 tableName 包含 schema (格式: schema.table) if parts := strings.SplitN(table, ".", 2); len(parts) == 2 { parsedSchema := strings.TrimSpace(parts[0]) parsedTable := strings.TrimSpace(parts[1]) if parsedSchema != "" && parsedTable != "" { schema = parsedSchema table = parsedTable } } if table == "" { return nil, fmt.Errorf("表名不能为空") } // 转义函数:处理单引号,移除双引号 esc := func(s string) string { s = strings.Trim(s, "\"") return strings.ReplaceAll(s, "'", "''") } // 构建查询:如果没有指定schema,使用current_schema() var query string if schema != "" { query = fmt.Sprintf(` SELECT tc.constraint_name, kcu.column_name, ccu.table_name AS foreign_table_name, ccu.column_name AS foreign_column_name FROM information_schema.table_constraints AS tc JOIN information_schema.key_column_usage AS kcu ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema JOIN information_schema.constraint_column_usage AS ccu ON ccu.constraint_name = tc.constraint_name AND ccu.table_schema = tc.table_schema WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema='%s'`, esc(table), esc(schema)) } else { query = fmt.Sprintf(` SELECT tc.constraint_name, kcu.column_name, ccu.table_name AS foreign_table_name, ccu.column_name AS foreign_column_name FROM information_schema.table_constraints AS tc JOIN information_schema.key_column_usage AS kcu ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema JOIN information_schema.constraint_column_usage AS ccu ON ccu.constraint_name = tc.constraint_name AND ccu.table_schema = tc.table_schema WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema=current_schema()`, esc(table)) } data, _, err := k.Query(query) if err != nil { return nil, err } var fks []connection.ForeignKeyDefinition for _, row := range data { fk := connection.ForeignKeyDefinition{ Name: fmt.Sprintf("%v", row["constraint_name"]), ColumnName: fmt.Sprintf("%v", row["column_name"]), RefTableName: fmt.Sprintf("%v", row["foreign_table_name"]), RefColumnName: fmt.Sprintf("%v", row["foreign_column_name"]), ConstraintName: fmt.Sprintf("%v", row["constraint_name"]), } fks = append(fks, fk) } return fks, nil } func (k *KingbaseDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) { // 解析 schema.table 格式 schema := strings.TrimSpace(dbName) table := strings.TrimSpace(tableName) // 如果 tableName 包含 schema (格式: schema.table) if parts := strings.SplitN(table, ".", 2); len(parts) == 2 { parsedSchema := strings.TrimSpace(parts[0]) parsedTable := strings.TrimSpace(parts[1]) if parsedSchema != "" && parsedTable != "" { schema = parsedSchema table = parsedTable } } if table == "" { return nil, fmt.Errorf("表名不能为空") } // 转义函数:处理单引号,移除双引号 esc := func(s string) string { s = strings.Trim(s, "\"") return strings.ReplaceAll(s, "'", "''") } // 构建查询:如果指定了schema,也加上schema条件 var query string if schema != "" { query = fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation FROM information_schema.triggers WHERE event_object_table = '%s' AND event_object_schema = '%s'`, esc(table), esc(schema)) } else { query = fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation FROM information_schema.triggers WHERE event_object_table = '%s' AND event_object_schema = current_schema()`, esc(table)) } data, _, err := k.Query(query) if err != nil { return nil, err } var triggers []connection.TriggerDefinition for _, row := range data { trig := connection.TriggerDefinition{ Name: fmt.Sprintf("%v", row["trigger_name"]), Timing: fmt.Sprintf("%v", row["action_timing"]), Event: fmt.Sprintf("%v", row["event_manipulation"]), Statement: "SOURCE HIDDEN", } triggers = append(triggers, trig) } return triggers, nil } func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet) error { if k.conn == nil { return fmt.Errorf("连接未打开") } tx, err := k.conn.Begin() if err != nil { return err } defer tx.Rollback() schema, table := splitKingbaseQualifiedTable(tableName) if table == "" { return fmt.Errorf("表名不能为空") } qualifiedTable := "" if schema != "" { qualifiedTable = fmt.Sprintf("%s.%s", quoteKingbaseIdent(schema), quoteKingbaseIdent(table)) } else { qualifiedTable = quoteKingbaseIdent(table) } // 1. Deletes for _, pk := range changes.Deletes { var wheres []string var args []interface{} idx := 0 for k, v := range pk { idx++ wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteKingbaseIdent(k), idx)) args = append(args, v) } if len(wheres) == 0 { continue } query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND ")) if _, err := tx.Exec(query, args...); err != nil { return fmt.Errorf("delete error: %v; sql=%s", err, query) } } // 2. Updates for _, update := range changes.Updates { var sets []string var args []interface{} idx := 0 for k, v := range update.Values { idx++ sets = append(sets, fmt.Sprintf("%s = $%d", quoteKingbaseIdent(k), idx)) args = append(args, v) } if len(sets) == 0 { continue } var wheres []string for k, v := range update.Keys { idx++ wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteKingbaseIdent(k), idx)) args = append(args, v) } if len(wheres) == 0 { return fmt.Errorf("更新操作需要主键条件") } query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND ")) if _, err := tx.Exec(query, args...); err != nil { return fmt.Errorf("update error: %v; sql=%s", err, query) } } // 3. Inserts for _, row := range changes.Inserts { var cols []string var placeholders []string var args []interface{} idx := 0 for k, v := range row { idx++ cols = append(cols, quoteKingbaseIdent(k)) placeholders = append(placeholders, fmt.Sprintf("$%d", idx)) args = append(args, v) } if len(cols) == 0 { continue } query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", ")) if _, err := tx.Exec(query, args...); err != nil { return fmt.Errorf("插入失败:%v; sql=%s", err, query) } } return tx.Commit() } func normalizeKingbaseIdentifier(raw string) string { return normalizeKingbaseIdentCommon(raw) } // kingbaseIdentNeedsQuote 判断标识符是否需要双引号包裹。 // 与前端 sql.ts 中 needsQuote 逻辑保持一致。 func kingbaseIdentNeedsQuote(ident string) bool { if ident == "" { return false } // 不是合法裸标识符格式(必须以字母或下划线开头,仅含字母、数字、下划线) if matched, _ := regexp.MatchString(`^[a-zA-Z_][a-zA-Z0-9_]*$`, ident); !matched { return true } // 包含大写字母时需要引号保护(KingbaseES/PostgreSQL 默认将未加引号的标识符折叠为小写) for _, r := range ident { if r >= 'A' && r <= 'Z' { return true } } // 是 SQL 保留字 return isKingbaseReservedWord(ident) } // isKingbaseReservedWord 检查是否为常见 SQL 保留字(简化版,与前端保持一致)。 func isKingbaseReservedWord(ident string) bool { switch strings.ToLower(ident) { case "select", "from", "where", "table", "index", "user", "order", "group", "by", "limit", "offset", "and", "or", "not", "null", "true", "false", "key", "primary", "foreign", "references", "default", "constraint", "create", "drop", "alter", "insert", "update", "delete", "set", "values", "into", "join", "left", "right", "inner", "outer", "on", "as", "is", "in", "like", "between", "case", "when", "then", "else", "end", "having", "distinct", "all", "any", "exists", "union", "except", "intersect", "column", "check", "unique", "with", "grant", "revoke", "trigger", "begin", "commit", "rollback", "schema", "database", "view", "function", "procedure", "sequence", "type", "domain", "role", "session", "current", "authorization", "cross", "full", "natural", "some", "cast", "fetch", "for", "to", "do", "if", "return", "returns", "declare", "cursor", "server", "owner": return true } return false } func quoteKingbaseIdent(name string) string { n := normalizeKingbaseIdentifier(name) if n == "" { return "\"\"" } if !kingbaseIdentNeedsQuote(n) { return n } n = strings.ReplaceAll(n, `"`, `""`) return `"` + n + `"` } func splitKingbaseQualifiedTable(tableName string) (schema string, table string) { return splitKingbaseQualifiedNameCommon(tableName) } func (k *KingbaseDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) { // dbName 在本项目语义里是“数据库”,schema 由 table_schema 决定;这里返回全部用户 schema 的列用于查询提示。 query := ` SELECT table_schema, table_name, column_name, data_type FROM information_schema.columns WHERE table_schema NOT IN ('pg_catalog', 'information_schema') AND table_schema NOT LIKE 'pg|_%' ESCAPE '|' ORDER BY table_schema, table_name, ordinal_position` data, _, err := k.Query(query) if err != nil { return nil, err } var cols []connection.ColumnDefinitionWithTable for _, row := range data { schema := fmt.Sprintf("%v", row["table_schema"]) table := fmt.Sprintf("%v", row["table_name"]) tableName := table if strings.TrimSpace(schema) != "" { tableName = fmt.Sprintf("%s.%s", schema, table) } col := connection.ColumnDefinitionWithTable{ TableName: tableName, Name: fmt.Sprintf("%v", row["column_name"]), Type: fmt.Sprintf("%v", row["data_type"]), } cols = append(cols, col) } return cols, nil }