mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-05-07 08:22:42 +08:00
1004 lines
27 KiB
Go
1004 lines
27 KiB
Go
//go:build gonavi_full_drivers || gonavi_kingbase_driver
|
||
|
||
package db
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"fmt"
|
||
"net"
|
||
"regexp"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"GoNavi-Wails/internal/connection"
|
||
"GoNavi-Wails/internal/logger"
|
||
"GoNavi-Wails/internal/ssh"
|
||
"GoNavi-Wails/internal/utils"
|
||
|
||
_ "gitea.com/kingbase/gokb" // Registers "kingbase" driver
|
||
)
|
||
|
||
type KingbaseDB struct {
|
||
conn *sql.DB
|
||
pingTimeout time.Duration
|
||
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
|
||
}
|
||
|
||
func quoteConnValue(v string) string {
|
||
if v == "" {
|
||
return "''"
|
||
}
|
||
|
||
needsQuote := false
|
||
for _, r := range v {
|
||
switch r {
|
||
case ' ', '\t', '\n', '\r', '\v', '\f', '\'', '\\':
|
||
needsQuote = true
|
||
}
|
||
if needsQuote {
|
||
break
|
||
}
|
||
}
|
||
if !needsQuote {
|
||
return v
|
||
}
|
||
|
||
var b strings.Builder
|
||
b.Grow(len(v) + 2)
|
||
b.WriteByte('\'')
|
||
for _, r := range v {
|
||
if r == '\\' || r == '\'' {
|
||
b.WriteByte('\\')
|
||
}
|
||
b.WriteRune(r)
|
||
}
|
||
b.WriteByte('\'')
|
||
return b.String()
|
||
}
|
||
|
||
func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string {
|
||
// Kingbase DSN usually similar to Postgres:
|
||
// host=localhost port=54321 user=system password=... dbname=TEST sslmode=disable
|
||
|
||
address := config.Host
|
||
port := config.Port
|
||
|
||
// Construct DSN
|
||
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s connect_timeout=%d",
|
||
quoteConnValue(address),
|
||
port,
|
||
quoteConnValue(config.User),
|
||
quoteConnValue(config.Password),
|
||
quoteConnValue(config.Database),
|
||
quoteConnValue(resolvePostgresSSLMode(config)),
|
||
getConnectTimeoutSeconds(config),
|
||
)
|
||
|
||
return dsn
|
||
}
|
||
|
||
func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error {
|
||
runConfig := config
|
||
|
||
if config.UseSSH {
|
||
// Create SSH tunnel with local port forwarding
|
||
logger.Infof("人大金仓使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
|
||
|
||
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
|
||
if err != nil {
|
||
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
|
||
}
|
||
k.forwarder = forwarder
|
||
|
||
// Parse local address
|
||
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
|
||
if err != nil {
|
||
return fmt.Errorf("解析本地转发地址失败:%w", err)
|
||
}
|
||
|
||
port, err := strconv.Atoi(portStr)
|
||
if err != nil {
|
||
return fmt.Errorf("解析本地端口失败:%w", err)
|
||
}
|
||
|
||
// Create a modified config pointing to local forwarder
|
||
localConfig := config
|
||
localConfig.Host = host
|
||
localConfig.Port = port
|
||
localConfig.UseSSH = false
|
||
|
||
runConfig = localConfig
|
||
logger.Infof("人大金仓通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
|
||
}
|
||
|
||
attempts := []connection.ConnectionConfig{runConfig}
|
||
if shouldTrySSLPreferredFallback(runConfig) {
|
||
attempts = append(attempts, withSSLDisabled(runConfig))
|
||
}
|
||
|
||
var failures []string
|
||
for idx, attempt := range attempts {
|
||
dsn := k.getDSN(attempt)
|
||
db, err := sql.Open("kingbase", dsn)
|
||
if err != nil {
|
||
failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err))
|
||
continue
|
||
}
|
||
k.conn = db
|
||
k.pingTimeout = getConnectTimeout(attempt)
|
||
if err := k.Ping(); err != nil {
|
||
_ = db.Close()
|
||
k.conn = nil
|
||
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err))
|
||
continue
|
||
}
|
||
if idx > 0 {
|
||
logger.Warnf("人大金仓 SSL 优先连接失败,已回退至明文连接")
|
||
}
|
||
|
||
// 获取 schema 列表以重构带有 search_path 的连接池
|
||
searchPathStr := k.getSearchPathStr()
|
||
if searchPathStr != "" {
|
||
// 将 search_path 参数拼入 DSN
|
||
finalDSN := dsn + " search_path=" + quoteConnValue(searchPathStr)
|
||
if finalDB, err := sql.Open("kingbase", finalDSN); err == nil {
|
||
k.pingTimeout = getConnectTimeout(attempt)
|
||
finalDB.SetConnMaxLifetime(5 * time.Minute)
|
||
|
||
// 临时将 k.conn 指向 finalDB 来做 ping 测试
|
||
oldConn := k.conn
|
||
k.conn = finalDB
|
||
if err := k.Ping(); err == nil {
|
||
// 成功使用带 search_path 的连接池
|
||
_ = oldConn.Close()
|
||
logger.Infof("人大金仓已配置连接级 search_path:%s", searchPathStr)
|
||
} else {
|
||
_ = finalDB.Close()
|
||
k.conn = oldConn
|
||
}
|
||
}
|
||
}
|
||
if searchPathStr != "" {
|
||
timeout := k.pingTimeout
|
||
if timeout <= 0 {
|
||
timeout = 5 * time.Second
|
||
}
|
||
ctx, cancel := utils.ContextWithTimeout(timeout)
|
||
defer cancel()
|
||
if _, err := k.conn.ExecContext(ctx, fmt.Sprintf("SET search_path TO %s", searchPathStr)); err != nil {
|
||
logger.Warnf("人大金仓显式设置 search_path 失败:%v", err)
|
||
} else {
|
||
logger.Infof("人大金仓已设置默认 search_path:%s", searchPathStr)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ";"))
|
||
}
|
||
|
||
// getSearchPathStr 查询当前数据库中所有用户 schema,配置 DSN 的 search_path。
|
||
// KingBase 默认 search_path 为 "$user", public,对于自定义 schema 下的表不可见。
|
||
func (k *KingbaseDB) getSearchPathStr() string {
|
||
if k.conn == nil {
|
||
return ""
|
||
}
|
||
|
||
query := `SELECT nspname FROM pg_namespace
|
||
WHERE nspname NOT IN ('pg_catalog', 'information_schema')
|
||
AND nspname NOT LIKE 'pg_%'
|
||
ORDER BY nspname`
|
||
|
||
rows, err := k.conn.Query(query)
|
||
if err != nil {
|
||
logger.Warnf("人大金仓查询用户 schema 失败,跳过 search_path 设置:%v", err)
|
||
return ""
|
||
}
|
||
defer rows.Close()
|
||
|
||
var rawSchemas []string
|
||
for rows.Next() {
|
||
var name string
|
||
if err := rows.Scan(&name); err != nil {
|
||
continue
|
||
}
|
||
name = strings.TrimSpace(name)
|
||
if name != "" {
|
||
rawSchemas = append(rawSchemas, name)
|
||
}
|
||
}
|
||
|
||
searchPath, _ := buildKingbaseSearchPathCommon(rawSchemas)
|
||
return searchPath
|
||
}
|
||
|
||
func (k *KingbaseDB) Close() error {
|
||
// Close SSH forwarder first if exists
|
||
if k.forwarder != nil {
|
||
if err := k.forwarder.Close(); err != nil {
|
||
logger.Warnf("关闭人大金仓 SSH 端口转发失败:%v", err)
|
||
}
|
||
k.forwarder = nil
|
||
}
|
||
|
||
// Then close database connection
|
||
if k.conn != nil {
|
||
return k.conn.Close()
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (k *KingbaseDB) Ping() error {
|
||
if k.conn == nil {
|
||
return fmt.Errorf("连接未打开")
|
||
}
|
||
timeout := k.pingTimeout
|
||
if timeout <= 0 {
|
||
timeout = 5 * time.Second
|
||
}
|
||
ctx, cancel := utils.ContextWithTimeout(timeout)
|
||
defer cancel()
|
||
return k.conn.PingContext(ctx)
|
||
}
|
||
|
||
func (k *KingbaseDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||
if k.conn == nil {
|
||
return nil, nil, fmt.Errorf("连接未打开")
|
||
}
|
||
|
||
rows, err := k.conn.QueryContext(ctx, query)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
return scanRows(rows)
|
||
}
|
||
|
||
func (k *KingbaseDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||
if k.conn == nil {
|
||
return nil, nil, fmt.Errorf("连接未打开")
|
||
}
|
||
|
||
rows, err := k.conn.Query(query)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
defer rows.Close()
|
||
return scanRows(rows)
|
||
}
|
||
|
||
func (k *KingbaseDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||
if k.conn == nil {
|
||
return 0, fmt.Errorf("连接未打开")
|
||
}
|
||
res, err := k.conn.ExecContext(ctx, query)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return res.RowsAffected()
|
||
}
|
||
|
||
func (k *KingbaseDB) ExecBatchContext(ctx context.Context, query string) (int64, error) {
|
||
if k.conn == nil {
|
||
return 0, fmt.Errorf("连接未打开")
|
||
}
|
||
res, err := k.conn.ExecContext(ctx, query)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return res.RowsAffected()
|
||
}
|
||
|
||
func (k *KingbaseDB) Exec(query string) (int64, error) {
|
||
if k.conn == nil {
|
||
return 0, fmt.Errorf("连接未打开")
|
||
}
|
||
res, err := k.conn.Exec(query)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return res.RowsAffected()
|
||
}
|
||
|
||
func (k *KingbaseDB) GetDatabases() ([]string, error) {
|
||
data, _, err := k.Query("SELECT datname FROM pg_database WHERE datistemplate = false")
|
||
if err == nil {
|
||
dbs := collectKingbaseNames(data, "datname", "database")
|
||
if len(dbs) > 0 {
|
||
return dbs, nil
|
||
}
|
||
}
|
||
|
||
fallbackData, _, fallbackErr := k.Query("SELECT current_database() AS datname")
|
||
if fallbackErr != nil {
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return nil, fallbackErr
|
||
}
|
||
|
||
dbs := collectKingbaseNames(fallbackData, "datname", "database", "current_database", "currentDatabase")
|
||
if len(dbs) > 0 {
|
||
return dbs, nil
|
||
}
|
||
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return nil, fmt.Errorf("未获取到可见数据库列表")
|
||
}
|
||
|
||
func collectKingbaseNames(rows []map[string]interface{}, keys ...string) []string {
|
||
result := make([]string, 0, len(rows))
|
||
seen := make(map[string]struct{}, len(rows))
|
||
for _, row := range rows {
|
||
name := strings.TrimSpace(getKingbaseNameFromRow(row, keys...))
|
||
if name == "" {
|
||
continue
|
||
}
|
||
if _, exists := seen[name]; exists {
|
||
continue
|
||
}
|
||
seen[name] = struct{}{}
|
||
result = append(result, name)
|
||
}
|
||
return result
|
||
}
|
||
|
||
func getKingbaseNameFromRow(row map[string]interface{}, keys ...string) string {
|
||
if len(row) == 0 {
|
||
return ""
|
||
}
|
||
for _, key := range keys {
|
||
if value, ok := row[key]; ok {
|
||
return fmt.Sprintf("%v", value)
|
||
}
|
||
}
|
||
for existingKey, value := range row {
|
||
for _, key := range keys {
|
||
if strings.EqualFold(existingKey, key) {
|
||
return fmt.Sprintf("%v", value)
|
||
}
|
||
}
|
||
}
|
||
for _, value := range row {
|
||
return fmt.Sprintf("%v", value)
|
||
}
|
||
return ""
|
||
}
|
||
|
||
func (k *KingbaseDB) GetTables(dbName string) ([]string, error) {
|
||
// Kingbase: tables are scoped by the current DB connection; include schema to avoid search_path issues.
|
||
query := `
|
||
SELECT table_schema AS schemaname, table_name AS tablename
|
||
FROM information_schema.tables
|
||
WHERE table_type = 'BASE TABLE'
|
||
AND table_schema NOT IN ('pg_catalog', 'information_schema')
|
||
AND table_schema NOT LIKE 'pg_%'
|
||
ORDER BY table_schema, table_name`
|
||
|
||
data, _, err := k.Query(query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var tables []string
|
||
for _, row := range data {
|
||
schema, okSchema := row["schemaname"]
|
||
name, okName := row["tablename"]
|
||
if okSchema && okName {
|
||
tables = append(tables, fmt.Sprintf("%v.%v", schema, name))
|
||
continue
|
||
}
|
||
if val, ok := row["table_name"]; ok {
|
||
tables = append(tables, fmt.Sprintf("%v", val))
|
||
}
|
||
}
|
||
return tables, nil
|
||
}
|
||
|
||
func (k *KingbaseDB) GetCreateStatement(dbName, tableName string) (string, error) {
|
||
// Kingbase doesn't have "SHOW CREATE TABLE".
|
||
// We can try pg_dump logic or use a query to reconstruction.
|
||
// A simple approach is just returning basic info or "Not Supported".
|
||
// Or we can query information_schema to build it.
|
||
return "SHOW CREATE TABLE not directly supported in Kingbase/Postgres via SQL", nil
|
||
}
|
||
|
||
func (k *KingbaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||
// 解析 schema.table 格式
|
||
schema := strings.TrimSpace(dbName)
|
||
table := strings.TrimSpace(tableName)
|
||
|
||
// 如果 tableName 包含 schema (格式: schema.table)
|
||
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
|
||
parsedSchema := strings.TrimSpace(parts[0])
|
||
parsedTable := strings.TrimSpace(parts[1])
|
||
if parsedSchema != "" && parsedTable != "" {
|
||
schema = parsedSchema
|
||
table = parsedTable
|
||
}
|
||
}
|
||
|
||
// 如果仍然没有 schema,使用 current_schema()
|
||
// 这样可以自动匹配当前连接的 search_path
|
||
if schema == "" {
|
||
return k.getColumnsWithCurrentSchema(table)
|
||
}
|
||
|
||
if table == "" {
|
||
return nil, fmt.Errorf("表名不能为空")
|
||
}
|
||
|
||
// 转义函数:处理单引号,移除双引号
|
||
esc := func(s string) string {
|
||
// 移除前后的双引号(如果存在)
|
||
s = strings.Trim(s, "\"")
|
||
// 转义单引号
|
||
return strings.ReplaceAll(s, "'", "''")
|
||
}
|
||
|
||
query := fmt.Sprintf(`
|
||
SELECT
|
||
a.attname AS column_name,
|
||
pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
|
||
CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable,
|
||
pg_get_expr(ad.adbin, ad.adrelid) AS column_default,
|
||
col_description(a.attrelid, a.attnum) AS comment,
|
||
CASE WHEN pk.attname IS NOT NULL THEN 'PRI' ELSE '' END AS column_key
|
||
FROM pg_class c
|
||
JOIN pg_namespace n ON n.oid = c.relnamespace
|
||
JOIN pg_attribute a ON a.attrelid = c.oid
|
||
LEFT JOIN pg_attrdef ad ON ad.adrelid = c.oid AND ad.adnum = a.attnum
|
||
LEFT JOIN (
|
||
SELECT i.indrelid, a3.attname
|
||
FROM pg_index i
|
||
JOIN pg_attribute a3 ON a3.attrelid = i.indrelid AND a3.attnum = ANY(i.indkey)
|
||
WHERE i.indisprimary
|
||
) pk ON pk.indrelid = c.oid AND pk.attname = a.attname
|
||
WHERE c.relkind IN ('r', 'p')
|
||
AND n.nspname = '%s'
|
||
AND c.relname = '%s'
|
||
AND a.attnum > 0
|
||
AND NOT a.attisdropped
|
||
ORDER BY a.attnum`, esc(schema), esc(table))
|
||
|
||
data, _, err := k.Query(query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var columns []connection.ColumnDefinition
|
||
for _, row := range data {
|
||
col := connection.ColumnDefinition{
|
||
Name: fmt.Sprintf("%v", row["column_name"]),
|
||
Type: fmt.Sprintf("%v", row["data_type"]),
|
||
Nullable: fmt.Sprintf("%v", row["is_nullable"]),
|
||
Key: fmt.Sprintf("%v", row["column_key"]),
|
||
Extra: "",
|
||
Comment: "",
|
||
}
|
||
|
||
if row["column_default"] != nil {
|
||
def := fmt.Sprintf("%v", row["column_default"])
|
||
col.Default = &def
|
||
if strings.HasPrefix(strings.ToLower(strings.TrimSpace(def)), "nextval(") {
|
||
col.Extra = "auto_increment"
|
||
}
|
||
}
|
||
|
||
if v, ok := row["comment"]; ok && v != nil {
|
||
col.Comment = fmt.Sprintf("%v", v)
|
||
}
|
||
|
||
columns = append(columns, col)
|
||
}
|
||
return columns, nil
|
||
}
|
||
|
||
// getColumnsWithCurrentSchema 使用 current_schema() 查询当前schema的表
|
||
func (k *KingbaseDB) getColumnsWithCurrentSchema(tableName string) ([]connection.ColumnDefinition, error) {
|
||
table := strings.TrimSpace(tableName)
|
||
if table == "" {
|
||
return nil, fmt.Errorf("表名不能为空")
|
||
}
|
||
|
||
// 转义函数
|
||
esc := func(s string) string {
|
||
s = strings.Trim(s, "\"")
|
||
return strings.ReplaceAll(s, "'", "''")
|
||
}
|
||
|
||
// 使用 current_schema() 获取当前schema
|
||
query := fmt.Sprintf(`
|
||
SELECT
|
||
a.attname AS column_name,
|
||
pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
|
||
CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable,
|
||
pg_get_expr(ad.adbin, ad.adrelid) AS column_default,
|
||
col_description(a.attrelid, a.attnum) AS comment,
|
||
CASE WHEN pk.attname IS NOT NULL THEN 'PRI' ELSE '' END AS column_key
|
||
FROM pg_class c
|
||
JOIN pg_namespace n ON n.oid = c.relnamespace
|
||
JOIN pg_attribute a ON a.attrelid = c.oid
|
||
LEFT JOIN pg_attrdef ad ON ad.adrelid = c.oid AND ad.adnum = a.attnum
|
||
LEFT JOIN (
|
||
SELECT i.indrelid, a3.attname
|
||
FROM pg_index i
|
||
JOIN pg_attribute a3 ON a3.attrelid = i.indrelid AND a3.attnum = ANY(i.indkey)
|
||
WHERE i.indisprimary
|
||
) pk ON pk.indrelid = c.oid AND pk.attname = a.attname
|
||
WHERE c.relkind IN ('r', 'p')
|
||
AND n.nspname = current_schema()
|
||
AND c.relname = '%s'
|
||
AND a.attnum > 0
|
||
AND NOT a.attisdropped
|
||
ORDER BY a.attnum`, esc(table))
|
||
|
||
data, _, err := k.Query(query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var columns []connection.ColumnDefinition
|
||
for _, row := range data {
|
||
col := connection.ColumnDefinition{
|
||
Name: fmt.Sprintf("%v", row["column_name"]),
|
||
Type: fmt.Sprintf("%v", row["data_type"]),
|
||
Nullable: fmt.Sprintf("%v", row["is_nullable"]),
|
||
Key: fmt.Sprintf("%v", row["column_key"]),
|
||
Extra: "",
|
||
Comment: "",
|
||
}
|
||
|
||
if row["column_default"] != nil {
|
||
def := fmt.Sprintf("%v", row["column_default"])
|
||
col.Default = &def
|
||
if strings.HasPrefix(strings.ToLower(strings.TrimSpace(def)), "nextval(") {
|
||
col.Extra = "auto_increment"
|
||
}
|
||
}
|
||
|
||
if v, ok := row["comment"]; ok && v != nil {
|
||
col.Comment = fmt.Sprintf("%v", v)
|
||
}
|
||
|
||
columns = append(columns, col)
|
||
}
|
||
return columns, nil
|
||
}
|
||
|
||
func (k *KingbaseDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
|
||
// 解析 schema.table 格式
|
||
schema := strings.TrimSpace(dbName)
|
||
table := strings.TrimSpace(tableName)
|
||
|
||
// 如果 tableName 包含 schema (格式: schema.table)
|
||
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
|
||
parsedSchema := strings.TrimSpace(parts[0])
|
||
parsedTable := strings.TrimSpace(parts[1])
|
||
if parsedSchema != "" && parsedTable != "" {
|
||
schema = parsedSchema
|
||
table = parsedTable
|
||
}
|
||
}
|
||
|
||
if table == "" {
|
||
return nil, fmt.Errorf("表名不能为空")
|
||
}
|
||
|
||
// 转义函数:处理单引号,移除双引号
|
||
esc := func(s string) string {
|
||
s = strings.Trim(s, "\"")
|
||
return strings.ReplaceAll(s, "'", "''")
|
||
}
|
||
|
||
// 构建查询:如果没有指定schema,使用current_schema()
|
||
var query string
|
||
if schema != "" {
|
||
query = fmt.Sprintf(`
|
||
SELECT
|
||
i.relname as index_name,
|
||
a.attname as column_name,
|
||
ix.indisunique as is_unique
|
||
FROM
|
||
pg_class t,
|
||
pg_class i,
|
||
pg_index ix,
|
||
pg_attribute a,
|
||
pg_namespace n
|
||
WHERE
|
||
t.oid = ix.indrelid
|
||
AND i.oid = ix.indexrelid
|
||
AND a.attrelid = t.oid
|
||
AND a.attnum = ANY(ix.indkey)
|
||
AND t.relkind = 'r'
|
||
AND t.relname = '%s'
|
||
AND n.oid = t.relnamespace
|
||
AND n.nspname = '%s'
|
||
`, esc(table), esc(schema))
|
||
} else {
|
||
query = fmt.Sprintf(`
|
||
SELECT
|
||
i.relname as index_name,
|
||
a.attname as column_name,
|
||
ix.indisunique as is_unique
|
||
FROM
|
||
pg_class t,
|
||
pg_class i,
|
||
pg_index ix,
|
||
pg_attribute a,
|
||
pg_namespace n
|
||
WHERE
|
||
t.oid = ix.indrelid
|
||
AND i.oid = ix.indexrelid
|
||
AND a.attrelid = t.oid
|
||
AND a.attnum = ANY(ix.indkey)
|
||
AND t.relkind = 'r'
|
||
AND t.relname = '%s'
|
||
AND n.oid = t.relnamespace
|
||
AND n.nspname = current_schema()
|
||
`, esc(table))
|
||
}
|
||
|
||
data, _, err := k.Query(query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var indexes []connection.IndexDefinition
|
||
for _, row := range data {
|
||
nonUnique := 1
|
||
if val, ok := row["is_unique"]; ok {
|
||
if b, ok := val.(bool); ok && b {
|
||
nonUnique = 0
|
||
}
|
||
}
|
||
|
||
idx := connection.IndexDefinition{
|
||
Name: fmt.Sprintf("%v", row["index_name"]),
|
||
ColumnName: fmt.Sprintf("%v", row["column_name"]),
|
||
NonUnique: nonUnique,
|
||
IndexType: "BTREE", // Default
|
||
}
|
||
indexes = append(indexes, idx)
|
||
}
|
||
return indexes, nil
|
||
}
|
||
|
||
func (k *KingbaseDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
|
||
// 解析 schema.table 格式
|
||
schema := strings.TrimSpace(dbName)
|
||
table := strings.TrimSpace(tableName)
|
||
|
||
// 如果 tableName 包含 schema (格式: schema.table)
|
||
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
|
||
parsedSchema := strings.TrimSpace(parts[0])
|
||
parsedTable := strings.TrimSpace(parts[1])
|
||
if parsedSchema != "" && parsedTable != "" {
|
||
schema = parsedSchema
|
||
table = parsedTable
|
||
}
|
||
}
|
||
|
||
if table == "" {
|
||
return nil, fmt.Errorf("表名不能为空")
|
||
}
|
||
|
||
// 转义函数:处理单引号,移除双引号
|
||
esc := func(s string) string {
|
||
s = strings.Trim(s, "\"")
|
||
return strings.ReplaceAll(s, "'", "''")
|
||
}
|
||
|
||
// 构建查询:如果没有指定schema,使用current_schema()
|
||
var query string
|
||
if schema != "" {
|
||
query = fmt.Sprintf(`
|
||
SELECT
|
||
tc.constraint_name,
|
||
kcu.column_name,
|
||
ccu.table_name AS foreign_table_name,
|
||
ccu.column_name AS foreign_column_name
|
||
FROM
|
||
information_schema.table_constraints AS tc
|
||
JOIN information_schema.key_column_usage AS kcu
|
||
ON tc.constraint_name = kcu.constraint_name
|
||
AND tc.table_schema = kcu.table_schema
|
||
JOIN information_schema.constraint_column_usage AS ccu
|
||
ON ccu.constraint_name = tc.constraint_name
|
||
AND ccu.table_schema = tc.table_schema
|
||
WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema='%s'`,
|
||
esc(table), esc(schema))
|
||
} else {
|
||
query = fmt.Sprintf(`
|
||
SELECT
|
||
tc.constraint_name,
|
||
kcu.column_name,
|
||
ccu.table_name AS foreign_table_name,
|
||
ccu.column_name AS foreign_column_name
|
||
FROM
|
||
information_schema.table_constraints AS tc
|
||
JOIN information_schema.key_column_usage AS kcu
|
||
ON tc.constraint_name = kcu.constraint_name
|
||
AND tc.table_schema = kcu.table_schema
|
||
JOIN information_schema.constraint_column_usage AS ccu
|
||
ON ccu.constraint_name = tc.constraint_name
|
||
AND ccu.table_schema = tc.table_schema
|
||
WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema=current_schema()`,
|
||
esc(table))
|
||
}
|
||
|
||
data, _, err := k.Query(query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var fks []connection.ForeignKeyDefinition
|
||
for _, row := range data {
|
||
fk := connection.ForeignKeyDefinition{
|
||
Name: fmt.Sprintf("%v", row["constraint_name"]),
|
||
ColumnName: fmt.Sprintf("%v", row["column_name"]),
|
||
RefTableName: fmt.Sprintf("%v", row["foreign_table_name"]),
|
||
RefColumnName: fmt.Sprintf("%v", row["foreign_column_name"]),
|
||
ConstraintName: fmt.Sprintf("%v", row["constraint_name"]),
|
||
}
|
||
fks = append(fks, fk)
|
||
}
|
||
return fks, nil
|
||
}
|
||
|
||
func (k *KingbaseDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
|
||
// 解析 schema.table 格式
|
||
schema := strings.TrimSpace(dbName)
|
||
table := strings.TrimSpace(tableName)
|
||
|
||
// 如果 tableName 包含 schema (格式: schema.table)
|
||
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
|
||
parsedSchema := strings.TrimSpace(parts[0])
|
||
parsedTable := strings.TrimSpace(parts[1])
|
||
if parsedSchema != "" && parsedTable != "" {
|
||
schema = parsedSchema
|
||
table = parsedTable
|
||
}
|
||
}
|
||
|
||
if table == "" {
|
||
return nil, fmt.Errorf("表名不能为空")
|
||
}
|
||
|
||
// 转义函数:处理单引号,移除双引号
|
||
esc := func(s string) string {
|
||
s = strings.Trim(s, "\"")
|
||
return strings.ReplaceAll(s, "'", "''")
|
||
}
|
||
|
||
// 构建查询:如果指定了schema,也加上schema条件
|
||
var query string
|
||
if schema != "" {
|
||
query = fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation
|
||
FROM information_schema.triggers
|
||
WHERE event_object_table = '%s' AND event_object_schema = '%s'`,
|
||
esc(table), esc(schema))
|
||
} else {
|
||
query = fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation
|
||
FROM information_schema.triggers
|
||
WHERE event_object_table = '%s' AND event_object_schema = current_schema()`,
|
||
esc(table))
|
||
}
|
||
|
||
data, _, err := k.Query(query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var triggers []connection.TriggerDefinition
|
||
for _, row := range data {
|
||
trig := connection.TriggerDefinition{
|
||
Name: fmt.Sprintf("%v", row["trigger_name"]),
|
||
Timing: fmt.Sprintf("%v", row["action_timing"]),
|
||
Event: fmt.Sprintf("%v", row["event_manipulation"]),
|
||
Statement: "SOURCE HIDDEN",
|
||
}
|
||
triggers = append(triggers, trig)
|
||
}
|
||
return triggers, nil
|
||
}
|
||
|
||
func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
|
||
if k.conn == nil {
|
||
return fmt.Errorf("连接未打开")
|
||
}
|
||
|
||
tx, err := k.conn.Begin()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer tx.Rollback()
|
||
|
||
schema, table := splitKingbaseQualifiedTable(tableName)
|
||
if table == "" {
|
||
return fmt.Errorf("表名不能为空")
|
||
}
|
||
|
||
qualifiedTable := ""
|
||
if schema != "" {
|
||
qualifiedTable = fmt.Sprintf("%s.%s", quoteKingbaseIdent(schema), quoteKingbaseIdent(table))
|
||
} else {
|
||
qualifiedTable = quoteKingbaseIdent(table)
|
||
}
|
||
|
||
// 1. Deletes
|
||
for _, pk := range changes.Deletes {
|
||
var wheres []string
|
||
var args []interface{}
|
||
idx := 0
|
||
for k, v := range pk {
|
||
idx++
|
||
wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteKingbaseIdent(k), idx))
|
||
args = append(args, v)
|
||
}
|
||
if len(wheres) == 0 {
|
||
continue
|
||
}
|
||
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
|
||
if _, err := tx.Exec(query, args...); err != nil {
|
||
return fmt.Errorf("delete error: %v; sql=%s", err, query)
|
||
}
|
||
}
|
||
|
||
// 2. Updates
|
||
for _, update := range changes.Updates {
|
||
var sets []string
|
||
var args []interface{}
|
||
idx := 0
|
||
|
||
for k, v := range update.Values {
|
||
idx++
|
||
sets = append(sets, fmt.Sprintf("%s = $%d", quoteKingbaseIdent(k), idx))
|
||
args = append(args, v)
|
||
}
|
||
|
||
if len(sets) == 0 {
|
||
continue
|
||
}
|
||
|
||
var wheres []string
|
||
for k, v := range update.Keys {
|
||
idx++
|
||
wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteKingbaseIdent(k), idx))
|
||
args = append(args, v)
|
||
}
|
||
|
||
if len(wheres) == 0 {
|
||
return fmt.Errorf("更新操作需要主键条件")
|
||
}
|
||
|
||
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
|
||
if _, err := tx.Exec(query, args...); err != nil {
|
||
return fmt.Errorf("update error: %v; sql=%s", err, query)
|
||
}
|
||
}
|
||
|
||
// 3. Inserts
|
||
for _, row := range changes.Inserts {
|
||
var cols []string
|
||
var placeholders []string
|
||
var args []interface{}
|
||
idx := 0
|
||
|
||
for k, v := range row {
|
||
idx++
|
||
cols = append(cols, quoteKingbaseIdent(k))
|
||
placeholders = append(placeholders, fmt.Sprintf("$%d", idx))
|
||
args = append(args, v)
|
||
}
|
||
|
||
if len(cols) == 0 {
|
||
continue
|
||
}
|
||
|
||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
|
||
if _, err := tx.Exec(query, args...); err != nil {
|
||
return fmt.Errorf("插入失败:%v; sql=%s", err, query)
|
||
}
|
||
}
|
||
|
||
return tx.Commit()
|
||
}
|
||
|
||
func normalizeKingbaseIdentifier(raw string) string {
|
||
return normalizeKingbaseIdentCommon(raw)
|
||
}
|
||
|
||
// kingbaseIdentNeedsQuote 判断标识符是否需要双引号包裹。
|
||
// 与前端 sql.ts 中 needsQuote 逻辑保持一致。
|
||
func kingbaseIdentNeedsQuote(ident string) bool {
|
||
if ident == "" {
|
||
return false
|
||
}
|
||
// 不是合法裸标识符格式(必须以字母或下划线开头,仅含字母、数字、下划线)
|
||
if matched, _ := regexp.MatchString(`^[a-zA-Z_][a-zA-Z0-9_]*$`, ident); !matched {
|
||
return true
|
||
}
|
||
// 包含大写字母时需要引号保护(KingbaseES/PostgreSQL 默认将未加引号的标识符折叠为小写)
|
||
for _, r := range ident {
|
||
if r >= 'A' && r <= 'Z' {
|
||
return true
|
||
}
|
||
}
|
||
// 是 SQL 保留字
|
||
return isKingbaseReservedWord(ident)
|
||
}
|
||
|
||
// isKingbaseReservedWord 检查是否为常见 SQL 保留字(简化版,与前端保持一致)。
|
||
func isKingbaseReservedWord(ident string) bool {
|
||
switch strings.ToLower(ident) {
|
||
case "select", "from", "where", "table", "index", "user", "order", "group", "by",
|
||
"limit", "offset", "and", "or", "not", "null", "true", "false", "key",
|
||
"primary", "foreign", "references", "default", "constraint",
|
||
"create", "drop", "alter", "insert", "update", "delete", "set", "values", "into",
|
||
"join", "left", "right", "inner", "outer", "on", "as", "is", "in", "like",
|
||
"between", "case", "when", "then", "else", "end", "having", "distinct",
|
||
"all", "any", "exists", "union", "except", "intersect",
|
||
"column", "check", "unique", "with", "grant", "revoke", "trigger",
|
||
"begin", "commit", "rollback", "schema", "database", "view", "function",
|
||
"procedure", "sequence", "type", "domain", "role", "session", "current",
|
||
"authorization", "cross", "full", "natural", "some", "cast", "fetch",
|
||
"for", "to", "do", "if", "return", "returns", "declare", "cursor", "server", "owner":
|
||
return true
|
||
}
|
||
return false
|
||
}
|
||
|
||
func quoteKingbaseIdent(name string) string {
|
||
n := normalizeKingbaseIdentifier(name)
|
||
if n == "" {
|
||
return "\"\""
|
||
}
|
||
if !kingbaseIdentNeedsQuote(n) {
|
||
return n
|
||
}
|
||
n = strings.ReplaceAll(n, `"`, `""`)
|
||
return `"` + n + `"`
|
||
}
|
||
|
||
func splitKingbaseQualifiedTable(tableName string) (schema string, table string) {
|
||
return splitKingbaseQualifiedNameCommon(tableName)
|
||
}
|
||
|
||
func (k *KingbaseDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
|
||
// dbName 在本项目语义里是“数据库”,schema 由 table_schema 决定;这里返回全部用户 schema 的列用于查询提示。
|
||
query := `
|
||
SELECT table_schema, table_name, column_name, data_type
|
||
FROM information_schema.columns
|
||
WHERE table_schema NOT IN ('pg_catalog', 'information_schema')
|
||
AND table_schema NOT LIKE 'pg_%'
|
||
ORDER BY table_schema, table_name, ordinal_position`
|
||
|
||
data, _, err := k.Query(query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var cols []connection.ColumnDefinitionWithTable
|
||
for _, row := range data {
|
||
schema := fmt.Sprintf("%v", row["table_schema"])
|
||
table := fmt.Sprintf("%v", row["table_name"])
|
||
tableName := table
|
||
if strings.TrimSpace(schema) != "" {
|
||
tableName = fmt.Sprintf("%s.%s", schema, table)
|
||
}
|
||
col := connection.ColumnDefinitionWithTable{
|
||
TableName: tableName,
|
||
Name: fmt.Sprintf("%v", row["column_name"]),
|
||
Type: fmt.Sprintf("%v", row["data_type"]),
|
||
}
|
||
cols = append(cols, col)
|
||
}
|
||
return cols, nil
|
||
}
|