Files
MyGoNavi/internal/ssh/ssh.go
杨国锋 99c21f4fd4 🐛 fix(connection): 修复多数据源连接测试成功但实际失败,closes #23
- 前端改用通用 DB API,避免强制走 MySQL 接口导致 PostgreSQL 等连接异常
  - 后端统一各数据源 timeout(Ping 超时 + 连接参数注入)
  - DSN 生成兼容特殊字符密码(Postgres/Oracle/达梦/金仓)
  - 增加文件日志与错误链输出,连接失败提示日志路径便于排障
2026-02-03 12:23:37 +08:00

113 lines
3.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package ssh
import (
"context"
"fmt"
"net"
"os"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"github.com/go-sql-driver/mysql"
"golang.org/x/crypto/ssh"
)
// ViaSSHDialer registers a custom network for MySQL that proxies through SSH
type ViaSSHDialer struct {
sshClient *ssh.Client
}
func (d *ViaSSHDialer) Dial(ctx context.Context, addr string) (net.Conn, error) {
return dialContext(ctx, d.sshClient, "tcp", addr)
}
func dialContext(ctx context.Context, client *ssh.Client, network, addr string) (net.Conn, error) {
type result struct {
conn net.Conn
err error
}
ch := make(chan result, 1)
go func() {
c, err := client.Dial(network, addr)
ch <- result{conn: c, err: err}
}()
select {
case <-ctx.Done():
go func() {
r := <-ch
if r.conn != nil {
_ = r.conn.Close()
}
}()
return nil, ctx.Err()
case r := <-ch:
return r.conn, r.err
}
}
// connectSSH establishes an SSH connection and returns a Dialer
func connectSSH(config connection.SSHConfig) (*ssh.Client, error) {
logger.Infof("开始建立 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
authMethods := []ssh.AuthMethod{}
if config.KeyPath != "" {
key, err := os.ReadFile(config.KeyPath)
if err != nil {
logger.Warnf("读取 SSH 私钥失败:路径=%s原因%v", config.KeyPath, err)
} else {
signer, err := ssh.ParsePrivateKey(key)
if err != nil {
logger.Warnf("解析 SSH 私钥失败:路径=%s原因%v", config.KeyPath, err)
} else {
authMethods = append(authMethods, ssh.PublicKeys(signer))
}
}
}
if config.Password != "" {
authMethods = append(authMethods, ssh.Password(config.Password))
}
if len(authMethods) == 0 {
logger.Warnf("SSH 未配置认证方式(密码或私钥)")
}
sshConfig := &ssh.ClientConfig{
User: config.User,
Auth: authMethods,
HostKeyCallback: ssh.InsecureIgnoreHostKey(), // Use strict checking in production!
Timeout: 5 * time.Second,
}
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
client, err := ssh.Dial("tcp", addr, sshConfig)
if err != nil {
logger.Error(err, "SSH 连接建立失败:地址=%s 用户=%s", addr, config.User)
return nil, err
}
logger.Infof("SSH 连接建立成功:地址=%s 用户=%s", addr, config.User)
return client, nil
}
// RegisterSSHNetwork registers a unique network name for a specific SSH tunnel
// Returns the network name to use in DSN
func RegisterSSHNetwork(sshConfig connection.SSHConfig) (string, error) {
client, err := connectSSH(sshConfig)
if err != nil {
return "", err
}
// Generate unique network name
netName := fmt.Sprintf("ssh_%s_%d", sshConfig.Host, time.Now().UnixNano())
logger.Infof("注册 SSH 网络:%s地址=%s:%d 用户=%s", netName, sshConfig.Host, sshConfig.Port, sshConfig.User)
mysql.RegisterDialContext(netName, func(ctx context.Context, addr string) (net.Conn, error) {
return dialContext(ctx, client, "tcp", addr)
})
return netName, nil
}