package app import ( "context" "fmt" "strings" "time" "GoNavi-Wails/internal/connection" "GoNavi-Wails/internal/db" "GoNavi-Wails/internal/logger" "github.com/google/uuid" ) const sqlEditorTransactionFinishTimeout = 30 * time.Second // DBQueryMultiTransactional executes SQL editor DML in a managed transaction. // The transaction stays open until DBCommitTransaction or DBRollbackTransaction // is called by the SQL editor UI. func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbName string, query string, queryID string) connection.QueryResult { runConfig := normalizeRunConfig(config, dbName) if queryID == "" { queryID = generateQueryID() } query = sanitizeSQLForPgLike(resolveDDLDBType(config), query) if !shouldUseManagedSQLTransaction(runConfig.Type, query) { return a.DBQueryMulti(config, dbName, query, queryID) } beginSQL, commitSQL, rollbackSQL, hasTextTransaction := sqlFileBatchTransactionSQL(runConfig.Type) dbInst, err := a.getDatabase(runConfig) if err != nil { logger.Error(err, "DBQueryMultiTransactional 获取连接失败:%s", formatConnSummary(runConfig)) return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID} } ctx, cancel := newQueryExecutionContext(runConfig) defer cancel() a.queryMu.Lock() a.runningQueries[queryID] = queryContext{ cancel: cancel, started: time.Now(), } a.queryMu.Unlock() defer func() { a.queryMu.Lock() delete(a.runningQueries, queryID) a.queryMu.Unlock() }() var ( sessionExecer db.StatementExecer transactor db.TransactionExecer transactionCancel context.CancelFunc startTextTransaction bool ) if provider, ok := dbInst.(db.TransactionExecerProvider); ok { transactionContext := context.Background() if a.ctx != nil { transactionContext = a.ctx } transactionContext, transactionCancel = context.WithCancel(transactionContext) transactionExecer, err := provider.OpenTransactionExecer(transactionContext) if err != nil { transactionCancel() logger.Error(err, "DBQueryMultiTransactional 打开驱动事务失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query)) return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID} } sessionExecer = transactionExecer transactor = transactionExecer } else { if !hasTextTransaction { return connection.QueryResult{ Success: false, Message: fmt.Sprintf("当前数据源(%s)不支持 SQL 编辑器托管事务", runConfig.Type), QueryID: queryID, } } provider, ok := dbInst.(db.SessionExecerProvider) if !ok { return connection.QueryResult{ Success: false, Message: fmt.Sprintf("当前数据源(%s)不支持 SQL 编辑器托管事务", runConfig.Type), QueryID: queryID, } } sessionExecer, err = provider.OpenSessionExecer(ctx) if err != nil { logger.Error(err, "DBQueryMultiTransactional 打开事务会话失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query)) return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID} } startTextTransaction = true } closeSession := true defer func() { if closeSession { if err := sessionExecer.Close(); err != nil { logger.Warnf("DBQueryMultiTransactional 关闭事务会话失败:%v", err) } if transactionCancel != nil { transactionCancel() } } }() if startTextTransaction { if _, err := sessionExecer.ExecContext(ctx, beginSQL); err != nil { logger.Error(err, "DBQueryMultiTransactional 开启事务失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query)) return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID} } } statements := splitSQLStatements(query) resultSets, err := executeManagedSQLTransactionStatements(ctx, sessionExecer, runConfig, statements) if err != nil { var rollbackErr error if transactor != nil { rollbackErr = transactor.Rollback() } else if strings.TrimSpace(rollbackSQL) != "" { _, rollbackErr = sessionExecer.ExecContext(context.Background(), rollbackSQL) } if rollbackErr != nil { logger.Error(rollbackErr, "DBQueryMultiTransactional 执行失败后回滚失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query)) err = fmt.Errorf("%v;回滚失败: %w", err, rollbackErr) } logger.Error(err, "DBQueryMultiTransactional 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query)) return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID} } transactionID := "sql-editor-" + uuid.NewString() a.sqlTransactionMu.Lock() if a.sqlTransactions == nil { a.sqlTransactions = make(map[string]*managedSQLTransaction) } a.sqlTransactions[transactionID] = &managedSQLTransaction{ id: transactionID, execer: sessionExecer, transactor: transactor, cancel: transactionCancel, dbType: runConfig.Type, commitSQL: commitSQL, rollbackSQL: rollbackSQL, createdAt: time.Now(), } a.sqlTransactionMu.Unlock() closeSession = false return connection.QueryResult{ Success: true, Data: resultSets, QueryID: queryID, TransactionID: transactionID, TransactionPending: true, } } func executeManagedSQLTransactionStatements(ctx context.Context, session db.StatementExecer, runConfig connection.ConnectionConfig, statements []string) ([]connection.ResultSetData, error) { var resultSets []connection.ResultSetData sessionQueryTarget, _ := session.(db.StatementQueryExecer) sessionQueryMessageTarget, _ := session.(db.StatementQueryMessageExecer) sessionMultiQueryTarget, _ := session.(db.StatementMultiResultQueryExecer) sessionMultiQueryMessageTarget, _ := session.(db.StatementMultiResultQueryMessageExecer) for idx, stmt := range statements { stmt = strings.TrimSpace(stmt) if stmt == "" { continue } isReadStmt := isReadOnlySQLQuery(runConfig.Type, stmt) tryQueryStmtFirst := shouldTryQueryResultFirst(runConfig.Type, stmt) if isReadStmt || tryQueryStmtFirst { var ( data []map[string]interface{} columns []string messages []string statementResults []connection.ResultSetData usedMultiResult bool err error ) if sessionMultiQueryMessageTarget != nil { statementResults, messages, err = sessionMultiQueryMessageTarget.QueryMultiContextWithMessages(ctx, stmt) usedMultiResult = true } else if sessionMultiQueryTarget != nil { statementResults, err = sessionMultiQueryTarget.QueryMultiContext(ctx, stmt) usedMultiResult = true } else if sessionQueryMessageTarget != nil { data, columns, messages, err = sessionQueryMessageTarget.QueryContextWithMessages(ctx, stmt) } else if sessionQueryTarget != nil { data, columns, err = sessionQueryTarget.QueryContext(ctx, stmt) } else { err = fmt.Errorf("当前事务会话不支持查询语句") } if err == nil { if usedMultiResult { if len(statementResults) == 0 && len(messages) > 0 { statementResults = []connection.ResultSetData{{ Rows: []map[string]interface{}{}, Columns: []string{}, Messages: append([]string(nil), messages...), }} } for _, statementResult := range statementResults { if statementResult.Rows == nil { statementResult.Rows = []map[string]interface{}{} } if statementResult.Columns == nil { statementResult.Columns = []string{} } statementResult.StatementIndex = idx + 1 resultSets = append(resultSets, statementResult) } continue } if data == nil { data = make([]map[string]interface{}, 0) } if columns == nil { columns = []string{} } resultSets = append(resultSets, connection.ResultSetData{ Rows: data, Columns: columns, Messages: messages, StatementIndex: idx + 1, }) continue } if isReadStmt { return nil, fmt.Errorf("第 %d 条语句执行失败: %w", idx+1, err) } } affected, err := session.ExecContext(ctx, stmt) if err != nil { return nil, fmt.Errorf("第 %d 条语句执行失败: %w", idx+1, err) } resultSets = append(resultSets, connection.ResultSetData{ Rows: []map[string]interface{}{{"affectedRows": affected}}, Columns: []string{"affectedRows"}, StatementIndex: idx + 1, }) } if resultSets == nil { resultSets = []connection.ResultSetData{} } return resultSets, nil } func shouldUseManagedSQLTransaction(dbType string, query string) bool { statements := splitSQLStatements(query) hasManagedWrite := false for _, stmt := range statements { stmt = strings.TrimSpace(stmt) if stmt == "" { continue } if isSQLTransactionControlStatement(stmt) { return false } if isReadOnlySQLQuery(dbType, stmt) { continue } if isBatchableWriteSQLStatement(dbType, stmt) { hasManagedWrite = true continue } return false } return hasManagedWrite } func isSQLTransactionControlStatement(stmt string) bool { switch leadingSQLKeyword(stmt) { case "begin", "commit", "rollback", "savepoint", "release": return true case "start": return strings.Contains(strings.ToLower(stmt), "transaction") default: return false } } func (a *App) DBCommitTransaction(transactionID string) connection.QueryResult { return a.finishManagedSQLTransaction(transactionID, true) } func (a *App) DBRollbackTransaction(transactionID string) connection.QueryResult { return a.finishManagedSQLTransaction(transactionID, false) } func (a *App) finishManagedSQLTransaction(transactionID string, commit bool) connection.QueryResult { transactionID = strings.TrimSpace(transactionID) if transactionID == "" { return connection.QueryResult{Success: false, Message: "事务 ID 不能为空"} } a.sqlTransactionMu.Lock() tx, ok := a.sqlTransactions[transactionID] if ok { delete(a.sqlTransactions, transactionID) } a.sqlTransactionMu.Unlock() if !ok || tx == nil || tx.execer == nil { return connection.QueryResult{Success: false, Message: "事务不存在或已结束"} } if tx.cancel != nil { defer tx.cancel() } action := "回滚" sqlText := tx.rollbackSQL if commit { action = "提交" sqlText = tx.commitSQL } ctx, cancel := context.WithTimeout(context.Background(), sqlEditorTransactionFinishTimeout) defer cancel() var execErr error if tx.transactor != nil { if commit { execErr = tx.transactor.Commit() } else { execErr = tx.transactor.Rollback() } } else if strings.TrimSpace(sqlText) != "" { _, execErr = tx.execer.ExecContext(ctx, sqlText) } closeErr := tx.execer.Close() if execErr != nil { logger.Error(execErr, "SQL 编辑器事务%s失败:id=%s dbType=%s", action, transactionID, tx.dbType) return connection.QueryResult{Success: false, Message: fmt.Sprintf("事务%s失败: %v", action, execErr)} } if closeErr != nil { logger.Error(closeErr, "SQL 编辑器事务%s后关闭会话失败:id=%s dbType=%s", action, transactionID, tx.dbType) return connection.QueryResult{Success: false, Message: fmt.Sprintf("事务%s成功,但关闭会话失败: %v", action, closeErr)} } if commit { return connection.QueryResult{Success: true, Message: "事务已提交"} } return connection.QueryResult{Success: true, Message: "事务已回滚"} } func (a *App) rollbackPendingSQLTransactionsOnShutdown() { a.sqlTransactionMu.Lock() pending := make([]*managedSQLTransaction, 0, len(a.sqlTransactions)) for id, tx := range a.sqlTransactions { if tx != nil { pending = append(pending, tx) } delete(a.sqlTransactions, id) } a.sqlTransactionMu.Unlock() for _, tx := range pending { ctx, cancel := context.WithTimeout(context.Background(), sqlEditorTransactionFinishTimeout) if tx.transactor != nil { if err := tx.transactor.Rollback(); err != nil { logger.Warnf("关闭应用时回滚 SQL 编辑器事务失败:id=%s dbType=%s err=%v", tx.id, tx.dbType, err) } } else if strings.TrimSpace(tx.rollbackSQL) != "" && tx.execer != nil { if _, err := tx.execer.ExecContext(ctx, tx.rollbackSQL); err != nil { logger.Warnf("关闭应用时回滚 SQL 编辑器事务失败:id=%s dbType=%s err=%v", tx.id, tx.dbType, err) } } cancel() if tx.cancel != nil { tx.cancel() } if tx.execer != nil { if err := tx.execer.Close(); err != nil { logger.Warnf("关闭应用时关闭 SQL 编辑器事务会话失败:id=%s dbType=%s err=%v", tx.id, tx.dbType, err) } } } }