Files
MyGoNavi/internal/app/methods_db_transaction.go
Syngnat 8a0dc3a7d3 🐛 fix(transaction): 修复 Oracle 托管事务提交回滚失败
- Oracle 托管事务改为固定物理连接执行 COMMIT/ROLLBACK

- SQL 编辑器事务按归一化方言判断 Oracle 兼容协议

- 补充 Oracle 与 OceanBase Oracle 事务回归测试
2026-06-12 02:51:01 +08:00

422 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package app
import (
"context"
"fmt"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/db"
"GoNavi-Wails/internal/logger"
"github.com/google/uuid"
)
const sqlEditorTransactionFinishTimeout = 30 * time.Second
// DBQueryMultiTransactional executes SQL editor DML in a managed transaction.
// The transaction stays open until DBCommitTransaction or DBRollbackTransaction
// is called by the SQL editor UI.
func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbName string, query string, queryID string) connection.QueryResult {
runConfig := normalizeRunConfig(config, dbName)
transactionDBType := resolveDDLDBType(runConfig)
transactionConfig := runConfig
transactionConfig.Type = transactionDBType
if queryID == "" {
queryID = generateQueryID()
}
query = sanitizeSQLForPgLike(transactionDBType, query)
if !shouldUseManagedSQLTransaction(transactionDBType, query) {
return a.DBQueryMulti(config, dbName, query, queryID)
}
beginSQL, commitSQL, rollbackSQL, hasTextTransaction := sqlFileBatchTransactionSQL(transactionDBType)
implicitTextTransaction := false
if implicitCommitSQL, implicitRollbackSQL, ok := sqlEditorImplicitTransactionSQL(transactionDBType); ok {
commitSQL = implicitCommitSQL
rollbackSQL = implicitRollbackSQL
hasTextTransaction = true
implicitTextTransaction = true
}
dbInst, err := a.getDatabase(runConfig)
if err != nil {
logger.Error(err, "DBQueryMultiTransactional 获取连接失败:%s", formatConnSummary(runConfig))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
ctx, cancel := newQueryExecutionContext(runConfig)
defer cancel()
a.queryMu.Lock()
a.runningQueries[queryID] = queryContext{
cancel: cancel,
started: time.Now(),
}
a.queryMu.Unlock()
defer func() {
a.queryMu.Lock()
delete(a.runningQueries, queryID)
a.queryMu.Unlock()
}()
var (
sessionExecer db.StatementExecer
transactor db.TransactionExecer
transactionCancel context.CancelFunc
startTextTransaction bool
)
if implicitTextTransaction {
provider, ok := dbInst.(db.SessionExecerProvider)
if !ok {
return connection.QueryResult{
Success: false,
Message: fmt.Sprintf("当前数据源(%s不支持 SQL 编辑器托管事务", transactionDBType),
QueryID: queryID,
}
}
sessionExecer, err = provider.OpenSessionExecer(ctx)
if err != nil {
logger.Error(err, "DBQueryMultiTransactional 打开隐式事务会话失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
} else if provider, ok := dbInst.(db.TransactionExecerProvider); ok {
// database/sql rolls back a BeginTx transaction when its context is cancelled.
// SQL editor transactions must outlive the execution RPC and be ended only by
// explicit commit, rollback, or shutdown cleanup.
transactionContext := context.Background()
transactionContext, transactionCancel = context.WithCancel(transactionContext)
transactionExecer, err := provider.OpenTransactionExecer(transactionContext)
if err != nil {
transactionCancel()
logger.Error(err, "DBQueryMultiTransactional 打开驱动事务失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
sessionExecer = transactionExecer
transactor = transactionExecer
} else {
if !hasTextTransaction {
return connection.QueryResult{
Success: false,
Message: fmt.Sprintf("当前数据源(%s不支持 SQL 编辑器托管事务", transactionDBType),
QueryID: queryID,
}
}
provider, ok := dbInst.(db.SessionExecerProvider)
if !ok {
return connection.QueryResult{
Success: false,
Message: fmt.Sprintf("当前数据源(%s不支持 SQL 编辑器托管事务", transactionDBType),
QueryID: queryID,
}
}
sessionExecer, err = provider.OpenSessionExecer(ctx)
if err != nil {
logger.Error(err, "DBQueryMultiTransactional 打开事务会话失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
startTextTransaction = true
}
closeSession := true
defer func() {
if closeSession {
if err := sessionExecer.Close(); err != nil {
logger.Warnf("DBQueryMultiTransactional 关闭事务会话失败:%v", err)
}
if transactionCancel != nil {
transactionCancel()
}
}
}()
if startTextTransaction {
if _, err := sessionExecer.ExecContext(ctx, beginSQL); err != nil {
logger.Error(err, "DBQueryMultiTransactional 开启事务失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
}
statements := splitSQLStatements(query)
resultSets, err := executeManagedSQLTransactionStatements(ctx, sessionExecer, transactionConfig, statements)
if err != nil {
var rollbackErr error
if transactor != nil {
rollbackErr = transactor.Rollback()
} else if strings.TrimSpace(rollbackSQL) != "" {
_, rollbackErr = sessionExecer.ExecContext(context.Background(), rollbackSQL)
}
if rollbackErr != nil {
logger.Error(rollbackErr, "DBQueryMultiTransactional 执行失败后回滚失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
err = fmt.Errorf("%v回滚失败: %w", err, rollbackErr)
}
logger.Error(err, "DBQueryMultiTransactional 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
transactionID := "sql-editor-" + uuid.NewString()
a.sqlTransactionMu.Lock()
if a.sqlTransactions == nil {
a.sqlTransactions = make(map[string]*managedSQLTransaction)
}
a.sqlTransactions[transactionID] = &managedSQLTransaction{
id: transactionID,
execer: sessionExecer,
transactor: transactor,
cancel: transactionCancel,
dbType: transactionDBType,
commitSQL: commitSQL,
rollbackSQL: rollbackSQL,
createdAt: time.Now(),
}
a.sqlTransactionMu.Unlock()
closeSession = false
return connection.QueryResult{
Success: true,
Data: resultSets,
QueryID: queryID,
TransactionID: transactionID,
TransactionPending: true,
}
}
func executeManagedSQLTransactionStatements(ctx context.Context, session db.StatementExecer, runConfig connection.ConnectionConfig, statements []string) ([]connection.ResultSetData, error) {
var resultSets []connection.ResultSetData
sessionQueryTarget, _ := session.(db.StatementQueryExecer)
sessionQueryMessageTarget, _ := session.(db.StatementQueryMessageExecer)
sessionMultiQueryTarget, _ := session.(db.StatementMultiResultQueryExecer)
sessionMultiQueryMessageTarget, _ := session.(db.StatementMultiResultQueryMessageExecer)
for idx, stmt := range statements {
stmt = strings.TrimSpace(stmt)
if stmt == "" {
continue
}
isReadStmt := isReadOnlySQLQuery(runConfig.Type, stmt)
tryQueryStmtFirst := shouldTryQueryResultFirst(runConfig.Type, stmt)
if isReadStmt || tryQueryStmtFirst {
var (
data []map[string]interface{}
columns []string
messages []string
statementResults []connection.ResultSetData
usedMultiResult bool
err error
)
if sessionMultiQueryMessageTarget != nil {
statementResults, messages, err = sessionMultiQueryMessageTarget.QueryMultiContextWithMessages(ctx, stmt)
usedMultiResult = true
} else if sessionMultiQueryTarget != nil {
statementResults, err = sessionMultiQueryTarget.QueryMultiContext(ctx, stmt)
usedMultiResult = true
} else if sessionQueryMessageTarget != nil {
data, columns, messages, err = sessionQueryMessageTarget.QueryContextWithMessages(ctx, stmt)
} else if sessionQueryTarget != nil {
data, columns, err = sessionQueryTarget.QueryContext(ctx, stmt)
} else {
err = fmt.Errorf("当前事务会话不支持查询语句")
}
if err == nil {
if usedMultiResult {
if len(statementResults) == 0 && len(messages) > 0 {
statementResults = []connection.ResultSetData{{
Rows: []map[string]interface{}{},
Columns: []string{},
Messages: append([]string(nil), messages...),
}}
}
for _, statementResult := range statementResults {
if statementResult.Rows == nil {
statementResult.Rows = []map[string]interface{}{}
}
if statementResult.Columns == nil {
statementResult.Columns = []string{}
}
statementResult.StatementIndex = idx + 1
resultSets = append(resultSets, statementResult)
}
continue
}
if data == nil {
data = make([]map[string]interface{}, 0)
}
if columns == nil {
columns = []string{}
}
resultSets = append(resultSets, connection.ResultSetData{
Rows: data,
Columns: columns,
Messages: messages,
StatementIndex: idx + 1,
})
continue
}
if isReadStmt {
return nil, fmt.Errorf("第 %d 条语句执行失败: %w", idx+1, err)
}
}
affected, err := session.ExecContext(ctx, stmt)
if err != nil {
return nil, fmt.Errorf("第 %d 条语句执行失败: %w", idx+1, err)
}
resultSets = append(resultSets, connection.ResultSetData{
Rows: []map[string]interface{}{{"affectedRows": affected}},
Columns: []string{"affectedRows"},
StatementIndex: idx + 1,
})
}
if resultSets == nil {
resultSets = []connection.ResultSetData{}
}
return resultSets, nil
}
func shouldUseManagedSQLTransaction(dbType string, query string) bool {
statements := splitSQLStatements(query)
hasManagedWrite := false
for _, stmt := range statements {
stmt = strings.TrimSpace(stmt)
if stmt == "" {
continue
}
if isSQLTransactionControlStatement(stmt) {
return false
}
if isReadOnlySQLQuery(dbType, stmt) {
continue
}
if isBatchableWriteSQLStatement(dbType, stmt) {
hasManagedWrite = true
continue
}
return false
}
return hasManagedWrite
}
func sqlEditorImplicitTransactionSQL(dbType string) (commitSQL string, rollbackSQL string, ok bool) {
switch strings.ToLower(strings.TrimSpace(dbType)) {
case "oracle":
// Oracle starts a transaction implicitly on the first DML statement.
// Keeping SQL editor DML on one physical connection avoids database/sql
// Tx context lifecycle ending the transaction before the UI commits it.
return "COMMIT", "ROLLBACK", true
default:
return "", "", false
}
}
func isSQLTransactionControlStatement(stmt string) bool {
switch leadingSQLKeyword(stmt) {
case "begin", "commit", "rollback", "savepoint", "release":
return true
case "start":
return strings.Contains(strings.ToLower(stmt), "transaction")
default:
return false
}
}
func (a *App) DBCommitTransaction(transactionID string) connection.QueryResult {
return a.finishManagedSQLTransaction(transactionID, true)
}
func (a *App) DBRollbackTransaction(transactionID string) connection.QueryResult {
return a.finishManagedSQLTransaction(transactionID, false)
}
func (a *App) finishManagedSQLTransaction(transactionID string, commit bool) connection.QueryResult {
transactionID = strings.TrimSpace(transactionID)
if transactionID == "" {
return connection.QueryResult{Success: false, Message: "事务 ID 不能为空"}
}
a.sqlTransactionMu.Lock()
tx, ok := a.sqlTransactions[transactionID]
if ok {
delete(a.sqlTransactions, transactionID)
}
a.sqlTransactionMu.Unlock()
if !ok || tx == nil || tx.execer == nil {
return connection.QueryResult{Success: false, Message: "事务不存在或已结束"}
}
if tx.cancel != nil {
defer tx.cancel()
}
action := "回滚"
sqlText := tx.rollbackSQL
if commit {
action = "提交"
sqlText = tx.commitSQL
}
ctx, cancel := context.WithTimeout(context.Background(), sqlEditorTransactionFinishTimeout)
defer cancel()
var execErr error
if tx.transactor != nil {
if commit {
execErr = tx.transactor.Commit()
} else {
execErr = tx.transactor.Rollback()
}
} else if strings.TrimSpace(sqlText) != "" {
_, execErr = tx.execer.ExecContext(ctx, sqlText)
}
closeErr := tx.execer.Close()
if execErr != nil {
logger.Error(execErr, "SQL 编辑器事务%s失败id=%s dbType=%s", action, transactionID, tx.dbType)
return connection.QueryResult{Success: false, Message: fmt.Sprintf("事务%s失败: %v", action, execErr)}
}
if closeErr != nil {
logger.Error(closeErr, "SQL 编辑器事务%s后关闭会话失败id=%s dbType=%s", action, transactionID, tx.dbType)
return connection.QueryResult{Success: false, Message: fmt.Sprintf("事务%s成功但关闭会话失败: %v", action, closeErr)}
}
if commit {
return connection.QueryResult{Success: true, Message: "事务已提交"}
}
return connection.QueryResult{Success: true, Message: "事务已回滚"}
}
func (a *App) rollbackPendingSQLTransactionsOnShutdown() {
a.sqlTransactionMu.Lock()
pending := make([]*managedSQLTransaction, 0, len(a.sqlTransactions))
for id, tx := range a.sqlTransactions {
if tx != nil {
pending = append(pending, tx)
}
delete(a.sqlTransactions, id)
}
a.sqlTransactionMu.Unlock()
for _, tx := range pending {
ctx, cancel := context.WithTimeout(context.Background(), sqlEditorTransactionFinishTimeout)
if tx.transactor != nil {
if err := tx.transactor.Rollback(); err != nil {
logger.Warnf("关闭应用时回滚 SQL 编辑器事务失败id=%s dbType=%s err=%v", tx.id, tx.dbType, err)
}
} else if strings.TrimSpace(tx.rollbackSQL) != "" && tx.execer != nil {
if _, err := tx.execer.ExecContext(ctx, tx.rollbackSQL); err != nil {
logger.Warnf("关闭应用时回滚 SQL 编辑器事务失败id=%s dbType=%s err=%v", tx.id, tx.dbType, err)
}
}
cancel()
if tx.cancel != nil {
tx.cancel()
}
if tx.execer != nil {
if err := tx.execer.Close(); err != nil {
logger.Warnf("关闭应用时关闭 SQL 编辑器事务会话失败id=%s dbType=%s err=%v", tx.id, tx.dbType, err)
}
}
}
}