🐛 fix(editor): 修正 SQL 编辑器 DML 事务提交语义

- SQL 编辑器 DML 固定进入托管事务

- 区分 WITH SELECT 和 WITH DML 的事务判定

- 调整提交方式文案并补充前后端回归测试
This commit is contained in:
Syngnat
2026-06-10 18:05:46 +08:00
parent 61d71cf1d0
commit d8da8d6abf
7 changed files with 696 additions and 52 deletions

View File

@@ -2182,6 +2182,81 @@ describe('QueryEditor external SQL save', () => {
expect(textContent(renderer!.root)).not.toContain('事务待提交');
});
it('runs SQL editor WITH DML through a pending managed transaction', async () => {
const sql = 'WITH target AS (SELECT id FROM users WHERE active = 1) UPDATE users SET synced = 1 WHERE id IN (SELECT id FROM target)';
backendApp.DBQueryMultiTransactional.mockResolvedValueOnce({
success: true,
transactionId: 'tx-with-dml',
transactionPending: true,
data: [
{ columns: ['affectedRows'], rows: [{ affectedRows: 2 }], statementIndex: 1 },
],
});
let renderer!: ReactTestRenderer;
await act(async () => {
renderer = create(<QueryEditor tab={createTab({ query: sql })} />);
});
await act(async () => {
await findButton(renderer!, '运行').props.onClick();
});
await act(async () => {
await Promise.resolve();
await Promise.resolve();
});
expect(backendApp.DBQueryMultiTransactional).toHaveBeenCalledWith(
expect.anything(),
'main',
expect.stringContaining('WITH target AS'),
'query-1',
);
expect(backendApp.DBQueryMulti).not.toHaveBeenCalled();
expect(textContent(renderer!.root)).toContain('事务待提交');
await act(async () => {
await findExactButton(renderer!, '提交').props.onClick();
});
await act(async () => {
await Promise.resolve();
await Promise.resolve();
});
expect(backendApp.DBCommitTransaction).toHaveBeenCalledWith('tx-with-dml');
});
it('keeps SQL editor WITH SELECT on the regular query path', async () => {
const sql = 'WITH target AS (SELECT id FROM users WHERE active = 1) SELECT * FROM target';
backendApp.DBQueryMulti.mockResolvedValueOnce({
success: true,
data: [
{ columns: ['id'], rows: [{ id: 1 }], statementIndex: 1 },
],
});
let renderer!: ReactTestRenderer;
await act(async () => {
renderer = create(<QueryEditor tab={createTab({ query: sql })} />);
});
await act(async () => {
await findButton(renderer!, '运行').props.onClick();
});
await act(async () => {
await Promise.resolve();
await Promise.resolve();
});
expect(backendApp.DBQueryMulti).toHaveBeenCalledWith(
expect.anything(),
'main',
expect.stringContaining('WITH target AS'),
'query-1',
);
expect(backendApp.DBQueryMultiTransactional).not.toHaveBeenCalled();
});
it('auto commits SQL editor DML transactions after the configured delay', async () => {
vi.useFakeTimers();
storeState.sqlEditorTransactionOptions = {
@@ -3446,6 +3521,9 @@ describe('QueryEditor external SQL save', () => {
expect(source).toContain('gn-v2-query-toolbar-max-rows-select');
expect(source).toContain('gn-v2-query-toolbar-transaction-mode-select');
expect(source).toContain('gn-v2-query-toolbar-transaction-delay-select');
expect(source).toContain('这里仅选择该事务执行成功后的提交方式');
expect(source).toContain("label: '提交:手动'");
expect(source).toContain("label: '提交:自动'");
expect(source).toContain('gn-v2-query-toolbar-action-group');
expect(source).toContain('style={isV2Ui ? undefined : { width: 150 }}');
expect(source).toContain('style={isV2Ui ? undefined : { width: 200 }}');

View File

@@ -18,6 +18,7 @@ import { applyQueryAutoLimit } from '../utils/queryAutoLimit';
import { extractQueryResultTableRef, type QueryResultTableRef } from '../utils/queryResultTable';
import { quoteIdentPart } from '../utils/sql';
import { formatSqlExecutionError } from '../utils/sqlErrorSemantics';
import { shouldUseSqlEditorManagedTransaction } from '../utils/sqlEditorTransaction';
import { findSqlStatementRanges, resolveCurrentSqlStatementRange, resolveExecutableSql } from '../utils/sqlStatementSelection';
import { isMacLikePlatform } from '../utils/appearance';
import { splitSidebarQualifiedName } from '../utils/sidebarLocate';
@@ -750,9 +751,6 @@ const areSqlStatementListsEqual = (left: string[], right: string[]): boolean =>
&& left.every((statement, index) => normalizeExecutedSqlKey(statement) === normalizeExecutedSqlKey(right[index]))
);
const SQL_EDITOR_DML_KEYWORDS = new Set(['insert', 'update', 'delete', 'replace', 'merge', 'upsert']);
const SQL_EDITOR_READ_KEYWORDS = new Set(['select', 'with', 'show', 'describe', 'desc', 'explain', 'pragma', 'values']);
const SQL_EDITOR_TRANSACTION_CONTROL_KEYWORDS = new Set(['begin', 'commit', 'rollback', 'savepoint', 'release']);
const SQL_EDITOR_AUTO_COMMIT_DELAY_OPTIONS = [
{ value: 3000, label: '3 秒' },
{ value: 5000, label: '5 秒' },
@@ -760,50 +758,6 @@ const SQL_EDITOR_AUTO_COMMIT_DELAY_OPTIONS = [
{ value: 30000, label: '30 秒' },
];
const resolveLeadingSqlKeyword = (statement: string): string => {
let text = String(statement || '').trim();
while (text) {
if (text.startsWith('--') || text.startsWith('#')) {
const lineBreak = text.indexOf('\n');
if (lineBreak < 0) return '';
text = text.slice(lineBreak + 1).trimStart();
continue;
}
if (text.startsWith('/*')) {
const blockEnd = text.indexOf('*/');
if (blockEnd < 0) return '';
text = text.slice(blockEnd + 2).trimStart();
continue;
}
break;
}
const match = text.match(/^([A-Za-z0-9_]+)/);
return match?.[1]?.toLowerCase() || '';
};
const isSqlEditorTransactionControlStatement = (statement: string): boolean => {
const keyword = resolveLeadingSqlKeyword(statement);
if (SQL_EDITOR_TRANSACTION_CONTROL_KEYWORDS.has(keyword)) return true;
return keyword === 'start' && /\btransaction\b/i.test(statement);
};
const shouldUseSqlEditorManagedTransaction = (statements: string[]): boolean => {
let hasManagedWrite = false;
for (const statement of statements) {
const trimmed = String(statement || '').trim();
if (!trimmed) continue;
if (isSqlEditorTransactionControlStatement(trimmed)) return false;
const keyword = resolveLeadingSqlKeyword(trimmed);
if (SQL_EDITOR_READ_KEYWORDS.has(keyword)) continue;
if (SQL_EDITOR_DML_KEYWORDS.has(keyword)) {
hasManagedWrite = true;
continue;
}
return false;
}
return hasManagedWrite;
};
type PendingSqlEditorTransaction = {
id: string;
commitMode: 'manual' | 'auto';
@@ -5369,15 +5323,15 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
]}
/>
</Tooltip>
<Tooltip title="SQL 编辑器直接执行 INSERT/UPDATE/DELETE 等增删改语句时启用事务;手动提交更安全,自动提交会在执行成功后按所选时间提交。">
<Tooltip title="SQL 编辑器执行 INSERT/UPDATE/DELETE 等 DML 时始终启用事务;这里仅选择该事务执行成功后的提交方式。">
<Select
className={isV2Ui ? 'gn-v2-query-toolbar-select gn-v2-query-toolbar-transaction-mode-select' : undefined}
style={isV2Ui ? undefined : { width: 128 }}
value={sqlEditorCommitMode}
onChange={(mode) => setSqlEditorTransactionOptions({ commitMode: mode === 'auto' ? 'auto' : 'manual' })}
options={[
{ label: '事务:手动', value: 'manual' },
{ label: '事务:自动', value: 'auto' },
{ label: '提交:手动', value: 'manual' },
{ label: '提交:自动', value: 'auto' },
]}
/>
</Tooltip>

View File

@@ -0,0 +1,41 @@
import { describe, expect, it } from 'vitest';
import {
resolveSqlEditorOperationKeyword,
shouldUseSqlEditorManagedTransaction,
} from './sqlEditorTransaction';
describe('sqlEditorTransaction', () => {
it('keeps regular DML in a managed transaction', () => {
expect(shouldUseSqlEditorManagedTransaction(['UPDATE users SET name = "n" WHERE id = 1'])).toBe(true);
expect(shouldUseSqlEditorManagedTransaction(['INSERT INTO users(id) VALUES (1)'])).toBe(true);
expect(shouldUseSqlEditorManagedTransaction(['DELETE FROM users WHERE id = 1'])).toBe(true);
});
it('classifies WITH statements by their top-level operation', () => {
expect(resolveSqlEditorOperationKeyword('WITH target AS (SELECT id FROM users) SELECT * FROM target')).toBe('select');
expect(resolveSqlEditorOperationKeyword('WITH target AS (SELECT id FROM users) UPDATE users SET synced = 1')).toBe('update');
expect(resolveSqlEditorOperationKeyword('WITH target AS (SELECT id FROM users) DELETE FROM users WHERE id IN (SELECT id FROM target)')).toBe('delete');
});
it('uses managed transactions for WITH DML but not WITH SELECT', () => {
expect(shouldUseSqlEditorManagedTransaction([
'WITH target AS (SELECT id FROM users) UPDATE users SET synced = 1 WHERE id IN (SELECT id FROM target)',
])).toBe(true);
expect(shouldUseSqlEditorManagedTransaction([
'WITH target AS (SELECT id FROM users) SELECT * FROM target',
])).toBe(false);
});
it('does not wrap user-authored explicit transactions', () => {
expect(shouldUseSqlEditorManagedTransaction([
'BEGIN',
'UPDATE users SET name = "n" WHERE id = 1',
'COMMIT',
])).toBe(false);
expect(shouldUseSqlEditorManagedTransaction([
'START TRANSACTION',
'DELETE FROM users WHERE id = 1',
])).toBe(false);
});
});

View File

@@ -0,0 +1,246 @@
const SQL_EDITOR_DML_KEYWORDS = new Set(['insert', 'update', 'delete', 'replace', 'merge', 'upsert']);
const SQL_EDITOR_READ_KEYWORDS = new Set(['select', 'with', 'show', 'describe', 'desc', 'explain', 'pragma', 'values']);
const SQL_EDITOR_TRANSACTION_CONTROL_KEYWORDS = new Set(['begin', 'commit', 'rollback', 'savepoint', 'release']);
const isSqlEditorKeywordChar = (char: string | undefined): boolean => !!char && /[A-Za-z0-9_]/.test(char);
const skipSqlEditorTrivia = (text: string, start: number): number => {
let pos = start;
while (pos < text.length) {
const char = text[pos];
if (/\s/.test(char || '')) {
pos++;
continue;
}
if (text.startsWith('--', pos) || text.startsWith('#', pos)) {
const nextLine = text.indexOf('\n', pos);
if (nextLine < 0) return text.length;
pos = nextLine + 1;
continue;
}
if (text.startsWith('/*', pos)) {
const blockEnd = text.indexOf('*/', pos + 2);
if (blockEnd < 0) return text.length;
pos = blockEnd + 2;
continue;
}
return pos;
}
return pos;
};
const readSqlEditorKeyword = (text: string, start: number): { keyword: string; end: number } => {
const pos = skipSqlEditorTrivia(text, start);
if (!isSqlEditorKeywordChar(text[pos])) {
return { keyword: '', end: pos };
}
let end = pos + 1;
while (isSqlEditorKeywordChar(text[end])) {
end++;
}
return { keyword: text.slice(pos, end).toLowerCase(), end };
};
const skipSqlEditorDelimited = (text: string, start: number, delimiter: string): number => {
let pos = start + 1;
while (pos < text.length) {
if (text[pos] === delimiter) {
if (text[pos + 1] === delimiter) {
pos += 2;
continue;
}
return pos + 1;
}
pos++;
}
return text.length;
};
const resolveSqlEditorDollarQuoteTag = (text: string, start: number): string => {
if (text[start] !== '$') return '';
let end = start + 1;
while (isSqlEditorKeywordChar(text[end])) {
end++;
}
return text[end] === '$' ? text.slice(start, end + 1) : '';
};
const skipSqlEditorQuotedOrComment = (text: string, start: number): number | null => {
if (text.startsWith('--', start) || text.startsWith('#', start)) {
const nextLine = text.indexOf('\n', start);
return nextLine < 0 ? text.length : nextLine + 1;
}
if (text.startsWith('/*', start)) {
const blockEnd = text.indexOf('*/', start + 2);
return blockEnd < 0 ? text.length : blockEnd + 2;
}
const char = text[start];
if (char === '\'' || char === '"' || char === '`') {
return skipSqlEditorDelimited(text, start, char);
}
if (char === '[') {
const bracketEnd = text.indexOf(']', start + 1);
return bracketEnd < 0 ? text.length : bracketEnd + 1;
}
const dollarTag = resolveSqlEditorDollarQuoteTag(text, start);
if (dollarTag) {
const dollarEnd = text.indexOf(dollarTag, start + dollarTag.length);
return dollarEnd < 0 ? text.length : dollarEnd + dollarTag.length;
}
return null;
};
const skipBalancedSqlEditorParens = (text: string, start: number): number => {
if (text[start] !== '(') return -1;
let depth = 0;
let pos = start;
while (pos < text.length) {
const skipped = skipSqlEditorQuotedOrComment(text, pos);
if (skipped !== null) {
pos = skipped;
continue;
}
if (text[pos] === '(') {
depth++;
pos++;
continue;
}
if (text[pos] === ')') {
depth--;
pos++;
if (depth === 0) return pos;
continue;
}
pos++;
}
return -1;
};
const skipSqlEditorIdentifierToken = (text: string, start: number): number => {
if (start >= text.length) return -1;
const char = text[start];
if (char === '"' || char === '`') return skipSqlEditorDelimited(text, start, char);
if (char === '[') {
const bracketEnd = text.indexOf(']', start + 1);
return bracketEnd < 0 ? text.length : bracketEnd + 1;
}
if (!isSqlEditorKeywordChar(char)) return -1;
let end = start + 1;
while (isSqlEditorKeywordChar(text[end])) {
end++;
}
return end;
};
const findTopLevelSqlEditorKeyword = (text: string, start: number, keyword: string): number => {
let depth = 0;
let pos = start;
while (pos < text.length) {
const skipped = skipSqlEditorQuotedOrComment(text, pos);
if (skipped !== null) {
pos = skipped;
continue;
}
if (text[pos] === '(') {
depth++;
pos++;
continue;
}
if (text[pos] === ')') {
if (depth > 0) depth--;
pos++;
continue;
}
if (depth === 0 && isSqlEditorKeywordChar(text[pos])) {
let end = pos + 1;
while (isSqlEditorKeywordChar(text[end])) {
end++;
}
if (text.slice(pos, end).toLowerCase() === keyword) {
return end;
}
pos = end;
continue;
}
pos++;
}
return -1;
};
const resolveSqlEditorKeywordAfterWith = (text: string, start: number): string => {
let pos = skipSqlEditorTrivia(text, start);
const recursive = readSqlEditorKeyword(text, pos);
if (recursive.keyword === 'recursive') {
pos = recursive.end;
}
while (pos < text.length) {
pos = skipSqlEditorTrivia(text, pos);
const identifierEnd = skipSqlEditorIdentifierToken(text, pos);
if (identifierEnd < 0) return '';
pos = skipSqlEditorTrivia(text, identifierEnd);
if (text[pos] === '(') {
const columnsEnd = skipBalancedSqlEditorParens(text, pos);
if (columnsEnd < 0) return '';
pos = skipSqlEditorTrivia(text, columnsEnd);
}
const asEnd = findTopLevelSqlEditorKeyword(text, pos, 'as');
if (asEnd < 0) return '';
pos = skipSqlEditorTrivia(text, asEnd);
const materialized = readSqlEditorKeyword(text, pos);
if (materialized.keyword === 'not') {
const next = readSqlEditorKeyword(text, materialized.end);
if (next.keyword === 'materialized') {
pos = next.end;
}
} else if (materialized.keyword === 'materialized') {
pos = materialized.end;
}
pos = skipSqlEditorTrivia(text, pos);
if (text[pos] !== '(') return '';
const cteEnd = skipBalancedSqlEditorParens(text, pos);
if (cteEnd < 0) return '';
pos = skipSqlEditorTrivia(text, cteEnd);
if (text[pos] === ',') {
pos++;
continue;
}
return readSqlEditorKeyword(text, pos).keyword;
}
return '';
};
export const resolveSqlEditorOperationKeyword = (statement: string): string => {
const text = String(statement || '');
const leading = readSqlEditorKeyword(text, 0);
if (leading.keyword !== 'with') {
return leading.keyword;
}
return resolveSqlEditorKeywordAfterWith(text, leading.end) || leading.keyword;
};
const isSqlEditorTransactionControlStatement = (statement: string): boolean => {
const keyword = readSqlEditorKeyword(String(statement || ''), 0).keyword;
if (SQL_EDITOR_TRANSACTION_CONTROL_KEYWORDS.has(keyword)) return true;
return keyword === 'start' && /\btransaction\b/i.test(statement);
};
export const shouldUseSqlEditorManagedTransaction = (statements: string[]): boolean => {
let hasManagedWrite = false;
for (const statement of statements) {
const trimmed = String(statement || '').trim();
if (!trimmed) continue;
if (isSqlEditorTransactionControlStatement(trimmed)) return false;
const keyword = resolveSqlEditorOperationKeyword(trimmed);
if (SQL_EDITOR_READ_KEYWORDS.has(keyword)) continue;
if (SQL_EDITOR_DML_KEYWORDS.has(keyword)) {
hasManagedWrite = true;
continue;
}
return false;
}
return hasManagedWrite;
};

View File

@@ -584,6 +584,46 @@ func TestDBQueryMultiTransactionalKeepsDMLTransactionOpenUntilCommit(t *testing.
}
}
func TestDBQueryMultiTransactionalTreatsWithDMLAsManagedWrite(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
})
stmt := "WITH target AS (SELECT id FROM users WHERE active = 1) UPDATE users SET synced = 1 WHERE id IN (SELECT id FROM target)"
fakeDB := &fakeBatchWriteDB{
execAffected: map[string]int64{
stmt: 2,
},
}
newDatabaseFunc = func(dbType string) (db.Database, error) {
return fakeDB, nil
}
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
config := connection.ConnectionConfig{Type: "postgres", Host: "127.0.0.1", Port: 5432, User: "postgres"}
result := app.DBQueryMultiTransactional(config, "main", stmt, "with-dml-query")
if !result.Success {
t.Fatalf("expected transactional WITH DML success, got failure: %s", result.Message)
}
if result.TransactionID == "" || !result.TransactionPending {
t.Fatalf("expected pending transaction metadata, got id=%q pending=%v", result.TransactionID, result.TransactionPending)
}
if fakeDB.session == nil || fakeDB.session.closed {
t.Fatal("expected WITH DML transaction session to stay open")
}
wantExecs := []string{"BEGIN", stmt}
if len(fakeDB.execQueries) != len(wantExecs) {
t.Fatalf("expected exec queries %#v, got %#v", wantExecs, fakeDB.execQueries)
}
for i, want := range wantExecs {
if fakeDB.execQueries[i] != want {
t.Fatalf("expected exec query %d = %q, got %q", i, want, fakeDB.execQueries[i])
}
}
}
func TestDBQueryMultiTransactionalRollsBackAndClosesOnDMLFailure(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {

View File

@@ -52,12 +52,282 @@ func leadingSQLKeyword(query string) string {
return strings.ToLower(text)
}
func sqlDataOperationKeyword(query string) string {
keyword, keywordEnd := nextSQLKeyword(query, 0)
if keyword != "with" {
return keyword
}
if withKeyword, ok := sqlKeywordAfterLeadingWith(query, keywordEnd); ok {
return withKeyword
}
return keyword
}
func nextSQLKeyword(text string, start int) (string, int) {
pos := skipSQLTrivia(text, start)
if pos >= len(text) || !isSQLKeywordByte(text[pos]) {
return "", pos
}
end := pos + 1
for end < len(text) && isSQLKeywordByte(text[end]) {
end++
}
return strings.ToLower(text[pos:end]), end
}
func skipSQLTrivia(text string, start int) int {
pos := start
for pos < len(text) {
switch {
case text[pos] == ' ' || text[pos] == '\t' || text[pos] == '\r' || text[pos] == '\n' || text[pos] == '\f':
pos++
case strings.HasPrefix(text[pos:], "--"):
next := strings.IndexByte(text[pos:], '\n')
if next < 0 {
return len(text)
}
pos += next + 1
case strings.HasPrefix(text[pos:], "#"):
next := strings.IndexByte(text[pos:], '\n')
if next < 0 {
return len(text)
}
pos += next + 1
case strings.HasPrefix(text[pos:], "/*"):
end := strings.Index(text[pos+2:], "*/")
if end < 0 {
return len(text)
}
pos += end + 4
default:
return pos
}
}
return pos
}
func sqlKeywordAfterLeadingWith(text string, start int) (string, bool) {
pos := skipSQLTrivia(text, start)
if keyword, end := nextSQLKeyword(text, pos); keyword == "recursive" {
pos = end
}
for {
pos = skipSQLTrivia(text, pos)
next, ok := skipSQLIdentifierToken(text, pos)
if !ok {
return "", false
}
pos = skipSQLTrivia(text, next)
if pos < len(text) && text[pos] == '(' {
next = skipBalancedSQLParens(text, pos)
if next < 0 {
return "", false
}
pos = skipSQLTrivia(text, next)
}
asEnd := findTopLevelSQLKeyword(text, pos, "as")
if asEnd < 0 {
return "", false
}
pos = skipSQLTrivia(text, asEnd)
if keyword, end := nextSQLKeyword(text, pos); keyword == "not" {
if nextKeyword, nextEnd := nextSQLKeyword(text, end); nextKeyword == "materialized" {
pos = nextEnd
}
} else if keyword == "materialized" {
pos = end
}
pos = skipSQLTrivia(text, pos)
if pos >= len(text) || text[pos] != '(' {
return "", false
}
next = skipBalancedSQLParens(text, pos)
if next < 0 {
return "", false
}
pos = skipSQLTrivia(text, next)
if pos < len(text) && text[pos] == ',' {
pos++
continue
}
keyword, _ := nextSQLKeyword(text, pos)
return keyword, keyword != ""
}
}
func findTopLevelSQLKeyword(text string, start int, want string) int {
depth := 0
for pos := start; pos < len(text); {
if next, ok := skipSQLQuotedOrComment(text, pos); ok {
pos = next
continue
}
switch text[pos] {
case '(':
depth++
pos++
case ')':
if depth > 0 {
depth--
}
pos++
default:
if depth == 0 && isSQLKeywordByte(text[pos]) {
end := pos + 1
for end < len(text) && isSQLKeywordByte(text[end]) {
end++
}
if strings.EqualFold(text[pos:end], want) {
return end
}
pos = end
continue
}
pos++
}
}
return -1
}
func skipSQLIdentifierToken(text string, start int) (int, bool) {
if start >= len(text) {
return start, false
}
switch text[start] {
case '"', '`':
next := skipSQLDelimited(text, start, text[start])
return next, next > start
case '[':
next := strings.IndexByte(text[start+1:], ']')
if next < 0 {
return len(text), true
}
return start + next + 2, true
default:
if !isSQLKeywordByte(text[start]) {
return start, false
}
end := start + 1
for end < len(text) && isSQLKeywordByte(text[end]) {
end++
}
return end, true
}
}
func skipBalancedSQLParens(text string, start int) int {
if start >= len(text) || text[start] != '(' {
return -1
}
depth := 0
for pos := start; pos < len(text); {
if next, ok := skipSQLQuotedOrComment(text, pos); ok {
pos = next
continue
}
switch text[pos] {
case '(':
depth++
pos++
case ')':
depth--
pos++
if depth == 0 {
return pos
}
default:
pos++
}
}
return -1
}
func skipSQLQuotedOrComment(text string, start int) (int, bool) {
if start >= len(text) {
return start, false
}
switch {
case strings.HasPrefix(text[start:], "--"):
next := strings.IndexByte(text[start:], '\n')
if next < 0 {
return len(text), true
}
return start + next + 1, true
case strings.HasPrefix(text[start:], "#"):
next := strings.IndexByte(text[start:], '\n')
if next < 0 {
return len(text), true
}
return start + next + 1, true
case strings.HasPrefix(text[start:], "/*"):
end := strings.Index(text[start+2:], "*/")
if end < 0 {
return len(text), true
}
return start + end + 4, true
case text[start] == '\'' || text[start] == '"' || text[start] == '`':
return skipSQLDelimited(text, start, text[start]), true
case text[start] == '[':
next := strings.IndexByte(text[start+1:], ']')
if next < 0 {
return len(text), true
}
return start + next + 2, true
default:
if tag, ok := sqlDollarQuoteTag(text, start); ok {
end := strings.Index(text[start+len(tag):], tag)
if end < 0 {
return len(text), true
}
return start + len(tag) + end + len(tag), true
}
return start, false
}
}
func skipSQLDelimited(text string, start int, delimiter byte) int {
pos := start + 1
for pos < len(text) {
if text[pos] == delimiter {
if pos+1 < len(text) && text[pos+1] == delimiter {
pos += 2
continue
}
return pos + 1
}
pos++
}
return len(text)
}
func sqlDollarQuoteTag(text string, start int) (string, bool) {
if start >= len(text) || text[start] != '$' {
return "", false
}
end := start + 1
for end < len(text) && (isSQLKeywordByte(text[end]) || text[end] == '-') {
end++
}
if end < len(text) && text[end] == '$' {
return text[start : end+1], true
}
return "", false
}
func isSQLKeywordByte(ch byte) bool {
return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_'
}
func isReadOnlySQLQuery(dbType string, query string) bool {
if strings.ToLower(strings.TrimSpace(dbType)) == "mongodb" && strings.HasPrefix(strings.TrimSpace(query), "{") {
return true
}
switch leadingSQLKeyword(query) {
switch sqlDataOperationKeyword(query) {
case "select", "with", "show", "describe", "desc", "explain", "pragma", "values":
return true
default:
@@ -70,7 +340,7 @@ func isBatchableWriteSQLStatement(dbType string, query string) bool {
return false
}
switch leadingSQLKeyword(query) {
switch sqlDataOperationKeyword(query) {
case "insert", "update", "delete", "replace", "merge", "upsert":
return true
default:

View File

@@ -69,10 +69,25 @@ func TestIsReadOnlySQLQuery_DoesNotTreatExecAsReadOnly(t *testing.T) {
}
}
func TestIsReadOnlySQLQuery_ClassifiesWithByTopLevelOperation(t *testing.T) {
readQuery := "WITH target AS (SELECT id FROM users WHERE active = 1) SELECT * FROM target"
if !isReadOnlySQLQuery("postgres", readQuery) {
t.Fatal("WITH ... SELECT should stay read-only")
}
writeQuery := "WITH target AS (SELECT id FROM users WHERE active = 1) UPDATE users SET synced = 1 WHERE id IN (SELECT id FROM target)"
if isReadOnlySQLQuery("postgres", writeQuery) {
t.Fatal("WITH ... UPDATE should not be treated as read-only")
}
}
func TestIsBatchableWriteSQLStatement_OnlyMatchesRealWriteStatements(t *testing.T) {
if !isBatchableWriteSQLStatement("mysql", "INSERT INTO demo(id) VALUES (1)") {
t.Fatal("expected INSERT to be treated as batchable write")
}
if !isBatchableWriteSQLStatement("postgres", "WITH target AS (SELECT id FROM users) DELETE FROM users WHERE id IN (SELECT id FROM target)") {
t.Fatal("expected WITH ... DELETE to be treated as batchable write")
}
if isBatchableWriteSQLStatement("sqlserver", "EXEC sp_who2") {
t.Fatal("EXEC should not be treated as batchable write")
}