🐛 fix(oceanbase): 增强 Oracle 协议连接校验与诊断

- 运行时校验可选 driver-agent revision,避免旧代理继续被复用
- OceanBase agent revision 纳入 oracle_impl.go 指纹并重新生成
- OceanBase Oracle 保留 URI 中的 Oracle 连接参数
- Oracle DSN 默认写入连接和读取超时,并输出脱敏诊断摘要
- 补充 revision、Oracle DSN、OceanBase Oracle 参数提升测试
This commit is contained in:
Syngnat
2026-05-14 10:30:17 +08:00
parent 6456658576
commit 527ecd37e1
9 changed files with 253 additions and 6 deletions

View File

@@ -543,6 +543,9 @@ func (a *App) openDatabaseIsolated(config connection.ConnectionConfig) (db.Datab
}
return nil, withLogHint{err: fmt.Errorf("%s", reason), logPath: logger.Path()}
}
if revisionErr := verifyRuntimeOptionalDriverAgentRevision(effectiveConfig); revisionErr != nil {
return nil, withLogHint{err: revisionErr, logPath: logger.Path()}
}
dbInst, err := newDatabaseFunc(effectiveConfig.Type)
if err != nil {
@@ -655,6 +658,9 @@ func (a *App) getDatabaseWithPing(config connection.ConnectionConfig, forcePing
formatConnSummary(effectiveConfig), shortKey, formatConnectFailureCooldown(remaining), normalizeErrorMessage(failure.err))
return nil, withLogHint{err: fmt.Errorf("%s", message), logPath: logger.Path()}
}
if revisionErr := verifyRuntimeOptionalDriverAgentRevision(effectiveConfig); revisionErr != nil {
return nil, withLogHint{err: revisionErr, logPath: logger.Path()}
}
initialKey := key
dbInst, connectedConfig, err := a.connectDatabaseWithStartupRetry(resolvedConfig)
@@ -744,6 +750,32 @@ func formatConnectFailureCooldown(remaining time.Duration) time.Duration {
return remaining.Truncate(time.Second)
}
func verifyRuntimeOptionalDriverAgentRevision(config connection.ConnectionConfig) error {
driverType := normalizeDriverType(config.Type)
if !db.IsOptionalGoDriver(driverType) {
return nil
}
executablePath, err := db.ResolveOptionalDriverAgentExecutablePath("", driverType)
if err != nil {
return err
}
pkg, packageMetaExists := readInstalledDriverPackage("", driverType)
selectedVersion := ""
if packageMetaExists {
selectedVersion = strings.TrimSpace(pkg.Version)
}
agentRevision, err := verifyInstalledOptionalDriverAgentRevision(driverType, executablePath, selectedVersion)
if err != nil {
return err
}
if expectedRevision := strings.TrimSpace(db.OptionalDriverAgentRevision(driverType)); expectedRevision != "" {
displayName := resolveDriverDisplayName(driverDefinition{Type: driverType})
logger.Infof("%s driver-agent revision 校验通过:已安装=%s 当前需要=%s version=%s path=%s",
displayName, strings.TrimSpace(agentRevision), expectedRevision, selectedVersion, executablePath)
}
return nil
}
func shortenCacheKey(key string) string {
if len(key) > 12 {
return key[:12]

View File

@@ -53,6 +53,7 @@ var _ db.Database = (*fakeCreateDatabaseDB)(nil)
func TestResolveDDLDBType_SQLServerAliases(t *testing.T) {
tests := []connection.ConnectionConfig{
{Type: "sqlserver"},
{Type: "mssql"},
{Type: "sql_server"},
{Type: "custom", Driver: "mssql"},
@@ -95,7 +96,8 @@ func TestCreateDatabase_SQLServerUsesBracketIdentifiers(t *testing.T) {
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
result := app.CreateDatabase(connection.ConnectionConfig{
Type: "sqlserver",
Type: "custom",
Driver: "mssql",
Database: "master",
}, "lg")

View File

@@ -79,6 +79,49 @@ func TestVerifyInstalledOptionalDriverAgentRevisionRejectsProbeFailure(t *testin
}
}
func TestVerifyRuntimeOptionalDriverAgentRevisionRejectsStaleOceanBaseAgent(t *testing.T) {
originalProbe := optionalDriverAgentMetadataProbe
t.Cleanup(func() {
optionalDriverAgentMetadataProbe = originalProbe
})
optionalDriverAgentMetadataProbe = func(driverType string, executablePath string) (db.OptionalDriverAgentMetadata, error) {
return db.OptionalDriverAgentMetadata{
DriverType: driverType,
AgentRevision: "src-stale-agent",
}, nil
}
err := verifyRuntimeOptionalDriverAgentRevision(connection.ConnectionConfig{Type: "oceanbase"})
if err == nil {
t.Fatal("expected stale OceanBase agent revision to be rejected")
}
if !strings.Contains(err.Error(), "revision 不匹配") {
t.Fatalf("expected revision mismatch error, got %v", err)
}
}
func TestVerifyRuntimeOptionalDriverAgentRevisionSkipsCustomDriver(t *testing.T) {
originalProbe := optionalDriverAgentMetadataProbe
t.Cleanup(func() {
optionalDriverAgentMetadataProbe = originalProbe
})
calls := 0
optionalDriverAgentMetadataProbe = func(driverType string, executablePath string) (db.OptionalDriverAgentMetadata, error) {
calls++
return db.OptionalDriverAgentMetadata{}, nil
}
if err := verifyRuntimeOptionalDriverAgentRevision(connection.ConnectionConfig{
Type: "custom",
Driver: "oceanbase",
}); err != nil {
t.Fatalf("custom driver should skip optional agent runtime revision check: %v", err)
}
if calls != 0 {
t.Fatalf("custom driver should not probe optional agent metadata, got %d calls", calls)
}
}
func optionalDriverAgentRevisionTestDrivers(t *testing.T) []string {
t.Helper()
drivers := []string{

View File

@@ -5,7 +5,7 @@ package db
func init() {
optionalDriverAgentRevisions = map[string]string{
"mariadb": "src-1a1cc64f8f92d92b",
"oceanbase": "src-ac051813e2451265",
"oceanbase": "src-5bcb757b1b85d41e",
"diros": "src-bcc78fa43671ade5",
"sphinx": "src-404765c2fda68c5f",
"sqlserver": "src-d9fba1eca0a27c49",

View File

@@ -248,6 +248,33 @@ func withoutOceanBaseProtocolParams(config connection.ConnectionConfig) connecti
return next
}
func promoteOceanBaseOracleURIParams(config connection.ConnectionConfig) connection.ConnectionConfig {
uriParams := connectionParamsFromURI(config.URI, "oceanbase", "mysql")
if len(uriParams) == 0 {
return config
}
for _, key := range []string{"protocol", "oceanBaseProtocol", "oceanbaseProtocol", "tenantMode", "compatMode", "mode"} {
uriParams.Del(key)
}
if len(uriParams) == 0 {
return config
}
merged := url.Values{}
mergeConnectionParamValuesWithAllowlist(merged, uriParams, oracleConnectionParamNames)
mergeConnectionParamValuesWithAllowlist(merged, connectionParamsFromText(config.ConnectionParams), oracleConnectionParamNames)
config.ConnectionParams = merged.Encode()
return config
}
func prepareOceanBaseOracleConfig(config connection.ConnectionConfig) connection.ConnectionConfig {
runConfig := withoutOceanBaseProtocolParams(applyOceanBaseURI(config))
runConfig = promoteOceanBaseOracleURIParams(runConfig)
runConfig.Type = "oracle"
// OracleDB 不解析 oceanbase:// URI。连接要素已落到结构化字段和 ConnectionParams。
runConfig.URI = ""
return runConfig
}
func isOceanBaseOracleTenantMySQLDriverError(err error) bool {
if err == nil {
return false
@@ -264,8 +291,7 @@ func formatOceanBaseMySQLAttemptError(address string, err error) string {
}
func (o *OceanBaseDB) connectOracle(config connection.ConnectionConfig) error {
runConfig := withoutOceanBaseProtocolParams(applyOceanBaseURI(config))
runConfig.Type = "oracle"
runConfig := prepareOceanBaseOracleConfig(config)
if strings.TrimSpace(runConfig.Database) == "" {
return fmt.Errorf("OceanBase Oracle 协议需要填写服务名Service Name请在连接配置中填写租户监听的服务名")
}

View File

@@ -148,6 +148,36 @@ func TestWithoutOceanBaseProtocolParamsStripsDriverMeta(t *testing.T) {
}
}
func TestPrepareOceanBaseOracleConfigPromotesURIParams(t *testing.T) {
t.Parallel()
config := prepareOceanBaseOracleConfig(connection.ConnectionConfig{
Type: "oceanbase",
URI: "oceanbase://sys%40oracle001:pass@127.0.0.1:2881/ORCL?protocol=oracle&CONNECT_TIMEOUT=12&DBA_PRIVILEGE=SYSDBA",
ConnectionParams: "protocol=oracle&READ_TIMEOUT=7",
})
if config.Type != "oracle" {
t.Fatalf("expected routed type oracle, got %q", config.Type)
}
if config.URI != "" {
t.Fatalf("expected routed Oracle config to clear oceanbase URI, got %q", config.URI)
}
params := connectionParamsFromText(config.ConnectionParams)
if got := params.Get("CONNECT TIMEOUT"); got != "12" {
t.Fatalf("expected URI CONNECT_TIMEOUT promoted, got %q in %q", got, config.ConnectionParams)
}
if got := params.Get("READ TIMEOUT"); got != "7" {
t.Fatalf("expected explicit READ_TIMEOUT kept, got %q in %q", got, config.ConnectionParams)
}
if got := params.Get("DBA PRIVILEGE"); got != "SYSDBA" {
t.Fatalf("expected URI DBA_PRIVILEGE promoted, got %q in %q", got, config.ConnectionParams)
}
if strings.Contains(config.ConnectionParams, "protocol=") {
t.Fatalf("expected OceanBase protocol param stripped, got %q", config.ConnectionParams)
}
}
func TestOceanBaseOracleRequiresServiceName(t *testing.T) {
t.Parallel()

View File

@@ -1,7 +1,9 @@
package db
import (
"errors"
"net/url"
"strings"
"testing"
"GoNavi-Wails/internal/connection"
@@ -31,6 +33,31 @@ func TestOracleGetDSNIncludesQueryPerformanceOptions(t *testing.T) {
}
}
func TestOracleGetDSNIncludesTimeoutDefaults(t *testing.T) {
t.Parallel()
dsn := (&OracleDB{}).getDSN(connection.ConnectionConfig{
Host: "db.example.com",
Port: 1521,
User: "scott",
Password: "tiger",
Database: "ORCLPDB1",
Timeout: 12,
})
parsed, err := url.Parse(dsn)
if err != nil {
t.Fatalf("解析 Oracle DSN 失败: %v", err)
}
query := parsed.Query()
if got := query.Get("CONNECT TIMEOUT"); got != "12" {
t.Fatalf("CONNECT TIMEOUT = %q, want 12", got)
}
if got := query.Get("READ TIMEOUT"); got != "12" {
t.Fatalf("READ TIMEOUT = %q, want 12", got)
}
}
func TestOracleGetDSNMergesConnectionParams(t *testing.T) {
t.Parallel()
@@ -40,7 +67,7 @@ func TestOracleGetDSNMergesConnectionParams(t *testing.T) {
User: "scott",
Password: "tiger",
Database: "ORCLPDB1",
ConnectionParams: "PREFETCH_ROWS=5000&TRACE FILE=/tmp/go-ora.trc&connect_timeout=10&FAILOVER=3&unknown=bad",
ConnectionParams: "PREFETCH_ROWS=5000&TRACE FILE=/tmp/go-ora.trc&connect_timeout=10&read_timeout=7&FAILOVER=3&unknown=bad",
})
parsed, err := url.Parse(dsn)
@@ -57,6 +84,9 @@ func TestOracleGetDSNMergesConnectionParams(t *testing.T) {
if got := query.Get("CONNECT TIMEOUT"); got != "10" {
t.Fatalf("CONNECT TIMEOUT = %q, want 10", got)
}
if got := query.Get("READ TIMEOUT"); got != "7" {
t.Fatalf("READ TIMEOUT = %q, want 7", got)
}
if got := query.Get("FAILOVER"); got != "" {
t.Fatalf("FAILOVER should be filtered because go-ora no longer supports it, got %q", got)
}
@@ -64,3 +94,35 @@ func TestOracleGetDSNMergesConnectionParams(t *testing.T) {
t.Fatalf("unknown should be filtered, got %q", got)
}
}
func TestOracleDSNLogSummaryDoesNotExposePassword(t *testing.T) {
t.Parallel()
dsn := (&OracleDB{}).getDSN(connection.ConnectionConfig{
Host: "db.example.com",
Port: 1521,
User: "sys@tenant",
Password: "top-secret",
Database: "ORCLPDB1",
ConnectionParams: "DBA_PRIVILEGE=SYSDBA&AUTH_TYPE=NORMAL",
})
got := oracleDSNLogSummary(connection.ConnectionConfig{Database: "ORCLPDB1"}, dsn)
if strings.Contains(got, "top-secret") || strings.Contains(got, "sys@tenant") {
t.Fatalf("summary should not expose credentials, got %q", got)
}
for _, want := range []string{"服务名=ORCLPDB1", "DBA_PRIVILEGE=SYSDBA", "AUTH_TYPE=NORMAL"} {
if !strings.Contains(got, want) {
t.Fatalf("expected summary to contain %q, got %q", want, got)
}
}
}
func TestAnnotateOracleValidationErrorAddsClosedConnectionHint(t *testing.T) {
t.Parallel()
err := annotateOracleValidationError(errors.New("read tcp 127.0.0.1:1->127.0.0.1:2: use of closed network connection"))
if err == nil || !strings.Contains(err.Error(), "Service Name") {
t.Fatalf("expected closed connection hint, got %v", err)
}
}

View File

@@ -48,6 +48,9 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
q.Set("PREFETCH_ROWS", "10000")
// LOB 数据延迟加载,避免大 LOB 列影响普通查询性能
q.Set("LOB FETCH", "POST")
timeoutSeconds := strconv.Itoa(getConnectTimeoutSeconds(config))
q.Set("CONNECT TIMEOUT", timeoutSeconds)
q.Set("READ TIMEOUT", timeoutSeconds)
mergeConnectionParamsFromConfigWithAllowlist(q, config, oracleConnectionParamNames, "oracle")
if encoded := q.Encode(); encoded != "" {
u.RawQuery = encoded
@@ -55,6 +58,53 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
return u.String()
}
func oracleQueryValue(values url.Values, key string) string {
return strings.TrimSpace(values.Get(key))
}
func oracleQueryValueOrDefault(values url.Values, key string) string {
value := oracleQueryValue(values, key)
if value == "" {
return "未配置"
}
return value
}
func oracleDSNLogSummary(config connection.ConnectionConfig, dsn string) string {
serviceName := strings.TrimSpace(config.Database)
params := url.Values{}
if parsed, err := url.Parse(dsn); err == nil && parsed != nil {
if pathService, unescapeErr := url.PathUnescape(strings.TrimPrefix(parsed.EscapedPath(), "/")); unescapeErr == nil && strings.TrimSpace(pathService) != "" {
serviceName = strings.TrimSpace(pathService)
}
params = parsed.Query()
}
if serviceName == "" {
serviceName = "(未配置)"
}
return fmt.Sprintf("服务名=%s CONNECT_TIMEOUT=%s READ_TIMEOUT=%s SSL=%s SSL_VERIFY=%s AUTH_TYPE=%s DBA_PRIVILEGE=%s SID=%s",
serviceName,
oracleQueryValueOrDefault(params, "CONNECT TIMEOUT"),
oracleQueryValueOrDefault(params, "READ TIMEOUT"),
oracleQueryValueOrDefault(params, "SSL"),
oracleQueryValueOrDefault(params, "SSL VERIFY"),
oracleQueryValueOrDefault(params, "AUTH TYPE"),
oracleQueryValueOrDefault(params, "DBA PRIVILEGE"),
oracleQueryValueOrDefault(params, "SID"),
)
}
func annotateOracleValidationError(err error) error {
if err == nil {
return nil
}
message := strings.ToLower(err.Error())
if !strings.Contains(message, "use of closed network connection") {
return err
}
return fmt.Errorf("%wOracle 连接在验证阶段被服务端关闭或被驱动超时中断;请检查监听端口是否为 Oracle 协议端口、Service Name 是否正确、认证参数如 DBA_PRIVILEGE/AUTH_TYPE 是否匹配)", err)
}
func (o *OracleDB) Connect(config connection.ConnectionConfig) error {
runConfig := config
serviceName := strings.TrimSpace(config.Database)
@@ -101,6 +151,7 @@ func (o *OracleDB) Connect(config connection.ConnectionConfig) error {
var failures []string
for idx, attempt := range attempts {
dsn := o.getDSN(attempt)
logger.Infof("Oracle 连接参数摘要:地址=%s:%d 用户=%s %s", attempt.Host, attempt.Port, attempt.User, oracleDSNLogSummary(attempt, dsn))
db, err := sql.Open("oracle", dsn)
if err != nil {
failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err))
@@ -111,7 +162,7 @@ func (o *OracleDB) Connect(config connection.ConnectionConfig) error {
if err := o.Ping(); err != nil {
_ = db.Close()
o.conn = nil
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err))
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, annotateOracleValidationError(err)))
continue
}
if idx > 0 {

View File

@@ -93,6 +93,7 @@ internal/db/timeout.go)
case "$driver:$identity" in
mariadb:internal/db/mariadb_impl.go|\
oceanbase:internal/db/oceanbase_impl.go|\
oceanbase:internal/db/oracle_impl.go|\
oceanbase:internal/db/mysql_impl.go|\
diros:internal/db/diros_impl.go|\
diros:internal/db/mysql_impl.go|\