diff --git a/internal/app/app.go b/internal/app/app.go index c06865f..f351f71 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -59,6 +59,7 @@ type managedSQLTransaction struct { id string execer db.StatementExecer transactor db.TransactionExecer + cancel context.CancelFunc dbType string commitSQL string rollbackSQL string diff --git a/internal/app/methods_db_multi_test.go b/internal/app/methods_db_multi_test.go index a6ac3b4..e29fda4 100644 --- a/internal/app/methods_db_multi_test.go +++ b/internal/app/methods_db_multi_test.go @@ -2,6 +2,7 @@ package app import ( "context" + "database/sql" "errors" "testing" @@ -227,22 +228,30 @@ type fakeTransactionalDB struct { func (f *fakeTransactionalDB) OpenTransactionExecer(ctx context.Context) (db.TransactionExecer, error) { f.txSession = &fakeTransactionSession{ fakeBatchWriteSession: fakeBatchWriteSession{parent: &f.fakeBatchWriteDB}, + beginCtx: ctx, } return f.txSession, nil } type fakeTransactionSession struct { fakeBatchWriteSession + beginCtx context.Context commitCalls int rollbackCalls int } func (s *fakeTransactionSession) Commit() error { + if s.beginCtx != nil && s.beginCtx.Err() != nil { + return sql.ErrTxDone + } s.commitCalls++ return nil } func (s *fakeTransactionSession) Rollback() error { + if s.beginCtx != nil && s.beginCtx.Err() != nil { + return sql.ErrTxDone + } s.rollbackCalls++ return nil } diff --git a/internal/app/methods_db_transaction.go b/internal/app/methods_db_transaction.go index 4adcd5c..dba4436 100644 --- a/internal/app/methods_db_transaction.go +++ b/internal/app/methods_db_transaction.go @@ -55,11 +55,18 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa var ( sessionExecer db.StatementExecer transactor db.TransactionExecer + transactionCancel context.CancelFunc startTextTransaction bool ) if provider, ok := dbInst.(db.TransactionExecerProvider); ok { - transactionExecer, err := provider.OpenTransactionExecer(ctx) + transactionContext := context.Background() + if a.ctx != nil { + transactionContext = a.ctx + } + transactionContext, transactionCancel = context.WithCancel(transactionContext) + transactionExecer, err := provider.OpenTransactionExecer(transactionContext) if err != nil { + transactionCancel() logger.Error(err, "DBQueryMultiTransactional 打开驱动事务失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query)) return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID} } @@ -95,6 +102,9 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa if err := sessionExecer.Close(); err != nil { logger.Warnf("DBQueryMultiTransactional 关闭事务会话失败:%v", err) } + if transactionCancel != nil { + transactionCancel() + } } }() @@ -131,6 +141,7 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa id: transactionID, execer: sessionExecer, transactor: transactor, + cancel: transactionCancel, dbType: runConfig.Type, commitSQL: commitSQL, rollbackSQL: rollbackSQL, @@ -299,6 +310,9 @@ func (a *App) finishManagedSQLTransaction(transactionID string, commit bool) con if !ok || tx == nil || tx.execer == nil { return connection.QueryResult{Success: false, Message: "事务不存在或已结束"} } + if tx.cancel != nil { + defer tx.cancel() + } action := "回滚" sqlText := tx.rollbackSQL @@ -359,6 +373,9 @@ func (a *App) rollbackPendingSQLTransactionsOnShutdown() { } } cancel() + if tx.cancel != nil { + tx.cancel() + } if tx.execer != nil { if err := tx.execer.Close(); err != nil { logger.Warnf("关闭应用时关闭 SQL 编辑器事务会话失败:id=%s dbType=%s err=%v", tx.id, tx.dbType, err)