mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-06-28 17:31:32 +08:00
🐛 fix(query-editor): 支持 Oracle SQL 编辑器托管事务
- 新增 driver transaction 执行器,支持不适合文本 BEGIN 的数据库 - Oracle SQL 编辑器 DML 托管事务改用 database/sql Tx 提交和回滚 - 补充 Oracle 托管事务提交和失败回滚回归测试
This commit is contained in:
@@ -58,6 +58,7 @@ type queryContext struct {
|
||||
type managedSQLTransaction struct {
|
||||
id string
|
||||
execer db.StatementExecer
|
||||
transactor db.TransactionExecer
|
||||
dbType string
|
||||
commitSQL string
|
||||
rollbackSQL string
|
||||
|
||||
@@ -219,6 +219,34 @@ func (s *fakeBatchWriteSession) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type fakeTransactionalDB struct {
|
||||
fakeBatchWriteDB
|
||||
txSession *fakeTransactionSession
|
||||
}
|
||||
|
||||
func (f *fakeTransactionalDB) OpenTransactionExecer(ctx context.Context) (db.TransactionExecer, error) {
|
||||
f.txSession = &fakeTransactionSession{
|
||||
fakeBatchWriteSession: fakeBatchWriteSession{parent: &f.fakeBatchWriteDB},
|
||||
}
|
||||
return f.txSession, nil
|
||||
}
|
||||
|
||||
type fakeTransactionSession struct {
|
||||
fakeBatchWriteSession
|
||||
commitCalls int
|
||||
rollbackCalls int
|
||||
}
|
||||
|
||||
func (s *fakeTransactionSession) Commit() error {
|
||||
s.commitCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeTransactionSession) Rollback() error {
|
||||
s.rollbackCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func cloneResultSets(input []connection.ResultSetData) []connection.ResultSetData {
|
||||
if len(input) == 0 {
|
||||
return nil
|
||||
@@ -584,6 +612,111 @@ func TestDBQueryMultiTransactionalKeepsDMLTransactionOpenUntilCommit(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiTransactionalUsesDriverTransactionForOracle(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
})
|
||||
|
||||
firstStmt := "UPDATE users SET name = 'new' WHERE id = 1"
|
||||
secondStmt := "DELETE FROM audit_logs WHERE user_id = 1"
|
||||
fakeDB := &fakeTransactionalDB{
|
||||
fakeBatchWriteDB: fakeBatchWriteDB{
|
||||
execAffected: map[string]int64{
|
||||
firstStmt: 1,
|
||||
secondStmt: 3,
|
||||
},
|
||||
},
|
||||
}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
config := connection.ConnectionConfig{Type: "oracle", Host: "127.0.0.1", Port: 1521, User: "app"}
|
||||
|
||||
result := app.DBQueryMultiTransactional(config, "ORCLPDB1", firstStmt+";\n"+secondStmt+";", "oracle-tx-query")
|
||||
if !result.Success {
|
||||
t.Fatalf("expected Oracle transactional query success, got failure: %s", result.Message)
|
||||
}
|
||||
if result.TransactionID == "" || !result.TransactionPending {
|
||||
t.Fatalf("expected pending transaction metadata, got id=%q pending=%v", result.TransactionID, result.TransactionPending)
|
||||
}
|
||||
if fakeDB.txSession == nil {
|
||||
t.Fatal("expected Oracle transactional query to open a driver transaction")
|
||||
}
|
||||
if fakeDB.txSession.closed {
|
||||
t.Fatal("expected Oracle transaction session to stay open before commit")
|
||||
}
|
||||
wantExecs := []string{firstStmt, secondStmt}
|
||||
if len(fakeDB.execQueries) != len(wantExecs) {
|
||||
t.Fatalf("expected driver transaction exec queries %#v, got %#v", wantExecs, fakeDB.execQueries)
|
||||
}
|
||||
for i, want := range wantExecs {
|
||||
if fakeDB.execQueries[i] != want {
|
||||
t.Fatalf("expected exec query %d = %q, got %q", i, want, fakeDB.execQueries[i])
|
||||
}
|
||||
}
|
||||
|
||||
commitResult := app.DBCommitTransaction(result.TransactionID)
|
||||
if !commitResult.Success {
|
||||
t.Fatalf("expected Oracle commit success, got failure: %s", commitResult.Message)
|
||||
}
|
||||
if fakeDB.txSession.commitCalls != 1 || fakeDB.txSession.rollbackCalls != 0 {
|
||||
t.Fatalf("expected driver commit only, got commits=%d rollbacks=%d", fakeDB.txSession.commitCalls, fakeDB.txSession.rollbackCalls)
|
||||
}
|
||||
if !fakeDB.txSession.closed {
|
||||
t.Fatal("expected Oracle transaction session to close after commit")
|
||||
}
|
||||
if len(fakeDB.execQueries) != len(wantExecs) {
|
||||
t.Fatalf("expected no textual BEGIN/COMMIT for Oracle, got %#v", fakeDB.execQueries)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiTransactionalRollsBackOracleDriverTransactionOnDMLFailure(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
})
|
||||
|
||||
firstStmt := "UPDATE users SET name = 'new' WHERE id = 1"
|
||||
secondStmt := "DELETE FROM audit_logs WHERE user_id = 1"
|
||||
fakeDB := &fakeTransactionalDB{
|
||||
fakeBatchWriteDB: fakeBatchWriteDB{
|
||||
execErr: map[string]error{
|
||||
secondStmt: errors.New("delete failed"),
|
||||
},
|
||||
},
|
||||
}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
config := connection.ConnectionConfig{Type: "oracle", Host: "127.0.0.1", Port: 1521, User: "app"}
|
||||
|
||||
result := app.DBQueryMultiTransactional(config, "ORCLPDB1", firstStmt+";\n"+secondStmt+";", "oracle-tx-failure")
|
||||
if result.Success {
|
||||
t.Fatal("expected Oracle transactional query failure")
|
||||
}
|
||||
if result.TransactionID != "" || result.TransactionPending {
|
||||
t.Fatalf("expected failed transaction not to be exposed, got id=%q pending=%v", result.TransactionID, result.TransactionPending)
|
||||
}
|
||||
if fakeDB.txSession == nil {
|
||||
t.Fatal("expected Oracle transactional query to open a driver transaction")
|
||||
}
|
||||
if fakeDB.txSession.commitCalls != 0 || fakeDB.txSession.rollbackCalls != 1 {
|
||||
t.Fatalf("expected driver rollback only, got commits=%d rollbacks=%d", fakeDB.txSession.commitCalls, fakeDB.txSession.rollbackCalls)
|
||||
}
|
||||
if !fakeDB.txSession.closed {
|
||||
t.Fatal("expected failed Oracle transaction session to close")
|
||||
}
|
||||
wantExecs := []string{firstStmt, secondStmt}
|
||||
if len(fakeDB.execQueries) != len(wantExecs) {
|
||||
t.Fatalf("expected no textual BEGIN/ROLLBACK for Oracle, got %#v", fakeDB.execQueries)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiTransactionalTreatsWithDMLAsManagedWrite(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
|
||||
@@ -29,14 +29,7 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa
|
||||
return a.DBQueryMulti(config, dbName, query, queryID)
|
||||
}
|
||||
|
||||
beginSQL, commitSQL, rollbackSQL, ok := sqlFileBatchTransactionSQL(runConfig.Type)
|
||||
if !ok {
|
||||
return connection.QueryResult{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("当前数据源(%s)不支持 SQL 编辑器托管事务", runConfig.Type),
|
||||
QueryID: queryID,
|
||||
}
|
||||
}
|
||||
beginSQL, commitSQL, rollbackSQL, hasTextTransaction := sqlFileBatchTransactionSQL(runConfig.Type)
|
||||
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
@@ -44,15 +37,6 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa
|
||||
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
|
||||
}
|
||||
|
||||
provider, ok := dbInst.(db.SessionExecerProvider)
|
||||
if !ok {
|
||||
return connection.QueryResult{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("当前数据源(%s)不支持 SQL 编辑器托管事务", runConfig.Type),
|
||||
QueryID: queryID,
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := newQueryExecutionContext(runConfig)
|
||||
defer cancel()
|
||||
|
||||
@@ -68,10 +52,41 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa
|
||||
a.queryMu.Unlock()
|
||||
}()
|
||||
|
||||
sessionExecer, err := provider.OpenSessionExecer(ctx)
|
||||
if err != nil {
|
||||
logger.Error(err, "DBQueryMultiTransactional 打开事务会话失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
|
||||
var (
|
||||
sessionExecer db.StatementExecer
|
||||
transactor db.TransactionExecer
|
||||
startTextTransaction bool
|
||||
)
|
||||
if provider, ok := dbInst.(db.TransactionExecerProvider); ok {
|
||||
transactionExecer, err := provider.OpenTransactionExecer(ctx)
|
||||
if err != nil {
|
||||
logger.Error(err, "DBQueryMultiTransactional 打开驱动事务失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
|
||||
}
|
||||
sessionExecer = transactionExecer
|
||||
transactor = transactionExecer
|
||||
} else {
|
||||
if !hasTextTransaction {
|
||||
return connection.QueryResult{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("当前数据源(%s)不支持 SQL 编辑器托管事务", runConfig.Type),
|
||||
QueryID: queryID,
|
||||
}
|
||||
}
|
||||
provider, ok := dbInst.(db.SessionExecerProvider)
|
||||
if !ok {
|
||||
return connection.QueryResult{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("当前数据源(%s)不支持 SQL 编辑器托管事务", runConfig.Type),
|
||||
QueryID: queryID,
|
||||
}
|
||||
}
|
||||
sessionExecer, err = provider.OpenSessionExecer(ctx)
|
||||
if err != nil {
|
||||
logger.Error(err, "DBQueryMultiTransactional 打开事务会话失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
|
||||
}
|
||||
startTextTransaction = true
|
||||
}
|
||||
|
||||
closeSession := true
|
||||
@@ -83,15 +98,23 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := sessionExecer.ExecContext(ctx, beginSQL); err != nil {
|
||||
logger.Error(err, "DBQueryMultiTransactional 开启事务失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
|
||||
if startTextTransaction {
|
||||
if _, err := sessionExecer.ExecContext(ctx, beginSQL); err != nil {
|
||||
logger.Error(err, "DBQueryMultiTransactional 开启事务失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
|
||||
}
|
||||
}
|
||||
|
||||
statements := splitSQLStatements(query)
|
||||
resultSets, err := executeManagedSQLTransactionStatements(ctx, sessionExecer, runConfig, statements)
|
||||
if err != nil {
|
||||
if _, rollbackErr := sessionExecer.ExecContext(context.Background(), rollbackSQL); rollbackErr != nil {
|
||||
var rollbackErr error
|
||||
if transactor != nil {
|
||||
rollbackErr = transactor.Rollback()
|
||||
} else if strings.TrimSpace(rollbackSQL) != "" {
|
||||
_, rollbackErr = sessionExecer.ExecContext(context.Background(), rollbackSQL)
|
||||
}
|
||||
if rollbackErr != nil {
|
||||
logger.Error(rollbackErr, "DBQueryMultiTransactional 执行失败后回滚失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
err = fmt.Errorf("%v;回滚失败: %w", err, rollbackErr)
|
||||
}
|
||||
@@ -107,6 +130,7 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa
|
||||
a.sqlTransactions[transactionID] = &managedSQLTransaction{
|
||||
id: transactionID,
|
||||
execer: sessionExecer,
|
||||
transactor: transactor,
|
||||
dbType: runConfig.Type,
|
||||
commitSQL: commitSQL,
|
||||
rollbackSQL: rollbackSQL,
|
||||
@@ -287,7 +311,13 @@ func (a *App) finishManagedSQLTransaction(transactionID string, commit bool) con
|
||||
defer cancel()
|
||||
|
||||
var execErr error
|
||||
if strings.TrimSpace(sqlText) != "" {
|
||||
if tx.transactor != nil {
|
||||
if commit {
|
||||
execErr = tx.transactor.Commit()
|
||||
} else {
|
||||
execErr = tx.transactor.Rollback()
|
||||
}
|
||||
} else if strings.TrimSpace(sqlText) != "" {
|
||||
_, execErr = tx.execer.ExecContext(ctx, sqlText)
|
||||
}
|
||||
closeErr := tx.execer.Close()
|
||||
@@ -319,7 +349,11 @@ func (a *App) rollbackPendingSQLTransactionsOnShutdown() {
|
||||
|
||||
for _, tx := range pending {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), sqlEditorTransactionFinishTimeout)
|
||||
if strings.TrimSpace(tx.rollbackSQL) != "" && tx.execer != nil {
|
||||
if tx.transactor != nil {
|
||||
if err := tx.transactor.Rollback(); err != nil {
|
||||
logger.Warnf("关闭应用时回滚 SQL 编辑器事务失败:id=%s dbType=%s err=%v", tx.id, tx.dbType, err)
|
||||
}
|
||||
} else if strings.TrimSpace(tx.rollbackSQL) != "" && tx.execer != nil {
|
||||
if _, err := tx.execer.ExecContext(ctx, tx.rollbackSQL); err != nil {
|
||||
logger.Warnf("关闭应用时回滚 SQL 编辑器事务失败:id=%s dbType=%s err=%v", tx.id, tx.dbType, err)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Database 定义了统一的数据源访问接口。
|
||||
@@ -118,6 +119,21 @@ type SessionExecerProvider interface {
|
||||
OpenSessionExecer(ctx context.Context) (StatementExecer, error)
|
||||
}
|
||||
|
||||
// TransactionExecer is a single transaction handle backed by the database
|
||||
// driver. It is required for dialects where textual BEGIN/COMMIT is not a
|
||||
// valid transaction-control statement, such as Oracle.
|
||||
type TransactionExecer interface {
|
||||
StatementExecer
|
||||
Commit() error
|
||||
Rollback() error
|
||||
}
|
||||
|
||||
// TransactionExecerProvider is implemented by drivers that can expose a real
|
||||
// database/sql transaction for long-running SQL editor managed transactions.
|
||||
type TransactionExecerProvider interface {
|
||||
OpenTransactionExecer(ctx context.Context) (TransactionExecer, error)
|
||||
}
|
||||
|
||||
type sqlConnStatementExecer struct {
|
||||
conn *sql.Conn
|
||||
}
|
||||
@@ -184,6 +200,109 @@ func (e *sqlConnStatementExecer) Close() error {
|
||||
return e.conn.Close()
|
||||
}
|
||||
|
||||
type sqlTxStatementExecer struct {
|
||||
mu sync.Mutex
|
||||
tx *sql.Tx
|
||||
done bool
|
||||
}
|
||||
|
||||
func NewSQLTxStatementExecer(tx *sql.Tx) TransactionExecer {
|
||||
return &sqlTxStatementExecer{tx: tx}
|
||||
}
|
||||
|
||||
func (e *sqlTxStatementExecer) activeTx() (*sql.Tx, error) {
|
||||
if e == nil || e.tx == nil {
|
||||
return nil, fmt.Errorf("事务未打开")
|
||||
}
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
if e.done {
|
||||
return nil, fmt.Errorf("事务已结束")
|
||||
}
|
||||
return e.tx, nil
|
||||
}
|
||||
|
||||
func (e *sqlTxStatementExecer) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
tx, err := e.activeTx()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
res, err := tx.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (e *sqlTxStatementExecer) Exec(query string) (int64, error) {
|
||||
return e.ExecContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (e *sqlTxStatementExecer) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
tx, err := e.activeTx()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
rows, err := tx.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (e *sqlTxStatementExecer) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
return e.QueryContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (e *sqlTxStatementExecer) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
|
||||
tx, err := e.activeTx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows, err := tx.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanMultiRows(rows)
|
||||
}
|
||||
|
||||
func (e *sqlTxStatementExecer) QueryMulti(query string) ([]connection.ResultSetData, error) {
|
||||
return e.QueryMultiContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (e *sqlTxStatementExecer) finish(action func(*sql.Tx) error) error {
|
||||
if e == nil || e.tx == nil {
|
||||
return nil
|
||||
}
|
||||
e.mu.Lock()
|
||||
if e.done {
|
||||
e.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
tx := e.tx
|
||||
e.done = true
|
||||
e.mu.Unlock()
|
||||
return action(tx)
|
||||
}
|
||||
|
||||
func (e *sqlTxStatementExecer) Commit() error {
|
||||
return e.finish(func(tx *sql.Tx) error {
|
||||
return tx.Commit()
|
||||
})
|
||||
}
|
||||
|
||||
func (e *sqlTxStatementExecer) Rollback() error {
|
||||
return e.finish(func(tx *sql.Tx) error {
|
||||
return tx.Rollback()
|
||||
})
|
||||
}
|
||||
|
||||
func (e *sqlTxStatementExecer) Close() error {
|
||||
return e.Rollback()
|
||||
}
|
||||
|
||||
// BatchApplier 定义了批量变更提交接口。
|
||||
// 支持批量编辑的驱动实现此接口,用于一次性提交前端 DataGrid 中的增删改操作。
|
||||
type BatchApplier interface {
|
||||
|
||||
@@ -24,6 +24,8 @@ type OracleDB struct {
|
||||
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
|
||||
}
|
||||
|
||||
var _ TransactionExecerProvider = (*OracleDB)(nil)
|
||||
|
||||
func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
|
||||
// oracle://user:pass@host:port/service_name
|
||||
database := strings.TrimSpace(config.Database)
|
||||
@@ -251,6 +253,17 @@ func (o *OracleDB) Exec(query string) (int64, error) {
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (o *OracleDB) OpenTransactionExecer(ctx context.Context) (TransactionExecer, error) {
|
||||
if o.conn == nil {
|
||||
return nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
tx, err := o.conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewSQLTxStatementExecer(tx), nil
|
||||
}
|
||||
|
||||
func (o *OracleDB) GetDatabases() ([]string, error) {
|
||||
// Oracle treats Users/Schemas as "Databases" in this context
|
||||
data, _, err := o.Query("SELECT username FROM all_users ORDER BY username")
|
||||
|
||||
Reference in New Issue
Block a user