♻️ refactor(database/ssh): SSH隧道架构重构与多数据源适配

- 架构升级:从driver专属拨号器改为通用本地端口转发模式
  - 并发安全:sync.Once保护Close操作,RWMutex保护状态访问,双向errc等待
  - 连接池化:GetOrCreateLocalForwarder/GetOrCreateSSHClient实现缓存复用
  - SQL安全:kingbase_impl.go引入esc函数,防止双引号注入(""ldf_server""问题)
  - Schema动态化:三级fallback(schema.table解析→dbName参数→current_schema())
  - 代码复用:scanRows统一行扫描逻辑,normalizeQueryValueWithDBType增强类型处理
  Close #40
This commit is contained in:
Syngnat
2026-02-04 14:35:31 +08:00
parent d8656c6c9c
commit 71e5de0cdc
15 changed files with 879 additions and 325 deletions

View File

@@ -11,6 +11,7 @@ import (
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
@@ -20,6 +21,7 @@ import (
type OracleDB struct {
conn *sql.DB
pingTimeout time.Duration
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
}
func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
@@ -29,28 +31,6 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
database = config.User // Default to user service/schema if empty?
}
if config.UseSSH {
_, err := ssh.RegisterSSHNetwork(config.SSH)
if err == nil {
// Oracle driver might not support custom dialer via DSN easily without extra config
// But go-ora v2 supports some advanced options.
// For simplicity, we assume standard TCP or we might need a workaround for SSH.
// go-ora v2 is pure Go, so we can potentially use a custom dialer if we manually open.
// But for now, let's just use the address.
// SSH tunneling via net.Dialer override is complex in sql.Open("oracle", ...).
// We might need to forward a local port if using SSH.
// Since ssh.RegisterSSHNetwork creates a custom network "ssh-via-...",
// we need to see if go-ora supports custom networks.
// Checking go-ora docs (simulated): It supports "unix" and "tcp".
// We might need to map the custom network to a local proxy.
// For now, we will assume direct connection or handle SSH separately later.
// We'll leave the protocol implementation as is in MySQL for now, hoping go-ora uses standard net.Dial.
// Note: go-ora connection string: oracle://user:pass@host:port/service
// It parses host/port. It doesn't easily take a custom "network" parameter in URL.
// We will proceed with standard TCP string.
}
}
u := &url.URL{
Scheme: "oracle",
Host: net.JoinHostPort(config.Host, strconv.Itoa(config.Port)),
@@ -62,7 +42,42 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
}
func (o *OracleDB) Connect(config connection.ConnectionConfig) error {
dsn := o.getDSN(config)
var dsn string
var err error
if config.UseSSH {
// Create SSH tunnel with local port forwarding
logger.Infof("Oracle 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
if err != nil {
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
}
o.forwarder = forwarder
// Parse local address
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
if err != nil {
return fmt.Errorf("解析本地转发地址失败:%w", err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("解析本地端口失败:%w", err)
}
// Create a modified config pointing to local forwarder
localConfig := config
localConfig.Host = host
localConfig.Port = port
localConfig.UseSSH = false
dsn = o.getDSN(localConfig)
logger.Infof("Oracle 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
} else {
dsn = o.getDSN(config)
}
db, err := sql.Open("oracle", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
@@ -76,6 +91,15 @@ func (o *OracleDB) Connect(config connection.ConnectionConfig) error {
}
func (o *OracleDB) Close() error {
// Close SSH forwarder first if exists
if o.forwarder != nil {
if err := o.forwarder.Close(); err != nil {
logger.Warnf("关闭 Oracle SSH 端口转发失败:%v", err)
}
o.forwarder = nil
}
// Then close database connection
if o.conn != nil {
return o.conn.Close()
}
@@ -119,33 +143,7 @@ func (o *OracleDB) Query(query string) ([]map[string]interface{}, []string, erro
return nil, nil, err
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
return nil, nil, err
}
var resultData []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
continue
}
entry := make(map[string]interface{})
for i, col := range columns {
entry[col] = normalizeQueryValue(values[i])
}
resultData = append(resultData, entry)
}
return resultData, columns, nil
return scanRows(rows)
}
func (o *OracleDB) ExecContext(ctx context.Context, query string) (int64, error) {