mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-07-01 09:51:23 +08:00
🐛 fix(query-editor): 修正 SQL 编辑器 DML 事务识别
- 统一前后端 DML 与数据修改 CTE 的受管事务判断 - 保留数据修改 CTE 返回行并补充事务回归测试 - 明确 SQL 编辑器事务提交策略文案
This commit is contained in:
@@ -1007,6 +1007,9 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
|
||||
|
||||
func shouldTryQueryResultFirst(dbType string, query string) bool {
|
||||
isSQLServer := strings.EqualFold(strings.TrimSpace(dbType), "sqlserver")
|
||||
if keyword, withHasWrite := sqlDataOperationInfo(query); withHasWrite && keyword == "select" {
|
||||
return true
|
||||
}
|
||||
keyword := leadingSQLKeyword(query)
|
||||
switch keyword {
|
||||
case "exec", "execute", "call":
|
||||
|
||||
@@ -624,6 +624,59 @@ func TestDBQueryMultiTransactionalTreatsWithDMLAsManagedWrite(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiTransactionalTreatsDataChangingCTEAsManagedWrite(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
})
|
||||
|
||||
stmt := "WITH moved AS (DELETE FROM audit_logs WHERE created_at < NOW() RETURNING id) SELECT * FROM moved"
|
||||
fakeDB := &fakeBatchWriteDB{
|
||||
queryMap: map[string][]map[string]interface{}{
|
||||
stmt: {{"id": 41}, {"id": 42}},
|
||||
},
|
||||
fieldMap: map[string][]string{
|
||||
stmt: {"id"},
|
||||
},
|
||||
}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
config := connection.ConnectionConfig{Type: "postgres", Host: "127.0.0.1", Port: 5432, User: "postgres"}
|
||||
|
||||
result := app.DBQueryMultiTransactional(config, "main", stmt, "cte-write-query")
|
||||
if !result.Success {
|
||||
t.Fatalf("expected transactional data-changing CTE success, got failure: %s", result.Message)
|
||||
}
|
||||
if result.TransactionID == "" || !result.TransactionPending {
|
||||
t.Fatalf("expected pending transaction metadata, got id=%q pending=%v", result.TransactionID, result.TransactionPending)
|
||||
}
|
||||
if fakeDB.session == nil || fakeDB.session.closed {
|
||||
t.Fatal("expected data-changing CTE transaction session to stay open")
|
||||
}
|
||||
wantExecs := []string{"BEGIN"}
|
||||
if len(fakeDB.execQueries) != len(wantExecs) {
|
||||
t.Fatalf("expected exec queries %#v, got %#v", wantExecs, fakeDB.execQueries)
|
||||
}
|
||||
for i, want := range wantExecs {
|
||||
if fakeDB.execQueries[i] != want {
|
||||
t.Fatalf("expected exec query %d = %q, got %q", i, want, fakeDB.execQueries[i])
|
||||
}
|
||||
}
|
||||
if fakeDB.session.queryCalls == 0 {
|
||||
t.Fatal("expected data-changing CTE SELECT to query returned rows inside the transaction")
|
||||
}
|
||||
resultSets, ok := result.Data.([]connection.ResultSetData)
|
||||
if !ok {
|
||||
t.Fatalf("expected []connection.ResultSetData, got %T", result.Data)
|
||||
}
|
||||
if len(resultSets) != 1 || len(resultSets[0].Rows) != 2 {
|
||||
t.Fatalf("expected returned rows from data-changing CTE, got %#v", resultSets)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiTransactionalRollsBackAndClosesOnDMLFailure(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
|
||||
@@ -53,14 +53,19 @@ func leadingSQLKeyword(query string) string {
|
||||
}
|
||||
|
||||
func sqlDataOperationKeyword(query string) string {
|
||||
keyword, _ := sqlDataOperationInfo(query)
|
||||
return keyword
|
||||
}
|
||||
|
||||
func sqlDataOperationInfo(query string) (keyword string, withHasWrite bool) {
|
||||
keyword, keywordEnd := nextSQLKeyword(query, 0)
|
||||
if keyword != "with" {
|
||||
return keyword
|
||||
return keyword, false
|
||||
}
|
||||
if withKeyword, ok := sqlKeywordAfterLeadingWith(query, keywordEnd); ok {
|
||||
return withKeyword
|
||||
if withKeyword, hasWrite, ok := sqlKeywordAfterLeadingWith(query, keywordEnd); ok {
|
||||
return withKeyword, hasWrite
|
||||
}
|
||||
return keyword
|
||||
return keyword, false
|
||||
}
|
||||
|
||||
func nextSQLKeyword(text string, start int) (string, int) {
|
||||
@@ -106,8 +111,9 @@ func skipSQLTrivia(text string, start int) int {
|
||||
return pos
|
||||
}
|
||||
|
||||
func sqlKeywordAfterLeadingWith(text string, start int) (string, bool) {
|
||||
func sqlKeywordAfterLeadingWith(text string, start int) (string, bool, bool) {
|
||||
pos := skipSQLTrivia(text, start)
|
||||
hasWriteCTE := false
|
||||
if keyword, end := nextSQLKeyword(text, pos); keyword == "recursive" {
|
||||
pos = end
|
||||
}
|
||||
@@ -116,20 +122,20 @@ func sqlKeywordAfterLeadingWith(text string, start int) (string, bool) {
|
||||
pos = skipSQLTrivia(text, pos)
|
||||
next, ok := skipSQLIdentifierToken(text, pos)
|
||||
if !ok {
|
||||
return "", false
|
||||
return "", hasWriteCTE, false
|
||||
}
|
||||
pos = skipSQLTrivia(text, next)
|
||||
if pos < len(text) && text[pos] == '(' {
|
||||
next = skipBalancedSQLParens(text, pos)
|
||||
if next < 0 {
|
||||
return "", false
|
||||
return "", hasWriteCTE, false
|
||||
}
|
||||
pos = skipSQLTrivia(text, next)
|
||||
}
|
||||
|
||||
asEnd := findTopLevelSQLKeyword(text, pos, "as")
|
||||
if asEnd < 0 {
|
||||
return "", false
|
||||
return "", hasWriteCTE, false
|
||||
}
|
||||
pos = skipSQLTrivia(text, asEnd)
|
||||
if keyword, end := nextSQLKeyword(text, pos); keyword == "not" {
|
||||
@@ -142,11 +148,19 @@ func sqlKeywordAfterLeadingWith(text string, start int) (string, bool) {
|
||||
|
||||
pos = skipSQLTrivia(text, pos)
|
||||
if pos >= len(text) || text[pos] != '(' {
|
||||
return "", false
|
||||
return "", hasWriteCTE, false
|
||||
}
|
||||
cteBodyStart := pos + 1
|
||||
next = skipBalancedSQLParens(text, pos)
|
||||
if next < 0 {
|
||||
return "", false
|
||||
return "", hasWriteCTE, false
|
||||
}
|
||||
cteBodyEnd := next - 1
|
||||
if cteBodyEnd >= cteBodyStart {
|
||||
bodyKeyword, bodyHasWrite := sqlDataOperationInfo(text[cteBodyStart:cteBodyEnd])
|
||||
if bodyHasWrite || isSQLDataWriteKeyword(bodyKeyword) {
|
||||
hasWriteCTE = true
|
||||
}
|
||||
}
|
||||
pos = skipSQLTrivia(text, next)
|
||||
if pos < len(text) && text[pos] == ',' {
|
||||
@@ -155,7 +169,7 @@ func sqlKeywordAfterLeadingWith(text string, start int) (string, bool) {
|
||||
}
|
||||
|
||||
keyword, _ := nextSQLKeyword(text, pos)
|
||||
return keyword, keyword != ""
|
||||
return keyword, hasWriteCTE, keyword != ""
|
||||
}
|
||||
}
|
||||
|
||||
@@ -327,7 +341,11 @@ func isReadOnlySQLQuery(dbType string, query string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
switch sqlDataOperationKeyword(query) {
|
||||
keyword, withHasWrite := sqlDataOperationInfo(query)
|
||||
if withHasWrite {
|
||||
return false
|
||||
}
|
||||
switch keyword {
|
||||
case "select", "with", "show", "describe", "desc", "explain", "pragma", "values":
|
||||
return true
|
||||
default:
|
||||
@@ -340,7 +358,15 @@ func isBatchableWriteSQLStatement(dbType string, query string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
switch sqlDataOperationKeyword(query) {
|
||||
keyword, withHasWrite := sqlDataOperationInfo(query)
|
||||
if withHasWrite {
|
||||
return true
|
||||
}
|
||||
return isSQLDataWriteKeyword(keyword)
|
||||
}
|
||||
|
||||
func isSQLDataWriteKeyword(keyword string) bool {
|
||||
switch keyword {
|
||||
case "insert", "update", "delete", "replace", "merge", "upsert":
|
||||
return true
|
||||
default:
|
||||
|
||||
@@ -79,6 +79,11 @@ func TestIsReadOnlySQLQuery_ClassifiesWithByTopLevelOperation(t *testing.T) {
|
||||
if isReadOnlySQLQuery("postgres", writeQuery) {
|
||||
t.Fatal("WITH ... UPDATE should not be treated as read-only")
|
||||
}
|
||||
|
||||
writeCTEQuery := "WITH moved AS (DELETE FROM audit_logs WHERE created_at < NOW() RETURNING id) SELECT * FROM moved"
|
||||
if isReadOnlySQLQuery("postgres", writeCTEQuery) {
|
||||
t.Fatal("data-changing CTE should not be treated as read-only")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBatchableWriteSQLStatement_OnlyMatchesRealWriteStatements(t *testing.T) {
|
||||
@@ -88,6 +93,9 @@ func TestIsBatchableWriteSQLStatement_OnlyMatchesRealWriteStatements(t *testing.
|
||||
if !isBatchableWriteSQLStatement("postgres", "WITH target AS (SELECT id FROM users) DELETE FROM users WHERE id IN (SELECT id FROM target)") {
|
||||
t.Fatal("expected WITH ... DELETE to be treated as batchable write")
|
||||
}
|
||||
if !isBatchableWriteSQLStatement("postgres", "WITH moved AS (DELETE FROM audit_logs WHERE created_at < NOW() RETURNING id) SELECT * FROM moved") {
|
||||
t.Fatal("expected data-changing CTE to be treated as batchable write")
|
||||
}
|
||||
if isBatchableWriteSQLStatement("sqlserver", "EXEC sp_who2") {
|
||||
t.Fatal("EXEC should not be treated as batchable write")
|
||||
}
|
||||
@@ -116,3 +124,10 @@ func TestShouldTryQueryResultFirst_TreatsSQLServerSystemCommandsAsQueryFirst(t *
|
||||
t.Fatal("non-SQLServer system procedure name should not force query-first")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldTryQueryResultFirst_TreatsDataChangingCTESelectAsQueryFirst(t *testing.T) {
|
||||
query := "WITH moved AS (DELETE FROM audit_logs WHERE created_at < NOW() RETURNING id) SELECT * FROM moved"
|
||||
if !shouldTryQueryResultFirst("postgres", query) {
|
||||
t.Fatal("data-changing CTE ending in SELECT should try query-first to preserve returned rows")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user