Files
MyGoNavi/internal/app/methods_db_transaction.go
Syngnat fce50b513c 🐛 fix(sql-editor): 修复 Oracle 事务结束并补充 Redis 拓扑提示
- SQL 编辑器:Oracle 托管事务优先使用 transaction provider 完成提交和回滚

- Redis:拆分 Key 浏览工具栏并展示 Cluster/Sentinel 拓扑上下文

- 测试:补充 Oracle 事务结束和 Redis 拓扑头部回归用例
2026-06-12 08:48:08 +08:00

422 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package app
import (
"context"
"fmt"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/db"
"GoNavi-Wails/internal/logger"
"github.com/google/uuid"
)
const sqlEditorTransactionFinishTimeout = 30 * time.Second
// DBQueryMultiTransactional executes SQL editor DML in a managed transaction.
// The transaction stays open until DBCommitTransaction or DBRollbackTransaction
// is called by the SQL editor UI.
func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbName string, query string, queryID string) connection.QueryResult {
runConfig := normalizeRunConfig(config, dbName)
transactionDBType := resolveDDLDBType(runConfig)
transactionConfig := runConfig
transactionConfig.Type = transactionDBType
if queryID == "" {
queryID = generateQueryID()
}
query = sanitizeSQLForPgLike(transactionDBType, query)
if !shouldUseManagedSQLTransaction(transactionDBType, query) {
return a.DBQueryMulti(config, dbName, query, queryID)
}
beginSQL, commitSQL, rollbackSQL, hasTextTransaction := sqlFileBatchTransactionSQL(transactionDBType)
implicitTextTransaction := false
if implicitCommitSQL, implicitRollbackSQL, ok := sqlEditorImplicitTransactionSQL(transactionDBType); ok {
commitSQL = implicitCommitSQL
rollbackSQL = implicitRollbackSQL
hasTextTransaction = true
implicitTextTransaction = true
}
dbInst, err := a.getDatabase(runConfig)
if err != nil {
logger.Error(err, "DBQueryMultiTransactional 获取连接失败:%s", formatConnSummary(runConfig))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
ctx, cancel := newQueryExecutionContext(runConfig)
defer cancel()
a.queryMu.Lock()
a.runningQueries[queryID] = queryContext{
cancel: cancel,
started: time.Now(),
}
a.queryMu.Unlock()
defer func() {
a.queryMu.Lock()
delete(a.runningQueries, queryID)
a.queryMu.Unlock()
}()
var (
sessionExecer db.StatementExecer
transactor db.TransactionExecer
transactionCancel context.CancelFunc
startTextTransaction bool
)
if provider, ok := dbInst.(db.TransactionExecerProvider); ok {
// database/sql rolls back a BeginTx transaction when its context is cancelled.
// SQL editor transactions must outlive the execution RPC and be ended only by
// explicit commit, rollback, or shutdown cleanup.
transactionContext := context.Background()
transactionContext, transactionCancel = context.WithCancel(transactionContext)
transactionExecer, err := provider.OpenTransactionExecer(transactionContext)
if err != nil {
transactionCancel()
logger.Error(err, "DBQueryMultiTransactional 打开驱动事务失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
sessionExecer = transactionExecer
transactor = transactionExecer
} else if implicitTextTransaction {
provider, ok := dbInst.(db.SessionExecerProvider)
if !ok {
return connection.QueryResult{
Success: false,
Message: fmt.Sprintf("当前数据源(%s不支持 SQL 编辑器托管事务", transactionDBType),
QueryID: queryID,
}
}
sessionExecer, err = provider.OpenSessionExecer(ctx)
if err != nil {
logger.Error(err, "DBQueryMultiTransactional 打开隐式事务会话失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
} else {
if !hasTextTransaction {
return connection.QueryResult{
Success: false,
Message: fmt.Sprintf("当前数据源(%s不支持 SQL 编辑器托管事务", transactionDBType),
QueryID: queryID,
}
}
provider, ok := dbInst.(db.SessionExecerProvider)
if !ok {
return connection.QueryResult{
Success: false,
Message: fmt.Sprintf("当前数据源(%s不支持 SQL 编辑器托管事务", transactionDBType),
QueryID: queryID,
}
}
sessionExecer, err = provider.OpenSessionExecer(ctx)
if err != nil {
logger.Error(err, "DBQueryMultiTransactional 打开事务会话失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
startTextTransaction = true
}
closeSession := true
defer func() {
if closeSession {
if err := sessionExecer.Close(); err != nil {
logger.Warnf("DBQueryMultiTransactional 关闭事务会话失败:%v", err)
}
if transactionCancel != nil {
transactionCancel()
}
}
}()
if startTextTransaction {
if _, err := sessionExecer.ExecContext(ctx, beginSQL); err != nil {
logger.Error(err, "DBQueryMultiTransactional 开启事务失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
}
statements := splitSQLStatements(query)
resultSets, err := executeManagedSQLTransactionStatements(ctx, sessionExecer, transactionConfig, statements)
if err != nil {
var rollbackErr error
if transactor != nil {
rollbackErr = transactor.Rollback()
} else if strings.TrimSpace(rollbackSQL) != "" {
_, rollbackErr = sessionExecer.ExecContext(context.Background(), rollbackSQL)
}
if rollbackErr != nil {
logger.Error(rollbackErr, "DBQueryMultiTransactional 执行失败后回滚失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
err = fmt.Errorf("%v回滚失败: %w", err, rollbackErr)
}
logger.Error(err, "DBQueryMultiTransactional 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
transactionID := "sql-editor-" + uuid.NewString()
a.sqlTransactionMu.Lock()
if a.sqlTransactions == nil {
a.sqlTransactions = make(map[string]*managedSQLTransaction)
}
a.sqlTransactions[transactionID] = &managedSQLTransaction{
id: transactionID,
execer: sessionExecer,
transactor: transactor,
cancel: transactionCancel,
dbType: transactionDBType,
commitSQL: commitSQL,
rollbackSQL: rollbackSQL,
createdAt: time.Now(),
}
a.sqlTransactionMu.Unlock()
closeSession = false
return connection.QueryResult{
Success: true,
Data: resultSets,
QueryID: queryID,
TransactionID: transactionID,
TransactionPending: true,
}
}
func executeManagedSQLTransactionStatements(ctx context.Context, session db.StatementExecer, runConfig connection.ConnectionConfig, statements []string) ([]connection.ResultSetData, error) {
var resultSets []connection.ResultSetData
sessionQueryTarget, _ := session.(db.StatementQueryExecer)
sessionQueryMessageTarget, _ := session.(db.StatementQueryMessageExecer)
sessionMultiQueryTarget, _ := session.(db.StatementMultiResultQueryExecer)
sessionMultiQueryMessageTarget, _ := session.(db.StatementMultiResultQueryMessageExecer)
for idx, stmt := range statements {
stmt = strings.TrimSpace(stmt)
if stmt == "" {
continue
}
isReadStmt := isReadOnlySQLQuery(runConfig.Type, stmt)
tryQueryStmtFirst := shouldTryQueryResultFirst(runConfig.Type, stmt)
if isReadStmt || tryQueryStmtFirst {
var (
data []map[string]interface{}
columns []string
messages []string
statementResults []connection.ResultSetData
usedMultiResult bool
err error
)
if sessionMultiQueryMessageTarget != nil {
statementResults, messages, err = sessionMultiQueryMessageTarget.QueryMultiContextWithMessages(ctx, stmt)
usedMultiResult = true
} else if sessionMultiQueryTarget != nil {
statementResults, err = sessionMultiQueryTarget.QueryMultiContext(ctx, stmt)
usedMultiResult = true
} else if sessionQueryMessageTarget != nil {
data, columns, messages, err = sessionQueryMessageTarget.QueryContextWithMessages(ctx, stmt)
} else if sessionQueryTarget != nil {
data, columns, err = sessionQueryTarget.QueryContext(ctx, stmt)
} else {
err = fmt.Errorf("当前事务会话不支持查询语句")
}
if err == nil {
if usedMultiResult {
if len(statementResults) == 0 && len(messages) > 0 {
statementResults = []connection.ResultSetData{{
Rows: []map[string]interface{}{},
Columns: []string{},
Messages: append([]string(nil), messages...),
}}
}
for _, statementResult := range statementResults {
if statementResult.Rows == nil {
statementResult.Rows = []map[string]interface{}{}
}
if statementResult.Columns == nil {
statementResult.Columns = []string{}
}
statementResult.StatementIndex = idx + 1
resultSets = append(resultSets, statementResult)
}
continue
}
if data == nil {
data = make([]map[string]interface{}, 0)
}
if columns == nil {
columns = []string{}
}
resultSets = append(resultSets, connection.ResultSetData{
Rows: data,
Columns: columns,
Messages: messages,
StatementIndex: idx + 1,
})
continue
}
if isReadStmt {
return nil, fmt.Errorf("第 %d 条语句执行失败: %w", idx+1, err)
}
}
affected, err := session.ExecContext(ctx, stmt)
if err != nil {
return nil, fmt.Errorf("第 %d 条语句执行失败: %w", idx+1, err)
}
resultSets = append(resultSets, connection.ResultSetData{
Rows: []map[string]interface{}{{"affectedRows": affected}},
Columns: []string{"affectedRows"},
StatementIndex: idx + 1,
})
}
if resultSets == nil {
resultSets = []connection.ResultSetData{}
}
return resultSets, nil
}
func shouldUseManagedSQLTransaction(dbType string, query string) bool {
statements := splitSQLStatements(query)
hasManagedWrite := false
for _, stmt := range statements {
stmt = strings.TrimSpace(stmt)
if stmt == "" {
continue
}
if isSQLTransactionControlStatement(stmt) {
return false
}
if isReadOnlySQLQuery(dbType, stmt) {
continue
}
if isBatchableWriteSQLStatement(dbType, stmt) {
hasManagedWrite = true
continue
}
return false
}
return hasManagedWrite
}
func sqlEditorImplicitTransactionSQL(dbType string) (commitSQL string, rollbackSQL string, ok bool) {
switch strings.ToLower(strings.TrimSpace(dbType)) {
case "oracle":
// Oracle starts a transaction implicitly on the first DML statement.
// Keeping SQL editor DML on one physical connection avoids database/sql
// Tx context lifecycle ending the transaction before the UI commits it.
return "COMMIT", "ROLLBACK", true
default:
return "", "", false
}
}
func isSQLTransactionControlStatement(stmt string) bool {
switch leadingSQLKeyword(stmt) {
case "begin", "commit", "rollback", "savepoint", "release":
return true
case "start":
return strings.Contains(strings.ToLower(stmt), "transaction")
default:
return false
}
}
func (a *App) DBCommitTransaction(transactionID string) connection.QueryResult {
return a.finishManagedSQLTransaction(transactionID, true)
}
func (a *App) DBRollbackTransaction(transactionID string) connection.QueryResult {
return a.finishManagedSQLTransaction(transactionID, false)
}
func (a *App) finishManagedSQLTransaction(transactionID string, commit bool) connection.QueryResult {
transactionID = strings.TrimSpace(transactionID)
if transactionID == "" {
return connection.QueryResult{Success: false, Message: "事务 ID 不能为空"}
}
a.sqlTransactionMu.Lock()
tx, ok := a.sqlTransactions[transactionID]
if ok {
delete(a.sqlTransactions, transactionID)
}
a.sqlTransactionMu.Unlock()
if !ok || tx == nil || tx.execer == nil {
return connection.QueryResult{Success: false, Message: "事务不存在或已结束"}
}
if tx.cancel != nil {
defer tx.cancel()
}
action := "回滚"
sqlText := tx.rollbackSQL
if commit {
action = "提交"
sqlText = tx.commitSQL
}
ctx, cancel := context.WithTimeout(context.Background(), sqlEditorTransactionFinishTimeout)
defer cancel()
var execErr error
if tx.transactor != nil {
if commit {
execErr = tx.transactor.Commit()
} else {
execErr = tx.transactor.Rollback()
}
} else if strings.TrimSpace(sqlText) != "" {
_, execErr = tx.execer.ExecContext(ctx, sqlText)
}
closeErr := tx.execer.Close()
if execErr != nil {
logger.Error(execErr, "SQL 编辑器事务%s失败id=%s dbType=%s", action, transactionID, tx.dbType)
return connection.QueryResult{Success: false, Message: fmt.Sprintf("事务%s失败: %v", action, execErr)}
}
if closeErr != nil {
logger.Error(closeErr, "SQL 编辑器事务%s后关闭会话失败id=%s dbType=%s", action, transactionID, tx.dbType)
return connection.QueryResult{Success: false, Message: fmt.Sprintf("事务%s成功但关闭会话失败: %v", action, closeErr)}
}
if commit {
return connection.QueryResult{Success: true, Message: "事务已提交"}
}
return connection.QueryResult{Success: true, Message: "事务已回滚"}
}
func (a *App) rollbackPendingSQLTransactionsOnShutdown() {
a.sqlTransactionMu.Lock()
pending := make([]*managedSQLTransaction, 0, len(a.sqlTransactions))
for id, tx := range a.sqlTransactions {
if tx != nil {
pending = append(pending, tx)
}
delete(a.sqlTransactions, id)
}
a.sqlTransactionMu.Unlock()
for _, tx := range pending {
ctx, cancel := context.WithTimeout(context.Background(), sqlEditorTransactionFinishTimeout)
if tx.transactor != nil {
if err := tx.transactor.Rollback(); err != nil {
logger.Warnf("关闭应用时回滚 SQL 编辑器事务失败id=%s dbType=%s err=%v", tx.id, tx.dbType, err)
}
} else if strings.TrimSpace(tx.rollbackSQL) != "" && tx.execer != nil {
if _, err := tx.execer.ExecContext(ctx, tx.rollbackSQL); err != nil {
logger.Warnf("关闭应用时回滚 SQL 编辑器事务失败id=%s dbType=%s err=%v", tx.id, tx.dbType, err)
}
}
cancel()
if tx.cancel != nil {
tx.cancel()
}
if tx.execer != nil {
if err := tx.execer.Close(); err != nil {
logger.Warnf("关闭应用时关闭 SQL 编辑器事务会话失败id=%s dbType=%s err=%v", tx.id, tx.dbType, err)
}
}
}
}