diff --git a/internal/app/app.go b/internal/app/app.go index 5b37396..5c2de4f 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -543,6 +543,9 @@ func (a *App) openDatabaseIsolated(config connection.ConnectionConfig) (db.Datab } return nil, withLogHint{err: fmt.Errorf("%s", reason), logPath: logger.Path()} } + if revisionErr := verifyRuntimeOptionalDriverAgentRevision(effectiveConfig); revisionErr != nil { + return nil, withLogHint{err: revisionErr, logPath: logger.Path()} + } dbInst, err := newDatabaseFunc(effectiveConfig.Type) if err != nil { @@ -655,6 +658,9 @@ func (a *App) getDatabaseWithPing(config connection.ConnectionConfig, forcePing formatConnSummary(effectiveConfig), shortKey, formatConnectFailureCooldown(remaining), normalizeErrorMessage(failure.err)) return nil, withLogHint{err: fmt.Errorf("%s", message), logPath: logger.Path()} } + if revisionErr := verifyRuntimeOptionalDriverAgentRevision(effectiveConfig); revisionErr != nil { + return nil, withLogHint{err: revisionErr, logPath: logger.Path()} + } initialKey := key dbInst, connectedConfig, err := a.connectDatabaseWithStartupRetry(resolvedConfig) @@ -744,6 +750,32 @@ func formatConnectFailureCooldown(remaining time.Duration) time.Duration { return remaining.Truncate(time.Second) } +func verifyRuntimeOptionalDriverAgentRevision(config connection.ConnectionConfig) error { + driverType := normalizeDriverType(config.Type) + if !db.IsOptionalGoDriver(driverType) { + return nil + } + executablePath, err := db.ResolveOptionalDriverAgentExecutablePath("", driverType) + if err != nil { + return err + } + pkg, packageMetaExists := readInstalledDriverPackage("", driverType) + selectedVersion := "" + if packageMetaExists { + selectedVersion = strings.TrimSpace(pkg.Version) + } + agentRevision, err := verifyInstalledOptionalDriverAgentRevision(driverType, executablePath, selectedVersion) + if err != nil { + return err + } + if expectedRevision := strings.TrimSpace(db.OptionalDriverAgentRevision(driverType)); expectedRevision != "" { + displayName := resolveDriverDisplayName(driverDefinition{Type: driverType}) + logger.Infof("%s driver-agent revision 校验通过:已安装=%s 当前需要=%s version=%s path=%s", + displayName, strings.TrimSpace(agentRevision), expectedRevision, selectedVersion, executablePath) + } + return nil +} + func shortenCacheKey(key string) string { if len(key) > 12 { return key[:12] diff --git a/internal/app/methods_db_create_test.go b/internal/app/methods_db_create_test.go index cad9ab1..f7c4768 100644 --- a/internal/app/methods_db_create_test.go +++ b/internal/app/methods_db_create_test.go @@ -53,6 +53,7 @@ var _ db.Database = (*fakeCreateDatabaseDB)(nil) func TestResolveDDLDBType_SQLServerAliases(t *testing.T) { tests := []connection.ConnectionConfig{ + {Type: "sqlserver"}, {Type: "mssql"}, {Type: "sql_server"}, {Type: "custom", Driver: "mssql"}, @@ -95,7 +96,8 @@ func TestCreateDatabase_SQLServerUsesBracketIdentifiers(t *testing.T) { app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test")) result := app.CreateDatabase(connection.ConnectionConfig{ - Type: "sqlserver", + Type: "custom", + Driver: "mssql", Database: "master", }, "lg") diff --git a/internal/app/methods_driver_agent_revision_test.go b/internal/app/methods_driver_agent_revision_test.go index 7c48833..22423f8 100644 --- a/internal/app/methods_driver_agent_revision_test.go +++ b/internal/app/methods_driver_agent_revision_test.go @@ -79,6 +79,49 @@ func TestVerifyInstalledOptionalDriverAgentRevisionRejectsProbeFailure(t *testin } } +func TestVerifyRuntimeOptionalDriverAgentRevisionRejectsStaleOceanBaseAgent(t *testing.T) { + originalProbe := optionalDriverAgentMetadataProbe + t.Cleanup(func() { + optionalDriverAgentMetadataProbe = originalProbe + }) + optionalDriverAgentMetadataProbe = func(driverType string, executablePath string) (db.OptionalDriverAgentMetadata, error) { + return db.OptionalDriverAgentMetadata{ + DriverType: driverType, + AgentRevision: "src-stale-agent", + }, nil + } + + err := verifyRuntimeOptionalDriverAgentRevision(connection.ConnectionConfig{Type: "oceanbase"}) + if err == nil { + t.Fatal("expected stale OceanBase agent revision to be rejected") + } + if !strings.Contains(err.Error(), "revision 不匹配") { + t.Fatalf("expected revision mismatch error, got %v", err) + } +} + +func TestVerifyRuntimeOptionalDriverAgentRevisionSkipsCustomDriver(t *testing.T) { + originalProbe := optionalDriverAgentMetadataProbe + t.Cleanup(func() { + optionalDriverAgentMetadataProbe = originalProbe + }) + calls := 0 + optionalDriverAgentMetadataProbe = func(driverType string, executablePath string) (db.OptionalDriverAgentMetadata, error) { + calls++ + return db.OptionalDriverAgentMetadata{}, nil + } + + if err := verifyRuntimeOptionalDriverAgentRevision(connection.ConnectionConfig{ + Type: "custom", + Driver: "oceanbase", + }); err != nil { + t.Fatalf("custom driver should skip optional agent runtime revision check: %v", err) + } + if calls != 0 { + t.Fatalf("custom driver should not probe optional agent metadata, got %d calls", calls) + } +} + func optionalDriverAgentRevisionTestDrivers(t *testing.T) []string { t.Helper() drivers := []string{ diff --git a/internal/db/driver_agent_revisions_gen.go b/internal/db/driver_agent_revisions_gen.go index 8044be9..8908843 100644 --- a/internal/db/driver_agent_revisions_gen.go +++ b/internal/db/driver_agent_revisions_gen.go @@ -5,7 +5,7 @@ package db func init() { optionalDriverAgentRevisions = map[string]string{ "mariadb": "src-1a1cc64f8f92d92b", - "oceanbase": "src-ac051813e2451265", + "oceanbase": "src-5bcb757b1b85d41e", "diros": "src-bcc78fa43671ade5", "sphinx": "src-404765c2fda68c5f", "sqlserver": "src-d9fba1eca0a27c49", diff --git a/internal/db/oceanbase_impl.go b/internal/db/oceanbase_impl.go index 947c2aa..98525f0 100644 --- a/internal/db/oceanbase_impl.go +++ b/internal/db/oceanbase_impl.go @@ -248,6 +248,33 @@ func withoutOceanBaseProtocolParams(config connection.ConnectionConfig) connecti return next } +func promoteOceanBaseOracleURIParams(config connection.ConnectionConfig) connection.ConnectionConfig { + uriParams := connectionParamsFromURI(config.URI, "oceanbase", "mysql") + if len(uriParams) == 0 { + return config + } + for _, key := range []string{"protocol", "oceanBaseProtocol", "oceanbaseProtocol", "tenantMode", "compatMode", "mode"} { + uriParams.Del(key) + } + if len(uriParams) == 0 { + return config + } + merged := url.Values{} + mergeConnectionParamValuesWithAllowlist(merged, uriParams, oracleConnectionParamNames) + mergeConnectionParamValuesWithAllowlist(merged, connectionParamsFromText(config.ConnectionParams), oracleConnectionParamNames) + config.ConnectionParams = merged.Encode() + return config +} + +func prepareOceanBaseOracleConfig(config connection.ConnectionConfig) connection.ConnectionConfig { + runConfig := withoutOceanBaseProtocolParams(applyOceanBaseURI(config)) + runConfig = promoteOceanBaseOracleURIParams(runConfig) + runConfig.Type = "oracle" + // OracleDB 不解析 oceanbase:// URI。连接要素已落到结构化字段和 ConnectionParams。 + runConfig.URI = "" + return runConfig +} + func isOceanBaseOracleTenantMySQLDriverError(err error) bool { if err == nil { return false @@ -264,8 +291,7 @@ func formatOceanBaseMySQLAttemptError(address string, err error) string { } func (o *OceanBaseDB) connectOracle(config connection.ConnectionConfig) error { - runConfig := withoutOceanBaseProtocolParams(applyOceanBaseURI(config)) - runConfig.Type = "oracle" + runConfig := prepareOceanBaseOracleConfig(config) if strings.TrimSpace(runConfig.Database) == "" { return fmt.Errorf("OceanBase Oracle 协议需要填写服务名(Service Name),请在连接配置中填写租户监听的服务名") } diff --git a/internal/db/oceanbase_impl_test.go b/internal/db/oceanbase_impl_test.go index 9dd3cd7..ff262fb 100644 --- a/internal/db/oceanbase_impl_test.go +++ b/internal/db/oceanbase_impl_test.go @@ -148,6 +148,36 @@ func TestWithoutOceanBaseProtocolParamsStripsDriverMeta(t *testing.T) { } } +func TestPrepareOceanBaseOracleConfigPromotesURIParams(t *testing.T) { + t.Parallel() + + config := prepareOceanBaseOracleConfig(connection.ConnectionConfig{ + Type: "oceanbase", + URI: "oceanbase://sys%40oracle001:pass@127.0.0.1:2881/ORCL?protocol=oracle&CONNECT_TIMEOUT=12&DBA_PRIVILEGE=SYSDBA", + ConnectionParams: "protocol=oracle&READ_TIMEOUT=7", + }) + + if config.Type != "oracle" { + t.Fatalf("expected routed type oracle, got %q", config.Type) + } + if config.URI != "" { + t.Fatalf("expected routed Oracle config to clear oceanbase URI, got %q", config.URI) + } + params := connectionParamsFromText(config.ConnectionParams) + if got := params.Get("CONNECT TIMEOUT"); got != "12" { + t.Fatalf("expected URI CONNECT_TIMEOUT promoted, got %q in %q", got, config.ConnectionParams) + } + if got := params.Get("READ TIMEOUT"); got != "7" { + t.Fatalf("expected explicit READ_TIMEOUT kept, got %q in %q", got, config.ConnectionParams) + } + if got := params.Get("DBA PRIVILEGE"); got != "SYSDBA" { + t.Fatalf("expected URI DBA_PRIVILEGE promoted, got %q in %q", got, config.ConnectionParams) + } + if strings.Contains(config.ConnectionParams, "protocol=") { + t.Fatalf("expected OceanBase protocol param stripped, got %q", config.ConnectionParams) + } +} + func TestOceanBaseOracleRequiresServiceName(t *testing.T) { t.Parallel() diff --git a/internal/db/oracle_dsn_test.go b/internal/db/oracle_dsn_test.go index c02c5b5..89e0014 100644 --- a/internal/db/oracle_dsn_test.go +++ b/internal/db/oracle_dsn_test.go @@ -1,7 +1,9 @@ package db import ( + "errors" "net/url" + "strings" "testing" "GoNavi-Wails/internal/connection" @@ -31,6 +33,31 @@ func TestOracleGetDSNIncludesQueryPerformanceOptions(t *testing.T) { } } +func TestOracleGetDSNIncludesTimeoutDefaults(t *testing.T) { + t.Parallel() + + dsn := (&OracleDB{}).getDSN(connection.ConnectionConfig{ + Host: "db.example.com", + Port: 1521, + User: "scott", + Password: "tiger", + Database: "ORCLPDB1", + Timeout: 12, + }) + + parsed, err := url.Parse(dsn) + if err != nil { + t.Fatalf("解析 Oracle DSN 失败: %v", err) + } + query := parsed.Query() + if got := query.Get("CONNECT TIMEOUT"); got != "12" { + t.Fatalf("CONNECT TIMEOUT = %q, want 12", got) + } + if got := query.Get("READ TIMEOUT"); got != "12" { + t.Fatalf("READ TIMEOUT = %q, want 12", got) + } +} + func TestOracleGetDSNMergesConnectionParams(t *testing.T) { t.Parallel() @@ -40,7 +67,7 @@ func TestOracleGetDSNMergesConnectionParams(t *testing.T) { User: "scott", Password: "tiger", Database: "ORCLPDB1", - ConnectionParams: "PREFETCH_ROWS=5000&TRACE FILE=/tmp/go-ora.trc&connect_timeout=10&FAILOVER=3&unknown=bad", + ConnectionParams: "PREFETCH_ROWS=5000&TRACE FILE=/tmp/go-ora.trc&connect_timeout=10&read_timeout=7&FAILOVER=3&unknown=bad", }) parsed, err := url.Parse(dsn) @@ -57,6 +84,9 @@ func TestOracleGetDSNMergesConnectionParams(t *testing.T) { if got := query.Get("CONNECT TIMEOUT"); got != "10" { t.Fatalf("CONNECT TIMEOUT = %q, want 10", got) } + if got := query.Get("READ TIMEOUT"); got != "7" { + t.Fatalf("READ TIMEOUT = %q, want 7", got) + } if got := query.Get("FAILOVER"); got != "" { t.Fatalf("FAILOVER should be filtered because go-ora no longer supports it, got %q", got) } @@ -64,3 +94,35 @@ func TestOracleGetDSNMergesConnectionParams(t *testing.T) { t.Fatalf("unknown should be filtered, got %q", got) } } + +func TestOracleDSNLogSummaryDoesNotExposePassword(t *testing.T) { + t.Parallel() + + dsn := (&OracleDB{}).getDSN(connection.ConnectionConfig{ + Host: "db.example.com", + Port: 1521, + User: "sys@tenant", + Password: "top-secret", + Database: "ORCLPDB1", + ConnectionParams: "DBA_PRIVILEGE=SYSDBA&AUTH_TYPE=NORMAL", + }) + + got := oracleDSNLogSummary(connection.ConnectionConfig{Database: "ORCLPDB1"}, dsn) + if strings.Contains(got, "top-secret") || strings.Contains(got, "sys@tenant") { + t.Fatalf("summary should not expose credentials, got %q", got) + } + for _, want := range []string{"服务名=ORCLPDB1", "DBA_PRIVILEGE=SYSDBA", "AUTH_TYPE=NORMAL"} { + if !strings.Contains(got, want) { + t.Fatalf("expected summary to contain %q, got %q", want, got) + } + } +} + +func TestAnnotateOracleValidationErrorAddsClosedConnectionHint(t *testing.T) { + t.Parallel() + + err := annotateOracleValidationError(errors.New("read tcp 127.0.0.1:1->127.0.0.1:2: use of closed network connection")) + if err == nil || !strings.Contains(err.Error(), "Service Name") { + t.Fatalf("expected closed connection hint, got %v", err) + } +} diff --git a/internal/db/oracle_impl.go b/internal/db/oracle_impl.go index a7ca991..b2b7551 100644 --- a/internal/db/oracle_impl.go +++ b/internal/db/oracle_impl.go @@ -48,6 +48,9 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string { q.Set("PREFETCH_ROWS", "10000") // LOB 数据延迟加载,避免大 LOB 列影响普通查询性能 q.Set("LOB FETCH", "POST") + timeoutSeconds := strconv.Itoa(getConnectTimeoutSeconds(config)) + q.Set("CONNECT TIMEOUT", timeoutSeconds) + q.Set("READ TIMEOUT", timeoutSeconds) mergeConnectionParamsFromConfigWithAllowlist(q, config, oracleConnectionParamNames, "oracle") if encoded := q.Encode(); encoded != "" { u.RawQuery = encoded @@ -55,6 +58,53 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string { return u.String() } +func oracleQueryValue(values url.Values, key string) string { + return strings.TrimSpace(values.Get(key)) +} + +func oracleQueryValueOrDefault(values url.Values, key string) string { + value := oracleQueryValue(values, key) + if value == "" { + return "未配置" + } + return value +} + +func oracleDSNLogSummary(config connection.ConnectionConfig, dsn string) string { + serviceName := strings.TrimSpace(config.Database) + params := url.Values{} + if parsed, err := url.Parse(dsn); err == nil && parsed != nil { + if pathService, unescapeErr := url.PathUnescape(strings.TrimPrefix(parsed.EscapedPath(), "/")); unescapeErr == nil && strings.TrimSpace(pathService) != "" { + serviceName = strings.TrimSpace(pathService) + } + params = parsed.Query() + } + if serviceName == "" { + serviceName = "(未配置)" + } + return fmt.Sprintf("服务名=%s CONNECT_TIMEOUT=%s READ_TIMEOUT=%s SSL=%s SSL_VERIFY=%s AUTH_TYPE=%s DBA_PRIVILEGE=%s SID=%s", + serviceName, + oracleQueryValueOrDefault(params, "CONNECT TIMEOUT"), + oracleQueryValueOrDefault(params, "READ TIMEOUT"), + oracleQueryValueOrDefault(params, "SSL"), + oracleQueryValueOrDefault(params, "SSL VERIFY"), + oracleQueryValueOrDefault(params, "AUTH TYPE"), + oracleQueryValueOrDefault(params, "DBA PRIVILEGE"), + oracleQueryValueOrDefault(params, "SID"), + ) +} + +func annotateOracleValidationError(err error) error { + if err == nil { + return nil + } + message := strings.ToLower(err.Error()) + if !strings.Contains(message, "use of closed network connection") { + return err + } + return fmt.Errorf("%w(Oracle 连接在验证阶段被服务端关闭或被驱动超时中断;请检查监听端口是否为 Oracle 协议端口、Service Name 是否正确、认证参数如 DBA_PRIVILEGE/AUTH_TYPE 是否匹配)", err) +} + func (o *OracleDB) Connect(config connection.ConnectionConfig) error { runConfig := config serviceName := strings.TrimSpace(config.Database) @@ -101,6 +151,7 @@ func (o *OracleDB) Connect(config connection.ConnectionConfig) error { var failures []string for idx, attempt := range attempts { dsn := o.getDSN(attempt) + logger.Infof("Oracle 连接参数摘要:地址=%s:%d 用户=%s %s", attempt.Host, attempt.Port, attempt.User, oracleDSNLogSummary(attempt, dsn)) db, err := sql.Open("oracle", dsn) if err != nil { failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err)) @@ -111,7 +162,7 @@ func (o *OracleDB) Connect(config connection.ConnectionConfig) error { if err := o.Ping(); err != nil { _ = db.Close() o.conn = nil - failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err)) + failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, annotateOracleValidationError(err))) continue } if idx > 0 { diff --git a/tools/generate-driver-agent-revisions.sh b/tools/generate-driver-agent-revisions.sh index acecf56..7ca1b0c 100755 --- a/tools/generate-driver-agent-revisions.sh +++ b/tools/generate-driver-agent-revisions.sh @@ -93,6 +93,7 @@ internal/db/timeout.go) case "$driver:$identity" in mariadb:internal/db/mariadb_impl.go|\ oceanbase:internal/db/oceanbase_impl.go|\ +oceanbase:internal/db/oracle_impl.go|\ oceanbase:internal/db/mysql_impl.go|\ diros:internal/db/diros_impl.go|\ diros:internal/db/mysql_impl.go|\