Files
MyGoNavi/internal/db/oracle_applychanges_test.go
Syngnat f94a0429d5 🐛 fix(oceanbase): 解决 Oracle 租户 MySQL wire 下双引号被误解析与列元数据静默失败
- DSN 注入 sql_mode='ANSI_QUOTES':让元数据查询的 AS "OWNER" 与 ApplyChanges 的 "schema"."table" 在 MySQL wire 上被识别为标识符
- sql_mode 加入 mysql driver 参数白名单,避免被 mergeMySQLConnectionParam 过滤丢弃
- 加载 Oracle 列元数据失败不再静默,改为返回带 ALL_TAB_COLUMNS 诊断提示的明确错误
- 修复 stripOceanBaseConnectionParamsForCache 未剥离 # 片段导致与 resolveOceanBaseProtocolParam 行为不一致
- 锁定 mysql ParseDSN 对 sys@oracle001#cluster:p@ss 类租户凭据切分的 invariant,防止未来误加 url.QueryEscape
- 同步 OceanBase agent revision,强制旧 driver-agent 被运行时校验拒绝
2026-05-14 12:09:19 +08:00

382 lines
9.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package db
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"io"
"strings"
"sync"
"testing"
"time"
"GoNavi-Wails/internal/connection"
)
const oracleRecordingDriverName = "gonavi_oracle_recording"
var (
registerOracleRecordingDriverOnce sync.Once
oracleRecordingDriverMu sync.Mutex
oracleRecordingDriverSeq int
oracleRecordingDriverStates = map[string]*oracleRecordingState{}
)
type oracleRecordingState struct {
mu sync.Mutex
execQueries []string
execArgs [][]driver.NamedValue
rowsAffected int64
queryResults map[string]oracleRecordingQueryResult
queryError error
}
type oracleRecordingQueryResult struct {
columns []string
rows [][]driver.Value
}
func (s *oracleRecordingState) snapshotExecQueries() []string {
s.mu.Lock()
defer s.mu.Unlock()
return append([]string(nil), s.execQueries...)
}
func (s *oracleRecordingState) snapshotExecArgs() [][]driver.NamedValue {
s.mu.Lock()
defer s.mu.Unlock()
result := make([][]driver.NamedValue, len(s.execArgs))
for i, args := range s.execArgs {
result[i] = append([]driver.NamedValue(nil), args...)
}
return result
}
type oracleRecordingDriver struct{}
func (oracleRecordingDriver) Open(name string) (driver.Conn, error) {
oracleRecordingDriverMu.Lock()
state := oracleRecordingDriverStates[name]
oracleRecordingDriverMu.Unlock()
if state == nil {
return nil, fmt.Errorf("recording state not found: %s", name)
}
return &oracleRecordingConn{state: state}, nil
}
type oracleRecordingConn struct {
state *oracleRecordingState
}
func (c *oracleRecordingConn) Prepare(query string) (driver.Stmt, error) {
return nil, fmt.Errorf("prepare not supported in oracle recording driver: %s", query)
}
func (c *oracleRecordingConn) Close() error { return nil }
func (c *oracleRecordingConn) Begin() (driver.Tx, error) { return oracleRecordingTx{}, nil }
func (c *oracleRecordingConn) ExecContext(_ context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
c.state.mu.Lock()
defer c.state.mu.Unlock()
c.state.execQueries = append(c.state.execQueries, query)
c.state.execArgs = append(c.state.execArgs, append([]driver.NamedValue(nil), args...))
return driver.RowsAffected(c.state.rowsAffected), nil
}
func (c *oracleRecordingConn) QueryContext(_ context.Context, query string, _ []driver.NamedValue) (driver.Rows, error) {
c.state.mu.Lock()
if err := c.state.queryError; err != nil {
c.state.mu.Unlock()
return nil, err
}
if result, ok := c.state.queryResults[query]; ok {
c.state.mu.Unlock()
return &oracleRecordingRows{
columns: append([]string(nil), result.columns...),
rows: cloneOracleRecordingRows(result.rows),
}, nil
}
c.state.mu.Unlock()
if strings.Contains(strings.ToLower(query), "tab_columns") {
return &oracleRecordingRows{
columns: []string{"COLUMN_NAME", "DATA_TYPE", "NULLABLE", "DATA_DEFAULT"},
rows: [][]driver.Value{
{"UPDATED_AT", "TIMESTAMP", "YES", nil},
{"CREATED_AT", "DATE", "NO", nil},
},
}, nil
}
return &oracleRecordingRows{}, nil
}
func cloneOracleRecordingRows(src [][]driver.Value) [][]driver.Value {
dst := make([][]driver.Value, len(src))
for i, row := range src {
dst[i] = append([]driver.Value(nil), row...)
}
return dst
}
var _ driver.ExecerContext = (*oracleRecordingConn)(nil)
var _ driver.QueryerContext = (*oracleRecordingConn)(nil)
type oracleRecordingTx struct{}
func (oracleRecordingTx) Commit() error { return nil }
func (oracleRecordingTx) Rollback() error { return nil }
type oracleRecordingRows struct {
columns []string
rows [][]driver.Value
index int
}
func (r *oracleRecordingRows) Columns() []string {
return append([]string(nil), r.columns...)
}
func (r *oracleRecordingRows) Close() error { return nil }
func (r *oracleRecordingRows) Next(dest []driver.Value) error {
if r.index >= len(r.rows) {
return io.EOF
}
row := r.rows[r.index]
for idx := range dest {
if idx < len(row) {
dest[idx] = row[idx]
}
}
r.index++
return nil
}
func openOracleRecordingDB(t *testing.T) (*sql.DB, *oracleRecordingState) {
t.Helper()
registerOracleRecordingDriverOnce.Do(func() {
sql.Register(oracleRecordingDriverName, oracleRecordingDriver{})
})
oracleRecordingDriverMu.Lock()
oracleRecordingDriverSeq++
dsn := fmt.Sprintf("oracle-recording-%d", oracleRecordingDriverSeq)
state := &oracleRecordingState{rowsAffected: 1, queryResults: map[string]oracleRecordingQueryResult{}}
oracleRecordingDriverStates[dsn] = state
oracleRecordingDriverMu.Unlock()
dbConn, err := sql.Open(oracleRecordingDriverName, dsn)
if err != nil {
t.Fatalf("打开 recording db 失败: %v", err)
}
t.Cleanup(func() {
_ = dbConn.Close()
oracleRecordingDriverMu.Lock()
delete(oracleRecordingDriverStates, dsn)
oracleRecordingDriverMu.Unlock()
})
return dbConn, state
}
func TestOracleApplyChangesReturnsErrorWhenUpdateMatchesNoRows(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
state.rowsAffected = 0
oracleDB := &OracleDB{conn: dbConn}
changes := connection.ChangeSet{
Updates: []connection.UpdateRow{{
Keys: map[string]interface{}{
"ID": 7,
},
Values: map[string]interface{}{
"NAME": "new-name",
},
}},
}
err := oracleDB.ApplyChanges("MYCIMLED.EDC_LOG", changes)
if err == nil {
t.Fatal("期望更新未匹配到行时返回错误,实际为 nil")
}
if !strings.Contains(err.Error(), "更新未生效") {
t.Fatalf("错误信息应提示更新未生效,实际=%v", err)
}
}
func TestOracleApplyChangesReturnsErrorWhenUpdateAffectsMultipleRows(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
state.rowsAffected = 2
oracleDB := &OracleDB{conn: dbConn}
changes := connection.ChangeSet{
Updates: []connection.UpdateRow{{
Keys: map[string]interface{}{
"ID": 7,
},
Values: map[string]interface{}{
"NAME": "new-name",
},
}},
}
err := oracleDB.ApplyChanges("MYCIMLED.EDC_LOG", changes)
if err == nil {
t.Fatal("期望更新影响多行时返回错误,实际为 nil")
}
if !strings.Contains(err.Error(), "影响了 2 行") {
t.Fatalf("错误信息应提示影响多行,实际=%v", err)
}
}
func TestOracleApplyChangesReturnsErrorWhenDeleteAffectsMultipleRows(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
state.rowsAffected = 2
oracleDB := &OracleDB{conn: dbConn}
changes := connection.ChangeSet{
Deletes: []map[string]interface{}{{
"STATUS": "stale",
}},
}
err := oracleDB.ApplyChanges("MYCIMLED.EDC_LOG", changes)
if err == nil {
t.Fatal("期望删除影响多行时返回错误,实际为 nil")
}
if !strings.Contains(err.Error(), "影响了 2 行") {
t.Fatalf("错误信息应提示影响多行,实际=%v", err)
}
}
func TestOracleApplyChangesNormalizesTemporalStringsForUpdate(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
oracleDB := &OracleDB{conn: dbConn}
changes := connection.ChangeSet{
Updates: []connection.UpdateRow{{
Keys: map[string]interface{}{
"CREATED_AT": "2026-03-05T10:30:00Z",
},
Values: map[string]interface{}{
"UPDATED_AT": "2026-04-01T12:13:14.123456789Z",
},
}},
}
if err := oracleDB.ApplyChanges("EVENTS", changes); err != nil {
t.Fatalf("ApplyChanges 返回错误: %v", err)
}
executions := state.snapshotExecArgs()
if len(executions) != 1 {
t.Fatalf("期望执行 1 条更新,实际 %d 条", len(executions))
}
args := executions[0]
if len(args) != 2 {
t.Fatalf("期望 2 个绑定参数,实际 %d 个: %#v", len(args), args)
}
if _, ok := args[0].Value.(time.Time); !ok {
t.Fatalf("更新时间字段应绑定为 time.Time实际=%#v(%T)", args[0].Value, args[0].Value)
}
if _, ok := args[1].Value.(time.Time); !ok {
t.Fatalf("日期主键字段应绑定为 time.Time实际=%#v(%T)", args[1].Value, args[1].Value)
}
}
func TestOracleApplyChangesUsesUnquotedRowIDLocator(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
oracleDB := &OracleDB{conn: dbConn}
changes := connection.ChangeSet{
LocatorStrategy: "oracle-rowid",
Updates: []connection.UpdateRow{{
Keys: map[string]interface{}{
"ROWID": "AAAA",
},
Values: map[string]interface{}{
"NAME": "new-name",
},
}},
}
if err := oracleDB.ApplyChanges("MYCIMLED.EDC_LOG", changes); err != nil {
t.Fatalf("ApplyChanges 返回错误: %v", err)
}
executions := state.snapshotExecQueries()
if len(executions) != 1 {
t.Fatalf("期望执行 1 条更新,实际 %d 条", len(executions))
}
query := executions[0]
if !strings.Contains(query, "ROWID = :2") {
t.Fatalf("ROWID 定位条件不正确: %s", query)
}
if strings.Contains(query, "\"ROWID\" =") {
t.Fatalf("ROWID 不应被当作普通列引用: %s", query)
}
}
func TestMySQLApplyChangesReturnsErrorWhenUpdateAffectsMultipleRows(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
state.rowsAffected = 2
mysqlDB := &MySQLDB{conn: dbConn}
changes := connection.ChangeSet{
Updates: []connection.UpdateRow{{
Keys: map[string]interface{}{
"id": 7,
},
Values: map[string]interface{}{
"name": "new-name",
},
}},
}
err := mysqlDB.ApplyChanges("users", changes)
if err == nil {
t.Fatal("期望 MySQL 更新影响多行时返回错误,实际为 nil")
}
if !strings.Contains(err.Error(), "影响了 2 行") {
t.Fatalf("错误信息应提示影响多行,实际=%v", err)
}
}
func TestPostgresApplyChangesReturnsErrorWhenDeleteAffectsMultipleRows(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
state.rowsAffected = 2
postgresDB := &PostgresDB{conn: dbConn}
changes := connection.ChangeSet{
Deletes: []map[string]interface{}{{
"id": 7,
}},
}
err := postgresDB.ApplyChanges("public.users", changes)
if err == nil {
t.Fatal("期望 PostgreSQL 删除影响多行时返回错误,实际为 nil")
}
if !strings.Contains(err.Error(), "影响了 2 行") {
t.Fatalf("错误信息应提示影响多行,实际=%v", err)
}
}