//go:build gonavi_full_drivers || gonavi_kingbase_driver package db import ( "context" "database/sql" "fmt" "net" "net/url" "sort" "strconv" "strings" "time" "GoNavi-Wails/internal/connection" "GoNavi-Wails/internal/logger" "GoNavi-Wails/internal/ssh" "GoNavi-Wails/internal/utils" _ "gitea.com/kingbase/gokb" // Registers "kingbase" driver ) type KingbaseDB struct { conn *sql.DB pingTimeout time.Duration forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder } func quoteConnValue(v string) string { if v == "" { return "''" } needsQuote := false for _, r := range v { switch r { case ' ', '\t', '\n', '\r', '\v', '\f', '\'', '\\': needsQuote = true } if needsQuote { break } } if !needsQuote { return v } var b strings.Builder b.Grow(len(v) + 2) b.WriteByte('\'') for _, r := range v { if r == '\\' || r == '\'' { b.WriteByte('\\') } b.WriteRune(r) } b.WriteByte('\'') return b.String() } func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string { // Kingbase DSN usually similar to Postgres: // host=localhost port=54321 user=system password=... dbname=TEST sslmode=disable params := url.Values{} params.Set("host", config.Host) params.Set("port", strconv.Itoa(config.Port)) params.Set("user", config.User) params.Set("password", config.Password) params.Set("dbname", config.Database) params.Set("sslmode", resolvePostgresSSLMode(config)) params.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config))) mergeConnectionParamsFromConfigWithAllowlist(params, config, kingbaseConnectionParamNames, "kingbase") preferred := []string{"host", "port", "user", "password", "dbname", "sslmode", "connect_timeout"} seen := make(map[string]struct{}, len(params)) parts := make([]string, 0, len(params)) for _, key := range preferred { if values, ok := params[key]; ok && len(values) > 0 { parts = append(parts, fmt.Sprintf("%s=%s", key, quoteConnValue(values[len(values)-1]))) seen[key] = struct{}{} } } extraKeys := make([]string, 0, len(params)) for key := range params { if _, ok := seen[key]; ok || !isSafeConnectionParamKey(key) { continue } extraKeys = append(extraKeys, key) } sort.Strings(extraKeys) for _, key := range extraKeys { values := params[key] if len(values) == 0 { continue } parts = append(parts, fmt.Sprintf("%s=%s", key, quoteConnValue(values[len(values)-1]))) } return strings.Join(parts, " ") } func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error { runConfig := config if config.UseSSH { // Create SSH tunnel with local port forwarding logger.Infof("人大金仓使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User) forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port) if err != nil { return fmt.Errorf("创建 SSH 隧道失败:%w", err) } k.forwarder = forwarder // Parse local address host, portStr, err := net.SplitHostPort(forwarder.LocalAddr) if err != nil { return fmt.Errorf("解析本地转发地址失败:%w", err) } port, err := strconv.Atoi(portStr) if err != nil { return fmt.Errorf("解析本地端口失败:%w", err) } // Create a modified config pointing to local forwarder localConfig := config localConfig.Host = host localConfig.Port = port localConfig.UseSSH = false runConfig = localConfig logger.Infof("人大金仓通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port) } attempts := []connection.ConnectionConfig{runConfig} if shouldTrySSLPreferredFallback(runConfig) { attempts = append(attempts, withSSLDisabled(runConfig)) } var failures []string for idx, attempt := range attempts { dsn := k.getDSN(attempt) db, err := sql.Open("kingbase", dsn) if err != nil { failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err)) continue } k.conn = db k.pingTimeout = getConnectTimeout(attempt) if err := k.Ping(); err != nil { _ = db.Close() k.conn = nil failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err)) continue } if idx > 0 { logger.Warnf("人大金仓 SSL 优先连接失败,已回退至明文连接") } // 获取 schema 列表以重构带有 search_path 的连接池 searchPathStr := k.getSearchPathStr() if searchPathStr != "" { // 将 search_path 参数拼入 DSN finalDSN := dsn + " search_path=" + quoteConnValue(searchPathStr) if finalDB, err := sql.Open("kingbase", finalDSN); err == nil { k.pingTimeout = getConnectTimeout(attempt) finalDB.SetConnMaxLifetime(5 * time.Minute) // 临时将 k.conn 指向 finalDB 来做 ping 测试 oldConn := k.conn k.conn = finalDB if err := k.Ping(); err == nil { // 成功使用带 search_path 的连接池 _ = oldConn.Close() logger.Infof("人大金仓已配置连接级 search_path:%s", searchPathStr) } else { _ = finalDB.Close() k.conn = oldConn } } } if searchPathStr != "" { timeout := k.pingTimeout if timeout <= 0 { timeout = 5 * time.Second } ctx, cancel := utils.ContextWithTimeout(timeout) defer cancel() if _, err := k.conn.ExecContext(ctx, fmt.Sprintf("SET search_path TO %s", searchPathStr)); err != nil { logger.Warnf("人大金仓显式设置 search_path 失败:%v", err) } else { logger.Infof("人大金仓已设置默认 search_path:%s", searchPathStr) } } return nil } return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ";")) } // getSearchPathStr 查询当前数据库中所有用户 schema,配置 DSN 的 search_path。 // KingBase 默认 search_path 为 "$user", public,对于自定义 schema 下的表不可见。 func (k *KingbaseDB) getSearchPathStr() string { if k.conn == nil { return "" } query := `SELECT nspname FROM pg_namespace WHERE nspname NOT IN ('pg_catalog', 'information_schema') AND nspname NOT LIKE 'pg|_%' ESCAPE '|' ORDER BY nspname` rows, err := k.conn.Query(query) if err != nil { logger.Warnf("人大金仓查询用户 schema 失败,跳过 search_path 设置:%v", err) return "" } defer rows.Close() var rawSchemas []string for rows.Next() { var name string if err := rows.Scan(&name); err != nil { continue } name = strings.TrimSpace(name) if name != "" { rawSchemas = append(rawSchemas, name) } } searchPath, _ := buildKingbaseSearchPathCommon(rawSchemas) return searchPath } func (k *KingbaseDB) Close() error { // Close SSH forwarder first if exists if k.forwarder != nil { if err := k.forwarder.Close(); err != nil { logger.Warnf("关闭人大金仓 SSH 端口转发失败:%v", err) } k.forwarder = nil } // Then close database connection if k.conn != nil { return k.conn.Close() } return nil } func (k *KingbaseDB) Ping() error { if k.conn == nil { return fmt.Errorf("连接未打开") } timeout := k.pingTimeout if timeout <= 0 { timeout = 5 * time.Second } ctx, cancel := utils.ContextWithTimeout(timeout) defer cancel() return k.conn.PingContext(ctx) } func (k *KingbaseDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { if k.conn == nil { return nil, nil, fmt.Errorf("连接未打开") } rows, err := k.conn.QueryContext(ctx, query) if err != nil { return nil, nil, err } defer rows.Close() return scanRows(rows) } func (k *KingbaseDB) Query(query string) ([]map[string]interface{}, []string, error) { if k.conn == nil { return nil, nil, fmt.Errorf("连接未打开") } rows, err := k.conn.Query(query) if err != nil { return nil, nil, err } defer rows.Close() return scanRows(rows) } func (k *KingbaseDB) ExecContext(ctx context.Context, query string) (int64, error) { if k.conn == nil { return 0, fmt.Errorf("连接未打开") } res, err := k.conn.ExecContext(ctx, query) if err != nil { return 0, err } return res.RowsAffected() } func (k *KingbaseDB) ExecBatchContext(ctx context.Context, query string) (int64, error) { if k.conn == nil { return 0, fmt.Errorf("连接未打开") } res, err := k.conn.ExecContext(ctx, query) if err != nil { return 0, err } return res.RowsAffected() } func (k *KingbaseDB) Exec(query string) (int64, error) { if k.conn == nil { return 0, fmt.Errorf("连接未打开") } res, err := k.conn.Exec(query) if err != nil { return 0, err } return res.RowsAffected() } func (k *KingbaseDB) GetDatabases() ([]string, error) { data, _, err := k.Query("SELECT datname FROM pg_database WHERE datistemplate = false") if err == nil { dbs := collectKingbaseNames(data, "datname", "database") if len(dbs) > 0 { return dbs, nil } } fallbackData, _, fallbackErr := k.Query("SELECT current_database() AS datname") if fallbackErr != nil { if err != nil { return nil, err } return nil, fallbackErr } dbs := collectKingbaseNames(fallbackData, "datname", "database", "current_database", "currentDatabase") if len(dbs) > 0 { return dbs, nil } if err != nil { return nil, err } return nil, fmt.Errorf("未获取到可见数据库列表") } func collectKingbaseNames(rows []map[string]interface{}, keys ...string) []string { result := make([]string, 0, len(rows)) seen := make(map[string]struct{}, len(rows)) for _, row := range rows { name := strings.TrimSpace(getKingbaseNameFromRow(row, keys...)) if name == "" { continue } if _, exists := seen[name]; exists { continue } seen[name] = struct{}{} result = append(result, name) } return result } func getKingbaseNameFromRow(row map[string]interface{}, keys ...string) string { if len(row) == 0 { return "" } for _, key := range keys { if value, ok := row[key]; ok { return fmt.Sprintf("%v", value) } } for existingKey, value := range row { for _, key := range keys { if strings.EqualFold(existingKey, key) { return fmt.Sprintf("%v", value) } } } for _, value := range row { return fmt.Sprintf("%v", value) } return "" } func (k *KingbaseDB) GetTables(dbName string) ([]string, error) { // Kingbase: tables are scoped by the current DB connection; include schema to avoid search_path issues. query := ` SELECT table_schema AS schemaname, table_name AS tablename FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema') AND table_schema NOT LIKE 'pg|_%' ESCAPE '|' ORDER BY table_schema, table_name` data, _, err := k.Query(query) if err != nil { return nil, err } var tables []string for _, row := range data { schema, okSchema := row["schemaname"] name, okName := row["tablename"] if okSchema && okName { tables = append(tables, fmt.Sprintf("%v.%v", schema, name)) continue } if val, ok := row["table_name"]; ok { tables = append(tables, fmt.Sprintf("%v", val)) } } return tables, nil } func (k *KingbaseDB) GetCreateStatement(dbName, tableName string) (string, error) { // Kingbase doesn't have "SHOW CREATE TABLE". // We can try pg_dump logic or use a query to reconstruction. // A simple approach is just returning basic info or "Not Supported". // Or we can query information_schema to build it. return "SHOW CREATE TABLE not directly supported in Kingbase/Postgres via SQL", nil } func (k *KingbaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) { // 解析 schema.table 格式 schema := strings.TrimSpace(dbName) table := strings.TrimSpace(tableName) // 如果 tableName 包含 schema (格式: schema.table) if parts := strings.SplitN(table, ".", 2); len(parts) == 2 { parsedSchema := strings.TrimSpace(parts[0]) parsedTable := strings.TrimSpace(parts[1]) if parsedSchema != "" && parsedTable != "" { schema = parsedSchema table = parsedTable } } // 如果仍然没有 schema,使用 current_schema() // 这样可以自动匹配当前连接的 search_path if schema == "" { return k.getColumnsWithCurrentSchema(table) } if table == "" { return nil, fmt.Errorf("表名不能为空") } // 转义函数:处理单引号,移除双引号 esc := func(s string) string { // 移除前后的双引号(如果存在) s = strings.Trim(s, "\"") // 转义单引号 return strings.ReplaceAll(s, "'", "''") } query := fmt.Sprintf(` SELECT a.attname AS column_name, pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type, CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable, pg_get_expr(ad.adbin, ad.adrelid) AS column_default, col_description(a.attrelid, a.attnum) AS comment, CASE WHEN pk.attname IS NOT NULL THEN 'PRI' ELSE '' END AS column_key FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace JOIN pg_attribute a ON a.attrelid = c.oid LEFT JOIN pg_attrdef ad ON ad.adrelid = c.oid AND ad.adnum = a.attnum LEFT JOIN ( SELECT i.indrelid, a3.attname FROM pg_index i JOIN pg_attribute a3 ON a3.attrelid = i.indrelid AND a3.attnum = ANY(i.indkey) WHERE i.indisprimary ) pk ON pk.indrelid = c.oid AND pk.attname = a.attname WHERE c.relkind IN ('r', 'p') AND n.nspname = '%s' AND c.relname = '%s' AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum`, esc(schema), esc(table)) data, _, err := k.Query(query) if err != nil { return nil, err } var columns []connection.ColumnDefinition for _, row := range data { col := connection.ColumnDefinition{ Name: fmt.Sprintf("%v", row["column_name"]), Type: fmt.Sprintf("%v", row["data_type"]), Nullable: fmt.Sprintf("%v", row["is_nullable"]), Key: fmt.Sprintf("%v", row["column_key"]), Extra: "", Comment: "", } if row["column_default"] != nil { def := fmt.Sprintf("%v", row["column_default"]) col.Default = &def if strings.HasPrefix(strings.ToLower(strings.TrimSpace(def)), "nextval(") { col.Extra = "auto_increment" } } if v, ok := row["comment"]; ok && v != nil { col.Comment = fmt.Sprintf("%v", v) } columns = append(columns, col) } return columns, nil } // getColumnsWithCurrentSchema 使用 current_schema() 查询当前schema的表 func (k *KingbaseDB) getColumnsWithCurrentSchema(tableName string) ([]connection.ColumnDefinition, error) { table := strings.TrimSpace(tableName) if table == "" { return nil, fmt.Errorf("表名不能为空") } // 转义函数 esc := func(s string) string { s = strings.Trim(s, "\"") return strings.ReplaceAll(s, "'", "''") } // 使用 current_schema() 获取当前schema query := fmt.Sprintf(` SELECT a.attname AS column_name, pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type, CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable, pg_get_expr(ad.adbin, ad.adrelid) AS column_default, col_description(a.attrelid, a.attnum) AS comment, CASE WHEN pk.attname IS NOT NULL THEN 'PRI' ELSE '' END AS column_key FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace JOIN pg_attribute a ON a.attrelid = c.oid LEFT JOIN pg_attrdef ad ON ad.adrelid = c.oid AND ad.adnum = a.attnum LEFT JOIN ( SELECT i.indrelid, a3.attname FROM pg_index i JOIN pg_attribute a3 ON a3.attrelid = i.indrelid AND a3.attnum = ANY(i.indkey) WHERE i.indisprimary ) pk ON pk.indrelid = c.oid AND pk.attname = a.attname WHERE c.relkind IN ('r', 'p') AND n.nspname = current_schema() AND c.relname = '%s' AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum`, esc(table)) data, _, err := k.Query(query) if err != nil { return nil, err } var columns []connection.ColumnDefinition for _, row := range data { col := connection.ColumnDefinition{ Name: fmt.Sprintf("%v", row["column_name"]), Type: fmt.Sprintf("%v", row["data_type"]), Nullable: fmt.Sprintf("%v", row["is_nullable"]), Key: fmt.Sprintf("%v", row["column_key"]), Extra: "", Comment: "", } if row["column_default"] != nil { def := fmt.Sprintf("%v", row["column_default"]) col.Default = &def if strings.HasPrefix(strings.ToLower(strings.TrimSpace(def)), "nextval(") { col.Extra = "auto_increment" } } if v, ok := row["comment"]; ok && v != nil { col.Comment = fmt.Sprintf("%v", v) } columns = append(columns, col) } return columns, nil } func (k *KingbaseDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) { // 解析 schema.table 格式 schema := strings.TrimSpace(dbName) table := strings.TrimSpace(tableName) // 如果 tableName 包含 schema (格式: schema.table) if parts := strings.SplitN(table, ".", 2); len(parts) == 2 { parsedSchema := strings.TrimSpace(parts[0]) parsedTable := strings.TrimSpace(parts[1]) if parsedSchema != "" && parsedTable != "" { schema = parsedSchema table = parsedTable } } if table == "" { return nil, fmt.Errorf("表名不能为空") } // 转义函数:处理单引号,移除双引号 esc := func(s string) string { s = strings.Trim(s, "\"") return strings.ReplaceAll(s, "'", "''") } // 构建查询:如果没有指定schema,使用current_schema() var query string if schema != "" { query = fmt.Sprintf(` SELECT i.relname as index_name, a.attname as column_name, ix.indisunique as is_unique FROM pg_class t, pg_class i, pg_index ix, pg_attribute a, pg_namespace n WHERE t.oid = ix.indrelid AND i.oid = ix.indexrelid AND a.attrelid = t.oid AND a.attnum = ANY(ix.indkey) AND t.relkind = 'r' AND t.relname = '%s' AND n.oid = t.relnamespace AND n.nspname = '%s' `, esc(table), esc(schema)) } else { query = fmt.Sprintf(` SELECT i.relname as index_name, a.attname as column_name, ix.indisunique as is_unique FROM pg_class t, pg_class i, pg_index ix, pg_attribute a, pg_namespace n WHERE t.oid = ix.indrelid AND i.oid = ix.indexrelid AND a.attrelid = t.oid AND a.attnum = ANY(ix.indkey) AND t.relkind = 'r' AND t.relname = '%s' AND n.oid = t.relnamespace AND n.nspname = current_schema() `, esc(table)) } data, _, err := k.Query(query) if err != nil { return nil, err } var indexes []connection.IndexDefinition for _, row := range data { nonUnique := 1 if val, ok := row["is_unique"]; ok { if b, ok := val.(bool); ok && b { nonUnique = 0 } } idx := connection.IndexDefinition{ Name: fmt.Sprintf("%v", row["index_name"]), ColumnName: fmt.Sprintf("%v", row["column_name"]), NonUnique: nonUnique, IndexType: "BTREE", // Default } indexes = append(indexes, idx) } return indexes, nil } func (k *KingbaseDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) { // 解析 schema.table 格式 schema := strings.TrimSpace(dbName) table := strings.TrimSpace(tableName) // 如果 tableName 包含 schema (格式: schema.table) if parts := strings.SplitN(table, ".", 2); len(parts) == 2 { parsedSchema := strings.TrimSpace(parts[0]) parsedTable := strings.TrimSpace(parts[1]) if parsedSchema != "" && parsedTable != "" { schema = parsedSchema table = parsedTable } } if table == "" { return nil, fmt.Errorf("表名不能为空") } // 转义函数:处理单引号,移除双引号 esc := func(s string) string { s = strings.Trim(s, "\"") return strings.ReplaceAll(s, "'", "''") } // 构建查询:如果没有指定schema,使用current_schema() var query string if schema != "" { query = fmt.Sprintf(` SELECT tc.constraint_name, kcu.column_name, ccu.table_name AS foreign_table_name, ccu.column_name AS foreign_column_name FROM information_schema.table_constraints AS tc JOIN information_schema.key_column_usage AS kcu ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema JOIN information_schema.constraint_column_usage AS ccu ON ccu.constraint_name = tc.constraint_name AND ccu.table_schema = tc.table_schema WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema='%s'`, esc(table), esc(schema)) } else { query = fmt.Sprintf(` SELECT tc.constraint_name, kcu.column_name, ccu.table_name AS foreign_table_name, ccu.column_name AS foreign_column_name FROM information_schema.table_constraints AS tc JOIN information_schema.key_column_usage AS kcu ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema JOIN information_schema.constraint_column_usage AS ccu ON ccu.constraint_name = tc.constraint_name AND ccu.table_schema = tc.table_schema WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema=current_schema()`, esc(table)) } data, _, err := k.Query(query) if err != nil { return nil, err } var fks []connection.ForeignKeyDefinition for _, row := range data { fk := connection.ForeignKeyDefinition{ Name: fmt.Sprintf("%v", row["constraint_name"]), ColumnName: fmt.Sprintf("%v", row["column_name"]), RefTableName: fmt.Sprintf("%v", row["foreign_table_name"]), RefColumnName: fmt.Sprintf("%v", row["foreign_column_name"]), ConstraintName: fmt.Sprintf("%v", row["constraint_name"]), } fks = append(fks, fk) } return fks, nil } func (k *KingbaseDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) { // 解析 schema.table 格式 schema := strings.TrimSpace(dbName) table := strings.TrimSpace(tableName) // 如果 tableName 包含 schema (格式: schema.table) if parts := strings.SplitN(table, ".", 2); len(parts) == 2 { parsedSchema := strings.TrimSpace(parts[0]) parsedTable := strings.TrimSpace(parts[1]) if parsedSchema != "" && parsedTable != "" { schema = parsedSchema table = parsedTable } } if table == "" { return nil, fmt.Errorf("表名不能为空") } // 转义函数:处理单引号,移除双引号 esc := func(s string) string { s = strings.Trim(s, "\"") return strings.ReplaceAll(s, "'", "''") } // 构建查询:如果指定了schema,也加上schema条件 var query string if schema != "" { query = fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation FROM information_schema.triggers WHERE event_object_table = '%s' AND event_object_schema = '%s'`, esc(table), esc(schema)) } else { query = fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation FROM information_schema.triggers WHERE event_object_table = '%s' AND event_object_schema = current_schema()`, esc(table)) } data, _, err := k.Query(query) if err != nil { return nil, err } var triggers []connection.TriggerDefinition for _, row := range data { trig := connection.TriggerDefinition{ Name: fmt.Sprintf("%v", row["trigger_name"]), Timing: fmt.Sprintf("%v", row["action_timing"]), Event: fmt.Sprintf("%v", row["event_manipulation"]), Statement: "SOURCE HIDDEN", } triggers = append(triggers, trig) } return triggers, nil } func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet) error { if k.conn == nil { return fmt.Errorf("连接未打开") } tx, err := k.conn.Begin() if err != nil { return err } defer tx.Rollback() schema, table := splitKingbaseQualifiedTable(tableName) if table == "" { return fmt.Errorf("表名不能为空") } qualifiedTable := "" if schema != "" { qualifiedTable = fmt.Sprintf("%s.%s", quoteKingbaseIdent(schema), quoteKingbaseIdent(table)) } else { qualifiedTable = quoteKingbaseIdent(table) } // 1. Deletes for _, pk := range changes.Deletes { var wheres []string var args []interface{} idx := 0 for k, v := range pk { idx++ wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteKingbaseIdent(k), idx)) args = append(args, v) } if len(wheres) == 0 { continue } query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND ")) if _, err := tx.Exec(query, args...); err != nil { return fmt.Errorf("delete error: %v; sql=%s", err, query) } } // 2. Updates for _, update := range changes.Updates { var sets []string var args []interface{} idx := 0 for k, v := range update.Values { idx++ sets = append(sets, fmt.Sprintf("%s = $%d", quoteKingbaseIdent(k), idx)) args = append(args, v) } if len(sets) == 0 { continue } var wheres []string for k, v := range update.Keys { idx++ wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteKingbaseIdent(k), idx)) args = append(args, v) } if len(wheres) == 0 { return fmt.Errorf("更新操作需要主键条件") } query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND ")) if _, err := tx.Exec(query, args...); err != nil { return fmt.Errorf("update error: %v; sql=%s", err, query) } } // 3. Inserts for _, row := range changes.Inserts { var cols []string var placeholders []string var args []interface{} idx := 0 for k, v := range row { idx++ cols = append(cols, quoteKingbaseIdent(k)) placeholders = append(placeholders, fmt.Sprintf("$%d", idx)) args = append(args, v) } if len(cols) == 0 { continue } query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", ")) if _, err := tx.Exec(query, args...); err != nil { return fmt.Errorf("插入失败:%v; sql=%s", err, query) } } return tx.Commit() } func splitKingbaseQualifiedTable(tableName string) (schema string, table string) { return splitKingbaseQualifiedNameCommon(tableName) } func (k *KingbaseDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) { // dbName 在本项目语义里是“数据库”,schema 由 table_schema 决定;这里返回全部用户 schema 的列用于查询提示。 query := ` SELECT table_schema, table_name, column_name, data_type FROM information_schema.columns WHERE table_schema NOT IN ('pg_catalog', 'information_schema') AND table_schema NOT LIKE 'pg|_%' ESCAPE '|' ORDER BY table_schema, table_name, ordinal_position` data, _, err := k.Query(query) if err != nil { return nil, err } var cols []connection.ColumnDefinitionWithTable for _, row := range data { schema := fmt.Sprintf("%v", row["table_schema"]) table := fmt.Sprintf("%v", row["table_name"]) tableName := table if strings.TrimSpace(schema) != "" { tableName = fmt.Sprintf("%s.%s", schema, table) } col := connection.ColumnDefinitionWithTable{ TableName: tableName, Name: fmt.Sprintf("%v", row["column_name"]), Type: fmt.Sprintf("%v", row["data_type"]), } cols = append(cols, col) } return cols, nil }