diff --git a/frontend/src/components/QueryEditor.external-sql-save.test.tsx b/frontend/src/components/QueryEditor.external-sql-save.test.tsx
index 5547a5f..a21a02b 100644
--- a/frontend/src/components/QueryEditor.external-sql-save.test.tsx
+++ b/frontend/src/components/QueryEditor.external-sql-save.test.tsx
@@ -2182,6 +2182,81 @@ describe('QueryEditor external SQL save', () => {
expect(textContent(renderer!.root)).not.toContain('事务待提交');
});
+ it('runs SQL editor WITH DML through a pending managed transaction', async () => {
+ const sql = 'WITH target AS (SELECT id FROM users WHERE active = 1) UPDATE users SET synced = 1 WHERE id IN (SELECT id FROM target)';
+ backendApp.DBQueryMultiTransactional.mockResolvedValueOnce({
+ success: true,
+ transactionId: 'tx-with-dml',
+ transactionPending: true,
+ data: [
+ { columns: ['affectedRows'], rows: [{ affectedRows: 2 }], statementIndex: 1 },
+ ],
+ });
+
+ let renderer!: ReactTestRenderer;
+ await act(async () => {
+ renderer = create();
+ });
+
+ await act(async () => {
+ await findButton(renderer!, '运行').props.onClick();
+ });
+ await act(async () => {
+ await Promise.resolve();
+ await Promise.resolve();
+ });
+
+ expect(backendApp.DBQueryMultiTransactional).toHaveBeenCalledWith(
+ expect.anything(),
+ 'main',
+ expect.stringContaining('WITH target AS'),
+ 'query-1',
+ );
+ expect(backendApp.DBQueryMulti).not.toHaveBeenCalled();
+ expect(textContent(renderer!.root)).toContain('事务待提交');
+
+ await act(async () => {
+ await findExactButton(renderer!, '提交').props.onClick();
+ });
+ await act(async () => {
+ await Promise.resolve();
+ await Promise.resolve();
+ });
+
+ expect(backendApp.DBCommitTransaction).toHaveBeenCalledWith('tx-with-dml');
+ });
+
+ it('keeps SQL editor WITH SELECT on the regular query path', async () => {
+ const sql = 'WITH target AS (SELECT id FROM users WHERE active = 1) SELECT * FROM target';
+ backendApp.DBQueryMulti.mockResolvedValueOnce({
+ success: true,
+ data: [
+ { columns: ['id'], rows: [{ id: 1 }], statementIndex: 1 },
+ ],
+ });
+
+ let renderer!: ReactTestRenderer;
+ await act(async () => {
+ renderer = create();
+ });
+
+ await act(async () => {
+ await findButton(renderer!, '运行').props.onClick();
+ });
+ await act(async () => {
+ await Promise.resolve();
+ await Promise.resolve();
+ });
+
+ expect(backendApp.DBQueryMulti).toHaveBeenCalledWith(
+ expect.anything(),
+ 'main',
+ expect.stringContaining('WITH target AS'),
+ 'query-1',
+ );
+ expect(backendApp.DBQueryMultiTransactional).not.toHaveBeenCalled();
+ });
+
it('auto commits SQL editor DML transactions after the configured delay', async () => {
vi.useFakeTimers();
storeState.sqlEditorTransactionOptions = {
@@ -3446,6 +3521,9 @@ describe('QueryEditor external SQL save', () => {
expect(source).toContain('gn-v2-query-toolbar-max-rows-select');
expect(source).toContain('gn-v2-query-toolbar-transaction-mode-select');
expect(source).toContain('gn-v2-query-toolbar-transaction-delay-select');
+ expect(source).toContain('这里仅选择该事务执行成功后的提交方式');
+ expect(source).toContain("label: '提交:手动'");
+ expect(source).toContain("label: '提交:自动'");
expect(source).toContain('gn-v2-query-toolbar-action-group');
expect(source).toContain('style={isV2Ui ? undefined : { width: 150 }}');
expect(source).toContain('style={isV2Ui ? undefined : { width: 200 }}');
diff --git a/frontend/src/components/QueryEditor.tsx b/frontend/src/components/QueryEditor.tsx
index 236f9e1..ac28ab7 100644
--- a/frontend/src/components/QueryEditor.tsx
+++ b/frontend/src/components/QueryEditor.tsx
@@ -18,6 +18,7 @@ import { applyQueryAutoLimit } from '../utils/queryAutoLimit';
import { extractQueryResultTableRef, type QueryResultTableRef } from '../utils/queryResultTable';
import { quoteIdentPart } from '../utils/sql';
import { formatSqlExecutionError } from '../utils/sqlErrorSemantics';
+import { shouldUseSqlEditorManagedTransaction } from '../utils/sqlEditorTransaction';
import { findSqlStatementRanges, resolveCurrentSqlStatementRange, resolveExecutableSql } from '../utils/sqlStatementSelection';
import { isMacLikePlatform } from '../utils/appearance';
import { splitSidebarQualifiedName } from '../utils/sidebarLocate';
@@ -750,9 +751,6 @@ const areSqlStatementListsEqual = (left: string[], right: string[]): boolean =>
&& left.every((statement, index) => normalizeExecutedSqlKey(statement) === normalizeExecutedSqlKey(right[index]))
);
-const SQL_EDITOR_DML_KEYWORDS = new Set(['insert', 'update', 'delete', 'replace', 'merge', 'upsert']);
-const SQL_EDITOR_READ_KEYWORDS = new Set(['select', 'with', 'show', 'describe', 'desc', 'explain', 'pragma', 'values']);
-const SQL_EDITOR_TRANSACTION_CONTROL_KEYWORDS = new Set(['begin', 'commit', 'rollback', 'savepoint', 'release']);
const SQL_EDITOR_AUTO_COMMIT_DELAY_OPTIONS = [
{ value: 3000, label: '3 秒' },
{ value: 5000, label: '5 秒' },
@@ -760,50 +758,6 @@ const SQL_EDITOR_AUTO_COMMIT_DELAY_OPTIONS = [
{ value: 30000, label: '30 秒' },
];
-const resolveLeadingSqlKeyword = (statement: string): string => {
- let text = String(statement || '').trim();
- while (text) {
- if (text.startsWith('--') || text.startsWith('#')) {
- const lineBreak = text.indexOf('\n');
- if (lineBreak < 0) return '';
- text = text.slice(lineBreak + 1).trimStart();
- continue;
- }
- if (text.startsWith('/*')) {
- const blockEnd = text.indexOf('*/');
- if (blockEnd < 0) return '';
- text = text.slice(blockEnd + 2).trimStart();
- continue;
- }
- break;
- }
- const match = text.match(/^([A-Za-z0-9_]+)/);
- return match?.[1]?.toLowerCase() || '';
-};
-
-const isSqlEditorTransactionControlStatement = (statement: string): boolean => {
- const keyword = resolveLeadingSqlKeyword(statement);
- if (SQL_EDITOR_TRANSACTION_CONTROL_KEYWORDS.has(keyword)) return true;
- return keyword === 'start' && /\btransaction\b/i.test(statement);
-};
-
-const shouldUseSqlEditorManagedTransaction = (statements: string[]): boolean => {
- let hasManagedWrite = false;
- for (const statement of statements) {
- const trimmed = String(statement || '').trim();
- if (!trimmed) continue;
- if (isSqlEditorTransactionControlStatement(trimmed)) return false;
- const keyword = resolveLeadingSqlKeyword(trimmed);
- if (SQL_EDITOR_READ_KEYWORDS.has(keyword)) continue;
- if (SQL_EDITOR_DML_KEYWORDS.has(keyword)) {
- hasManagedWrite = true;
- continue;
- }
- return false;
- }
- return hasManagedWrite;
-};
-
type PendingSqlEditorTransaction = {
id: string;
commitMode: 'manual' | 'auto';
@@ -5369,15 +5323,15 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
]}
/>
-
+
diff --git a/frontend/src/utils/sqlEditorTransaction.test.ts b/frontend/src/utils/sqlEditorTransaction.test.ts
new file mode 100644
index 0000000..5d8189c
--- /dev/null
+++ b/frontend/src/utils/sqlEditorTransaction.test.ts
@@ -0,0 +1,41 @@
+import { describe, expect, it } from 'vitest';
+
+import {
+ resolveSqlEditorOperationKeyword,
+ shouldUseSqlEditorManagedTransaction,
+} from './sqlEditorTransaction';
+
+describe('sqlEditorTransaction', () => {
+ it('keeps regular DML in a managed transaction', () => {
+ expect(shouldUseSqlEditorManagedTransaction(['UPDATE users SET name = "n" WHERE id = 1'])).toBe(true);
+ expect(shouldUseSqlEditorManagedTransaction(['INSERT INTO users(id) VALUES (1)'])).toBe(true);
+ expect(shouldUseSqlEditorManagedTransaction(['DELETE FROM users WHERE id = 1'])).toBe(true);
+ });
+
+ it('classifies WITH statements by their top-level operation', () => {
+ expect(resolveSqlEditorOperationKeyword('WITH target AS (SELECT id FROM users) SELECT * FROM target')).toBe('select');
+ expect(resolveSqlEditorOperationKeyword('WITH target AS (SELECT id FROM users) UPDATE users SET synced = 1')).toBe('update');
+ expect(resolveSqlEditorOperationKeyword('WITH target AS (SELECT id FROM users) DELETE FROM users WHERE id IN (SELECT id FROM target)')).toBe('delete');
+ });
+
+ it('uses managed transactions for WITH DML but not WITH SELECT', () => {
+ expect(shouldUseSqlEditorManagedTransaction([
+ 'WITH target AS (SELECT id FROM users) UPDATE users SET synced = 1 WHERE id IN (SELECT id FROM target)',
+ ])).toBe(true);
+ expect(shouldUseSqlEditorManagedTransaction([
+ 'WITH target AS (SELECT id FROM users) SELECT * FROM target',
+ ])).toBe(false);
+ });
+
+ it('does not wrap user-authored explicit transactions', () => {
+ expect(shouldUseSqlEditorManagedTransaction([
+ 'BEGIN',
+ 'UPDATE users SET name = "n" WHERE id = 1',
+ 'COMMIT',
+ ])).toBe(false);
+ expect(shouldUseSqlEditorManagedTransaction([
+ 'START TRANSACTION',
+ 'DELETE FROM users WHERE id = 1',
+ ])).toBe(false);
+ });
+});
diff --git a/frontend/src/utils/sqlEditorTransaction.ts b/frontend/src/utils/sqlEditorTransaction.ts
new file mode 100644
index 0000000..c2f5201
--- /dev/null
+++ b/frontend/src/utils/sqlEditorTransaction.ts
@@ -0,0 +1,246 @@
+const SQL_EDITOR_DML_KEYWORDS = new Set(['insert', 'update', 'delete', 'replace', 'merge', 'upsert']);
+const SQL_EDITOR_READ_KEYWORDS = new Set(['select', 'with', 'show', 'describe', 'desc', 'explain', 'pragma', 'values']);
+const SQL_EDITOR_TRANSACTION_CONTROL_KEYWORDS = new Set(['begin', 'commit', 'rollback', 'savepoint', 'release']);
+
+const isSqlEditorKeywordChar = (char: string | undefined): boolean => !!char && /[A-Za-z0-9_]/.test(char);
+
+const skipSqlEditorTrivia = (text: string, start: number): number => {
+ let pos = start;
+ while (pos < text.length) {
+ const char = text[pos];
+ if (/\s/.test(char || '')) {
+ pos++;
+ continue;
+ }
+ if (text.startsWith('--', pos) || text.startsWith('#', pos)) {
+ const nextLine = text.indexOf('\n', pos);
+ if (nextLine < 0) return text.length;
+ pos = nextLine + 1;
+ continue;
+ }
+ if (text.startsWith('/*', pos)) {
+ const blockEnd = text.indexOf('*/', pos + 2);
+ if (blockEnd < 0) return text.length;
+ pos = blockEnd + 2;
+ continue;
+ }
+ return pos;
+ }
+ return pos;
+};
+
+const readSqlEditorKeyword = (text: string, start: number): { keyword: string; end: number } => {
+ const pos = skipSqlEditorTrivia(text, start);
+ if (!isSqlEditorKeywordChar(text[pos])) {
+ return { keyword: '', end: pos };
+ }
+ let end = pos + 1;
+ while (isSqlEditorKeywordChar(text[end])) {
+ end++;
+ }
+ return { keyword: text.slice(pos, end).toLowerCase(), end };
+};
+
+const skipSqlEditorDelimited = (text: string, start: number, delimiter: string): number => {
+ let pos = start + 1;
+ while (pos < text.length) {
+ if (text[pos] === delimiter) {
+ if (text[pos + 1] === delimiter) {
+ pos += 2;
+ continue;
+ }
+ return pos + 1;
+ }
+ pos++;
+ }
+ return text.length;
+};
+
+const resolveSqlEditorDollarQuoteTag = (text: string, start: number): string => {
+ if (text[start] !== '$') return '';
+ let end = start + 1;
+ while (isSqlEditorKeywordChar(text[end])) {
+ end++;
+ }
+ return text[end] === '$' ? text.slice(start, end + 1) : '';
+};
+
+const skipSqlEditorQuotedOrComment = (text: string, start: number): number | null => {
+ if (text.startsWith('--', start) || text.startsWith('#', start)) {
+ const nextLine = text.indexOf('\n', start);
+ return nextLine < 0 ? text.length : nextLine + 1;
+ }
+ if (text.startsWith('/*', start)) {
+ const blockEnd = text.indexOf('*/', start + 2);
+ return blockEnd < 0 ? text.length : blockEnd + 2;
+ }
+ const char = text[start];
+ if (char === '\'' || char === '"' || char === '`') {
+ return skipSqlEditorDelimited(text, start, char);
+ }
+ if (char === '[') {
+ const bracketEnd = text.indexOf(']', start + 1);
+ return bracketEnd < 0 ? text.length : bracketEnd + 1;
+ }
+ const dollarTag = resolveSqlEditorDollarQuoteTag(text, start);
+ if (dollarTag) {
+ const dollarEnd = text.indexOf(dollarTag, start + dollarTag.length);
+ return dollarEnd < 0 ? text.length : dollarEnd + dollarTag.length;
+ }
+ return null;
+};
+
+const skipBalancedSqlEditorParens = (text: string, start: number): number => {
+ if (text[start] !== '(') return -1;
+ let depth = 0;
+ let pos = start;
+ while (pos < text.length) {
+ const skipped = skipSqlEditorQuotedOrComment(text, pos);
+ if (skipped !== null) {
+ pos = skipped;
+ continue;
+ }
+ if (text[pos] === '(') {
+ depth++;
+ pos++;
+ continue;
+ }
+ if (text[pos] === ')') {
+ depth--;
+ pos++;
+ if (depth === 0) return pos;
+ continue;
+ }
+ pos++;
+ }
+ return -1;
+};
+
+const skipSqlEditorIdentifierToken = (text: string, start: number): number => {
+ if (start >= text.length) return -1;
+ const char = text[start];
+ if (char === '"' || char === '`') return skipSqlEditorDelimited(text, start, char);
+ if (char === '[') {
+ const bracketEnd = text.indexOf(']', start + 1);
+ return bracketEnd < 0 ? text.length : bracketEnd + 1;
+ }
+ if (!isSqlEditorKeywordChar(char)) return -1;
+ let end = start + 1;
+ while (isSqlEditorKeywordChar(text[end])) {
+ end++;
+ }
+ return end;
+};
+
+const findTopLevelSqlEditorKeyword = (text: string, start: number, keyword: string): number => {
+ let depth = 0;
+ let pos = start;
+ while (pos < text.length) {
+ const skipped = skipSqlEditorQuotedOrComment(text, pos);
+ if (skipped !== null) {
+ pos = skipped;
+ continue;
+ }
+ if (text[pos] === '(') {
+ depth++;
+ pos++;
+ continue;
+ }
+ if (text[pos] === ')') {
+ if (depth > 0) depth--;
+ pos++;
+ continue;
+ }
+ if (depth === 0 && isSqlEditorKeywordChar(text[pos])) {
+ let end = pos + 1;
+ while (isSqlEditorKeywordChar(text[end])) {
+ end++;
+ }
+ if (text.slice(pos, end).toLowerCase() === keyword) {
+ return end;
+ }
+ pos = end;
+ continue;
+ }
+ pos++;
+ }
+ return -1;
+};
+
+const resolveSqlEditorKeywordAfterWith = (text: string, start: number): string => {
+ let pos = skipSqlEditorTrivia(text, start);
+ const recursive = readSqlEditorKeyword(text, pos);
+ if (recursive.keyword === 'recursive') {
+ pos = recursive.end;
+ }
+
+ while (pos < text.length) {
+ pos = skipSqlEditorTrivia(text, pos);
+ const identifierEnd = skipSqlEditorIdentifierToken(text, pos);
+ if (identifierEnd < 0) return '';
+ pos = skipSqlEditorTrivia(text, identifierEnd);
+ if (text[pos] === '(') {
+ const columnsEnd = skipBalancedSqlEditorParens(text, pos);
+ if (columnsEnd < 0) return '';
+ pos = skipSqlEditorTrivia(text, columnsEnd);
+ }
+
+ const asEnd = findTopLevelSqlEditorKeyword(text, pos, 'as');
+ if (asEnd < 0) return '';
+ pos = skipSqlEditorTrivia(text, asEnd);
+ const materialized = readSqlEditorKeyword(text, pos);
+ if (materialized.keyword === 'not') {
+ const next = readSqlEditorKeyword(text, materialized.end);
+ if (next.keyword === 'materialized') {
+ pos = next.end;
+ }
+ } else if (materialized.keyword === 'materialized') {
+ pos = materialized.end;
+ }
+
+ pos = skipSqlEditorTrivia(text, pos);
+ if (text[pos] !== '(') return '';
+ const cteEnd = skipBalancedSqlEditorParens(text, pos);
+ if (cteEnd < 0) return '';
+ pos = skipSqlEditorTrivia(text, cteEnd);
+ if (text[pos] === ',') {
+ pos++;
+ continue;
+ }
+
+ return readSqlEditorKeyword(text, pos).keyword;
+ }
+ return '';
+};
+
+export const resolveSqlEditorOperationKeyword = (statement: string): string => {
+ const text = String(statement || '');
+ const leading = readSqlEditorKeyword(text, 0);
+ if (leading.keyword !== 'with') {
+ return leading.keyword;
+ }
+ return resolveSqlEditorKeywordAfterWith(text, leading.end) || leading.keyword;
+};
+
+const isSqlEditorTransactionControlStatement = (statement: string): boolean => {
+ const keyword = readSqlEditorKeyword(String(statement || ''), 0).keyword;
+ if (SQL_EDITOR_TRANSACTION_CONTROL_KEYWORDS.has(keyword)) return true;
+ return keyword === 'start' && /\btransaction\b/i.test(statement);
+};
+
+export const shouldUseSqlEditorManagedTransaction = (statements: string[]): boolean => {
+ let hasManagedWrite = false;
+ for (const statement of statements) {
+ const trimmed = String(statement || '').trim();
+ if (!trimmed) continue;
+ if (isSqlEditorTransactionControlStatement(trimmed)) return false;
+ const keyword = resolveSqlEditorOperationKeyword(trimmed);
+ if (SQL_EDITOR_READ_KEYWORDS.has(keyword)) continue;
+ if (SQL_EDITOR_DML_KEYWORDS.has(keyword)) {
+ hasManagedWrite = true;
+ continue;
+ }
+ return false;
+ }
+ return hasManagedWrite;
+};
diff --git a/internal/app/methods_db_multi_test.go b/internal/app/methods_db_multi_test.go
index 37f9957..4906346 100644
--- a/internal/app/methods_db_multi_test.go
+++ b/internal/app/methods_db_multi_test.go
@@ -584,6 +584,46 @@ func TestDBQueryMultiTransactionalKeepsDMLTransactionOpenUntilCommit(t *testing.
}
}
+func TestDBQueryMultiTransactionalTreatsWithDMLAsManagedWrite(t *testing.T) {
+ originalNewDatabaseFunc := newDatabaseFunc
+ t.Cleanup(func() {
+ newDatabaseFunc = originalNewDatabaseFunc
+ })
+
+ stmt := "WITH target AS (SELECT id FROM users WHERE active = 1) UPDATE users SET synced = 1 WHERE id IN (SELECT id FROM target)"
+ fakeDB := &fakeBatchWriteDB{
+ execAffected: map[string]int64{
+ stmt: 2,
+ },
+ }
+ newDatabaseFunc = func(dbType string) (db.Database, error) {
+ return fakeDB, nil
+ }
+
+ app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
+ config := connection.ConnectionConfig{Type: "postgres", Host: "127.0.0.1", Port: 5432, User: "postgres"}
+
+ result := app.DBQueryMultiTransactional(config, "main", stmt, "with-dml-query")
+ if !result.Success {
+ t.Fatalf("expected transactional WITH DML success, got failure: %s", result.Message)
+ }
+ if result.TransactionID == "" || !result.TransactionPending {
+ t.Fatalf("expected pending transaction metadata, got id=%q pending=%v", result.TransactionID, result.TransactionPending)
+ }
+ if fakeDB.session == nil || fakeDB.session.closed {
+ t.Fatal("expected WITH DML transaction session to stay open")
+ }
+ wantExecs := []string{"BEGIN", stmt}
+ if len(fakeDB.execQueries) != len(wantExecs) {
+ t.Fatalf("expected exec queries %#v, got %#v", wantExecs, fakeDB.execQueries)
+ }
+ for i, want := range wantExecs {
+ if fakeDB.execQueries[i] != want {
+ t.Fatalf("expected exec query %d = %q, got %q", i, want, fakeDB.execQueries[i])
+ }
+ }
+}
+
func TestDBQueryMultiTransactionalRollsBackAndClosesOnDMLFailure(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
diff --git a/internal/app/sql_sanitize.go b/internal/app/sql_sanitize.go
index 06ae248..fb102dc 100644
--- a/internal/app/sql_sanitize.go
+++ b/internal/app/sql_sanitize.go
@@ -52,12 +52,282 @@ func leadingSQLKeyword(query string) string {
return strings.ToLower(text)
}
+func sqlDataOperationKeyword(query string) string {
+ keyword, keywordEnd := nextSQLKeyword(query, 0)
+ if keyword != "with" {
+ return keyword
+ }
+ if withKeyword, ok := sqlKeywordAfterLeadingWith(query, keywordEnd); ok {
+ return withKeyword
+ }
+ return keyword
+}
+
+func nextSQLKeyword(text string, start int) (string, int) {
+ pos := skipSQLTrivia(text, start)
+ if pos >= len(text) || !isSQLKeywordByte(text[pos]) {
+ return "", pos
+ }
+ end := pos + 1
+ for end < len(text) && isSQLKeywordByte(text[end]) {
+ end++
+ }
+ return strings.ToLower(text[pos:end]), end
+}
+
+func skipSQLTrivia(text string, start int) int {
+ pos := start
+ for pos < len(text) {
+ switch {
+ case text[pos] == ' ' || text[pos] == '\t' || text[pos] == '\r' || text[pos] == '\n' || text[pos] == '\f':
+ pos++
+ case strings.HasPrefix(text[pos:], "--"):
+ next := strings.IndexByte(text[pos:], '\n')
+ if next < 0 {
+ return len(text)
+ }
+ pos += next + 1
+ case strings.HasPrefix(text[pos:], "#"):
+ next := strings.IndexByte(text[pos:], '\n')
+ if next < 0 {
+ return len(text)
+ }
+ pos += next + 1
+ case strings.HasPrefix(text[pos:], "/*"):
+ end := strings.Index(text[pos+2:], "*/")
+ if end < 0 {
+ return len(text)
+ }
+ pos += end + 4
+ default:
+ return pos
+ }
+ }
+ return pos
+}
+
+func sqlKeywordAfterLeadingWith(text string, start int) (string, bool) {
+ pos := skipSQLTrivia(text, start)
+ if keyword, end := nextSQLKeyword(text, pos); keyword == "recursive" {
+ pos = end
+ }
+
+ for {
+ pos = skipSQLTrivia(text, pos)
+ next, ok := skipSQLIdentifierToken(text, pos)
+ if !ok {
+ return "", false
+ }
+ pos = skipSQLTrivia(text, next)
+ if pos < len(text) && text[pos] == '(' {
+ next = skipBalancedSQLParens(text, pos)
+ if next < 0 {
+ return "", false
+ }
+ pos = skipSQLTrivia(text, next)
+ }
+
+ asEnd := findTopLevelSQLKeyword(text, pos, "as")
+ if asEnd < 0 {
+ return "", false
+ }
+ pos = skipSQLTrivia(text, asEnd)
+ if keyword, end := nextSQLKeyword(text, pos); keyword == "not" {
+ if nextKeyword, nextEnd := nextSQLKeyword(text, end); nextKeyword == "materialized" {
+ pos = nextEnd
+ }
+ } else if keyword == "materialized" {
+ pos = end
+ }
+
+ pos = skipSQLTrivia(text, pos)
+ if pos >= len(text) || text[pos] != '(' {
+ return "", false
+ }
+ next = skipBalancedSQLParens(text, pos)
+ if next < 0 {
+ return "", false
+ }
+ pos = skipSQLTrivia(text, next)
+ if pos < len(text) && text[pos] == ',' {
+ pos++
+ continue
+ }
+
+ keyword, _ := nextSQLKeyword(text, pos)
+ return keyword, keyword != ""
+ }
+}
+
+func findTopLevelSQLKeyword(text string, start int, want string) int {
+ depth := 0
+ for pos := start; pos < len(text); {
+ if next, ok := skipSQLQuotedOrComment(text, pos); ok {
+ pos = next
+ continue
+ }
+ switch text[pos] {
+ case '(':
+ depth++
+ pos++
+ case ')':
+ if depth > 0 {
+ depth--
+ }
+ pos++
+ default:
+ if depth == 0 && isSQLKeywordByte(text[pos]) {
+ end := pos + 1
+ for end < len(text) && isSQLKeywordByte(text[end]) {
+ end++
+ }
+ if strings.EqualFold(text[pos:end], want) {
+ return end
+ }
+ pos = end
+ continue
+ }
+ pos++
+ }
+ }
+ return -1
+}
+
+func skipSQLIdentifierToken(text string, start int) (int, bool) {
+ if start >= len(text) {
+ return start, false
+ }
+ switch text[start] {
+ case '"', '`':
+ next := skipSQLDelimited(text, start, text[start])
+ return next, next > start
+ case '[':
+ next := strings.IndexByte(text[start+1:], ']')
+ if next < 0 {
+ return len(text), true
+ }
+ return start + next + 2, true
+ default:
+ if !isSQLKeywordByte(text[start]) {
+ return start, false
+ }
+ end := start + 1
+ for end < len(text) && isSQLKeywordByte(text[end]) {
+ end++
+ }
+ return end, true
+ }
+}
+
+func skipBalancedSQLParens(text string, start int) int {
+ if start >= len(text) || text[start] != '(' {
+ return -1
+ }
+ depth := 0
+ for pos := start; pos < len(text); {
+ if next, ok := skipSQLQuotedOrComment(text, pos); ok {
+ pos = next
+ continue
+ }
+ switch text[pos] {
+ case '(':
+ depth++
+ pos++
+ case ')':
+ depth--
+ pos++
+ if depth == 0 {
+ return pos
+ }
+ default:
+ pos++
+ }
+ }
+ return -1
+}
+
+func skipSQLQuotedOrComment(text string, start int) (int, bool) {
+ if start >= len(text) {
+ return start, false
+ }
+ switch {
+ case strings.HasPrefix(text[start:], "--"):
+ next := strings.IndexByte(text[start:], '\n')
+ if next < 0 {
+ return len(text), true
+ }
+ return start + next + 1, true
+ case strings.HasPrefix(text[start:], "#"):
+ next := strings.IndexByte(text[start:], '\n')
+ if next < 0 {
+ return len(text), true
+ }
+ return start + next + 1, true
+ case strings.HasPrefix(text[start:], "/*"):
+ end := strings.Index(text[start+2:], "*/")
+ if end < 0 {
+ return len(text), true
+ }
+ return start + end + 4, true
+ case text[start] == '\'' || text[start] == '"' || text[start] == '`':
+ return skipSQLDelimited(text, start, text[start]), true
+ case text[start] == '[':
+ next := strings.IndexByte(text[start+1:], ']')
+ if next < 0 {
+ return len(text), true
+ }
+ return start + next + 2, true
+ default:
+ if tag, ok := sqlDollarQuoteTag(text, start); ok {
+ end := strings.Index(text[start+len(tag):], tag)
+ if end < 0 {
+ return len(text), true
+ }
+ return start + len(tag) + end + len(tag), true
+ }
+ return start, false
+ }
+}
+
+func skipSQLDelimited(text string, start int, delimiter byte) int {
+ pos := start + 1
+ for pos < len(text) {
+ if text[pos] == delimiter {
+ if pos+1 < len(text) && text[pos+1] == delimiter {
+ pos += 2
+ continue
+ }
+ return pos + 1
+ }
+ pos++
+ }
+ return len(text)
+}
+
+func sqlDollarQuoteTag(text string, start int) (string, bool) {
+ if start >= len(text) || text[start] != '$' {
+ return "", false
+ }
+ end := start + 1
+ for end < len(text) && (isSQLKeywordByte(text[end]) || text[end] == '-') {
+ end++
+ }
+ if end < len(text) && text[end] == '$' {
+ return text[start : end+1], true
+ }
+ return "", false
+}
+
+func isSQLKeywordByte(ch byte) bool {
+ return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_'
+}
+
func isReadOnlySQLQuery(dbType string, query string) bool {
if strings.ToLower(strings.TrimSpace(dbType)) == "mongodb" && strings.HasPrefix(strings.TrimSpace(query), "{") {
return true
}
- switch leadingSQLKeyword(query) {
+ switch sqlDataOperationKeyword(query) {
case "select", "with", "show", "describe", "desc", "explain", "pragma", "values":
return true
default:
@@ -70,7 +340,7 @@ func isBatchableWriteSQLStatement(dbType string, query string) bool {
return false
}
- switch leadingSQLKeyword(query) {
+ switch sqlDataOperationKeyword(query) {
case "insert", "update", "delete", "replace", "merge", "upsert":
return true
default:
diff --git a/internal/app/sql_sanitize_test.go b/internal/app/sql_sanitize_test.go
index 8825e89..31a77c8 100644
--- a/internal/app/sql_sanitize_test.go
+++ b/internal/app/sql_sanitize_test.go
@@ -69,10 +69,25 @@ func TestIsReadOnlySQLQuery_DoesNotTreatExecAsReadOnly(t *testing.T) {
}
}
+func TestIsReadOnlySQLQuery_ClassifiesWithByTopLevelOperation(t *testing.T) {
+ readQuery := "WITH target AS (SELECT id FROM users WHERE active = 1) SELECT * FROM target"
+ if !isReadOnlySQLQuery("postgres", readQuery) {
+ t.Fatal("WITH ... SELECT should stay read-only")
+ }
+
+ writeQuery := "WITH target AS (SELECT id FROM users WHERE active = 1) UPDATE users SET synced = 1 WHERE id IN (SELECT id FROM target)"
+ if isReadOnlySQLQuery("postgres", writeQuery) {
+ t.Fatal("WITH ... UPDATE should not be treated as read-only")
+ }
+}
+
func TestIsBatchableWriteSQLStatement_OnlyMatchesRealWriteStatements(t *testing.T) {
if !isBatchableWriteSQLStatement("mysql", "INSERT INTO demo(id) VALUES (1)") {
t.Fatal("expected INSERT to be treated as batchable write")
}
+ if !isBatchableWriteSQLStatement("postgres", "WITH target AS (SELECT id FROM users) DELETE FROM users WHERE id IN (SELECT id FROM target)") {
+ t.Fatal("expected WITH ... DELETE to be treated as batchable write")
+ }
if isBatchableWriteSQLStatement("sqlserver", "EXEC sp_who2") {
t.Fatal("EXEC should not be treated as batchable write")
}