🐛 fix(transaction): 修复 Oracle 托管事务提交回滚失败

- Oracle 托管事务改为固定物理连接执行 COMMIT/ROLLBACK

- SQL 编辑器事务按归一化方言判断 Oracle 兼容协议

- 补充 Oracle 与 OceanBase Oracle 事务回归测试
This commit is contained in:
Syngnat
2026-06-12 02:51:01 +08:00
parent 453e13c88d
commit 8a0dc3a7d3
5 changed files with 282 additions and 14 deletions

View File

@@ -683,6 +683,68 @@ func TestDBQueryMultiTransactionalUsesImplicitSessionTransactionForOracle(t *tes
}
}
func TestDBQueryMultiTransactionalUsesOracleImplicitSessionForOceanBaseOracleProtocol(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
originalVerifyDriverAgentRevisionFunc := verifyDriverAgentRevisionFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
verifyDriverAgentRevisionFunc = originalVerifyDriverAgentRevisionFunc
})
stmt := "UPDATE USERS SET NAME = 'new' WHERE ID = 1"
fakeDB := &fakeBatchWriteDB{
execAffected: map[string]int64{
stmt: 1,
},
}
newDatabaseFunc = func(dbType string) (db.Database, error) {
return fakeDB, nil
}
verifyDriverAgentRevisionFunc = func(config connection.ConnectionConfig) error {
return nil
}
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
config := connection.ConnectionConfig{
Type: "oceanbase",
Host: "127.0.0.1",
Port: 2881,
User: "app",
OceanBaseProtocol: "oracle",
}
result := app.DBQueryMultiTransactional(config, "APP", stmt, "ob-oracle-tx-query")
if !result.Success {
t.Fatalf("expected OceanBase Oracle transactional query success, got failure: %s", result.Message)
}
if result.TransactionID == "" || !result.TransactionPending {
t.Fatalf("expected pending transaction metadata, got id=%q pending=%v", result.TransactionID, result.TransactionPending)
}
if fakeDB.session == nil {
t.Fatal("expected OceanBase Oracle transactional query to open a pinned Oracle-style session")
}
if fakeDB.session.closed {
t.Fatal("expected OceanBase Oracle transaction session to stay open before commit")
}
if len(fakeDB.execQueries) != 1 || fakeDB.execQueries[0] != stmt {
t.Fatalf("expected OceanBase Oracle to skip START TRANSACTION and execute only DML before commit, got %#v", fakeDB.execQueries)
}
commitResult := app.DBCommitTransaction(result.TransactionID)
if !commitResult.Success {
t.Fatalf("expected OceanBase Oracle commit success, got failure: %s", commitResult.Message)
}
wantExecs := []string{stmt, "COMMIT"}
if len(fakeDB.execQueries) != len(wantExecs) {
t.Fatalf("expected OceanBase Oracle implicit transaction COMMIT on pinned session, got %#v", fakeDB.execQueries)
}
for i, want := range wantExecs {
if fakeDB.execQueries[i] != want {
t.Fatalf("expected exec query %d = %q, got %q", i, want, fakeDB.execQueries[i])
}
}
}
func TestDBQueryMultiTransactionalOracleImplicitSessionOutlivesAppContextCancellation(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {

View File

@@ -19,19 +19,22 @@ const sqlEditorTransactionFinishTimeout = 30 * time.Second
// is called by the SQL editor UI.
func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbName string, query string, queryID string) connection.QueryResult {
runConfig := normalizeRunConfig(config, dbName)
transactionDBType := resolveDDLDBType(runConfig)
transactionConfig := runConfig
transactionConfig.Type = transactionDBType
if queryID == "" {
queryID = generateQueryID()
}
query = sanitizeSQLForPgLike(resolveDDLDBType(config), query)
if !shouldUseManagedSQLTransaction(runConfig.Type, query) {
query = sanitizeSQLForPgLike(transactionDBType, query)
if !shouldUseManagedSQLTransaction(transactionDBType, query) {
return a.DBQueryMulti(config, dbName, query, queryID)
}
beginSQL, commitSQL, rollbackSQL, hasTextTransaction := sqlFileBatchTransactionSQL(runConfig.Type)
beginSQL, commitSQL, rollbackSQL, hasTextTransaction := sqlFileBatchTransactionSQL(transactionDBType)
implicitTextTransaction := false
if implicitCommitSQL, implicitRollbackSQL, ok := sqlEditorImplicitTransactionSQL(runConfig.Type); ok {
if implicitCommitSQL, implicitRollbackSQL, ok := sqlEditorImplicitTransactionSQL(transactionDBType); ok {
commitSQL = implicitCommitSQL
rollbackSQL = implicitRollbackSQL
hasTextTransaction = true
@@ -70,7 +73,7 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa
if !ok {
return connection.QueryResult{
Success: false,
Message: fmt.Sprintf("当前数据源(%s不支持 SQL 编辑器托管事务", runConfig.Type),
Message: fmt.Sprintf("当前数据源(%s不支持 SQL 编辑器托管事务", transactionDBType),
QueryID: queryID,
}
}
@@ -97,7 +100,7 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa
if !hasTextTransaction {
return connection.QueryResult{
Success: false,
Message: fmt.Sprintf("当前数据源(%s不支持 SQL 编辑器托管事务", runConfig.Type),
Message: fmt.Sprintf("当前数据源(%s不支持 SQL 编辑器托管事务", transactionDBType),
QueryID: queryID,
}
}
@@ -105,7 +108,7 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa
if !ok {
return connection.QueryResult{
Success: false,
Message: fmt.Sprintf("当前数据源(%s不支持 SQL 编辑器托管事务", runConfig.Type),
Message: fmt.Sprintf("当前数据源(%s不支持 SQL 编辑器托管事务", transactionDBType),
QueryID: queryID,
}
}
@@ -137,7 +140,7 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa
}
statements := splitSQLStatements(query)
resultSets, err := executeManagedSQLTransactionStatements(ctx, sessionExecer, runConfig, statements)
resultSets, err := executeManagedSQLTransactionStatements(ctx, sessionExecer, transactionConfig, statements)
if err != nil {
var rollbackErr error
if transactor != nil {
@@ -163,7 +166,7 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa
execer: sessionExecer,
transactor: transactor,
cancel: transactionCancel,
dbType: runConfig.Type,
dbType: transactionDBType,
commitSQL: commitSQL,
rollbackSQL: rollbackSQL,
createdAt: time.Now(),

View File

@@ -128,8 +128,8 @@ type TransactionExecer interface {
Rollback() error
}
// TransactionExecerProvider is implemented by drivers that can expose a real
// database/sql transaction for long-running SQL editor managed transactions.
// TransactionExecerProvider is implemented by drivers that can expose a
// long-running SQL editor managed transaction.
type TransactionExecerProvider interface {
OpenTransactionExecer(ctx context.Context) (TransactionExecer, error)
}
@@ -200,6 +200,141 @@ func (e *sqlConnStatementExecer) Close() error {
return e.conn.Close()
}
type sqlConnTransactionExecer struct {
mu sync.Mutex
conn *sql.Conn
done bool
commitSQL string
rollbackSQL string
}
func NewSQLConnTransactionExecer(conn *sql.Conn, commitSQL string, rollbackSQL string) TransactionExecer {
return &sqlConnTransactionExecer{
conn: conn,
commitSQL: strings.TrimSpace(commitSQL),
rollbackSQL: strings.TrimSpace(rollbackSQL),
}
}
func (e *sqlConnTransactionExecer) activeConn() (*sql.Conn, error) {
if e == nil {
return nil, fmt.Errorf("连接未打开")
}
e.mu.Lock()
defer e.mu.Unlock()
if e.conn == nil {
return nil, fmt.Errorf("连接未打开")
}
if e.done {
return nil, fmt.Errorf("事务已结束")
}
return e.conn, nil
}
func (e *sqlConnTransactionExecer) ExecContext(ctx context.Context, query string) (int64, error) {
conn, err := e.activeConn()
if err != nil {
return 0, err
}
res, err := conn.ExecContext(ctx, query)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (e *sqlConnTransactionExecer) Exec(query string) (int64, error) {
return e.ExecContext(context.Background(), query)
}
func (e *sqlConnTransactionExecer) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
conn, err := e.activeConn()
if err != nil {
return nil, nil, err
}
rows, err := conn.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (e *sqlConnTransactionExecer) Query(query string) ([]map[string]interface{}, []string, error) {
return e.QueryContext(context.Background(), query)
}
func (e *sqlConnTransactionExecer) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
conn, err := e.activeConn()
if err != nil {
return nil, err
}
rows, err := conn.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
return scanMultiRows(rows)
}
func (e *sqlConnTransactionExecer) QueryMulti(query string) ([]connection.ResultSetData, error) {
return e.QueryMultiContext(context.Background(), query)
}
func (e *sqlConnTransactionExecer) finish(sqlText string) error {
if e == nil {
return nil
}
e.mu.Lock()
if e.conn == nil || e.done {
e.mu.Unlock()
return nil
}
conn := e.conn
e.done = true
e.mu.Unlock()
if strings.TrimSpace(sqlText) == "" {
return nil
}
_, err := conn.ExecContext(context.Background(), sqlText)
return err
}
func (e *sqlConnTransactionExecer) Commit() error {
return e.finish(e.commitSQL)
}
func (e *sqlConnTransactionExecer) Rollback() error {
return e.finish(e.rollbackSQL)
}
func (e *sqlConnTransactionExecer) Close() error {
if e == nil {
return nil
}
e.mu.Lock()
if e.conn == nil {
e.mu.Unlock()
return nil
}
conn := e.conn
shouldRollback := !e.done && e.rollbackSQL != ""
rollbackSQL := e.rollbackSQL
e.conn = nil
e.done = true
e.mu.Unlock()
var rollbackErr error
if shouldRollback {
_, rollbackErr = conn.ExecContext(context.Background(), rollbackSQL)
}
closeErr := conn.Close()
if rollbackErr != nil {
return rollbackErr
}
return closeErr
}
type sqlTxStatementExecer struct {
mu sync.Mutex
tx *sql.Tx

View File

@@ -6,6 +6,7 @@ import (
"database/sql/driver"
"fmt"
"io"
"reflect"
"strings"
"sync"
"testing"
@@ -28,6 +29,7 @@ type oracleRecordingState struct {
execQueries []string
execArgs [][]driver.NamedValue
queries []string
beginCalls int
rowsAffected int64
queryResults map[string]oracleRecordingQueryResult
queryError error
@@ -61,6 +63,12 @@ func (s *oracleRecordingState) snapshotQueries() []string {
return append([]string(nil), s.queries...)
}
func (s *oracleRecordingState) snapshotBeginCalls() int {
s.mu.Lock()
defer s.mu.Unlock()
return s.beginCalls
}
type oracleRecordingDriver struct{}
func (oracleRecordingDriver) Open(name string) (driver.Conn, error) {
@@ -83,7 +91,12 @@ func (c *oracleRecordingConn) Prepare(query string) (driver.Stmt, error) {
func (c *oracleRecordingConn) Close() error { return nil }
func (c *oracleRecordingConn) Begin() (driver.Tx, error) { return oracleRecordingTx{}, nil }
func (c *oracleRecordingConn) Begin() (driver.Tx, error) {
c.state.mu.Lock()
c.state.beginCalls++
c.state.mu.Unlock()
return oracleRecordingTx{}, nil
}
func (c *oracleRecordingConn) ExecContext(_ context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
c.state.mu.Lock()
@@ -191,6 +204,61 @@ func openOracleRecordingDB(t *testing.T) (*sql.DB, *oracleRecordingState) {
return dbConn, state
}
func TestOracleOpenTransactionExecerUsesPinnedSessionTransactionSQL(t *testing.T) {
t.Parallel()
for _, tt := range []struct {
name string
finish func(TransactionExecer) error
wantFinalSQL string
}{
{
name: "commit",
finish: func(tx TransactionExecer) error {
return tx.Commit()
},
wantFinalSQL: "COMMIT",
},
{
name: "rollback",
finish: func(tx TransactionExecer) error {
return tx.Rollback()
},
wantFinalSQL: "ROLLBACK",
},
} {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
oracleDB := &OracleDB{conn: dbConn}
stmt := "UPDATE USERS SET NAME = 'new' WHERE ID = 1"
tx, err := oracleDB.OpenTransactionExecer(context.Background())
if err != nil {
t.Fatalf("OpenTransactionExecer returned error: %v", err)
}
if _, err := tx.ExecContext(context.Background(), stmt); err != nil {
t.Fatalf("ExecContext returned error: %v", err)
}
if err := tt.finish(tx); err != nil {
t.Fatalf("finish returned error: %v", err)
}
if err := tx.Close(); err != nil {
t.Fatalf("Close returned error: %v", err)
}
if got := state.snapshotBeginCalls(); got != 0 {
t.Fatalf("expected Oracle transaction execer not to call database/sql Begin, got %d", got)
}
wantExecs := []string{stmt, tt.wantFinalSQL}
if got := state.snapshotExecQueries(); !reflect.DeepEqual(got, wantExecs) {
t.Fatalf("expected exec queries %#v, got %#v", wantExecs, got)
}
})
}
}
func TestOracleApplyChangesReturnsErrorWhenUpdateMatchesNoRows(t *testing.T) {
t.Parallel()

View File

@@ -258,11 +258,11 @@ func (o *OracleDB) OpenTransactionExecer(ctx context.Context) (TransactionExecer
if o.conn == nil {
return nil, fmt.Errorf("连接未打开")
}
tx, err := o.conn.BeginTx(ctx, nil)
conn, err := o.conn.Conn(ctx)
if err != nil {
return nil, err
}
return NewSQLTxStatementExecer(tx), nil
return NewSQLConnTransactionExecer(conn, "COMMIT", "ROLLBACK"), nil
}
func (o *OracleDB) OpenSessionExecer(ctx context.Context) (StatementExecer, error) {