Files
MyGoNavi/internal/app/methods_db_transaction.go
Syngnat 61d71cf1d0 feat(editor): 支持 SQL 编辑器增删改事务提交
- 为 SQL 编辑器 DML 新增后端托管事务会话和提交回滚接口

- 增加手动提交与自动提交设置,并显示待提交状态

- 补充前后端事务执行、提交、回滚和自动提交测试
2026-06-10 17:18:34 +08:00

335 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package app
import (
"context"
"fmt"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/db"
"GoNavi-Wails/internal/logger"
"github.com/google/uuid"
)
const sqlEditorTransactionFinishTimeout = 30 * time.Second
// DBQueryMultiTransactional executes SQL editor DML in a managed transaction.
// The transaction stays open until DBCommitTransaction or DBRollbackTransaction
// is called by the SQL editor UI.
func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbName string, query string, queryID string) connection.QueryResult {
runConfig := normalizeRunConfig(config, dbName)
if queryID == "" {
queryID = generateQueryID()
}
query = sanitizeSQLForPgLike(resolveDDLDBType(config), query)
if !shouldUseManagedSQLTransaction(runConfig.Type, query) {
return a.DBQueryMulti(config, dbName, query, queryID)
}
beginSQL, commitSQL, rollbackSQL, ok := sqlFileBatchTransactionSQL(runConfig.Type)
if !ok {
return connection.QueryResult{
Success: false,
Message: fmt.Sprintf("当前数据源(%s不支持 SQL 编辑器托管事务", runConfig.Type),
QueryID: queryID,
}
}
dbInst, err := a.getDatabase(runConfig)
if err != nil {
logger.Error(err, "DBQueryMultiTransactional 获取连接失败:%s", formatConnSummary(runConfig))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
provider, ok := dbInst.(db.SessionExecerProvider)
if !ok {
return connection.QueryResult{
Success: false,
Message: fmt.Sprintf("当前数据源(%s不支持 SQL 编辑器托管事务", runConfig.Type),
QueryID: queryID,
}
}
ctx, cancel := newQueryExecutionContext(runConfig)
defer cancel()
a.queryMu.Lock()
a.runningQueries[queryID] = queryContext{
cancel: cancel,
started: time.Now(),
}
a.queryMu.Unlock()
defer func() {
a.queryMu.Lock()
delete(a.runningQueries, queryID)
a.queryMu.Unlock()
}()
sessionExecer, err := provider.OpenSessionExecer(ctx)
if err != nil {
logger.Error(err, "DBQueryMultiTransactional 打开事务会话失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
closeSession := true
defer func() {
if closeSession {
if err := sessionExecer.Close(); err != nil {
logger.Warnf("DBQueryMultiTransactional 关闭事务会话失败:%v", err)
}
}
}()
if _, err := sessionExecer.ExecContext(ctx, beginSQL); err != nil {
logger.Error(err, "DBQueryMultiTransactional 开启事务失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
statements := splitSQLStatements(query)
resultSets, err := executeManagedSQLTransactionStatements(ctx, sessionExecer, runConfig, statements)
if err != nil {
if _, rollbackErr := sessionExecer.ExecContext(context.Background(), rollbackSQL); rollbackErr != nil {
logger.Error(rollbackErr, "DBQueryMultiTransactional 执行失败后回滚失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
err = fmt.Errorf("%v回滚失败: %w", err, rollbackErr)
}
logger.Error(err, "DBQueryMultiTransactional 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
transactionID := "sql-editor-" + uuid.NewString()
a.sqlTransactionMu.Lock()
if a.sqlTransactions == nil {
a.sqlTransactions = make(map[string]*managedSQLTransaction)
}
a.sqlTransactions[transactionID] = &managedSQLTransaction{
id: transactionID,
execer: sessionExecer,
dbType: runConfig.Type,
commitSQL: commitSQL,
rollbackSQL: rollbackSQL,
createdAt: time.Now(),
}
a.sqlTransactionMu.Unlock()
closeSession = false
return connection.QueryResult{
Success: true,
Data: resultSets,
QueryID: queryID,
TransactionID: transactionID,
TransactionPending: true,
}
}
func executeManagedSQLTransactionStatements(ctx context.Context, session db.StatementExecer, runConfig connection.ConnectionConfig, statements []string) ([]connection.ResultSetData, error) {
var resultSets []connection.ResultSetData
sessionQueryTarget, _ := session.(db.StatementQueryExecer)
sessionQueryMessageTarget, _ := session.(db.StatementQueryMessageExecer)
sessionMultiQueryTarget, _ := session.(db.StatementMultiResultQueryExecer)
sessionMultiQueryMessageTarget, _ := session.(db.StatementMultiResultQueryMessageExecer)
for idx, stmt := range statements {
stmt = strings.TrimSpace(stmt)
if stmt == "" {
continue
}
isReadStmt := isReadOnlySQLQuery(runConfig.Type, stmt)
tryQueryStmtFirst := shouldTryQueryResultFirst(runConfig.Type, stmt)
if isReadStmt || tryQueryStmtFirst {
var (
data []map[string]interface{}
columns []string
messages []string
statementResults []connection.ResultSetData
usedMultiResult bool
err error
)
if sessionMultiQueryMessageTarget != nil {
statementResults, messages, err = sessionMultiQueryMessageTarget.QueryMultiContextWithMessages(ctx, stmt)
usedMultiResult = true
} else if sessionMultiQueryTarget != nil {
statementResults, err = sessionMultiQueryTarget.QueryMultiContext(ctx, stmt)
usedMultiResult = true
} else if sessionQueryMessageTarget != nil {
data, columns, messages, err = sessionQueryMessageTarget.QueryContextWithMessages(ctx, stmt)
} else if sessionQueryTarget != nil {
data, columns, err = sessionQueryTarget.QueryContext(ctx, stmt)
} else {
err = fmt.Errorf("当前事务会话不支持查询语句")
}
if err == nil {
if usedMultiResult {
if len(statementResults) == 0 && len(messages) > 0 {
statementResults = []connection.ResultSetData{{
Rows: []map[string]interface{}{},
Columns: []string{},
Messages: append([]string(nil), messages...),
}}
}
for _, statementResult := range statementResults {
if statementResult.Rows == nil {
statementResult.Rows = []map[string]interface{}{}
}
if statementResult.Columns == nil {
statementResult.Columns = []string{}
}
statementResult.StatementIndex = idx + 1
resultSets = append(resultSets, statementResult)
}
continue
}
if data == nil {
data = make([]map[string]interface{}, 0)
}
if columns == nil {
columns = []string{}
}
resultSets = append(resultSets, connection.ResultSetData{
Rows: data,
Columns: columns,
Messages: messages,
StatementIndex: idx + 1,
})
continue
}
if isReadStmt {
return nil, fmt.Errorf("第 %d 条语句执行失败: %w", idx+1, err)
}
}
affected, err := session.ExecContext(ctx, stmt)
if err != nil {
return nil, fmt.Errorf("第 %d 条语句执行失败: %w", idx+1, err)
}
resultSets = append(resultSets, connection.ResultSetData{
Rows: []map[string]interface{}{{"affectedRows": affected}},
Columns: []string{"affectedRows"},
StatementIndex: idx + 1,
})
}
if resultSets == nil {
resultSets = []connection.ResultSetData{}
}
return resultSets, nil
}
func shouldUseManagedSQLTransaction(dbType string, query string) bool {
statements := splitSQLStatements(query)
hasManagedWrite := false
for _, stmt := range statements {
stmt = strings.TrimSpace(stmt)
if stmt == "" {
continue
}
if isSQLTransactionControlStatement(stmt) {
return false
}
if isReadOnlySQLQuery(dbType, stmt) {
continue
}
if isBatchableWriteSQLStatement(dbType, stmt) {
hasManagedWrite = true
continue
}
return false
}
return hasManagedWrite
}
func isSQLTransactionControlStatement(stmt string) bool {
switch leadingSQLKeyword(stmt) {
case "begin", "commit", "rollback", "savepoint", "release":
return true
case "start":
return strings.Contains(strings.ToLower(stmt), "transaction")
default:
return false
}
}
func (a *App) DBCommitTransaction(transactionID string) connection.QueryResult {
return a.finishManagedSQLTransaction(transactionID, true)
}
func (a *App) DBRollbackTransaction(transactionID string) connection.QueryResult {
return a.finishManagedSQLTransaction(transactionID, false)
}
func (a *App) finishManagedSQLTransaction(transactionID string, commit bool) connection.QueryResult {
transactionID = strings.TrimSpace(transactionID)
if transactionID == "" {
return connection.QueryResult{Success: false, Message: "事务 ID 不能为空"}
}
a.sqlTransactionMu.Lock()
tx, ok := a.sqlTransactions[transactionID]
if ok {
delete(a.sqlTransactions, transactionID)
}
a.sqlTransactionMu.Unlock()
if !ok || tx == nil || tx.execer == nil {
return connection.QueryResult{Success: false, Message: "事务不存在或已结束"}
}
action := "回滚"
sqlText := tx.rollbackSQL
if commit {
action = "提交"
sqlText = tx.commitSQL
}
ctx, cancel := context.WithTimeout(context.Background(), sqlEditorTransactionFinishTimeout)
defer cancel()
var execErr error
if strings.TrimSpace(sqlText) != "" {
_, execErr = tx.execer.ExecContext(ctx, sqlText)
}
closeErr := tx.execer.Close()
if execErr != nil {
logger.Error(execErr, "SQL 编辑器事务%s失败id=%s dbType=%s", action, transactionID, tx.dbType)
return connection.QueryResult{Success: false, Message: fmt.Sprintf("事务%s失败: %v", action, execErr)}
}
if closeErr != nil {
logger.Error(closeErr, "SQL 编辑器事务%s后关闭会话失败id=%s dbType=%s", action, transactionID, tx.dbType)
return connection.QueryResult{Success: false, Message: fmt.Sprintf("事务%s成功但关闭会话失败: %v", action, closeErr)}
}
if commit {
return connection.QueryResult{Success: true, Message: "事务已提交"}
}
return connection.QueryResult{Success: true, Message: "事务已回滚"}
}
func (a *App) rollbackPendingSQLTransactionsOnShutdown() {
a.sqlTransactionMu.Lock()
pending := make([]*managedSQLTransaction, 0, len(a.sqlTransactions))
for id, tx := range a.sqlTransactions {
if tx != nil {
pending = append(pending, tx)
}
delete(a.sqlTransactions, id)
}
a.sqlTransactionMu.Unlock()
for _, tx := range pending {
ctx, cancel := context.WithTimeout(context.Background(), sqlEditorTransactionFinishTimeout)
if strings.TrimSpace(tx.rollbackSQL) != "" && tx.execer != nil {
if _, err := tx.execer.ExecContext(ctx, tx.rollbackSQL); err != nil {
logger.Warnf("关闭应用时回滚 SQL 编辑器事务失败id=%s dbType=%s err=%v", tx.id, tx.dbType, err)
}
}
cancel()
if tx.execer != nil {
if err := tx.execer.Close(); err != nil {
logger.Warnf("关闭应用时关闭 SQL 编辑器事务会话失败id=%s dbType=%s err=%v", tx.id, tx.dbType, err)
}
}
}
}