🐛 fix(oceanbase): 修复 Oracle 租户 SSH 预探测超时

- 拆分 OceanBase Oracle 预探测的拨号超时与握手读取超时
- SSH 跳板机场景下使用完整连接超时,避免内网目标被误判不可达
- 保留 MySQL handshake 短读取超时,避免 TNS 端口测试连接变慢
- 补充 SSH 预探测超时与短读取行为回归测试
This commit is contained in:
Syngnat
2026-06-15 17:56:34 +08:00
parent 891c8c1200
commit 682017ba96
2 changed files with 109 additions and 12 deletions

View File

@@ -59,10 +59,11 @@ import (
)
const (
oceanbaseDriverName = "oceanbase"
defaultOceanBasePort = 2881
oceanBaseProtocolMySQL = "mysql"
oceanBaseProtocolOracle = "oracle"
oceanbaseDriverName = "oceanbase"
defaultOceanBasePort = 2881
oceanBaseProtocolMySQL = "mysql"
oceanBaseProtocolOracle = "oracle"
oceanBaseOracleProbeReadTimeout = 3 * time.Second
)
// OceanBaseDB 支持 OceanBase MySQL/Oracle 两种租户协议。
@@ -519,18 +520,25 @@ func probeOceanBaseMySQLWireHandshake(host string, port int, timeout time.Durati
}
func probeOceanBaseMySQLWireHandshakeDetail(config connection.ConnectionConfig, timeout time.Duration) oceanBaseMySQLWireProbeResult {
if timeout <= 0 {
timeout = 2 * time.Second
return probeOceanBaseMySQLWireHandshakeDetailWithTimeouts(config, timeout, timeout)
}
func probeOceanBaseMySQLWireHandshakeDetailWithTimeouts(config connection.ConnectionConfig, dialTimeout time.Duration, readTimeout time.Duration) oceanBaseMySQLWireProbeResult {
if dialTimeout <= 0 {
dialTimeout = 2 * time.Second
}
if readTimeout <= 0 {
readTimeout = dialTimeout
}
addr := normalizeMySQLAddress(config.Host, config.Port)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
ctx, cancel := context.WithTimeout(context.Background(), dialTimeout)
defer cancel()
conn, err := oceanBaseProbeDialContext(ctx, config, addr)
if err != nil {
return oceanBaseMySQLWireProbeResult{err: err}
}
defer conn.Close()
_ = conn.SetDeadline(time.Now().Add(timeout))
_ = conn.SetDeadline(time.Now().Add(readTimeout))
header := make([]byte, 4)
if _, err := io.ReadFull(conn, header); err != nil {
@@ -680,11 +688,16 @@ func (o *OceanBaseDB) Connect(config connection.ConnectionConfig) error {
if protocol == oceanBaseProtocolOracle {
// 预探测目标端口的实际协议,决定走哪条 Oracle 连接路径。
probeTimeout := getConnectTimeout(runConfig)
if probeTimeout > 3*time.Second {
probeTimeout = 3 * time.Second
// SSH 跳板机到内网目标的 direct-tcpip 拨号可能慢于 3 秒;只收紧握手读取超时,避免误判内网目标不可达。
probeDialTimeout := getConnectTimeout(runConfig)
if !runConfig.UseSSH && probeDialTimeout > oceanBaseOracleProbeReadTimeout {
probeDialTimeout = oceanBaseOracleProbeReadTimeout
}
probeResult := probeOceanBaseMySQLWireHandshakeDetail(runConfig, probeTimeout)
probeReadTimeout := oceanBaseOracleProbeReadTimeout
if probeReadTimeout > probeDialTimeout {
probeReadTimeout = probeDialTimeout
}
probeResult := probeOceanBaseMySQLWireHandshakeDetailWithTimeouts(runConfig, probeDialTimeout, probeReadTimeout)
switch {
case probeResult.probeSucceeded && probeResult.isOBMySQLWire:
// 明确识别为 OB MySQL wire 端口:直接走 OBClient capability 路径

View File

@@ -520,6 +520,90 @@ func TestProbeOceanBaseMySQLWireHandshakeUsesSSHConfiguredDialer(t *testing.T) {
}
}
func TestOceanBaseOracleConnectUsesFullSSHTimeoutForProbeDial(t *testing.T) {
originalDial := oceanBaseProbeDialContext
t.Cleanup(func() { oceanBaseProbeDialContext = originalDial })
var observedDialTimeout time.Duration
oceanBaseProbeDialContext = func(ctx context.Context, config connection.ConnectionConfig, address string) (net.Conn, error) {
if deadline, ok := ctx.Deadline(); ok {
observedDialTimeout = time.Until(deadline)
}
return nil, errors.New("remote dial denied")
}
ob := &OceanBaseDB{}
err := ob.Connect(connection.ConnectionConfig{
Type: "oceanbase",
Host: "172.22.39.20",
Port: 12883,
User: "SBDEV@SERVICE:srv_yhcs",
Password: "secret",
Database: "srv_yhcs",
OceanBaseProtocol: oceanBaseProtocolOracle,
Timeout: 12,
UseSSH: true,
SSH: connection.SSHConfig{
Host: "jump.example.com",
Port: 22,
User: "ops",
},
})
if err == nil {
t.Fatal("expected connect error from mocked probe dialer")
}
if observedDialTimeout < 10*time.Second {
t.Fatalf("expected SSH probe dial to use the full configured timeout, got about %s", observedDialTimeout)
}
}
func TestProbeOceanBaseMySQLWireHandshakeSplitsDialAndReadTimeout(t *testing.T) {
originalDial := oceanBaseProbeDialContext
t.Cleanup(func() { oceanBaseProbeDialContext = originalDial })
var observedDialTimeout time.Duration
var serverConn net.Conn
oceanBaseProbeDialContext = func(ctx context.Context, config connection.ConnectionConfig, address string) (net.Conn, error) {
if deadline, ok := ctx.Deadline(); ok {
observedDialTimeout = time.Until(deadline)
}
clientConn, remoteConn := net.Pipe()
serverConn = remoteConn
return clientConn, nil
}
t.Cleanup(func() {
if serverConn != nil {
_ = serverConn.Close()
}
})
started := time.Now()
result := probeOceanBaseMySQLWireHandshakeDetailWithTimeouts(connection.ConnectionConfig{
Host: "172.22.39.20",
Port: 12883,
UseSSH: true,
SSH: connection.SSHConfig{
Host: "jump.example.com",
Port: 22,
User: "ops",
},
}, 12*time.Second, 50*time.Millisecond)
elapsed := time.Since(started)
if observedDialTimeout < 10*time.Second {
t.Fatalf("expected probe dial context to keep the long dial timeout, got about %s", observedDialTimeout)
}
if elapsed > 500*time.Millisecond {
t.Fatalf("expected handshake read to use short timeout, elapsed=%s", elapsed)
}
if !result.probeSucceeded || !result.tcpReachable {
t.Fatalf("expected short read timeout to be treated as reachable non-MySQL-wire probe, got %+v", result)
}
if result.err == nil {
t.Fatalf("expected read timeout error to be recorded for diagnostics")
}
}
// probe 放宽 protocol_version 检查后,普通 MySQL/MariaDBserver_version 不含 OB 关键字)
// 应仍判定为非 OB MySQL wire由 regular_mysql_is_not_flagged / mariadb_is_not_flagged 子用例
// 覆盖)。原 IgnoresNonMySQLProtocol 测试因 probe 不再严格区分 mysql vs 非 mysql 而失效,已删除。