🐛 fix(query-editor): 支持 Oracle SQL 编辑器托管事务

- 新增 driver transaction 执行器,支持不适合文本 BEGIN 的数据库

- Oracle SQL 编辑器 DML 托管事务改用 database/sql Tx 提交和回滚

- 补充 Oracle 托管事务提交和失败回滚回归测试
This commit is contained in:
Syngnat
2026-06-11 15:45:13 +08:00
parent 06583abad9
commit e4672062f8
5 changed files with 327 additions and 27 deletions

View File

@@ -58,6 +58,7 @@ type queryContext struct {
type managedSQLTransaction struct {
id string
execer db.StatementExecer
transactor db.TransactionExecer
dbType string
commitSQL string
rollbackSQL string

View File

@@ -219,6 +219,34 @@ func (s *fakeBatchWriteSession) Close() error {
return nil
}
type fakeTransactionalDB struct {
fakeBatchWriteDB
txSession *fakeTransactionSession
}
func (f *fakeTransactionalDB) OpenTransactionExecer(ctx context.Context) (db.TransactionExecer, error) {
f.txSession = &fakeTransactionSession{
fakeBatchWriteSession: fakeBatchWriteSession{parent: &f.fakeBatchWriteDB},
}
return f.txSession, nil
}
type fakeTransactionSession struct {
fakeBatchWriteSession
commitCalls int
rollbackCalls int
}
func (s *fakeTransactionSession) Commit() error {
s.commitCalls++
return nil
}
func (s *fakeTransactionSession) Rollback() error {
s.rollbackCalls++
return nil
}
func cloneResultSets(input []connection.ResultSetData) []connection.ResultSetData {
if len(input) == 0 {
return nil
@@ -584,6 +612,111 @@ func TestDBQueryMultiTransactionalKeepsDMLTransactionOpenUntilCommit(t *testing.
}
}
func TestDBQueryMultiTransactionalUsesDriverTransactionForOracle(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
})
firstStmt := "UPDATE users SET name = 'new' WHERE id = 1"
secondStmt := "DELETE FROM audit_logs WHERE user_id = 1"
fakeDB := &fakeTransactionalDB{
fakeBatchWriteDB: fakeBatchWriteDB{
execAffected: map[string]int64{
firstStmt: 1,
secondStmt: 3,
},
},
}
newDatabaseFunc = func(dbType string) (db.Database, error) {
return fakeDB, nil
}
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
config := connection.ConnectionConfig{Type: "oracle", Host: "127.0.0.1", Port: 1521, User: "app"}
result := app.DBQueryMultiTransactional(config, "ORCLPDB1", firstStmt+";\n"+secondStmt+";", "oracle-tx-query")
if !result.Success {
t.Fatalf("expected Oracle transactional query success, got failure: %s", result.Message)
}
if result.TransactionID == "" || !result.TransactionPending {
t.Fatalf("expected pending transaction metadata, got id=%q pending=%v", result.TransactionID, result.TransactionPending)
}
if fakeDB.txSession == nil {
t.Fatal("expected Oracle transactional query to open a driver transaction")
}
if fakeDB.txSession.closed {
t.Fatal("expected Oracle transaction session to stay open before commit")
}
wantExecs := []string{firstStmt, secondStmt}
if len(fakeDB.execQueries) != len(wantExecs) {
t.Fatalf("expected driver transaction exec queries %#v, got %#v", wantExecs, fakeDB.execQueries)
}
for i, want := range wantExecs {
if fakeDB.execQueries[i] != want {
t.Fatalf("expected exec query %d = %q, got %q", i, want, fakeDB.execQueries[i])
}
}
commitResult := app.DBCommitTransaction(result.TransactionID)
if !commitResult.Success {
t.Fatalf("expected Oracle commit success, got failure: %s", commitResult.Message)
}
if fakeDB.txSession.commitCalls != 1 || fakeDB.txSession.rollbackCalls != 0 {
t.Fatalf("expected driver commit only, got commits=%d rollbacks=%d", fakeDB.txSession.commitCalls, fakeDB.txSession.rollbackCalls)
}
if !fakeDB.txSession.closed {
t.Fatal("expected Oracle transaction session to close after commit")
}
if len(fakeDB.execQueries) != len(wantExecs) {
t.Fatalf("expected no textual BEGIN/COMMIT for Oracle, got %#v", fakeDB.execQueries)
}
}
func TestDBQueryMultiTransactionalRollsBackOracleDriverTransactionOnDMLFailure(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
})
firstStmt := "UPDATE users SET name = 'new' WHERE id = 1"
secondStmt := "DELETE FROM audit_logs WHERE user_id = 1"
fakeDB := &fakeTransactionalDB{
fakeBatchWriteDB: fakeBatchWriteDB{
execErr: map[string]error{
secondStmt: errors.New("delete failed"),
},
},
}
newDatabaseFunc = func(dbType string) (db.Database, error) {
return fakeDB, nil
}
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
config := connection.ConnectionConfig{Type: "oracle", Host: "127.0.0.1", Port: 1521, User: "app"}
result := app.DBQueryMultiTransactional(config, "ORCLPDB1", firstStmt+";\n"+secondStmt+";", "oracle-tx-failure")
if result.Success {
t.Fatal("expected Oracle transactional query failure")
}
if result.TransactionID != "" || result.TransactionPending {
t.Fatalf("expected failed transaction not to be exposed, got id=%q pending=%v", result.TransactionID, result.TransactionPending)
}
if fakeDB.txSession == nil {
t.Fatal("expected Oracle transactional query to open a driver transaction")
}
if fakeDB.txSession.commitCalls != 0 || fakeDB.txSession.rollbackCalls != 1 {
t.Fatalf("expected driver rollback only, got commits=%d rollbacks=%d", fakeDB.txSession.commitCalls, fakeDB.txSession.rollbackCalls)
}
if !fakeDB.txSession.closed {
t.Fatal("expected failed Oracle transaction session to close")
}
wantExecs := []string{firstStmt, secondStmt}
if len(fakeDB.execQueries) != len(wantExecs) {
t.Fatalf("expected no textual BEGIN/ROLLBACK for Oracle, got %#v", fakeDB.execQueries)
}
}
func TestDBQueryMultiTransactionalTreatsWithDMLAsManagedWrite(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {

View File

@@ -29,14 +29,7 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa
return a.DBQueryMulti(config, dbName, query, queryID)
}
beginSQL, commitSQL, rollbackSQL, ok := sqlFileBatchTransactionSQL(runConfig.Type)
if !ok {
return connection.QueryResult{
Success: false,
Message: fmt.Sprintf("当前数据源(%s不支持 SQL 编辑器托管事务", runConfig.Type),
QueryID: queryID,
}
}
beginSQL, commitSQL, rollbackSQL, hasTextTransaction := sqlFileBatchTransactionSQL(runConfig.Type)
dbInst, err := a.getDatabase(runConfig)
if err != nil {
@@ -44,15 +37,6 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
provider, ok := dbInst.(db.SessionExecerProvider)
if !ok {
return connection.QueryResult{
Success: false,
Message: fmt.Sprintf("当前数据源(%s不支持 SQL 编辑器托管事务", runConfig.Type),
QueryID: queryID,
}
}
ctx, cancel := newQueryExecutionContext(runConfig)
defer cancel()
@@ -68,10 +52,41 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa
a.queryMu.Unlock()
}()
sessionExecer, err := provider.OpenSessionExecer(ctx)
if err != nil {
logger.Error(err, "DBQueryMultiTransactional 打开事务会话失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
var (
sessionExecer db.StatementExecer
transactor db.TransactionExecer
startTextTransaction bool
)
if provider, ok := dbInst.(db.TransactionExecerProvider); ok {
transactionExecer, err := provider.OpenTransactionExecer(ctx)
if err != nil {
logger.Error(err, "DBQueryMultiTransactional 打开驱动事务失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
sessionExecer = transactionExecer
transactor = transactionExecer
} else {
if !hasTextTransaction {
return connection.QueryResult{
Success: false,
Message: fmt.Sprintf("当前数据源(%s不支持 SQL 编辑器托管事务", runConfig.Type),
QueryID: queryID,
}
}
provider, ok := dbInst.(db.SessionExecerProvider)
if !ok {
return connection.QueryResult{
Success: false,
Message: fmt.Sprintf("当前数据源(%s不支持 SQL 编辑器托管事务", runConfig.Type),
QueryID: queryID,
}
}
sessionExecer, err = provider.OpenSessionExecer(ctx)
if err != nil {
logger.Error(err, "DBQueryMultiTransactional 打开事务会话失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
startTextTransaction = true
}
closeSession := true
@@ -83,15 +98,23 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa
}
}()
if _, err := sessionExecer.ExecContext(ctx, beginSQL); err != nil {
logger.Error(err, "DBQueryMultiTransactional 开启事务失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
if startTextTransaction {
if _, err := sessionExecer.ExecContext(ctx, beginSQL); err != nil {
logger.Error(err, "DBQueryMultiTransactional 开启事务失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
}
statements := splitSQLStatements(query)
resultSets, err := executeManagedSQLTransactionStatements(ctx, sessionExecer, runConfig, statements)
if err != nil {
if _, rollbackErr := sessionExecer.ExecContext(context.Background(), rollbackSQL); rollbackErr != nil {
var rollbackErr error
if transactor != nil {
rollbackErr = transactor.Rollback()
} else if strings.TrimSpace(rollbackSQL) != "" {
_, rollbackErr = sessionExecer.ExecContext(context.Background(), rollbackSQL)
}
if rollbackErr != nil {
logger.Error(rollbackErr, "DBQueryMultiTransactional 执行失败后回滚失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
err = fmt.Errorf("%v回滚失败: %w", err, rollbackErr)
}
@@ -107,6 +130,7 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa
a.sqlTransactions[transactionID] = &managedSQLTransaction{
id: transactionID,
execer: sessionExecer,
transactor: transactor,
dbType: runConfig.Type,
commitSQL: commitSQL,
rollbackSQL: rollbackSQL,
@@ -287,7 +311,13 @@ func (a *App) finishManagedSQLTransaction(transactionID string, commit bool) con
defer cancel()
var execErr error
if strings.TrimSpace(sqlText) != "" {
if tx.transactor != nil {
if commit {
execErr = tx.transactor.Commit()
} else {
execErr = tx.transactor.Rollback()
}
} else if strings.TrimSpace(sqlText) != "" {
_, execErr = tx.execer.ExecContext(ctx, sqlText)
}
closeErr := tx.execer.Close()
@@ -319,7 +349,11 @@ func (a *App) rollbackPendingSQLTransactionsOnShutdown() {
for _, tx := range pending {
ctx, cancel := context.WithTimeout(context.Background(), sqlEditorTransactionFinishTimeout)
if strings.TrimSpace(tx.rollbackSQL) != "" && tx.execer != nil {
if tx.transactor != nil {
if err := tx.transactor.Rollback(); err != nil {
logger.Warnf("关闭应用时回滚 SQL 编辑器事务失败id=%s dbType=%s err=%v", tx.id, tx.dbType, err)
}
} else if strings.TrimSpace(tx.rollbackSQL) != "" && tx.execer != nil {
if _, err := tx.execer.ExecContext(ctx, tx.rollbackSQL); err != nil {
logger.Warnf("关闭应用时回滚 SQL 编辑器事务失败id=%s dbType=%s err=%v", tx.id, tx.dbType, err)
}

View File

@@ -6,6 +6,7 @@ import (
"database/sql"
"fmt"
"strings"
"sync"
)
// Database 定义了统一的数据源访问接口。
@@ -118,6 +119,21 @@ type SessionExecerProvider interface {
OpenSessionExecer(ctx context.Context) (StatementExecer, error)
}
// TransactionExecer is a single transaction handle backed by the database
// driver. It is required for dialects where textual BEGIN/COMMIT is not a
// valid transaction-control statement, such as Oracle.
type TransactionExecer interface {
StatementExecer
Commit() error
Rollback() error
}
// TransactionExecerProvider is implemented by drivers that can expose a real
// database/sql transaction for long-running SQL editor managed transactions.
type TransactionExecerProvider interface {
OpenTransactionExecer(ctx context.Context) (TransactionExecer, error)
}
type sqlConnStatementExecer struct {
conn *sql.Conn
}
@@ -184,6 +200,109 @@ func (e *sqlConnStatementExecer) Close() error {
return e.conn.Close()
}
type sqlTxStatementExecer struct {
mu sync.Mutex
tx *sql.Tx
done bool
}
func NewSQLTxStatementExecer(tx *sql.Tx) TransactionExecer {
return &sqlTxStatementExecer{tx: tx}
}
func (e *sqlTxStatementExecer) activeTx() (*sql.Tx, error) {
if e == nil || e.tx == nil {
return nil, fmt.Errorf("事务未打开")
}
e.mu.Lock()
defer e.mu.Unlock()
if e.done {
return nil, fmt.Errorf("事务已结束")
}
return e.tx, nil
}
func (e *sqlTxStatementExecer) ExecContext(ctx context.Context, query string) (int64, error) {
tx, err := e.activeTx()
if err != nil {
return 0, err
}
res, err := tx.ExecContext(ctx, query)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (e *sqlTxStatementExecer) Exec(query string) (int64, error) {
return e.ExecContext(context.Background(), query)
}
func (e *sqlTxStatementExecer) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
tx, err := e.activeTx()
if err != nil {
return nil, nil, err
}
rows, err := tx.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (e *sqlTxStatementExecer) Query(query string) ([]map[string]interface{}, []string, error) {
return e.QueryContext(context.Background(), query)
}
func (e *sqlTxStatementExecer) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
tx, err := e.activeTx()
if err != nil {
return nil, err
}
rows, err := tx.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
return scanMultiRows(rows)
}
func (e *sqlTxStatementExecer) QueryMulti(query string) ([]connection.ResultSetData, error) {
return e.QueryMultiContext(context.Background(), query)
}
func (e *sqlTxStatementExecer) finish(action func(*sql.Tx) error) error {
if e == nil || e.tx == nil {
return nil
}
e.mu.Lock()
if e.done {
e.mu.Unlock()
return nil
}
tx := e.tx
e.done = true
e.mu.Unlock()
return action(tx)
}
func (e *sqlTxStatementExecer) Commit() error {
return e.finish(func(tx *sql.Tx) error {
return tx.Commit()
})
}
func (e *sqlTxStatementExecer) Rollback() error {
return e.finish(func(tx *sql.Tx) error {
return tx.Rollback()
})
}
func (e *sqlTxStatementExecer) Close() error {
return e.Rollback()
}
// BatchApplier 定义了批量变更提交接口。
// 支持批量编辑的驱动实现此接口,用于一次性提交前端 DataGrid 中的增删改操作。
type BatchApplier interface {

View File

@@ -24,6 +24,8 @@ type OracleDB struct {
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
}
var _ TransactionExecerProvider = (*OracleDB)(nil)
func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
// oracle://user:pass@host:port/service_name
database := strings.TrimSpace(config.Database)
@@ -251,6 +253,17 @@ func (o *OracleDB) Exec(query string) (int64, error) {
return res.RowsAffected()
}
func (o *OracleDB) OpenTransactionExecer(ctx context.Context) (TransactionExecer, error) {
if o.conn == nil {
return nil, fmt.Errorf("连接未打开")
}
tx, err := o.conn.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
return NewSQLTxStatementExecer(tx), nil
}
func (o *OracleDB) GetDatabases() ([]string, error) {
// Oracle treats Users/Schemas as "Databases" in this context
data, _, err := o.Query("SELECT username FROM all_users ORDER BY username")