mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-05-22 17:00:21 +08:00
🐛 fix(oceanbase): 增强 Oracle 协议连接校验与诊断
- 运行时校验可选 driver-agent revision,避免旧代理继续被复用 - OceanBase agent revision 纳入 oracle_impl.go 指纹并重新生成 - OceanBase Oracle 保留 URI 中的 Oracle 连接参数 - Oracle DSN 默认写入连接和读取超时,并输出脱敏诊断摘要 - 补充 revision、Oracle DSN、OceanBase Oracle 参数提升测试
This commit is contained in:
@@ -543,6 +543,9 @@ func (a *App) openDatabaseIsolated(config connection.ConnectionConfig) (db.Datab
|
||||
}
|
||||
return nil, withLogHint{err: fmt.Errorf("%s", reason), logPath: logger.Path()}
|
||||
}
|
||||
if revisionErr := verifyRuntimeOptionalDriverAgentRevision(effectiveConfig); revisionErr != nil {
|
||||
return nil, withLogHint{err: revisionErr, logPath: logger.Path()}
|
||||
}
|
||||
|
||||
dbInst, err := newDatabaseFunc(effectiveConfig.Type)
|
||||
if err != nil {
|
||||
@@ -655,6 +658,9 @@ func (a *App) getDatabaseWithPing(config connection.ConnectionConfig, forcePing
|
||||
formatConnSummary(effectiveConfig), shortKey, formatConnectFailureCooldown(remaining), normalizeErrorMessage(failure.err))
|
||||
return nil, withLogHint{err: fmt.Errorf("%s", message), logPath: logger.Path()}
|
||||
}
|
||||
if revisionErr := verifyRuntimeOptionalDriverAgentRevision(effectiveConfig); revisionErr != nil {
|
||||
return nil, withLogHint{err: revisionErr, logPath: logger.Path()}
|
||||
}
|
||||
|
||||
initialKey := key
|
||||
dbInst, connectedConfig, err := a.connectDatabaseWithStartupRetry(resolvedConfig)
|
||||
@@ -744,6 +750,32 @@ func formatConnectFailureCooldown(remaining time.Duration) time.Duration {
|
||||
return remaining.Truncate(time.Second)
|
||||
}
|
||||
|
||||
func verifyRuntimeOptionalDriverAgentRevision(config connection.ConnectionConfig) error {
|
||||
driverType := normalizeDriverType(config.Type)
|
||||
if !db.IsOptionalGoDriver(driverType) {
|
||||
return nil
|
||||
}
|
||||
executablePath, err := db.ResolveOptionalDriverAgentExecutablePath("", driverType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pkg, packageMetaExists := readInstalledDriverPackage("", driverType)
|
||||
selectedVersion := ""
|
||||
if packageMetaExists {
|
||||
selectedVersion = strings.TrimSpace(pkg.Version)
|
||||
}
|
||||
agentRevision, err := verifyInstalledOptionalDriverAgentRevision(driverType, executablePath, selectedVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if expectedRevision := strings.TrimSpace(db.OptionalDriverAgentRevision(driverType)); expectedRevision != "" {
|
||||
displayName := resolveDriverDisplayName(driverDefinition{Type: driverType})
|
||||
logger.Infof("%s driver-agent revision 校验通过:已安装=%s 当前需要=%s version=%s path=%s",
|
||||
displayName, strings.TrimSpace(agentRevision), expectedRevision, selectedVersion, executablePath)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func shortenCacheKey(key string) string {
|
||||
if len(key) > 12 {
|
||||
return key[:12]
|
||||
|
||||
@@ -53,6 +53,7 @@ var _ db.Database = (*fakeCreateDatabaseDB)(nil)
|
||||
|
||||
func TestResolveDDLDBType_SQLServerAliases(t *testing.T) {
|
||||
tests := []connection.ConnectionConfig{
|
||||
{Type: "sqlserver"},
|
||||
{Type: "mssql"},
|
||||
{Type: "sql_server"},
|
||||
{Type: "custom", Driver: "mssql"},
|
||||
@@ -95,7 +96,8 @@ func TestCreateDatabase_SQLServerUsesBracketIdentifiers(t *testing.T) {
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
result := app.CreateDatabase(connection.ConnectionConfig{
|
||||
Type: "sqlserver",
|
||||
Type: "custom",
|
||||
Driver: "mssql",
|
||||
Database: "master",
|
||||
}, "lg")
|
||||
|
||||
|
||||
@@ -79,6 +79,49 @@ func TestVerifyInstalledOptionalDriverAgentRevisionRejectsProbeFailure(t *testin
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyRuntimeOptionalDriverAgentRevisionRejectsStaleOceanBaseAgent(t *testing.T) {
|
||||
originalProbe := optionalDriverAgentMetadataProbe
|
||||
t.Cleanup(func() {
|
||||
optionalDriverAgentMetadataProbe = originalProbe
|
||||
})
|
||||
optionalDriverAgentMetadataProbe = func(driverType string, executablePath string) (db.OptionalDriverAgentMetadata, error) {
|
||||
return db.OptionalDriverAgentMetadata{
|
||||
DriverType: driverType,
|
||||
AgentRevision: "src-stale-agent",
|
||||
}, nil
|
||||
}
|
||||
|
||||
err := verifyRuntimeOptionalDriverAgentRevision(connection.ConnectionConfig{Type: "oceanbase"})
|
||||
if err == nil {
|
||||
t.Fatal("expected stale OceanBase agent revision to be rejected")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "revision 不匹配") {
|
||||
t.Fatalf("expected revision mismatch error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyRuntimeOptionalDriverAgentRevisionSkipsCustomDriver(t *testing.T) {
|
||||
originalProbe := optionalDriverAgentMetadataProbe
|
||||
t.Cleanup(func() {
|
||||
optionalDriverAgentMetadataProbe = originalProbe
|
||||
})
|
||||
calls := 0
|
||||
optionalDriverAgentMetadataProbe = func(driverType string, executablePath string) (db.OptionalDriverAgentMetadata, error) {
|
||||
calls++
|
||||
return db.OptionalDriverAgentMetadata{}, nil
|
||||
}
|
||||
|
||||
if err := verifyRuntimeOptionalDriverAgentRevision(connection.ConnectionConfig{
|
||||
Type: "custom",
|
||||
Driver: "oceanbase",
|
||||
}); err != nil {
|
||||
t.Fatalf("custom driver should skip optional agent runtime revision check: %v", err)
|
||||
}
|
||||
if calls != 0 {
|
||||
t.Fatalf("custom driver should not probe optional agent metadata, got %d calls", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func optionalDriverAgentRevisionTestDrivers(t *testing.T) []string {
|
||||
t.Helper()
|
||||
drivers := []string{
|
||||
|
||||
@@ -5,7 +5,7 @@ package db
|
||||
func init() {
|
||||
optionalDriverAgentRevisions = map[string]string{
|
||||
"mariadb": "src-1a1cc64f8f92d92b",
|
||||
"oceanbase": "src-ac051813e2451265",
|
||||
"oceanbase": "src-5bcb757b1b85d41e",
|
||||
"diros": "src-bcc78fa43671ade5",
|
||||
"sphinx": "src-404765c2fda68c5f",
|
||||
"sqlserver": "src-d9fba1eca0a27c49",
|
||||
|
||||
@@ -248,6 +248,33 @@ func withoutOceanBaseProtocolParams(config connection.ConnectionConfig) connecti
|
||||
return next
|
||||
}
|
||||
|
||||
func promoteOceanBaseOracleURIParams(config connection.ConnectionConfig) connection.ConnectionConfig {
|
||||
uriParams := connectionParamsFromURI(config.URI, "oceanbase", "mysql")
|
||||
if len(uriParams) == 0 {
|
||||
return config
|
||||
}
|
||||
for _, key := range []string{"protocol", "oceanBaseProtocol", "oceanbaseProtocol", "tenantMode", "compatMode", "mode"} {
|
||||
uriParams.Del(key)
|
||||
}
|
||||
if len(uriParams) == 0 {
|
||||
return config
|
||||
}
|
||||
merged := url.Values{}
|
||||
mergeConnectionParamValuesWithAllowlist(merged, uriParams, oracleConnectionParamNames)
|
||||
mergeConnectionParamValuesWithAllowlist(merged, connectionParamsFromText(config.ConnectionParams), oracleConnectionParamNames)
|
||||
config.ConnectionParams = merged.Encode()
|
||||
return config
|
||||
}
|
||||
|
||||
func prepareOceanBaseOracleConfig(config connection.ConnectionConfig) connection.ConnectionConfig {
|
||||
runConfig := withoutOceanBaseProtocolParams(applyOceanBaseURI(config))
|
||||
runConfig = promoteOceanBaseOracleURIParams(runConfig)
|
||||
runConfig.Type = "oracle"
|
||||
// OracleDB 不解析 oceanbase:// URI。连接要素已落到结构化字段和 ConnectionParams。
|
||||
runConfig.URI = ""
|
||||
return runConfig
|
||||
}
|
||||
|
||||
func isOceanBaseOracleTenantMySQLDriverError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
@@ -264,8 +291,7 @@ func formatOceanBaseMySQLAttemptError(address string, err error) string {
|
||||
}
|
||||
|
||||
func (o *OceanBaseDB) connectOracle(config connection.ConnectionConfig) error {
|
||||
runConfig := withoutOceanBaseProtocolParams(applyOceanBaseURI(config))
|
||||
runConfig.Type = "oracle"
|
||||
runConfig := prepareOceanBaseOracleConfig(config)
|
||||
if strings.TrimSpace(runConfig.Database) == "" {
|
||||
return fmt.Errorf("OceanBase Oracle 协议需要填写服务名(Service Name),请在连接配置中填写租户监听的服务名")
|
||||
}
|
||||
|
||||
@@ -148,6 +148,36 @@ func TestWithoutOceanBaseProtocolParamsStripsDriverMeta(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareOceanBaseOracleConfigPromotesURIParams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := prepareOceanBaseOracleConfig(connection.ConnectionConfig{
|
||||
Type: "oceanbase",
|
||||
URI: "oceanbase://sys%40oracle001:pass@127.0.0.1:2881/ORCL?protocol=oracle&CONNECT_TIMEOUT=12&DBA_PRIVILEGE=SYSDBA",
|
||||
ConnectionParams: "protocol=oracle&READ_TIMEOUT=7",
|
||||
})
|
||||
|
||||
if config.Type != "oracle" {
|
||||
t.Fatalf("expected routed type oracle, got %q", config.Type)
|
||||
}
|
||||
if config.URI != "" {
|
||||
t.Fatalf("expected routed Oracle config to clear oceanbase URI, got %q", config.URI)
|
||||
}
|
||||
params := connectionParamsFromText(config.ConnectionParams)
|
||||
if got := params.Get("CONNECT TIMEOUT"); got != "12" {
|
||||
t.Fatalf("expected URI CONNECT_TIMEOUT promoted, got %q in %q", got, config.ConnectionParams)
|
||||
}
|
||||
if got := params.Get("READ TIMEOUT"); got != "7" {
|
||||
t.Fatalf("expected explicit READ_TIMEOUT kept, got %q in %q", got, config.ConnectionParams)
|
||||
}
|
||||
if got := params.Get("DBA PRIVILEGE"); got != "SYSDBA" {
|
||||
t.Fatalf("expected URI DBA_PRIVILEGE promoted, got %q in %q", got, config.ConnectionParams)
|
||||
}
|
||||
if strings.Contains(config.ConnectionParams, "protocol=") {
|
||||
t.Fatalf("expected OceanBase protocol param stripped, got %q", config.ConnectionParams)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOceanBaseOracleRequiresServiceName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
@@ -31,6 +33,31 @@ func TestOracleGetDSNIncludesQueryPerformanceOptions(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOracleGetDSNIncludesTimeoutDefaults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dsn := (&OracleDB{}).getDSN(connection.ConnectionConfig{
|
||||
Host: "db.example.com",
|
||||
Port: 1521,
|
||||
User: "scott",
|
||||
Password: "tiger",
|
||||
Database: "ORCLPDB1",
|
||||
Timeout: 12,
|
||||
})
|
||||
|
||||
parsed, err := url.Parse(dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("解析 Oracle DSN 失败: %v", err)
|
||||
}
|
||||
query := parsed.Query()
|
||||
if got := query.Get("CONNECT TIMEOUT"); got != "12" {
|
||||
t.Fatalf("CONNECT TIMEOUT = %q, want 12", got)
|
||||
}
|
||||
if got := query.Get("READ TIMEOUT"); got != "12" {
|
||||
t.Fatalf("READ TIMEOUT = %q, want 12", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOracleGetDSNMergesConnectionParams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -40,7 +67,7 @@ func TestOracleGetDSNMergesConnectionParams(t *testing.T) {
|
||||
User: "scott",
|
||||
Password: "tiger",
|
||||
Database: "ORCLPDB1",
|
||||
ConnectionParams: "PREFETCH_ROWS=5000&TRACE FILE=/tmp/go-ora.trc&connect_timeout=10&FAILOVER=3&unknown=bad",
|
||||
ConnectionParams: "PREFETCH_ROWS=5000&TRACE FILE=/tmp/go-ora.trc&connect_timeout=10&read_timeout=7&FAILOVER=3&unknown=bad",
|
||||
})
|
||||
|
||||
parsed, err := url.Parse(dsn)
|
||||
@@ -57,6 +84,9 @@ func TestOracleGetDSNMergesConnectionParams(t *testing.T) {
|
||||
if got := query.Get("CONNECT TIMEOUT"); got != "10" {
|
||||
t.Fatalf("CONNECT TIMEOUT = %q, want 10", got)
|
||||
}
|
||||
if got := query.Get("READ TIMEOUT"); got != "7" {
|
||||
t.Fatalf("READ TIMEOUT = %q, want 7", got)
|
||||
}
|
||||
if got := query.Get("FAILOVER"); got != "" {
|
||||
t.Fatalf("FAILOVER should be filtered because go-ora no longer supports it, got %q", got)
|
||||
}
|
||||
@@ -64,3 +94,35 @@ func TestOracleGetDSNMergesConnectionParams(t *testing.T) {
|
||||
t.Fatalf("unknown should be filtered, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOracleDSNLogSummaryDoesNotExposePassword(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dsn := (&OracleDB{}).getDSN(connection.ConnectionConfig{
|
||||
Host: "db.example.com",
|
||||
Port: 1521,
|
||||
User: "sys@tenant",
|
||||
Password: "top-secret",
|
||||
Database: "ORCLPDB1",
|
||||
ConnectionParams: "DBA_PRIVILEGE=SYSDBA&AUTH_TYPE=NORMAL",
|
||||
})
|
||||
|
||||
got := oracleDSNLogSummary(connection.ConnectionConfig{Database: "ORCLPDB1"}, dsn)
|
||||
if strings.Contains(got, "top-secret") || strings.Contains(got, "sys@tenant") {
|
||||
t.Fatalf("summary should not expose credentials, got %q", got)
|
||||
}
|
||||
for _, want := range []string{"服务名=ORCLPDB1", "DBA_PRIVILEGE=SYSDBA", "AUTH_TYPE=NORMAL"} {
|
||||
if !strings.Contains(got, want) {
|
||||
t.Fatalf("expected summary to contain %q, got %q", want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnnotateOracleValidationErrorAddsClosedConnectionHint(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := annotateOracleValidationError(errors.New("read tcp 127.0.0.1:1->127.0.0.1:2: use of closed network connection"))
|
||||
if err == nil || !strings.Contains(err.Error(), "Service Name") {
|
||||
t.Fatalf("expected closed connection hint, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,6 +48,9 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
|
||||
q.Set("PREFETCH_ROWS", "10000")
|
||||
// LOB 数据延迟加载,避免大 LOB 列影响普通查询性能
|
||||
q.Set("LOB FETCH", "POST")
|
||||
timeoutSeconds := strconv.Itoa(getConnectTimeoutSeconds(config))
|
||||
q.Set("CONNECT TIMEOUT", timeoutSeconds)
|
||||
q.Set("READ TIMEOUT", timeoutSeconds)
|
||||
mergeConnectionParamsFromConfigWithAllowlist(q, config, oracleConnectionParamNames, "oracle")
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
u.RawQuery = encoded
|
||||
@@ -55,6 +58,53 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
|
||||
return u.String()
|
||||
}
|
||||
|
||||
func oracleQueryValue(values url.Values, key string) string {
|
||||
return strings.TrimSpace(values.Get(key))
|
||||
}
|
||||
|
||||
func oracleQueryValueOrDefault(values url.Values, key string) string {
|
||||
value := oracleQueryValue(values, key)
|
||||
if value == "" {
|
||||
return "未配置"
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func oracleDSNLogSummary(config connection.ConnectionConfig, dsn string) string {
|
||||
serviceName := strings.TrimSpace(config.Database)
|
||||
params := url.Values{}
|
||||
if parsed, err := url.Parse(dsn); err == nil && parsed != nil {
|
||||
if pathService, unescapeErr := url.PathUnescape(strings.TrimPrefix(parsed.EscapedPath(), "/")); unescapeErr == nil && strings.TrimSpace(pathService) != "" {
|
||||
serviceName = strings.TrimSpace(pathService)
|
||||
}
|
||||
params = parsed.Query()
|
||||
}
|
||||
if serviceName == "" {
|
||||
serviceName = "(未配置)"
|
||||
}
|
||||
return fmt.Sprintf("服务名=%s CONNECT_TIMEOUT=%s READ_TIMEOUT=%s SSL=%s SSL_VERIFY=%s AUTH_TYPE=%s DBA_PRIVILEGE=%s SID=%s",
|
||||
serviceName,
|
||||
oracleQueryValueOrDefault(params, "CONNECT TIMEOUT"),
|
||||
oracleQueryValueOrDefault(params, "READ TIMEOUT"),
|
||||
oracleQueryValueOrDefault(params, "SSL"),
|
||||
oracleQueryValueOrDefault(params, "SSL VERIFY"),
|
||||
oracleQueryValueOrDefault(params, "AUTH TYPE"),
|
||||
oracleQueryValueOrDefault(params, "DBA PRIVILEGE"),
|
||||
oracleQueryValueOrDefault(params, "SID"),
|
||||
)
|
||||
}
|
||||
|
||||
func annotateOracleValidationError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
message := strings.ToLower(err.Error())
|
||||
if !strings.Contains(message, "use of closed network connection") {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("%w(Oracle 连接在验证阶段被服务端关闭或被驱动超时中断;请检查监听端口是否为 Oracle 协议端口、Service Name 是否正确、认证参数如 DBA_PRIVILEGE/AUTH_TYPE 是否匹配)", err)
|
||||
}
|
||||
|
||||
func (o *OracleDB) Connect(config connection.ConnectionConfig) error {
|
||||
runConfig := config
|
||||
serviceName := strings.TrimSpace(config.Database)
|
||||
@@ -101,6 +151,7 @@ func (o *OracleDB) Connect(config connection.ConnectionConfig) error {
|
||||
var failures []string
|
||||
for idx, attempt := range attempts {
|
||||
dsn := o.getDSN(attempt)
|
||||
logger.Infof("Oracle 连接参数摘要:地址=%s:%d 用户=%s %s", attempt.Host, attempt.Port, attempt.User, oracleDSNLogSummary(attempt, dsn))
|
||||
db, err := sql.Open("oracle", dsn)
|
||||
if err != nil {
|
||||
failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err))
|
||||
@@ -111,7 +162,7 @@ func (o *OracleDB) Connect(config connection.ConnectionConfig) error {
|
||||
if err := o.Ping(); err != nil {
|
||||
_ = db.Close()
|
||||
o.conn = nil
|
||||
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err))
|
||||
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, annotateOracleValidationError(err)))
|
||||
continue
|
||||
}
|
||||
if idx > 0 {
|
||||
|
||||
@@ -93,6 +93,7 @@ internal/db/timeout.go)
|
||||
case "$driver:$identity" in
|
||||
mariadb:internal/db/mariadb_impl.go|\
|
||||
oceanbase:internal/db/oceanbase_impl.go|\
|
||||
oceanbase:internal/db/oracle_impl.go|\
|
||||
oceanbase:internal/db/mysql_impl.go|\
|
||||
diros:internal/db/diros_impl.go|\
|
||||
diros:internal/db/mysql_impl.go|\
|
||||
|
||||
Reference in New Issue
Block a user