From fb737690634d567321035d462986bd4f6b10da4e Mon Sep 17 00:00:00 2001 From: Syngnat Date: Fri, 12 Jun 2026 01:42:14 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(oracle):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=20SQL=20=E7=BC=96=E8=BE=91=E5=99=A8=E4=BA=8B=E5=8A=A1?= =?UTF-8?q?=E6=8F=90=E4=BA=A4=E5=A4=B1=E8=B4=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Oracle DML 托管事务改用固定连接隐式事务 - 提交和回滚通过 COMMIT/ROLLBACK 结束事务 - 覆盖提交、回滚和执行失败回滚场景 --- internal/app/methods_db_multi_test.go | 90 +++++++++++++------------- internal/app/methods_db_transaction.go | 35 +++++++++- internal/db/oracle_impl.go | 12 ++++ 3 files changed, 91 insertions(+), 46 deletions(-) diff --git a/internal/app/methods_db_multi_test.go b/internal/app/methods_db_multi_test.go index c763291..9b7e164 100644 --- a/internal/app/methods_db_multi_test.go +++ b/internal/app/methods_db_multi_test.go @@ -621,7 +621,7 @@ func TestDBQueryMultiTransactionalKeepsDMLTransactionOpenUntilCommit(t *testing. } } -func TestDBQueryMultiTransactionalUsesDriverTransactionForOracle(t *testing.T) { +func TestDBQueryMultiTransactionalUsesImplicitSessionTransactionForOracle(t *testing.T) { originalNewDatabaseFunc := newDatabaseFunc t.Cleanup(func() { newDatabaseFunc = originalNewDatabaseFunc @@ -629,12 +629,10 @@ func TestDBQueryMultiTransactionalUsesDriverTransactionForOracle(t *testing.T) { firstStmt := "UPDATE users SET name = 'new' WHERE id = 1" secondStmt := "DELETE FROM audit_logs WHERE user_id = 1" - fakeDB := &fakeTransactionalDB{ - fakeBatchWriteDB: fakeBatchWriteDB{ - execAffected: map[string]int64{ - firstStmt: 1, - secondStmt: 3, - }, + fakeDB := &fakeBatchWriteDB{ + execAffected: map[string]int64{ + firstStmt: 1, + secondStmt: 3, }, } newDatabaseFunc = func(dbType string) (db.Database, error) { @@ -651,15 +649,15 @@ func TestDBQueryMultiTransactionalUsesDriverTransactionForOracle(t *testing.T) { if result.TransactionID == "" || !result.TransactionPending { t.Fatalf("expected pending transaction metadata, got id=%q pending=%v", result.TransactionID, result.TransactionPending) } - if fakeDB.txSession == nil { - t.Fatal("expected Oracle transactional query to open a driver transaction") + if fakeDB.session == nil { + t.Fatal("expected Oracle transactional query to open a pinned session") } - if fakeDB.txSession.closed { + if fakeDB.session.closed { t.Fatal("expected Oracle transaction session to stay open before commit") } wantExecs := []string{firstStmt, secondStmt} if len(fakeDB.execQueries) != len(wantExecs) { - t.Fatalf("expected driver transaction exec queries %#v, got %#v", wantExecs, fakeDB.execQueries) + t.Fatalf("expected implicit transaction exec queries %#v, got %#v", wantExecs, fakeDB.execQueries) } for i, want := range wantExecs { if fakeDB.execQueries[i] != want { @@ -671,51 +669,51 @@ func TestDBQueryMultiTransactionalUsesDriverTransactionForOracle(t *testing.T) { if !commitResult.Success { t.Fatalf("expected Oracle commit success, got failure: %s", commitResult.Message) } - if fakeDB.txSession.commitCalls != 1 || fakeDB.txSession.rollbackCalls != 0 { - t.Fatalf("expected driver commit only, got commits=%d rollbacks=%d", fakeDB.txSession.commitCalls, fakeDB.txSession.rollbackCalls) - } - if !fakeDB.txSession.closed { + if !fakeDB.session.closed { t.Fatal("expected Oracle transaction session to close after commit") } + wantExecs = append(wantExecs, "COMMIT") if len(fakeDB.execQueries) != len(wantExecs) { - t.Fatalf("expected no textual BEGIN/COMMIT for Oracle, got %#v", fakeDB.execQueries) + t.Fatalf("expected Oracle implicit transaction COMMIT on pinned session, got %#v", fakeDB.execQueries) + } + for i, want := range wantExecs { + if fakeDB.execQueries[i] != want { + t.Fatalf("expected exec query %d = %q, got %q", i, want, fakeDB.execQueries[i]) + } } } -func TestDBQueryMultiTransactionalOracleDriverTransactionOutlivesAppContextCancellation(t *testing.T) { +func TestDBQueryMultiTransactionalOracleImplicitSessionOutlivesAppContextCancellation(t *testing.T) { originalNewDatabaseFunc := newDatabaseFunc t.Cleanup(func() { newDatabaseFunc = originalNewDatabaseFunc }) for _, tt := range []struct { - name string - finish func(*App, string) connection.QueryResult - wantCommits int - wantRollbacks int + name string + finish func(*App, string) connection.QueryResult + wantFinalSQL string }{ { name: "commit", finish: func(app *App, transactionID string) connection.QueryResult { return app.DBCommitTransaction(transactionID) }, - wantCommits: 1, + wantFinalSQL: "COMMIT", }, { name: "rollback", finish: func(app *App, transactionID string) connection.QueryResult { return app.DBRollbackTransaction(transactionID) }, - wantRollbacks: 1, + wantFinalSQL: "ROLLBACK", }, } { t.Run(tt.name, func(t *testing.T) { stmt := "UPDATE users SET name = 'new' WHERE id = 1" - fakeDB := &fakeTransactionalDB{ - fakeBatchWriteDB: fakeBatchWriteDB{ - execAffected: map[string]int64{ - stmt: 1, - }, + fakeDB := &fakeBatchWriteDB{ + execAffected: map[string]int64{ + stmt: 1, }, } newDatabaseFunc = func(dbType string) (db.Database, error) { @@ -740,15 +738,17 @@ func TestDBQueryMultiTransactionalOracleDriverTransactionOutlivesAppContextCance if !finishResult.Success { t.Fatalf("expected Oracle transaction %s success after app context cancellation, got failure: %s", tt.name, finishResult.Message) } - if fakeDB.txSession.commitCalls != tt.wantCommits || fakeDB.txSession.rollbackCalls != tt.wantRollbacks { - t.Fatalf("expected commits=%d rollbacks=%d, got commits=%d rollbacks=%d", - tt.wantCommits, tt.wantRollbacks, fakeDB.txSession.commitCalls, fakeDB.txSession.rollbackCalls) + if fakeDB.session == nil || !fakeDB.session.closed { + t.Fatal("expected Oracle transaction session to close after finish") + } + if len(fakeDB.execQueries) != 2 || fakeDB.execQueries[0] != stmt || fakeDB.execQueries[1] != tt.wantFinalSQL { + t.Fatalf("expected Oracle implicit transaction to finish with %s, got %#v", tt.wantFinalSQL, fakeDB.execQueries) } }) } } -func TestDBQueryMultiTransactionalRollsBackOracleDriverTransactionOnDMLFailure(t *testing.T) { +func TestDBQueryMultiTransactionalRollsBackOracleImplicitSessionOnDMLFailure(t *testing.T) { originalNewDatabaseFunc := newDatabaseFunc t.Cleanup(func() { newDatabaseFunc = originalNewDatabaseFunc @@ -756,11 +756,9 @@ func TestDBQueryMultiTransactionalRollsBackOracleDriverTransactionOnDMLFailure(t firstStmt := "UPDATE users SET name = 'new' WHERE id = 1" secondStmt := "DELETE FROM audit_logs WHERE user_id = 1" - fakeDB := &fakeTransactionalDB{ - fakeBatchWriteDB: fakeBatchWriteDB{ - execErr: map[string]error{ - secondStmt: errors.New("delete failed"), - }, + fakeDB := &fakeBatchWriteDB{ + execErr: map[string]error{ + secondStmt: errors.New("delete failed"), }, } newDatabaseFunc = func(dbType string) (db.Database, error) { @@ -777,18 +775,20 @@ func TestDBQueryMultiTransactionalRollsBackOracleDriverTransactionOnDMLFailure(t if result.TransactionID != "" || result.TransactionPending { t.Fatalf("expected failed transaction not to be exposed, got id=%q pending=%v", result.TransactionID, result.TransactionPending) } - if fakeDB.txSession == nil { - t.Fatal("expected Oracle transactional query to open a driver transaction") + if fakeDB.session == nil { + t.Fatal("expected Oracle transactional query to open a pinned session") } - if fakeDB.txSession.commitCalls != 0 || fakeDB.txSession.rollbackCalls != 1 { - t.Fatalf("expected driver rollback only, got commits=%d rollbacks=%d", fakeDB.txSession.commitCalls, fakeDB.txSession.rollbackCalls) - } - if !fakeDB.txSession.closed { + if !fakeDB.session.closed { t.Fatal("expected failed Oracle transaction session to close") } - wantExecs := []string{firstStmt, secondStmt} + wantExecs := []string{firstStmt, secondStmt, "ROLLBACK"} if len(fakeDB.execQueries) != len(wantExecs) { - t.Fatalf("expected no textual BEGIN/ROLLBACK for Oracle, got %#v", fakeDB.execQueries) + t.Fatalf("expected Oracle implicit transaction rollback, got %#v", fakeDB.execQueries) + } + for i, want := range wantExecs { + if fakeDB.execQueries[i] != want { + t.Fatalf("expected exec query %d = %q, got %q", i, want, fakeDB.execQueries[i]) + } } } diff --git a/internal/app/methods_db_transaction.go b/internal/app/methods_db_transaction.go index d5752c8..3001e3a 100644 --- a/internal/app/methods_db_transaction.go +++ b/internal/app/methods_db_transaction.go @@ -30,6 +30,13 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa } beginSQL, commitSQL, rollbackSQL, hasTextTransaction := sqlFileBatchTransactionSQL(runConfig.Type) + implicitTextTransaction := false + if implicitCommitSQL, implicitRollbackSQL, ok := sqlEditorImplicitTransactionSQL(runConfig.Type); ok { + commitSQL = implicitCommitSQL + rollbackSQL = implicitRollbackSQL + hasTextTransaction = true + implicitTextTransaction = true + } dbInst, err := a.getDatabase(runConfig) if err != nil { @@ -58,7 +65,21 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa transactionCancel context.CancelFunc startTextTransaction bool ) - if provider, ok := dbInst.(db.TransactionExecerProvider); ok { + if implicitTextTransaction { + provider, ok := dbInst.(db.SessionExecerProvider) + if !ok { + return connection.QueryResult{ + Success: false, + Message: fmt.Sprintf("当前数据源(%s)不支持 SQL 编辑器托管事务", runConfig.Type), + QueryID: queryID, + } + } + sessionExecer, err = provider.OpenSessionExecer(ctx) + if err != nil { + logger.Error(err, "DBQueryMultiTransactional 打开隐式事务会话失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query)) + return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID} + } + } else if provider, ok := dbInst.(db.TransactionExecerProvider); ok { // database/sql rolls back a BeginTx transaction when its context is cancelled. // SQL editor transactions must outlive the execution RPC and be ended only by // explicit commit, rollback, or shutdown cleanup. @@ -276,6 +297,18 @@ func shouldUseManagedSQLTransaction(dbType string, query string) bool { return hasManagedWrite } +func sqlEditorImplicitTransactionSQL(dbType string) (commitSQL string, rollbackSQL string, ok bool) { + switch strings.ToLower(strings.TrimSpace(dbType)) { + case "oracle": + // Oracle starts a transaction implicitly on the first DML statement. + // Keeping SQL editor DML on one physical connection avoids database/sql + // Tx context lifecycle ending the transaction before the UI commits it. + return "COMMIT", "ROLLBACK", true + default: + return "", "", false + } +} + func isSQLTransactionControlStatement(stmt string) bool { switch leadingSQLKeyword(stmt) { case "begin", "commit", "rollback", "savepoint", "release": diff --git a/internal/db/oracle_impl.go b/internal/db/oracle_impl.go index 2cccfb6..411c7be 100644 --- a/internal/db/oracle_impl.go +++ b/internal/db/oracle_impl.go @@ -24,6 +24,7 @@ type OracleDB struct { forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder } +var _ SessionExecerProvider = (*OracleDB)(nil) var _ TransactionExecerProvider = (*OracleDB)(nil) func (o *OracleDB) getDSN(config connection.ConnectionConfig) string { @@ -264,6 +265,17 @@ func (o *OracleDB) OpenTransactionExecer(ctx context.Context) (TransactionExecer return NewSQLTxStatementExecer(tx), nil } +func (o *OracleDB) OpenSessionExecer(ctx context.Context) (StatementExecer, error) { + if o.conn == nil { + return nil, fmt.Errorf("连接未打开") + } + conn, err := o.conn.Conn(ctx) + if err != nil { + return nil, err + } + return NewSQLConnStatementExecer(conn), nil +} + func (o *OracleDB) GetDatabases() ([]string, error) { // Oracle treats Users/Schemas as "Databases" in this context data, _, err := o.Query("SELECT username FROM all_users ORDER BY username")