Files
MyGoNavi/internal/ssh/ssh.go
Syngnat d8b6b4ef8d 🔧 fix(release,ssh): 修复 SSH 误判连接成功并纠正 DMG 打包结构
- SSH 缓存 key 纳入认证指纹(password/keyPath),避免改错凭证仍复用旧连接/端口转发
- MySQL/MariaDB/Doris:SSH 隧道建立失败直接返回错误,不再回退直连导致测试误判成功
- 新增最小单测覆盖 SSH cache key 与 UseSSH 异常路径
- build-release.sh:create-dmg 使用 staging 目录作为 source,避免 DMG 根目录变成 Contents
- refs #213
2026-03-11 14:36:36 +08:00

422 lines
11 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package ssh
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net"
"os"
"strconv"
"sync"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"github.com/go-sql-driver/mysql"
"golang.org/x/crypto/ssh"
)
// ViaSSHDialer registers a custom network for MySQL that proxies through SSH
type ViaSSHDialer struct {
sshClient *ssh.Client
}
func (d *ViaSSHDialer) Dial(ctx context.Context, addr string) (net.Conn, error) {
return dialContext(ctx, d.sshClient, "tcp", addr)
}
func dialContext(ctx context.Context, client *ssh.Client, network, addr string) (net.Conn, error) {
type result struct {
conn net.Conn
err error
}
ch := make(chan result, 1)
go func() {
c, err := client.Dial(network, addr)
ch <- result{conn: c, err: err}
}()
select {
case <-ctx.Done():
go func() {
r := <-ch
if r.conn != nil {
_ = r.conn.Close()
}
}()
return nil, ctx.Err()
case r := <-ch:
return r.conn, r.err
}
}
// connectSSH establishes an SSH connection and returns a Dialer
func connectSSH(config connection.SSHConfig) (*ssh.Client, error) {
logger.Infof("开始建立 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
authMethods := []ssh.AuthMethod{}
if config.KeyPath != "" {
key, err := os.ReadFile(config.KeyPath)
if err != nil {
logger.Warnf("读取 SSH 私钥失败:路径=%s原因%v", config.KeyPath, err)
} else {
signer, err := ssh.ParsePrivateKey(key)
if err != nil {
logger.Warnf("解析 SSH 私钥失败:路径=%s原因%v", config.KeyPath, err)
} else {
authMethods = append(authMethods, ssh.PublicKeys(signer))
}
}
}
if config.Password != "" {
authMethods = append(authMethods, ssh.Password(config.Password))
}
if len(authMethods) == 0 {
logger.Warnf("SSH 未配置认证方式(密码或私钥)")
}
sshConfig := &ssh.ClientConfig{
User: config.User,
Auth: authMethods,
HostKeyCallback: ssh.InsecureIgnoreHostKey(), // Use strict checking in production!
Timeout: 5 * time.Second,
}
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
client, err := ssh.Dial("tcp", addr, sshConfig)
if err != nil {
logger.Error(err, "SSH 连接建立失败:地址=%s 用户=%s", addr, config.User)
return nil, err
}
logger.Infof("SSH 连接建立成功:地址=%s 用户=%s", addr, config.User)
return client, nil
}
// RegisterSSHNetwork registers a unique network name for a specific SSH tunnel
// Returns the network name to use in DSN
func RegisterSSHNetwork(sshConfig connection.SSHConfig) (string, error) {
client, err := connectSSH(sshConfig)
if err != nil {
return "", err
}
// Generate unique network name
netName := fmt.Sprintf("ssh_%s_%d", sshConfig.Host, time.Now().UnixNano())
logger.Infof("注册 SSH 网络:%s地址=%s:%d 用户=%s", netName, sshConfig.Host, sshConfig.Port, sshConfig.User)
mysql.RegisterDialContext(netName, func(ctx context.Context, addr string) (net.Conn, error) {
return dialContext(ctx, client, "tcp", addr)
})
return netName, nil
}
// sshClientCache stores SSH clients to avoid creating multiple connections
var (
sshClientCache = make(map[sshClientCacheKey]*ssh.Client)
sshClientCacheMu sync.RWMutex
localForwarders = make(map[forwarderCacheKey]*LocalForwarder)
forwarderMu sync.RWMutex
)
type sshClientCacheKey struct {
host string
port int
user string
auth string
}
type forwarderCacheKey struct {
ssh sshClientCacheKey
remoteHost string
remotePort int
}
func sshAuthFingerprint(config connection.SSHConfig) string {
hasher := sha256.New()
_, _ = hasher.Write([]byte(config.Password))
_, _ = hasher.Write([]byte{0})
_, _ = hasher.Write([]byte(config.KeyPath))
if config.KeyPath != "" {
if st, err := os.Stat(config.KeyPath); err == nil {
_, _ = hasher.Write([]byte{0})
_, _ = hasher.Write([]byte(st.ModTime().UTC().Format(time.RFC3339Nano)))
_, _ = hasher.Write([]byte{0})
_, _ = hasher.Write([]byte(strconv.FormatInt(st.Size(), 10)))
} else {
_, _ = hasher.Write([]byte{0})
_, _ = hasher.Write([]byte("stat_err"))
}
}
sum := hasher.Sum(nil)
return hex.EncodeToString(sum[:8])
}
func newSSHClientCacheKey(config connection.SSHConfig) sshClientCacheKey {
return sshClientCacheKey{
host: config.Host,
port: config.Port,
user: config.User,
auth: sshAuthFingerprint(config),
}
}
func formatSSHClientKeyForLog(key sshClientCacheKey) string {
return fmt.Sprintf("%s:%d 用户=%s", key.host, key.port, key.user)
}
// LocalForwarder represents a local port forwarder through SSH
type LocalForwarder struct {
LocalAddr string
RemoteAddr string
SSHClient *ssh.Client
listener net.Listener
closeChan chan struct{}
closeOnce sync.Once // 防止重复关闭
closed bool // 关闭状态标记
closedMu sync.RWMutex
}
// NewLocalForwarder creates a new local port forwarder
// It listens on a random local port and forwards all connections through SSH tunnel
func NewLocalForwarder(sshConfig connection.SSHConfig, remoteHost string, remotePort int) (*LocalForwarder, error) {
client, err := GetOrCreateSSHClient(sshConfig)
if err != nil {
return nil, fmt.Errorf("建立 SSH 连接失败:%w", err)
}
// Listen on localhost with a random port
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, fmt.Errorf("创建本地监听器失败:%w", err)
}
localAddr := listener.Addr().String()
remoteAddr := fmt.Sprintf("%s:%d", remoteHost, remotePort)
forwarder := &LocalForwarder{
LocalAddr: localAddr,
RemoteAddr: remoteAddr,
SSHClient: client,
listener: listener,
closeChan: make(chan struct{}),
}
// Start forwarding in background
go forwarder.forward()
logger.Infof("已创建 SSH 端口转发:本地 %s -> 远程 %s", localAddr, remoteAddr)
return forwarder, nil
}
// forward handles the port forwarding
func (f *LocalForwarder) forward() {
for {
localConn, err := f.listener.Accept()
if err != nil {
// Check if we're shutting down
select {
case <-f.closeChan:
return
default:
logger.Warnf("接受本地连接失败:%v", err)
// listener可能已关闭,退出循环
return
}
}
go f.handleConnection(localConn)
}
}
// handleConnection handles a single connection
func (f *LocalForwarder) handleConnection(localConn net.Conn) {
defer localConn.Close()
// Connect to remote through SSH with timeout
remoteConn, err := f.SSHClient.Dial("tcp", f.RemoteAddr)
if err != nil {
logger.Warnf("通过 SSH 连接到远程 %s 失败:%v", f.RemoteAddr, err)
return
}
defer remoteConn.Close()
// Bidirectional copy with error channel
errc := make(chan error, 2)
// Copy from local to remote
go func() {
_, err := io.Copy(remoteConn, localConn)
if err != nil {
logger.Warnf("本地->远程数据复制错误:%v", err)
}
errc <- err
}()
// Copy from remote to local
go func() {
_, err := io.Copy(localConn, remoteConn)
if err != nil {
logger.Warnf("远程->本地数据复制错误:%v", err)
}
errc <- err
}()
// Wait for BOTH goroutines to complete
<-errc
<-errc
}
// Close closes the forwarder (thread-safe, can be called multiple times)
func (f *LocalForwarder) Close() error {
var err error
f.closeOnce.Do(func() {
f.closedMu.Lock()
f.closed = true
f.closedMu.Unlock()
close(f.closeChan)
err = f.listener.Close()
if err != nil {
logger.Warnf("关闭端口转发监听器失败:%v", err)
}
})
return err
}
// IsClosed returns whether the forwarder is closed
func (f *LocalForwarder) IsClosed() bool {
f.closedMu.RLock()
defer f.closedMu.RUnlock()
return f.closed
}
// GetOrCreateLocalForwarder returns a cached forwarder or creates a new one
func GetOrCreateLocalForwarder(sshConfig connection.SSHConfig, remoteHost string, remotePort int) (*LocalForwarder, error) {
key := forwarderCacheKey{
ssh: newSSHClientCacheKey(sshConfig),
remoteHost: remoteHost,
remotePort: remotePort,
}
logKey := fmt.Sprintf("%s:%d:%s->%s:%d",
sshConfig.Host, sshConfig.Port, sshConfig.User, remoteHost, remotePort)
forwarderMu.RLock()
forwarder, exists := localForwarders[key]
forwarderMu.RUnlock()
// Check if exists and is still valid
if exists && forwarder != nil && !forwarder.IsClosed() {
logger.Infof("复用已有端口转发:%s", logKey)
return forwarder, nil
}
// Remove stale forwarder from cache
if exists {
forwarderMu.Lock()
delete(localForwarders, key)
forwarderMu.Unlock()
}
forwarder, err := NewLocalForwarder(sshConfig, remoteHost, remotePort)
if err != nil {
return nil, err
}
forwarderMu.Lock()
localForwarders[key] = forwarder
forwarderMu.Unlock()
return forwarder, nil
}
// CloseAllForwarders closes all local forwarders
func CloseAllForwarders() {
forwarderMu.Lock()
defer forwarderMu.Unlock()
for _, forwarder := range localForwarders {
if forwarder != nil {
_ = forwarder.Close()
logger.Infof("已关闭端口转发:本地 %s -> 远程 %s", forwarder.LocalAddr, forwarder.RemoteAddr)
}
}
localForwarders = make(map[forwarderCacheKey]*LocalForwarder)
}
// GetOrCreateSSHClient returns a cached SSH client or creates a new one
func GetOrCreateSSHClient(config connection.SSHConfig) (*ssh.Client, error) {
key := newSSHClientCacheKey(config)
sshClientCacheMu.RLock()
client, exists := sshClientCache[key]
sshClientCacheMu.RUnlock()
if exists && client != nil {
// Test if connection is still alive by creating a test session
session, err := client.NewSession()
if err == nil {
session.Close()
logger.Infof("复用已有 SSH 连接:%s", formatSSHClientKeyForLog(key))
return client, nil
}
// Connection is dead, remove from cache
logger.Warnf("SSH 连接已断开,重新建立:%s (错误: %v)", formatSSHClientKeyForLog(key), err)
sshClientCacheMu.Lock()
delete(sshClientCache, key)
sshClientCacheMu.Unlock()
// Try to close the dead client
_ = client.Close()
}
// Create new SSH client
client, err := connectSSH(config)
if err != nil {
return nil, err
}
// Cache the client
sshClientCacheMu.Lock()
sshClientCache[key] = client
sshClientCacheMu.Unlock()
logger.Infof("已缓存 SSH 连接:%s", formatSSHClientKeyForLog(key))
return client, nil
}
// DialThroughSSH creates a connection through SSH tunnel
// This is a generic dialer that can be used by any database driver
func DialThroughSSH(config connection.SSHConfig, network, address string) (net.Conn, error) {
client, err := GetOrCreateSSHClient(config)
if err != nil {
return nil, fmt.Errorf("建立 SSH 连接失败:%w", err)
}
conn, err := client.Dial(network, address)
if err != nil {
return nil, fmt.Errorf("通过 SSH 隧道连接到 %s 失败:%w", address, err)
}
logger.Infof("已通过 SSH 隧道连接到:%s", address)
return conn, nil
}
// CloseAllSSHClients closes all cached SSH clients
func CloseAllSSHClients() {
sshClientCacheMu.Lock()
defer sshClientCacheMu.Unlock()
for key, client := range sshClientCache {
if client != nil {
_ = client.Close()
logger.Infof("已关闭 SSH 连接:%s", formatSSHClientKeyForLog(key))
}
}
sshClientCache = make(map[sshClientCacheKey]*ssh.Client)
}