From a611c1c04bc89827aa60942e186df2695e7facd2 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Mon, 15 Jun 2026 16:13:15 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(oceanbase):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=20Oracle=20=E7=A7=9F=E6=88=B7=E8=B7=B3=E6=9D=BF?= =?UTF-8?q?=E6=9C=BA=E8=BF=9E=E6=8E=A5=E9=A2=84=E6=8E=A2=E6=B5=8B=E5=A4=B1?= =?UTF-8?q?=E8=B4=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修复 OceanBase Oracle 预探测未走 SSH 隧道导致内网 IP 被本机直连误判不可达的问题 - 预探测阶段复用完整连接配置,支持通过 SSH 跳板机访问目标地址 - 区分本机 TCP 不可达与 SSH 跳板机访问失败,优化错误提示 - 保留 OBClient 与 TNS 双路径路由逻辑,避免协议判断回退 - 补充 OceanBase Oracle SSH 预探测与网络失败回归测试 --- internal/db/oceanbase_impl.go | 90 ++++++++++++++------ internal/db/oceanbase_impl_test.go | 131 ++++++++++++++++++++++++++++- internal/ssh/ssh.go | 16 ++++ 3 files changed, 210 insertions(+), 27 deletions(-) diff --git a/internal/db/oceanbase_impl.go b/internal/db/oceanbase_impl.go index 4ec7573..3beb278 100644 --- a/internal/db/oceanbase_impl.go +++ b/internal/db/oceanbase_impl.go @@ -16,12 +16,12 @@ // - 协议=MySQL:走 go-sql-driver/mysql,连 MySQL 租户。OB 服务端在 Oracle 租户上返回 // "Error 1235 (0A000): Oracle tenant for current client driver is not supported" // 时,错误信息提示用户切换到 Oracle 协议。 -// - 协议=Oracle:先做 mysql wire 端口预探测(probeOceanBaseMySQLWireHandshake): -// * 端口是 OB MySQL wire → 走 mysql wire + OBClient capability 注入路径 -// (ensureOceanBaseOBClientAttributes + ensureOceanBaseOracleANSIQuotes), -// 元数据查询通过 OracleDB wrapper 复用 Oracle 方言 SQL, -// ApplyChanges 用 applyOracleChangesMySQLWire("?" 占位符 + 双引号引用)。 -// * 端口非 OB MySQL wire → 走 sijms/go-ora 连接 OBProxy 的 Oracle listener。 +// - 协议=Oracle:先做 mysql wire 端口预探测(probeOceanBaseMySQLWireHandshake)。 +// 识别为 OB MySQL wire 时,走 mysql wire + OBClient capability 注入路径 +// (ensureOceanBaseOBClientAttributes + ensureOceanBaseOracleANSIQuotes); +// 元数据查询通过 OracleDB wrapper 复用 Oracle 方言 SQL,ApplyChanges 用 +// applyOracleChangesMySQLWire("?" 占位符 + 双引号引用)。 +// 端口非 OB MySQL wire 时,走 sijms/go-ora 连接 OBProxy 的 Oracle listener。 // // OBClient capability attribute 候选清单(基于 OceanBase 公开 connector-j 资料 + // 社区经验,**未在本仓库联调验证 Navicat 用的具体组合**): @@ -465,6 +465,37 @@ func annotateOceanBaseOracleConnectError(err error) error { return fmt.Errorf("%w(OceanBase Oracle 协议连接失败)", err) } +type oceanBaseMySQLWireProbeResult struct { + isOBMySQLWire bool + probeSucceeded bool + tcpReachable bool + err error +} + +var oceanBaseProbeDialContext = defaultOceanBaseProbeDialContext + +func defaultOceanBaseProbeDialContext(ctx context.Context, config connection.ConnectionConfig, address string) (net.Conn, error) { + if config.UseSSH { + return ssh.DialContextThroughSSH(ctx, config.SSH, "tcp", address) + } + var dialer net.Dialer + return dialer.DialContext(ctx, "tcp", address) +} + +func formatOceanBaseOracleNetworkProbeError(config connection.ConnectionConfig, err error) error { + address := normalizeMySQLAddress(config.Host, config.Port) + if config.UseSSH { + if err == nil { + return fmt.Errorf("OceanBase Oracle 连接失败:通过 SSH 跳板机访问目标地址 %s 失败。该错误发生在协议选择之前,和 OBClient/TNS 路径无关;请确认跳板机能访问该内网地址,并检查 SSH 配置、远端防火墙以及 OBProxy/OBServer 监听端口", address) + } + return fmt.Errorf("OceanBase Oracle 连接失败:通过 SSH 跳板机访问目标地址 %s 失败:%w。该错误发生在协议选择之前,和 OBClient/TNS 路径无关;请确认跳板机能访问该内网地址,并检查 SSH 配置、远端防火墙以及 OBProxy/OBServer 监听端口", address, err) + } + if err == nil { + return fmt.Errorf("OceanBase Oracle 连接失败:目标地址 %s TCP 不可达。该错误发生在协议选择之前,和 OBClient/TNS 路径无关;请确认客户端机器能访问该地址,并检查 VPN/内网路由、防火墙以及 OBProxy/OBServer 监听端口", address) + } + return fmt.Errorf("OceanBase Oracle 连接失败:目标地址 %s TCP 不可达:%w。该错误发生在协议选择之前,和 OBClient/TNS 路径无关;请确认客户端机器能访问该地址,并检查 VPN/内网路由、防火墙以及 OBProxy/OBServer 监听端口", address, err) +} + // probeOceanBaseMySQLWireHandshake 通过读取目标端口的 MySQL initial handshake packet // 判断该端口背后是否是 OceanBase 的 MySQL wire 协议端口。 // @@ -475,57 +506,65 @@ func annotateOceanBaseOracleConnectError(err error) error { // 4. server_version 是从 payload[1] 开始的 null-terminated 字符串 // 5. server_version 中包含 "oceanbase" / "ob" 关键字时判定为 OB MySQL wire // -// 返回值:(isOBMySQLWire, probeSucceeded)。probeSucceeded=false 表示建连/读包失败, -// 上层应该兜底执行真实连接尝试(OBClient 优先于 TNS)。 +// 返回值:(isOBMySQLWire, probeSucceeded)。probeSucceeded=false 表示建连或完整握手包读取失败。 +// Connect 使用 probeOceanBaseMySQLWireHandshakeDetail 区分 TCP 不可达与协议探测失败。 // // 容忍度设计: // - protocol_version 不严限(OB 自定义版本号也接受) // - payload 上限 64KB(OB 4.x 的 handshake 可能携带额外的能力位信息) // - 短超时(2s):探测只为方向选择,主流程的真实超时由 Connect 控制 func probeOceanBaseMySQLWireHandshake(host string, port int, timeout time.Duration) (bool, bool) { + result := probeOceanBaseMySQLWireHandshakeDetail(connection.ConnectionConfig{Host: host, Port: port}, timeout) + return result.isOBMySQLWire, result.probeSucceeded +} + +func probeOceanBaseMySQLWireHandshakeDetail(config connection.ConnectionConfig, timeout time.Duration) oceanBaseMySQLWireProbeResult { if timeout <= 0 { timeout = 2 * time.Second } - addr := normalizeMySQLAddress(host, port) - dialer := net.Dialer{Timeout: timeout} - conn, err := dialer.Dial("tcp", addr) + addr := normalizeMySQLAddress(config.Host, config.Port) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + conn, err := oceanBaseProbeDialContext(ctx, config, addr) if err != nil { - return false, false + return oceanBaseMySQLWireProbeResult{err: err} } defer conn.Close() _ = conn.SetDeadline(time.Now().Add(timeout)) header := make([]byte, 4) if _, err := io.ReadFull(conn, header); err != nil { - return false, false + // TCP 已经连通但服务端没有主动发送 MySQL handshake,通常是 Oracle TNS listener + // 或其它非 MySQL wire 协议端口。此时不能归因为网络不可达。 + return oceanBaseMySQLWireProbeResult{probeSucceeded: true, tcpReachable: true, err: err} } payloadLen := int(header[0]) | int(header[1])<<8 | int(header[2])<<16 // 放宽上限:OB 4.x handshake 可能携带额外 capability info。仍要约束以避免读取异常长度 if payloadLen < 1 || payloadLen > 65536 { - return false, true + return oceanBaseMySQLWireProbeResult{probeSucceeded: true, tcpReachable: true} } payload := make([]byte, payloadLen) if _, err := io.ReadFull(conn, payload); err != nil { - return false, false + return oceanBaseMySQLWireProbeResult{tcpReachable: true, err: err} } // 不再严格检查 protocol_version。OB 自定义版本号也认作 MySQL wire 候选—— // 只要 server_version 字符串含 OceanBase/OBProxy 关键字就足以做方向选择。 nullIdx := bytes.IndexByte(payload[1:], 0) if nullIdx < 0 { - return false, true + return oceanBaseMySQLWireProbeResult{probeSucceeded: true, tcpReachable: true} } serverVersion := strings.ToLower(string(payload[1 : 1+nullIdx])) if serverVersion == "" { - return false, true + return oceanBaseMySQLWireProbeResult{probeSucceeded: true, tcpReachable: true} } if strings.Contains(serverVersion, "oceanbase") || strings.Contains(serverVersion, "obproxy") { - return true, true + return oceanBaseMySQLWireProbeResult{isOBMySQLWire: true, probeSucceeded: true, tcpReachable: true} } if strings.Contains(serverVersion, "-ob") { - return true, true + return oceanBaseMySQLWireProbeResult{isOBMySQLWire: true, probeSucceeded: true, tcpReachable: true} } - return false, true + return oceanBaseMySQLWireProbeResult{probeSucceeded: true, tcpReachable: true} } // connectOracleViaTNS 走 sijms/go-ora,连 OBProxy 暴露的 Oracle listener 端口(标准 TNS)。 @@ -645,18 +684,21 @@ func (o *OceanBaseDB) Connect(config connection.ConnectionConfig) error { if probeTimeout > 3*time.Second { probeTimeout = 3 * time.Second } - isOBMySQLWire, probed := probeOceanBaseMySQLWireHandshake(runConfig.Host, runConfig.Port, probeTimeout) + probeResult := probeOceanBaseMySQLWireHandshakeDetail(runConfig, probeTimeout) switch { - case probed && isOBMySQLWire: + case probeResult.probeSucceeded && probeResult.isOBMySQLWire: // 明确识别为 OB MySQL wire 端口:直接走 OBClient capability 路径 logger.Infof("OceanBase 协议=Oracle 预探测:%s:%d 是 OB MySQL wire 端口,走 OBClient capability 注入路径连接 Oracle 租户", runConfig.Host, runConfig.Port) return o.connectOracleViaOBClient(runConfig) - case probed: + case probeResult.probeSucceeded: // 探测成功但 server_version 不含 OceanBase 标识:可能是真正的 Oracle TNS 端口 logger.Infof("OceanBase 协议=Oracle 预探测:%s:%d 不是 OB MySQL wire,走标准 Oracle TNS 协议(OBProxy Oracle listener)", runConfig.Host, runConfig.Port) return o.connectOracleViaTNS(runConfig) + case !probeResult.tcpReachable && probeResult.err != nil: + logger.Warnf("OceanBase 协议=Oracle 预探测建连失败:%s:%d,跳过 OBClient/TNS 重复尝试:%v", runConfig.Host, runConfig.Port, probeResult.err) + return formatOceanBaseOracleNetworkProbeError(runConfig, probeResult.err) default: - // 探测失败(建连或读 handshake 失败):可能是网络不通、防火墙阻断、或某些 OB 版本不主动发 handshake。 + // 探测失败但 TCP 已建连:可能是异常截断的握手包,或某些 OB 版本不主动发完整 handshake。 // 不能盲选 TNS——用户填 60014/2881 这类端口大概率仍是 OB MySQL wire。 // 串行尝试两条真实路径:先 OBClient(命中概率更高),失败再 TNS,合并错误信息。 logger.Warnf("OceanBase 协议=Oracle 预探测失败:%s:%d,串行尝试 OBClient capability 与 TNS 两条路径", runConfig.Host, runConfig.Port) diff --git a/internal/db/oceanbase_impl_test.go b/internal/db/oceanbase_impl_test.go index 3dc8b3f..97ddb9d 100644 --- a/internal/db/oceanbase_impl_test.go +++ b/internal/db/oceanbase_impl_test.go @@ -3,6 +3,7 @@ package db import ( + "context" "errors" "net" "net/url" @@ -376,7 +377,8 @@ func TestProbeOceanBaseMySQLWireDetectsOceanBaseHandshake(t *testing.T) { func TestProbeOceanBaseMySQLWireHandshakeReturnsFalseOnUnreachable(t *testing.T) { t.Parallel() - // 用一个不可达端口(监听后立即关闭),探测应返回 probed=false 让上层继续走 go-ora 路径 + // 用一个不可达端口(监听后立即关闭),探测应返回 probed=false, + // 上层会直接给出网络不可达诊断,避免 OBClient/TNS 两条路径重复超时。 ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen failed: %v", err) @@ -390,7 +392,131 @@ func TestProbeOceanBaseMySQLWireHandshakeReturnsFalseOnUnreachable(t *testing.T) t.Fatal("expected unreachable port not flagged as OB") } if probed { - t.Fatal("expected probed=false on unreachable port so upper layer falls back to go-ora") + t.Fatal("expected probed=false on unreachable port so upper layer can return network diagnosis") + } +} + +func TestOceanBaseOracleConnectStopsOnProbeDialFailure(t *testing.T) { + t.Parallel() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen failed: %v", err) + } + host, portStr, _ := net.SplitHostPort(ln.Addr().String()) + port, _ := strconv.Atoi(portStr) + _ = ln.Close() + + ob := &OceanBaseDB{} + err = ob.Connect(connection.ConnectionConfig{ + Type: "oceanbase", + Host: host, + Port: port, + User: "SBDEV@SERVICE:srv_yhcs", + Password: "secret", + Database: "sbdev", + OceanBaseProtocol: oceanBaseProtocolOracle, + Timeout: 1, + }) + if err == nil { + t.Fatal("expected connect error for unreachable OceanBase Oracle endpoint") + } + got := err.Error() + if !strings.Contains(got, "TCP 不可达") { + t.Fatalf("expected direct TCP unreachable diagnosis, got %q", got) + } + if !strings.Contains(got, "和 OBClient/TNS 路径无关") { + t.Fatalf("expected error to explain protocol paths are irrelevant, got %q", got) + } + if strings.Contains(got, "两条连接路径均失败") { + t.Fatalf("expected no dual-path failure after probe dial failure, got %q", got) + } +} + +func TestOceanBaseOracleConnectProbeDialFailureMentionsSSHWhenEnabled(t *testing.T) { + originalDial := oceanBaseProbeDialContext + t.Cleanup(func() { oceanBaseProbeDialContext = originalDial }) + + var seenConfig connection.ConnectionConfig + var seenAddress string + oceanBaseProbeDialContext = func(ctx context.Context, config connection.ConnectionConfig, address string) (net.Conn, error) { + seenConfig = config + seenAddress = address + return nil, errors.New("remote dial denied") + } + + ob := &OceanBaseDB{} + err := ob.Connect(connection.ConnectionConfig{ + Type: "oceanbase", + Host: "172.22.39.20", + Port: 12883, + User: "SBDEV@SERVICE:srv_yhcs", + Password: "secret", + Database: "sbdev", + OceanBaseProtocol: oceanBaseProtocolOracle, + Timeout: 1, + UseSSH: true, + SSH: connection.SSHConfig{ + Host: "jump.example.com", + Port: 22, + User: "ops", + Password: "jump-secret", + }, + }) + if err == nil { + t.Fatal("expected connect error for SSH probe dial failure") + } + got := err.Error() + if !seenConfig.UseSSH { + t.Fatalf("expected probe dialer to receive UseSSH=true, got %+v", seenConfig) + } + if seenAddress != "172.22.39.20:12883" { + t.Fatalf("expected probe target to remain remote inner address, got %q", seenAddress) + } + if !strings.Contains(got, "通过 SSH 跳板机访问目标地址 172.22.39.20:12883 失败") { + t.Fatalf("expected SSH-specific network diagnosis, got %q", got) + } + if strings.Contains(got, "VPN/内网路由") { + t.Fatalf("expected SSH diagnosis not direct-client VPN hint, got %q", got) + } +} + +func TestProbeOceanBaseMySQLWireHandshakeUsesSSHConfiguredDialer(t *testing.T) { + originalDial := oceanBaseProbeDialContext + t.Cleanup(func() { oceanBaseProbeDialContext = originalDial }) + + var seenConfig connection.ConnectionConfig + var seenAddress string + oceanBaseProbeDialContext = func(ctx context.Context, config connection.ConnectionConfig, address string) (net.Conn, error) { + seenConfig = config + seenAddress = address + clientConn, serverConn := net.Pipe() + go func() { + defer serverConn.Close() + _, _ = serverConn.Write(buildMySQLHandshakePacket("5.7.25-OceanBase-v4.2.1.0")) + }() + return clientConn, nil + } + + result := probeOceanBaseMySQLWireHandshakeDetail(connection.ConnectionConfig{ + Host: "172.22.39.20", + Port: 12883, + UseSSH: true, + SSH: connection.SSHConfig{ + Host: "jump.example.com", + Port: 22, + User: "ops", + }, + }, time.Second) + + if !result.probeSucceeded || !result.isOBMySQLWire { + t.Fatalf("expected SSH-routed probe to detect OceanBase handshake, got %+v", result) + } + if !seenConfig.UseSSH { + t.Fatalf("expected probe dialer to receive SSH config, got %+v", seenConfig) + } + if seenAddress != "172.22.39.20:12883" { + t.Fatalf("expected remote target address through SSH, got %q", seenAddress) } } @@ -590,4 +716,3 @@ func TestFormatOceanBaseMySQLAttemptErrorHintsOracleProtocol(t *testing.T) { t.Fatalf("expected hint to mention OBProxy Oracle protocol port, got %q", got) } } - diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go index 15feab3..c145e5a 100644 --- a/internal/ssh/ssh.go +++ b/internal/ssh/ssh.go @@ -116,6 +116,22 @@ func RegisterSSHNetwork(sshConfig connection.SSHConfig) (string, error) { return netName, nil } +// DialContextThroughSSH creates a context-aware connection through an SSH tunnel. +func DialContextThroughSSH(ctx context.Context, config connection.SSHConfig, network, address string) (net.Conn, error) { + client, err := GetOrCreateSSHClient(config) + if err != nil { + return nil, fmt.Errorf("建立 SSH 连接失败:%w", err) + } + + conn, err := dialContext(ctx, client, network, address) + if err != nil { + return nil, fmt.Errorf("通过 SSH 隧道连接到 %s 失败:%w", address, err) + } + + logger.Infof("已通过 SSH 隧道连接到:%s", address) + return conn, nil +} + // sshClientCache stores SSH clients to avoid creating multiple connections var ( sshClientCache = make(map[sshClientCacheKey]*ssh.Client)