🐛 fix(query-editor): 修正 SQL 编辑器 DML 事务识别

- 统一前后端 DML 与数据修改 CTE 的受管事务判断

- 保留数据修改 CTE 返回行并补充事务回归测试

- 明确 SQL 编辑器事务提交策略文案
This commit is contained in:
Syngnat
2026-06-10 19:13:54 +08:00
parent cf8f9be8dc
commit 89639e36bc
9 changed files with 206 additions and 37 deletions

View File

@@ -1007,6 +1007,9 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
func shouldTryQueryResultFirst(dbType string, query string) bool {
isSQLServer := strings.EqualFold(strings.TrimSpace(dbType), "sqlserver")
if keyword, withHasWrite := sqlDataOperationInfo(query); withHasWrite && keyword == "select" {
return true
}
keyword := leadingSQLKeyword(query)
switch keyword {
case "exec", "execute", "call":

View File

@@ -624,6 +624,59 @@ func TestDBQueryMultiTransactionalTreatsWithDMLAsManagedWrite(t *testing.T) {
}
}
func TestDBQueryMultiTransactionalTreatsDataChangingCTEAsManagedWrite(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
})
stmt := "WITH moved AS (DELETE FROM audit_logs WHERE created_at < NOW() RETURNING id) SELECT * FROM moved"
fakeDB := &fakeBatchWriteDB{
queryMap: map[string][]map[string]interface{}{
stmt: {{"id": 41}, {"id": 42}},
},
fieldMap: map[string][]string{
stmt: {"id"},
},
}
newDatabaseFunc = func(dbType string) (db.Database, error) {
return fakeDB, nil
}
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
config := connection.ConnectionConfig{Type: "postgres", Host: "127.0.0.1", Port: 5432, User: "postgres"}
result := app.DBQueryMultiTransactional(config, "main", stmt, "cte-write-query")
if !result.Success {
t.Fatalf("expected transactional data-changing CTE success, got failure: %s", result.Message)
}
if result.TransactionID == "" || !result.TransactionPending {
t.Fatalf("expected pending transaction metadata, got id=%q pending=%v", result.TransactionID, result.TransactionPending)
}
if fakeDB.session == nil || fakeDB.session.closed {
t.Fatal("expected data-changing CTE transaction session to stay open")
}
wantExecs := []string{"BEGIN"}
if len(fakeDB.execQueries) != len(wantExecs) {
t.Fatalf("expected exec queries %#v, got %#v", wantExecs, fakeDB.execQueries)
}
for i, want := range wantExecs {
if fakeDB.execQueries[i] != want {
t.Fatalf("expected exec query %d = %q, got %q", i, want, fakeDB.execQueries[i])
}
}
if fakeDB.session.queryCalls == 0 {
t.Fatal("expected data-changing CTE SELECT to query returned rows inside the transaction")
}
resultSets, ok := result.Data.([]connection.ResultSetData)
if !ok {
t.Fatalf("expected []connection.ResultSetData, got %T", result.Data)
}
if len(resultSets) != 1 || len(resultSets[0].Rows) != 2 {
t.Fatalf("expected returned rows from data-changing CTE, got %#v", resultSets)
}
}
func TestDBQueryMultiTransactionalRollsBackAndClosesOnDMLFailure(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {

View File

@@ -53,14 +53,19 @@ func leadingSQLKeyword(query string) string {
}
func sqlDataOperationKeyword(query string) string {
keyword, _ := sqlDataOperationInfo(query)
return keyword
}
func sqlDataOperationInfo(query string) (keyword string, withHasWrite bool) {
keyword, keywordEnd := nextSQLKeyword(query, 0)
if keyword != "with" {
return keyword
return keyword, false
}
if withKeyword, ok := sqlKeywordAfterLeadingWith(query, keywordEnd); ok {
return withKeyword
if withKeyword, hasWrite, ok := sqlKeywordAfterLeadingWith(query, keywordEnd); ok {
return withKeyword, hasWrite
}
return keyword
return keyword, false
}
func nextSQLKeyword(text string, start int) (string, int) {
@@ -106,8 +111,9 @@ func skipSQLTrivia(text string, start int) int {
return pos
}
func sqlKeywordAfterLeadingWith(text string, start int) (string, bool) {
func sqlKeywordAfterLeadingWith(text string, start int) (string, bool, bool) {
pos := skipSQLTrivia(text, start)
hasWriteCTE := false
if keyword, end := nextSQLKeyword(text, pos); keyword == "recursive" {
pos = end
}
@@ -116,20 +122,20 @@ func sqlKeywordAfterLeadingWith(text string, start int) (string, bool) {
pos = skipSQLTrivia(text, pos)
next, ok := skipSQLIdentifierToken(text, pos)
if !ok {
return "", false
return "", hasWriteCTE, false
}
pos = skipSQLTrivia(text, next)
if pos < len(text) && text[pos] == '(' {
next = skipBalancedSQLParens(text, pos)
if next < 0 {
return "", false
return "", hasWriteCTE, false
}
pos = skipSQLTrivia(text, next)
}
asEnd := findTopLevelSQLKeyword(text, pos, "as")
if asEnd < 0 {
return "", false
return "", hasWriteCTE, false
}
pos = skipSQLTrivia(text, asEnd)
if keyword, end := nextSQLKeyword(text, pos); keyword == "not" {
@@ -142,11 +148,19 @@ func sqlKeywordAfterLeadingWith(text string, start int) (string, bool) {
pos = skipSQLTrivia(text, pos)
if pos >= len(text) || text[pos] != '(' {
return "", false
return "", hasWriteCTE, false
}
cteBodyStart := pos + 1
next = skipBalancedSQLParens(text, pos)
if next < 0 {
return "", false
return "", hasWriteCTE, false
}
cteBodyEnd := next - 1
if cteBodyEnd >= cteBodyStart {
bodyKeyword, bodyHasWrite := sqlDataOperationInfo(text[cteBodyStart:cteBodyEnd])
if bodyHasWrite || isSQLDataWriteKeyword(bodyKeyword) {
hasWriteCTE = true
}
}
pos = skipSQLTrivia(text, next)
if pos < len(text) && text[pos] == ',' {
@@ -155,7 +169,7 @@ func sqlKeywordAfterLeadingWith(text string, start int) (string, bool) {
}
keyword, _ := nextSQLKeyword(text, pos)
return keyword, keyword != ""
return keyword, hasWriteCTE, keyword != ""
}
}
@@ -327,7 +341,11 @@ func isReadOnlySQLQuery(dbType string, query string) bool {
return true
}
switch sqlDataOperationKeyword(query) {
keyword, withHasWrite := sqlDataOperationInfo(query)
if withHasWrite {
return false
}
switch keyword {
case "select", "with", "show", "describe", "desc", "explain", "pragma", "values":
return true
default:
@@ -340,7 +358,15 @@ func isBatchableWriteSQLStatement(dbType string, query string) bool {
return false
}
switch sqlDataOperationKeyword(query) {
keyword, withHasWrite := sqlDataOperationInfo(query)
if withHasWrite {
return true
}
return isSQLDataWriteKeyword(keyword)
}
func isSQLDataWriteKeyword(keyword string) bool {
switch keyword {
case "insert", "update", "delete", "replace", "merge", "upsert":
return true
default:

View File

@@ -79,6 +79,11 @@ func TestIsReadOnlySQLQuery_ClassifiesWithByTopLevelOperation(t *testing.T) {
if isReadOnlySQLQuery("postgres", writeQuery) {
t.Fatal("WITH ... UPDATE should not be treated as read-only")
}
writeCTEQuery := "WITH moved AS (DELETE FROM audit_logs WHERE created_at < NOW() RETURNING id) SELECT * FROM moved"
if isReadOnlySQLQuery("postgres", writeCTEQuery) {
t.Fatal("data-changing CTE should not be treated as read-only")
}
}
func TestIsBatchableWriteSQLStatement_OnlyMatchesRealWriteStatements(t *testing.T) {
@@ -88,6 +93,9 @@ func TestIsBatchableWriteSQLStatement_OnlyMatchesRealWriteStatements(t *testing.
if !isBatchableWriteSQLStatement("postgres", "WITH target AS (SELECT id FROM users) DELETE FROM users WHERE id IN (SELECT id FROM target)") {
t.Fatal("expected WITH ... DELETE to be treated as batchable write")
}
if !isBatchableWriteSQLStatement("postgres", "WITH moved AS (DELETE FROM audit_logs WHERE created_at < NOW() RETURNING id) SELECT * FROM moved") {
t.Fatal("expected data-changing CTE to be treated as batchable write")
}
if isBatchableWriteSQLStatement("sqlserver", "EXEC sp_who2") {
t.Fatal("EXEC should not be treated as batchable write")
}
@@ -116,3 +124,10 @@ func TestShouldTryQueryResultFirst_TreatsSQLServerSystemCommandsAsQueryFirst(t *
t.Fatal("non-SQLServer system procedure name should not force query-first")
}
}
func TestShouldTryQueryResultFirst_TreatsDataChangingCTESelectAsQueryFirst(t *testing.T) {
query := "WITH moved AS (DELETE FROM audit_logs WHERE created_at < NOW() RETURNING id) SELECT * FROM moved"
if !shouldTryQueryResultFirst("postgres", query) {
t.Fatal("data-changing CTE ending in SELECT should try query-first to preserve returned rows")
}
}