mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-07-02 05:31:21 +08:00
✨ feat(editor): 支持 SQL 编辑器增删改事务提交
- 为 SQL 编辑器 DML 新增后端托管事务会话和提交回滚接口 - 增加手动提交与自动提交设置,并显示待提交状态 - 补充前后端事务执行、提交、回滚和自动提交测试
This commit is contained in:
@@ -55,6 +55,15 @@ type queryContext struct {
|
||||
started time.Time
|
||||
}
|
||||
|
||||
type managedSQLTransaction struct {
|
||||
id string
|
||||
execer db.StatementExecer
|
||||
dbType string
|
||||
commitSQL string
|
||||
rollbackSQL string
|
||||
createdAt time.Time
|
||||
}
|
||||
|
||||
// App struct
|
||||
type App struct {
|
||||
ctx context.Context
|
||||
@@ -68,6 +77,8 @@ type App struct {
|
||||
configDir string
|
||||
secretStore secretstore.SecretStore
|
||||
runningQueries map[string]queryContext // queryID -> cancelFunc and start time
|
||||
sqlTransactionMu sync.Mutex
|
||||
sqlTransactions map[string]*managedSQLTransaction
|
||||
jvmPreviewTokenMu sync.Mutex
|
||||
jvmPreviewTokens map[string]jvmPreviewConfirmationToken
|
||||
jvmPreviewTokenTTL time.Duration
|
||||
@@ -86,6 +97,7 @@ func NewAppWithSecretStore(store secretstore.SecretStore) *App {
|
||||
dbCache: make(map[string]cachedDatabase),
|
||||
connectFailures: make(map[string]cachedConnectFailure),
|
||||
runningQueries: make(map[string]queryContext),
|
||||
sqlTransactions: make(map[string]*managedSQLTransaction),
|
||||
configDir: resolveAppConfigDir(),
|
||||
secretStore: store,
|
||||
jvmPreviewTokens: make(map[string]jvmPreviewConfirmationToken),
|
||||
@@ -167,6 +179,7 @@ func (a *App) LogWindowDiagnostic(stage string, payload string) {
|
||||
// Shutdown is called when the app terminates
|
||||
func (a *App) Shutdown(ctx context.Context) {
|
||||
logger.Infof("应用开始关闭,准备释放资源")
|
||||
a.rollbackPendingSQLTransactionsOnShutdown()
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
for _, dbInst := range a.dbCache {
|
||||
|
||||
@@ -2,6 +2,7 @@ package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
@@ -21,6 +22,7 @@ type fakeBatchWriteDB struct {
|
||||
messageMap map[string][]string
|
||||
multiResult map[string][]connection.ResultSetData
|
||||
queryErr map[string]error
|
||||
execErr map[string]error
|
||||
execAffected map[string]int64
|
||||
session *fakeBatchWriteSession
|
||||
}
|
||||
@@ -53,6 +55,9 @@ func (f *fakeBatchWriteDB) QueryWithMessages(query string) ([]map[string]interfa
|
||||
func (f *fakeBatchWriteDB) Exec(query string) (int64, error) {
|
||||
f.execCalls++
|
||||
f.execQueries = append(f.execQueries, query)
|
||||
if err := f.execErr[query]; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if affected, ok := f.execAffected[query]; ok {
|
||||
return affected, nil
|
||||
}
|
||||
@@ -95,6 +100,9 @@ func (f *fakeBatchWriteDB) ExecContext(ctx context.Context, query string) (int64
|
||||
f.lastCtx = ctx
|
||||
f.execCalls++
|
||||
f.execQueries = append(f.execQueries, query)
|
||||
if err := f.execErr[query]; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if affected, ok := f.execAffected[query]; ok {
|
||||
return affected, nil
|
||||
}
|
||||
@@ -506,6 +514,185 @@ func TestDBQueryMultiPreservesPerStatementResultsForMultipleWriteStatements(t *t
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiTransactionalKeepsDMLTransactionOpenUntilCommit(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
})
|
||||
|
||||
firstStmt := "UPDATE users SET name = 'new' WHERE id = 1"
|
||||
secondStmt := "DELETE FROM audit_logs WHERE user_id = 1"
|
||||
fakeDB := &fakeBatchWriteDB{
|
||||
execAffected: map[string]int64{
|
||||
firstStmt: 1,
|
||||
secondStmt: 3,
|
||||
},
|
||||
}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
config := connection.ConnectionConfig{Type: "mysql", Host: "127.0.0.1", Port: 3306, User: "root"}
|
||||
|
||||
result := app.DBQueryMultiTransactional(config, "main", firstStmt+";\n"+secondStmt+";", "tx-query")
|
||||
if !result.Success {
|
||||
t.Fatalf("expected transactional query success, got failure: %s", result.Message)
|
||||
}
|
||||
if result.TransactionID == "" || !result.TransactionPending {
|
||||
t.Fatalf("expected pending transaction metadata, got id=%q pending=%v", result.TransactionID, result.TransactionPending)
|
||||
}
|
||||
if fakeDB.session == nil {
|
||||
t.Fatal("expected transactional query to open a pinned session")
|
||||
}
|
||||
if fakeDB.session.closed {
|
||||
t.Fatal("expected transaction session to stay open before commit")
|
||||
}
|
||||
wantExecs := []string{"START TRANSACTION", firstStmt, secondStmt}
|
||||
if len(fakeDB.execQueries) != len(wantExecs) {
|
||||
t.Fatalf("expected exec queries %#v, got %#v", wantExecs, fakeDB.execQueries)
|
||||
}
|
||||
for i, want := range wantExecs {
|
||||
if fakeDB.execQueries[i] != want {
|
||||
t.Fatalf("expected exec query %d = %q, got %q", i, want, fakeDB.execQueries[i])
|
||||
}
|
||||
}
|
||||
|
||||
resultSets, ok := result.Data.([]connection.ResultSetData)
|
||||
if !ok {
|
||||
t.Fatalf("expected []connection.ResultSetData, got %T", result.Data)
|
||||
}
|
||||
if len(resultSets) != 2 {
|
||||
t.Fatalf("expected one affectedRows result per DML statement, got %#v", resultSets)
|
||||
}
|
||||
if got := resultSets[0].Rows[0]["affectedRows"]; got != int64(1) {
|
||||
t.Fatalf("expected first affectedRows=1, got %#v", got)
|
||||
}
|
||||
if got := resultSets[1].Rows[0]["affectedRows"]; got != int64(3) {
|
||||
t.Fatalf("expected second affectedRows=3, got %#v", got)
|
||||
}
|
||||
|
||||
commitResult := app.DBCommitTransaction(result.TransactionID)
|
||||
if !commitResult.Success {
|
||||
t.Fatalf("expected commit success, got failure: %s", commitResult.Message)
|
||||
}
|
||||
if !fakeDB.session.closed {
|
||||
t.Fatal("expected transaction session to close after commit")
|
||||
}
|
||||
if got := fakeDB.execQueries[len(fakeDB.execQueries)-1]; got != "COMMIT" {
|
||||
t.Fatalf("expected final exec to be COMMIT, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiTransactionalRollsBackAndClosesOnDMLFailure(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
})
|
||||
|
||||
firstStmt := "UPDATE users SET name = 'new' WHERE id = 1"
|
||||
secondStmt := "DELETE FROM audit_logs WHERE user_id = 1"
|
||||
fakeDB := &fakeBatchWriteDB{
|
||||
execErr: map[string]error{
|
||||
secondStmt: errors.New("delete failed"),
|
||||
},
|
||||
}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
config := connection.ConnectionConfig{Type: "mysql", Host: "127.0.0.1", Port: 3306, User: "root"}
|
||||
|
||||
result := app.DBQueryMultiTransactional(config, "main", firstStmt+";\n"+secondStmt+";", "tx-query")
|
||||
if result.Success {
|
||||
t.Fatal("expected transactional query failure")
|
||||
}
|
||||
if result.TransactionID != "" || result.TransactionPending {
|
||||
t.Fatalf("expected failed transaction not to be exposed, got id=%q pending=%v", result.TransactionID, result.TransactionPending)
|
||||
}
|
||||
if fakeDB.session == nil || !fakeDB.session.closed {
|
||||
t.Fatal("expected failed transaction session to close")
|
||||
}
|
||||
wantExecs := []string{"START TRANSACTION", firstStmt, secondStmt, "ROLLBACK"}
|
||||
if len(fakeDB.execQueries) != len(wantExecs) {
|
||||
t.Fatalf("expected exec queries %#v, got %#v", wantExecs, fakeDB.execQueries)
|
||||
}
|
||||
for i, want := range wantExecs {
|
||||
if fakeDB.execQueries[i] != want {
|
||||
t.Fatalf("expected exec query %d = %q, got %q", i, want, fakeDB.execQueries[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiTransactionalSkipsManagedTransactionForReadOnlySQL(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
})
|
||||
|
||||
query := "SELECT 1 AS value"
|
||||
fakeDB := &fakeBatchWriteDB{
|
||||
queryMap: map[string][]map[string]interface{}{
|
||||
query: {{"value": 1}},
|
||||
},
|
||||
fieldMap: map[string][]string{
|
||||
query: {"value"},
|
||||
},
|
||||
queryErr: map[string]error{},
|
||||
}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
config := connection.ConnectionConfig{Type: "mysql", Host: "127.0.0.1", Port: 3306, User: "root"}
|
||||
|
||||
result := app.DBQueryMultiTransactional(config, "main", query, "read-query")
|
||||
if !result.Success {
|
||||
t.Fatalf("expected read-only query success, got failure: %s", result.Message)
|
||||
}
|
||||
if result.TransactionID != "" || result.TransactionPending {
|
||||
t.Fatalf("expected read-only query not to start managed transaction, got id=%q pending=%v", result.TransactionID, result.TransactionPending)
|
||||
}
|
||||
if len(fakeDB.execQueries) != 0 {
|
||||
t.Fatalf("expected no transaction wrapper execs for read-only query, got %#v", fakeDB.execQueries)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiTransactionalSkipsManagedTransactionForExplicitTransactionSQL(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
})
|
||||
|
||||
stmt := "UPDATE users SET name = 'new' WHERE id = 1"
|
||||
fakeDB := &fakeBatchWriteDB{}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
config := connection.ConnectionConfig{Type: "mysql", Host: "127.0.0.1", Port: 3306, User: "root"}
|
||||
|
||||
result := app.DBQueryMultiTransactional(config, "main", "BEGIN;\n"+stmt+";\nCOMMIT;", "explicit-tx-query")
|
||||
if !result.Success {
|
||||
t.Fatalf("expected explicit transaction SQL success, got failure: %s", result.Message)
|
||||
}
|
||||
if result.TransactionID != "" || result.TransactionPending {
|
||||
t.Fatalf("expected explicit transaction SQL not to be managed, got id=%q pending=%v", result.TransactionID, result.TransactionPending)
|
||||
}
|
||||
if len(fakeDB.execQueries) != 3 {
|
||||
t.Fatalf("expected explicit transaction statements only, got %#v", fakeDB.execQueries)
|
||||
}
|
||||
if fakeDB.execQueries[0] != "BEGIN" || fakeDB.execQueries[1] != stmt || fakeDB.execQueries[2] != "COMMIT" {
|
||||
t.Fatalf("expected explicit transaction statements unchanged, got %#v", fakeDB.execQueries)
|
||||
}
|
||||
if fakeDB.session == nil || !fakeDB.session.closed {
|
||||
t.Fatal("expected normal DBQueryMulti session to close after explicit transaction SQL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiPrefersResultSetForExecStoredProcedure(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
|
||||
334
internal/app/methods_db_transaction.go
Normal file
334
internal/app/methods_db_transaction.go
Normal file
@@ -0,0 +1,334 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/db"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const sqlEditorTransactionFinishTimeout = 30 * time.Second
|
||||
|
||||
// DBQueryMultiTransactional executes SQL editor DML in a managed transaction.
|
||||
// The transaction stays open until DBCommitTransaction or DBRollbackTransaction
|
||||
// is called by the SQL editor UI.
|
||||
func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbName string, query string, queryID string) connection.QueryResult {
|
||||
runConfig := normalizeRunConfig(config, dbName)
|
||||
|
||||
if queryID == "" {
|
||||
queryID = generateQueryID()
|
||||
}
|
||||
|
||||
query = sanitizeSQLForPgLike(resolveDDLDBType(config), query)
|
||||
if !shouldUseManagedSQLTransaction(runConfig.Type, query) {
|
||||
return a.DBQueryMulti(config, dbName, query, queryID)
|
||||
}
|
||||
|
||||
beginSQL, commitSQL, rollbackSQL, ok := sqlFileBatchTransactionSQL(runConfig.Type)
|
||||
if !ok {
|
||||
return connection.QueryResult{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("当前数据源(%s)不支持 SQL 编辑器托管事务", runConfig.Type),
|
||||
QueryID: queryID,
|
||||
}
|
||||
}
|
||||
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
logger.Error(err, "DBQueryMultiTransactional 获取连接失败:%s", formatConnSummary(runConfig))
|
||||
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
|
||||
}
|
||||
|
||||
provider, ok := dbInst.(db.SessionExecerProvider)
|
||||
if !ok {
|
||||
return connection.QueryResult{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("当前数据源(%s)不支持 SQL 编辑器托管事务", runConfig.Type),
|
||||
QueryID: queryID,
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := newQueryExecutionContext(runConfig)
|
||||
defer cancel()
|
||||
|
||||
a.queryMu.Lock()
|
||||
a.runningQueries[queryID] = queryContext{
|
||||
cancel: cancel,
|
||||
started: time.Now(),
|
||||
}
|
||||
a.queryMu.Unlock()
|
||||
defer func() {
|
||||
a.queryMu.Lock()
|
||||
delete(a.runningQueries, queryID)
|
||||
a.queryMu.Unlock()
|
||||
}()
|
||||
|
||||
sessionExecer, err := provider.OpenSessionExecer(ctx)
|
||||
if err != nil {
|
||||
logger.Error(err, "DBQueryMultiTransactional 打开事务会话失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
|
||||
}
|
||||
|
||||
closeSession := true
|
||||
defer func() {
|
||||
if closeSession {
|
||||
if err := sessionExecer.Close(); err != nil {
|
||||
logger.Warnf("DBQueryMultiTransactional 关闭事务会话失败:%v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := sessionExecer.ExecContext(ctx, beginSQL); err != nil {
|
||||
logger.Error(err, "DBQueryMultiTransactional 开启事务失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
|
||||
}
|
||||
|
||||
statements := splitSQLStatements(query)
|
||||
resultSets, err := executeManagedSQLTransactionStatements(ctx, sessionExecer, runConfig, statements)
|
||||
if err != nil {
|
||||
if _, rollbackErr := sessionExecer.ExecContext(context.Background(), rollbackSQL); rollbackErr != nil {
|
||||
logger.Error(rollbackErr, "DBQueryMultiTransactional 执行失败后回滚失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
err = fmt.Errorf("%v;回滚失败: %w", err, rollbackErr)
|
||||
}
|
||||
logger.Error(err, "DBQueryMultiTransactional 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
|
||||
}
|
||||
|
||||
transactionID := "sql-editor-" + uuid.NewString()
|
||||
a.sqlTransactionMu.Lock()
|
||||
if a.sqlTransactions == nil {
|
||||
a.sqlTransactions = make(map[string]*managedSQLTransaction)
|
||||
}
|
||||
a.sqlTransactions[transactionID] = &managedSQLTransaction{
|
||||
id: transactionID,
|
||||
execer: sessionExecer,
|
||||
dbType: runConfig.Type,
|
||||
commitSQL: commitSQL,
|
||||
rollbackSQL: rollbackSQL,
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
a.sqlTransactionMu.Unlock()
|
||||
|
||||
closeSession = false
|
||||
return connection.QueryResult{
|
||||
Success: true,
|
||||
Data: resultSets,
|
||||
QueryID: queryID,
|
||||
TransactionID: transactionID,
|
||||
TransactionPending: true,
|
||||
}
|
||||
}
|
||||
|
||||
func executeManagedSQLTransactionStatements(ctx context.Context, session db.StatementExecer, runConfig connection.ConnectionConfig, statements []string) ([]connection.ResultSetData, error) {
|
||||
var resultSets []connection.ResultSetData
|
||||
sessionQueryTarget, _ := session.(db.StatementQueryExecer)
|
||||
sessionQueryMessageTarget, _ := session.(db.StatementQueryMessageExecer)
|
||||
sessionMultiQueryTarget, _ := session.(db.StatementMultiResultQueryExecer)
|
||||
sessionMultiQueryMessageTarget, _ := session.(db.StatementMultiResultQueryMessageExecer)
|
||||
|
||||
for idx, stmt := range statements {
|
||||
stmt = strings.TrimSpace(stmt)
|
||||
if stmt == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
isReadStmt := isReadOnlySQLQuery(runConfig.Type, stmt)
|
||||
tryQueryStmtFirst := shouldTryQueryResultFirst(runConfig.Type, stmt)
|
||||
if isReadStmt || tryQueryStmtFirst {
|
||||
var (
|
||||
data []map[string]interface{}
|
||||
columns []string
|
||||
messages []string
|
||||
statementResults []connection.ResultSetData
|
||||
usedMultiResult bool
|
||||
err error
|
||||
)
|
||||
if sessionMultiQueryMessageTarget != nil {
|
||||
statementResults, messages, err = sessionMultiQueryMessageTarget.QueryMultiContextWithMessages(ctx, stmt)
|
||||
usedMultiResult = true
|
||||
} else if sessionMultiQueryTarget != nil {
|
||||
statementResults, err = sessionMultiQueryTarget.QueryMultiContext(ctx, stmt)
|
||||
usedMultiResult = true
|
||||
} else if sessionQueryMessageTarget != nil {
|
||||
data, columns, messages, err = sessionQueryMessageTarget.QueryContextWithMessages(ctx, stmt)
|
||||
} else if sessionQueryTarget != nil {
|
||||
data, columns, err = sessionQueryTarget.QueryContext(ctx, stmt)
|
||||
} else {
|
||||
err = fmt.Errorf("当前事务会话不支持查询语句")
|
||||
}
|
||||
if err == nil {
|
||||
if usedMultiResult {
|
||||
if len(statementResults) == 0 && len(messages) > 0 {
|
||||
statementResults = []connection.ResultSetData{{
|
||||
Rows: []map[string]interface{}{},
|
||||
Columns: []string{},
|
||||
Messages: append([]string(nil), messages...),
|
||||
}}
|
||||
}
|
||||
for _, statementResult := range statementResults {
|
||||
if statementResult.Rows == nil {
|
||||
statementResult.Rows = []map[string]interface{}{}
|
||||
}
|
||||
if statementResult.Columns == nil {
|
||||
statementResult.Columns = []string{}
|
||||
}
|
||||
statementResult.StatementIndex = idx + 1
|
||||
resultSets = append(resultSets, statementResult)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if data == nil {
|
||||
data = make([]map[string]interface{}, 0)
|
||||
}
|
||||
if columns == nil {
|
||||
columns = []string{}
|
||||
}
|
||||
resultSets = append(resultSets, connection.ResultSetData{
|
||||
Rows: data,
|
||||
Columns: columns,
|
||||
Messages: messages,
|
||||
StatementIndex: idx + 1,
|
||||
})
|
||||
continue
|
||||
}
|
||||
if isReadStmt {
|
||||
return nil, fmt.Errorf("第 %d 条语句执行失败: %w", idx+1, err)
|
||||
}
|
||||
}
|
||||
|
||||
affected, err := session.ExecContext(ctx, stmt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("第 %d 条语句执行失败: %w", idx+1, err)
|
||||
}
|
||||
resultSets = append(resultSets, connection.ResultSetData{
|
||||
Rows: []map[string]interface{}{{"affectedRows": affected}},
|
||||
Columns: []string{"affectedRows"},
|
||||
StatementIndex: idx + 1,
|
||||
})
|
||||
}
|
||||
|
||||
if resultSets == nil {
|
||||
resultSets = []connection.ResultSetData{}
|
||||
}
|
||||
return resultSets, nil
|
||||
}
|
||||
|
||||
func shouldUseManagedSQLTransaction(dbType string, query string) bool {
|
||||
statements := splitSQLStatements(query)
|
||||
hasManagedWrite := false
|
||||
for _, stmt := range statements {
|
||||
stmt = strings.TrimSpace(stmt)
|
||||
if stmt == "" {
|
||||
continue
|
||||
}
|
||||
if isSQLTransactionControlStatement(stmt) {
|
||||
return false
|
||||
}
|
||||
if isReadOnlySQLQuery(dbType, stmt) {
|
||||
continue
|
||||
}
|
||||
if isBatchableWriteSQLStatement(dbType, stmt) {
|
||||
hasManagedWrite = true
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
return hasManagedWrite
|
||||
}
|
||||
|
||||
func isSQLTransactionControlStatement(stmt string) bool {
|
||||
switch leadingSQLKeyword(stmt) {
|
||||
case "begin", "commit", "rollback", "savepoint", "release":
|
||||
return true
|
||||
case "start":
|
||||
return strings.Contains(strings.ToLower(stmt), "transaction")
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) DBCommitTransaction(transactionID string) connection.QueryResult {
|
||||
return a.finishManagedSQLTransaction(transactionID, true)
|
||||
}
|
||||
|
||||
func (a *App) DBRollbackTransaction(transactionID string) connection.QueryResult {
|
||||
return a.finishManagedSQLTransaction(transactionID, false)
|
||||
}
|
||||
|
||||
func (a *App) finishManagedSQLTransaction(transactionID string, commit bool) connection.QueryResult {
|
||||
transactionID = strings.TrimSpace(transactionID)
|
||||
if transactionID == "" {
|
||||
return connection.QueryResult{Success: false, Message: "事务 ID 不能为空"}
|
||||
}
|
||||
|
||||
a.sqlTransactionMu.Lock()
|
||||
tx, ok := a.sqlTransactions[transactionID]
|
||||
if ok {
|
||||
delete(a.sqlTransactions, transactionID)
|
||||
}
|
||||
a.sqlTransactionMu.Unlock()
|
||||
if !ok || tx == nil || tx.execer == nil {
|
||||
return connection.QueryResult{Success: false, Message: "事务不存在或已结束"}
|
||||
}
|
||||
|
||||
action := "回滚"
|
||||
sqlText := tx.rollbackSQL
|
||||
if commit {
|
||||
action = "提交"
|
||||
sqlText = tx.commitSQL
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), sqlEditorTransactionFinishTimeout)
|
||||
defer cancel()
|
||||
|
||||
var execErr error
|
||||
if strings.TrimSpace(sqlText) != "" {
|
||||
_, execErr = tx.execer.ExecContext(ctx, sqlText)
|
||||
}
|
||||
closeErr := tx.execer.Close()
|
||||
if execErr != nil {
|
||||
logger.Error(execErr, "SQL 编辑器事务%s失败:id=%s dbType=%s", action, transactionID, tx.dbType)
|
||||
return connection.QueryResult{Success: false, Message: fmt.Sprintf("事务%s失败: %v", action, execErr)}
|
||||
}
|
||||
if closeErr != nil {
|
||||
logger.Error(closeErr, "SQL 编辑器事务%s后关闭会话失败:id=%s dbType=%s", action, transactionID, tx.dbType)
|
||||
return connection.QueryResult{Success: false, Message: fmt.Sprintf("事务%s成功,但关闭会话失败: %v", action, closeErr)}
|
||||
}
|
||||
|
||||
if commit {
|
||||
return connection.QueryResult{Success: true, Message: "事务已提交"}
|
||||
}
|
||||
return connection.QueryResult{Success: true, Message: "事务已回滚"}
|
||||
}
|
||||
|
||||
func (a *App) rollbackPendingSQLTransactionsOnShutdown() {
|
||||
a.sqlTransactionMu.Lock()
|
||||
pending := make([]*managedSQLTransaction, 0, len(a.sqlTransactions))
|
||||
for id, tx := range a.sqlTransactions {
|
||||
if tx != nil {
|
||||
pending = append(pending, tx)
|
||||
}
|
||||
delete(a.sqlTransactions, id)
|
||||
}
|
||||
a.sqlTransactionMu.Unlock()
|
||||
|
||||
for _, tx := range pending {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), sqlEditorTransactionFinishTimeout)
|
||||
if strings.TrimSpace(tx.rollbackSQL) != "" && tx.execer != nil {
|
||||
if _, err := tx.execer.ExecContext(ctx, tx.rollbackSQL); err != nil {
|
||||
logger.Warnf("关闭应用时回滚 SQL 编辑器事务失败:id=%s dbType=%s err=%v", tx.id, tx.dbType, err)
|
||||
}
|
||||
}
|
||||
cancel()
|
||||
if tx.execer != nil {
|
||||
if err := tx.execer.Close(); err != nil {
|
||||
logger.Warnf("关闭应用时关闭 SQL 编辑器事务会话失败:id=%s dbType=%s err=%v", tx.id, tx.dbType, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -130,12 +130,14 @@ type ResultSetData struct {
|
||||
|
||||
// QueryResult 是 Wails 绑定方法的统一响应格式,前端通过此结构体接收后端结果。
|
||||
type QueryResult struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data"`
|
||||
Fields []string `json:"fields,omitempty"`
|
||||
Messages []string `json:"messages,omitempty"`
|
||||
QueryID string `json:"queryId,omitempty"` // Unique ID for query cancellation
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data"`
|
||||
Fields []string `json:"fields,omitempty"`
|
||||
Messages []string `json:"messages,omitempty"`
|
||||
QueryID string `json:"queryId,omitempty"` // Unique ID for query cancellation
|
||||
TransactionID string `json:"transactionId,omitempty"`
|
||||
TransactionPending bool `json:"transactionPending,omitempty"`
|
||||
}
|
||||
|
||||
// ColumnDefinition 描述表的一个列定义。
|
||||
|
||||
Reference in New Issue
Block a user