🐛 fix(oceanbase): 修复 Oracle 租户连接误走 go-ora

- 连接层改为通过 OceanBase MySQL 兼容协议建立 Oracle 租户连接
- 保留 Oracle 元数据包装,避免表结构和 schema 查询退回 MySQL 方言
- 修复 Oracle 租户数据编辑在 MySQL wire 下的占位符格式
- 更新 OceanBase driver-agent revision,确保 dev 包触发驱动刷新
This commit is contained in:
Syngnat
2026-05-14 11:47:32 +08:00
parent 527ecd37e1
commit 17331ddbaa
3 changed files with 202 additions and 76 deletions

View File

@@ -5,7 +5,7 @@ package db
func init() {
optionalDriverAgentRevisions = map[string]string{
"mariadb": "src-1a1cc64f8f92d92b",
"oceanbase": "src-5bcb757b1b85d41e",
"oceanbase": "src-f6f19676bb5102d1",
"diros": "src-bcc78fa43671ade5",
"sphinx": "src-404765c2fda68c5f",
"sqlserver": "src-d9fba1eca0a27c49",

View File

@@ -8,6 +8,7 @@ import (
"fmt"
"net/url"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
@@ -248,33 +249,6 @@ func withoutOceanBaseProtocolParams(config connection.ConnectionConfig) connecti
return next
}
func promoteOceanBaseOracleURIParams(config connection.ConnectionConfig) connection.ConnectionConfig {
uriParams := connectionParamsFromURI(config.URI, "oceanbase", "mysql")
if len(uriParams) == 0 {
return config
}
for _, key := range []string{"protocol", "oceanBaseProtocol", "oceanbaseProtocol", "tenantMode", "compatMode", "mode"} {
uriParams.Del(key)
}
if len(uriParams) == 0 {
return config
}
merged := url.Values{}
mergeConnectionParamValuesWithAllowlist(merged, uriParams, oracleConnectionParamNames)
mergeConnectionParamValuesWithAllowlist(merged, connectionParamsFromText(config.ConnectionParams), oracleConnectionParamNames)
config.ConnectionParams = merged.Encode()
return config
}
func prepareOceanBaseOracleConfig(config connection.ConnectionConfig) connection.ConnectionConfig {
runConfig := withoutOceanBaseProtocolParams(applyOceanBaseURI(config))
runConfig = promoteOceanBaseOracleURIParams(runConfig)
runConfig.Type = "oracle"
// OracleDB 不解析 oceanbase:// URI。连接要素已落到结构化字段和 ConnectionParams。
runConfig.URI = ""
return runConfig
}
func isOceanBaseOracleTenantMySQLDriverError(err error) bool {
if err == nil {
return false
@@ -290,22 +264,30 @@ func formatOceanBaseMySQLAttemptError(address string, err error) string {
return fmt.Sprintf("%s 验证失败: %v", address, err)
}
func (o *OceanBaseDB) connectOracle(config connection.ConnectionConfig) error {
runConfig := prepareOceanBaseOracleConfig(config)
if strings.TrimSpace(runConfig.Database) == "" {
return fmt.Errorf("OceanBase Oracle 协议需要填写服务名Service Name请在连接配置中填写租户监听的服务名")
func formatOceanBaseAttemptError(address string, protocol string, err error) string {
if protocol == oceanBaseProtocolMySQL {
return formatOceanBaseMySQLAttemptError(address, err)
}
oracleDB := &OracleDB{}
if err := oracleDB.Connect(runConfig); err != nil {
return fmt.Errorf("OceanBase Oracle 协议连接失败:%w", err)
return fmt.Sprintf("%s 验证失败: %v", address, err)
}
func (o *OceanBaseDB) bindConnectedDatabase(db *sql.DB, timeout time.Duration, protocol string) {
o.oracle = nil
o.conn = nil
o.pingTimeout = 0
if protocol == oceanBaseProtocolOracle {
o.oracle = &OracleDB{conn: db, pingTimeout: timeout}
o.protocol = oceanBaseProtocolOracle
return
}
o.oracle = oracleDB
o.protocol = oceanBaseProtocolOracle
return nil
o.conn = db
o.pingTimeout = timeout
o.protocol = oceanBaseProtocolMySQL
}
func (o *OceanBaseDB) Connect(config connection.ConnectionConfig) error {
o.oracle = nil
o.conn = nil
o.protocol = oceanBaseProtocolMySQL
appliedConfig := applyOceanBaseURI(config)
protocol, err := resolveOceanBaseProtocol(appliedConfig)
@@ -314,8 +296,7 @@ func (o *OceanBaseDB) Connect(config connection.ConnectionConfig) error {
}
runConfig := withoutOceanBaseProtocolParams(appliedConfig)
if protocol == oceanBaseProtocolOracle {
logger.Infof("OceanBase 使用 Oracle 协议连接:地址=%s:%d 用户=%s", runConfig.Host, runConfig.Port, runConfig.User)
return o.connectOracle(runConfig)
logger.Infof("OceanBase 使用 Oracle 租户模式连接:地址=%s:%d 用户=%s(连接层使用 OceanBase MySQL 兼容协议)", runConfig.Host, runConfig.Port, runConfig.User)
}
addresses := collectOceanBaseAddresses(runConfig)
@@ -351,13 +332,11 @@ func (o *OceanBaseDB) Connect(config connection.ConnectionConfig) error {
cancel()
if pingErr != nil {
_ = db.Close()
errorDetails = append(errorDetails, formatOceanBaseMySQLAttemptError(address, pingErr))
errorDetails = append(errorDetails, formatOceanBaseAttemptError(address, protocol, pingErr))
continue
}
o.conn = db
o.pingTimeout = timeout
o.protocol = oceanBaseProtocolMySQL
o.bindConnectedDatabase(db, timeout, protocol)
return nil
}
@@ -475,8 +454,137 @@ func (o *OceanBaseDB) GetTriggers(dbName, tableName string) ([]connection.Trigge
}
func (o *OceanBaseDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
if o.protocol == oceanBaseProtocolOracle && o.oracle != nil {
return o.applyOracleChangesMySQLWire(tableName, changes)
}
if applier, ok := o.activeDatabase().(BatchApplier); ok {
return applier.ApplyChanges(tableName, changes)
}
return fmt.Errorf("当前 OceanBase %s 协议不支持 ApplyChanges", o.protocol)
}
func (o *OceanBaseDB) applyOracleChangesMySQLWire(tableName string, changes connection.ChangeSet) error {
if o.oracle == nil || o.oracle.conn == nil {
return fmt.Errorf("连接未打开")
}
columnTypeMap := o.oracle.loadColumnTypeMap(tableName)
tx, err := o.oracle.conn.Begin()
if err != nil {
return err
}
defer tx.Rollback()
quoteIdent := func(name string) string {
n := strings.TrimSpace(name)
n = strings.Trim(n, "\"")
n = strings.ReplaceAll(n, "\"", "\"\"")
if n == "" {
return "\"\""
}
return `"` + n + `"`
}
schema := ""
table := strings.TrimSpace(tableName)
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
schema = strings.TrimSpace(parts[0])
table = strings.TrimSpace(parts[1])
}
qualifiedTable := ""
if schema != "" {
qualifiedTable = fmt.Sprintf("%s.%s", quoteIdent(schema), quoteIdent(table))
} else {
qualifiedTable = quoteIdent(table)
}
isOracleRowIDLocator := strings.EqualFold(strings.TrimSpace(changes.LocatorStrategy), "oracle-rowid")
buildWhere := func(keys map[string]interface{}) ([]string, []interface{}) {
var wheres []string
var args []interface{}
for k, v := range keys {
if isOracleRowIDLocator && strings.EqualFold(strings.TrimSpace(k), "ROWID") {
wheres = append(wheres, "ROWID = ?")
args = append(args, v)
continue
}
wheres = append(wheres, fmt.Sprintf("%s = ?", quoteIdent(k)))
args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap))
}
return wheres, args
}
for _, pk := range changes.Deletes {
wheres, args := buildWhere(pk)
if len(wheres) == 0 {
continue
}
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
res, err := tx.Exec(query, args...)
if err != nil {
return fmt.Errorf("删除失败:%v", err)
}
if err := requireSingleRowAffected(res, "删除"); err != nil {
return err
}
}
for _, update := range changes.Updates {
var sets []string
var args []interface{}
for k, v := range update.Values {
sets = append(sets, fmt.Sprintf("%s = ?", quoteIdent(k)))
args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap))
}
if len(sets) == 0 {
continue
}
wheres, whereArgs := buildWhere(update.Keys)
args = append(args, whereArgs...)
if len(wheres) == 0 {
return fmt.Errorf("更新操作需要主键条件")
}
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
res, err := tx.Exec(query, args...)
if err != nil {
return fmt.Errorf("更新失败:%v", err)
}
if err := requireSingleRowAffected(res, "更新"); err != nil {
return err
}
}
for _, row := range changes.Inserts {
var cols []string
var placeholders []string
var args []interface{}
for k, v := range row {
cols = append(cols, quoteIdent(k))
placeholders = append(placeholders, "?")
args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap))
}
if len(cols) == 0 {
continue
}
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
res, err := tx.Exec(query, args...)
if err != nil {
return fmt.Errorf("插入失败:%v", err)
}
if affected, err := res.RowsAffected(); err == nil && affected == 0 {
return fmt.Errorf("插入未生效:未影响任何行")
}
}
return tx.Commit()
}

View File

@@ -3,6 +3,7 @@
package db
import (
"database/sql/driver"
"errors"
"strings"
"testing"
@@ -148,51 +149,68 @@ func TestWithoutOceanBaseProtocolParamsStripsDriverMeta(t *testing.T) {
}
}
func TestPrepareOceanBaseOracleConfigPromotesURIParams(t *testing.T) {
func TestOceanBaseOracleProtocolUsesMySQLWireConnection(t *testing.T) {
t.Parallel()
config := prepareOceanBaseOracleConfig(connection.ConnectionConfig{
Type: "oceanbase",
URI: "oceanbase://sys%40oracle001:pass@127.0.0.1:2881/ORCL?protocol=oracle&CONNECT_TIMEOUT=12&DBA_PRIVILEGE=SYSDBA",
ConnectionParams: "protocol=oracle&READ_TIMEOUT=7",
})
dbConn, state := openOracleRecordingDB(t)
state.queryResults["SELECT username FROM all_users ORDER BY username"] = oracleRecordingQueryResult{
columns: []string{"USERNAME"},
rows: [][]driver.Value{{"SYS"}},
}
if config.Type != "oracle" {
t.Fatalf("expected routed type oracle, got %q", config.Type)
oceanbaseDB := &OceanBaseDB{}
oceanbaseDB.bindConnectedDatabase(dbConn, 0, oceanBaseProtocolOracle)
if oceanbaseDB.oracle == nil {
t.Fatal("expected Oracle metadata wrapper for OceanBase Oracle tenant")
}
if config.URI != "" {
t.Fatalf("expected routed Oracle config to clear oceanbase URI, got %q", config.URI)
if oceanbaseDB.conn != nil {
t.Fatal("expected MySQLDB connection slot to stay empty for Oracle tenant wrapper")
}
params := connectionParamsFromText(config.ConnectionParams)
if got := params.Get("CONNECT TIMEOUT"); got != "12" {
t.Fatalf("expected URI CONNECT_TIMEOUT promoted, got %q in %q", got, config.ConnectionParams)
if oceanbaseDB.protocol != oceanBaseProtocolOracle {
t.Fatalf("expected protocol oracle, got %q", oceanbaseDB.protocol)
}
if got := params.Get("READ TIMEOUT"); got != "7" {
t.Fatalf("expected explicit READ_TIMEOUT kept, got %q in %q", got, config.ConnectionParams)
databases, err := oceanbaseDB.GetDatabases()
if err != nil {
t.Fatalf("GetDatabases() unexpected error: %v", err)
}
if got := params.Get("DBA PRIVILEGE"); got != "SYSDBA" {
t.Fatalf("expected URI DBA_PRIVILEGE promoted, got %q in %q", got, config.ConnectionParams)
}
if strings.Contains(config.ConnectionParams, "protocol=") {
t.Fatalf("expected OceanBase protocol param stripped, got %q", config.ConnectionParams)
if len(databases) != 1 || databases[0] != "SYS" {
t.Fatalf("GetDatabases() = %#v, want [SYS]", databases)
}
}
func TestOceanBaseOracleRequiresServiceName(t *testing.T) {
func TestOceanBaseOracleApplyChangesUsesMySQLWirePlaceholders(t *testing.T) {
t.Parallel()
err := (&OceanBaseDB{}).Connect(connection.ConnectionConfig{
Type: "oceanbase",
Host: "127.0.0.1",
Port: 2881,
User: "sys@oracle001",
ConnectionParams: "protocol=oracle",
})
if err == nil {
t.Fatal("expected missing service name error")
dbConn, state := openOracleRecordingDB(t)
oceanbaseDB := &OceanBaseDB{}
oceanbaseDB.bindConnectedDatabase(dbConn, 0, oceanBaseProtocolOracle)
changes := connection.ChangeSet{
Updates: []connection.UpdateRow{{
Keys: map[string]interface{}{
"ID": 7,
},
Values: map[string]interface{}{
"NAME": "new-name",
},
}},
}
if !strings.Contains(err.Error(), "服务名") {
t.Fatalf("expected service name hint, got %v", err)
if err := oceanbaseDB.ApplyChanges("APP.USERS", changes); err != nil {
t.Fatalf("ApplyChanges() unexpected error: %v", err)
}
queries := state.snapshotExecQueries()
if len(queries) != 1 {
t.Fatalf("expected one exec query, got %#v", queries)
}
if strings.Contains(queries[0], ":1") {
t.Fatalf("expected MySQL wire placeholder style, got %q", queries[0])
}
if !strings.Contains(queries[0], `"NAME" = ?`) || !strings.Contains(queries[0], `"ID" = ?`) {
t.Fatalf("expected question mark placeholders, got %q", queries[0])
}
}