mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-06-14 02:19:58 +08:00
🐛 fix(transaction): 修复 Oracle 托管事务提交回滚失败
- Oracle 托管事务改为固定物理连接执行 COMMIT/ROLLBACK - SQL 编辑器事务按归一化方言判断 Oracle 兼容协议 - 补充 Oracle 与 OceanBase Oracle 事务回归测试
This commit is contained in:
@@ -683,6 +683,68 @@ func TestDBQueryMultiTransactionalUsesImplicitSessionTransactionForOracle(t *tes
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiTransactionalUsesOracleImplicitSessionForOceanBaseOracleProtocol(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
originalVerifyDriverAgentRevisionFunc := verifyDriverAgentRevisionFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
verifyDriverAgentRevisionFunc = originalVerifyDriverAgentRevisionFunc
|
||||
})
|
||||
|
||||
stmt := "UPDATE USERS SET NAME = 'new' WHERE ID = 1"
|
||||
fakeDB := &fakeBatchWriteDB{
|
||||
execAffected: map[string]int64{
|
||||
stmt: 1,
|
||||
},
|
||||
}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
verifyDriverAgentRevisionFunc = func(config connection.ConnectionConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
config := connection.ConnectionConfig{
|
||||
Type: "oceanbase",
|
||||
Host: "127.0.0.1",
|
||||
Port: 2881,
|
||||
User: "app",
|
||||
OceanBaseProtocol: "oracle",
|
||||
}
|
||||
|
||||
result := app.DBQueryMultiTransactional(config, "APP", stmt, "ob-oracle-tx-query")
|
||||
if !result.Success {
|
||||
t.Fatalf("expected OceanBase Oracle transactional query success, got failure: %s", result.Message)
|
||||
}
|
||||
if result.TransactionID == "" || !result.TransactionPending {
|
||||
t.Fatalf("expected pending transaction metadata, got id=%q pending=%v", result.TransactionID, result.TransactionPending)
|
||||
}
|
||||
if fakeDB.session == nil {
|
||||
t.Fatal("expected OceanBase Oracle transactional query to open a pinned Oracle-style session")
|
||||
}
|
||||
if fakeDB.session.closed {
|
||||
t.Fatal("expected OceanBase Oracle transaction session to stay open before commit")
|
||||
}
|
||||
if len(fakeDB.execQueries) != 1 || fakeDB.execQueries[0] != stmt {
|
||||
t.Fatalf("expected OceanBase Oracle to skip START TRANSACTION and execute only DML before commit, got %#v", fakeDB.execQueries)
|
||||
}
|
||||
|
||||
commitResult := app.DBCommitTransaction(result.TransactionID)
|
||||
if !commitResult.Success {
|
||||
t.Fatalf("expected OceanBase Oracle commit success, got failure: %s", commitResult.Message)
|
||||
}
|
||||
wantExecs := []string{stmt, "COMMIT"}
|
||||
if len(fakeDB.execQueries) != len(wantExecs) {
|
||||
t.Fatalf("expected OceanBase Oracle implicit transaction COMMIT on pinned session, got %#v", fakeDB.execQueries)
|
||||
}
|
||||
for i, want := range wantExecs {
|
||||
if fakeDB.execQueries[i] != want {
|
||||
t.Fatalf("expected exec query %d = %q, got %q", i, want, fakeDB.execQueries[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiTransactionalOracleImplicitSessionOutlivesAppContextCancellation(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
|
||||
@@ -19,19 +19,22 @@ const sqlEditorTransactionFinishTimeout = 30 * time.Second
|
||||
// is called by the SQL editor UI.
|
||||
func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbName string, query string, queryID string) connection.QueryResult {
|
||||
runConfig := normalizeRunConfig(config, dbName)
|
||||
transactionDBType := resolveDDLDBType(runConfig)
|
||||
transactionConfig := runConfig
|
||||
transactionConfig.Type = transactionDBType
|
||||
|
||||
if queryID == "" {
|
||||
queryID = generateQueryID()
|
||||
}
|
||||
|
||||
query = sanitizeSQLForPgLike(resolveDDLDBType(config), query)
|
||||
if !shouldUseManagedSQLTransaction(runConfig.Type, query) {
|
||||
query = sanitizeSQLForPgLike(transactionDBType, query)
|
||||
if !shouldUseManagedSQLTransaction(transactionDBType, query) {
|
||||
return a.DBQueryMulti(config, dbName, query, queryID)
|
||||
}
|
||||
|
||||
beginSQL, commitSQL, rollbackSQL, hasTextTransaction := sqlFileBatchTransactionSQL(runConfig.Type)
|
||||
beginSQL, commitSQL, rollbackSQL, hasTextTransaction := sqlFileBatchTransactionSQL(transactionDBType)
|
||||
implicitTextTransaction := false
|
||||
if implicitCommitSQL, implicitRollbackSQL, ok := sqlEditorImplicitTransactionSQL(runConfig.Type); ok {
|
||||
if implicitCommitSQL, implicitRollbackSQL, ok := sqlEditorImplicitTransactionSQL(transactionDBType); ok {
|
||||
commitSQL = implicitCommitSQL
|
||||
rollbackSQL = implicitRollbackSQL
|
||||
hasTextTransaction = true
|
||||
@@ -70,7 +73,7 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa
|
||||
if !ok {
|
||||
return connection.QueryResult{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("当前数据源(%s)不支持 SQL 编辑器托管事务", runConfig.Type),
|
||||
Message: fmt.Sprintf("当前数据源(%s)不支持 SQL 编辑器托管事务", transactionDBType),
|
||||
QueryID: queryID,
|
||||
}
|
||||
}
|
||||
@@ -97,7 +100,7 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa
|
||||
if !hasTextTransaction {
|
||||
return connection.QueryResult{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("当前数据源(%s)不支持 SQL 编辑器托管事务", runConfig.Type),
|
||||
Message: fmt.Sprintf("当前数据源(%s)不支持 SQL 编辑器托管事务", transactionDBType),
|
||||
QueryID: queryID,
|
||||
}
|
||||
}
|
||||
@@ -105,7 +108,7 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa
|
||||
if !ok {
|
||||
return connection.QueryResult{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("当前数据源(%s)不支持 SQL 编辑器托管事务", runConfig.Type),
|
||||
Message: fmt.Sprintf("当前数据源(%s)不支持 SQL 编辑器托管事务", transactionDBType),
|
||||
QueryID: queryID,
|
||||
}
|
||||
}
|
||||
@@ -137,7 +140,7 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa
|
||||
}
|
||||
|
||||
statements := splitSQLStatements(query)
|
||||
resultSets, err := executeManagedSQLTransactionStatements(ctx, sessionExecer, runConfig, statements)
|
||||
resultSets, err := executeManagedSQLTransactionStatements(ctx, sessionExecer, transactionConfig, statements)
|
||||
if err != nil {
|
||||
var rollbackErr error
|
||||
if transactor != nil {
|
||||
@@ -163,7 +166,7 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa
|
||||
execer: sessionExecer,
|
||||
transactor: transactor,
|
||||
cancel: transactionCancel,
|
||||
dbType: runConfig.Type,
|
||||
dbType: transactionDBType,
|
||||
commitSQL: commitSQL,
|
||||
rollbackSQL: rollbackSQL,
|
||||
createdAt: time.Now(),
|
||||
|
||||
@@ -128,8 +128,8 @@ type TransactionExecer interface {
|
||||
Rollback() error
|
||||
}
|
||||
|
||||
// TransactionExecerProvider is implemented by drivers that can expose a real
|
||||
// database/sql transaction for long-running SQL editor managed transactions.
|
||||
// TransactionExecerProvider is implemented by drivers that can expose a
|
||||
// long-running SQL editor managed transaction.
|
||||
type TransactionExecerProvider interface {
|
||||
OpenTransactionExecer(ctx context.Context) (TransactionExecer, error)
|
||||
}
|
||||
@@ -200,6 +200,141 @@ func (e *sqlConnStatementExecer) Close() error {
|
||||
return e.conn.Close()
|
||||
}
|
||||
|
||||
type sqlConnTransactionExecer struct {
|
||||
mu sync.Mutex
|
||||
conn *sql.Conn
|
||||
done bool
|
||||
commitSQL string
|
||||
rollbackSQL string
|
||||
}
|
||||
|
||||
func NewSQLConnTransactionExecer(conn *sql.Conn, commitSQL string, rollbackSQL string) TransactionExecer {
|
||||
return &sqlConnTransactionExecer{
|
||||
conn: conn,
|
||||
commitSQL: strings.TrimSpace(commitSQL),
|
||||
rollbackSQL: strings.TrimSpace(rollbackSQL),
|
||||
}
|
||||
}
|
||||
|
||||
func (e *sqlConnTransactionExecer) activeConn() (*sql.Conn, error) {
|
||||
if e == nil {
|
||||
return nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
if e.conn == nil {
|
||||
return nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
if e.done {
|
||||
return nil, fmt.Errorf("事务已结束")
|
||||
}
|
||||
return e.conn, nil
|
||||
}
|
||||
|
||||
func (e *sqlConnTransactionExecer) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
conn, err := e.activeConn()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
res, err := conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (e *sqlConnTransactionExecer) Exec(query string) (int64, error) {
|
||||
return e.ExecContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (e *sqlConnTransactionExecer) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
conn, err := e.activeConn()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
rows, err := conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (e *sqlConnTransactionExecer) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
return e.QueryContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (e *sqlConnTransactionExecer) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
|
||||
conn, err := e.activeConn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows, err := conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanMultiRows(rows)
|
||||
}
|
||||
|
||||
func (e *sqlConnTransactionExecer) QueryMulti(query string) ([]connection.ResultSetData, error) {
|
||||
return e.QueryMultiContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (e *sqlConnTransactionExecer) finish(sqlText string) error {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
e.mu.Lock()
|
||||
if e.conn == nil || e.done {
|
||||
e.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
conn := e.conn
|
||||
e.done = true
|
||||
e.mu.Unlock()
|
||||
if strings.TrimSpace(sqlText) == "" {
|
||||
return nil
|
||||
}
|
||||
_, err := conn.ExecContext(context.Background(), sqlText)
|
||||
return err
|
||||
}
|
||||
|
||||
func (e *sqlConnTransactionExecer) Commit() error {
|
||||
return e.finish(e.commitSQL)
|
||||
}
|
||||
|
||||
func (e *sqlConnTransactionExecer) Rollback() error {
|
||||
return e.finish(e.rollbackSQL)
|
||||
}
|
||||
|
||||
func (e *sqlConnTransactionExecer) Close() error {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
e.mu.Lock()
|
||||
if e.conn == nil {
|
||||
e.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
conn := e.conn
|
||||
shouldRollback := !e.done && e.rollbackSQL != ""
|
||||
rollbackSQL := e.rollbackSQL
|
||||
e.conn = nil
|
||||
e.done = true
|
||||
e.mu.Unlock()
|
||||
|
||||
var rollbackErr error
|
||||
if shouldRollback {
|
||||
_, rollbackErr = conn.ExecContext(context.Background(), rollbackSQL)
|
||||
}
|
||||
closeErr := conn.Close()
|
||||
if rollbackErr != nil {
|
||||
return rollbackErr
|
||||
}
|
||||
return closeErr
|
||||
}
|
||||
|
||||
type sqlTxStatementExecer struct {
|
||||
mu sync.Mutex
|
||||
tx *sql.Tx
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -28,6 +29,7 @@ type oracleRecordingState struct {
|
||||
execQueries []string
|
||||
execArgs [][]driver.NamedValue
|
||||
queries []string
|
||||
beginCalls int
|
||||
rowsAffected int64
|
||||
queryResults map[string]oracleRecordingQueryResult
|
||||
queryError error
|
||||
@@ -61,6 +63,12 @@ func (s *oracleRecordingState) snapshotQueries() []string {
|
||||
return append([]string(nil), s.queries...)
|
||||
}
|
||||
|
||||
func (s *oracleRecordingState) snapshotBeginCalls() int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.beginCalls
|
||||
}
|
||||
|
||||
type oracleRecordingDriver struct{}
|
||||
|
||||
func (oracleRecordingDriver) Open(name string) (driver.Conn, error) {
|
||||
@@ -83,7 +91,12 @@ func (c *oracleRecordingConn) Prepare(query string) (driver.Stmt, error) {
|
||||
|
||||
func (c *oracleRecordingConn) Close() error { return nil }
|
||||
|
||||
func (c *oracleRecordingConn) Begin() (driver.Tx, error) { return oracleRecordingTx{}, nil }
|
||||
func (c *oracleRecordingConn) Begin() (driver.Tx, error) {
|
||||
c.state.mu.Lock()
|
||||
c.state.beginCalls++
|
||||
c.state.mu.Unlock()
|
||||
return oracleRecordingTx{}, nil
|
||||
}
|
||||
|
||||
func (c *oracleRecordingConn) ExecContext(_ context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
c.state.mu.Lock()
|
||||
@@ -191,6 +204,61 @@ func openOracleRecordingDB(t *testing.T) (*sql.DB, *oracleRecordingState) {
|
||||
return dbConn, state
|
||||
}
|
||||
|
||||
func TestOracleOpenTransactionExecerUsesPinnedSessionTransactionSQL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
finish func(TransactionExecer) error
|
||||
wantFinalSQL string
|
||||
}{
|
||||
{
|
||||
name: "commit",
|
||||
finish: func(tx TransactionExecer) error {
|
||||
return tx.Commit()
|
||||
},
|
||||
wantFinalSQL: "COMMIT",
|
||||
},
|
||||
{
|
||||
name: "rollback",
|
||||
finish: func(tx TransactionExecer) error {
|
||||
return tx.Rollback()
|
||||
},
|
||||
wantFinalSQL: "ROLLBACK",
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbConn, state := openOracleRecordingDB(t)
|
||||
oracleDB := &OracleDB{conn: dbConn}
|
||||
stmt := "UPDATE USERS SET NAME = 'new' WHERE ID = 1"
|
||||
|
||||
tx, err := oracleDB.OpenTransactionExecer(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("OpenTransactionExecer returned error: %v", err)
|
||||
}
|
||||
if _, err := tx.ExecContext(context.Background(), stmt); err != nil {
|
||||
t.Fatalf("ExecContext returned error: %v", err)
|
||||
}
|
||||
if err := tt.finish(tx); err != nil {
|
||||
t.Fatalf("finish returned error: %v", err)
|
||||
}
|
||||
if err := tx.Close(); err != nil {
|
||||
t.Fatalf("Close returned error: %v", err)
|
||||
}
|
||||
|
||||
if got := state.snapshotBeginCalls(); got != 0 {
|
||||
t.Fatalf("expected Oracle transaction execer not to call database/sql Begin, got %d", got)
|
||||
}
|
||||
wantExecs := []string{stmt, tt.wantFinalSQL}
|
||||
if got := state.snapshotExecQueries(); !reflect.DeepEqual(got, wantExecs) {
|
||||
t.Fatalf("expected exec queries %#v, got %#v", wantExecs, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOracleApplyChangesReturnsErrorWhenUpdateMatchesNoRows(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -258,11 +258,11 @@ func (o *OracleDB) OpenTransactionExecer(ctx context.Context) (TransactionExecer
|
||||
if o.conn == nil {
|
||||
return nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
tx, err := o.conn.BeginTx(ctx, nil)
|
||||
conn, err := o.conn.Conn(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewSQLTxStatementExecer(tx), nil
|
||||
return NewSQLConnTransactionExecer(conn, "COMMIT", "ROLLBACK"), nil
|
||||
}
|
||||
|
||||
func (o *OracleDB) OpenSessionExecer(ctx context.Context) (StatementExecer, error) {
|
||||
|
||||
Reference in New Issue
Block a user