mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-05-22 08:50:17 +08:00
🐛 fix(oceanbase): 修复 Oracle 租户连接误走 go-ora
- 连接层改为通过 OceanBase MySQL 兼容协议建立 Oracle 租户连接 - 保留 Oracle 元数据包装,避免表结构和 schema 查询退回 MySQL 方言 - 修复 Oracle 租户数据编辑在 MySQL wire 下的占位符格式 - 更新 OceanBase driver-agent revision,确保 dev 包触发驱动刷新
This commit is contained in:
@@ -5,7 +5,7 @@ package db
|
||||
func init() {
|
||||
optionalDriverAgentRevisions = map[string]string{
|
||||
"mariadb": "src-1a1cc64f8f92d92b",
|
||||
"oceanbase": "src-5bcb757b1b85d41e",
|
||||
"oceanbase": "src-f6f19676bb5102d1",
|
||||
"diros": "src-bcc78fa43671ade5",
|
||||
"sphinx": "src-404765c2fda68c5f",
|
||||
"sqlserver": "src-d9fba1eca0a27c49",
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
@@ -248,33 +249,6 @@ func withoutOceanBaseProtocolParams(config connection.ConnectionConfig) connecti
|
||||
return next
|
||||
}
|
||||
|
||||
func promoteOceanBaseOracleURIParams(config connection.ConnectionConfig) connection.ConnectionConfig {
|
||||
uriParams := connectionParamsFromURI(config.URI, "oceanbase", "mysql")
|
||||
if len(uriParams) == 0 {
|
||||
return config
|
||||
}
|
||||
for _, key := range []string{"protocol", "oceanBaseProtocol", "oceanbaseProtocol", "tenantMode", "compatMode", "mode"} {
|
||||
uriParams.Del(key)
|
||||
}
|
||||
if len(uriParams) == 0 {
|
||||
return config
|
||||
}
|
||||
merged := url.Values{}
|
||||
mergeConnectionParamValuesWithAllowlist(merged, uriParams, oracleConnectionParamNames)
|
||||
mergeConnectionParamValuesWithAllowlist(merged, connectionParamsFromText(config.ConnectionParams), oracleConnectionParamNames)
|
||||
config.ConnectionParams = merged.Encode()
|
||||
return config
|
||||
}
|
||||
|
||||
func prepareOceanBaseOracleConfig(config connection.ConnectionConfig) connection.ConnectionConfig {
|
||||
runConfig := withoutOceanBaseProtocolParams(applyOceanBaseURI(config))
|
||||
runConfig = promoteOceanBaseOracleURIParams(runConfig)
|
||||
runConfig.Type = "oracle"
|
||||
// OracleDB 不解析 oceanbase:// URI。连接要素已落到结构化字段和 ConnectionParams。
|
||||
runConfig.URI = ""
|
||||
return runConfig
|
||||
}
|
||||
|
||||
func isOceanBaseOracleTenantMySQLDriverError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
@@ -290,22 +264,30 @@ func formatOceanBaseMySQLAttemptError(address string, err error) string {
|
||||
return fmt.Sprintf("%s 验证失败: %v", address, err)
|
||||
}
|
||||
|
||||
func (o *OceanBaseDB) connectOracle(config connection.ConnectionConfig) error {
|
||||
runConfig := prepareOceanBaseOracleConfig(config)
|
||||
if strings.TrimSpace(runConfig.Database) == "" {
|
||||
return fmt.Errorf("OceanBase Oracle 协议需要填写服务名(Service Name),请在连接配置中填写租户监听的服务名")
|
||||
func formatOceanBaseAttemptError(address string, protocol string, err error) string {
|
||||
if protocol == oceanBaseProtocolMySQL {
|
||||
return formatOceanBaseMySQLAttemptError(address, err)
|
||||
}
|
||||
oracleDB := &OracleDB{}
|
||||
if err := oracleDB.Connect(runConfig); err != nil {
|
||||
return fmt.Errorf("OceanBase Oracle 协议连接失败:%w", err)
|
||||
return fmt.Sprintf("%s 验证失败: %v", address, err)
|
||||
}
|
||||
|
||||
func (o *OceanBaseDB) bindConnectedDatabase(db *sql.DB, timeout time.Duration, protocol string) {
|
||||
o.oracle = nil
|
||||
o.conn = nil
|
||||
o.pingTimeout = 0
|
||||
if protocol == oceanBaseProtocolOracle {
|
||||
o.oracle = &OracleDB{conn: db, pingTimeout: timeout}
|
||||
o.protocol = oceanBaseProtocolOracle
|
||||
return
|
||||
}
|
||||
o.oracle = oracleDB
|
||||
o.protocol = oceanBaseProtocolOracle
|
||||
return nil
|
||||
o.conn = db
|
||||
o.pingTimeout = timeout
|
||||
o.protocol = oceanBaseProtocolMySQL
|
||||
}
|
||||
|
||||
func (o *OceanBaseDB) Connect(config connection.ConnectionConfig) error {
|
||||
o.oracle = nil
|
||||
o.conn = nil
|
||||
o.protocol = oceanBaseProtocolMySQL
|
||||
appliedConfig := applyOceanBaseURI(config)
|
||||
protocol, err := resolveOceanBaseProtocol(appliedConfig)
|
||||
@@ -314,8 +296,7 @@ func (o *OceanBaseDB) Connect(config connection.ConnectionConfig) error {
|
||||
}
|
||||
runConfig := withoutOceanBaseProtocolParams(appliedConfig)
|
||||
if protocol == oceanBaseProtocolOracle {
|
||||
logger.Infof("OceanBase 使用 Oracle 协议连接:地址=%s:%d 用户=%s", runConfig.Host, runConfig.Port, runConfig.User)
|
||||
return o.connectOracle(runConfig)
|
||||
logger.Infof("OceanBase 使用 Oracle 租户模式连接:地址=%s:%d 用户=%s(连接层使用 OceanBase MySQL 兼容协议)", runConfig.Host, runConfig.Port, runConfig.User)
|
||||
}
|
||||
|
||||
addresses := collectOceanBaseAddresses(runConfig)
|
||||
@@ -351,13 +332,11 @@ func (o *OceanBaseDB) Connect(config connection.ConnectionConfig) error {
|
||||
cancel()
|
||||
if pingErr != nil {
|
||||
_ = db.Close()
|
||||
errorDetails = append(errorDetails, formatOceanBaseMySQLAttemptError(address, pingErr))
|
||||
errorDetails = append(errorDetails, formatOceanBaseAttemptError(address, protocol, pingErr))
|
||||
continue
|
||||
}
|
||||
|
||||
o.conn = db
|
||||
o.pingTimeout = timeout
|
||||
o.protocol = oceanBaseProtocolMySQL
|
||||
o.bindConnectedDatabase(db, timeout, protocol)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -475,8 +454,137 @@ func (o *OceanBaseDB) GetTriggers(dbName, tableName string) ([]connection.Trigge
|
||||
}
|
||||
|
||||
func (o *OceanBaseDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
|
||||
if o.protocol == oceanBaseProtocolOracle && o.oracle != nil {
|
||||
return o.applyOracleChangesMySQLWire(tableName, changes)
|
||||
}
|
||||
if applier, ok := o.activeDatabase().(BatchApplier); ok {
|
||||
return applier.ApplyChanges(tableName, changes)
|
||||
}
|
||||
return fmt.Errorf("当前 OceanBase %s 协议不支持 ApplyChanges", o.protocol)
|
||||
}
|
||||
|
||||
func (o *OceanBaseDB) applyOracleChangesMySQLWire(tableName string, changes connection.ChangeSet) error {
|
||||
if o.oracle == nil || o.oracle.conn == nil {
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
columnTypeMap := o.oracle.loadColumnTypeMap(tableName)
|
||||
|
||||
tx, err := o.oracle.conn.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
quoteIdent := func(name string) string {
|
||||
n := strings.TrimSpace(name)
|
||||
n = strings.Trim(n, "\"")
|
||||
n = strings.ReplaceAll(n, "\"", "\"\"")
|
||||
if n == "" {
|
||||
return "\"\""
|
||||
}
|
||||
return `"` + n + `"`
|
||||
}
|
||||
|
||||
schema := ""
|
||||
table := strings.TrimSpace(tableName)
|
||||
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
|
||||
schema = strings.TrimSpace(parts[0])
|
||||
table = strings.TrimSpace(parts[1])
|
||||
}
|
||||
|
||||
qualifiedTable := ""
|
||||
if schema != "" {
|
||||
qualifiedTable = fmt.Sprintf("%s.%s", quoteIdent(schema), quoteIdent(table))
|
||||
} else {
|
||||
qualifiedTable = quoteIdent(table)
|
||||
}
|
||||
|
||||
isOracleRowIDLocator := strings.EqualFold(strings.TrimSpace(changes.LocatorStrategy), "oracle-rowid")
|
||||
buildWhere := func(keys map[string]interface{}) ([]string, []interface{}) {
|
||||
var wheres []string
|
||||
var args []interface{}
|
||||
for k, v := range keys {
|
||||
if isOracleRowIDLocator && strings.EqualFold(strings.TrimSpace(k), "ROWID") {
|
||||
wheres = append(wheres, "ROWID = ?")
|
||||
args = append(args, v)
|
||||
continue
|
||||
}
|
||||
wheres = append(wheres, fmt.Sprintf("%s = ?", quoteIdent(k)))
|
||||
args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap))
|
||||
}
|
||||
return wheres, args
|
||||
}
|
||||
|
||||
for _, pk := range changes.Deletes {
|
||||
wheres, args := buildWhere(pk)
|
||||
if len(wheres) == 0 {
|
||||
continue
|
||||
}
|
||||
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
|
||||
res, err := tx.Exec(query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除失败:%v", err)
|
||||
}
|
||||
if err := requireSingleRowAffected(res, "删除"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, update := range changes.Updates {
|
||||
var sets []string
|
||||
var args []interface{}
|
||||
|
||||
for k, v := range update.Values {
|
||||
sets = append(sets, fmt.Sprintf("%s = ?", quoteIdent(k)))
|
||||
args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap))
|
||||
}
|
||||
|
||||
if len(sets) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
wheres, whereArgs := buildWhere(update.Keys)
|
||||
args = append(args, whereArgs...)
|
||||
|
||||
if len(wheres) == 0 {
|
||||
return fmt.Errorf("更新操作需要主键条件")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
|
||||
res, err := tx.Exec(query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新失败:%v", err)
|
||||
}
|
||||
if err := requireSingleRowAffected(res, "更新"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, row := range changes.Inserts {
|
||||
var cols []string
|
||||
var placeholders []string
|
||||
var args []interface{}
|
||||
|
||||
for k, v := range row {
|
||||
cols = append(cols, quoteIdent(k))
|
||||
placeholders = append(placeholders, "?")
|
||||
args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap))
|
||||
}
|
||||
|
||||
if len(cols) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
|
||||
res, err := tx.Exec(query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("插入失败:%v", err)
|
||||
}
|
||||
if affected, err := res.RowsAffected(); err == nil && affected == 0 {
|
||||
return fmt.Errorf("插入未生效:未影响任何行")
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -148,51 +149,68 @@ func TestWithoutOceanBaseProtocolParamsStripsDriverMeta(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareOceanBaseOracleConfigPromotesURIParams(t *testing.T) {
|
||||
func TestOceanBaseOracleProtocolUsesMySQLWireConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := prepareOceanBaseOracleConfig(connection.ConnectionConfig{
|
||||
Type: "oceanbase",
|
||||
URI: "oceanbase://sys%40oracle001:pass@127.0.0.1:2881/ORCL?protocol=oracle&CONNECT_TIMEOUT=12&DBA_PRIVILEGE=SYSDBA",
|
||||
ConnectionParams: "protocol=oracle&READ_TIMEOUT=7",
|
||||
})
|
||||
dbConn, state := openOracleRecordingDB(t)
|
||||
state.queryResults["SELECT username FROM all_users ORDER BY username"] = oracleRecordingQueryResult{
|
||||
columns: []string{"USERNAME"},
|
||||
rows: [][]driver.Value{{"SYS"}},
|
||||
}
|
||||
|
||||
if config.Type != "oracle" {
|
||||
t.Fatalf("expected routed type oracle, got %q", config.Type)
|
||||
oceanbaseDB := &OceanBaseDB{}
|
||||
oceanbaseDB.bindConnectedDatabase(dbConn, 0, oceanBaseProtocolOracle)
|
||||
|
||||
if oceanbaseDB.oracle == nil {
|
||||
t.Fatal("expected Oracle metadata wrapper for OceanBase Oracle tenant")
|
||||
}
|
||||
if config.URI != "" {
|
||||
t.Fatalf("expected routed Oracle config to clear oceanbase URI, got %q", config.URI)
|
||||
if oceanbaseDB.conn != nil {
|
||||
t.Fatal("expected MySQLDB connection slot to stay empty for Oracle tenant wrapper")
|
||||
}
|
||||
params := connectionParamsFromText(config.ConnectionParams)
|
||||
if got := params.Get("CONNECT TIMEOUT"); got != "12" {
|
||||
t.Fatalf("expected URI CONNECT_TIMEOUT promoted, got %q in %q", got, config.ConnectionParams)
|
||||
if oceanbaseDB.protocol != oceanBaseProtocolOracle {
|
||||
t.Fatalf("expected protocol oracle, got %q", oceanbaseDB.protocol)
|
||||
}
|
||||
if got := params.Get("READ TIMEOUT"); got != "7" {
|
||||
t.Fatalf("expected explicit READ_TIMEOUT kept, got %q in %q", got, config.ConnectionParams)
|
||||
|
||||
databases, err := oceanbaseDB.GetDatabases()
|
||||
if err != nil {
|
||||
t.Fatalf("GetDatabases() unexpected error: %v", err)
|
||||
}
|
||||
if got := params.Get("DBA PRIVILEGE"); got != "SYSDBA" {
|
||||
t.Fatalf("expected URI DBA_PRIVILEGE promoted, got %q in %q", got, config.ConnectionParams)
|
||||
}
|
||||
if strings.Contains(config.ConnectionParams, "protocol=") {
|
||||
t.Fatalf("expected OceanBase protocol param stripped, got %q", config.ConnectionParams)
|
||||
if len(databases) != 1 || databases[0] != "SYS" {
|
||||
t.Fatalf("GetDatabases() = %#v, want [SYS]", databases)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOceanBaseOracleRequiresServiceName(t *testing.T) {
|
||||
func TestOceanBaseOracleApplyChangesUsesMySQLWirePlaceholders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := (&OceanBaseDB{}).Connect(connection.ConnectionConfig{
|
||||
Type: "oceanbase",
|
||||
Host: "127.0.0.1",
|
||||
Port: 2881,
|
||||
User: "sys@oracle001",
|
||||
ConnectionParams: "protocol=oracle",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected missing service name error")
|
||||
dbConn, state := openOracleRecordingDB(t)
|
||||
oceanbaseDB := &OceanBaseDB{}
|
||||
oceanbaseDB.bindConnectedDatabase(dbConn, 0, oceanBaseProtocolOracle)
|
||||
|
||||
changes := connection.ChangeSet{
|
||||
Updates: []connection.UpdateRow{{
|
||||
Keys: map[string]interface{}{
|
||||
"ID": 7,
|
||||
},
|
||||
Values: map[string]interface{}{
|
||||
"NAME": "new-name",
|
||||
},
|
||||
}},
|
||||
}
|
||||
if !strings.Contains(err.Error(), "服务名") {
|
||||
t.Fatalf("expected service name hint, got %v", err)
|
||||
|
||||
if err := oceanbaseDB.ApplyChanges("APP.USERS", changes); err != nil {
|
||||
t.Fatalf("ApplyChanges() unexpected error: %v", err)
|
||||
}
|
||||
|
||||
queries := state.snapshotExecQueries()
|
||||
if len(queries) != 1 {
|
||||
t.Fatalf("expected one exec query, got %#v", queries)
|
||||
}
|
||||
if strings.Contains(queries[0], ":1") {
|
||||
t.Fatalf("expected MySQL wire placeholder style, got %q", queries[0])
|
||||
}
|
||||
if !strings.Contains(queries[0], `"NAME" = ?`) || !strings.Contains(queries[0], `"ID" = ?`) {
|
||||
t.Fatalf("expected question mark placeholders, got %q", queries[0])
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user