mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-05-06 20:03:05 +08:00
- 前端连接表单新增额外连接参数入口,支持 URI query 格式录入与解析回填 - MySQL 兼容驱动支持 JDBC 常见参数映射,修复 UTF-8 字符集与 serverTimezone 兼容问题 - 扩展 Oracle、PostgreSQL 兼容、SQL Server、ClickHouse、MongoDB、达梦、TDengine 参数合并 - 按不同驱动通道处理 DSN、URI、Options 与 Settings,避免统一透传导致连接异常 - 修复编辑已保存连接时解析无认证 URI 会清空已有账号密码的问题 - 补充连接参数透传、缓存隔离、DSN 合并与 URI 回填回归测试
840 lines
22 KiB
Go
840 lines
22 KiB
Go
package db
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"fmt"
|
||
"net"
|
||
"net/url"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"GoNavi-Wails/internal/connection"
|
||
"GoNavi-Wails/internal/logger"
|
||
"GoNavi-Wails/internal/ssh"
|
||
"GoNavi-Wails/internal/utils"
|
||
|
||
_ "github.com/lib/pq"
|
||
)
|
||
|
||
type PostgresDB struct {
|
||
conn *sql.DB
|
||
pingTimeout time.Duration
|
||
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
|
||
}
|
||
|
||
func resolvePostgresConnectDatabases(config connection.ConnectionConfig) []string {
|
||
explicit := strings.TrimSpace(config.Database)
|
||
if explicit != "" {
|
||
return []string{explicit}
|
||
}
|
||
|
||
candidates := []string{"postgres", "template1", strings.TrimSpace(config.User)}
|
||
seen := make(map[string]struct{}, len(candidates))
|
||
result := make([]string, 0, len(candidates))
|
||
for _, name := range candidates {
|
||
trimmed := strings.TrimSpace(name)
|
||
if trimmed == "" {
|
||
continue
|
||
}
|
||
normalized := strings.ToLower(trimmed)
|
||
if _, exists := seen[normalized]; exists {
|
||
continue
|
||
}
|
||
seen[normalized] = struct{}{}
|
||
result = append(result, trimmed)
|
||
}
|
||
return result
|
||
}
|
||
|
||
func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string {
|
||
// postgres://user:password@host:port/dbname?sslmode=disable
|
||
dbname := config.Database
|
||
if dbname == "" {
|
||
dbname = "postgres" // Default DB
|
||
}
|
||
|
||
u := &url.URL{
|
||
Scheme: "postgres",
|
||
Host: net.JoinHostPort(config.Host, strconv.Itoa(config.Port)),
|
||
Path: "/" + dbname,
|
||
}
|
||
u.User = url.UserPassword(config.User, config.Password)
|
||
q := url.Values{}
|
||
q.Set("sslmode", resolvePostgresSSLMode(config))
|
||
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
|
||
mergeConnectionParamsFromConfig(q, config, "postgres", "postgresql")
|
||
u.RawQuery = q.Encode()
|
||
|
||
return u.String()
|
||
}
|
||
|
||
func (p *PostgresDB) Connect(config connection.ConnectionConfig) error {
|
||
if supported, reason := DriverRuntimeSupportStatus("postgres"); !supported {
|
||
if strings.TrimSpace(reason) == "" {
|
||
reason = "PostgreSQL 纯 Go 驱动未启用,请先在驱动管理中安装启用"
|
||
}
|
||
return fmt.Errorf("%s", reason)
|
||
}
|
||
|
||
runConfig := config
|
||
p.pingTimeout = getConnectTimeout(config)
|
||
|
||
cleanupOnFailure := true
|
||
defer func() {
|
||
if !cleanupOnFailure {
|
||
return
|
||
}
|
||
if p.conn != nil {
|
||
_ = p.conn.Close()
|
||
p.conn = nil
|
||
}
|
||
if p.forwarder != nil {
|
||
_ = p.forwarder.Close()
|
||
p.forwarder = nil
|
||
}
|
||
}()
|
||
|
||
if config.UseSSH {
|
||
// Create SSH tunnel with local port forwarding
|
||
logger.Infof("PostgreSQL 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
|
||
|
||
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
|
||
if err != nil {
|
||
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
|
||
}
|
||
p.forwarder = forwarder
|
||
|
||
// Parse local address
|
||
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
|
||
if err != nil {
|
||
return fmt.Errorf("解析本地转发地址失败:%w", err)
|
||
}
|
||
|
||
port, err := strconv.Atoi(portStr)
|
||
if err != nil {
|
||
return fmt.Errorf("解析本地端口失败:%w", err)
|
||
}
|
||
|
||
// Create a modified config pointing to local forwarder
|
||
localConfig := config
|
||
localConfig.Host = host
|
||
localConfig.Port = port
|
||
localConfig.UseSSH = false // Disable SSH flag for DSN generation
|
||
|
||
runConfig = localConfig
|
||
logger.Infof("PostgreSQL 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
|
||
}
|
||
|
||
sslAttempts := []connection.ConnectionConfig{runConfig}
|
||
if shouldTrySSLPreferredFallback(runConfig) {
|
||
sslAttempts = append(sslAttempts, withSSLDisabled(runConfig))
|
||
}
|
||
|
||
var failures []string
|
||
for sslIndex, sslConfig := range sslAttempts {
|
||
sslLabel := "SSL"
|
||
if sslIndex > 0 {
|
||
sslLabel = "明文回退"
|
||
}
|
||
|
||
attemptDBs := resolvePostgresConnectDatabases(sslConfig)
|
||
for _, dbName := range attemptDBs {
|
||
attemptConfig := sslConfig
|
||
attemptConfig.Database = dbName
|
||
dsn := p.getDSN(attemptConfig)
|
||
|
||
dbConn, err := sql.Open("postgres", dsn)
|
||
if err != nil {
|
||
failures = append(failures, fmt.Sprintf("%s 数据库=%s 打开连接失败: %v", sslLabel, dbName, err))
|
||
continue
|
||
}
|
||
p.conn = dbConn
|
||
|
||
// Force verification
|
||
if err := p.Ping(); err != nil {
|
||
failures = append(failures, fmt.Sprintf("%s 数据库=%s 验证失败: %v", sslLabel, dbName, err))
|
||
_ = dbConn.Close()
|
||
p.conn = nil
|
||
continue
|
||
}
|
||
|
||
if sslIndex > 0 {
|
||
logger.Warnf("PostgreSQL SSL 优先连接失败,已回退至明文连接")
|
||
}
|
||
if strings.TrimSpace(config.Database) == "" && !strings.EqualFold(dbName, "postgres") {
|
||
logger.Infof("PostgreSQL 自动选择连接数据库:%s", dbName)
|
||
}
|
||
|
||
// 设置 search_path,使所有用户 schema 下的表可以不带 schema 前缀访问
|
||
p.ensureSearchPath(dsn)
|
||
|
||
cleanupOnFailure = false
|
||
return nil
|
||
}
|
||
}
|
||
|
||
if len(failures) == 0 {
|
||
return fmt.Errorf("连接建立后验证失败:未找到可用的连接数据库")
|
||
}
|
||
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ";"))
|
||
}
|
||
|
||
func (p *PostgresDB) Close() error {
|
||
// Close SSH forwarder first if exists
|
||
if p.forwarder != nil {
|
||
if err := p.forwarder.Close(); err != nil {
|
||
logger.Warnf("关闭 PostgreSQL SSH 端口转发失败:%v", err)
|
||
}
|
||
p.forwarder = nil
|
||
}
|
||
|
||
// Then close database connection
|
||
if p.conn != nil {
|
||
return p.conn.Close()
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (p *PostgresDB) Ping() error {
|
||
if p.conn == nil {
|
||
return fmt.Errorf("连接未打开")
|
||
}
|
||
timeout := p.pingTimeout
|
||
if timeout <= 0 {
|
||
timeout = 5 * time.Second
|
||
}
|
||
ctx, cancel := utils.ContextWithTimeout(timeout)
|
||
defer cancel()
|
||
return p.conn.PingContext(ctx)
|
||
}
|
||
|
||
func (p *PostgresDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||
if p.conn == nil {
|
||
return nil, nil, fmt.Errorf("连接未打开")
|
||
}
|
||
|
||
rows, err := p.conn.QueryContext(ctx, query)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
return scanRows(rows)
|
||
}
|
||
|
||
func (p *PostgresDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||
if p.conn == nil {
|
||
return nil, nil, fmt.Errorf("连接未打开")
|
||
}
|
||
|
||
rows, err := p.conn.Query(query)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
defer rows.Close()
|
||
return scanRows(rows)
|
||
}
|
||
|
||
func (p *PostgresDB) ExecBatchContext(ctx context.Context, query string) (int64, error) {
|
||
if p.conn == nil {
|
||
return 0, fmt.Errorf("连接未打开")
|
||
}
|
||
res, err := p.conn.ExecContext(ctx, query)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return res.RowsAffected()
|
||
}
|
||
|
||
func (p *PostgresDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||
if p.conn == nil {
|
||
return 0, fmt.Errorf("连接未打开")
|
||
}
|
||
res, err := p.conn.ExecContext(ctx, query)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return res.RowsAffected()
|
||
}
|
||
|
||
func (p *PostgresDB) Exec(query string) (int64, error) {
|
||
if p.conn == nil {
|
||
return 0, fmt.Errorf("连接未打开")
|
||
}
|
||
res, err := p.conn.Exec(query)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return res.RowsAffected()
|
||
}
|
||
|
||
func (p *PostgresDB) GetDatabases() ([]string, error) {
|
||
data, _, err := p.Query("SELECT datname FROM pg_database WHERE datistemplate = false")
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
var dbs []string
|
||
for _, row := range data {
|
||
if val, ok := row["datname"]; ok {
|
||
dbs = append(dbs, fmt.Sprintf("%v", val))
|
||
}
|
||
}
|
||
return dbs, nil
|
||
}
|
||
|
||
func (p *PostgresDB) GetTables(dbName string) ([]string, error) {
|
||
query := "SELECT schemaname, tablename FROM pg_catalog.pg_tables WHERE schemaname != 'information_schema' AND schemaname NOT LIKE 'pg_%' ORDER BY schemaname, tablename"
|
||
data, _, err := p.Query(query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var tables []string
|
||
for _, row := range data {
|
||
schema, okSchema := row["schemaname"]
|
||
name, okName := row["tablename"]
|
||
if okSchema && okName {
|
||
tables = append(tables, fmt.Sprintf("%v.%v", schema, name))
|
||
continue
|
||
}
|
||
if okName {
|
||
tables = append(tables, fmt.Sprintf("%v", name))
|
||
}
|
||
}
|
||
return tables, nil
|
||
}
|
||
|
||
func (p *PostgresDB) GetCreateStatement(dbName, tableName string) (string, error) {
|
||
return fmt.Sprintf("-- SHOW CREATE TABLE not fully supported for PostgreSQL in this MVP.\n-- Table: %s", tableName), nil
|
||
}
|
||
|
||
func (p *PostgresDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||
schema := strings.TrimSpace(dbName)
|
||
if schema == "" {
|
||
schema = "public"
|
||
}
|
||
table := strings.TrimSpace(tableName)
|
||
if table == "" {
|
||
return nil, fmt.Errorf("表名不能为空")
|
||
}
|
||
|
||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||
|
||
query := fmt.Sprintf(`
|
||
SELECT
|
||
a.attname AS column_name,
|
||
pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
|
||
CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable,
|
||
pg_get_expr(ad.adbin, ad.adrelid) AS column_default,
|
||
col_description(a.attrelid, a.attnum) AS comment,
|
||
CASE WHEN pk.attname IS NOT NULL THEN 'PRI' ELSE '' END AS column_key
|
||
FROM pg_class c
|
||
JOIN pg_namespace n ON n.oid = c.relnamespace
|
||
JOIN pg_attribute a ON a.attrelid = c.oid
|
||
LEFT JOIN pg_attrdef ad ON ad.adrelid = c.oid AND ad.adnum = a.attnum
|
||
LEFT JOIN (
|
||
SELECT i.indrelid, a3.attname
|
||
FROM pg_index i
|
||
JOIN pg_attribute a3 ON a3.attrelid = i.indrelid AND a3.attnum = ANY(i.indkey)
|
||
WHERE i.indisprimary
|
||
) pk ON pk.indrelid = c.oid AND pk.attname = a.attname
|
||
WHERE c.relkind IN ('r', 'p')
|
||
AND n.nspname = '%s'
|
||
AND c.relname = '%s'
|
||
AND a.attnum > 0
|
||
AND NOT a.attisdropped
|
||
ORDER BY a.attnum`, esc(schema), esc(table))
|
||
|
||
data, _, err := p.Query(query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var columns []connection.ColumnDefinition
|
||
for _, row := range data {
|
||
col := connection.ColumnDefinition{
|
||
Name: fmt.Sprintf("%v", row["column_name"]),
|
||
Type: fmt.Sprintf("%v", row["data_type"]),
|
||
Nullable: fmt.Sprintf("%v", row["is_nullable"]),
|
||
Key: fmt.Sprintf("%v", row["column_key"]),
|
||
Extra: "",
|
||
Comment: "",
|
||
}
|
||
|
||
if v, ok := row["comment"]; ok && v != nil {
|
||
col.Comment = fmt.Sprintf("%v", v)
|
||
}
|
||
|
||
if v, ok := row["column_default"]; ok && v != nil {
|
||
def := fmt.Sprintf("%v", v)
|
||
col.Default = &def
|
||
if strings.HasPrefix(strings.ToLower(strings.TrimSpace(def)), "nextval(") {
|
||
col.Extra = "auto_increment"
|
||
}
|
||
}
|
||
|
||
columns = append(columns, col)
|
||
}
|
||
return columns, nil
|
||
}
|
||
|
||
func (p *PostgresDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
|
||
schema := strings.TrimSpace(dbName)
|
||
if schema == "" {
|
||
schema = "public"
|
||
}
|
||
table := strings.TrimSpace(tableName)
|
||
if table == "" {
|
||
return nil, fmt.Errorf("表名不能为空")
|
||
}
|
||
|
||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||
|
||
query := fmt.Sprintf(`
|
||
SELECT
|
||
i.relname AS index_name,
|
||
a.attname AS column_name,
|
||
ix.indisunique AS is_unique,
|
||
x.ordinality AS seq_in_index,
|
||
am.amname AS index_type
|
||
FROM pg_class t
|
||
JOIN pg_namespace n ON n.oid = t.relnamespace
|
||
JOIN pg_index ix ON t.oid = ix.indrelid
|
||
JOIN pg_class i ON i.oid = ix.indexrelid
|
||
JOIN pg_am am ON i.relam = am.oid
|
||
JOIN unnest(ix.indkey) WITH ORDINALITY AS x(attnum, ordinality) ON TRUE
|
||
JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = x.attnum
|
||
WHERE t.relkind IN ('r', 'p')
|
||
AND t.relname = '%s'
|
||
AND n.nspname = '%s'
|
||
AND ix.indisvalid
|
||
AND ix.indpred IS NULL
|
||
AND x.ordinality <= ix.indnkeyatts
|
||
AND NOT EXISTS (
|
||
SELECT 1 FROM unnest(ix.indkey) AS expr_key(attnum) WHERE expr_key.attnum <= 0
|
||
)
|
||
ORDER BY i.relname, x.ordinality`, esc(table), esc(schema))
|
||
|
||
data, _, err := p.Query(query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
parseBool := func(v interface{}) bool {
|
||
switch val := v.(type) {
|
||
case bool:
|
||
return val
|
||
case string:
|
||
s := strings.ToLower(strings.TrimSpace(val))
|
||
return s == "t" || s == "true" || s == "1" || s == "y" || s == "yes"
|
||
default:
|
||
s := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v)))
|
||
return s == "t" || s == "true" || s == "1" || s == "y" || s == "yes"
|
||
}
|
||
}
|
||
|
||
parseInt := func(v interface{}) int {
|
||
switch val := v.(type) {
|
||
case int:
|
||
return val
|
||
case int64:
|
||
return int(val)
|
||
case float64:
|
||
return int(val)
|
||
case string:
|
||
// best effort
|
||
var n int
|
||
_, _ = fmt.Sscanf(strings.TrimSpace(val), "%d", &n)
|
||
return n
|
||
default:
|
||
var n int
|
||
_, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", v)), "%d", &n)
|
||
return n
|
||
}
|
||
}
|
||
|
||
var indexes []connection.IndexDefinition
|
||
for _, row := range data {
|
||
isUnique := false
|
||
if v, ok := row["is_unique"]; ok && v != nil {
|
||
isUnique = parseBool(v)
|
||
}
|
||
|
||
nonUnique := 1
|
||
if isUnique {
|
||
nonUnique = 0
|
||
}
|
||
|
||
seq := 0
|
||
if v, ok := row["seq_in_index"]; ok && v != nil {
|
||
seq = parseInt(v)
|
||
}
|
||
|
||
indexType := ""
|
||
if v, ok := row["index_type"]; ok && v != nil {
|
||
indexType = strings.ToUpper(fmt.Sprintf("%v", v))
|
||
}
|
||
if indexType == "" {
|
||
indexType = "BTREE"
|
||
}
|
||
|
||
idx := connection.IndexDefinition{
|
||
Name: fmt.Sprintf("%v", row["index_name"]),
|
||
ColumnName: fmt.Sprintf("%v", row["column_name"]),
|
||
NonUnique: nonUnique,
|
||
SeqInIndex: seq,
|
||
IndexType: indexType,
|
||
}
|
||
indexes = append(indexes, idx)
|
||
}
|
||
return indexes, nil
|
||
}
|
||
|
||
func (p *PostgresDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
|
||
schema := strings.TrimSpace(dbName)
|
||
if schema == "" {
|
||
schema = "public"
|
||
}
|
||
table := strings.TrimSpace(tableName)
|
||
if table == "" {
|
||
return nil, fmt.Errorf("表名不能为空")
|
||
}
|
||
|
||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||
|
||
query := fmt.Sprintf(`
|
||
SELECT
|
||
tc.constraint_name AS constraint_name,
|
||
kcu.column_name AS column_name,
|
||
ccu.table_schema AS foreign_table_schema,
|
||
ccu.table_name AS foreign_table_name,
|
||
ccu.column_name AS foreign_column_name
|
||
FROM information_schema.table_constraints AS tc
|
||
JOIN information_schema.key_column_usage AS kcu
|
||
ON tc.constraint_name = kcu.constraint_name
|
||
AND tc.table_schema = kcu.table_schema
|
||
JOIN information_schema.constraint_column_usage AS ccu
|
||
ON ccu.constraint_name = tc.constraint_name
|
||
AND ccu.table_schema = tc.table_schema
|
||
WHERE tc.constraint_type = 'FOREIGN KEY'
|
||
AND tc.table_name = '%s'
|
||
AND tc.table_schema = '%s'
|
||
ORDER BY tc.constraint_name, kcu.ordinal_position`, esc(table), esc(schema))
|
||
|
||
data, _, err := p.Query(query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var fks []connection.ForeignKeyDefinition
|
||
for _, row := range data {
|
||
refSchema := ""
|
||
if v, ok := row["foreign_table_schema"]; ok && v != nil {
|
||
refSchema = fmt.Sprintf("%v", v)
|
||
}
|
||
refTable := fmt.Sprintf("%v", row["foreign_table_name"])
|
||
refTableName := refTable
|
||
if strings.TrimSpace(refSchema) != "" {
|
||
refTableName = fmt.Sprintf("%s.%s", refSchema, refTable)
|
||
}
|
||
|
||
fk := connection.ForeignKeyDefinition{
|
||
Name: fmt.Sprintf("%v", row["constraint_name"]),
|
||
ColumnName: fmt.Sprintf("%v", row["column_name"]),
|
||
RefTableName: refTableName,
|
||
RefColumnName: fmt.Sprintf("%v", row["foreign_column_name"]),
|
||
ConstraintName: fmt.Sprintf("%v", row["constraint_name"]),
|
||
}
|
||
fks = append(fks, fk)
|
||
}
|
||
return fks, nil
|
||
}
|
||
|
||
func (p *PostgresDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
|
||
schema := strings.TrimSpace(dbName)
|
||
if schema == "" {
|
||
schema = "public"
|
||
}
|
||
table := strings.TrimSpace(tableName)
|
||
if table == "" {
|
||
return nil, fmt.Errorf("表名不能为空")
|
||
}
|
||
|
||
esc := func(s string) string { return strings.ReplaceAll(s, "'", "''") }
|
||
|
||
query := fmt.Sprintf(`
|
||
SELECT trigger_name, action_timing, event_manipulation, action_statement
|
||
FROM information_schema.triggers
|
||
WHERE event_object_table = '%s'
|
||
AND event_object_schema = '%s'
|
||
ORDER BY trigger_name, event_manipulation`, esc(table), esc(schema))
|
||
|
||
data, _, err := p.Query(query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var triggers []connection.TriggerDefinition
|
||
for _, row := range data {
|
||
trig := connection.TriggerDefinition{
|
||
Name: fmt.Sprintf("%v", row["trigger_name"]),
|
||
Timing: fmt.Sprintf("%v", row["action_timing"]),
|
||
Event: fmt.Sprintf("%v", row["event_manipulation"]),
|
||
Statement: fmt.Sprintf("%v", row["action_statement"]),
|
||
}
|
||
triggers = append(triggers, trig)
|
||
}
|
||
return triggers, nil
|
||
}
|
||
|
||
func (p *PostgresDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
|
||
query := `
|
||
SELECT table_schema, table_name, column_name, data_type
|
||
FROM information_schema.columns
|
||
WHERE table_schema NOT IN ('pg_catalog', 'information_schema')
|
||
AND table_schema NOT LIKE 'pg_%'
|
||
ORDER BY table_schema, table_name, ordinal_position`
|
||
|
||
data, _, err := p.Query(query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var cols []connection.ColumnDefinitionWithTable
|
||
for _, row := range data {
|
||
schema := fmt.Sprintf("%v", row["table_schema"])
|
||
table := fmt.Sprintf("%v", row["table_name"])
|
||
tableName := table
|
||
if strings.TrimSpace(schema) != "" {
|
||
tableName = fmt.Sprintf("%s.%s", schema, table)
|
||
}
|
||
|
||
col := connection.ColumnDefinitionWithTable{
|
||
TableName: tableName,
|
||
Name: fmt.Sprintf("%v", row["column_name"]),
|
||
Type: fmt.Sprintf("%v", row["data_type"]),
|
||
}
|
||
cols = append(cols, col)
|
||
}
|
||
return cols, nil
|
||
}
|
||
|
||
// ensureSearchPath 查询当前数据库中所有用户 schema,通过重建连接池将 search_path 写入 DSN。
|
||
// 仅使用 SET search_path 只对连接池中的单个连接生效,后续查询可能拿到未设置的连接。
|
||
// 将 search_path 写入 DSN (lib/pq 支持任意 PostgreSQL runtime parameter),
|
||
// 使连接池中每个连接建立时自动携带 search_path,与金仓行为一致。
|
||
func (p *PostgresDB) ensureSearchPath(baseDSN string) {
|
||
if p.conn == nil {
|
||
return
|
||
}
|
||
|
||
rawSchemas := p.queryUserSchemas()
|
||
if len(rawSchemas) == 0 {
|
||
return
|
||
}
|
||
|
||
// 构建 search_path SQL 片段(带双引号转义),用于 SET 兜底
|
||
searchPathSQL, normalizedSchemas := buildKingbaseSearchPathCommon(rawSchemas)
|
||
if strings.TrimSpace(searchPathSQL) == "" {
|
||
return
|
||
}
|
||
|
||
// 策略 1:将 search_path 写入 DSN,重建连接池
|
||
// lib/pq 支持在 URL 查参数中设置任意 PostgreSQL runtime parameter,
|
||
// 如 ?search_path=ce,public,每个新连接建立时会自动 SET search_path。
|
||
searchPathDSNVal := strings.Join(normalizedSchemas, ",")
|
||
u, parseErr := url.Parse(baseDSN)
|
||
if parseErr == nil {
|
||
q := u.Query()
|
||
q.Set("search_path", searchPathDSNVal)
|
||
u.RawQuery = q.Encode()
|
||
newDSN := u.String()
|
||
|
||
newDB, err := sql.Open("postgres", newDSN)
|
||
if err == nil {
|
||
newDB.SetConnMaxLifetime(5 * time.Minute)
|
||
oldConn := p.conn
|
||
p.conn = newDB
|
||
if err := p.Ping(); err == nil {
|
||
_ = oldConn.Close()
|
||
logger.Infof("PostgreSQL 已通过 DSN 配置 search_path:%s", searchPathDSNVal)
|
||
return
|
||
}
|
||
// DSN 方式失败,回滚
|
||
_ = newDB.Close()
|
||
p.conn = oldConn
|
||
logger.Warnf("PostgreSQL DSN search_path 验证失败,回退至 SET 方式")
|
||
}
|
||
}
|
||
|
||
// 策略 2 兜底:通过 SET search_path 设置(仅影响单个连接,但聊胜于无)
|
||
timeout := p.pingTimeout
|
||
if timeout <= 0 {
|
||
timeout = 5 * time.Second
|
||
}
|
||
ctx, cancel := utils.ContextWithTimeout(timeout)
|
||
defer cancel()
|
||
|
||
if _, err := p.conn.ExecContext(ctx, fmt.Sprintf("SET search_path TO %s", searchPathSQL)); err != nil {
|
||
logger.Warnf("PostgreSQL 设置 search_path 失败:%v", err)
|
||
return
|
||
}
|
||
logger.Infof("PostgreSQL 已通过 SET 设置 search_path:%s", searchPathSQL)
|
||
}
|
||
|
||
// queryUserSchemas 查询当前数据库中所有用户 schema。
|
||
func (p *PostgresDB) queryUserSchemas() []string {
|
||
if p.conn == nil {
|
||
return nil
|
||
}
|
||
|
||
query := `SELECT nspname FROM pg_namespace
|
||
WHERE nspname NOT IN ('pg_catalog', 'information_schema')
|
||
AND nspname NOT LIKE 'pg_%'
|
||
ORDER BY nspname`
|
||
|
||
rows, err := p.conn.Query(query)
|
||
if err != nil {
|
||
logger.Warnf("PostgreSQL 查询用户 schema 失败:%v", err)
|
||
return nil
|
||
}
|
||
defer rows.Close()
|
||
|
||
var schemas []string
|
||
for rows.Next() {
|
||
var name string
|
||
if err := rows.Scan(&name); err != nil {
|
||
continue
|
||
}
|
||
name = strings.TrimSpace(name)
|
||
if name != "" {
|
||
schemas = append(schemas, name)
|
||
}
|
||
}
|
||
return schemas
|
||
}
|
||
|
||
func (p *PostgresDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
|
||
if p.conn == nil {
|
||
return fmt.Errorf("连接未打开")
|
||
}
|
||
|
||
tx, err := p.conn.Begin()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer tx.Rollback()
|
||
|
||
quoteIdent := func(name string) string {
|
||
n := strings.TrimSpace(name)
|
||
n = strings.Trim(n, "\"")
|
||
n = strings.ReplaceAll(n, "\"", "\"\"")
|
||
if n == "" {
|
||
return "\"\""
|
||
}
|
||
return `"` + n + `"`
|
||
}
|
||
|
||
schema := ""
|
||
table := strings.TrimSpace(tableName)
|
||
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
|
||
schema = strings.TrimSpace(parts[0])
|
||
table = strings.TrimSpace(parts[1])
|
||
}
|
||
|
||
qualifiedTable := ""
|
||
if schema != "" {
|
||
qualifiedTable = fmt.Sprintf("%s.%s", quoteIdent(schema), quoteIdent(table))
|
||
} else {
|
||
qualifiedTable = quoteIdent(table)
|
||
}
|
||
|
||
// 1. Deletes
|
||
for _, pk := range changes.Deletes {
|
||
var wheres []string
|
||
var args []interface{}
|
||
idx := 0
|
||
for k, v := range pk {
|
||
idx++
|
||
wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteIdent(k), idx))
|
||
args = append(args, v)
|
||
}
|
||
if len(wheres) == 0 {
|
||
continue
|
||
}
|
||
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
|
||
res, err := tx.Exec(query, args...)
|
||
if err != nil {
|
||
return fmt.Errorf("删除失败:%v", err)
|
||
}
|
||
if err := requireSingleRowAffected(res, "删除"); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
// 2. Updates
|
||
for _, update := range changes.Updates {
|
||
var sets []string
|
||
var args []interface{}
|
||
idx := 0
|
||
|
||
for k, v := range update.Values {
|
||
idx++
|
||
sets = append(sets, fmt.Sprintf("%s = $%d", quoteIdent(k), idx))
|
||
args = append(args, v)
|
||
}
|
||
|
||
if len(sets) == 0 {
|
||
continue
|
||
}
|
||
|
||
var wheres []string
|
||
for k, v := range update.Keys {
|
||
idx++
|
||
wheres = append(wheres, fmt.Sprintf("%s = $%d", quoteIdent(k), idx))
|
||
args = append(args, v)
|
||
}
|
||
|
||
if len(wheres) == 0 {
|
||
return fmt.Errorf("更新操作需要主键条件")
|
||
}
|
||
|
||
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
|
||
res, err := tx.Exec(query, args...)
|
||
if err != nil {
|
||
return fmt.Errorf("更新失败:%v", err)
|
||
}
|
||
if err := requireSingleRowAffected(res, "更新"); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
// 3. Inserts
|
||
for _, row := range changes.Inserts {
|
||
var cols []string
|
||
var placeholders []string
|
||
var args []interface{}
|
||
idx := 0
|
||
|
||
for k, v := range row {
|
||
idx++
|
||
cols = append(cols, quoteIdent(k))
|
||
placeholders = append(placeholders, fmt.Sprintf("$%d", idx))
|
||
args = append(args, v)
|
||
}
|
||
|
||
if len(cols) == 0 {
|
||
continue
|
||
}
|
||
|
||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
|
||
if _, err := tx.Exec(query, args...); err != nil {
|
||
return fmt.Errorf("插入失败:%v", err)
|
||
}
|
||
}
|
||
|
||
return tx.Commit()
|
||
}
|