diff --git a/frontend/src/components/ConnectionModal.tsx b/frontend/src/components/ConnectionModal.tsx
index 8fa0bb9..e06a834 100644
--- a/frontend/src/components/ConnectionModal.tsx
+++ b/frontend/src/components/ConnectionModal.tsx
@@ -264,8 +264,8 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
{useSSH && (
-
-
+
+
diff --git a/internal/app/app.go b/internal/app/app.go
index 622bfc2..c61169e 100644
--- a/internal/app/app.go
+++ b/internal/app/app.go
@@ -141,9 +141,6 @@ func (a *App) getDatabase(config connection.ConnectionConfig) (db.Database, erro
if len(shortKey) > 12 {
shortKey = shortKey[:12]
}
- if config.UseSSH && config.Type != "mysql" {
- logger.Warnf("当前仅 MySQL 支持内置 SSH 直连,其他类型请使用本地端口转发:%s", formatConnSummary(config))
- }
logger.Infof("获取数据库连接:%s 缓存Key=%s", formatConnSummary(config), shortKey)
a.mu.Lock()
diff --git a/internal/app/sql_sanitize.go b/internal/app/sql_sanitize.go
index 9ba89e3..4e37ed5 100644
--- a/internal/app/sql_sanitize.go
+++ b/internal/app/sql_sanitize.go
@@ -8,14 +8,26 @@ import (
func sanitizeSQLForPgLike(dbType string, query string) string {
switch strings.ToLower(strings.TrimSpace(dbType)) {
case "postgres", "kingbase":
- return fixBrokenDoubleDoubleQuotedIdent(query)
+ // 有些情况下会出现多层重复引用(例如 """"schema"""" 或 ""schema"""),单次修复不一定收敛。
+ // 这里做有限次数的迭代,直到输出不再变化。
+ out := query
+ for i := 0; i < 3; i++ {
+ fixed := fixBrokenDoubleDoubleQuotedIdent(out)
+ if fixed == out {
+ break
+ }
+ out = fixed
+ }
+ return out
default:
return query
}
}
// fixBrokenDoubleDoubleQuotedIdent fixes accidental identifiers like:
-// SELECT * FROM ""schema"".""table""
+//
+// SELECT * FROM ""schema"".""table""
+//
// which can be produced when a quoted identifier gets wrapped by quotes again.
//
// It is intentionally conservative:
@@ -124,20 +136,17 @@ func fixBrokenDoubleDoubleQuotedIdent(query string) string {
}
}
- // Fix: ""ident"" -> "ident" (only when it looks like a plain identifier)
- if ch == '"' && next == '"' {
- prevIsQuote := i > 0 && query[i-1] == '"'
- nextIsQuote := i+2 < len(query) && query[i+2] == '"'
- if !prevIsQuote && !nextIsQuote {
+ if ch == '"' {
+ // Fix: ""ident"" -> "ident" (only when it looks like a plain identifier)
+ // Also handle variants like ""ident""" / """"ident"""" (extra quotes at either side).
+ if next == '"' {
if replacement, advance, ok := tryFixDoubleDoubleQuotedIdent(query, i); ok {
b.WriteString(replacement)
i = advance - 1
continue
}
}
- }
- if ch == '"' {
b.WriteByte(ch)
inDoubleIdent = true
continue
@@ -150,7 +159,8 @@ func fixBrokenDoubleDoubleQuotedIdent(query string) string {
}
func tryFixDoubleDoubleQuotedIdent(query string, start int) (replacement string, advance int, ok bool) {
- // start points at the first quote of `""...""`
+ // start points at the first quote of a broken identifier, usually like:
+ // ""ident"" / ""ident""" / """"ident""""
if start < 0 || start+1 >= len(query) {
return "", 0, false
}
@@ -160,24 +170,31 @@ func tryFixDoubleDoubleQuotedIdent(query string, start int) (replacement string,
if start > 0 && query[start-1] == '"' {
return "", 0, false
}
- if start+2 < len(query) && query[start+2] == '"' {
+
+ runLen := 0
+ for start+runLen < len(query) && query[start+runLen] == '"' {
+ runLen++
+ }
+ if runLen < 2 || runLen%2 == 1 {
+ // Odd run (e.g. """...) can be a valid quoted identifier with escaped quotes.
return "", 0, false
}
- contentStart := start + 2
+ contentStart := start + runLen
j := contentStart
- for j+1 < len(query) {
- if query[j] == '"' && query[j+1] == '"' {
- // ensure closing pair is not part of a triple quote
- if j+2 < len(query) && query[j+2] == '"' {
- j++
- continue
+ for j < len(query) {
+ if query[j] == '"' {
+ endRunLen := 0
+ for j+endRunLen < len(query) && query[j+endRunLen] == '"' {
+ endRunLen++
}
- content := strings.TrimSpace(query[contentStart:j])
- if looksLikeIdentifierContent(content) {
- return `"` + content + `"`, j + 2, true
+ if endRunLen >= 2 {
+ content := strings.TrimSpace(query[contentStart:j])
+ if looksLikeIdentifierContent(content) {
+ return `"` + content + `"`, j + endRunLen, true
+ }
+ return "", 0, false
}
- return "", 0, false
}
// Fast abort: identifier-like content should not span lines.
if query[j] == '\n' || query[j] == '\r' {
diff --git a/internal/app/sql_sanitize_test.go b/internal/app/sql_sanitize_test.go
index fbee1f6..426ddda 100644
--- a/internal/app/sql_sanitize_test.go
+++ b/internal/app/sql_sanitize_test.go
@@ -11,6 +11,24 @@ func TestSanitizeSQLForPgLike_FixesBrokenDoubleDoubleQuotes(t *testing.T) {
}
}
+func TestSanitizeSQLForPgLike_FixesBrokenDoubleDoubleQuotes_WithExtraQuotes(t *testing.T) {
+ in := `SELECT * FROM ""ldf_server""".""t_user"" LIMIT 1`
+ out := sanitizeSQLForPgLike("kingbase", in)
+ want := `SELECT * FROM "ldf_server"."t_user" LIMIT 1`
+ if out != want {
+ t.Fatalf("unexpected sanitize output:\nIN: %s\nOUT: %s\nWANT: %s", in, out, want)
+ }
+}
+
+func TestSanitizeSQLForPgLike_FixesBrokenDoubleDoubleQuotes_WithQuadQuotes(t *testing.T) {
+ in := `SELECT * FROM """"ldf_server"""".""t_user"" LIMIT 1`
+ out := sanitizeSQLForPgLike("kingbase", in)
+ want := `SELECT * FROM "ldf_server"."t_user" LIMIT 1`
+ if out != want {
+ t.Fatalf("unexpected sanitize output:\nIN: %s\nOUT: %s\nWANT: %s", in, out, want)
+ }
+}
+
func TestSanitizeSQLForPgLike_DoesNotTouchEscapedQuotesInsideIdentifier(t *testing.T) {
in := `SELECT "a""b" FROM "t""x"`
out := sanitizeSQLForPgLike("postgres", in)
diff --git a/internal/db/custom_impl.go b/internal/db/custom_impl.go
index 495ff95..bc9f7e9 100644
--- a/internal/db/custom_impl.go
+++ b/internal/db/custom_impl.go
@@ -82,33 +82,7 @@ func (c *CustomDB) Query(query string) ([]map[string]interface{}, []string, erro
return nil, nil, err
}
defer rows.Close()
-
- columns, err := rows.Columns()
- if err != nil {
- return nil, nil, err
- }
-
- var resultData []map[string]interface{}
-
- for rows.Next() {
- values := make([]interface{}, len(columns))
- valuePtrs := make([]interface{}, len(columns))
- for i := range columns {
- valuePtrs[i] = &values[i]
- }
-
- if err := rows.Scan(valuePtrs...); err != nil {
- continue
- }
-
- entry := make(map[string]interface{})
- for i, col := range columns {
- entry[col] = normalizeQueryValue(values[i])
- }
- resultData = append(resultData, entry)
- }
-
- return resultData, columns, nil
+ return scanRows(rows)
}
func (c *CustomDB) ExecContext(ctx context.Context, query string) (int64, error) {
diff --git a/internal/db/dameng_impl.go b/internal/db/dameng_impl.go
index ce19473..52a263f 100644
--- a/internal/db/dameng_impl.go
+++ b/internal/db/dameng_impl.go
@@ -11,6 +11,7 @@ import (
"time"
"GoNavi-Wails/internal/connection"
+ "GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
@@ -20,6 +21,7 @@ import (
type DamengDB struct {
conn *sql.DB
pingTimeout time.Duration
+ forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
}
func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
@@ -27,16 +29,6 @@ func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
// or dm://user:password@host:port
address := net.JoinHostPort(config.Host, strconv.Itoa(config.Port))
- if config.UseSSH {
- // SSH logic similar to others, assumes port forwarding
- _, err := ssh.RegisterSSHNetwork(config.SSH)
- if err == nil {
- // DM driver likely uses standard net.Dial, so we might need a local listener
- // or assume port forwarding is handled externally or implicitly via "tcp" override if driver allows.
- // Similar to Oracle, we skip complex custom dialer injection for now.
- }
- }
-
escapedPassword := url.PathEscape(config.Password)
q := url.Values{}
if config.Database != "" {
@@ -56,7 +48,42 @@ func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
}
func (d *DamengDB) Connect(config connection.ConnectionConfig) error {
- dsn := d.getDSN(config)
+ var dsn string
+ var err error
+
+ if config.UseSSH {
+ // Create SSH tunnel with local port forwarding
+ logger.Infof("达梦数据库使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
+
+ forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
+ if err != nil {
+ return fmt.Errorf("创建 SSH 隧道失败:%w", err)
+ }
+ d.forwarder = forwarder
+
+ // Parse local address
+ host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
+ if err != nil {
+ return fmt.Errorf("解析本地转发地址失败:%w", err)
+ }
+
+ port, err := strconv.Atoi(portStr)
+ if err != nil {
+ return fmt.Errorf("解析本地端口失败:%w", err)
+ }
+
+ // Create a modified config pointing to local forwarder
+ localConfig := config
+ localConfig.Host = host
+ localConfig.Port = port
+ localConfig.UseSSH = false
+
+ dsn = d.getDSN(localConfig)
+ logger.Infof("达梦数据库通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
+ } else {
+ dsn = d.getDSN(config)
+ }
+
db, err := sql.Open("dm", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
@@ -70,6 +97,15 @@ func (d *DamengDB) Connect(config connection.ConnectionConfig) error {
}
func (d *DamengDB) Close() error {
+ // Close SSH forwarder first if exists
+ if d.forwarder != nil {
+ if err := d.forwarder.Close(); err != nil {
+ logger.Warnf("关闭达梦数据库 SSH 端口转发失败:%v", err)
+ }
+ d.forwarder = nil
+ }
+
+ // Then close database connection
if d.conn != nil {
return d.conn.Close()
}
@@ -113,33 +149,7 @@ func (d *DamengDB) Query(query string) ([]map[string]interface{}, []string, erro
return nil, nil, err
}
defer rows.Close()
-
- columns, err := rows.Columns()
- if err != nil {
- return nil, nil, err
- }
-
- var resultData []map[string]interface{}
-
- for rows.Next() {
- values := make([]interface{}, len(columns))
- valuePtrs := make([]interface{}, len(columns))
- for i := range columns {
- valuePtrs[i] = &values[i]
- }
-
- if err := rows.Scan(valuePtrs...); err != nil {
- continue
- }
-
- entry := make(map[string]interface{})
- for i, col := range columns {
- entry[col] = normalizeQueryValue(values[i])
- }
- resultData = append(resultData, entry)
- }
-
- return resultData, columns, nil
+ return scanRows(rows)
}
func (d *DamengDB) ExecContext(ctx context.Context, query string) (int64, error) {
diff --git a/internal/db/kingbase_impl.go b/internal/db/kingbase_impl.go
index 245fcbb..f9e423a 100644
--- a/internal/db/kingbase_impl.go
+++ b/internal/db/kingbase_impl.go
@@ -4,10 +4,13 @@ import (
"context"
"database/sql"
"fmt"
+ "net"
+ "strconv"
"strings"
"time"
"GoNavi-Wails/internal/connection"
+ "GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
@@ -17,6 +20,7 @@ import (
type KingbaseDB struct {
conn *sql.DB
pingTimeout time.Duration
+ forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
}
func quoteConnValue(v string) string {
@@ -58,20 +62,6 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string {
address := config.Host
port := config.Port
- if config.UseSSH {
- netName, err := ssh.RegisterSSHNetwork(config.SSH)
- if err == nil {
- // Kingbase/Postgres lib/pq allows custom dialer via "host" if using unix socket,
- // but for custom network it's harder.
- // Ideally we use a local forwarder.
- // For now, we assume standard TCP or handle SSH externally.
- // If we implement the net.Dial override for "kingbase" driver (which might use lib/pq internally),
- // we might need to check if it supports "cloudsql" style or similar custom dialers.
- // Similar to others, skipping SSH deep integration here for now.
- _ = netName
- }
- }
-
// Construct DSN
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable connect_timeout=%d",
quoteConnValue(address),
@@ -86,7 +76,42 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string {
}
func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error {
- dsn := k.getDSN(config)
+ var dsn string
+ var err error
+
+ if config.UseSSH {
+ // Create SSH tunnel with local port forwarding
+ logger.Infof("人大金仓使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
+
+ forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
+ if err != nil {
+ return fmt.Errorf("创建 SSH 隧道失败:%w", err)
+ }
+ k.forwarder = forwarder
+
+ // Parse local address
+ host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
+ if err != nil {
+ return fmt.Errorf("解析本地转发地址失败:%w", err)
+ }
+
+ port, err := strconv.Atoi(portStr)
+ if err != nil {
+ return fmt.Errorf("解析本地端口失败:%w", err)
+ }
+
+ // Create a modified config pointing to local forwarder
+ localConfig := config
+ localConfig.Host = host
+ localConfig.Port = port
+ localConfig.UseSSH = false
+
+ dsn = k.getDSN(localConfig)
+ logger.Infof("人大金仓通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
+ } else {
+ dsn = k.getDSN(config)
+ }
+
// Open using "kingbase" driver
db, err := sql.Open("kingbase", dsn)
if err != nil {
@@ -101,6 +126,15 @@ func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error {
}
func (k *KingbaseDB) Close() error {
+ // Close SSH forwarder first if exists
+ if k.forwarder != nil {
+ if err := k.forwarder.Close(); err != nil {
+ logger.Warnf("关闭人大金仓 SSH 端口转发失败:%v", err)
+ }
+ k.forwarder = nil
+ }
+
+ // Then close database connection
if k.conn != nil {
return k.conn.Close()
}
@@ -144,33 +178,7 @@ func (k *KingbaseDB) Query(query string) ([]map[string]interface{}, []string, er
return nil, nil, err
}
defer rows.Close()
-
- columns, err := rows.Columns()
- if err != nil {
- return nil, nil, err
- }
-
- var resultData []map[string]interface{}
-
- for rows.Next() {
- values := make([]interface{}, len(columns))
- valuePtrs := make([]interface{}, len(columns))
- for i := range columns {
- valuePtrs[i] = &values[i]
- }
-
- if err := rows.Scan(valuePtrs...); err != nil {
- continue
- }
-
- entry := make(map[string]interface{})
- for i, col := range columns {
- entry[col] = normalizeQueryValue(values[i])
- }
- resultData = append(resultData, entry)
- }
-
- return resultData, columns, nil
+ return scanRows(rows)
}
func (k *KingbaseDB) ExecContext(ctx context.Context, query string) (int64, error) {
@@ -249,15 +257,84 @@ func (k *KingbaseDB) GetCreateStatement(dbName, tableName string) (string, error
}
func (k *KingbaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
- schema := "public"
- if dbName != "" {
- schema = dbName
+ // 解析 schema.table 格式
+ schema := strings.TrimSpace(dbName)
+ table := strings.TrimSpace(tableName)
+
+ // 如果 tableName 包含 schema (格式: schema.table)
+ if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
+ parsedSchema := strings.TrimSpace(parts[0])
+ parsedTable := strings.TrimSpace(parts[1])
+ if parsedSchema != "" && parsedTable != "" {
+ schema = parsedSchema
+ table = parsedTable
+ }
}
- query := fmt.Sprintf(`SELECT column_name, data_type, is_nullable, column_default
- FROM information_schema.columns
- WHERE table_schema = '%s' AND table_name = '%s'
- ORDER BY ordinal_position`, schema, tableName)
+ // 如果仍然没有 schema,使用 current_schema()
+ // 这样可以自动匹配当前连接的 search_path
+ if schema == "" {
+ return k.getColumnsWithCurrentSchema(table)
+ }
+
+ if table == "" {
+ return nil, fmt.Errorf("table name required")
+ }
+
+ // 转义函数:处理单引号,移除双引号
+ esc := func(s string) string {
+ // 移除前后的双引号(如果存在)
+ s = strings.Trim(s, "\"")
+ // 转义单引号
+ return strings.ReplaceAll(s, "'", "''")
+ }
+
+ query := fmt.Sprintf(`SELECT column_name, data_type, is_nullable, column_default
+ FROM information_schema.columns
+ WHERE table_schema = '%s' AND table_name = '%s'
+ ORDER BY ordinal_position`, esc(schema), esc(table))
+
+ data, _, err := k.Query(query)
+ if err != nil {
+ return nil, err
+ }
+
+ var columns []connection.ColumnDefinition
+ for _, row := range data {
+ col := connection.ColumnDefinition{
+ Name: fmt.Sprintf("%v", row["column_name"]),
+ Type: fmt.Sprintf("%v", row["data_type"]),
+ Nullable: fmt.Sprintf("%v", row["is_nullable"]),
+ }
+
+ if row["column_default"] != nil {
+ def := fmt.Sprintf("%v", row["column_default"])
+ col.Default = &def
+ }
+
+ columns = append(columns, col)
+ }
+ return columns, nil
+}
+
+// getColumnsWithCurrentSchema 使用 current_schema() 查询当前schema的表
+func (k *KingbaseDB) getColumnsWithCurrentSchema(tableName string) ([]connection.ColumnDefinition, error) {
+ table := strings.TrimSpace(tableName)
+ if table == "" {
+ return nil, fmt.Errorf("table name required")
+ }
+
+ // 转义函数
+ esc := func(s string) string {
+ s = strings.Trim(s, "\"")
+ return strings.ReplaceAll(s, "'", "''")
+ }
+
+ // 使用 current_schema() 获取当前schema
+ query := fmt.Sprintf(`SELECT column_name, data_type, is_nullable, column_default
+ FROM information_schema.columns
+ WHERE table_schema = current_schema() AND table_name = '%s'
+ ORDER BY ordinal_position`, esc(table))
data, _, err := k.Query(query)
if err != nil {
@@ -283,32 +360,76 @@ func (k *KingbaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDe
}
func (k *KingbaseDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
- // Postgres/Kingbase index query
- query := fmt.Sprintf(`
- SELECT
- i.relname as index_name,
- a.attname as column_name,
- ix.indisunique as is_unique
- FROM
- pg_class t,
- pg_class i,
- pg_index ix,
- pg_attribute a,
- pg_namespace n
- WHERE
- t.oid = ix.indrelid
- AND i.oid = ix.indexrelid
- AND a.attrelid = t.oid
- AND a.attnum = ANY(ix.indkey)
- AND t.relkind = 'r'
- AND t.relname = '%s'
- AND n.oid = t.relnamespace
- AND n.nspname = '%s'
- `, tableName, "public") // Default to public if dbName (schema) not clear.
+ // 解析 schema.table 格式
+ schema := strings.TrimSpace(dbName)
+ table := strings.TrimSpace(tableName)
- if dbName != "" {
- // Update query to use dbName as schema
- query = strings.Replace(query, "'public'", fmt.Sprintf("'%s'", dbName), 1)
+ // 如果 tableName 包含 schema (格式: schema.table)
+ if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
+ parsedSchema := strings.TrimSpace(parts[0])
+ parsedTable := strings.TrimSpace(parts[1])
+ if parsedSchema != "" && parsedTable != "" {
+ schema = parsedSchema
+ table = parsedTable
+ }
+ }
+
+ if table == "" {
+ return nil, fmt.Errorf("table name required")
+ }
+
+ // 转义函数:处理单引号,移除双引号
+ esc := func(s string) string {
+ s = strings.Trim(s, "\"")
+ return strings.ReplaceAll(s, "'", "''")
+ }
+
+ // 构建查询:如果没有指定schema,使用current_schema()
+ var query string
+ if schema != "" {
+ query = fmt.Sprintf(`
+ SELECT
+ i.relname as index_name,
+ a.attname as column_name,
+ ix.indisunique as is_unique
+ FROM
+ pg_class t,
+ pg_class i,
+ pg_index ix,
+ pg_attribute a,
+ pg_namespace n
+ WHERE
+ t.oid = ix.indrelid
+ AND i.oid = ix.indexrelid
+ AND a.attrelid = t.oid
+ AND a.attnum = ANY(ix.indkey)
+ AND t.relkind = 'r'
+ AND t.relname = '%s'
+ AND n.oid = t.relnamespace
+ AND n.nspname = '%s'
+ `, esc(table), esc(schema))
+ } else {
+ query = fmt.Sprintf(`
+ SELECT
+ i.relname as index_name,
+ a.attname as column_name,
+ ix.indisunique as is_unique
+ FROM
+ pg_class t,
+ pg_class i,
+ pg_index ix,
+ pg_attribute a,
+ pg_namespace n
+ WHERE
+ t.oid = ix.indrelid
+ AND i.oid = ix.indexrelid
+ AND a.attrelid = t.oid
+ AND a.attnum = ANY(ix.indkey)
+ AND t.relkind = 'r'
+ AND t.relname = '%s'
+ AND n.oid = t.relnamespace
+ AND n.nspname = current_schema()
+ `, esc(table))
}
data, _, err := k.Query(query)
@@ -337,27 +458,67 @@ func (k *KingbaseDB) GetIndexes(dbName, tableName string) ([]connection.IndexDef
}
func (k *KingbaseDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
- schema := "public"
- if dbName != "" {
- schema = dbName
+ // 解析 schema.table 格式
+ schema := strings.TrimSpace(dbName)
+ table := strings.TrimSpace(tableName)
+
+ // 如果 tableName 包含 schema (格式: schema.table)
+ if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
+ parsedSchema := strings.TrimSpace(parts[0])
+ parsedTable := strings.TrimSpace(parts[1])
+ if parsedSchema != "" && parsedTable != "" {
+ schema = parsedSchema
+ table = parsedTable
+ }
}
- query := fmt.Sprintf(`
- SELECT
- tc.constraint_name,
- kcu.column_name,
- ccu.table_name AS foreign_table_name,
- ccu.column_name AS foreign_column_name
- FROM
- information_schema.table_constraints AS tc
- JOIN information_schema.key_column_usage AS kcu
- ON tc.constraint_name = kcu.constraint_name
- AND tc.table_schema = kcu.table_schema
- JOIN information_schema.constraint_column_usage AS ccu
- ON ccu.constraint_name = tc.constraint_name
- AND ccu.table_schema = tc.table_schema
- WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema='%s'`,
- tableName, schema)
+ if table == "" {
+ return nil, fmt.Errorf("table name required")
+ }
+
+ // 转义函数:处理单引号,移除双引号
+ esc := func(s string) string {
+ s = strings.Trim(s, "\"")
+ return strings.ReplaceAll(s, "'", "''")
+ }
+
+ // 构建查询:如果没有指定schema,使用current_schema()
+ var query string
+ if schema != "" {
+ query = fmt.Sprintf(`
+ SELECT
+ tc.constraint_name,
+ kcu.column_name,
+ ccu.table_name AS foreign_table_name,
+ ccu.column_name AS foreign_column_name
+ FROM
+ information_schema.table_constraints AS tc
+ JOIN information_schema.key_column_usage AS kcu
+ ON tc.constraint_name = kcu.constraint_name
+ AND tc.table_schema = kcu.table_schema
+ JOIN information_schema.constraint_column_usage AS ccu
+ ON ccu.constraint_name = tc.constraint_name
+ AND ccu.table_schema = tc.table_schema
+ WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema='%s'`,
+ esc(table), esc(schema))
+ } else {
+ query = fmt.Sprintf(`
+ SELECT
+ tc.constraint_name,
+ kcu.column_name,
+ ccu.table_name AS foreign_table_name,
+ ccu.column_name AS foreign_column_name
+ FROM
+ information_schema.table_constraints AS tc
+ JOIN information_schema.key_column_usage AS kcu
+ ON tc.constraint_name = kcu.constraint_name
+ AND tc.table_schema = kcu.table_schema
+ JOIN information_schema.constraint_column_usage AS ccu
+ ON ccu.constraint_name = tc.constraint_name
+ AND ccu.table_schema = tc.table_schema
+ WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema=current_schema()`,
+ esc(table))
+ }
data, _, err := k.Query(query)
if err != nil {
@@ -379,9 +540,43 @@ func (k *KingbaseDB) GetForeignKeys(dbName, tableName string) ([]connection.Fore
}
func (k *KingbaseDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
- query := fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation
- FROM information_schema.triggers
- WHERE event_object_table = '%s'`, tableName)
+ // 解析 schema.table 格式
+ schema := strings.TrimSpace(dbName)
+ table := strings.TrimSpace(tableName)
+
+ // 如果 tableName 包含 schema (格式: schema.table)
+ if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
+ parsedSchema := strings.TrimSpace(parts[0])
+ parsedTable := strings.TrimSpace(parts[1])
+ if parsedSchema != "" && parsedTable != "" {
+ schema = parsedSchema
+ table = parsedTable
+ }
+ }
+
+ if table == "" {
+ return nil, fmt.Errorf("table name required")
+ }
+
+ // 转义函数:处理单引号,移除双引号
+ esc := func(s string) string {
+ s = strings.Trim(s, "\"")
+ return strings.ReplaceAll(s, "'", "''")
+ }
+
+ // 构建查询:如果指定了schema,也加上schema条件
+ var query string
+ if schema != "" {
+ query = fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation
+ FROM information_schema.triggers
+ WHERE event_object_table = '%s' AND event_object_schema = '%s'`,
+ esc(table), esc(schema))
+ } else {
+ query = fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation
+ FROM information_schema.triggers
+ WHERE event_object_table = '%s' AND event_object_schema = current_schema()`,
+ esc(table))
+ }
data, _, err := k.Query(query)
if err != nil {
diff --git a/internal/db/mysql_impl.go b/internal/db/mysql_impl.go
index 2577d6a..5de70cf 100644
--- a/internal/db/mysql_impl.go
+++ b/internal/db/mysql_impl.go
@@ -101,33 +101,7 @@ func (m *MySQLDB) Query(query string) ([]map[string]interface{}, []string, error
return nil, nil, err
}
defer rows.Close()
-
- columns, err := rows.Columns()
- if err != nil {
- return nil, nil, err
- }
-
- var resultData []map[string]interface{}
-
- for rows.Next() {
- values := make([]interface{}, len(columns))
- valuePtrs := make([]interface{}, len(columns))
- for i := range columns {
- valuePtrs[i] = &values[i]
- }
-
- if err := rows.Scan(valuePtrs...); err != nil {
- continue
- }
-
- entry := make(map[string]interface{})
- for i, col := range columns {
- entry[col] = normalizeQueryValue(values[i])
- }
- resultData = append(resultData, entry)
- }
-
- return resultData, columns, nil
+ return scanRows(rows)
}
func (m *MySQLDB) ExecContext(ctx context.Context, query string) (int64, error) {
diff --git a/internal/db/oracle_impl.go b/internal/db/oracle_impl.go
index 454f460..f07c376 100644
--- a/internal/db/oracle_impl.go
+++ b/internal/db/oracle_impl.go
@@ -11,6 +11,7 @@ import (
"time"
"GoNavi-Wails/internal/connection"
+ "GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
@@ -20,6 +21,7 @@ import (
type OracleDB struct {
conn *sql.DB
pingTimeout time.Duration
+ forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
}
func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
@@ -29,28 +31,6 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
database = config.User // Default to user service/schema if empty?
}
- if config.UseSSH {
- _, err := ssh.RegisterSSHNetwork(config.SSH)
- if err == nil {
- // Oracle driver might not support custom dialer via DSN easily without extra config
- // But go-ora v2 supports some advanced options.
- // For simplicity, we assume standard TCP or we might need a workaround for SSH.
- // go-ora v2 is pure Go, so we can potentially use a custom dialer if we manually open.
- // But for now, let's just use the address.
- // SSH tunneling via net.Dialer override is complex in sql.Open("oracle", ...).
- // We might need to forward a local port if using SSH.
- // Since ssh.RegisterSSHNetwork creates a custom network "ssh-via-...",
- // we need to see if go-ora supports custom networks.
- // Checking go-ora docs (simulated): It supports "unix" and "tcp".
- // We might need to map the custom network to a local proxy.
- // For now, we will assume direct connection or handle SSH separately later.
- // We'll leave the protocol implementation as is in MySQL for now, hoping go-ora uses standard net.Dial.
- // Note: go-ora connection string: oracle://user:pass@host:port/service
- // It parses host/port. It doesn't easily take a custom "network" parameter in URL.
- // We will proceed with standard TCP string.
- }
- }
-
u := &url.URL{
Scheme: "oracle",
Host: net.JoinHostPort(config.Host, strconv.Itoa(config.Port)),
@@ -62,7 +42,42 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
}
func (o *OracleDB) Connect(config connection.ConnectionConfig) error {
- dsn := o.getDSN(config)
+ var dsn string
+ var err error
+
+ if config.UseSSH {
+ // Create SSH tunnel with local port forwarding
+ logger.Infof("Oracle 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
+
+ forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
+ if err != nil {
+ return fmt.Errorf("创建 SSH 隧道失败:%w", err)
+ }
+ o.forwarder = forwarder
+
+ // Parse local address
+ host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
+ if err != nil {
+ return fmt.Errorf("解析本地转发地址失败:%w", err)
+ }
+
+ port, err := strconv.Atoi(portStr)
+ if err != nil {
+ return fmt.Errorf("解析本地端口失败:%w", err)
+ }
+
+ // Create a modified config pointing to local forwarder
+ localConfig := config
+ localConfig.Host = host
+ localConfig.Port = port
+ localConfig.UseSSH = false
+
+ dsn = o.getDSN(localConfig)
+ logger.Infof("Oracle 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
+ } else {
+ dsn = o.getDSN(config)
+ }
+
db, err := sql.Open("oracle", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
@@ -76,6 +91,15 @@ func (o *OracleDB) Connect(config connection.ConnectionConfig) error {
}
func (o *OracleDB) Close() error {
+ // Close SSH forwarder first if exists
+ if o.forwarder != nil {
+ if err := o.forwarder.Close(); err != nil {
+ logger.Warnf("关闭 Oracle SSH 端口转发失败:%v", err)
+ }
+ o.forwarder = nil
+ }
+
+ // Then close database connection
if o.conn != nil {
return o.conn.Close()
}
@@ -119,33 +143,7 @@ func (o *OracleDB) Query(query string) ([]map[string]interface{}, []string, erro
return nil, nil, err
}
defer rows.Close()
-
- columns, err := rows.Columns()
- if err != nil {
- return nil, nil, err
- }
-
- var resultData []map[string]interface{}
-
- for rows.Next() {
- values := make([]interface{}, len(columns))
- valuePtrs := make([]interface{}, len(columns))
- for i := range columns {
- valuePtrs[i] = &values[i]
- }
-
- if err := rows.Scan(valuePtrs...); err != nil {
- continue
- }
-
- entry := make(map[string]interface{})
- for i, col := range columns {
- entry[col] = normalizeQueryValue(values[i])
- }
- resultData = append(resultData, entry)
- }
-
- return resultData, columns, nil
+ return scanRows(rows)
}
func (o *OracleDB) ExecContext(ctx context.Context, query string) (int64, error) {
diff --git a/internal/db/postgres_impl.go b/internal/db/postgres_impl.go
index 9ade7f8..26fed51 100644
--- a/internal/db/postgres_impl.go
+++ b/internal/db/postgres_impl.go
@@ -11,16 +11,21 @@ import (
"time"
"GoNavi-Wails/internal/connection"
+ "GoNavi-Wails/internal/logger"
+ "GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
_ "github.com/lib/pq"
)
+
type PostgresDB struct {
conn *sql.DB
pingTimeout time.Duration
+ forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
}
+
func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string {
// postgres://user:password@host:port/dbname?sslmode=disable
dbname := config.Database
@@ -43,7 +48,42 @@ func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string {
}
func (p *PostgresDB) Connect(config connection.ConnectionConfig) error {
- dsn := p.getDSN(config)
+ var dsn string
+ var err error
+
+ if config.UseSSH {
+ // Create SSH tunnel with local port forwarding
+ logger.Infof("PostgreSQL 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
+
+ forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
+ if err != nil {
+ return fmt.Errorf("创建 SSH 隧道失败:%w", err)
+ }
+ p.forwarder = forwarder
+
+ // Parse local address
+ host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
+ if err != nil {
+ return fmt.Errorf("解析本地转发地址失败:%w", err)
+ }
+
+ port, err := strconv.Atoi(portStr)
+ if err != nil {
+ return fmt.Errorf("解析本地端口失败:%w", err)
+ }
+
+ // Create a modified config pointing to local forwarder
+ localConfig := config
+ localConfig.Host = host
+ localConfig.Port = port
+ localConfig.UseSSH = false // Disable SSH flag for DSN generation
+
+ dsn = p.getDSN(localConfig)
+ logger.Infof("PostgreSQL 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
+ } else {
+ dsn = p.getDSN(config)
+ }
+
db, err := sql.Open("postgres", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
@@ -58,7 +98,17 @@ func (p *PostgresDB) Connect(config connection.ConnectionConfig) error {
return nil
}
+
func (p *PostgresDB) Close() error {
+ // Close SSH forwarder first if exists
+ if p.forwarder != nil {
+ if err := p.forwarder.Close(); err != nil {
+ logger.Warnf("关闭 PostgreSQL SSH 端口转发失败:%v", err)
+ }
+ p.forwarder = nil
+ }
+
+ // Then close database connection
if p.conn != nil {
return p.conn.Close()
}
@@ -102,33 +152,7 @@ func (p *PostgresDB) Query(query string) ([]map[string]interface{}, []string, er
return nil, nil, err
}
defer rows.Close()
-
- columns, err := rows.Columns()
- if err != nil {
- return nil, nil, err
- }
-
- var resultData []map[string]interface{}
-
- for rows.Next() {
- values := make([]interface{}, len(columns))
- valuePtrs := make([]interface{}, len(columns))
- for i := range columns {
- valuePtrs[i] = &values[i]
- }
-
- if err := rows.Scan(valuePtrs...); err != nil {
- continue
- }
-
- entry := make(map[string]interface{})
- for i, col := range columns {
- entry[col] = normalizeQueryValue(values[i])
- }
- resultData = append(resultData, entry)
- }
-
- return resultData, columns, nil
+ return scanRows(rows)
}
func (p *PostgresDB) ExecContext(ctx context.Context, query string) (int64, error) {
diff --git a/internal/db/query_value.go b/internal/db/query_value.go
index 764ccbf..d4dde25 100644
--- a/internal/db/query_value.go
+++ b/internal/db/query_value.go
@@ -2,6 +2,8 @@ package db
import (
"encoding/hex"
+ "fmt"
+ "strings"
"unicode"
"unicode/utf8"
)
@@ -9,13 +11,17 @@ import (
// normalizeQueryValue normalizes driver-returned values for UI/JSON transport.
// 当前主要处理 []byte:如果是可读文本则转为 string,否则转为十六进制字符串,避免前端出现“空白值”。
func normalizeQueryValue(v interface{}) interface{} {
+ return normalizeQueryValueWithDBType(v, "")
+}
+
+func normalizeQueryValueWithDBType(v interface{}, databaseTypeName string) interface{} {
if b, ok := v.([]byte); ok {
- return bytesToReadableString(b)
+ return bytesToDisplayValue(b, databaseTypeName)
}
return v
}
-func bytesToReadableString(b []byte) interface{} {
+func bytesToDisplayValue(b []byte, databaseTypeName string) interface{} {
if b == nil {
return nil
}
@@ -23,6 +29,18 @@ func bytesToReadableString(b []byte) interface{} {
return ""
}
+ dbType := strings.ToUpper(strings.TrimSpace(databaseTypeName))
+ if isBitLikeDBType(dbType) {
+ if u, ok := bytesToUint64(b); ok {
+ // JS number precision is limited; keep large bitmasks as string.
+ const maxSafeInteger = 9007199254740991 // 2^53 - 1
+ if u <= maxSafeInteger {
+ return int64(u)
+ }
+ return fmt.Sprintf("%d", u)
+ }
+ }
+
if utf8.Valid(b) {
s := string(b)
if isMostlyPrintable(s) {
@@ -30,9 +48,47 @@ func bytesToReadableString(b []byte) interface{} {
}
}
+ // Fallback: some drivers return BIT(1) as []byte{0} / []byte{1} without type info.
+ if dbType == "" && len(b) == 1 && (b[0] == 0 || b[0] == 1) {
+ return int64(b[0])
+ }
+
+ return bytesToReadableString(b)
+}
+
+func bytesToReadableString(b []byte) interface{} {
+ if b == nil {
+ return nil
+ }
+ if len(b) == 0 {
+ return ""
+ }
return "0x" + hex.EncodeToString(b)
}
+func isBitLikeDBType(typeName string) bool {
+ if typeName == "" {
+ return false
+ }
+ switch typeName {
+ case "BIT", "VARBIT":
+ return true
+ default:
+ }
+ return strings.HasPrefix(typeName, "BIT")
+}
+
+func bytesToUint64(b []byte) (uint64, bool) {
+ if len(b) == 0 || len(b) > 8 {
+ return 0, false
+ }
+ var u uint64
+ for _, v := range b {
+ u = (u << 8) | uint64(v)
+ }
+ return u, true
+}
+
func isMostlyPrintable(s string) bool {
if s == "" {
return true
diff --git a/internal/db/query_value_test.go b/internal/db/query_value_test.go
new file mode 100644
index 0000000..1b2c140
--- /dev/null
+++ b/internal/db/query_value_test.go
@@ -0,0 +1,44 @@
+package db
+
+import "testing"
+
+func TestNormalizeQueryValueWithDBType_BitBytes(t *testing.T) {
+ v := normalizeQueryValueWithDBType([]byte{0x00}, "BIT")
+ if v != int64(0) {
+ t.Fatalf("BIT 0x00 期望为 0,实际=%v(%T)", v, v)
+ }
+
+ v = normalizeQueryValueWithDBType([]byte{0x01}, "bit")
+ if v != int64(1) {
+ t.Fatalf("BIT 0x01 期望为 1,实际=%v(%T)", v, v)
+ }
+
+ v = normalizeQueryValueWithDBType([]byte{0x01, 0x02}, "BIT VARYING")
+ if v != int64(258) {
+ t.Fatalf("BIT 0x0102 期望为 258,实际=%v(%T)", v, v)
+ }
+}
+
+func TestNormalizeQueryValueWithDBType_BitLargeAsString(t *testing.T) {
+ v := normalizeQueryValueWithDBType([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, "BIT")
+ if s, ok := v.(string); !ok || s != "18446744073709551615" {
+ t.Fatalf("BIT 0xffffffffffffffff 期望为 string(18446744073709551615),实际=%v(%T)", v, v)
+ }
+}
+
+func TestNormalizeQueryValueWithDBType_ByteFallbacks(t *testing.T) {
+ v := normalizeQueryValueWithDBType([]byte("abc"), "")
+ if v != "abc" {
+ t.Fatalf("文本 []byte 期望返回 string,实际=%v(%T)", v, v)
+ }
+
+ v = normalizeQueryValueWithDBType([]byte{0x00}, "")
+ if v != int64(0) {
+ t.Fatalf("未知类型 0x00 期望返回 0,实际=%v(%T)", v, v)
+ }
+
+ v = normalizeQueryValueWithDBType([]byte{0xff}, "")
+ if v != "0xff" {
+ t.Fatalf("未知类型 0xff 期望返回 0xff,实际=%v(%T)", v, v)
+ }
+}
diff --git a/internal/db/scan_rows.go b/internal/db/scan_rows.go
index d77bab0..d629e0a 100644
--- a/internal/db/scan_rows.go
+++ b/internal/db/scan_rows.go
@@ -1,6 +1,8 @@
package db
-import "database/sql"
+import (
+ "database/sql"
+)
func scanRows(rows *sql.Rows) ([]map[string]interface{}, []string, error) {
columns, err := rows.Columns()
@@ -8,6 +10,11 @@ func scanRows(rows *sql.Rows) ([]map[string]interface{}, []string, error) {
return nil, nil, err
}
+ colTypes, err := rows.ColumnTypes()
+ if err != nil || len(colTypes) != len(columns) {
+ colTypes = nil
+ }
+
resultData := make([]map[string]interface{}, 0)
for rows.Next() {
@@ -23,7 +30,11 @@ func scanRows(rows *sql.Rows) ([]map[string]interface{}, []string, error) {
entry := make(map[string]interface{}, len(columns))
for i, col := range columns {
- entry[col] = normalizeQueryValue(values[i])
+ dbTypeName := ""
+ if colTypes != nil && i < len(colTypes) && colTypes[i] != nil {
+ dbTypeName = colTypes[i].DatabaseTypeName()
+ }
+ entry[col] = normalizeQueryValueWithDBType(values[i], dbTypeName)
}
resultData = append(resultData, entry)
}
@@ -33,4 +44,3 @@ func scanRows(rows *sql.Rows) ([]map[string]interface{}, []string, error) {
}
return resultData, columns, nil
}
-
diff --git a/internal/db/sqlite_impl.go b/internal/db/sqlite_impl.go
index 57a8eec..3fe8dc6 100644
--- a/internal/db/sqlite_impl.go
+++ b/internal/db/sqlite_impl.go
@@ -78,33 +78,7 @@ func (s *SQLiteDB) Query(query string) ([]map[string]interface{}, []string, erro
return nil, nil, err
}
defer rows.Close()
-
- columns, err := rows.Columns()
- if err != nil {
- return nil, nil, err
- }
-
- var resultData []map[string]interface{}
-
- for rows.Next() {
- values := make([]interface{}, len(columns))
- valuePtrs := make([]interface{}, len(columns))
- for i := range columns {
- valuePtrs[i] = &values[i]
- }
-
- if err := rows.Scan(valuePtrs...); err != nil {
- continue
- }
-
- entry := make(map[string]interface{})
- for i, col := range columns {
- entry[col] = normalizeQueryValue(values[i])
- }
- resultData = append(resultData, entry)
- }
-
- return resultData, columns, nil
+ return scanRows(rows)
}
func (s *SQLiteDB) ExecContext(ctx context.Context, query string) (int64, error) {
diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go
index 2d6fe64..51ad364 100644
--- a/internal/ssh/ssh.go
+++ b/internal/ssh/ssh.go
@@ -3,8 +3,10 @@ package ssh
import (
"context"
"fmt"
+ "io"
"net"
"os"
+ "sync"
"time"
"GoNavi-Wails/internal/connection"
@@ -110,3 +112,264 @@ func RegisterSSHNetwork(sshConfig connection.SSHConfig) (string, error) {
return netName, nil
}
+
+// sshClientCache stores SSH clients to avoid creating multiple connections
+var (
+ sshClientCache = make(map[string]*ssh.Client)
+ sshClientCacheMu sync.RWMutex
+ localForwarders = make(map[string]*LocalForwarder)
+ forwarderMu sync.RWMutex
+)
+
+// LocalForwarder represents a local port forwarder through SSH
+type LocalForwarder struct {
+ LocalAddr string
+ RemoteAddr string
+ SSHClient *ssh.Client
+ listener net.Listener
+ closeChan chan struct{}
+ closeOnce sync.Once // 防止重复关闭
+ closed bool // 关闭状态标记
+ closedMu sync.RWMutex
+}
+
+// NewLocalForwarder creates a new local port forwarder
+// It listens on a random local port and forwards all connections through SSH tunnel
+func NewLocalForwarder(sshConfig connection.SSHConfig, remoteHost string, remotePort int) (*LocalForwarder, error) {
+ client, err := GetOrCreateSSHClient(sshConfig)
+ if err != nil {
+ return nil, fmt.Errorf("建立 SSH 连接失败:%w", err)
+ }
+
+ // Listen on localhost with a random port
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ return nil, fmt.Errorf("创建本地监听器失败:%w", err)
+ }
+
+ localAddr := listener.Addr().String()
+ remoteAddr := fmt.Sprintf("%s:%d", remoteHost, remotePort)
+
+ forwarder := &LocalForwarder{
+ LocalAddr: localAddr,
+ RemoteAddr: remoteAddr,
+ SSHClient: client,
+ listener: listener,
+ closeChan: make(chan struct{}),
+ }
+
+ // Start forwarding in background
+ go forwarder.forward()
+
+ logger.Infof("已创建 SSH 端口转发:本地 %s -> 远程 %s", localAddr, remoteAddr)
+ return forwarder, nil
+}
+
+// forward handles the port forwarding
+func (f *LocalForwarder) forward() {
+ for {
+ localConn, err := f.listener.Accept()
+ if err != nil {
+ // Check if we're shutting down
+ select {
+ case <-f.closeChan:
+ return
+ default:
+ logger.Warnf("接受本地连接失败:%v", err)
+ // listener可能已关闭,退出循环
+ return
+ }
+ }
+
+ go f.handleConnection(localConn)
+ }
+}
+
+// handleConnection handles a single connection
+func (f *LocalForwarder) handleConnection(localConn net.Conn) {
+ defer localConn.Close()
+
+ // Connect to remote through SSH with timeout
+ remoteConn, err := f.SSHClient.Dial("tcp", f.RemoteAddr)
+ if err != nil {
+ logger.Warnf("通过 SSH 连接到远程 %s 失败:%v", f.RemoteAddr, err)
+ return
+ }
+ defer remoteConn.Close()
+
+ // Bidirectional copy with error channel
+ errc := make(chan error, 2)
+
+ // Copy from local to remote
+ go func() {
+ _, err := io.Copy(remoteConn, localConn)
+ if err != nil {
+ logger.Warnf("本地->远程数据复制错误:%v", err)
+ }
+ errc <- err
+ }()
+
+ // Copy from remote to local
+ go func() {
+ _, err := io.Copy(localConn, remoteConn)
+ if err != nil {
+ logger.Warnf("远程->本地数据复制错误:%v", err)
+ }
+ errc <- err
+ }()
+
+ // Wait for BOTH goroutines to complete
+ <-errc
+ <-errc
+}
+
+// Close closes the forwarder (thread-safe, can be called multiple times)
+func (f *LocalForwarder) Close() error {
+ var err error
+ f.closeOnce.Do(func() {
+ f.closedMu.Lock()
+ f.closed = true
+ f.closedMu.Unlock()
+
+ close(f.closeChan)
+ err = f.listener.Close()
+ if err != nil {
+ logger.Warnf("关闭端口转发监听器失败:%v", err)
+ }
+ })
+ return err
+}
+
+// IsClosed returns whether the forwarder is closed
+func (f *LocalForwarder) IsClosed() bool {
+ f.closedMu.RLock()
+ defer f.closedMu.RUnlock()
+ return f.closed
+}
+
+// GetOrCreateLocalForwarder returns a cached forwarder or creates a new one
+func GetOrCreateLocalForwarder(sshConfig connection.SSHConfig, remoteHost string, remotePort int) (*LocalForwarder, error) {
+ key := fmt.Sprintf("%s:%d:%s->%s:%d",
+ sshConfig.Host, sshConfig.Port, sshConfig.User,
+ remoteHost, remotePort)
+
+ forwarderMu.RLock()
+ forwarder, exists := localForwarders[key]
+ forwarderMu.RUnlock()
+
+ // Check if exists and is still valid
+ if exists && forwarder != nil && !forwarder.IsClosed() {
+ logger.Infof("复用已有端口转发:%s", key)
+ return forwarder, nil
+ }
+
+ // Remove stale forwarder from cache
+ if exists {
+ forwarderMu.Lock()
+ delete(localForwarders, key)
+ forwarderMu.Unlock()
+ }
+
+ forwarder, err := NewLocalForwarder(sshConfig, remoteHost, remotePort)
+ if err != nil {
+ return nil, err
+ }
+
+ forwarderMu.Lock()
+ localForwarders[key] = forwarder
+ forwarderMu.Unlock()
+
+ return forwarder, nil
+}
+
+// CloseAllForwarders closes all local forwarders
+func CloseAllForwarders() {
+ forwarderMu.Lock()
+ defer forwarderMu.Unlock()
+
+ for key, forwarder := range localForwarders {
+ if forwarder != nil {
+ _ = forwarder.Close()
+ logger.Infof("已关闭端口转发:%s", key)
+ }
+ }
+ localForwarders = make(map[string]*LocalForwarder)
+}
+
+
+// getSSHClientCacheKey generates a unique cache key for SSH config
+func getSSHClientCacheKey(config connection.SSHConfig) string {
+ return fmt.Sprintf("%s:%d:%s", config.Host, config.Port, config.User)
+}
+
+// GetOrCreateSSHClient returns a cached SSH client or creates a new one
+func GetOrCreateSSHClient(config connection.SSHConfig) (*ssh.Client, error) {
+ key := getSSHClientCacheKey(config)
+
+ sshClientCacheMu.RLock()
+ client, exists := sshClientCache[key]
+ sshClientCacheMu.RUnlock()
+
+ if exists && client != nil {
+ // Test if connection is still alive by creating a test session
+ session, err := client.NewSession()
+ if err == nil {
+ session.Close()
+ logger.Infof("复用已有 SSH 连接:%s", key)
+ return client, nil
+ }
+ // Connection is dead, remove from cache
+ logger.Warnf("SSH 连接已断开,重新建立:%s (错误: %v)", key, err)
+ sshClientCacheMu.Lock()
+ delete(sshClientCache, key)
+ sshClientCacheMu.Unlock()
+ // Try to close the dead client
+ _ = client.Close()
+ }
+
+ // Create new SSH client
+ client, err := connectSSH(config)
+ if err != nil {
+ return nil, err
+ }
+
+ // Cache the client
+ sshClientCacheMu.Lock()
+ sshClientCache[key] = client
+ sshClientCacheMu.Unlock()
+
+ logger.Infof("已缓存 SSH 连接:%s", key)
+ return client, nil
+}
+
+// DialThroughSSH creates a connection through SSH tunnel
+// This is a generic dialer that can be used by any database driver
+func DialThroughSSH(config connection.SSHConfig, network, address string) (net.Conn, error) {
+ client, err := GetOrCreateSSHClient(config)
+ if err != nil {
+ return nil, fmt.Errorf("建立 SSH 连接失败:%w", err)
+ }
+
+ conn, err := client.Dial(network, address)
+ if err != nil {
+ return nil, fmt.Errorf("通过 SSH 隧道连接到 %s 失败:%w", address, err)
+ }
+
+ logger.Infof("已通过 SSH 隧道连接到:%s", address)
+ return conn, nil
+}
+
+// CloseAllSSHClients closes all cached SSH clients
+func CloseAllSSHClients() {
+ sshClientCacheMu.Lock()
+ defer sshClientCacheMu.Unlock()
+
+ for key, client := range sshClientCache {
+ if client != nil {
+ _ = client.Close()
+ logger.Infof("已关闭 SSH 连接:%s", key)
+ }
+ }
+ sshClientCache = make(map[string]*ssh.Client)
+}
+