From 682017ba960c67352b896eedd76696008adbea0b Mon Sep 17 00:00:00 2001 From: Syngnat Date: Mon, 15 Jun 2026 17:56:34 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(oceanbase):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=20Oracle=20=E7=A7=9F=E6=88=B7=20SSH=20=E9=A2=84?= =?UTF-8?q?=E6=8E=A2=E6=B5=8B=E8=B6=85=E6=97=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 拆分 OceanBase Oracle 预探测的拨号超时与握手读取超时 - SSH 跳板机场景下使用完整连接超时,避免内网目标被误判不可达 - 保留 MySQL handshake 短读取超时,避免 TNS 端口测试连接变慢 - 补充 SSH 预探测超时与短读取行为回归测试 --- internal/db/oceanbase_impl.go | 37 ++++++++----- internal/db/oceanbase_impl_test.go | 84 ++++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 12 deletions(-) diff --git a/internal/db/oceanbase_impl.go b/internal/db/oceanbase_impl.go index 3beb278..9234f2b 100644 --- a/internal/db/oceanbase_impl.go +++ b/internal/db/oceanbase_impl.go @@ -59,10 +59,11 @@ import ( ) const ( - oceanbaseDriverName = "oceanbase" - defaultOceanBasePort = 2881 - oceanBaseProtocolMySQL = "mysql" - oceanBaseProtocolOracle = "oracle" + oceanbaseDriverName = "oceanbase" + defaultOceanBasePort = 2881 + oceanBaseProtocolMySQL = "mysql" + oceanBaseProtocolOracle = "oracle" + oceanBaseOracleProbeReadTimeout = 3 * time.Second ) // OceanBaseDB 支持 OceanBase MySQL/Oracle 两种租户协议。 @@ -519,18 +520,25 @@ func probeOceanBaseMySQLWireHandshake(host string, port int, timeout time.Durati } func probeOceanBaseMySQLWireHandshakeDetail(config connection.ConnectionConfig, timeout time.Duration) oceanBaseMySQLWireProbeResult { - if timeout <= 0 { - timeout = 2 * time.Second + return probeOceanBaseMySQLWireHandshakeDetailWithTimeouts(config, timeout, timeout) +} + +func probeOceanBaseMySQLWireHandshakeDetailWithTimeouts(config connection.ConnectionConfig, dialTimeout time.Duration, readTimeout time.Duration) oceanBaseMySQLWireProbeResult { + if dialTimeout <= 0 { + dialTimeout = 2 * time.Second + } + if readTimeout <= 0 { + readTimeout = dialTimeout } addr := normalizeMySQLAddress(config.Host, config.Port) - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(context.Background(), dialTimeout) defer cancel() conn, err := oceanBaseProbeDialContext(ctx, config, addr) if err != nil { return oceanBaseMySQLWireProbeResult{err: err} } defer conn.Close() - _ = conn.SetDeadline(time.Now().Add(timeout)) + _ = conn.SetDeadline(time.Now().Add(readTimeout)) header := make([]byte, 4) if _, err := io.ReadFull(conn, header); err != nil { @@ -680,11 +688,16 @@ func (o *OceanBaseDB) Connect(config connection.ConnectionConfig) error { if protocol == oceanBaseProtocolOracle { // 预探测目标端口的实际协议,决定走哪条 Oracle 连接路径。 - probeTimeout := getConnectTimeout(runConfig) - if probeTimeout > 3*time.Second { - probeTimeout = 3 * time.Second + // SSH 跳板机到内网目标的 direct-tcpip 拨号可能慢于 3 秒;只收紧握手读取超时,避免误判内网目标不可达。 + probeDialTimeout := getConnectTimeout(runConfig) + if !runConfig.UseSSH && probeDialTimeout > oceanBaseOracleProbeReadTimeout { + probeDialTimeout = oceanBaseOracleProbeReadTimeout } - probeResult := probeOceanBaseMySQLWireHandshakeDetail(runConfig, probeTimeout) + probeReadTimeout := oceanBaseOracleProbeReadTimeout + if probeReadTimeout > probeDialTimeout { + probeReadTimeout = probeDialTimeout + } + probeResult := probeOceanBaseMySQLWireHandshakeDetailWithTimeouts(runConfig, probeDialTimeout, probeReadTimeout) switch { case probeResult.probeSucceeded && probeResult.isOBMySQLWire: // 明确识别为 OB MySQL wire 端口:直接走 OBClient capability 路径 diff --git a/internal/db/oceanbase_impl_test.go b/internal/db/oceanbase_impl_test.go index 97ddb9d..baa6e7b 100644 --- a/internal/db/oceanbase_impl_test.go +++ b/internal/db/oceanbase_impl_test.go @@ -520,6 +520,90 @@ func TestProbeOceanBaseMySQLWireHandshakeUsesSSHConfiguredDialer(t *testing.T) { } } +func TestOceanBaseOracleConnectUsesFullSSHTimeoutForProbeDial(t *testing.T) { + originalDial := oceanBaseProbeDialContext + t.Cleanup(func() { oceanBaseProbeDialContext = originalDial }) + + var observedDialTimeout time.Duration + oceanBaseProbeDialContext = func(ctx context.Context, config connection.ConnectionConfig, address string) (net.Conn, error) { + if deadline, ok := ctx.Deadline(); ok { + observedDialTimeout = time.Until(deadline) + } + return nil, errors.New("remote dial denied") + } + + ob := &OceanBaseDB{} + err := ob.Connect(connection.ConnectionConfig{ + Type: "oceanbase", + Host: "172.22.39.20", + Port: 12883, + User: "SBDEV@SERVICE:srv_yhcs", + Password: "secret", + Database: "srv_yhcs", + OceanBaseProtocol: oceanBaseProtocolOracle, + Timeout: 12, + UseSSH: true, + SSH: connection.SSHConfig{ + Host: "jump.example.com", + Port: 22, + User: "ops", + }, + }) + if err == nil { + t.Fatal("expected connect error from mocked probe dialer") + } + if observedDialTimeout < 10*time.Second { + t.Fatalf("expected SSH probe dial to use the full configured timeout, got about %s", observedDialTimeout) + } +} + +func TestProbeOceanBaseMySQLWireHandshakeSplitsDialAndReadTimeout(t *testing.T) { + originalDial := oceanBaseProbeDialContext + t.Cleanup(func() { oceanBaseProbeDialContext = originalDial }) + + var observedDialTimeout time.Duration + var serverConn net.Conn + oceanBaseProbeDialContext = func(ctx context.Context, config connection.ConnectionConfig, address string) (net.Conn, error) { + if deadline, ok := ctx.Deadline(); ok { + observedDialTimeout = time.Until(deadline) + } + clientConn, remoteConn := net.Pipe() + serverConn = remoteConn + return clientConn, nil + } + t.Cleanup(func() { + if serverConn != nil { + _ = serverConn.Close() + } + }) + + started := time.Now() + result := probeOceanBaseMySQLWireHandshakeDetailWithTimeouts(connection.ConnectionConfig{ + Host: "172.22.39.20", + Port: 12883, + UseSSH: true, + SSH: connection.SSHConfig{ + Host: "jump.example.com", + Port: 22, + User: "ops", + }, + }, 12*time.Second, 50*time.Millisecond) + elapsed := time.Since(started) + + if observedDialTimeout < 10*time.Second { + t.Fatalf("expected probe dial context to keep the long dial timeout, got about %s", observedDialTimeout) + } + if elapsed > 500*time.Millisecond { + t.Fatalf("expected handshake read to use short timeout, elapsed=%s", elapsed) + } + if !result.probeSucceeded || !result.tcpReachable { + t.Fatalf("expected short read timeout to be treated as reachable non-MySQL-wire probe, got %+v", result) + } + if result.err == nil { + t.Fatalf("expected read timeout error to be recorded for diagnostics") + } +} + // probe 放宽 protocol_version 检查后,普通 MySQL/MariaDB(server_version 不含 OB 关键字) // 应仍判定为非 OB MySQL wire(由 regular_mysql_is_not_flagged / mariadb_is_not_flagged 子用例 // 覆盖)。原 IgnoresNonMySQLProtocol 测试因 probe 不再严格区分 mysql vs 非 mysql 而失效,已删除。