diff --git a/internal/app/app.go b/internal/app/app.go index 886d0c1..c06865f 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -58,6 +58,7 @@ type queryContext struct { type managedSQLTransaction struct { id string execer db.StatementExecer + transactor db.TransactionExecer dbType string commitSQL string rollbackSQL string diff --git a/internal/app/methods_db_multi_test.go b/internal/app/methods_db_multi_test.go index b44fa6e..a6ac3b4 100644 --- a/internal/app/methods_db_multi_test.go +++ b/internal/app/methods_db_multi_test.go @@ -219,6 +219,34 @@ func (s *fakeBatchWriteSession) Close() error { return nil } +type fakeTransactionalDB struct { + fakeBatchWriteDB + txSession *fakeTransactionSession +} + +func (f *fakeTransactionalDB) OpenTransactionExecer(ctx context.Context) (db.TransactionExecer, error) { + f.txSession = &fakeTransactionSession{ + fakeBatchWriteSession: fakeBatchWriteSession{parent: &f.fakeBatchWriteDB}, + } + return f.txSession, nil +} + +type fakeTransactionSession struct { + fakeBatchWriteSession + commitCalls int + rollbackCalls int +} + +func (s *fakeTransactionSession) Commit() error { + s.commitCalls++ + return nil +} + +func (s *fakeTransactionSession) Rollback() error { + s.rollbackCalls++ + return nil +} + func cloneResultSets(input []connection.ResultSetData) []connection.ResultSetData { if len(input) == 0 { return nil @@ -584,6 +612,111 @@ func TestDBQueryMultiTransactionalKeepsDMLTransactionOpenUntilCommit(t *testing. } } +func TestDBQueryMultiTransactionalUsesDriverTransactionForOracle(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + t.Cleanup(func() { + newDatabaseFunc = originalNewDatabaseFunc + }) + + firstStmt := "UPDATE users SET name = 'new' WHERE id = 1" + secondStmt := "DELETE FROM audit_logs WHERE user_id = 1" + fakeDB := &fakeTransactionalDB{ + fakeBatchWriteDB: fakeBatchWriteDB{ + execAffected: map[string]int64{ + firstStmt: 1, + secondStmt: 3, + }, + }, + } + newDatabaseFunc = func(dbType string) (db.Database, error) { + return fakeDB, nil + } + + app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test")) + config := connection.ConnectionConfig{Type: "oracle", Host: "127.0.0.1", Port: 1521, User: "app"} + + result := app.DBQueryMultiTransactional(config, "ORCLPDB1", firstStmt+";\n"+secondStmt+";", "oracle-tx-query") + if !result.Success { + t.Fatalf("expected Oracle transactional query success, got failure: %s", result.Message) + } + if result.TransactionID == "" || !result.TransactionPending { + t.Fatalf("expected pending transaction metadata, got id=%q pending=%v", result.TransactionID, result.TransactionPending) + } + if fakeDB.txSession == nil { + t.Fatal("expected Oracle transactional query to open a driver transaction") + } + if fakeDB.txSession.closed { + t.Fatal("expected Oracle transaction session to stay open before commit") + } + wantExecs := []string{firstStmt, secondStmt} + if len(fakeDB.execQueries) != len(wantExecs) { + t.Fatalf("expected driver transaction exec queries %#v, got %#v", wantExecs, fakeDB.execQueries) + } + for i, want := range wantExecs { + if fakeDB.execQueries[i] != want { + t.Fatalf("expected exec query %d = %q, got %q", i, want, fakeDB.execQueries[i]) + } + } + + commitResult := app.DBCommitTransaction(result.TransactionID) + if !commitResult.Success { + t.Fatalf("expected Oracle commit success, got failure: %s", commitResult.Message) + } + if fakeDB.txSession.commitCalls != 1 || fakeDB.txSession.rollbackCalls != 0 { + t.Fatalf("expected driver commit only, got commits=%d rollbacks=%d", fakeDB.txSession.commitCalls, fakeDB.txSession.rollbackCalls) + } + if !fakeDB.txSession.closed { + t.Fatal("expected Oracle transaction session to close after commit") + } + if len(fakeDB.execQueries) != len(wantExecs) { + t.Fatalf("expected no textual BEGIN/COMMIT for Oracle, got %#v", fakeDB.execQueries) + } +} + +func TestDBQueryMultiTransactionalRollsBackOracleDriverTransactionOnDMLFailure(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + t.Cleanup(func() { + newDatabaseFunc = originalNewDatabaseFunc + }) + + firstStmt := "UPDATE users SET name = 'new' WHERE id = 1" + secondStmt := "DELETE FROM audit_logs WHERE user_id = 1" + fakeDB := &fakeTransactionalDB{ + fakeBatchWriteDB: fakeBatchWriteDB{ + execErr: map[string]error{ + secondStmt: errors.New("delete failed"), + }, + }, + } + newDatabaseFunc = func(dbType string) (db.Database, error) { + return fakeDB, nil + } + + app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test")) + config := connection.ConnectionConfig{Type: "oracle", Host: "127.0.0.1", Port: 1521, User: "app"} + + result := app.DBQueryMultiTransactional(config, "ORCLPDB1", firstStmt+";\n"+secondStmt+";", "oracle-tx-failure") + if result.Success { + t.Fatal("expected Oracle transactional query failure") + } + if result.TransactionID != "" || result.TransactionPending { + t.Fatalf("expected failed transaction not to be exposed, got id=%q pending=%v", result.TransactionID, result.TransactionPending) + } + if fakeDB.txSession == nil { + t.Fatal("expected Oracle transactional query to open a driver transaction") + } + if fakeDB.txSession.commitCalls != 0 || fakeDB.txSession.rollbackCalls != 1 { + t.Fatalf("expected driver rollback only, got commits=%d rollbacks=%d", fakeDB.txSession.commitCalls, fakeDB.txSession.rollbackCalls) + } + if !fakeDB.txSession.closed { + t.Fatal("expected failed Oracle transaction session to close") + } + wantExecs := []string{firstStmt, secondStmt} + if len(fakeDB.execQueries) != len(wantExecs) { + t.Fatalf("expected no textual BEGIN/ROLLBACK for Oracle, got %#v", fakeDB.execQueries) + } +} + func TestDBQueryMultiTransactionalTreatsWithDMLAsManagedWrite(t *testing.T) { originalNewDatabaseFunc := newDatabaseFunc t.Cleanup(func() { diff --git a/internal/app/methods_db_transaction.go b/internal/app/methods_db_transaction.go index 2c6765a..4adcd5c 100644 --- a/internal/app/methods_db_transaction.go +++ b/internal/app/methods_db_transaction.go @@ -29,14 +29,7 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa return a.DBQueryMulti(config, dbName, query, queryID) } - beginSQL, commitSQL, rollbackSQL, ok := sqlFileBatchTransactionSQL(runConfig.Type) - if !ok { - return connection.QueryResult{ - Success: false, - Message: fmt.Sprintf("当前数据源(%s)不支持 SQL 编辑器托管事务", runConfig.Type), - QueryID: queryID, - } - } + beginSQL, commitSQL, rollbackSQL, hasTextTransaction := sqlFileBatchTransactionSQL(runConfig.Type) dbInst, err := a.getDatabase(runConfig) if err != nil { @@ -44,15 +37,6 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID} } - provider, ok := dbInst.(db.SessionExecerProvider) - if !ok { - return connection.QueryResult{ - Success: false, - Message: fmt.Sprintf("当前数据源(%s)不支持 SQL 编辑器托管事务", runConfig.Type), - QueryID: queryID, - } - } - ctx, cancel := newQueryExecutionContext(runConfig) defer cancel() @@ -68,10 +52,41 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa a.queryMu.Unlock() }() - sessionExecer, err := provider.OpenSessionExecer(ctx) - if err != nil { - logger.Error(err, "DBQueryMultiTransactional 打开事务会话失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query)) - return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID} + var ( + sessionExecer db.StatementExecer + transactor db.TransactionExecer + startTextTransaction bool + ) + if provider, ok := dbInst.(db.TransactionExecerProvider); ok { + transactionExecer, err := provider.OpenTransactionExecer(ctx) + if err != nil { + logger.Error(err, "DBQueryMultiTransactional 打开驱动事务失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query)) + return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID} + } + sessionExecer = transactionExecer + transactor = transactionExecer + } else { + if !hasTextTransaction { + return connection.QueryResult{ + Success: false, + Message: fmt.Sprintf("当前数据源(%s)不支持 SQL 编辑器托管事务", runConfig.Type), + QueryID: queryID, + } + } + provider, ok := dbInst.(db.SessionExecerProvider) + if !ok { + return connection.QueryResult{ + Success: false, + Message: fmt.Sprintf("当前数据源(%s)不支持 SQL 编辑器托管事务", runConfig.Type), + QueryID: queryID, + } + } + sessionExecer, err = provider.OpenSessionExecer(ctx) + if err != nil { + logger.Error(err, "DBQueryMultiTransactional 打开事务会话失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query)) + return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID} + } + startTextTransaction = true } closeSession := true @@ -83,15 +98,23 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa } }() - if _, err := sessionExecer.ExecContext(ctx, beginSQL); err != nil { - logger.Error(err, "DBQueryMultiTransactional 开启事务失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query)) - return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID} + if startTextTransaction { + if _, err := sessionExecer.ExecContext(ctx, beginSQL); err != nil { + logger.Error(err, "DBQueryMultiTransactional 开启事务失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query)) + return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID} + } } statements := splitSQLStatements(query) resultSets, err := executeManagedSQLTransactionStatements(ctx, sessionExecer, runConfig, statements) if err != nil { - if _, rollbackErr := sessionExecer.ExecContext(context.Background(), rollbackSQL); rollbackErr != nil { + var rollbackErr error + if transactor != nil { + rollbackErr = transactor.Rollback() + } else if strings.TrimSpace(rollbackSQL) != "" { + _, rollbackErr = sessionExecer.ExecContext(context.Background(), rollbackSQL) + } + if rollbackErr != nil { logger.Error(rollbackErr, "DBQueryMultiTransactional 执行失败后回滚失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query)) err = fmt.Errorf("%v;回滚失败: %w", err, rollbackErr) } @@ -107,6 +130,7 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa a.sqlTransactions[transactionID] = &managedSQLTransaction{ id: transactionID, execer: sessionExecer, + transactor: transactor, dbType: runConfig.Type, commitSQL: commitSQL, rollbackSQL: rollbackSQL, @@ -287,7 +311,13 @@ func (a *App) finishManagedSQLTransaction(transactionID string, commit bool) con defer cancel() var execErr error - if strings.TrimSpace(sqlText) != "" { + if tx.transactor != nil { + if commit { + execErr = tx.transactor.Commit() + } else { + execErr = tx.transactor.Rollback() + } + } else if strings.TrimSpace(sqlText) != "" { _, execErr = tx.execer.ExecContext(ctx, sqlText) } closeErr := tx.execer.Close() @@ -319,7 +349,11 @@ func (a *App) rollbackPendingSQLTransactionsOnShutdown() { for _, tx := range pending { ctx, cancel := context.WithTimeout(context.Background(), sqlEditorTransactionFinishTimeout) - if strings.TrimSpace(tx.rollbackSQL) != "" && tx.execer != nil { + if tx.transactor != nil { + if err := tx.transactor.Rollback(); err != nil { + logger.Warnf("关闭应用时回滚 SQL 编辑器事务失败:id=%s dbType=%s err=%v", tx.id, tx.dbType, err) + } + } else if strings.TrimSpace(tx.rollbackSQL) != "" && tx.execer != nil { if _, err := tx.execer.ExecContext(ctx, tx.rollbackSQL); err != nil { logger.Warnf("关闭应用时回滚 SQL 编辑器事务失败:id=%s dbType=%s err=%v", tx.id, tx.dbType, err) } diff --git a/internal/db/database.go b/internal/db/database.go index e1cd184..7ca6c1c 100644 --- a/internal/db/database.go +++ b/internal/db/database.go @@ -6,6 +6,7 @@ import ( "database/sql" "fmt" "strings" + "sync" ) // Database 定义了统一的数据源访问接口。 @@ -118,6 +119,21 @@ type SessionExecerProvider interface { OpenSessionExecer(ctx context.Context) (StatementExecer, error) } +// TransactionExecer is a single transaction handle backed by the database +// driver. It is required for dialects where textual BEGIN/COMMIT is not a +// valid transaction-control statement, such as Oracle. +type TransactionExecer interface { + StatementExecer + Commit() error + Rollback() error +} + +// TransactionExecerProvider is implemented by drivers that can expose a real +// database/sql transaction for long-running SQL editor managed transactions. +type TransactionExecerProvider interface { + OpenTransactionExecer(ctx context.Context) (TransactionExecer, error) +} + type sqlConnStatementExecer struct { conn *sql.Conn } @@ -184,6 +200,109 @@ func (e *sqlConnStatementExecer) Close() error { return e.conn.Close() } +type sqlTxStatementExecer struct { + mu sync.Mutex + tx *sql.Tx + done bool +} + +func NewSQLTxStatementExecer(tx *sql.Tx) TransactionExecer { + return &sqlTxStatementExecer{tx: tx} +} + +func (e *sqlTxStatementExecer) activeTx() (*sql.Tx, error) { + if e == nil || e.tx == nil { + return nil, fmt.Errorf("事务未打开") + } + e.mu.Lock() + defer e.mu.Unlock() + if e.done { + return nil, fmt.Errorf("事务已结束") + } + return e.tx, nil +} + +func (e *sqlTxStatementExecer) ExecContext(ctx context.Context, query string) (int64, error) { + tx, err := e.activeTx() + if err != nil { + return 0, err + } + res, err := tx.ExecContext(ctx, query) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func (e *sqlTxStatementExecer) Exec(query string) (int64, error) { + return e.ExecContext(context.Background(), query) +} + +func (e *sqlTxStatementExecer) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { + tx, err := e.activeTx() + if err != nil { + return nil, nil, err + } + rows, err := tx.QueryContext(ctx, query) + if err != nil { + return nil, nil, err + } + defer rows.Close() + return scanRows(rows) +} + +func (e *sqlTxStatementExecer) Query(query string) ([]map[string]interface{}, []string, error) { + return e.QueryContext(context.Background(), query) +} + +func (e *sqlTxStatementExecer) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) { + tx, err := e.activeTx() + if err != nil { + return nil, err + } + rows, err := tx.QueryContext(ctx, query) + if err != nil { + return nil, err + } + defer rows.Close() + return scanMultiRows(rows) +} + +func (e *sqlTxStatementExecer) QueryMulti(query string) ([]connection.ResultSetData, error) { + return e.QueryMultiContext(context.Background(), query) +} + +func (e *sqlTxStatementExecer) finish(action func(*sql.Tx) error) error { + if e == nil || e.tx == nil { + return nil + } + e.mu.Lock() + if e.done { + e.mu.Unlock() + return nil + } + tx := e.tx + e.done = true + e.mu.Unlock() + return action(tx) +} + +func (e *sqlTxStatementExecer) Commit() error { + return e.finish(func(tx *sql.Tx) error { + return tx.Commit() + }) +} + +func (e *sqlTxStatementExecer) Rollback() error { + return e.finish(func(tx *sql.Tx) error { + return tx.Rollback() + }) +} + +func (e *sqlTxStatementExecer) Close() error { + return e.Rollback() +} + // BatchApplier 定义了批量变更提交接口。 // 支持批量编辑的驱动实现此接口,用于一次性提交前端 DataGrid 中的增删改操作。 type BatchApplier interface { diff --git a/internal/db/oracle_impl.go b/internal/db/oracle_impl.go index 00a8dbf..2cccfb6 100644 --- a/internal/db/oracle_impl.go +++ b/internal/db/oracle_impl.go @@ -24,6 +24,8 @@ type OracleDB struct { forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder } +var _ TransactionExecerProvider = (*OracleDB)(nil) + func (o *OracleDB) getDSN(config connection.ConnectionConfig) string { // oracle://user:pass@host:port/service_name database := strings.TrimSpace(config.Database) @@ -251,6 +253,17 @@ func (o *OracleDB) Exec(query string) (int64, error) { return res.RowsAffected() } +func (o *OracleDB) OpenTransactionExecer(ctx context.Context) (TransactionExecer, error) { + if o.conn == nil { + return nil, fmt.Errorf("连接未打开") + } + tx, err := o.conn.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + return NewSQLTxStatementExecer(tx), nil +} + func (o *OracleDB) GetDatabases() ([]string, error) { // Oracle treats Users/Schemas as "Databases" in this context data, _, err := o.Query("SELECT username FROM all_users ORDER BY username")