Files
MyGoNavi/internal/app/methods_db_transaction.go
Syngnat e4672062f8 🐛 fix(query-editor): 支持 Oracle SQL 编辑器托管事务
- 新增 driver transaction 执行器,支持不适合文本 BEGIN 的数据库

- Oracle SQL 编辑器 DML 托管事务改用 database/sql Tx 提交和回滚

- 补充 Oracle 托管事务提交和失败回滚回归测试
2026-06-11 15:45:13 +08:00

369 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package app
import (
"context"
"fmt"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/db"
"GoNavi-Wails/internal/logger"
"github.com/google/uuid"
)
const sqlEditorTransactionFinishTimeout = 30 * time.Second
// DBQueryMultiTransactional executes SQL editor DML in a managed transaction.
// The transaction stays open until DBCommitTransaction or DBRollbackTransaction
// is called by the SQL editor UI.
func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbName string, query string, queryID string) connection.QueryResult {
runConfig := normalizeRunConfig(config, dbName)
if queryID == "" {
queryID = generateQueryID()
}
query = sanitizeSQLForPgLike(resolveDDLDBType(config), query)
if !shouldUseManagedSQLTransaction(runConfig.Type, query) {
return a.DBQueryMulti(config, dbName, query, queryID)
}
beginSQL, commitSQL, rollbackSQL, hasTextTransaction := sqlFileBatchTransactionSQL(runConfig.Type)
dbInst, err := a.getDatabase(runConfig)
if err != nil {
logger.Error(err, "DBQueryMultiTransactional 获取连接失败:%s", formatConnSummary(runConfig))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
ctx, cancel := newQueryExecutionContext(runConfig)
defer cancel()
a.queryMu.Lock()
a.runningQueries[queryID] = queryContext{
cancel: cancel,
started: time.Now(),
}
a.queryMu.Unlock()
defer func() {
a.queryMu.Lock()
delete(a.runningQueries, queryID)
a.queryMu.Unlock()
}()
var (
sessionExecer db.StatementExecer
transactor db.TransactionExecer
startTextTransaction bool
)
if provider, ok := dbInst.(db.TransactionExecerProvider); ok {
transactionExecer, err := provider.OpenTransactionExecer(ctx)
if err != nil {
logger.Error(err, "DBQueryMultiTransactional 打开驱动事务失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
sessionExecer = transactionExecer
transactor = transactionExecer
} else {
if !hasTextTransaction {
return connection.QueryResult{
Success: false,
Message: fmt.Sprintf("当前数据源(%s不支持 SQL 编辑器托管事务", runConfig.Type),
QueryID: queryID,
}
}
provider, ok := dbInst.(db.SessionExecerProvider)
if !ok {
return connection.QueryResult{
Success: false,
Message: fmt.Sprintf("当前数据源(%s不支持 SQL 编辑器托管事务", runConfig.Type),
QueryID: queryID,
}
}
sessionExecer, err = provider.OpenSessionExecer(ctx)
if err != nil {
logger.Error(err, "DBQueryMultiTransactional 打开事务会话失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
startTextTransaction = true
}
closeSession := true
defer func() {
if closeSession {
if err := sessionExecer.Close(); err != nil {
logger.Warnf("DBQueryMultiTransactional 关闭事务会话失败:%v", err)
}
}
}()
if startTextTransaction {
if _, err := sessionExecer.ExecContext(ctx, beginSQL); err != nil {
logger.Error(err, "DBQueryMultiTransactional 开启事务失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
}
statements := splitSQLStatements(query)
resultSets, err := executeManagedSQLTransactionStatements(ctx, sessionExecer, runConfig, statements)
if err != nil {
var rollbackErr error
if transactor != nil {
rollbackErr = transactor.Rollback()
} else if strings.TrimSpace(rollbackSQL) != "" {
_, rollbackErr = sessionExecer.ExecContext(context.Background(), rollbackSQL)
}
if rollbackErr != nil {
logger.Error(rollbackErr, "DBQueryMultiTransactional 执行失败后回滚失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
err = fmt.Errorf("%v回滚失败: %w", err, rollbackErr)
}
logger.Error(err, "DBQueryMultiTransactional 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
transactionID := "sql-editor-" + uuid.NewString()
a.sqlTransactionMu.Lock()
if a.sqlTransactions == nil {
a.sqlTransactions = make(map[string]*managedSQLTransaction)
}
a.sqlTransactions[transactionID] = &managedSQLTransaction{
id: transactionID,
execer: sessionExecer,
transactor: transactor,
dbType: runConfig.Type,
commitSQL: commitSQL,
rollbackSQL: rollbackSQL,
createdAt: time.Now(),
}
a.sqlTransactionMu.Unlock()
closeSession = false
return connection.QueryResult{
Success: true,
Data: resultSets,
QueryID: queryID,
TransactionID: transactionID,
TransactionPending: true,
}
}
func executeManagedSQLTransactionStatements(ctx context.Context, session db.StatementExecer, runConfig connection.ConnectionConfig, statements []string) ([]connection.ResultSetData, error) {
var resultSets []connection.ResultSetData
sessionQueryTarget, _ := session.(db.StatementQueryExecer)
sessionQueryMessageTarget, _ := session.(db.StatementQueryMessageExecer)
sessionMultiQueryTarget, _ := session.(db.StatementMultiResultQueryExecer)
sessionMultiQueryMessageTarget, _ := session.(db.StatementMultiResultQueryMessageExecer)
for idx, stmt := range statements {
stmt = strings.TrimSpace(stmt)
if stmt == "" {
continue
}
isReadStmt := isReadOnlySQLQuery(runConfig.Type, stmt)
tryQueryStmtFirst := shouldTryQueryResultFirst(runConfig.Type, stmt)
if isReadStmt || tryQueryStmtFirst {
var (
data []map[string]interface{}
columns []string
messages []string
statementResults []connection.ResultSetData
usedMultiResult bool
err error
)
if sessionMultiQueryMessageTarget != nil {
statementResults, messages, err = sessionMultiQueryMessageTarget.QueryMultiContextWithMessages(ctx, stmt)
usedMultiResult = true
} else if sessionMultiQueryTarget != nil {
statementResults, err = sessionMultiQueryTarget.QueryMultiContext(ctx, stmt)
usedMultiResult = true
} else if sessionQueryMessageTarget != nil {
data, columns, messages, err = sessionQueryMessageTarget.QueryContextWithMessages(ctx, stmt)
} else if sessionQueryTarget != nil {
data, columns, err = sessionQueryTarget.QueryContext(ctx, stmt)
} else {
err = fmt.Errorf("当前事务会话不支持查询语句")
}
if err == nil {
if usedMultiResult {
if len(statementResults) == 0 && len(messages) > 0 {
statementResults = []connection.ResultSetData{{
Rows: []map[string]interface{}{},
Columns: []string{},
Messages: append([]string(nil), messages...),
}}
}
for _, statementResult := range statementResults {
if statementResult.Rows == nil {
statementResult.Rows = []map[string]interface{}{}
}
if statementResult.Columns == nil {
statementResult.Columns = []string{}
}
statementResult.StatementIndex = idx + 1
resultSets = append(resultSets, statementResult)
}
continue
}
if data == nil {
data = make([]map[string]interface{}, 0)
}
if columns == nil {
columns = []string{}
}
resultSets = append(resultSets, connection.ResultSetData{
Rows: data,
Columns: columns,
Messages: messages,
StatementIndex: idx + 1,
})
continue
}
if isReadStmt {
return nil, fmt.Errorf("第 %d 条语句执行失败: %w", idx+1, err)
}
}
affected, err := session.ExecContext(ctx, stmt)
if err != nil {
return nil, fmt.Errorf("第 %d 条语句执行失败: %w", idx+1, err)
}
resultSets = append(resultSets, connection.ResultSetData{
Rows: []map[string]interface{}{{"affectedRows": affected}},
Columns: []string{"affectedRows"},
StatementIndex: idx + 1,
})
}
if resultSets == nil {
resultSets = []connection.ResultSetData{}
}
return resultSets, nil
}
func shouldUseManagedSQLTransaction(dbType string, query string) bool {
statements := splitSQLStatements(query)
hasManagedWrite := false
for _, stmt := range statements {
stmt = strings.TrimSpace(stmt)
if stmt == "" {
continue
}
if isSQLTransactionControlStatement(stmt) {
return false
}
if isReadOnlySQLQuery(dbType, stmt) {
continue
}
if isBatchableWriteSQLStatement(dbType, stmt) {
hasManagedWrite = true
continue
}
return false
}
return hasManagedWrite
}
func isSQLTransactionControlStatement(stmt string) bool {
switch leadingSQLKeyword(stmt) {
case "begin", "commit", "rollback", "savepoint", "release":
return true
case "start":
return strings.Contains(strings.ToLower(stmt), "transaction")
default:
return false
}
}
func (a *App) DBCommitTransaction(transactionID string) connection.QueryResult {
return a.finishManagedSQLTransaction(transactionID, true)
}
func (a *App) DBRollbackTransaction(transactionID string) connection.QueryResult {
return a.finishManagedSQLTransaction(transactionID, false)
}
func (a *App) finishManagedSQLTransaction(transactionID string, commit bool) connection.QueryResult {
transactionID = strings.TrimSpace(transactionID)
if transactionID == "" {
return connection.QueryResult{Success: false, Message: "事务 ID 不能为空"}
}
a.sqlTransactionMu.Lock()
tx, ok := a.sqlTransactions[transactionID]
if ok {
delete(a.sqlTransactions, transactionID)
}
a.sqlTransactionMu.Unlock()
if !ok || tx == nil || tx.execer == nil {
return connection.QueryResult{Success: false, Message: "事务不存在或已结束"}
}
action := "回滚"
sqlText := tx.rollbackSQL
if commit {
action = "提交"
sqlText = tx.commitSQL
}
ctx, cancel := context.WithTimeout(context.Background(), sqlEditorTransactionFinishTimeout)
defer cancel()
var execErr error
if tx.transactor != nil {
if commit {
execErr = tx.transactor.Commit()
} else {
execErr = tx.transactor.Rollback()
}
} else if strings.TrimSpace(sqlText) != "" {
_, execErr = tx.execer.ExecContext(ctx, sqlText)
}
closeErr := tx.execer.Close()
if execErr != nil {
logger.Error(execErr, "SQL 编辑器事务%s失败id=%s dbType=%s", action, transactionID, tx.dbType)
return connection.QueryResult{Success: false, Message: fmt.Sprintf("事务%s失败: %v", action, execErr)}
}
if closeErr != nil {
logger.Error(closeErr, "SQL 编辑器事务%s后关闭会话失败id=%s dbType=%s", action, transactionID, tx.dbType)
return connection.QueryResult{Success: false, Message: fmt.Sprintf("事务%s成功但关闭会话失败: %v", action, closeErr)}
}
if commit {
return connection.QueryResult{Success: true, Message: "事务已提交"}
}
return connection.QueryResult{Success: true, Message: "事务已回滚"}
}
func (a *App) rollbackPendingSQLTransactionsOnShutdown() {
a.sqlTransactionMu.Lock()
pending := make([]*managedSQLTransaction, 0, len(a.sqlTransactions))
for id, tx := range a.sqlTransactions {
if tx != nil {
pending = append(pending, tx)
}
delete(a.sqlTransactions, id)
}
a.sqlTransactionMu.Unlock()
for _, tx := range pending {
ctx, cancel := context.WithTimeout(context.Background(), sqlEditorTransactionFinishTimeout)
if tx.transactor != nil {
if err := tx.transactor.Rollback(); err != nil {
logger.Warnf("关闭应用时回滚 SQL 编辑器事务失败id=%s dbType=%s err=%v", tx.id, tx.dbType, err)
}
} else if strings.TrimSpace(tx.rollbackSQL) != "" && tx.execer != nil {
if _, err := tx.execer.ExecContext(ctx, tx.rollbackSQL); err != nil {
logger.Warnf("关闭应用时回滚 SQL 编辑器事务失败id=%s dbType=%s err=%v", tx.id, tx.dbType, err)
}
}
cancel()
if tx.execer != nil {
if err := tx.execer.Close(); err != nil {
logger.Warnf("关闭应用时关闭 SQL 编辑器事务会话失败id=%s dbType=%s err=%v", tx.id, tx.dbType, err)
}
}
}
}