🐛 fix(oracle): 修复 SQL 编辑器事务提交失败

- Oracle DML 托管事务改用固定连接隐式事务

- 提交和回滚通过 COMMIT/ROLLBACK 结束事务

- 覆盖提交、回滚和执行失败回滚场景
This commit is contained in:
Syngnat
2026-06-12 01:42:14 +08:00
parent 4cac8ef3c9
commit fb73769063
3 changed files with 91 additions and 46 deletions

View File

@@ -621,7 +621,7 @@ func TestDBQueryMultiTransactionalKeepsDMLTransactionOpenUntilCommit(t *testing.
}
}
func TestDBQueryMultiTransactionalUsesDriverTransactionForOracle(t *testing.T) {
func TestDBQueryMultiTransactionalUsesImplicitSessionTransactionForOracle(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
@@ -629,12 +629,10 @@ func TestDBQueryMultiTransactionalUsesDriverTransactionForOracle(t *testing.T) {
firstStmt := "UPDATE users SET name = 'new' WHERE id = 1"
secondStmt := "DELETE FROM audit_logs WHERE user_id = 1"
fakeDB := &fakeTransactionalDB{
fakeBatchWriteDB: fakeBatchWriteDB{
execAffected: map[string]int64{
firstStmt: 1,
secondStmt: 3,
},
fakeDB := &fakeBatchWriteDB{
execAffected: map[string]int64{
firstStmt: 1,
secondStmt: 3,
},
}
newDatabaseFunc = func(dbType string) (db.Database, error) {
@@ -651,15 +649,15 @@ func TestDBQueryMultiTransactionalUsesDriverTransactionForOracle(t *testing.T) {
if result.TransactionID == "" || !result.TransactionPending {
t.Fatalf("expected pending transaction metadata, got id=%q pending=%v", result.TransactionID, result.TransactionPending)
}
if fakeDB.txSession == nil {
t.Fatal("expected Oracle transactional query to open a driver transaction")
if fakeDB.session == nil {
t.Fatal("expected Oracle transactional query to open a pinned session")
}
if fakeDB.txSession.closed {
if fakeDB.session.closed {
t.Fatal("expected Oracle transaction session to stay open before commit")
}
wantExecs := []string{firstStmt, secondStmt}
if len(fakeDB.execQueries) != len(wantExecs) {
t.Fatalf("expected driver transaction exec queries %#v, got %#v", wantExecs, fakeDB.execQueries)
t.Fatalf("expected implicit transaction exec queries %#v, got %#v", wantExecs, fakeDB.execQueries)
}
for i, want := range wantExecs {
if fakeDB.execQueries[i] != want {
@@ -671,51 +669,51 @@ func TestDBQueryMultiTransactionalUsesDriverTransactionForOracle(t *testing.T) {
if !commitResult.Success {
t.Fatalf("expected Oracle commit success, got failure: %s", commitResult.Message)
}
if fakeDB.txSession.commitCalls != 1 || fakeDB.txSession.rollbackCalls != 0 {
t.Fatalf("expected driver commit only, got commits=%d rollbacks=%d", fakeDB.txSession.commitCalls, fakeDB.txSession.rollbackCalls)
}
if !fakeDB.txSession.closed {
if !fakeDB.session.closed {
t.Fatal("expected Oracle transaction session to close after commit")
}
wantExecs = append(wantExecs, "COMMIT")
if len(fakeDB.execQueries) != len(wantExecs) {
t.Fatalf("expected no textual BEGIN/COMMIT for Oracle, got %#v", fakeDB.execQueries)
t.Fatalf("expected Oracle implicit transaction COMMIT on pinned session, got %#v", fakeDB.execQueries)
}
for i, want := range wantExecs {
if fakeDB.execQueries[i] != want {
t.Fatalf("expected exec query %d = %q, got %q", i, want, fakeDB.execQueries[i])
}
}
}
func TestDBQueryMultiTransactionalOracleDriverTransactionOutlivesAppContextCancellation(t *testing.T) {
func TestDBQueryMultiTransactionalOracleImplicitSessionOutlivesAppContextCancellation(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
})
for _, tt := range []struct {
name string
finish func(*App, string) connection.QueryResult
wantCommits int
wantRollbacks int
name string
finish func(*App, string) connection.QueryResult
wantFinalSQL string
}{
{
name: "commit",
finish: func(app *App, transactionID string) connection.QueryResult {
return app.DBCommitTransaction(transactionID)
},
wantCommits: 1,
wantFinalSQL: "COMMIT",
},
{
name: "rollback",
finish: func(app *App, transactionID string) connection.QueryResult {
return app.DBRollbackTransaction(transactionID)
},
wantRollbacks: 1,
wantFinalSQL: "ROLLBACK",
},
} {
t.Run(tt.name, func(t *testing.T) {
stmt := "UPDATE users SET name = 'new' WHERE id = 1"
fakeDB := &fakeTransactionalDB{
fakeBatchWriteDB: fakeBatchWriteDB{
execAffected: map[string]int64{
stmt: 1,
},
fakeDB := &fakeBatchWriteDB{
execAffected: map[string]int64{
stmt: 1,
},
}
newDatabaseFunc = func(dbType string) (db.Database, error) {
@@ -740,15 +738,17 @@ func TestDBQueryMultiTransactionalOracleDriverTransactionOutlivesAppContextCance
if !finishResult.Success {
t.Fatalf("expected Oracle transaction %s success after app context cancellation, got failure: %s", tt.name, finishResult.Message)
}
if fakeDB.txSession.commitCalls != tt.wantCommits || fakeDB.txSession.rollbackCalls != tt.wantRollbacks {
t.Fatalf("expected commits=%d rollbacks=%d, got commits=%d rollbacks=%d",
tt.wantCommits, tt.wantRollbacks, fakeDB.txSession.commitCalls, fakeDB.txSession.rollbackCalls)
if fakeDB.session == nil || !fakeDB.session.closed {
t.Fatal("expected Oracle transaction session to close after finish")
}
if len(fakeDB.execQueries) != 2 || fakeDB.execQueries[0] != stmt || fakeDB.execQueries[1] != tt.wantFinalSQL {
t.Fatalf("expected Oracle implicit transaction to finish with %s, got %#v", tt.wantFinalSQL, fakeDB.execQueries)
}
})
}
}
func TestDBQueryMultiTransactionalRollsBackOracleDriverTransactionOnDMLFailure(t *testing.T) {
func TestDBQueryMultiTransactionalRollsBackOracleImplicitSessionOnDMLFailure(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
@@ -756,11 +756,9 @@ func TestDBQueryMultiTransactionalRollsBackOracleDriverTransactionOnDMLFailure(t
firstStmt := "UPDATE users SET name = 'new' WHERE id = 1"
secondStmt := "DELETE FROM audit_logs WHERE user_id = 1"
fakeDB := &fakeTransactionalDB{
fakeBatchWriteDB: fakeBatchWriteDB{
execErr: map[string]error{
secondStmt: errors.New("delete failed"),
},
fakeDB := &fakeBatchWriteDB{
execErr: map[string]error{
secondStmt: errors.New("delete failed"),
},
}
newDatabaseFunc = func(dbType string) (db.Database, error) {
@@ -777,18 +775,20 @@ func TestDBQueryMultiTransactionalRollsBackOracleDriverTransactionOnDMLFailure(t
if result.TransactionID != "" || result.TransactionPending {
t.Fatalf("expected failed transaction not to be exposed, got id=%q pending=%v", result.TransactionID, result.TransactionPending)
}
if fakeDB.txSession == nil {
t.Fatal("expected Oracle transactional query to open a driver transaction")
if fakeDB.session == nil {
t.Fatal("expected Oracle transactional query to open a pinned session")
}
if fakeDB.txSession.commitCalls != 0 || fakeDB.txSession.rollbackCalls != 1 {
t.Fatalf("expected driver rollback only, got commits=%d rollbacks=%d", fakeDB.txSession.commitCalls, fakeDB.txSession.rollbackCalls)
}
if !fakeDB.txSession.closed {
if !fakeDB.session.closed {
t.Fatal("expected failed Oracle transaction session to close")
}
wantExecs := []string{firstStmt, secondStmt}
wantExecs := []string{firstStmt, secondStmt, "ROLLBACK"}
if len(fakeDB.execQueries) != len(wantExecs) {
t.Fatalf("expected no textual BEGIN/ROLLBACK for Oracle, got %#v", fakeDB.execQueries)
t.Fatalf("expected Oracle implicit transaction rollback, got %#v", fakeDB.execQueries)
}
for i, want := range wantExecs {
if fakeDB.execQueries[i] != want {
t.Fatalf("expected exec query %d = %q, got %q", i, want, fakeDB.execQueries[i])
}
}
}

View File

@@ -30,6 +30,13 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa
}
beginSQL, commitSQL, rollbackSQL, hasTextTransaction := sqlFileBatchTransactionSQL(runConfig.Type)
implicitTextTransaction := false
if implicitCommitSQL, implicitRollbackSQL, ok := sqlEditorImplicitTransactionSQL(runConfig.Type); ok {
commitSQL = implicitCommitSQL
rollbackSQL = implicitRollbackSQL
hasTextTransaction = true
implicitTextTransaction = true
}
dbInst, err := a.getDatabase(runConfig)
if err != nil {
@@ -58,7 +65,21 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa
transactionCancel context.CancelFunc
startTextTransaction bool
)
if provider, ok := dbInst.(db.TransactionExecerProvider); ok {
if implicitTextTransaction {
provider, ok := dbInst.(db.SessionExecerProvider)
if !ok {
return connection.QueryResult{
Success: false,
Message: fmt.Sprintf("当前数据源(%s不支持 SQL 编辑器托管事务", runConfig.Type),
QueryID: queryID,
}
}
sessionExecer, err = provider.OpenSessionExecer(ctx)
if err != nil {
logger.Error(err, "DBQueryMultiTransactional 打开隐式事务会话失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
} else if provider, ok := dbInst.(db.TransactionExecerProvider); ok {
// database/sql rolls back a BeginTx transaction when its context is cancelled.
// SQL editor transactions must outlive the execution RPC and be ended only by
// explicit commit, rollback, or shutdown cleanup.
@@ -276,6 +297,18 @@ func shouldUseManagedSQLTransaction(dbType string, query string) bool {
return hasManagedWrite
}
func sqlEditorImplicitTransactionSQL(dbType string) (commitSQL string, rollbackSQL string, ok bool) {
switch strings.ToLower(strings.TrimSpace(dbType)) {
case "oracle":
// Oracle starts a transaction implicitly on the first DML statement.
// Keeping SQL editor DML on one physical connection avoids database/sql
// Tx context lifecycle ending the transaction before the UI commits it.
return "COMMIT", "ROLLBACK", true
default:
return "", "", false
}
}
func isSQLTransactionControlStatement(stmt string) bool {
switch leadingSQLKeyword(stmt) {
case "begin", "commit", "rollback", "savepoint", "release":

View File

@@ -24,6 +24,7 @@ type OracleDB struct {
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
}
var _ SessionExecerProvider = (*OracleDB)(nil)
var _ TransactionExecerProvider = (*OracleDB)(nil)
func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
@@ -264,6 +265,17 @@ func (o *OracleDB) OpenTransactionExecer(ctx context.Context) (TransactionExecer
return NewSQLTxStatementExecer(tx), nil
}
func (o *OracleDB) OpenSessionExecer(ctx context.Context) (StatementExecer, error) {
if o.conn == nil {
return nil, fmt.Errorf("连接未打开")
}
conn, err := o.conn.Conn(ctx)
if err != nil {
return nil, err
}
return NewSQLConnStatementExecer(conn), nil
}
func (o *OracleDB) GetDatabases() ([]string, error) {
// Oracle treats Users/Schemas as "Databases" in this context
data, _, err := o.Query("SELECT username FROM all_users ORDER BY username")