Merge pull request #64 from Syngnat/release/0.2.6

♻️ refactor(database/ssh): SSH隧道架构重构与多数据源适配
This commit is contained in:
Syngnat
2026-02-04 14:41:43 +08:00
committed by GitHub
15 changed files with 879 additions and 325 deletions

View File

@@ -264,8 +264,8 @@ const ConnectionModal: React.FC<{ open: boolean; onClose: () => void; initialVal
{useSSH && (
<div style={{ padding: '12px', background: '#f5f5f5', borderRadius: 6, marginTop: 12 }}>
<div style={{ display: 'flex', gap: 16 }}>
<Form.Item name="sshHost" label="SSH 主机" rules={[{ required: useSSH, message: '请输入SSH主机' }]} style={{ flex: 1 }}>
<Input placeholder="ssh.example.com" />
<Form.Item name="sshHost" label="SSH 主机 (域名或IP)" rules={[{ required: useSSH, message: '请输入SSH主机' }]} style={{ flex: 1 }}>
<Input placeholder="例如: ssh.example.com 或 192.168.1.100" />
</Form.Item>
<Form.Item name="sshPort" label="端口" rules={[{ required: useSSH, message: '请输入SSH端口' }]} style={{ width: 100 }}>
<InputNumber style={{ width: '100%' }} />

View File

@@ -141,9 +141,6 @@ func (a *App) getDatabase(config connection.ConnectionConfig) (db.Database, erro
if len(shortKey) > 12 {
shortKey = shortKey[:12]
}
if config.UseSSH && config.Type != "mysql" {
logger.Warnf("当前仅 MySQL 支持内置 SSH 直连,其他类型请使用本地端口转发:%s", formatConnSummary(config))
}
logger.Infof("获取数据库连接:%s 缓存Key=%s", formatConnSummary(config), shortKey)
a.mu.Lock()

View File

@@ -8,14 +8,26 @@ import (
func sanitizeSQLForPgLike(dbType string, query string) string {
switch strings.ToLower(strings.TrimSpace(dbType)) {
case "postgres", "kingbase":
return fixBrokenDoubleDoubleQuotedIdent(query)
// 有些情况下会出现多层重复引用(例如 """"schema"""" 或 ""schema"""),单次修复不一定收敛。
// 这里做有限次数的迭代,直到输出不再变化。
out := query
for i := 0; i < 3; i++ {
fixed := fixBrokenDoubleDoubleQuotedIdent(out)
if fixed == out {
break
}
out = fixed
}
return out
default:
return query
}
}
// fixBrokenDoubleDoubleQuotedIdent fixes accidental identifiers like:
// SELECT * FROM ""schema"".""table""
//
// SELECT * FROM ""schema"".""table""
//
// which can be produced when a quoted identifier gets wrapped by quotes again.
//
// It is intentionally conservative:
@@ -124,20 +136,17 @@ func fixBrokenDoubleDoubleQuotedIdent(query string) string {
}
}
// Fix: ""ident"" -> "ident" (only when it looks like a plain identifier)
if ch == '"' && next == '"' {
prevIsQuote := i > 0 && query[i-1] == '"'
nextIsQuote := i+2 < len(query) && query[i+2] == '"'
if !prevIsQuote && !nextIsQuote {
if ch == '"' {
// Fix: ""ident"" -> "ident" (only when it looks like a plain identifier)
// Also handle variants like ""ident""" / """"ident"""" (extra quotes at either side).
if next == '"' {
if replacement, advance, ok := tryFixDoubleDoubleQuotedIdent(query, i); ok {
b.WriteString(replacement)
i = advance - 1
continue
}
}
}
if ch == '"' {
b.WriteByte(ch)
inDoubleIdent = true
continue
@@ -150,7 +159,8 @@ func fixBrokenDoubleDoubleQuotedIdent(query string) string {
}
func tryFixDoubleDoubleQuotedIdent(query string, start int) (replacement string, advance int, ok bool) {
// start points at the first quote of `""...""`
// start points at the first quote of a broken identifier, usually like:
// ""ident"" / ""ident""" / """"ident""""
if start < 0 || start+1 >= len(query) {
return "", 0, false
}
@@ -160,24 +170,31 @@ func tryFixDoubleDoubleQuotedIdent(query string, start int) (replacement string,
if start > 0 && query[start-1] == '"' {
return "", 0, false
}
if start+2 < len(query) && query[start+2] == '"' {
runLen := 0
for start+runLen < len(query) && query[start+runLen] == '"' {
runLen++
}
if runLen < 2 || runLen%2 == 1 {
// Odd run (e.g. """...) can be a valid quoted identifier with escaped quotes.
return "", 0, false
}
contentStart := start + 2
contentStart := start + runLen
j := contentStart
for j+1 < len(query) {
if query[j] == '"' && query[j+1] == '"' {
// ensure closing pair is not part of a triple quote
if j+2 < len(query) && query[j+2] == '"' {
j++
continue
for j < len(query) {
if query[j] == '"' {
endRunLen := 0
for j+endRunLen < len(query) && query[j+endRunLen] == '"' {
endRunLen++
}
content := strings.TrimSpace(query[contentStart:j])
if looksLikeIdentifierContent(content) {
return `"` + content + `"`, j + 2, true
if endRunLen >= 2 {
content := strings.TrimSpace(query[contentStart:j])
if looksLikeIdentifierContent(content) {
return `"` + content + `"`, j + endRunLen, true
}
return "", 0, false
}
return "", 0, false
}
// Fast abort: identifier-like content should not span lines.
if query[j] == '\n' || query[j] == '\r' {

View File

@@ -11,6 +11,24 @@ func TestSanitizeSQLForPgLike_FixesBrokenDoubleDoubleQuotes(t *testing.T) {
}
}
func TestSanitizeSQLForPgLike_FixesBrokenDoubleDoubleQuotes_WithExtraQuotes(t *testing.T) {
in := `SELECT * FROM ""ldf_server""".""t_user"" LIMIT 1`
out := sanitizeSQLForPgLike("kingbase", in)
want := `SELECT * FROM "ldf_server"."t_user" LIMIT 1`
if out != want {
t.Fatalf("unexpected sanitize output:\nIN: %s\nOUT: %s\nWANT: %s", in, out, want)
}
}
func TestSanitizeSQLForPgLike_FixesBrokenDoubleDoubleQuotes_WithQuadQuotes(t *testing.T) {
in := `SELECT * FROM """"ldf_server"""".""t_user"" LIMIT 1`
out := sanitizeSQLForPgLike("kingbase", in)
want := `SELECT * FROM "ldf_server"."t_user" LIMIT 1`
if out != want {
t.Fatalf("unexpected sanitize output:\nIN: %s\nOUT: %s\nWANT: %s", in, out, want)
}
}
func TestSanitizeSQLForPgLike_DoesNotTouchEscapedQuotesInsideIdentifier(t *testing.T) {
in := `SELECT "a""b" FROM "t""x"`
out := sanitizeSQLForPgLike("postgres", in)

View File

@@ -82,33 +82,7 @@ func (c *CustomDB) Query(query string) ([]map[string]interface{}, []string, erro
return nil, nil, err
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
return nil, nil, err
}
var resultData []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
continue
}
entry := make(map[string]interface{})
for i, col := range columns {
entry[col] = normalizeQueryValue(values[i])
}
resultData = append(resultData, entry)
}
return resultData, columns, nil
return scanRows(rows)
}
func (c *CustomDB) ExecContext(ctx context.Context, query string) (int64, error) {

View File

@@ -11,6 +11,7 @@ import (
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
@@ -20,6 +21,7 @@ import (
type DamengDB struct {
conn *sql.DB
pingTimeout time.Duration
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
}
func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
@@ -27,16 +29,6 @@ func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
// or dm://user:password@host:port
address := net.JoinHostPort(config.Host, strconv.Itoa(config.Port))
if config.UseSSH {
// SSH logic similar to others, assumes port forwarding
_, err := ssh.RegisterSSHNetwork(config.SSH)
if err == nil {
// DM driver likely uses standard net.Dial, so we might need a local listener
// or assume port forwarding is handled externally or implicitly via "tcp" override if driver allows.
// Similar to Oracle, we skip complex custom dialer injection for now.
}
}
escapedPassword := url.PathEscape(config.Password)
q := url.Values{}
if config.Database != "" {
@@ -56,7 +48,42 @@ func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
}
func (d *DamengDB) Connect(config connection.ConnectionConfig) error {
dsn := d.getDSN(config)
var dsn string
var err error
if config.UseSSH {
// Create SSH tunnel with local port forwarding
logger.Infof("达梦数据库使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
if err != nil {
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
}
d.forwarder = forwarder
// Parse local address
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
if err != nil {
return fmt.Errorf("解析本地转发地址失败:%w", err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("解析本地端口失败:%w", err)
}
// Create a modified config pointing to local forwarder
localConfig := config
localConfig.Host = host
localConfig.Port = port
localConfig.UseSSH = false
dsn = d.getDSN(localConfig)
logger.Infof("达梦数据库通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
} else {
dsn = d.getDSN(config)
}
db, err := sql.Open("dm", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
@@ -70,6 +97,15 @@ func (d *DamengDB) Connect(config connection.ConnectionConfig) error {
}
func (d *DamengDB) Close() error {
// Close SSH forwarder first if exists
if d.forwarder != nil {
if err := d.forwarder.Close(); err != nil {
logger.Warnf("关闭达梦数据库 SSH 端口转发失败:%v", err)
}
d.forwarder = nil
}
// Then close database connection
if d.conn != nil {
return d.conn.Close()
}
@@ -113,33 +149,7 @@ func (d *DamengDB) Query(query string) ([]map[string]interface{}, []string, erro
return nil, nil, err
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
return nil, nil, err
}
var resultData []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
continue
}
entry := make(map[string]interface{})
for i, col := range columns {
entry[col] = normalizeQueryValue(values[i])
}
resultData = append(resultData, entry)
}
return resultData, columns, nil
return scanRows(rows)
}
func (d *DamengDB) ExecContext(ctx context.Context, query string) (int64, error) {

View File

@@ -4,10 +4,13 @@ import (
"context"
"database/sql"
"fmt"
"net"
"strconv"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
@@ -17,6 +20,7 @@ import (
type KingbaseDB struct {
conn *sql.DB
pingTimeout time.Duration
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
}
func quoteConnValue(v string) string {
@@ -58,20 +62,6 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string {
address := config.Host
port := config.Port
if config.UseSSH {
netName, err := ssh.RegisterSSHNetwork(config.SSH)
if err == nil {
// Kingbase/Postgres lib/pq allows custom dialer via "host" if using unix socket,
// but for custom network it's harder.
// Ideally we use a local forwarder.
// For now, we assume standard TCP or handle SSH externally.
// If we implement the net.Dial override for "kingbase" driver (which might use lib/pq internally),
// we might need to check if it supports "cloudsql" style or similar custom dialers.
// Similar to others, skipping SSH deep integration here for now.
_ = netName
}
}
// Construct DSN
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable connect_timeout=%d",
quoteConnValue(address),
@@ -86,7 +76,42 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string {
}
func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error {
dsn := k.getDSN(config)
var dsn string
var err error
if config.UseSSH {
// Create SSH tunnel with local port forwarding
logger.Infof("人大金仓使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
if err != nil {
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
}
k.forwarder = forwarder
// Parse local address
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
if err != nil {
return fmt.Errorf("解析本地转发地址失败:%w", err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("解析本地端口失败:%w", err)
}
// Create a modified config pointing to local forwarder
localConfig := config
localConfig.Host = host
localConfig.Port = port
localConfig.UseSSH = false
dsn = k.getDSN(localConfig)
logger.Infof("人大金仓通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
} else {
dsn = k.getDSN(config)
}
// Open using "kingbase" driver
db, err := sql.Open("kingbase", dsn)
if err != nil {
@@ -101,6 +126,15 @@ func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error {
}
func (k *KingbaseDB) Close() error {
// Close SSH forwarder first if exists
if k.forwarder != nil {
if err := k.forwarder.Close(); err != nil {
logger.Warnf("关闭人大金仓 SSH 端口转发失败:%v", err)
}
k.forwarder = nil
}
// Then close database connection
if k.conn != nil {
return k.conn.Close()
}
@@ -144,33 +178,7 @@ func (k *KingbaseDB) Query(query string) ([]map[string]interface{}, []string, er
return nil, nil, err
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
return nil, nil, err
}
var resultData []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
continue
}
entry := make(map[string]interface{})
for i, col := range columns {
entry[col] = normalizeQueryValue(values[i])
}
resultData = append(resultData, entry)
}
return resultData, columns, nil
return scanRows(rows)
}
func (k *KingbaseDB) ExecContext(ctx context.Context, query string) (int64, error) {
@@ -249,15 +257,84 @@ func (k *KingbaseDB) GetCreateStatement(dbName, tableName string) (string, error
}
func (k *KingbaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
schema := "public"
if dbName != "" {
schema = dbName
// 解析 schema.table 格式
schema := strings.TrimSpace(dbName)
table := strings.TrimSpace(tableName)
// 如果 tableName 包含 schema (格式: schema.table)
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
parsedSchema := strings.TrimSpace(parts[0])
parsedTable := strings.TrimSpace(parts[1])
if parsedSchema != "" && parsedTable != "" {
schema = parsedSchema
table = parsedTable
}
}
query := fmt.Sprintf(`SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_schema = '%s' AND table_name = '%s'
ORDER BY ordinal_position`, schema, tableName)
// 如果仍然没有 schema,使用 current_schema()
// 这样可以自动匹配当前连接的 search_path
if schema == "" {
return k.getColumnsWithCurrentSchema(table)
}
if table == "" {
return nil, fmt.Errorf("table name required")
}
// 转义函数:处理单引号,移除双引号
esc := func(s string) string {
// 移除前后的双引号(如果存在)
s = strings.Trim(s, "\"")
// 转义单引号
return strings.ReplaceAll(s, "'", "''")
}
query := fmt.Sprintf(`SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_schema = '%s' AND table_name = '%s'
ORDER BY ordinal_position`, esc(schema), esc(table))
data, _, err := k.Query(query)
if err != nil {
return nil, err
}
var columns []connection.ColumnDefinition
for _, row := range data {
col := connection.ColumnDefinition{
Name: fmt.Sprintf("%v", row["column_name"]),
Type: fmt.Sprintf("%v", row["data_type"]),
Nullable: fmt.Sprintf("%v", row["is_nullable"]),
}
if row["column_default"] != nil {
def := fmt.Sprintf("%v", row["column_default"])
col.Default = &def
}
columns = append(columns, col)
}
return columns, nil
}
// getColumnsWithCurrentSchema 使用 current_schema() 查询当前schema的表
func (k *KingbaseDB) getColumnsWithCurrentSchema(tableName string) ([]connection.ColumnDefinition, error) {
table := strings.TrimSpace(tableName)
if table == "" {
return nil, fmt.Errorf("table name required")
}
// 转义函数
esc := func(s string) string {
s = strings.Trim(s, "\"")
return strings.ReplaceAll(s, "'", "''")
}
// 使用 current_schema() 获取当前schema
query := fmt.Sprintf(`SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_schema = current_schema() AND table_name = '%s'
ORDER BY ordinal_position`, esc(table))
data, _, err := k.Query(query)
if err != nil {
@@ -283,32 +360,76 @@ func (k *KingbaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDe
}
func (k *KingbaseDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
// Postgres/Kingbase index query
query := fmt.Sprintf(`
SELECT
i.relname as index_name,
a.attname as column_name,
ix.indisunique as is_unique
FROM
pg_class t,
pg_class i,
pg_index ix,
pg_attribute a,
pg_namespace n
WHERE
t.oid = ix.indrelid
AND i.oid = ix.indexrelid
AND a.attrelid = t.oid
AND a.attnum = ANY(ix.indkey)
AND t.relkind = 'r'
AND t.relname = '%s'
AND n.oid = t.relnamespace
AND n.nspname = '%s'
`, tableName, "public") // Default to public if dbName (schema) not clear.
// 解析 schema.table 格式
schema := strings.TrimSpace(dbName)
table := strings.TrimSpace(tableName)
if dbName != "" {
// Update query to use dbName as schema
query = strings.Replace(query, "'public'", fmt.Sprintf("'%s'", dbName), 1)
// 如果 tableName 包含 schema (格式: schema.table)
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
parsedSchema := strings.TrimSpace(parts[0])
parsedTable := strings.TrimSpace(parts[1])
if parsedSchema != "" && parsedTable != "" {
schema = parsedSchema
table = parsedTable
}
}
if table == "" {
return nil, fmt.Errorf("table name required")
}
// 转义函数:处理单引号,移除双引号
esc := func(s string) string {
s = strings.Trim(s, "\"")
return strings.ReplaceAll(s, "'", "''")
}
// 构建查询如果没有指定schema,使用current_schema()
var query string
if schema != "" {
query = fmt.Sprintf(`
SELECT
i.relname as index_name,
a.attname as column_name,
ix.indisunique as is_unique
FROM
pg_class t,
pg_class i,
pg_index ix,
pg_attribute a,
pg_namespace n
WHERE
t.oid = ix.indrelid
AND i.oid = ix.indexrelid
AND a.attrelid = t.oid
AND a.attnum = ANY(ix.indkey)
AND t.relkind = 'r'
AND t.relname = '%s'
AND n.oid = t.relnamespace
AND n.nspname = '%s'
`, esc(table), esc(schema))
} else {
query = fmt.Sprintf(`
SELECT
i.relname as index_name,
a.attname as column_name,
ix.indisunique as is_unique
FROM
pg_class t,
pg_class i,
pg_index ix,
pg_attribute a,
pg_namespace n
WHERE
t.oid = ix.indrelid
AND i.oid = ix.indexrelid
AND a.attrelid = t.oid
AND a.attnum = ANY(ix.indkey)
AND t.relkind = 'r'
AND t.relname = '%s'
AND n.oid = t.relnamespace
AND n.nspname = current_schema()
`, esc(table))
}
data, _, err := k.Query(query)
@@ -337,27 +458,67 @@ func (k *KingbaseDB) GetIndexes(dbName, tableName string) ([]connection.IndexDef
}
func (k *KingbaseDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
schema := "public"
if dbName != "" {
schema = dbName
// 解析 schema.table 格式
schema := strings.TrimSpace(dbName)
table := strings.TrimSpace(tableName)
// 如果 tableName 包含 schema (格式: schema.table)
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
parsedSchema := strings.TrimSpace(parts[0])
parsedTable := strings.TrimSpace(parts[1])
if parsedSchema != "" && parsedTable != "" {
schema = parsedSchema
table = parsedTable
}
}
query := fmt.Sprintf(`
SELECT
tc.constraint_name,
kcu.column_name,
ccu.table_name AS foreign_table_name,
ccu.column_name AS foreign_column_name
FROM
information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema='%s'`,
tableName, schema)
if table == "" {
return nil, fmt.Errorf("table name required")
}
// 转义函数:处理单引号,移除双引号
esc := func(s string) string {
s = strings.Trim(s, "\"")
return strings.ReplaceAll(s, "'", "''")
}
// 构建查询如果没有指定schema,使用current_schema()
var query string
if schema != "" {
query = fmt.Sprintf(`
SELECT
tc.constraint_name,
kcu.column_name,
ccu.table_name AS foreign_table_name,
ccu.column_name AS foreign_column_name
FROM
information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema='%s'`,
esc(table), esc(schema))
} else {
query = fmt.Sprintf(`
SELECT
tc.constraint_name,
kcu.column_name,
ccu.table_name AS foreign_table_name,
ccu.column_name AS foreign_column_name
FROM
information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name='%s' AND tc.table_schema=current_schema()`,
esc(table))
}
data, _, err := k.Query(query)
if err != nil {
@@ -379,9 +540,43 @@ func (k *KingbaseDB) GetForeignKeys(dbName, tableName string) ([]connection.Fore
}
func (k *KingbaseDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
query := fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation
FROM information_schema.triggers
WHERE event_object_table = '%s'`, tableName)
// 解析 schema.table 格式
schema := strings.TrimSpace(dbName)
table := strings.TrimSpace(tableName)
// 如果 tableName 包含 schema (格式: schema.table)
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
parsedSchema := strings.TrimSpace(parts[0])
parsedTable := strings.TrimSpace(parts[1])
if parsedSchema != "" && parsedTable != "" {
schema = parsedSchema
table = parsedTable
}
}
if table == "" {
return nil, fmt.Errorf("table name required")
}
// 转义函数:处理单引号,移除双引号
esc := func(s string) string {
s = strings.Trim(s, "\"")
return strings.ReplaceAll(s, "'", "''")
}
// 构建查询如果指定了schema,也加上schema条件
var query string
if schema != "" {
query = fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation
FROM information_schema.triggers
WHERE event_object_table = '%s' AND event_object_schema = '%s'`,
esc(table), esc(schema))
} else {
query = fmt.Sprintf(`SELECT trigger_name, action_timing, event_manipulation
FROM information_schema.triggers
WHERE event_object_table = '%s' AND event_object_schema = current_schema()`,
esc(table))
}
data, _, err := k.Query(query)
if err != nil {

View File

@@ -101,33 +101,7 @@ func (m *MySQLDB) Query(query string) ([]map[string]interface{}, []string, error
return nil, nil, err
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
return nil, nil, err
}
var resultData []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
continue
}
entry := make(map[string]interface{})
for i, col := range columns {
entry[col] = normalizeQueryValue(values[i])
}
resultData = append(resultData, entry)
}
return resultData, columns, nil
return scanRows(rows)
}
func (m *MySQLDB) ExecContext(ctx context.Context, query string) (int64, error) {

View File

@@ -11,6 +11,7 @@ import (
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
@@ -20,6 +21,7 @@ import (
type OracleDB struct {
conn *sql.DB
pingTimeout time.Duration
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
}
func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
@@ -29,28 +31,6 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
database = config.User // Default to user service/schema if empty?
}
if config.UseSSH {
_, err := ssh.RegisterSSHNetwork(config.SSH)
if err == nil {
// Oracle driver might not support custom dialer via DSN easily without extra config
// But go-ora v2 supports some advanced options.
// For simplicity, we assume standard TCP or we might need a workaround for SSH.
// go-ora v2 is pure Go, so we can potentially use a custom dialer if we manually open.
// But for now, let's just use the address.
// SSH tunneling via net.Dialer override is complex in sql.Open("oracle", ...).
// We might need to forward a local port if using SSH.
// Since ssh.RegisterSSHNetwork creates a custom network "ssh-via-...",
// we need to see if go-ora supports custom networks.
// Checking go-ora docs (simulated): It supports "unix" and "tcp".
// We might need to map the custom network to a local proxy.
// For now, we will assume direct connection or handle SSH separately later.
// We'll leave the protocol implementation as is in MySQL for now, hoping go-ora uses standard net.Dial.
// Note: go-ora connection string: oracle://user:pass@host:port/service
// It parses host/port. It doesn't easily take a custom "network" parameter in URL.
// We will proceed with standard TCP string.
}
}
u := &url.URL{
Scheme: "oracle",
Host: net.JoinHostPort(config.Host, strconv.Itoa(config.Port)),
@@ -62,7 +42,42 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
}
func (o *OracleDB) Connect(config connection.ConnectionConfig) error {
dsn := o.getDSN(config)
var dsn string
var err error
if config.UseSSH {
// Create SSH tunnel with local port forwarding
logger.Infof("Oracle 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
if err != nil {
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
}
o.forwarder = forwarder
// Parse local address
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
if err != nil {
return fmt.Errorf("解析本地转发地址失败:%w", err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("解析本地端口失败:%w", err)
}
// Create a modified config pointing to local forwarder
localConfig := config
localConfig.Host = host
localConfig.Port = port
localConfig.UseSSH = false
dsn = o.getDSN(localConfig)
logger.Infof("Oracle 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
} else {
dsn = o.getDSN(config)
}
db, err := sql.Open("oracle", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
@@ -76,6 +91,15 @@ func (o *OracleDB) Connect(config connection.ConnectionConfig) error {
}
func (o *OracleDB) Close() error {
// Close SSH forwarder first if exists
if o.forwarder != nil {
if err := o.forwarder.Close(); err != nil {
logger.Warnf("关闭 Oracle SSH 端口转发失败:%v", err)
}
o.forwarder = nil
}
// Then close database connection
if o.conn != nil {
return o.conn.Close()
}
@@ -119,33 +143,7 @@ func (o *OracleDB) Query(query string) ([]map[string]interface{}, []string, erro
return nil, nil, err
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
return nil, nil, err
}
var resultData []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
continue
}
entry := make(map[string]interface{})
for i, col := range columns {
entry[col] = normalizeQueryValue(values[i])
}
resultData = append(resultData, entry)
}
return resultData, columns, nil
return scanRows(rows)
}
func (o *OracleDB) ExecContext(ctx context.Context, query string) (int64, error) {

View File

@@ -11,16 +11,21 @@ import (
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
_ "github.com/lib/pq"
)
type PostgresDB struct {
conn *sql.DB
pingTimeout time.Duration
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
}
func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string {
// postgres://user:password@host:port/dbname?sslmode=disable
dbname := config.Database
@@ -43,7 +48,42 @@ func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string {
}
func (p *PostgresDB) Connect(config connection.ConnectionConfig) error {
dsn := p.getDSN(config)
var dsn string
var err error
if config.UseSSH {
// Create SSH tunnel with local port forwarding
logger.Infof("PostgreSQL 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
forwarder, err := ssh.GetOrCreateLocalForwarder(config.SSH, config.Host, config.Port)
if err != nil {
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
}
p.forwarder = forwarder
// Parse local address
host, portStr, err := net.SplitHostPort(forwarder.LocalAddr)
if err != nil {
return fmt.Errorf("解析本地转发地址失败:%w", err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("解析本地端口失败:%w", err)
}
// Create a modified config pointing to local forwarder
localConfig := config
localConfig.Host = host
localConfig.Port = port
localConfig.UseSSH = false // Disable SSH flag for DSN generation
dsn = p.getDSN(localConfig)
logger.Infof("PostgreSQL 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
} else {
dsn = p.getDSN(config)
}
db, err := sql.Open("postgres", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
@@ -58,7 +98,17 @@ func (p *PostgresDB) Connect(config connection.ConnectionConfig) error {
return nil
}
func (p *PostgresDB) Close() error {
// Close SSH forwarder first if exists
if p.forwarder != nil {
if err := p.forwarder.Close(); err != nil {
logger.Warnf("关闭 PostgreSQL SSH 端口转发失败:%v", err)
}
p.forwarder = nil
}
// Then close database connection
if p.conn != nil {
return p.conn.Close()
}
@@ -102,33 +152,7 @@ func (p *PostgresDB) Query(query string) ([]map[string]interface{}, []string, er
return nil, nil, err
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
return nil, nil, err
}
var resultData []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
continue
}
entry := make(map[string]interface{})
for i, col := range columns {
entry[col] = normalizeQueryValue(values[i])
}
resultData = append(resultData, entry)
}
return resultData, columns, nil
return scanRows(rows)
}
func (p *PostgresDB) ExecContext(ctx context.Context, query string) (int64, error) {

View File

@@ -2,6 +2,8 @@ package db
import (
"encoding/hex"
"fmt"
"strings"
"unicode"
"unicode/utf8"
)
@@ -9,13 +11,17 @@ import (
// normalizeQueryValue normalizes driver-returned values for UI/JSON transport.
// 当前主要处理 []byte如果是可读文本则转为 string否则转为十六进制字符串避免前端出现“空白值”。
func normalizeQueryValue(v interface{}) interface{} {
return normalizeQueryValueWithDBType(v, "")
}
func normalizeQueryValueWithDBType(v interface{}, databaseTypeName string) interface{} {
if b, ok := v.([]byte); ok {
return bytesToReadableString(b)
return bytesToDisplayValue(b, databaseTypeName)
}
return v
}
func bytesToReadableString(b []byte) interface{} {
func bytesToDisplayValue(b []byte, databaseTypeName string) interface{} {
if b == nil {
return nil
}
@@ -23,6 +29,18 @@ func bytesToReadableString(b []byte) interface{} {
return ""
}
dbType := strings.ToUpper(strings.TrimSpace(databaseTypeName))
if isBitLikeDBType(dbType) {
if u, ok := bytesToUint64(b); ok {
// JS number precision is limited; keep large bitmasks as string.
const maxSafeInteger = 9007199254740991 // 2^53 - 1
if u <= maxSafeInteger {
return int64(u)
}
return fmt.Sprintf("%d", u)
}
}
if utf8.Valid(b) {
s := string(b)
if isMostlyPrintable(s) {
@@ -30,9 +48,47 @@ func bytesToReadableString(b []byte) interface{} {
}
}
// Fallback: some drivers return BIT(1) as []byte{0} / []byte{1} without type info.
if dbType == "" && len(b) == 1 && (b[0] == 0 || b[0] == 1) {
return int64(b[0])
}
return bytesToReadableString(b)
}
func bytesToReadableString(b []byte) interface{} {
if b == nil {
return nil
}
if len(b) == 0 {
return ""
}
return "0x" + hex.EncodeToString(b)
}
func isBitLikeDBType(typeName string) bool {
if typeName == "" {
return false
}
switch typeName {
case "BIT", "VARBIT":
return true
default:
}
return strings.HasPrefix(typeName, "BIT")
}
func bytesToUint64(b []byte) (uint64, bool) {
if len(b) == 0 || len(b) > 8 {
return 0, false
}
var u uint64
for _, v := range b {
u = (u << 8) | uint64(v)
}
return u, true
}
func isMostlyPrintable(s string) bool {
if s == "" {
return true

View File

@@ -0,0 +1,44 @@
package db
import "testing"
func TestNormalizeQueryValueWithDBType_BitBytes(t *testing.T) {
v := normalizeQueryValueWithDBType([]byte{0x00}, "BIT")
if v != int64(0) {
t.Fatalf("BIT 0x00 期望为 0实际=%v(%T)", v, v)
}
v = normalizeQueryValueWithDBType([]byte{0x01}, "bit")
if v != int64(1) {
t.Fatalf("BIT 0x01 期望为 1实际=%v(%T)", v, v)
}
v = normalizeQueryValueWithDBType([]byte{0x01, 0x02}, "BIT VARYING")
if v != int64(258) {
t.Fatalf("BIT 0x0102 期望为 258实际=%v(%T)", v, v)
}
}
func TestNormalizeQueryValueWithDBType_BitLargeAsString(t *testing.T) {
v := normalizeQueryValueWithDBType([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, "BIT")
if s, ok := v.(string); !ok || s != "18446744073709551615" {
t.Fatalf("BIT 0xffffffffffffffff 期望为 string(18446744073709551615),实际=%v(%T)", v, v)
}
}
func TestNormalizeQueryValueWithDBType_ByteFallbacks(t *testing.T) {
v := normalizeQueryValueWithDBType([]byte("abc"), "")
if v != "abc" {
t.Fatalf("文本 []byte 期望返回 string实际=%v(%T)", v, v)
}
v = normalizeQueryValueWithDBType([]byte{0x00}, "")
if v != int64(0) {
t.Fatalf("未知类型 0x00 期望返回 0实际=%v(%T)", v, v)
}
v = normalizeQueryValueWithDBType([]byte{0xff}, "")
if v != "0xff" {
t.Fatalf("未知类型 0xff 期望返回 0xff实际=%v(%T)", v, v)
}
}

View File

@@ -1,6 +1,8 @@
package db
import "database/sql"
import (
"database/sql"
)
func scanRows(rows *sql.Rows) ([]map[string]interface{}, []string, error) {
columns, err := rows.Columns()
@@ -8,6 +10,11 @@ func scanRows(rows *sql.Rows) ([]map[string]interface{}, []string, error) {
return nil, nil, err
}
colTypes, err := rows.ColumnTypes()
if err != nil || len(colTypes) != len(columns) {
colTypes = nil
}
resultData := make([]map[string]interface{}, 0)
for rows.Next() {
@@ -23,7 +30,11 @@ func scanRows(rows *sql.Rows) ([]map[string]interface{}, []string, error) {
entry := make(map[string]interface{}, len(columns))
for i, col := range columns {
entry[col] = normalizeQueryValue(values[i])
dbTypeName := ""
if colTypes != nil && i < len(colTypes) && colTypes[i] != nil {
dbTypeName = colTypes[i].DatabaseTypeName()
}
entry[col] = normalizeQueryValueWithDBType(values[i], dbTypeName)
}
resultData = append(resultData, entry)
}
@@ -33,4 +44,3 @@ func scanRows(rows *sql.Rows) ([]map[string]interface{}, []string, error) {
}
return resultData, columns, nil
}

View File

@@ -78,33 +78,7 @@ func (s *SQLiteDB) Query(query string) ([]map[string]interface{}, []string, erro
return nil, nil, err
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
return nil, nil, err
}
var resultData []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
continue
}
entry := make(map[string]interface{})
for i, col := range columns {
entry[col] = normalizeQueryValue(values[i])
}
resultData = append(resultData, entry)
}
return resultData, columns, nil
return scanRows(rows)
}
func (s *SQLiteDB) ExecContext(ctx context.Context, query string) (int64, error) {

View File

@@ -3,8 +3,10 @@ package ssh
import (
"context"
"fmt"
"io"
"net"
"os"
"sync"
"time"
"GoNavi-Wails/internal/connection"
@@ -110,3 +112,264 @@ func RegisterSSHNetwork(sshConfig connection.SSHConfig) (string, error) {
return netName, nil
}
// sshClientCache stores SSH clients to avoid creating multiple connections
var (
sshClientCache = make(map[string]*ssh.Client)
sshClientCacheMu sync.RWMutex
localForwarders = make(map[string]*LocalForwarder)
forwarderMu sync.RWMutex
)
// LocalForwarder represents a local port forwarder through SSH
type LocalForwarder struct {
LocalAddr string
RemoteAddr string
SSHClient *ssh.Client
listener net.Listener
closeChan chan struct{}
closeOnce sync.Once // 防止重复关闭
closed bool // 关闭状态标记
closedMu sync.RWMutex
}
// NewLocalForwarder creates a new local port forwarder
// It listens on a random local port and forwards all connections through SSH tunnel
func NewLocalForwarder(sshConfig connection.SSHConfig, remoteHost string, remotePort int) (*LocalForwarder, error) {
client, err := GetOrCreateSSHClient(sshConfig)
if err != nil {
return nil, fmt.Errorf("建立 SSH 连接失败:%w", err)
}
// Listen on localhost with a random port
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, fmt.Errorf("创建本地监听器失败:%w", err)
}
localAddr := listener.Addr().String()
remoteAddr := fmt.Sprintf("%s:%d", remoteHost, remotePort)
forwarder := &LocalForwarder{
LocalAddr: localAddr,
RemoteAddr: remoteAddr,
SSHClient: client,
listener: listener,
closeChan: make(chan struct{}),
}
// Start forwarding in background
go forwarder.forward()
logger.Infof("已创建 SSH 端口转发:本地 %s -> 远程 %s", localAddr, remoteAddr)
return forwarder, nil
}
// forward handles the port forwarding
func (f *LocalForwarder) forward() {
for {
localConn, err := f.listener.Accept()
if err != nil {
// Check if we're shutting down
select {
case <-f.closeChan:
return
default:
logger.Warnf("接受本地连接失败:%v", err)
// listener可能已关闭,退出循环
return
}
}
go f.handleConnection(localConn)
}
}
// handleConnection handles a single connection
func (f *LocalForwarder) handleConnection(localConn net.Conn) {
defer localConn.Close()
// Connect to remote through SSH with timeout
remoteConn, err := f.SSHClient.Dial("tcp", f.RemoteAddr)
if err != nil {
logger.Warnf("通过 SSH 连接到远程 %s 失败:%v", f.RemoteAddr, err)
return
}
defer remoteConn.Close()
// Bidirectional copy with error channel
errc := make(chan error, 2)
// Copy from local to remote
go func() {
_, err := io.Copy(remoteConn, localConn)
if err != nil {
logger.Warnf("本地->远程数据复制错误:%v", err)
}
errc <- err
}()
// Copy from remote to local
go func() {
_, err := io.Copy(localConn, remoteConn)
if err != nil {
logger.Warnf("远程->本地数据复制错误:%v", err)
}
errc <- err
}()
// Wait for BOTH goroutines to complete
<-errc
<-errc
}
// Close closes the forwarder (thread-safe, can be called multiple times)
func (f *LocalForwarder) Close() error {
var err error
f.closeOnce.Do(func() {
f.closedMu.Lock()
f.closed = true
f.closedMu.Unlock()
close(f.closeChan)
err = f.listener.Close()
if err != nil {
logger.Warnf("关闭端口转发监听器失败:%v", err)
}
})
return err
}
// IsClosed returns whether the forwarder is closed
func (f *LocalForwarder) IsClosed() bool {
f.closedMu.RLock()
defer f.closedMu.RUnlock()
return f.closed
}
// GetOrCreateLocalForwarder returns a cached forwarder or creates a new one
func GetOrCreateLocalForwarder(sshConfig connection.SSHConfig, remoteHost string, remotePort int) (*LocalForwarder, error) {
key := fmt.Sprintf("%s:%d:%s->%s:%d",
sshConfig.Host, sshConfig.Port, sshConfig.User,
remoteHost, remotePort)
forwarderMu.RLock()
forwarder, exists := localForwarders[key]
forwarderMu.RUnlock()
// Check if exists and is still valid
if exists && forwarder != nil && !forwarder.IsClosed() {
logger.Infof("复用已有端口转发:%s", key)
return forwarder, nil
}
// Remove stale forwarder from cache
if exists {
forwarderMu.Lock()
delete(localForwarders, key)
forwarderMu.Unlock()
}
forwarder, err := NewLocalForwarder(sshConfig, remoteHost, remotePort)
if err != nil {
return nil, err
}
forwarderMu.Lock()
localForwarders[key] = forwarder
forwarderMu.Unlock()
return forwarder, nil
}
// CloseAllForwarders closes all local forwarders
func CloseAllForwarders() {
forwarderMu.Lock()
defer forwarderMu.Unlock()
for key, forwarder := range localForwarders {
if forwarder != nil {
_ = forwarder.Close()
logger.Infof("已关闭端口转发:%s", key)
}
}
localForwarders = make(map[string]*LocalForwarder)
}
// getSSHClientCacheKey generates a unique cache key for SSH config
func getSSHClientCacheKey(config connection.SSHConfig) string {
return fmt.Sprintf("%s:%d:%s", config.Host, config.Port, config.User)
}
// GetOrCreateSSHClient returns a cached SSH client or creates a new one
func GetOrCreateSSHClient(config connection.SSHConfig) (*ssh.Client, error) {
key := getSSHClientCacheKey(config)
sshClientCacheMu.RLock()
client, exists := sshClientCache[key]
sshClientCacheMu.RUnlock()
if exists && client != nil {
// Test if connection is still alive by creating a test session
session, err := client.NewSession()
if err == nil {
session.Close()
logger.Infof("复用已有 SSH 连接:%s", key)
return client, nil
}
// Connection is dead, remove from cache
logger.Warnf("SSH 连接已断开,重新建立:%s (错误: %v)", key, err)
sshClientCacheMu.Lock()
delete(sshClientCache, key)
sshClientCacheMu.Unlock()
// Try to close the dead client
_ = client.Close()
}
// Create new SSH client
client, err := connectSSH(config)
if err != nil {
return nil, err
}
// Cache the client
sshClientCacheMu.Lock()
sshClientCache[key] = client
sshClientCacheMu.Unlock()
logger.Infof("已缓存 SSH 连接:%s", key)
return client, nil
}
// DialThroughSSH creates a connection through SSH tunnel
// This is a generic dialer that can be used by any database driver
func DialThroughSSH(config connection.SSHConfig, network, address string) (net.Conn, error) {
client, err := GetOrCreateSSHClient(config)
if err != nil {
return nil, fmt.Errorf("建立 SSH 连接失败:%w", err)
}
conn, err := client.Dial(network, address)
if err != nil {
return nil, fmt.Errorf("通过 SSH 隧道连接到 %s 失败:%w", address, err)
}
logger.Infof("已通过 SSH 隧道连接到:%s", address)
return conn, nil
}
// CloseAllSSHClients closes all cached SSH clients
func CloseAllSSHClients() {
sshClientCacheMu.Lock()
defer sshClientCacheMu.Unlock()
for key, client := range sshClientCache {
if client != nil {
_ = client.Close()
logger.Infof("已关闭 SSH 连接:%s", key)
}
}
sshClientCache = make(map[string]*ssh.Client)
}