From 71e5de0cdc9ff7d2db2f84a57c66c56ecf1c9f6f Mon Sep 17 00:00:00 2001 From: Syngnat Date: Wed, 4 Feb 2026 14:35:31 +0800 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor(database/ssh):=20?= =?UTF-8?q?SSH=E9=9A=A7=E9=81=93=E6=9E=B6=E6=9E=84=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E4=B8=8E=E5=A4=9A=E6=95=B0=E6=8D=AE=E6=BA=90=E9=80=82=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 架构升级:从driver专属拨号器改为通用本地端口转发模式 - 并发安全:sync.Once保护Close操作,RWMutex保护状态访问,双向errc等待 - 连接池化:GetOrCreateLocalForwarder/GetOrCreateSSHClient实现缓存复用 - SQL安全:kingbase_impl.go引入esc函数,防止双引号注入(""ldf_server""问题) - Schema动态化:三级fallback(schema.table解析→dbName参数→current_schema()) - 代码复用:scanRows统一行扫描逻辑,normalizeQueryValueWithDBType增强类型处理 Close #40 --- frontend/src/components/ConnectionModal.tsx | 4 +- internal/app/app.go | 3 - internal/app/sql_sanitize.go | 61 +-- internal/app/sql_sanitize_test.go | 18 + internal/db/custom_impl.go | 28 +- internal/db/dameng_impl.go | 86 +++-- internal/db/kingbase_impl.go | 387 +++++++++++++++----- internal/db/mysql_impl.go | 28 +- internal/db/oracle_impl.go | 98 +++-- internal/db/postgres_impl.go | 80 ++-- internal/db/query_value.go | 60 ++- internal/db/query_value_test.go | 44 +++ internal/db/scan_rows.go | 16 +- internal/db/sqlite_impl.go | 28 +- internal/ssh/ssh.go | 263 +++++++++++++ 15 files changed, 879 insertions(+), 325 deletions(-) create mode 100644 internal/db/query_value_test.go diff --git a/frontend/src/components/ConnectionModal.tsx b/frontend/src/components/ConnectionModal.tsx index 8fa0bb9..e06a834 100644 --- a/frontend/src/components/ConnectionModal.tsx +++ b/frontend/src/components/ConnectionModal.tsx @@ -264,8 +264,8 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal {useSSH && (
- - + + diff --git a/internal/app/app.go b/internal/app/app.go index 622bfc2..c61169e 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -141,9 +141,6 @@ func (a *App) getDatabase(config connection.ConnectionConfig) (db.Database, erro if len(shortKey) > 12 { shortKey = shortKey[:12] } - if config.UseSSH && config.Type != "mysql" { - logger.Warnf("当前仅 MySQL 支持内置 SSH 直连,其他类型请使用本地端口转发:%s", formatConnSummary(config)) - } logger.Infof("获取数据库连接:%s 缓存Key=%s", formatConnSummary(config), shortKey) a.mu.Lock() diff --git a/internal/app/sql_sanitize.go b/internal/app/sql_sanitize.go index 9ba89e3..4e37ed5 100644 --- a/internal/app/sql_sanitize.go +++ b/internal/app/sql_sanitize.go @@ -8,14 +8,26 @@ import ( func sanitizeSQLForPgLike(dbType string, query string) string { switch strings.ToLower(strings.TrimSpace(dbType)) { case "postgres", "kingbase": - return fixBrokenDoubleDoubleQuotedIdent(query) + // 有些情况下会出现多层重复引用(例如 """"schema"""" 或 ""schema"""),单次修复不一定收敛。 + // 这里做有限次数的迭代,直到输出不再变化。 + out := query + for i := 0; i < 3; i++ { + fixed := fixBrokenDoubleDoubleQuotedIdent(out) + if fixed == out { + break + } + out = fixed + } + return out default: return query } } // fixBrokenDoubleDoubleQuotedIdent fixes accidental identifiers like: -// SELECT * FROM ""schema"".""table"" +// +// SELECT * FROM ""schema"".""table"" +// // which can be produced when a quoted identifier gets wrapped by quotes again. // // It is intentionally conservative: @@ -124,20 +136,17 @@ func fixBrokenDoubleDoubleQuotedIdent(query string) string { } } - // Fix: ""ident"" -> "ident" (only when it looks like a plain identifier) - if ch == '"' && next == '"' { - prevIsQuote := i > 0 && query[i-1] == '"' - nextIsQuote := i+2 < len(query) && query[i+2] == '"' - if !prevIsQuote && !nextIsQuote { + if ch == '"' { + // Fix: ""ident"" -> "ident" (only when it looks like a plain identifier) + // Also handle variants like ""ident""" / """"ident"""" (extra quotes at either side). + if next == '"' { if replacement, advance, ok := tryFixDoubleDoubleQuotedIdent(query, i); ok { b.WriteString(replacement) i = advance - 1 continue } } - } - if ch == '"' { b.WriteByte(ch) inDoubleIdent = true continue @@ -150,7 +159,8 @@ func fixBrokenDoubleDoubleQuotedIdent(query string) string { } func tryFixDoubleDoubleQuotedIdent(query string, start int) (replacement string, advance int, ok bool) { - // start points at the first quote of `""...""` + // start points at the first quote of a broken identifier, usually like: + // ""ident"" / ""ident""" / """"ident"""" if start < 0 || start+1 >= len(query) { return "", 0, false } @@ -160,24 +170,31 @@ func tryFixDoubleDoubleQuotedIdent(query string, start int) (replacement string, if start > 0 && query[start-1] == '"' { return "", 0, false } - if start+2 < len(query) && query[start+2] == '"' { + + runLen := 0 + for start+runLen < len(query) && query[start+runLen] == '"' { + runLen++ + } + if runLen < 2 || runLen%2 == 1 { + // Odd run (e.g. """...) can be a valid quoted identifier with escaped quotes. return "", 0, false } - contentStart := start + 2 + contentStart := start + runLen j := contentStart - for j+1 < len(query) { - if query[j] == '"' && query[j+1] == '"' { - // ensure closing pair is not part of a triple quote - if j+2 < len(query) && query[j+2] == '"' { - j++ - continue + for j < len(query) { + if query[j] == '"' { + endRunLen := 0 + for j+endRunLen < len(query) && query[j+endRunLen] == '"' { + endRunLen++ } - content := strings.TrimSpace(query[contentStart:j]) - if looksLikeIdentifierContent(content) { - return `"` + content + `"`, j + 2, true + if endRunLen >= 2 { + content := strings.TrimSpace(query[contentStart:j]) + if looksLikeIdentifierContent(content) { + return `"` + content + `"`, j + endRunLen, true + } + return "", 0, false } - return "", 0, false } // Fast abort: identifier-like content should not span lines. if query[j] == '\n' || query[j] == '\r' { diff --git a/internal/app/sql_sanitize_test.go b/internal/app/sql_sanitize_test.go index fbee1f6..426ddda 100644 --- a/internal/app/sql_sanitize_test.go +++ b/internal/app/sql_sanitize_test.go @@ -11,6 +11,24 @@ func TestSanitizeSQLForPgLike_FixesBrokenDoubleDoubleQuotes(t *testing.T) { } } +func TestSanitizeSQLForPgLike_FixesBrokenDoubleDoubleQuotes_WithExtraQuotes(t *testing.T) { + in := `SELECT * FROM ""ldf_server""".""t_user"" LIMIT 1` + out := sanitizeSQLForPgLike("kingbase", in) + want := `SELECT * FROM "ldf_server"."t_user" LIMIT 1` + if out != want { + t.Fatalf("unexpected sanitize output:\nIN: %s\nOUT: %s\nWANT: %s", in, out, want) + } +} + +func TestSanitizeSQLForPgLike_FixesBrokenDoubleDoubleQuotes_WithQuadQuotes(t *testing.T) { + in := `SELECT * FROM """"ldf_server"""".""t_user"" LIMIT 1` + out := sanitizeSQLForPgLike("kingbase", in) + want := `SELECT * FROM "ldf_server"."t_user" LIMIT 1` + if out != want { + t.Fatalf("unexpected sanitize output:\nIN: %s\nOUT: %s\nWANT: %s", in, out, want) + } +} + func TestSanitizeSQLForPgLike_DoesNotTouchEscapedQuotesInsideIdentifier(t *testing.T) { in := `SELECT "a""b" FROM "t""x"` out := sanitizeSQLForPgLike("postgres", in) diff --git a/internal/db/custom_impl.go b/internal/db/custom_impl.go index 495ff95..bc9f7e9 100644 --- a/internal/db/custom_impl.go +++ b/internal/db/custom_impl.go @@ -82,33 +82,7 @@ func (c *CustomDB) Query(query string) ([]map[string]interface{}, []string, erro return nil, nil, err } defer rows.Close() - - columns, err := rows.Columns() - if err != nil { - return nil, nil, err - } - - var resultData []map[string]interface{} - - for rows.Next() { - values := make([]interface{}, len(columns)) - valuePtrs := make([]interface{}, len(columns)) - for i := range columns { - valuePtrs[i] = &values[i] - } - - if err := rows.Scan(valuePtrs...); err != nil { - continue - } - - entry := make(map[string]interface{}) - for i, col := range columns { - entry[col] = normalizeQueryValue(values[i]) - } - resultData = append(resultData, entry) - } - - return resultData, columns, nil + return scanRows(rows) } func (c *CustomDB) ExecContext(ctx context.Context, query string) (int64, error) { diff --git a/internal/db/dameng_impl.go b/internal/db/dameng_impl.go index ce19473..52a263f 100644 --- a/internal/db/dameng_impl.go +++ b/internal/db/dameng_impl.go @@ -11,6 +11,7 @@ import ( "time" "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/logger" "GoNavi-Wails/internal/ssh" "GoNavi-Wails/internal/utils" @@ -20,6 +21,7 @@ import ( type DamengDB struct { conn *sql.DB pingTimeout time.Duration + forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder } func (d *DamengDB) getDSN(config connection.ConnectionConfig) string { @@ -27,16 +29,6 @@ func (d *DamengDB) getDSN(config connection.ConnectionConfig) string { // or dm://user:password@host:port address := net.JoinHostPort(config.Host, strconv.Itoa(config.Port)) - if config.UseSSH { - // SSH logic similar to others, assumes port forwarding - _, err := ssh.RegisterSSHNetwork(config.SSH) - if err == nil { - // DM driver likely uses standard net.Dial, so we might need a local listener - // or assume port forwarding is handled externally or implicitly via "tcp" override if driver allows. - // Similar to Oracle, we skip complex custom dialer injection for now. - } - } - escapedPassword := url.PathEscape(config.Password) q := url.Values{} if config.Database != "" { @@ -56,7 +48,42 @@ func (d *DamengDB) getDSN(config connection.ConnectionConfig) string { } func (d *DamengDB) Connect(config connection.ConnectionConfig) error { - dsn := d.getDSN(config) + var dsn string + var err error + + if config.UseSSH { + // Create SSH tunnel with local port forwarding + logger.Infof("达梦数据库使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User) + + forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port) + if err != nil { + return fmt.Errorf("创建 SSH 隧道失败:%w", err) + } + d.forwarder = forwarder + + // Parse local address + host, portStr, err := net.SplitHostPort(forwarder.LocalAddr) + if err != nil { + return fmt.Errorf("解析本地转发地址失败:%w", err) + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return fmt.Errorf("解析本地端口失败:%w", err) + } + + // Create a modified config pointing to local forwarder + localConfig := config + localConfig.Host = host + localConfig.Port = port + localConfig.UseSSH = false + + dsn = d.getDSN(localConfig) + logger.Infof("达梦数据库通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port) + } else { + dsn = d.getDSN(config) + } + db, err := sql.Open("dm", dsn) if err != nil { return fmt.Errorf("打开数据库连接失败:%w", err) @@ -70,6 +97,15 @@ func (d *DamengDB) Connect(config connection.ConnectionConfig) error { } func (d *DamengDB) Close() error { + // Close SSH forwarder first if exists + if d.forwarder != nil { + if err := d.forwarder.Close(); err != nil { + logger.Warnf("关闭达梦数据库 SSH 端口转发失败:%v", err) + } + d.forwarder = nil + } + + // Then close database connection if d.conn != nil { return d.conn.Close() } @@ -113,33 +149,7 @@ func (d *DamengDB) Query(query string) ([]map[string]interface{}, []string, erro return nil, nil, err } defer rows.Close() - - columns, err := rows.Columns() - if err != nil { - return nil, nil, err - } - - var resultData []map[string]interface{} - - for rows.Next() { - values := make([]interface{}, len(columns)) - valuePtrs := make([]interface{}, len(columns)) - for i := range columns { - valuePtrs[i] = &values[i] - } - - if err := rows.Scan(valuePtrs...); err != nil { - continue - } - - entry := make(map[string]interface{}) - for i, col := range columns { - entry[col] = normalizeQueryValue(values[i]) - } - resultData = append(resultData, entry) - } - - return resultData, columns, nil + return scanRows(rows) } func (d *DamengDB) ExecContext(ctx context.Context, query string) (int64, error) { diff --git a/internal/db/kingbase_impl.go b/internal/db/kingbase_impl.go index 245fcbb..f9e423a 100644 --- a/internal/db/kingbase_impl.go +++ b/internal/db/kingbase_impl.go @@ -4,10 +4,13 @@ import ( "context" "database/sql" "fmt" + "net" + "strconv" "strings" "time" "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/logger" "GoNavi-Wails/internal/ssh" "GoNavi-Wails/internal/utils" @@ -17,6 +20,7 @@ import ( type KingbaseDB struct { conn *sql.DB pingTimeout time.Duration + forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder } func quoteConnValue(v string) string { @@ -58,20 +62,6 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string { address := config.Host port := config.Port - if config.UseSSH { - netName, err := ssh.RegisterSSHNetwork(config.SSH) - if err == nil { - // Kingbase/Postgres lib/pq allows custom dialer via "host" if using unix socket, - // but for custom network it's harder. - // Ideally we use a local forwarder. - // For now, we assume standard TCP or handle SSH externally. - // If we implement the net.Dial override for "kingbase" driver (which might use lib/pq internally), - // we might need to check if it supports "cloudsql" style or similar custom dialers. - // Similar to others, skipping SSH deep integration here for now. - _ = netName - } - } - // Construct DSN dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable connect_timeout=%d", quoteConnValue(address), @@ -86,7 +76,42 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string { } func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error { - dsn := k.getDSN(config) + var dsn string + var err error + + if config.UseSSH { + // Create SSH tunnel with local port forwarding + logger.Infof("人大金仓使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User) + + forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port) + if err != nil { + return fmt.Errorf("创建 SSH 隧道失败:%w", err) + } + k.forwarder = forwarder + + // Parse local address + host, portStr, err := net.SplitHostPort(forwarder.LocalAddr) + if err != nil { + return fmt.Errorf("解析本地转发地址失败:%w", err) + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return fmt.Errorf("解析本地端口失败:%w", err) + } + + // Create a modified config pointing to local forwarder + localConfig := config + localConfig.Host = host + localConfig.Port = port + localConfig.UseSSH = false + + dsn = k.getDSN(localConfig) + logger.Infof("人大金仓通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port) + } else { + dsn = k.getDSN(config) + } + // Open using "kingbase" driver db, err := sql.Open("kingbase", dsn) if err != nil { @@ -101,6 +126,15 @@ func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error { } func (k *KingbaseDB) Close() error { + // Close SSH forwarder first if exists + if k.forwarder != nil { + if err := k.forwarder.Close(); err != nil { + logger.Warnf("关闭人大金仓 SSH 端口转发失败:%v", err) + } + k.forwarder = nil + } + + // Then close database connection if k.conn != nil { return k.conn.Close() } @@ -144,33 +178,7 @@ func (k *KingbaseDB) Query(query string) ([]map[string]interface{}, []string, er return nil, nil, err } defer rows.Close() - - columns, err := rows.Columns() - if err != nil { - return nil, nil, err - } - - var resultData []map[string]interface{} - - for rows.Next() { - values := make([]interface{}, len(columns)) - valuePtrs := make([]interface{}, len(columns)) - for i := range columns { - valuePtrs[i] = &values[i] - } - - if err := rows.Scan(valuePtrs...); err != nil { - continue - } - - entry := make(map[string]interface{}) - for i, col := range columns { - entry[col] = normalizeQueryValue(values[i]) - } - resultData = append(resultData, entry) - } - - return resultData, columns, nil + return scanRows(rows) } func (k *KingbaseDB) ExecContext(ctx context.Context, query string) (int64, error) { @@ -249,15 +257,84 @@ func (k *KingbaseDB) GetCreateStatement(dbName, tableName string) (string, error } func (k *KingbaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) { - schema := "public" - if dbName != "" { - schema = dbName + // 解析 schema.table 格式 + schema := strings.TrimSpace(dbName) + table := strings.TrimSpace(tableName) + + // 如果 tableName 包含 schema (格式: schema.table) + if parts := strings.SplitN(table, ".", 2); len(parts) == 2 { + parsedSchema := strings.TrimSpace(parts[0]) + parsedTable := strings.TrimSpace(parts[1]) + if parsedSchema != "" && parsedTable != "" { + schema = parsedSchema + table = parsedTable + } } - query := fmt.Sprintf(`SELECT column_name, data_type, is_nullable, column_default - FROM information_schema.columns - WHERE table_schema = '%s' AND table_name = '%s' - ORDER BY ordinal_position`, schema, tableName) + // 如果仍然没有 schema,使用 current_schema() + // 这样可以自动匹配当前连接的 search_path + if schema == "" { + return k.getColumnsWithCurrentSchema(table) + } + + if table == "" { + return nil, fmt.Errorf("table name required") + } + + // 转义函数:处理单引号,移除双引号 + esc := func(s string) string { + // 移除前后的双引号(如果存在) + s = strings.Trim(s, "\"") + // 转义单引号 + return strings.ReplaceAll(s, "'", "''") + } + + query := fmt.Sprintf(`SELECT column_name, data_type, is_nullable, column_default + FROM information_schema.columns + WHERE table_schema = '%s' AND table_name = '%s' + ORDER BY ordinal_position`, esc(schema), esc(table)) + + data, _, err := k.Query(query) + if err != nil { + return nil, err + } + + var columns []connection.ColumnDefinition + for _, row := range data { + col := connection.ColumnDefinition{ + Name: fmt.Sprintf("%v", row["column_name"]), + Type: fmt.Sprintf("%v", row["data_type"]), + Nullable: fmt.Sprintf("%v", row["is_nullable"]), + } + + if row["column_default"] != nil { + def := fmt.Sprintf("%v", row["column_default"]) + col.Default = &def + } + + columns = append(columns, col) + } + return columns, nil +} + +// getColumnsWithCurrentSchema 使用 current_schema() 查询当前schema的表 +func (k *KingbaseDB) getColumnsWithCurrentSchema(tableName string) ([]connection.ColumnDefinition, error) { + table := strings.TrimSpace(tableName) + if table == "" { + return nil, fmt.Errorf("table name required") + } + + // 转义函数 + esc := func(s string) string { + s = strings.Trim(s, "\"") + return strings.ReplaceAll(s, "'", "''") + } + + // 使用 current_schema() 获取当前schema + query := fmt.Sprintf(`SELECT column_name, data_type, is_nullable, column_default + FROM information_schema.columns + WHERE table_schema = current_schema() AND table_name = '%s' + ORDER BY ordinal_position`, esc(table)) data, _, err := k.Query(query) if err != nil { @@ -283,32 +360,76 @@ func (k *KingbaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDe } func (k *KingbaseDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) { - // Postgres/Kingbase index query - query := fmt.Sprintf(` - SELECT - i.relname as index_name, - a.attname as column_name, - ix.indisunique as is_unique - FROM - pg_class t, - pg_class i, - pg_index ix, - pg_attribute a, - pg_namespace n - WHERE - t.oid = ix.indrelid - AND i.oid = ix.indexrelid - AND a.attrelid = t.oid - AND a.attnum = ANY(ix.indkey) - AND t.relkind = 'r' - AND t.relname = '%s' - AND n.oid = t.relnamespace - AND n.nspname = '%s' - `, tableName, "public") // Default to public if dbName (schema) not clear. + // 解析 schema.table 格式 + schema := strings.TrimSpace(dbName) + table := strings.TrimSpace(tableName) - if dbName != "" { - // Update query to use dbName as schema - query = strings.Replace(query, "'public'", fmt.Sprintf("'%s'", dbName), 1) + // 如果 tableName 包含 schema (格式: schema.table) + if parts := strings.SplitN(table, ".", 2); len(parts) == 2 { + parsedSchema := strings.TrimSpace(parts[0]) + parsedTable := strings.TrimSpace(parts[1]) + if parsedSchema != "" && parsedTable != "" { + schema = parsedSchema + table = parsedTable + } + } + + if table == "" { + return nil, fmt.Errorf("table name required") + } + + // 转义函数:处理单引号,移除双引号 + esc := func(s string) string { + s = strings.Trim(s, "\"") + return strings.ReplaceAll(s, "'", "''") + } + + // 构建查询:如果没有指定schema,使用current_schema() + var query string + if schema != "" { + query = fmt.Sprintf(` + SELECT + i.relname as index_name, + a.attname as column_name, + ix.indisunique as is_unique + FROM + pg_class t, + pg_class i, + pg_index ix, + pg_attribute a, + pg_namespace n + WHERE + t.oid = ix.indrelid + AND i.oid = ix.indexrelid + AND a.attrelid = t.oid + AND a.attnum = ANY(ix.indkey) + AND t.relkind = 'r' + AND t.relname = '%s' + AND n.oid = t.relnamespace + AND n.nspname = '%s' + `, esc(table), esc(schema)) + } else { + query = fmt.Sprintf(` + SELECT + i.relname as index_name, + a.attname as column_name, + ix.indisunique as is_unique + FROM + pg_class t, + pg_class i, + pg_index ix, + pg_attribute a, + pg_namespace n + WHERE + t.oid = ix.indrelid + AND i.oid = ix.indexrelid + AND a.attrelid = t.oid + AND a.attnum = ANY(ix.indkey) + AND t.relkind = 'r' + AND t.relname = '%s' + AND n.oid = t.relnamespace + AND n.nspname = current_schema() + `, esc(table)) } data, _, err := k.Query(query) @@ -337,27 +458,67 @@ func (k *KingbaseDB) GetIndexes(dbName, tableName string) ([]connection.IndexDef } func (k *KingbaseDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) { - schema := "public" - if dbName != "" { - schema = dbName + // 解析 schema.table 格式 + schema := strings.TrimSpace(dbName) + table := strings.TrimSpace(tableName) + + // 如果 tableName 包含 schema (格式: schema.table) + if parts := strings.SplitN(table, ".", 2); len(parts) == 2 { + parsedSchema := strings.TrimSpace(parts[0]) + parsedTable := strings.TrimSpace(parts[1]) + if parsedSchema != "" && parsedTable != "" { + schema = parsedSchema + table = parsedTable + } } - query := fmt.Sprintf(` - SELECT - tc.constraint_name, - kcu.column_name, - ccu.table_name AS foreign_table_name, - ccu.column_name AS foreign_column_name - FROM - information_schema.table_constraints AS tc - JOIN information_schema.key_column_usage AS kcu - ON tc.constraint_name = kcu.constraint_name - AND tc.table_schema = kcu.table_schema - JOIN information_schema.constraint_column_usage AS ccu - ON ccu.constraint_name = tc.constraint_name - AND ccu.table_schema = tc.table_schema - WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema='%s'`, - tableName, schema) + if table == "" { + return nil, fmt.Errorf("table name required") + } + + // 转义函数:处理单引号,移除双引号 + esc := func(s string) string { + s = strings.Trim(s, "\"") + return strings.ReplaceAll(s, "'", "''") + } + + // 构建查询:如果没有指定schema,使用current_schema() + var query string + if schema != "" { + query = fmt.Sprintf(` + SELECT + tc.constraint_name, + kcu.column_name, + ccu.table_name AS foreign_table_name, + ccu.column_name AS foreign_column_name + FROM + information_schema.table_constraints AS tc + JOIN information_schema.key_column_usage AS kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + JOIN information_schema.constraint_column_usage AS ccu + ON ccu.constraint_name = tc.constraint_name + AND ccu.table_schema = tc.table_schema + WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema='%s'`, + esc(table), esc(schema)) + } else { + query = fmt.Sprintf(` + SELECT + tc.constraint_name, + kcu.column_name, + ccu.table_name AS foreign_table_name, + ccu.column_name AS foreign_column_name + FROM + information_schema.table_constraints AS tc + JOIN information_schema.key_column_usage AS kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + JOIN information_schema.constraint_column_usage AS ccu + ON ccu.constraint_name = tc.constraint_name + AND ccu.table_schema = tc.table_schema + WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema=current_schema()`, + esc(table)) + } data, _, err := k.Query(query) if err != nil { @@ -379,9 +540,43 @@ func (k *KingbaseDB) GetForeignKeys(dbName, tableName string) ([]connection.Fore } func (k *KingbaseDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) { - query := fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation - FROM information_schema.triggers - WHERE event_object_table = '%s'`, tableName) + // 解析 schema.table 格式 + schema := strings.TrimSpace(dbName) + table := strings.TrimSpace(tableName) + + // 如果 tableName 包含 schema (格式: schema.table) + if parts := strings.SplitN(table, ".", 2); len(parts) == 2 { + parsedSchema := strings.TrimSpace(parts[0]) + parsedTable := strings.TrimSpace(parts[1]) + if parsedSchema != "" && parsedTable != "" { + schema = parsedSchema + table = parsedTable + } + } + + if table == "" { + return nil, fmt.Errorf("table name required") + } + + // 转义函数:处理单引号,移除双引号 + esc := func(s string) string { + s = strings.Trim(s, "\"") + return strings.ReplaceAll(s, "'", "''") + } + + // 构建查询:如果指定了schema,也加上schema条件 + var query string + if schema != "" { + query = fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation + FROM information_schema.triggers + WHERE event_object_table = '%s' AND event_object_schema = '%s'`, + esc(table), esc(schema)) + } else { + query = fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation + FROM information_schema.triggers + WHERE event_object_table = '%s' AND event_object_schema = current_schema()`, + esc(table)) + } data, _, err := k.Query(query) if err != nil { diff --git a/internal/db/mysql_impl.go b/internal/db/mysql_impl.go index 2577d6a..5de70cf 100644 --- a/internal/db/mysql_impl.go +++ b/internal/db/mysql_impl.go @@ -101,33 +101,7 @@ func (m *MySQLDB) Query(query string) ([]map[string]interface{}, []string, error return nil, nil, err } defer rows.Close() - - columns, err := rows.Columns() - if err != nil { - return nil, nil, err - } - - var resultData []map[string]interface{} - - for rows.Next() { - values := make([]interface{}, len(columns)) - valuePtrs := make([]interface{}, len(columns)) - for i := range columns { - valuePtrs[i] = &values[i] - } - - if err := rows.Scan(valuePtrs...); err != nil { - continue - } - - entry := make(map[string]interface{}) - for i, col := range columns { - entry[col] = normalizeQueryValue(values[i]) - } - resultData = append(resultData, entry) - } - - return resultData, columns, nil + return scanRows(rows) } func (m *MySQLDB) ExecContext(ctx context.Context, query string) (int64, error) { diff --git a/internal/db/oracle_impl.go b/internal/db/oracle_impl.go index 454f460..f07c376 100644 --- a/internal/db/oracle_impl.go +++ b/internal/db/oracle_impl.go @@ -11,6 +11,7 @@ import ( "time" "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/logger" "GoNavi-Wails/internal/ssh" "GoNavi-Wails/internal/utils" @@ -20,6 +21,7 @@ import ( type OracleDB struct { conn *sql.DB pingTimeout time.Duration + forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder } func (o *OracleDB) getDSN(config connection.ConnectionConfig) string { @@ -29,28 +31,6 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string { database = config.User // Default to user service/schema if empty? } - if config.UseSSH { - _, err := ssh.RegisterSSHNetwork(config.SSH) - if err == nil { - // Oracle driver might not support custom dialer via DSN easily without extra config - // But go-ora v2 supports some advanced options. - // For simplicity, we assume standard TCP or we might need a workaround for SSH. - // go-ora v2 is pure Go, so we can potentially use a custom dialer if we manually open. - // But for now, let's just use the address. - // SSH tunneling via net.Dialer override is complex in sql.Open("oracle", ...). - // We might need to forward a local port if using SSH. - // Since ssh.RegisterSSHNetwork creates a custom network "ssh-via-...", - // we need to see if go-ora supports custom networks. - // Checking go-ora docs (simulated): It supports "unix" and "tcp". - // We might need to map the custom network to a local proxy. - // For now, we will assume direct connection or handle SSH separately later. - // We'll leave the protocol implementation as is in MySQL for now, hoping go-ora uses standard net.Dial. - // Note: go-ora connection string: oracle://user:pass@host:port/service - // It parses host/port. It doesn't easily take a custom "network" parameter in URL. - // We will proceed with standard TCP string. - } - } - u := &url.URL{ Scheme: "oracle", Host: net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), @@ -62,7 +42,42 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string { } func (o *OracleDB) Connect(config connection.ConnectionConfig) error { - dsn := o.getDSN(config) + var dsn string + var err error + + if config.UseSSH { + // Create SSH tunnel with local port forwarding + logger.Infof("Oracle 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User) + + forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port) + if err != nil { + return fmt.Errorf("创建 SSH 隧道失败:%w", err) + } + o.forwarder = forwarder + + // Parse local address + host, portStr, err := net.SplitHostPort(forwarder.LocalAddr) + if err != nil { + return fmt.Errorf("解析本地转发地址失败:%w", err) + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return fmt.Errorf("解析本地端口失败:%w", err) + } + + // Create a modified config pointing to local forwarder + localConfig := config + localConfig.Host = host + localConfig.Port = port + localConfig.UseSSH = false + + dsn = o.getDSN(localConfig) + logger.Infof("Oracle 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port) + } else { + dsn = o.getDSN(config) + } + db, err := sql.Open("oracle", dsn) if err != nil { return fmt.Errorf("打开数据库连接失败:%w", err) @@ -76,6 +91,15 @@ func (o *OracleDB) Connect(config connection.ConnectionConfig) error { } func (o *OracleDB) Close() error { + // Close SSH forwarder first if exists + if o.forwarder != nil { + if err := o.forwarder.Close(); err != nil { + logger.Warnf("关闭 Oracle SSH 端口转发失败:%v", err) + } + o.forwarder = nil + } + + // Then close database connection if o.conn != nil { return o.conn.Close() } @@ -119,33 +143,7 @@ func (o *OracleDB) Query(query string) ([]map[string]interface{}, []string, erro return nil, nil, err } defer rows.Close() - - columns, err := rows.Columns() - if err != nil { - return nil, nil, err - } - - var resultData []map[string]interface{} - - for rows.Next() { - values := make([]interface{}, len(columns)) - valuePtrs := make([]interface{}, len(columns)) - for i := range columns { - valuePtrs[i] = &values[i] - } - - if err := rows.Scan(valuePtrs...); err != nil { - continue - } - - entry := make(map[string]interface{}) - for i, col := range columns { - entry[col] = normalizeQueryValue(values[i]) - } - resultData = append(resultData, entry) - } - - return resultData, columns, nil + return scanRows(rows) } func (o *OracleDB) ExecContext(ctx context.Context, query string) (int64, error) { diff --git a/internal/db/postgres_impl.go b/internal/db/postgres_impl.go index 9ade7f8..26fed51 100644 --- a/internal/db/postgres_impl.go +++ b/internal/db/postgres_impl.go @@ -11,16 +11,21 @@ import ( "time" "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/logger" + "GoNavi-Wails/internal/ssh" "GoNavi-Wails/internal/utils" _ "github.com/lib/pq" ) + type PostgresDB struct { conn *sql.DB pingTimeout time.Duration + forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder } + func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string { // postgres://user:password@host:port/dbname?sslmode=disable dbname := config.Database @@ -43,7 +48,42 @@ func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string { } func (p *PostgresDB) Connect(config connection.ConnectionConfig) error { - dsn := p.getDSN(config) + var dsn string + var err error + + if config.UseSSH { + // Create SSH tunnel with local port forwarding + logger.Infof("PostgreSQL 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User) + + forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port) + if err != nil { + return fmt.Errorf("创建 SSH 隧道失败:%w", err) + } + p.forwarder = forwarder + + // Parse local address + host, portStr, err := net.SplitHostPort(forwarder.LocalAddr) + if err != nil { + return fmt.Errorf("解析本地转发地址失败:%w", err) + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return fmt.Errorf("解析本地端口失败:%w", err) + } + + // Create a modified config pointing to local forwarder + localConfig := config + localConfig.Host = host + localConfig.Port = port + localConfig.UseSSH = false // Disable SSH flag for DSN generation + + dsn = p.getDSN(localConfig) + logger.Infof("PostgreSQL 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port) + } else { + dsn = p.getDSN(config) + } + db, err := sql.Open("postgres", dsn) if err != nil { return fmt.Errorf("打开数据库连接失败:%w", err) @@ -58,7 +98,17 @@ func (p *PostgresDB) Connect(config connection.ConnectionConfig) error { return nil } + func (p *PostgresDB) Close() error { + // Close SSH forwarder first if exists + if p.forwarder != nil { + if err := p.forwarder.Close(); err != nil { + logger.Warnf("关闭 PostgreSQL SSH 端口转发失败:%v", err) + } + p.forwarder = nil + } + + // Then close database connection if p.conn != nil { return p.conn.Close() } @@ -102,33 +152,7 @@ func (p *PostgresDB) Query(query string) ([]map[string]interface{}, []string, er return nil, nil, err } defer rows.Close() - - columns, err := rows.Columns() - if err != nil { - return nil, nil, err - } - - var resultData []map[string]interface{} - - for rows.Next() { - values := make([]interface{}, len(columns)) - valuePtrs := make([]interface{}, len(columns)) - for i := range columns { - valuePtrs[i] = &values[i] - } - - if err := rows.Scan(valuePtrs...); err != nil { - continue - } - - entry := make(map[string]interface{}) - for i, col := range columns { - entry[col] = normalizeQueryValue(values[i]) - } - resultData = append(resultData, entry) - } - - return resultData, columns, nil + return scanRows(rows) } func (p *PostgresDB) ExecContext(ctx context.Context, query string) (int64, error) { diff --git a/internal/db/query_value.go b/internal/db/query_value.go index 764ccbf..d4dde25 100644 --- a/internal/db/query_value.go +++ b/internal/db/query_value.go @@ -2,6 +2,8 @@ package db import ( "encoding/hex" + "fmt" + "strings" "unicode" "unicode/utf8" ) @@ -9,13 +11,17 @@ import ( // normalizeQueryValue normalizes driver-returned values for UI/JSON transport. // 当前主要处理 []byte:如果是可读文本则转为 string,否则转为十六进制字符串,避免前端出现“空白值”。 func normalizeQueryValue(v interface{}) interface{} { + return normalizeQueryValueWithDBType(v, "") +} + +func normalizeQueryValueWithDBType(v interface{}, databaseTypeName string) interface{} { if b, ok := v.([]byte); ok { - return bytesToReadableString(b) + return bytesToDisplayValue(b, databaseTypeName) } return v } -func bytesToReadableString(b []byte) interface{} { +func bytesToDisplayValue(b []byte, databaseTypeName string) interface{} { if b == nil { return nil } @@ -23,6 +29,18 @@ func bytesToReadableString(b []byte) interface{} { return "" } + dbType := strings.ToUpper(strings.TrimSpace(databaseTypeName)) + if isBitLikeDBType(dbType) { + if u, ok := bytesToUint64(b); ok { + // JS number precision is limited; keep large bitmasks as string. + const maxSafeInteger = 9007199254740991 // 2^53 - 1 + if u <= maxSafeInteger { + return int64(u) + } + return fmt.Sprintf("%d", u) + } + } + if utf8.Valid(b) { s := string(b) if isMostlyPrintable(s) { @@ -30,9 +48,47 @@ func bytesToReadableString(b []byte) interface{} { } } + // Fallback: some drivers return BIT(1) as []byte{0} / []byte{1} without type info. + if dbType == "" && len(b) == 1 && (b[0] == 0 || b[0] == 1) { + return int64(b[0]) + } + + return bytesToReadableString(b) +} + +func bytesToReadableString(b []byte) interface{} { + if b == nil { + return nil + } + if len(b) == 0 { + return "" + } return "0x" + hex.EncodeToString(b) } +func isBitLikeDBType(typeName string) bool { + if typeName == "" { + return false + } + switch typeName { + case "BIT", "VARBIT": + return true + default: + } + return strings.HasPrefix(typeName, "BIT") +} + +func bytesToUint64(b []byte) (uint64, bool) { + if len(b) == 0 || len(b) > 8 { + return 0, false + } + var u uint64 + for _, v := range b { + u = (u << 8) | uint64(v) + } + return u, true +} + func isMostlyPrintable(s string) bool { if s == "" { return true diff --git a/internal/db/query_value_test.go b/internal/db/query_value_test.go new file mode 100644 index 0000000..1b2c140 --- /dev/null +++ b/internal/db/query_value_test.go @@ -0,0 +1,44 @@ +package db + +import "testing" + +func TestNormalizeQueryValueWithDBType_BitBytes(t *testing.T) { + v := normalizeQueryValueWithDBType([]byte{0x00}, "BIT") + if v != int64(0) { + t.Fatalf("BIT 0x00 期望为 0,实际=%v(%T)", v, v) + } + + v = normalizeQueryValueWithDBType([]byte{0x01}, "bit") + if v != int64(1) { + t.Fatalf("BIT 0x01 期望为 1,实际=%v(%T)", v, v) + } + + v = normalizeQueryValueWithDBType([]byte{0x01, 0x02}, "BIT VARYING") + if v != int64(258) { + t.Fatalf("BIT 0x0102 期望为 258,实际=%v(%T)", v, v) + } +} + +func TestNormalizeQueryValueWithDBType_BitLargeAsString(t *testing.T) { + v := normalizeQueryValueWithDBType([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, "BIT") + if s, ok := v.(string); !ok || s != "18446744073709551615" { + t.Fatalf("BIT 0xffffffffffffffff 期望为 string(18446744073709551615),实际=%v(%T)", v, v) + } +} + +func TestNormalizeQueryValueWithDBType_ByteFallbacks(t *testing.T) { + v := normalizeQueryValueWithDBType([]byte("abc"), "") + if v != "abc" { + t.Fatalf("文本 []byte 期望返回 string,实际=%v(%T)", v, v) + } + + v = normalizeQueryValueWithDBType([]byte{0x00}, "") + if v != int64(0) { + t.Fatalf("未知类型 0x00 期望返回 0,实际=%v(%T)", v, v) + } + + v = normalizeQueryValueWithDBType([]byte{0xff}, "") + if v != "0xff" { + t.Fatalf("未知类型 0xff 期望返回 0xff,实际=%v(%T)", v, v) + } +} diff --git a/internal/db/scan_rows.go b/internal/db/scan_rows.go index d77bab0..d629e0a 100644 --- a/internal/db/scan_rows.go +++ b/internal/db/scan_rows.go @@ -1,6 +1,8 @@ package db -import "database/sql" +import ( + "database/sql" +) func scanRows(rows *sql.Rows) ([]map[string]interface{}, []string, error) { columns, err := rows.Columns() @@ -8,6 +10,11 @@ func scanRows(rows *sql.Rows) ([]map[string]interface{}, []string, error) { return nil, nil, err } + colTypes, err := rows.ColumnTypes() + if err != nil || len(colTypes) != len(columns) { + colTypes = nil + } + resultData := make([]map[string]interface{}, 0) for rows.Next() { @@ -23,7 +30,11 @@ func scanRows(rows *sql.Rows) ([]map[string]interface{}, []string, error) { entry := make(map[string]interface{}, len(columns)) for i, col := range columns { - entry[col] = normalizeQueryValue(values[i]) + dbTypeName := "" + if colTypes != nil && i < len(colTypes) && colTypes[i] != nil { + dbTypeName = colTypes[i].DatabaseTypeName() + } + entry[col] = normalizeQueryValueWithDBType(values[i], dbTypeName) } resultData = append(resultData, entry) } @@ -33,4 +44,3 @@ func scanRows(rows *sql.Rows) ([]map[string]interface{}, []string, error) { } return resultData, columns, nil } - diff --git a/internal/db/sqlite_impl.go b/internal/db/sqlite_impl.go index 57a8eec..3fe8dc6 100644 --- a/internal/db/sqlite_impl.go +++ b/internal/db/sqlite_impl.go @@ -78,33 +78,7 @@ func (s *SQLiteDB) Query(query string) ([]map[string]interface{}, []string, erro return nil, nil, err } defer rows.Close() - - columns, err := rows.Columns() - if err != nil { - return nil, nil, err - } - - var resultData []map[string]interface{} - - for rows.Next() { - values := make([]interface{}, len(columns)) - valuePtrs := make([]interface{}, len(columns)) - for i := range columns { - valuePtrs[i] = &values[i] - } - - if err := rows.Scan(valuePtrs...); err != nil { - continue - } - - entry := make(map[string]interface{}) - for i, col := range columns { - entry[col] = normalizeQueryValue(values[i]) - } - resultData = append(resultData, entry) - } - - return resultData, columns, nil + return scanRows(rows) } func (s *SQLiteDB) ExecContext(ctx context.Context, query string) (int64, error) { diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go index 2d6fe64..51ad364 100644 --- a/internal/ssh/ssh.go +++ b/internal/ssh/ssh.go @@ -3,8 +3,10 @@ package ssh import ( "context" "fmt" + "io" "net" "os" + "sync" "time" "GoNavi-Wails/internal/connection" @@ -110,3 +112,264 @@ func RegisterSSHNetwork(sshConfig connection.SSHConfig) (string, error) { return netName, nil } + +// sshClientCache stores SSH clients to avoid creating multiple connections +var ( + sshClientCache = make(map[string]*ssh.Client) + sshClientCacheMu sync.RWMutex + localForwarders = make(map[string]*LocalForwarder) + forwarderMu sync.RWMutex +) + +// LocalForwarder represents a local port forwarder through SSH +type LocalForwarder struct { + LocalAddr string + RemoteAddr string + SSHClient *ssh.Client + listener net.Listener + closeChan chan struct{} + closeOnce sync.Once // 防止重复关闭 + closed bool // 关闭状态标记 + closedMu sync.RWMutex +} + +// NewLocalForwarder creates a new local port forwarder +// It listens on a random local port and forwards all connections through SSH tunnel +func NewLocalForwarder(sshConfig connection.SSHConfig, remoteHost string, remotePort int) (*LocalForwarder, error) { + client, err := GetOrCreateSSHClient(sshConfig) + if err != nil { + return nil, fmt.Errorf("建立 SSH 连接失败:%w", err) + } + + // Listen on localhost with a random port + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, fmt.Errorf("创建本地监听器失败:%w", err) + } + + localAddr := listener.Addr().String() + remoteAddr := fmt.Sprintf("%s:%d", remoteHost, remotePort) + + forwarder := &LocalForwarder{ + LocalAddr: localAddr, + RemoteAddr: remoteAddr, + SSHClient: client, + listener: listener, + closeChan: make(chan struct{}), + } + + // Start forwarding in background + go forwarder.forward() + + logger.Infof("已创建 SSH 端口转发:本地 %s -> 远程 %s", localAddr, remoteAddr) + return forwarder, nil +} + +// forward handles the port forwarding +func (f *LocalForwarder) forward() { + for { + localConn, err := f.listener.Accept() + if err != nil { + // Check if we're shutting down + select { + case <-f.closeChan: + return + default: + logger.Warnf("接受本地连接失败:%v", err) + // listener可能已关闭,退出循环 + return + } + } + + go f.handleConnection(localConn) + } +} + +// handleConnection handles a single connection +func (f *LocalForwarder) handleConnection(localConn net.Conn) { + defer localConn.Close() + + // Connect to remote through SSH with timeout + remoteConn, err := f.SSHClient.Dial("tcp", f.RemoteAddr) + if err != nil { + logger.Warnf("通过 SSH 连接到远程 %s 失败:%v", f.RemoteAddr, err) + return + } + defer remoteConn.Close() + + // Bidirectional copy with error channel + errc := make(chan error, 2) + + // Copy from local to remote + go func() { + _, err := io.Copy(remoteConn, localConn) + if err != nil { + logger.Warnf("本地->远程数据复制错误:%v", err) + } + errc <- err + }() + + // Copy from remote to local + go func() { + _, err := io.Copy(localConn, remoteConn) + if err != nil { + logger.Warnf("远程->本地数据复制错误:%v", err) + } + errc <- err + }() + + // Wait for BOTH goroutines to complete + <-errc + <-errc +} + +// Close closes the forwarder (thread-safe, can be called multiple times) +func (f *LocalForwarder) Close() error { + var err error + f.closeOnce.Do(func() { + f.closedMu.Lock() + f.closed = true + f.closedMu.Unlock() + + close(f.closeChan) + err = f.listener.Close() + if err != nil { + logger.Warnf("关闭端口转发监听器失败:%v", err) + } + }) + return err +} + +// IsClosed returns whether the forwarder is closed +func (f *LocalForwarder) IsClosed() bool { + f.closedMu.RLock() + defer f.closedMu.RUnlock() + return f.closed +} + +// GetOrCreateLocalForwarder returns a cached forwarder or creates a new one +func GetOrCreateLocalForwarder(sshConfig connection.SSHConfig, remoteHost string, remotePort int) (*LocalForwarder, error) { + key := fmt.Sprintf("%s:%d:%s->%s:%d", + sshConfig.Host, sshConfig.Port, sshConfig.User, + remoteHost, remotePort) + + forwarderMu.RLock() + forwarder, exists := localForwarders[key] + forwarderMu.RUnlock() + + // Check if exists and is still valid + if exists && forwarder != nil && !forwarder.IsClosed() { + logger.Infof("复用已有端口转发:%s", key) + return forwarder, nil + } + + // Remove stale forwarder from cache + if exists { + forwarderMu.Lock() + delete(localForwarders, key) + forwarderMu.Unlock() + } + + forwarder, err := NewLocalForwarder(sshConfig, remoteHost, remotePort) + if err != nil { + return nil, err + } + + forwarderMu.Lock() + localForwarders[key] = forwarder + forwarderMu.Unlock() + + return forwarder, nil +} + +// CloseAllForwarders closes all local forwarders +func CloseAllForwarders() { + forwarderMu.Lock() + defer forwarderMu.Unlock() + + for key, forwarder := range localForwarders { + if forwarder != nil { + _ = forwarder.Close() + logger.Infof("已关闭端口转发:%s", key) + } + } + localForwarders = make(map[string]*LocalForwarder) +} + + +// getSSHClientCacheKey generates a unique cache key for SSH config +func getSSHClientCacheKey(config connection.SSHConfig) string { + return fmt.Sprintf("%s:%d:%s", config.Host, config.Port, config.User) +} + +// GetOrCreateSSHClient returns a cached SSH client or creates a new one +func GetOrCreateSSHClient(config connection.SSHConfig) (*ssh.Client, error) { + key := getSSHClientCacheKey(config) + + sshClientCacheMu.RLock() + client, exists := sshClientCache[key] + sshClientCacheMu.RUnlock() + + if exists && client != nil { + // Test if connection is still alive by creating a test session + session, err := client.NewSession() + if err == nil { + session.Close() + logger.Infof("复用已有 SSH 连接:%s", key) + return client, nil + } + // Connection is dead, remove from cache + logger.Warnf("SSH 连接已断开,重新建立:%s (错误: %v)", key, err) + sshClientCacheMu.Lock() + delete(sshClientCache, key) + sshClientCacheMu.Unlock() + // Try to close the dead client + _ = client.Close() + } + + // Create new SSH client + client, err := connectSSH(config) + if err != nil { + return nil, err + } + + // Cache the client + sshClientCacheMu.Lock() + sshClientCache[key] = client + sshClientCacheMu.Unlock() + + logger.Infof("已缓存 SSH 连接:%s", key) + return client, nil +} + +// DialThroughSSH creates a connection through SSH tunnel +// This is a generic dialer that can be used by any database driver +func DialThroughSSH(config connection.SSHConfig, network, address string) (net.Conn, error) { + client, err := GetOrCreateSSHClient(config) + if err != nil { + return nil, fmt.Errorf("建立 SSH 连接失败:%w", err) + } + + conn, err := client.Dial(network, address) + if err != nil { + return nil, fmt.Errorf("通过 SSH 隧道连接到 %s 失败:%w", address, err) + } + + logger.Infof("已通过 SSH 隧道连接到:%s", address) + return conn, nil +} + +// CloseAllSSHClients closes all cached SSH clients +func CloseAllSSHClients() { + sshClientCacheMu.Lock() + defer sshClientCacheMu.Unlock() + + for key, client := range sshClientCache { + if client != nil { + _ = client.Close() + logger.Infof("已关闭 SSH 连接:%s", key) + } + } + sshClientCache = make(map[string]*ssh.Client) +} +