Files
MyGoNavi/internal/db/kingbase_impl.go
2026-04-11 21:53:52 +08:00

1004 lines
27 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//go:build gonavi_full_drivers || gonavi_kingbase_driver
package db
import (
"context"
"database/sql"
"fmt"
"net"
"regexp"
"strconv"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
_ "gitea.com/kingbase/gokb" // Registers "kingbase" driver
)
type KingbaseDB struct {
conn *sql.DB
pingTimeout time.Duration
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
}
func quoteConnValue(v string) string {
if v == "" {
return "''"
}
needsQuote := false
for _, r := range v {
switch r {
case ' ', '\t', '\n', '\r', '\v', '\f', '\'', '\\':
needsQuote = true
}
if needsQuote {
break
}
}
if !needsQuote {
return v
}
var b strings.Builder
b.Grow(len(v) + 2)
b.WriteByte('\'')
for _, r := range v {
if r == '\\' || r == '\'' {
b.WriteByte('\\')
}
b.WriteRune(r)
}
b.WriteByte('\'')
return b.String()
}
func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string {
// Kingbase DSN usually similar to Postgres:
// host=localhost port=54321 user=system password=... dbname=TEST sslmode=disable
address := config.Host
port := config.Port
// Construct DSN
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s connect_timeout=%d",
quoteConnValue(address),
port,
quoteConnValue(config.User),
quoteConnValue(config.Password),
quoteConnValue(config.Database),
quoteConnValue(resolvePostgresSSLMode(config)),
getConnectTimeoutSeconds(config),
)
return dsn
}
func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error {
runConfig := config
if config.UseSSH {
// Create SSH tunnel with local port forwarding
logger.Infof("人大金仓使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
if err != nil {
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
}
k.forwarder = forwarder
// Parse local address
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
if err != nil {
return fmt.Errorf("解析本地转发地址失败:%w", err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("解析本地端口失败:%w", err)
}
// Create a modified config pointing to local forwarder
localConfig := config
localConfig.Host = host
localConfig.Port = port
localConfig.UseSSH = false
runConfig = localConfig
logger.Infof("人大金仓通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
}
attempts := []connection.ConnectionConfig{runConfig}
if shouldTrySSLPreferredFallback(runConfig) {
attempts = append(attempts, withSSLDisabled(runConfig))
}
var failures []string
for idx, attempt := range attempts {
dsn := k.getDSN(attempt)
db, err := sql.Open("kingbase", dsn)
if err != nil {
failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err))
continue
}
k.conn = db
k.pingTimeout = getConnectTimeout(attempt)
if err := k.Ping(); err != nil {
_ = db.Close()
k.conn = nil
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err))
continue
}
if idx > 0 {
logger.Warnf("人大金仓 SSL 优先连接失败,已回退至明文连接")
}
// 获取 schema 列表以重构带有 search_path 的连接池
searchPathStr := k.getSearchPathStr()
if searchPathStr != "" {
// 将 search_path 参数拼入 DSN
finalDSN := dsn + " search_path=" + quoteConnValue(searchPathStr)
if finalDB, err := sql.Open("kingbase", finalDSN); err == nil {
k.pingTimeout = getConnectTimeout(attempt)
finalDB.SetConnMaxLifetime(5 * time.Minute)
// 临时将 k.conn 指向 finalDB 来做 ping 测试
oldConn := k.conn
k.conn = finalDB
if err := k.Ping(); err == nil {
// 成功使用带 search_path 的连接池
_ = oldConn.Close()
logger.Infof("人大金仓已配置连接级 search_path%s", searchPathStr)
} else {
_ = finalDB.Close()
k.conn = oldConn
}
}
}
if searchPathStr != "" {
timeout := k.pingTimeout
if timeout <= 0 {
timeout = 5 * time.Second
}
ctx, cancel := utils.ContextWithTimeout(timeout)
defer cancel()
if _, err := k.conn.ExecContext(ctx, fmt.Sprintf("SET search_path TO %s", searchPathStr)); err != nil {
logger.Warnf("人大金仓显式设置 search_path 失败:%v", err)
} else {
logger.Infof("人大金仓已设置默认 search_path%s", searchPathStr)
}
}
return nil
}
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ""))
}
// getSearchPathStr 查询当前数据库中所有用户 schema配置 DSN 的 search_path。
// KingBase 默认 search_path 为 "$user", public对于自定义 schema 下的表不可见。
func (k *KingbaseDB) getSearchPathStr() string {
if k.conn == nil {
return ""
}
query := `SELECT nspname FROM pg_namespace
WHERE nspname NOT IN ('pg_catalog', 'information_schema')
AND nspname NOT LIKE 'pg_%'
ORDER BY nspname`
rows, err := k.conn.Query(query)
if err != nil {
logger.Warnf("人大金仓查询用户 schema 失败,跳过 search_path 设置:%v", err)
return ""
}
defer rows.Close()
var rawSchemas []string
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
continue
}
name = strings.TrimSpace(name)
if name != "" {
rawSchemas = append(rawSchemas, name)
}
}
searchPath, _ := buildKingbaseSearchPathCommon(rawSchemas)
return searchPath
}
func (k *KingbaseDB) Close() error {
// Close SSH forwarder first if exists
if k.forwarder != nil {
if err := k.forwarder.Close(); err != nil {
logger.Warnf("关闭人大金仓 SSH 端口转发失败:%v", err)
}
k.forwarder = nil
}
// Then close database connection
if k.conn != nil {
return k.conn.Close()
}
return nil
}
func (k *KingbaseDB) Ping() error {
if k.conn == nil {
return fmt.Errorf("连接未打开")
}
timeout := k.pingTimeout
if timeout <= 0 {
timeout = 5 * time.Second
}
ctx, cancel := utils.ContextWithTimeout(timeout)
defer cancel()
return k.conn.PingContext(ctx)
}
func (k *KingbaseDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if k.conn == nil {
return nil, nil, fmt.Errorf("连接未打开")
}
rows, err := k.conn.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (k *KingbaseDB) Query(query string) ([]map[string]interface{}, []string, error) {
if k.conn == nil {
return nil, nil, fmt.Errorf("连接未打开")
}
rows, err := k.conn.Query(query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (k *KingbaseDB) ExecContext(ctx context.Context, query string) (int64, error) {
if k.conn == nil {
return 0, fmt.Errorf("连接未打开")
}
res, err := k.conn.ExecContext(ctx, query)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (k *KingbaseDB) ExecBatchContext(ctx context.Context, query string) (int64, error) {
if k.conn == nil {
return 0, fmt.Errorf("连接未打开")
}
res, err := k.conn.ExecContext(ctx, query)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (k *KingbaseDB) Exec(query string) (int64, error) {
if k.conn == nil {
return 0, fmt.Errorf("连接未打开")
}
res, err := k.conn.Exec(query)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (k *KingbaseDB) GetDatabases() ([]string, error) {
data, _, err := k.Query("SELECT datname FROM pg_database WHERE datistemplate = false")
if err == nil {
dbs := collectKingbaseNames(data, "datname", "database")
if len(dbs) > 0 {
return dbs, nil
}
}
fallbackData, _, fallbackErr := k.Query("SELECT current_database() AS datname")
if fallbackErr != nil {
if err != nil {
return nil, err
}
return nil, fallbackErr
}
dbs := collectKingbaseNames(fallbackData, "datname", "database", "current_database", "currentDatabase")
if len(dbs) > 0 {
return dbs, nil
}
if err != nil {
return nil, err
}
return nil, fmt.Errorf("未获取到可见数据库列表")
}
func collectKingbaseNames(rows []map[string]interface{}, keys ...string) []string {
result := make([]string, 0, len(rows))
seen := make(map[string]struct{}, len(rows))
for _, row := range rows {
name := strings.TrimSpace(getKingbaseNameFromRow(row, keys...))
if name == "" {
continue
}
if _, exists := seen[name]; exists {
continue
}
seen[name] = struct{}{}
result = append(result, name)
}
return result
}
func getKingbaseNameFromRow(row map[string]interface{}, keys ...string) string {
if len(row) == 0 {
return ""
}
for _, key := range keys {
if value, ok := row[key]; ok {
return fmt.Sprintf("%v", value)
}
}
for existingKey, value := range row {
for _, key := range keys {
if strings.EqualFold(existingKey, key) {
return fmt.Sprintf("%v", value)
}
}
}
for _, value := range row {
return fmt.Sprintf("%v", value)
}
return ""
}
func (k *KingbaseDB) GetTables(dbName string) ([]string, error) {
// Kingbase: tables are scoped by the current DB connection; include schema to avoid search_path issues.
query := `
SELECT table_schema AS schemaname, table_name AS tablename
FROM information_schema.tables
WHERE table_type = 'BASE TABLE'
AND table_schema NOT IN ('pg_catalog', 'information_schema')
AND table_schema NOT LIKE 'pg_%'
ORDER BY table_schema, table_name`
data, _, err := k.Query(query)
if err != nil {
return nil, err
}
var tables []string
for _, row := range data {
schema, okSchema := row["schemaname"]
name, okName := row["tablename"]
if okSchema && okName {
tables = append(tables, fmt.Sprintf("%v.%v", schema, name))
continue
}
if val, ok := row["table_name"]; ok {
tables = append(tables, fmt.Sprintf("%v", val))
}
}
return tables, nil
}
func (k *KingbaseDB) GetCreateStatement(dbName, tableName string) (string, error) {
// Kingbase doesn't have "SHOW CREATE TABLE".
// We can try pg_dump logic or use a query to reconstruction.
// A simple approach is just returning basic info or "Not Supported".
// Or we can query information_schema to build it.
return "SHOW CREATE TABLE not directly supported in Kingbase/Postgres via SQL", nil
}
func (k *KingbaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
// 解析 schema.table 格式
schema := strings.TrimSpace(dbName)
table := strings.TrimSpace(tableName)
// 如果 tableName 包含 schema (格式: schema.table)
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
parsedSchema := strings.TrimSpace(parts[0])
parsedTable := strings.TrimSpace(parts[1])
if parsedSchema != "" && parsedTable != "" {
schema = parsedSchema
table = parsedTable
}
}
// 如果仍然没有 schema,使用 current_schema()
// 这样可以自动匹配当前连接的 search_path
if schema == "" {
return k.getColumnsWithCurrentSchema(table)
}
if table == "" {
return nil, fmt.Errorf("表名不能为空")
}
// 转义函数:处理单引号,移除双引号
esc := func(s string) string {
// 移除前后的双引号(如果存在)
s = strings.Trim(s, "\"")
// 转义单引号
return strings.ReplaceAll(s, "'", "''")
}
query := fmt.Sprintf(`
SELECT
a.attname AS column_name,
pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable,
pg_get_expr(ad.adbin, ad.adrelid) AS column_default,
col_description(a.attrelid, a.attnum) AS comment,
CASE WHEN pk.attname IS NOT NULL THEN 'PRI' ELSE '' END AS column_key
FROM pg_class c
JOIN pg_namespace n ON n.oid = c.relnamespace
JOIN pg_attribute a ON a.attrelid = c.oid
LEFT JOIN pg_attrdef ad ON ad.adrelid = c.oid AND ad.adnum = a.attnum
LEFT JOIN (
SELECT i.indrelid, a3.attname
FROM pg_index i
JOIN pg_attribute a3 ON a3.attrelid = i.indrelid AND a3.attnum = ANY(i.indkey)
WHERE i.indisprimary
) pk ON pk.indrelid = c.oid AND pk.attname = a.attname
WHERE c.relkind IN ('r', 'p')
AND n.nspname = '%s'
AND c.relname = '%s'
AND a.attnum > 0
AND NOT a.attisdropped
ORDER BY a.attnum`, esc(schema), esc(table))
data, _, err := k.Query(query)
if err != nil {
return nil, err
}
var columns []connection.ColumnDefinition
for _, row := range data {
col := connection.ColumnDefinition{
Name: fmt.Sprintf("%v", row["column_name"]),
Type: fmt.Sprintf("%v", row["data_type"]),
Nullable: fmt.Sprintf("%v", row["is_nullable"]),
Key: fmt.Sprintf("%v", row["column_key"]),
Extra: "",
Comment: "",
}
if row["column_default"] != nil {
def := fmt.Sprintf("%v", row["column_default"])
col.Default = &def
if strings.HasPrefix(strings.ToLower(strings.TrimSpace(def)), "nextval(") {
col.Extra = "auto_increment"
}
}
if v, ok := row["comment"]; ok && v != nil {
col.Comment = fmt.Sprintf("%v", v)
}
columns = append(columns, col)
}
return columns, nil
}
// getColumnsWithCurrentSchema 使用 current_schema() 查询当前schema的表
func (k *KingbaseDB) getColumnsWithCurrentSchema(tableName string) ([]connection.ColumnDefinition, error) {
table := strings.TrimSpace(tableName)
if table == "" {
return nil, fmt.Errorf("表名不能为空")
}
// 转义函数
esc := func(s string) string {
s = strings.Trim(s, "\"")
return strings.ReplaceAll(s, "'", "''")
}
// 使用 current_schema() 获取当前schema
query := fmt.Sprintf(`
SELECT
a.attname AS column_name,
pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable,
pg_get_expr(ad.adbin, ad.adrelid) AS column_default,
col_description(a.attrelid, a.attnum) AS comment,
CASE WHEN pk.attname IS NOT NULL THEN 'PRI' ELSE '' END AS column_key
FROM pg_class c
JOIN pg_namespace n ON n.oid = c.relnamespace
JOIN pg_attribute a ON a.attrelid = c.oid
LEFT JOIN pg_attrdef ad ON ad.adrelid = c.oid AND ad.adnum = a.attnum
LEFT JOIN (
SELECT i.indrelid, a3.attname
FROM pg_index i
JOIN pg_attribute a3 ON a3.attrelid = i.indrelid AND a3.attnum = ANY(i.indkey)
WHERE i.indisprimary
) pk ON pk.indrelid = c.oid AND pk.attname = a.attname
WHERE c.relkind IN ('r', 'p')
AND n.nspname = current_schema()
AND c.relname = '%s'
AND a.attnum > 0
AND NOT a.attisdropped
ORDER BY a.attnum`, esc(table))
data, _, err := k.Query(query)
if err != nil {
return nil, err
}
var columns []connection.ColumnDefinition
for _, row := range data {
col := connection.ColumnDefinition{
Name: fmt.Sprintf("%v", row["column_name"]),
Type: fmt.Sprintf("%v", row["data_type"]),
Nullable: fmt.Sprintf("%v", row["is_nullable"]),
Key: fmt.Sprintf("%v", row["column_key"]),
Extra: "",
Comment: "",
}
if row["column_default"] != nil {
def := fmt.Sprintf("%v", row["column_default"])
col.Default = &def
if strings.HasPrefix(strings.ToLower(strings.TrimSpace(def)), "nextval(") {
col.Extra = "auto_increment"
}
}
if v, ok := row["comment"]; ok && v != nil {
col.Comment = fmt.Sprintf("%v", v)
}
columns = append(columns, col)
}
return columns, nil
}
func (k *KingbaseDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
// 解析 schema.table 格式
schema := strings.TrimSpace(dbName)
table := strings.TrimSpace(tableName)
// 如果 tableName 包含 schema (格式: schema.table)
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
parsedSchema := strings.TrimSpace(parts[0])
parsedTable := strings.TrimSpace(parts[1])
if parsedSchema != "" && parsedTable != "" {
schema = parsedSchema
table = parsedTable
}
}
if table == "" {
return nil, fmt.Errorf("表名不能为空")
}
// 转义函数:处理单引号,移除双引号
esc := func(s string) string {
s = strings.Trim(s, "\"")
return strings.ReplaceAll(s, "'", "''")
}
// 构建查询如果没有指定schema,使用current_schema()
var query string
if schema != "" {
query = fmt.Sprintf(`
SELECT
i.relname as index_name,
a.attname as column_name,
ix.indisunique as is_unique
FROM
pg_class t,
pg_class i,
pg_index ix,
pg_attribute a,
pg_namespace n
WHERE
t.oid = ix.indrelid
AND i.oid = ix.indexrelid
AND a.attrelid = t.oid
AND a.attnum = ANY(ix.indkey)
AND t.relkind = 'r'
AND t.relname = '%s'
AND n.oid = t.relnamespace
AND n.nspname = '%s'
`, esc(table), esc(schema))
} else {
query = fmt.Sprintf(`
SELECT
i.relname as index_name,
a.attname as column_name,
ix.indisunique as is_unique
FROM
pg_class t,
pg_class i,
pg_index ix,
pg_attribute a,
pg_namespace n
WHERE
t.oid = ix.indrelid
AND i.oid = ix.indexrelid
AND a.attrelid = t.oid
AND a.attnum = ANY(ix.indkey)
AND t.relkind = 'r'
AND t.relname = '%s'
AND n.oid = t.relnamespace
AND n.nspname = current_schema()
`, esc(table))
}
data, _, err := k.Query(query)
if err != nil {
return nil, err
}
var indexes []connection.IndexDefinition
for _, row := range data {
nonUnique := 1
if val, ok := row["is_unique"]; ok {
if b, ok := val.(bool); ok && b {
nonUnique = 0
}
}
idx := connection.IndexDefinition{
Name: fmt.Sprintf("%v", row["index_name"]),
ColumnName: fmt.Sprintf("%v", row["column_name"]),
NonUnique: nonUnique,
IndexType: "BTREE", // Default
}
indexes = append(indexes, idx)
}
return indexes, nil
}
func (k *KingbaseDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
// 解析 schema.table 格式
schema := strings.TrimSpace(dbName)
table := strings.TrimSpace(tableName)
// 如果 tableName 包含 schema (格式: schema.table)
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
parsedSchema := strings.TrimSpace(parts[0])
parsedTable := strings.TrimSpace(parts[1])
if parsedSchema != "" && parsedTable != "" {
schema = parsedSchema
table = parsedTable
}
}
if table == "" {
return nil, fmt.Errorf("表名不能为空")
}
// 转义函数:处理单引号,移除双引号
esc := func(s string) string {
s = strings.Trim(s, "\"")
return strings.ReplaceAll(s, "'", "''")
}
// 构建查询如果没有指定schema,使用current_schema()
var query string
if schema != "" {
query = fmt.Sprintf(`
SELECT
tc.constraint_name,
kcu.column_name,
ccu.table_name AS foreign_table_name,
ccu.column_name AS foreign_column_name
FROM
information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema='%s'`,
esc(table), esc(schema))
} else {
query = fmt.Sprintf(`
SELECT
tc.constraint_name,
kcu.column_name,
ccu.table_name AS foreign_table_name,
ccu.column_name AS foreign_column_name
FROM
information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema=current_schema()`,
esc(table))
}
data, _, err := k.Query(query)
if err != nil {
return nil, err
}
var fks []connection.ForeignKeyDefinition
for _, row := range data {
fk := connection.ForeignKeyDefinition{
Name: fmt.Sprintf("%v", row["constraint_name"]),
ColumnName: fmt.Sprintf("%v", row["column_name"]),
RefTableName: fmt.Sprintf("%v", row["foreign_table_name"]),
RefColumnName: fmt.Sprintf("%v", row["foreign_column_name"]),
ConstraintName: fmt.Sprintf("%v", row["constraint_name"]),
}
fks = append(fks, fk)
}
return fks, nil
}
func (k *KingbaseDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
// 解析 schema.table 格式
schema := strings.TrimSpace(dbName)
table := strings.TrimSpace(tableName)
// 如果 tableName 包含 schema (格式: schema.table)
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
parsedSchema := strings.TrimSpace(parts[0])
parsedTable := strings.TrimSpace(parts[1])
if parsedSchema != "" && parsedTable != "" {
schema = parsedSchema
table = parsedTable
}
}
if table == "" {
return nil, fmt.Errorf("表名不能为空")
}
// 转义函数:处理单引号,移除双引号
esc := func(s string) string {
s = strings.Trim(s, "\"")
return strings.ReplaceAll(s, "'", "''")
}
// 构建查询如果指定了schema,也加上schema条件
var query string
if schema != "" {
query = fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation
FROM information_schema.triggers
WHERE event_object_table = '%s' AND event_object_schema = '%s'`,
esc(table), esc(schema))
} else {
query = fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation
FROM information_schema.triggers
WHERE event_object_table = '%s' AND event_object_schema = current_schema()`,
esc(table))
}
data, _, err := k.Query(query)
if err != nil {
return nil, err
}
var triggers []connection.TriggerDefinition
for _, row := range data {
trig := connection.TriggerDefinition{
Name: fmt.Sprintf("%v", row["trigger_name"]),
Timing: fmt.Sprintf("%v", row["action_timing"]),
Event: fmt.Sprintf("%v", row["event_manipulation"]),
Statement: "SOURCE HIDDEN",
}
triggers = append(triggers, trig)
}
return triggers, nil
}
func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
if k.conn == nil {
return fmt.Errorf("连接未打开")
}
tx, err := k.conn.Begin()
if err != nil {
return err
}
defer tx.Rollback()
schema, table := splitKingbaseQualifiedTable(tableName)
if table == "" {
return fmt.Errorf("表名不能为空")
}
qualifiedTable := ""
if schema != "" {
qualifiedTable = fmt.Sprintf("%s.%s", quoteKingbaseIdent(schema), quoteKingbaseIdent(table))
} else {
qualifiedTable = quoteKingbaseIdent(table)
}
// 1. Deletes
for _, pk := range changes.Deletes {
var wheres []string
var args []interface{}
idx := 0
for k, v := range pk {
idx++
wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteKingbaseIdent(k), idx))
args = append(args, v)
}
if len(wheres) == 0 {
continue
}
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
if _, err := tx.Exec(query, args...); err != nil {
return fmt.Errorf("delete error: %v; sql=%s", err, query)
}
}
// 2. Updates
for _, update := range changes.Updates {
var sets []string
var args []interface{}
idx := 0
for k, v := range update.Values {
idx++
sets = append(sets, fmt.Sprintf("%s = $%d", quoteKingbaseIdent(k), idx))
args = append(args, v)
}
if len(sets) == 0 {
continue
}
var wheres []string
for k, v := range update.Keys {
idx++
wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteKingbaseIdent(k), idx))
args = append(args, v)
}
if len(wheres) == 0 {
return fmt.Errorf("更新操作需要主键条件")
}
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
if _, err := tx.Exec(query, args...); err != nil {
return fmt.Errorf("update error: %v; sql=%s", err, query)
}
}
// 3. Inserts
for _, row := range changes.Inserts {
var cols []string
var placeholders []string
var args []interface{}
idx := 0
for k, v := range row {
idx++
cols = append(cols, quoteKingbaseIdent(k))
placeholders = append(placeholders, fmt.Sprintf("$%d", idx))
args = append(args, v)
}
if len(cols) == 0 {
continue
}
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
if _, err := tx.Exec(query, args...); err != nil {
return fmt.Errorf("插入失败:%v; sql=%s", err, query)
}
}
return tx.Commit()
}
func normalizeKingbaseIdentifier(raw string) string {
return normalizeKingbaseIdentCommon(raw)
}
// kingbaseIdentNeedsQuote 判断标识符是否需要双引号包裹。
// 与前端 sql.ts 中 needsQuote 逻辑保持一致。
func kingbaseIdentNeedsQuote(ident string) bool {
if ident == "" {
return false
}
// 不是合法裸标识符格式(必须以字母或下划线开头,仅含字母、数字、下划线)
if matched, _ := regexp.MatchString(`^[a-zA-Z_][a-zA-Z0-9_]*$`, ident); !matched {
return true
}
// 包含大写字母时需要引号保护KingbaseES/PostgreSQL 默认将未加引号的标识符折叠为小写)
for _, r := range ident {
if r >= 'A' && r <= 'Z' {
return true
}
}
// 是 SQL 保留字
return isKingbaseReservedWord(ident)
}
// isKingbaseReservedWord 检查是否为常见 SQL 保留字(简化版,与前端保持一致)。
func isKingbaseReservedWord(ident string) bool {
switch strings.ToLower(ident) {
case "select", "from", "where", "table", "index", "user", "order", "group", "by",
"limit", "offset", "and", "or", "not", "null", "true", "false", "key",
"primary", "foreign", "references", "default", "constraint",
"create", "drop", "alter", "insert", "update", "delete", "set", "values", "into",
"join", "left", "right", "inner", "outer", "on", "as", "is", "in", "like",
"between", "case", "when", "then", "else", "end", "having", "distinct",
"all", "any", "exists", "union", "except", "intersect",
"column", "check", "unique", "with", "grant", "revoke", "trigger",
"begin", "commit", "rollback", "schema", "database", "view", "function",
"procedure", "sequence", "type", "domain", "role", "session", "current",
"authorization", "cross", "full", "natural", "some", "cast", "fetch",
"for", "to", "do", "if", "return", "returns", "declare", "cursor", "server", "owner":
return true
}
return false
}
func quoteKingbaseIdent(name string) string {
n := normalizeKingbaseIdentifier(name)
if n == "" {
return "\"\""
}
if !kingbaseIdentNeedsQuote(n) {
return n
}
n = strings.ReplaceAll(n, `"`, `""`)
return `"` + n + `"`
}
func splitKingbaseQualifiedTable(tableName string) (schema string, table string) {
return splitKingbaseQualifiedNameCommon(tableName)
}
func (k *KingbaseDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
// dbName 在本项目语义里是“数据库”schema 由 table_schema 决定;这里返回全部用户 schema 的列用于查询提示。
query := `
SELECT table_schema, table_name, column_name, data_type
FROM information_schema.columns
WHERE table_schema NOT IN ('pg_catalog', 'information_schema')
AND table_schema NOT LIKE 'pg_%'
ORDER BY table_schema, table_name, ordinal_position`
data, _, err := k.Query(query)
if err != nil {
return nil, err
}
var cols []connection.ColumnDefinitionWithTable
for _, row := range data {
schema := fmt.Sprintf("%v", row["table_schema"])
table := fmt.Sprintf("%v", row["table_name"])
tableName := table
if strings.TrimSpace(schema) != "" {
tableName = fmt.Sprintf("%s.%s", schema, table)
}
col := connection.ColumnDefinitionWithTable{
TableName: tableName,
Name: fmt.Sprintf("%v", row["column_name"]),
Type: fmt.Sprintf("%v", row["data_type"]),
}
cols = append(cols, col)
}
return cols, nil
}