diff --git a/internal/db/driver_agent_revisions_gen.go b/internal/db/driver_agent_revisions_gen.go index 8908843..c969347 100644 --- a/internal/db/driver_agent_revisions_gen.go +++ b/internal/db/driver_agent_revisions_gen.go @@ -5,7 +5,7 @@ package db func init() { optionalDriverAgentRevisions = map[string]string{ "mariadb": "src-1a1cc64f8f92d92b", - "oceanbase": "src-5bcb757b1b85d41e", + "oceanbase": "src-f6f19676bb5102d1", "diros": "src-bcc78fa43671ade5", "sphinx": "src-404765c2fda68c5f", "sqlserver": "src-d9fba1eca0a27c49", diff --git a/internal/db/oceanbase_impl.go b/internal/db/oceanbase_impl.go index 98525f0..56af255 100644 --- a/internal/db/oceanbase_impl.go +++ b/internal/db/oceanbase_impl.go @@ -8,6 +8,7 @@ import ( "fmt" "net/url" "strings" + "time" "GoNavi-Wails/internal/connection" "GoNavi-Wails/internal/logger" @@ -248,33 +249,6 @@ func withoutOceanBaseProtocolParams(config connection.ConnectionConfig) connecti return next } -func promoteOceanBaseOracleURIParams(config connection.ConnectionConfig) connection.ConnectionConfig { - uriParams := connectionParamsFromURI(config.URI, "oceanbase", "mysql") - if len(uriParams) == 0 { - return config - } - for _, key := range []string{"protocol", "oceanBaseProtocol", "oceanbaseProtocol", "tenantMode", "compatMode", "mode"} { - uriParams.Del(key) - } - if len(uriParams) == 0 { - return config - } - merged := url.Values{} - mergeConnectionParamValuesWithAllowlist(merged, uriParams, oracleConnectionParamNames) - mergeConnectionParamValuesWithAllowlist(merged, connectionParamsFromText(config.ConnectionParams), oracleConnectionParamNames) - config.ConnectionParams = merged.Encode() - return config -} - -func prepareOceanBaseOracleConfig(config connection.ConnectionConfig) connection.ConnectionConfig { - runConfig := withoutOceanBaseProtocolParams(applyOceanBaseURI(config)) - runConfig = promoteOceanBaseOracleURIParams(runConfig) - runConfig.Type = "oracle" - // OracleDB 不解析 oceanbase:// URI。连接要素已落到结构化字段和 ConnectionParams。 - runConfig.URI = "" - return runConfig -} - func isOceanBaseOracleTenantMySQLDriverError(err error) bool { if err == nil { return false @@ -290,22 +264,30 @@ func formatOceanBaseMySQLAttemptError(address string, err error) string { return fmt.Sprintf("%s 验证失败: %v", address, err) } -func (o *OceanBaseDB) connectOracle(config connection.ConnectionConfig) error { - runConfig := prepareOceanBaseOracleConfig(config) - if strings.TrimSpace(runConfig.Database) == "" { - return fmt.Errorf("OceanBase Oracle 协议需要填写服务名(Service Name),请在连接配置中填写租户监听的服务名") +func formatOceanBaseAttemptError(address string, protocol string, err error) string { + if protocol == oceanBaseProtocolMySQL { + return formatOceanBaseMySQLAttemptError(address, err) } - oracleDB := &OracleDB{} - if err := oracleDB.Connect(runConfig); err != nil { - return fmt.Errorf("OceanBase Oracle 协议连接失败:%w", err) + return fmt.Sprintf("%s 验证失败: %v", address, err) +} + +func (o *OceanBaseDB) bindConnectedDatabase(db *sql.DB, timeout time.Duration, protocol string) { + o.oracle = nil + o.conn = nil + o.pingTimeout = 0 + if protocol == oceanBaseProtocolOracle { + o.oracle = &OracleDB{conn: db, pingTimeout: timeout} + o.protocol = oceanBaseProtocolOracle + return } - o.oracle = oracleDB - o.protocol = oceanBaseProtocolOracle - return nil + o.conn = db + o.pingTimeout = timeout + o.protocol = oceanBaseProtocolMySQL } func (o *OceanBaseDB) Connect(config connection.ConnectionConfig) error { o.oracle = nil + o.conn = nil o.protocol = oceanBaseProtocolMySQL appliedConfig := applyOceanBaseURI(config) protocol, err := resolveOceanBaseProtocol(appliedConfig) @@ -314,8 +296,7 @@ func (o *OceanBaseDB) Connect(config connection.ConnectionConfig) error { } runConfig := withoutOceanBaseProtocolParams(appliedConfig) if protocol == oceanBaseProtocolOracle { - logger.Infof("OceanBase 使用 Oracle 协议连接:地址=%s:%d 用户=%s", runConfig.Host, runConfig.Port, runConfig.User) - return o.connectOracle(runConfig) + logger.Infof("OceanBase 使用 Oracle 租户模式连接:地址=%s:%d 用户=%s(连接层使用 OceanBase MySQL 兼容协议)", runConfig.Host, runConfig.Port, runConfig.User) } addresses := collectOceanBaseAddresses(runConfig) @@ -351,13 +332,11 @@ func (o *OceanBaseDB) Connect(config connection.ConnectionConfig) error { cancel() if pingErr != nil { _ = db.Close() - errorDetails = append(errorDetails, formatOceanBaseMySQLAttemptError(address, pingErr)) + errorDetails = append(errorDetails, formatOceanBaseAttemptError(address, protocol, pingErr)) continue } - o.conn = db - o.pingTimeout = timeout - o.protocol = oceanBaseProtocolMySQL + o.bindConnectedDatabase(db, timeout, protocol) return nil } @@ -475,8 +454,137 @@ func (o *OceanBaseDB) GetTriggers(dbName, tableName string) ([]connection.Trigge } func (o *OceanBaseDB) ApplyChanges(tableName string, changes connection.ChangeSet) error { + if o.protocol == oceanBaseProtocolOracle && o.oracle != nil { + return o.applyOracleChangesMySQLWire(tableName, changes) + } if applier, ok := o.activeDatabase().(BatchApplier); ok { return applier.ApplyChanges(tableName, changes) } return fmt.Errorf("当前 OceanBase %s 协议不支持 ApplyChanges", o.protocol) } + +func (o *OceanBaseDB) applyOracleChangesMySQLWire(tableName string, changes connection.ChangeSet) error { + if o.oracle == nil || o.oracle.conn == nil { + return fmt.Errorf("连接未打开") + } + + columnTypeMap := o.oracle.loadColumnTypeMap(tableName) + + tx, err := o.oracle.conn.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + quoteIdent := func(name string) string { + n := strings.TrimSpace(name) + n = strings.Trim(n, "\"") + n = strings.ReplaceAll(n, "\"", "\"\"") + if n == "" { + return "\"\"" + } + return `"` + n + `"` + } + + schema := "" + table := strings.TrimSpace(tableName) + if parts := strings.SplitN(table, ".", 2); len(parts) == 2 { + schema = strings.TrimSpace(parts[0]) + table = strings.TrimSpace(parts[1]) + } + + qualifiedTable := "" + if schema != "" { + qualifiedTable = fmt.Sprintf("%s.%s", quoteIdent(schema), quoteIdent(table)) + } else { + qualifiedTable = quoteIdent(table) + } + + isOracleRowIDLocator := strings.EqualFold(strings.TrimSpace(changes.LocatorStrategy), "oracle-rowid") + buildWhere := func(keys map[string]interface{}) ([]string, []interface{}) { + var wheres []string + var args []interface{} + for k, v := range keys { + if isOracleRowIDLocator && strings.EqualFold(strings.TrimSpace(k), "ROWID") { + wheres = append(wheres, "ROWID = ?") + args = append(args, v) + continue + } + wheres = append(wheres, fmt.Sprintf("%s = ?", quoteIdent(k))) + args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap)) + } + return wheres, args + } + + for _, pk := range changes.Deletes { + wheres, args := buildWhere(pk) + if len(wheres) == 0 { + continue + } + query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND ")) + res, err := tx.Exec(query, args...) + if err != nil { + return fmt.Errorf("删除失败:%v", err) + } + if err := requireSingleRowAffected(res, "删除"); err != nil { + return err + } + } + + for _, update := range changes.Updates { + var sets []string + var args []interface{} + + for k, v := range update.Values { + sets = append(sets, fmt.Sprintf("%s = ?", quoteIdent(k))) + args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap)) + } + + if len(sets) == 0 { + continue + } + + wheres, whereArgs := buildWhere(update.Keys) + args = append(args, whereArgs...) + + if len(wheres) == 0 { + return fmt.Errorf("更新操作需要主键条件") + } + + query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND ")) + res, err := tx.Exec(query, args...) + if err != nil { + return fmt.Errorf("更新失败:%v", err) + } + if err := requireSingleRowAffected(res, "更新"); err != nil { + return err + } + } + + for _, row := range changes.Inserts { + var cols []string + var placeholders []string + var args []interface{} + + for k, v := range row { + cols = append(cols, quoteIdent(k)) + placeholders = append(placeholders, "?") + args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap)) + } + + if len(cols) == 0 { + continue + } + + query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", ")) + res, err := tx.Exec(query, args...) + if err != nil { + return fmt.Errorf("插入失败:%v", err) + } + if affected, err := res.RowsAffected(); err == nil && affected == 0 { + return fmt.Errorf("插入未生效:未影响任何行") + } + } + + return tx.Commit() +} diff --git a/internal/db/oceanbase_impl_test.go b/internal/db/oceanbase_impl_test.go index ff262fb..5ff6ba6 100644 --- a/internal/db/oceanbase_impl_test.go +++ b/internal/db/oceanbase_impl_test.go @@ -3,6 +3,7 @@ package db import ( + "database/sql/driver" "errors" "strings" "testing" @@ -148,51 +149,68 @@ func TestWithoutOceanBaseProtocolParamsStripsDriverMeta(t *testing.T) { } } -func TestPrepareOceanBaseOracleConfigPromotesURIParams(t *testing.T) { +func TestOceanBaseOracleProtocolUsesMySQLWireConnection(t *testing.T) { t.Parallel() - config := prepareOceanBaseOracleConfig(connection.ConnectionConfig{ - Type: "oceanbase", - URI: "oceanbase://sys%40oracle001:pass@127.0.0.1:2881/ORCL?protocol=oracle&CONNECT_TIMEOUT=12&DBA_PRIVILEGE=SYSDBA", - ConnectionParams: "protocol=oracle&READ_TIMEOUT=7", - }) + dbConn, state := openOracleRecordingDB(t) + state.queryResults["SELECT username FROM all_users ORDER BY username"] = oracleRecordingQueryResult{ + columns: []string{"USERNAME"}, + rows: [][]driver.Value{{"SYS"}}, + } - if config.Type != "oracle" { - t.Fatalf("expected routed type oracle, got %q", config.Type) + oceanbaseDB := &OceanBaseDB{} + oceanbaseDB.bindConnectedDatabase(dbConn, 0, oceanBaseProtocolOracle) + + if oceanbaseDB.oracle == nil { + t.Fatal("expected Oracle metadata wrapper for OceanBase Oracle tenant") } - if config.URI != "" { - t.Fatalf("expected routed Oracle config to clear oceanbase URI, got %q", config.URI) + if oceanbaseDB.conn != nil { + t.Fatal("expected MySQLDB connection slot to stay empty for Oracle tenant wrapper") } - params := connectionParamsFromText(config.ConnectionParams) - if got := params.Get("CONNECT TIMEOUT"); got != "12" { - t.Fatalf("expected URI CONNECT_TIMEOUT promoted, got %q in %q", got, config.ConnectionParams) + if oceanbaseDB.protocol != oceanBaseProtocolOracle { + t.Fatalf("expected protocol oracle, got %q", oceanbaseDB.protocol) } - if got := params.Get("READ TIMEOUT"); got != "7" { - t.Fatalf("expected explicit READ_TIMEOUT kept, got %q in %q", got, config.ConnectionParams) + + databases, err := oceanbaseDB.GetDatabases() + if err != nil { + t.Fatalf("GetDatabases() unexpected error: %v", err) } - if got := params.Get("DBA PRIVILEGE"); got != "SYSDBA" { - t.Fatalf("expected URI DBA_PRIVILEGE promoted, got %q in %q", got, config.ConnectionParams) - } - if strings.Contains(config.ConnectionParams, "protocol=") { - t.Fatalf("expected OceanBase protocol param stripped, got %q", config.ConnectionParams) + if len(databases) != 1 || databases[0] != "SYS" { + t.Fatalf("GetDatabases() = %#v, want [SYS]", databases) } } -func TestOceanBaseOracleRequiresServiceName(t *testing.T) { +func TestOceanBaseOracleApplyChangesUsesMySQLWirePlaceholders(t *testing.T) { t.Parallel() - err := (&OceanBaseDB{}).Connect(connection.ConnectionConfig{ - Type: "oceanbase", - Host: "127.0.0.1", - Port: 2881, - User: "sys@oracle001", - ConnectionParams: "protocol=oracle", - }) - if err == nil { - t.Fatal("expected missing service name error") + dbConn, state := openOracleRecordingDB(t) + oceanbaseDB := &OceanBaseDB{} + oceanbaseDB.bindConnectedDatabase(dbConn, 0, oceanBaseProtocolOracle) + + changes := connection.ChangeSet{ + Updates: []connection.UpdateRow{{ + Keys: map[string]interface{}{ + "ID": 7, + }, + Values: map[string]interface{}{ + "NAME": "new-name", + }, + }}, } - if !strings.Contains(err.Error(), "服务名") { - t.Fatalf("expected service name hint, got %v", err) + + if err := oceanbaseDB.ApplyChanges("APP.USERS", changes); err != nil { + t.Fatalf("ApplyChanges() unexpected error: %v", err) + } + + queries := state.snapshotExecQueries() + if len(queries) != 1 { + t.Fatalf("expected one exec query, got %#v", queries) + } + if strings.Contains(queries[0], ":1") { + t.Fatalf("expected MySQL wire placeholder style, got %q", queries[0]) + } + if !strings.Contains(queries[0], `"NAME" = ?`) || !strings.Contains(queries[0], `"ID" = ?`) { + t.Fatalf("expected question mark placeholders, got %q", queries[0]) } }