From d8b6b4ef8d444105a405df406bea3affb5c0af74 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Wed, 11 Mar 2026 14:36:36 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20fix(release,ssh):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=20SSH=20=E8=AF=AF=E5=88=A4=E8=BF=9E=E6=8E=A5=E6=88=90?= =?UTF-8?q?=E5=8A=9F=E5=B9=B6=E7=BA=A0=E6=AD=A3=20DMG=20=E6=89=93=E5=8C=85?= =?UTF-8?q?=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - SSH 缓存 key 纳入认证指纹(password/keyPath),避免改错凭证仍复用旧连接/端口转发 - MySQL/MariaDB/Doris:SSH 隧道建立失败直接返回错误,不再回退直连导致测试误判成功 - 新增最小单测覆盖 SSH cache key 与 UseSSH 异常路径 - build-release.sh:create-dmg 使用 staging 目录作为 source,避免 DMG 根目录变成 Contents - refs #213 --- build-release.sh | 154 +++++++++++++++++------------ internal/db/diros_impl.go | 23 +++-- internal/db/mariadb_impl.go | 22 +++-- internal/db/mysql_impl.go | 22 +++-- internal/db/mysql_ssh_test.go | 26 +++++ internal/ssh/ssh.go | 94 +++++++++++++----- internal/ssh/ssh_cache_key_test.go | 46 +++++++++ 7 files changed, 269 insertions(+), 118 deletions(-) create mode 100644 internal/db/mysql_ssh_test.go create mode 100644 internal/ssh/ssh_cache_key_test.go diff --git a/build-release.sh b/build-release.sh index d8b3a72..a36f835 100755 --- a/build-release.sh +++ b/build-release.sh @@ -42,39 +42,50 @@ if [ $? -eq 0 ]; then # 移动 .app 到 dist mv "$APP_SRC" "$DIST_DIR/$APP_DEST_NAME" - # Ad-hoc 代码签名(无 Apple Developer 账号时防止 Gatekeeper 报已损坏) - echo " 🔏 正在对 .app 进行 ad-hoc 签名 (arm64)..." - codesign --force --deep --sign - "$DIST_DIR/$APP_DEST_NAME" + # Ad-hoc 代码签名(无 Apple Developer 账号时防止 Gatekeeper 报已损坏) + echo " 🔏 正在对 .app 进行 ad-hoc 签名 (arm64)..." + codesign --force --deep --sign - "$DIST_DIR/$APP_DEST_NAME" - # 创建 DMG - if command -v create-dmg &> /dev/null; then - echo " 📦 正在打包 DMG (arm64)..." - # 移除已存在的 DMG (以防万一) - rm -f "$DIST_DIR/$DMG_NAME" + # 创建 DMG + if command -v create-dmg &> /dev/null; then + echo " 📦 正在打包 DMG (arm64)..." + # 移除已存在的 DMG (以防万一) + rm -f "$DIST_DIR/$DMG_NAME" + # create-dmg 的 source 需要是“包含 .app 的目录”,不能直接传 .app 路径。 + STAGE_DIR=$(mktemp -d "$DIST_DIR/.dmg-stage-${APP_NAME}-${VERSION}-arm64.XXXXXX") + if [ -z "$STAGE_DIR" ] || [ ! -d "$STAGE_DIR" ]; then + echo -e "${RED} ❌ 创建 DMG 临时目录失败,跳过 DMG 打包。${NC}" + else + if command -v ditto &> /dev/null; then + ditto "$DIST_DIR/$APP_DEST_NAME" "$STAGE_DIR/$APP_DEST_NAME" + else + cp -R "$DIST_DIR/$APP_DEST_NAME" "$STAGE_DIR/$APP_DEST_NAME" + fi - # --sandbox-safe 会跳过 Finder 的 AppleScript 排版,避免打包过程中弹出/打开挂载窗口(CI/本地静默打包更友好)。 - CREATE_DMG_ARGS=(--volname "${APP_NAME} ${VERSION}" --format UDZO --sandbox-safe) - if [ -n "$MAC_VOLICON_PATH" ]; then - CREATE_DMG_ARGS+=(--volicon "$MAC_VOLICON_PATH") + # --sandbox-safe 会跳过 Finder 的 AppleScript 排版,避免打包过程中弹出/打开挂载窗口(CI/本地静默打包更友好)。 + CREATE_DMG_ARGS=(--volname "${APP_NAME} ${VERSION}" --format UDZO --sandbox-safe) + if [ -n "$MAC_VOLICON_PATH" ]; then + CREATE_DMG_ARGS+=(--volicon "$MAC_VOLICON_PATH") else echo -e "${YELLOW} ⚠️ 未找到 macOS 卷图标 (build/darwin/icon.icns),跳过 --volicon。${NC}" fi - create-dmg "${CREATE_DMG_ARGS[@]}" \ - --window-pos 200 120 \ - --window-size 800 400 \ - --icon-size 100 \ - --icon "$APP_DEST_NAME" 200 190 \ - --hide-extension "$APP_DEST_NAME" \ - --app-drop-link 600 185 \ - "$DIST_DIR/$DMG_NAME" \ - "$DIST_DIR/$APP_DEST_NAME" + create-dmg "${CREATE_DMG_ARGS[@]}" \ + --window-pos 200 120 \ + --window-size 800 400 \ + --icon-size 100 \ + --icon "$APP_DEST_NAME" 200 190 \ + --hide-extension "$APP_DEST_NAME" \ + --app-drop-link 600 185 \ + "$DIST_DIR/$DMG_NAME" \ + "$STAGE_DIR" - CREATE_DMG_EXIT_CODE=$? - - if [ $CREATE_DMG_EXIT_CODE -ne 0 ]; then - echo -e "${RED} ❌ create-dmg 执行失败 (exit=$CREATE_DMG_EXIT_CODE),保留 .app 以便排查。${NC}" - else + CREATE_DMG_EXIT_CODE=$? + rm -rf "$STAGE_DIR" + + if [ $CREATE_DMG_EXIT_CODE -ne 0 ]; then + echo -e "${RED} ❌ create-dmg 执行失败 (exit=$CREATE_DMG_EXIT_CODE),保留 .app 以便排查。${NC}" + else # create-dmg 可能会在失败时遗留 rw.*.dmg 中间产物;不要直接当作最终产物使用 if [ ! -f "$DIST_DIR/$DMG_NAME" ]; then RW_FILE=$(find "$DIST_DIR" -maxdepth 1 -name "rw.*.dmg" -print -quit) @@ -108,14 +119,15 @@ if [ $? -eq 0 ]; then fi fi - if [ ! -f "$DIST_DIR/$DMG_NAME" ]; then - echo -e "${RED} ❌ DMG 生成失败,请检查 create-dmg 输出。${NC}" - fi - else - echo -e "${YELLOW} ⚠️ 未找到 create-dmg 工具,跳过 DMG 打包,仅保留 .app。${NC}" - echo " 安装命令: brew install create-dmg" - fi -else + if [ ! -f "$DIST_DIR/$DMG_NAME" ]; then + echo -e "${RED} ❌ DMG 生成失败,请检查 create-dmg 输出。${NC}" + fi + fi + else + echo -e "${YELLOW} ⚠️ 未找到 create-dmg 工具,跳过 DMG 打包,仅保留 .app。${NC}" + echo " 安装命令: brew install create-dmg" + fi + else echo -e "${RED} ❌ macOS arm64 构建失败。${NC}" fi @@ -129,37 +141,48 @@ if [ $? -eq 0 ]; then mv "$APP_SRC" "$DIST_DIR/$APP_DEST_NAME" - # Ad-hoc 代码签名 - echo " 🔏 正在对 .app 进行 ad-hoc 签名 (amd64)..." - codesign --force --deep --sign - "$DIST_DIR/$APP_DEST_NAME" + # Ad-hoc 代码签名 + echo " 🔏 正在对 .app 进行 ad-hoc 签名 (amd64)..." + codesign --force --deep --sign - "$DIST_DIR/$APP_DEST_NAME" - if command -v create-dmg &> /dev/null; then - echo " 📦 正在打包 DMG (amd64)..." - rm -f "$DIST_DIR/$DMG_NAME" + if command -v create-dmg &> /dev/null; then + echo " 📦 正在打包 DMG (amd64)..." + rm -f "$DIST_DIR/$DMG_NAME" + # create-dmg 的 source 需要是“包含 .app 的目录”,不能直接传 .app 路径。 + STAGE_DIR=$(mktemp -d "$DIST_DIR/.dmg-stage-${APP_NAME}-${VERSION}-amd64.XXXXXX") + if [ -z "$STAGE_DIR" ] || [ ! -d "$STAGE_DIR" ]; then + echo -e "${RED} ❌ 创建 DMG 临时目录失败,跳过 DMG 打包。${NC}" + else + if command -v ditto &> /dev/null; then + ditto "$DIST_DIR/$APP_DEST_NAME" "$STAGE_DIR/$APP_DEST_NAME" + else + cp -R "$DIST_DIR/$APP_DEST_NAME" "$STAGE_DIR/$APP_DEST_NAME" + fi - # --sandbox-safe 会跳过 Finder 的 AppleScript 排版,避免打包过程中弹出/打开挂载窗口(CI/本地静默打包更友好)。 - CREATE_DMG_ARGS=(--volname "${APP_NAME} ${VERSION}" --format UDZO --sandbox-safe) - if [ -n "$MAC_VOLICON_PATH" ]; then - CREATE_DMG_ARGS+=(--volicon "$MAC_VOLICON_PATH") + # --sandbox-safe 会跳过 Finder 的 AppleScript 排版,避免打包过程中弹出/打开挂载窗口(CI/本地静默打包更友好)。 + CREATE_DMG_ARGS=(--volname "${APP_NAME} ${VERSION}" --format UDZO --sandbox-safe) + if [ -n "$MAC_VOLICON_PATH" ]; then + CREATE_DMG_ARGS+=(--volicon "$MAC_VOLICON_PATH") else echo -e "${YELLOW} ⚠️ 未找到 macOS 卷图标 (build/darwin/icon.icns),跳过 --volicon。${NC}" fi - create-dmg "${CREATE_DMG_ARGS[@]}" \ - --window-pos 200 120 \ - --window-size 800 400 \ - --icon-size 100 \ - --icon "$APP_DEST_NAME" 200 190 \ - --hide-extension "$APP_DEST_NAME" \ - --app-drop-link 600 185 \ - "$DIST_DIR/$DMG_NAME" \ - "$DIST_DIR/$APP_DEST_NAME" + create-dmg "${CREATE_DMG_ARGS[@]}" \ + --window-pos 200 120 \ + --window-size 800 400 \ + --icon-size 100 \ + --icon "$APP_DEST_NAME" 200 190 \ + --hide-extension "$APP_DEST_NAME" \ + --app-drop-link 600 185 \ + "$DIST_DIR/$DMG_NAME" \ + "$STAGE_DIR" - CREATE_DMG_EXIT_CODE=$? + CREATE_DMG_EXIT_CODE=$? + rm -rf "$STAGE_DIR" - if [ $CREATE_DMG_EXIT_CODE -ne 0 ]; then - echo -e "${RED} ❌ create-dmg 执行失败 (exit=$CREATE_DMG_EXIT_CODE),保留 .app 以便排查。${NC}" - else + if [ $CREATE_DMG_EXIT_CODE -ne 0 ]; then + echo -e "${RED} ❌ create-dmg 执行失败 (exit=$CREATE_DMG_EXIT_CODE),保留 .app 以便排查。${NC}" + else if [ ! -f "$DIST_DIR/$DMG_NAME" ]; then RW_FILE=$(find "$DIST_DIR" -maxdepth 1 -name "rw.*.dmg" -print -quit) if [ -n "$RW_FILE" ]; then @@ -190,14 +213,15 @@ if [ $? -eq 0 ]; then fi fi - if [ ! -f "$DIST_DIR/$DMG_NAME" ]; then - echo -e "${RED} ❌ DMG 生成失败。${NC}" - fi - else - echo -e "${YELLOW} ⚠️ 未找到 create-dmg 工具。${NC}" - fi -else - echo -e "${RED} ❌ macOS amd64 构建失败。${NC}" + if [ ! -f "$DIST_DIR/$DMG_NAME" ]; then + echo -e "${RED} ❌ DMG 生成失败。${NC}" + fi + fi + else + echo -e "${YELLOW} ⚠️ 未找到 create-dmg 工具。${NC}" + fi + else + echo -e "${RED} ❌ macOS amd64 构建失败。${NC}" fi # --- Windows AMD64 构建 --- diff --git a/internal/db/diros_impl.go b/internal/db/diros_impl.go index 07bed73..773b7fa 100644 --- a/internal/db/diros_impl.go +++ b/internal/db/diros_impl.go @@ -9,7 +9,6 @@ import ( "strings" "GoNavi-Wails/internal/connection" - "GoNavi-Wails/internal/logger" "GoNavi-Wails/internal/ssh" "GoNavi-Wails/internal/utils" @@ -135,26 +134,26 @@ func collectDirosAddresses(config connection.ConnectionConfig) []string { return result } -func (d *DirosDB) getDSN(config connection.ConnectionConfig) string { +func (d *DirosDB) getDSN(config connection.ConnectionConfig) (string, error) { database := config.Database protocol := "tcp" address := normalizeMySQLAddress(config.Host, config.Port) if config.UseSSH { netName, err := ssh.RegisterSSHNetwork(config.SSH) - if err == nil { - protocol = netName - address = normalizeMySQLAddress(config.Host, config.Port) - } else { - logger.Warnf("注册 Doris SSH 网络失败,将尝试直连:地址=%s:%d 用户=%s,原因:%v", config.Host, config.Port, config.User, err) + if err != nil { + return "", fmt.Errorf("创建 SSH 隧道失败:%w", err) } + protocol = netName } timeout := getConnectTimeoutSeconds(config) tlsMode := resolveMySQLTLSMode(config) - return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s", - config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode)) + return fmt.Sprintf( + "%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s", + config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode), + ), nil } func resolveDirosCredential(config connection.ConnectionConfig, addressIndex int) (string, string) { @@ -192,7 +191,11 @@ func (d *DirosDB) Connect(config connection.ConnectionConfig) error { candidateConfig.Port = port candidateConfig.User, candidateConfig.Password = resolveDirosCredential(runConfig, index) - dsn := d.getDSN(candidateConfig) + dsn, err := d.getDSN(candidateConfig) + if err != nil { + errorDetails = append(errorDetails, fmt.Sprintf("%s 生成连接串失败: %v", address, err)) + continue + } db, err := sql.Open(dirosDriverName, dsn) if err != nil { errorDetails = append(errorDetails, fmt.Sprintf("%s 打开失败: %v", address, err)) diff --git a/internal/db/mariadb_impl.go b/internal/db/mariadb_impl.go index 6a36400..65b9cc3 100644 --- a/internal/db/mariadb_impl.go +++ b/internal/db/mariadb_impl.go @@ -11,7 +11,6 @@ import ( "time" "GoNavi-Wails/internal/connection" - "GoNavi-Wails/internal/logger" "GoNavi-Wails/internal/ssh" "GoNavi-Wails/internal/utils" @@ -25,30 +24,33 @@ type MariaDB struct { pingTimeout time.Duration } -func (m *MariaDB) getDSN(config connection.ConnectionConfig) string { +func (m *MariaDB) getDSN(config connection.ConnectionConfig) (string, error) { database := config.Database protocol := "tcp" address := fmt.Sprintf("%s:%d", config.Host, config.Port) if config.UseSSH { netName, err := ssh.RegisterSSHNetwork(config.SSH) - if err == nil { - protocol = netName - address = fmt.Sprintf("%s:%d", config.Host, config.Port) - } else { - logger.Warnf("注册 SSH 网络失败,将尝试直连:地址=%s:%d 用户=%s,原因:%v", config.Host, config.Port, config.User, err) + if err != nil { + return "", fmt.Errorf("创建 SSH 隧道失败:%w", err) } + protocol = netName } timeout := getConnectTimeoutSeconds(config) tlsMode := resolveMySQLTLSMode(config) - return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s", - config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode)) + return fmt.Sprintf( + "%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s", + config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode), + ), nil } func (m *MariaDB) Connect(config connection.ConnectionConfig) error { - dsn := m.getDSN(config) + dsn, err := m.getDSN(config) + if err != nil { + return err + } db, err := sql.Open("mysql", dsn) if err != nil { return fmt.Errorf("打开数据库连接失败:%w", err) diff --git a/internal/db/mysql_impl.go b/internal/db/mysql_impl.go index 5095f1c..32b63cc 100644 --- a/internal/db/mysql_impl.go +++ b/internal/db/mysql_impl.go @@ -169,26 +169,26 @@ func collectMySQLAddresses(config connection.ConnectionConfig) []string { return result } -func (m *MySQLDB) getDSN(config connection.ConnectionConfig) string { +func (m *MySQLDB) getDSN(config connection.ConnectionConfig) (string, error) { database := config.Database protocol := "tcp" address := normalizeMySQLAddress(config.Host, config.Port) if config.UseSSH { netName, err := ssh.RegisterSSHNetwork(config.SSH) - if err == nil { - protocol = netName - address = normalizeMySQLAddress(config.Host, config.Port) - } else { - logger.Warnf("注册 SSH 网络失败,将尝试直连:地址=%s:%d 用户=%s,原因:%v", config.Host, config.Port, config.User, err) + if err != nil { + return "", fmt.Errorf("创建 SSH 隧道失败:%w", err) } + protocol = netName } timeout := getConnectTimeoutSeconds(config) tlsMode := resolveMySQLTLSMode(config) - return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s", - config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode)) + return fmt.Sprintf( + "%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s", + config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode), + ), nil } func resolveMySQLCredential(config connection.ConnectionConfig, addressIndex int) (string, string) { @@ -226,7 +226,11 @@ func (m *MySQLDB) Connect(config connection.ConnectionConfig) error { candidateConfig.Port = port candidateConfig.User, candidateConfig.Password = resolveMySQLCredential(runConfig, index) - dsn := m.getDSN(candidateConfig) + dsn, err := m.getDSN(candidateConfig) + if err != nil { + errorDetails = append(errorDetails, fmt.Sprintf("%s 生成连接串失败: %v", address, err)) + continue + } db, err := sql.Open("mysql", dsn) if err != nil { errorDetails = append(errorDetails, fmt.Sprintf("%s 打开失败: %v", address, err)) diff --git a/internal/db/mysql_ssh_test.go b/internal/db/mysql_ssh_test.go new file mode 100644 index 0000000..673639c --- /dev/null +++ b/internal/db/mysql_ssh_test.go @@ -0,0 +1,26 @@ +package db + +import ( + "testing" + + "GoNavi-Wails/internal/connection" +) + +func TestMySQLDSN_UseSSH_ShouldFailWhenSSHInvalid(t *testing.T) { + m := &MySQLDB{} + _, err := m.getDSN(connection.ConnectionConfig{ + Host: "127.0.0.1", + Port: 3306, + User: "root", + UseSSH: true, + SSH: connection.SSHConfig{ + Host: "127.0.0.1", + Port: 0, // invalid port, should fail immediately + User: "bad", + Password: "bad", + }, + }) + if err == nil { + t.Fatalf("expected error when UseSSH=true and SSH config invalid") + } +} diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go index 51ad364..15feab3 100644 --- a/internal/ssh/ssh.go +++ b/internal/ssh/ssh.go @@ -2,10 +2,13 @@ package ssh import ( "context" + "crypto/sha256" + "encoding/hex" "fmt" "io" "net" "os" + "strconv" "sync" "time" @@ -69,7 +72,7 @@ func connectSSH(config connection.SSHConfig) (*ssh.Client, error) { } } } - + if config.Password != "" { authMethods = append(authMethods, ssh.Password(config.Password)) } @@ -105,7 +108,7 @@ func RegisterSSHNetwork(sshConfig connection.SSHConfig) (string, error) { // Generate unique network name netName := fmt.Sprintf("ssh_%s_%d", sshConfig.Host, time.Now().UnixNano()) logger.Infof("注册 SSH 网络:%s(地址=%s:%d 用户=%s)", netName, sshConfig.Host, sshConfig.Port, sshConfig.User) - + mysql.RegisterDialContext(netName, func(ctx context.Context, addr string) (net.Conn, error) { return dialContext(ctx, client, "tcp", addr) }) @@ -115,12 +118,58 @@ func RegisterSSHNetwork(sshConfig connection.SSHConfig) (string, error) { // sshClientCache stores SSH clients to avoid creating multiple connections var ( - sshClientCache = make(map[string]*ssh.Client) + sshClientCache = make(map[sshClientCacheKey]*ssh.Client) sshClientCacheMu sync.RWMutex - localForwarders = make(map[string]*LocalForwarder) + localForwarders = make(map[forwarderCacheKey]*LocalForwarder) forwarderMu sync.RWMutex ) +type sshClientCacheKey struct { + host string + port int + user string + auth string +} + +type forwarderCacheKey struct { + ssh sshClientCacheKey + remoteHost string + remotePort int +} + +func sshAuthFingerprint(config connection.SSHConfig) string { + hasher := sha256.New() + _, _ = hasher.Write([]byte(config.Password)) + _, _ = hasher.Write([]byte{0}) + _, _ = hasher.Write([]byte(config.KeyPath)) + if config.KeyPath != "" { + if st, err := os.Stat(config.KeyPath); err == nil { + _, _ = hasher.Write([]byte{0}) + _, _ = hasher.Write([]byte(st.ModTime().UTC().Format(time.RFC3339Nano))) + _, _ = hasher.Write([]byte{0}) + _, _ = hasher.Write([]byte(strconv.FormatInt(st.Size(), 10))) + } else { + _, _ = hasher.Write([]byte{0}) + _, _ = hasher.Write([]byte("stat_err")) + } + } + sum := hasher.Sum(nil) + return hex.EncodeToString(sum[:8]) +} + +func newSSHClientCacheKey(config connection.SSHConfig) sshClientCacheKey { + return sshClientCacheKey{ + host: config.Host, + port: config.Port, + user: config.User, + auth: sshAuthFingerprint(config), + } +} + +func formatSSHClientKeyForLog(key sshClientCacheKey) string { + return fmt.Sprintf("%s:%d 用户=%s", key.host, key.port, key.user) +} + // LocalForwarder represents a local port forwarder through SSH type LocalForwarder struct { LocalAddr string @@ -249,9 +298,13 @@ func (f *LocalForwarder) IsClosed() bool { // GetOrCreateLocalForwarder returns a cached forwarder or creates a new one func GetOrCreateLocalForwarder(sshConfig connection.SSHConfig, remoteHost string, remotePort int) (*LocalForwarder, error) { - key := fmt.Sprintf("%s:%d:%s->%s:%d", - sshConfig.Host, sshConfig.Port, sshConfig.User, - remoteHost, remotePort) + key := forwarderCacheKey{ + ssh: newSSHClientCacheKey(sshConfig), + remoteHost: remoteHost, + remotePort: remotePort, + } + logKey := fmt.Sprintf("%s:%d:%s->%s:%d", + sshConfig.Host, sshConfig.Port, sshConfig.User, remoteHost, remotePort) forwarderMu.RLock() forwarder, exists := localForwarders[key] @@ -259,7 +312,7 @@ func GetOrCreateLocalForwarder(sshConfig connection.SSHConfig, remoteHost string // Check if exists and is still valid if exists && forwarder != nil && !forwarder.IsClosed() { - logger.Infof("复用已有端口转发:%s", key) + logger.Infof("复用已有端口转发:%s", logKey) return forwarder, nil } @@ -287,24 +340,18 @@ func CloseAllForwarders() { forwarderMu.Lock() defer forwarderMu.Unlock() - for key, forwarder := range localForwarders { + for _, forwarder := range localForwarders { if forwarder != nil { _ = forwarder.Close() - logger.Infof("已关闭端口转发:%s", key) + logger.Infof("已关闭端口转发:本地 %s -> 远程 %s", forwarder.LocalAddr, forwarder.RemoteAddr) } } - localForwarders = make(map[string]*LocalForwarder) -} - - -// getSSHClientCacheKey generates a unique cache key for SSH config -func getSSHClientCacheKey(config connection.SSHConfig) string { - return fmt.Sprintf("%s:%d:%s", config.Host, config.Port, config.User) + localForwarders = make(map[forwarderCacheKey]*LocalForwarder) } // GetOrCreateSSHClient returns a cached SSH client or creates a new one func GetOrCreateSSHClient(config connection.SSHConfig) (*ssh.Client, error) { - key := getSSHClientCacheKey(config) + key := newSSHClientCacheKey(config) sshClientCacheMu.RLock() client, exists := sshClientCache[key] @@ -315,11 +362,11 @@ func GetOrCreateSSHClient(config connection.SSHConfig) (*ssh.Client, error) { session, err := client.NewSession() if err == nil { session.Close() - logger.Infof("复用已有 SSH 连接:%s", key) + logger.Infof("复用已有 SSH 连接:%s", formatSSHClientKeyForLog(key)) return client, nil } // Connection is dead, remove from cache - logger.Warnf("SSH 连接已断开,重新建立:%s (错误: %v)", key, err) + logger.Warnf("SSH 连接已断开,重新建立:%s (错误: %v)", formatSSHClientKeyForLog(key), err) sshClientCacheMu.Lock() delete(sshClientCache, key) sshClientCacheMu.Unlock() @@ -338,7 +385,7 @@ func GetOrCreateSSHClient(config connection.SSHConfig) (*ssh.Client, error) { sshClientCache[key] = client sshClientCacheMu.Unlock() - logger.Infof("已缓存 SSH 连接:%s", key) + logger.Infof("已缓存 SSH 连接:%s", formatSSHClientKeyForLog(key)) return client, nil } @@ -367,9 +414,8 @@ func CloseAllSSHClients() { for key, client := range sshClientCache { if client != nil { _ = client.Close() - logger.Infof("已关闭 SSH 连接:%s", key) + logger.Infof("已关闭 SSH 连接:%s", formatSSHClientKeyForLog(key)) } } - sshClientCache = make(map[string]*ssh.Client) + sshClientCache = make(map[sshClientCacheKey]*ssh.Client) } - diff --git a/internal/ssh/ssh_cache_key_test.go b/internal/ssh/ssh_cache_key_test.go new file mode 100644 index 0000000..16322c2 --- /dev/null +++ b/internal/ssh/ssh_cache_key_test.go @@ -0,0 +1,46 @@ +package ssh + +import ( + "testing" + + "GoNavi-Wails/internal/connection" +) + +func TestNewSSHClientCacheKey_DiffPassword(t *testing.T) { + a := newSSHClientCacheKey(connection.SSHConfig{ + Host: "127.0.0.1", + Port: 22, + User: "root", + Password: "a", + }) + b := newSSHClientCacheKey(connection.SSHConfig{ + Host: "127.0.0.1", + Port: 22, + User: "root", + Password: "b", + }) + if a == b { + t.Fatalf("expected different cache key when password differs") + } + if a.host != b.host || a.port != b.port || a.user != b.user { + t.Fatalf("expected host/port/user to stay identical") + } +} + +func TestNewSSHClientCacheKey_DiffKeyPath(t *testing.T) { + a := newSSHClientCacheKey(connection.SSHConfig{ + Host: "127.0.0.1", + Port: 22, + User: "root", + KeyPath: "/tmp/a.key", + }) + b := newSSHClientCacheKey(connection.SSHConfig{ + Host: "127.0.0.1", + Port: 22, + User: "root", + KeyPath: "/tmp/b.key", + }) + if a == b { + t.Fatalf("expected different cache key when keyPath differs") + } +}