mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-06-14 10:29:52 +08:00
🐛 fix(editor): 修正 SQL 编辑器 DML 事务提交语义
- SQL 编辑器 DML 固定进入托管事务 - 区分 WITH SELECT 和 WITH DML 的事务判定 - 调整提交方式文案并补充前后端回归测试
This commit is contained in:
@@ -2182,6 +2182,81 @@ describe('QueryEditor external SQL save', () => {
|
||||
expect(textContent(renderer!.root)).not.toContain('事务待提交');
|
||||
});
|
||||
|
||||
it('runs SQL editor WITH DML through a pending managed transaction', async () => {
|
||||
const sql = 'WITH target AS (SELECT id FROM users WHERE active = 1) UPDATE users SET synced = 1 WHERE id IN (SELECT id FROM target)';
|
||||
backendApp.DBQueryMultiTransactional.mockResolvedValueOnce({
|
||||
success: true,
|
||||
transactionId: 'tx-with-dml',
|
||||
transactionPending: true,
|
||||
data: [
|
||||
{ columns: ['affectedRows'], rows: [{ affectedRows: 2 }], statementIndex: 1 },
|
||||
],
|
||||
});
|
||||
|
||||
let renderer!: ReactTestRenderer;
|
||||
await act(async () => {
|
||||
renderer = create(<QueryEditor tab={createTab({ query: sql })} />);
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
await findButton(renderer!, '运行').props.onClick();
|
||||
});
|
||||
await act(async () => {
|
||||
await Promise.resolve();
|
||||
await Promise.resolve();
|
||||
});
|
||||
|
||||
expect(backendApp.DBQueryMultiTransactional).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
'main',
|
||||
expect.stringContaining('WITH target AS'),
|
||||
'query-1',
|
||||
);
|
||||
expect(backendApp.DBQueryMulti).not.toHaveBeenCalled();
|
||||
expect(textContent(renderer!.root)).toContain('事务待提交');
|
||||
|
||||
await act(async () => {
|
||||
await findExactButton(renderer!, '提交').props.onClick();
|
||||
});
|
||||
await act(async () => {
|
||||
await Promise.resolve();
|
||||
await Promise.resolve();
|
||||
});
|
||||
|
||||
expect(backendApp.DBCommitTransaction).toHaveBeenCalledWith('tx-with-dml');
|
||||
});
|
||||
|
||||
it('keeps SQL editor WITH SELECT on the regular query path', async () => {
|
||||
const sql = 'WITH target AS (SELECT id FROM users WHERE active = 1) SELECT * FROM target';
|
||||
backendApp.DBQueryMulti.mockResolvedValueOnce({
|
||||
success: true,
|
||||
data: [
|
||||
{ columns: ['id'], rows: [{ id: 1 }], statementIndex: 1 },
|
||||
],
|
||||
});
|
||||
|
||||
let renderer!: ReactTestRenderer;
|
||||
await act(async () => {
|
||||
renderer = create(<QueryEditor tab={createTab({ query: sql })} />);
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
await findButton(renderer!, '运行').props.onClick();
|
||||
});
|
||||
await act(async () => {
|
||||
await Promise.resolve();
|
||||
await Promise.resolve();
|
||||
});
|
||||
|
||||
expect(backendApp.DBQueryMulti).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
'main',
|
||||
expect.stringContaining('WITH target AS'),
|
||||
'query-1',
|
||||
);
|
||||
expect(backendApp.DBQueryMultiTransactional).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('auto commits SQL editor DML transactions after the configured delay', async () => {
|
||||
vi.useFakeTimers();
|
||||
storeState.sqlEditorTransactionOptions = {
|
||||
@@ -3446,6 +3521,9 @@ describe('QueryEditor external SQL save', () => {
|
||||
expect(source).toContain('gn-v2-query-toolbar-max-rows-select');
|
||||
expect(source).toContain('gn-v2-query-toolbar-transaction-mode-select');
|
||||
expect(source).toContain('gn-v2-query-toolbar-transaction-delay-select');
|
||||
expect(source).toContain('这里仅选择该事务执行成功后的提交方式');
|
||||
expect(source).toContain("label: '提交:手动'");
|
||||
expect(source).toContain("label: '提交:自动'");
|
||||
expect(source).toContain('gn-v2-query-toolbar-action-group');
|
||||
expect(source).toContain('style={isV2Ui ? undefined : { width: 150 }}');
|
||||
expect(source).toContain('style={isV2Ui ? undefined : { width: 200 }}');
|
||||
|
||||
@@ -18,6 +18,7 @@ import { applyQueryAutoLimit } from '../utils/queryAutoLimit';
|
||||
import { extractQueryResultTableRef, type QueryResultTableRef } from '../utils/queryResultTable';
|
||||
import { quoteIdentPart } from '../utils/sql';
|
||||
import { formatSqlExecutionError } from '../utils/sqlErrorSemantics';
|
||||
import { shouldUseSqlEditorManagedTransaction } from '../utils/sqlEditorTransaction';
|
||||
import { findSqlStatementRanges, resolveCurrentSqlStatementRange, resolveExecutableSql } from '../utils/sqlStatementSelection';
|
||||
import { isMacLikePlatform } from '../utils/appearance';
|
||||
import { splitSidebarQualifiedName } from '../utils/sidebarLocate';
|
||||
@@ -750,9 +751,6 @@ const areSqlStatementListsEqual = (left: string[], right: string[]): boolean =>
|
||||
&& left.every((statement, index) => normalizeExecutedSqlKey(statement) === normalizeExecutedSqlKey(right[index]))
|
||||
);
|
||||
|
||||
const SQL_EDITOR_DML_KEYWORDS = new Set(['insert', 'update', 'delete', 'replace', 'merge', 'upsert']);
|
||||
const SQL_EDITOR_READ_KEYWORDS = new Set(['select', 'with', 'show', 'describe', 'desc', 'explain', 'pragma', 'values']);
|
||||
const SQL_EDITOR_TRANSACTION_CONTROL_KEYWORDS = new Set(['begin', 'commit', 'rollback', 'savepoint', 'release']);
|
||||
const SQL_EDITOR_AUTO_COMMIT_DELAY_OPTIONS = [
|
||||
{ value: 3000, label: '3 秒' },
|
||||
{ value: 5000, label: '5 秒' },
|
||||
@@ -760,50 +758,6 @@ const SQL_EDITOR_AUTO_COMMIT_DELAY_OPTIONS = [
|
||||
{ value: 30000, label: '30 秒' },
|
||||
];
|
||||
|
||||
const resolveLeadingSqlKeyword = (statement: string): string => {
|
||||
let text = String(statement || '').trim();
|
||||
while (text) {
|
||||
if (text.startsWith('--') || text.startsWith('#')) {
|
||||
const lineBreak = text.indexOf('\n');
|
||||
if (lineBreak < 0) return '';
|
||||
text = text.slice(lineBreak + 1).trimStart();
|
||||
continue;
|
||||
}
|
||||
if (text.startsWith('/*')) {
|
||||
const blockEnd = text.indexOf('*/');
|
||||
if (blockEnd < 0) return '';
|
||||
text = text.slice(blockEnd + 2).trimStart();
|
||||
continue;
|
||||
}
|
||||
break;
|
||||
}
|
||||
const match = text.match(/^([A-Za-z0-9_]+)/);
|
||||
return match?.[1]?.toLowerCase() || '';
|
||||
};
|
||||
|
||||
const isSqlEditorTransactionControlStatement = (statement: string): boolean => {
|
||||
const keyword = resolveLeadingSqlKeyword(statement);
|
||||
if (SQL_EDITOR_TRANSACTION_CONTROL_KEYWORDS.has(keyword)) return true;
|
||||
return keyword === 'start' && /\btransaction\b/i.test(statement);
|
||||
};
|
||||
|
||||
const shouldUseSqlEditorManagedTransaction = (statements: string[]): boolean => {
|
||||
let hasManagedWrite = false;
|
||||
for (const statement of statements) {
|
||||
const trimmed = String(statement || '').trim();
|
||||
if (!trimmed) continue;
|
||||
if (isSqlEditorTransactionControlStatement(trimmed)) return false;
|
||||
const keyword = resolveLeadingSqlKeyword(trimmed);
|
||||
if (SQL_EDITOR_READ_KEYWORDS.has(keyword)) continue;
|
||||
if (SQL_EDITOR_DML_KEYWORDS.has(keyword)) {
|
||||
hasManagedWrite = true;
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return hasManagedWrite;
|
||||
};
|
||||
|
||||
type PendingSqlEditorTransaction = {
|
||||
id: string;
|
||||
commitMode: 'manual' | 'auto';
|
||||
@@ -5369,15 +5323,15 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
]}
|
||||
/>
|
||||
</Tooltip>
|
||||
<Tooltip title="SQL 编辑器直接执行 INSERT/UPDATE/DELETE 等增删改语句时启用事务;手动提交更安全,自动提交会在执行成功后按所选时间提交。">
|
||||
<Tooltip title="SQL 编辑器执行 INSERT/UPDATE/DELETE 等 DML 时始终启用事务;这里仅选择该事务执行成功后的提交方式。">
|
||||
<Select
|
||||
className={isV2Ui ? 'gn-v2-query-toolbar-select gn-v2-query-toolbar-transaction-mode-select' : undefined}
|
||||
style={isV2Ui ? undefined : { width: 128 }}
|
||||
value={sqlEditorCommitMode}
|
||||
onChange={(mode) => setSqlEditorTransactionOptions({ commitMode: mode === 'auto' ? 'auto' : 'manual' })}
|
||||
options={[
|
||||
{ label: '事务:手动', value: 'manual' },
|
||||
{ label: '事务:自动', value: 'auto' },
|
||||
{ label: '提交:手动', value: 'manual' },
|
||||
{ label: '提交:自动', value: 'auto' },
|
||||
]}
|
||||
/>
|
||||
</Tooltip>
|
||||
|
||||
41
frontend/src/utils/sqlEditorTransaction.test.ts
Normal file
41
frontend/src/utils/sqlEditorTransaction.test.ts
Normal file
@@ -0,0 +1,41 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import {
|
||||
resolveSqlEditorOperationKeyword,
|
||||
shouldUseSqlEditorManagedTransaction,
|
||||
} from './sqlEditorTransaction';
|
||||
|
||||
describe('sqlEditorTransaction', () => {
|
||||
it('keeps regular DML in a managed transaction', () => {
|
||||
expect(shouldUseSqlEditorManagedTransaction(['UPDATE users SET name = "n" WHERE id = 1'])).toBe(true);
|
||||
expect(shouldUseSqlEditorManagedTransaction(['INSERT INTO users(id) VALUES (1)'])).toBe(true);
|
||||
expect(shouldUseSqlEditorManagedTransaction(['DELETE FROM users WHERE id = 1'])).toBe(true);
|
||||
});
|
||||
|
||||
it('classifies WITH statements by their top-level operation', () => {
|
||||
expect(resolveSqlEditorOperationKeyword('WITH target AS (SELECT id FROM users) SELECT * FROM target')).toBe('select');
|
||||
expect(resolveSqlEditorOperationKeyword('WITH target AS (SELECT id FROM users) UPDATE users SET synced = 1')).toBe('update');
|
||||
expect(resolveSqlEditorOperationKeyword('WITH target AS (SELECT id FROM users) DELETE FROM users WHERE id IN (SELECT id FROM target)')).toBe('delete');
|
||||
});
|
||||
|
||||
it('uses managed transactions for WITH DML but not WITH SELECT', () => {
|
||||
expect(shouldUseSqlEditorManagedTransaction([
|
||||
'WITH target AS (SELECT id FROM users) UPDATE users SET synced = 1 WHERE id IN (SELECT id FROM target)',
|
||||
])).toBe(true);
|
||||
expect(shouldUseSqlEditorManagedTransaction([
|
||||
'WITH target AS (SELECT id FROM users) SELECT * FROM target',
|
||||
])).toBe(false);
|
||||
});
|
||||
|
||||
it('does not wrap user-authored explicit transactions', () => {
|
||||
expect(shouldUseSqlEditorManagedTransaction([
|
||||
'BEGIN',
|
||||
'UPDATE users SET name = "n" WHERE id = 1',
|
||||
'COMMIT',
|
||||
])).toBe(false);
|
||||
expect(shouldUseSqlEditorManagedTransaction([
|
||||
'START TRANSACTION',
|
||||
'DELETE FROM users WHERE id = 1',
|
||||
])).toBe(false);
|
||||
});
|
||||
});
|
||||
246
frontend/src/utils/sqlEditorTransaction.ts
Normal file
246
frontend/src/utils/sqlEditorTransaction.ts
Normal file
@@ -0,0 +1,246 @@
|
||||
const SQL_EDITOR_DML_KEYWORDS = new Set(['insert', 'update', 'delete', 'replace', 'merge', 'upsert']);
|
||||
const SQL_EDITOR_READ_KEYWORDS = new Set(['select', 'with', 'show', 'describe', 'desc', 'explain', 'pragma', 'values']);
|
||||
const SQL_EDITOR_TRANSACTION_CONTROL_KEYWORDS = new Set(['begin', 'commit', 'rollback', 'savepoint', 'release']);
|
||||
|
||||
const isSqlEditorKeywordChar = (char: string | undefined): boolean => !!char && /[A-Za-z0-9_]/.test(char);
|
||||
|
||||
const skipSqlEditorTrivia = (text: string, start: number): number => {
|
||||
let pos = start;
|
||||
while (pos < text.length) {
|
||||
const char = text[pos];
|
||||
if (/\s/.test(char || '')) {
|
||||
pos++;
|
||||
continue;
|
||||
}
|
||||
if (text.startsWith('--', pos) || text.startsWith('#', pos)) {
|
||||
const nextLine = text.indexOf('\n', pos);
|
||||
if (nextLine < 0) return text.length;
|
||||
pos = nextLine + 1;
|
||||
continue;
|
||||
}
|
||||
if (text.startsWith('/*', pos)) {
|
||||
const blockEnd = text.indexOf('*/', pos + 2);
|
||||
if (blockEnd < 0) return text.length;
|
||||
pos = blockEnd + 2;
|
||||
continue;
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
return pos;
|
||||
};
|
||||
|
||||
const readSqlEditorKeyword = (text: string, start: number): { keyword: string; end: number } => {
|
||||
const pos = skipSqlEditorTrivia(text, start);
|
||||
if (!isSqlEditorKeywordChar(text[pos])) {
|
||||
return { keyword: '', end: pos };
|
||||
}
|
||||
let end = pos + 1;
|
||||
while (isSqlEditorKeywordChar(text[end])) {
|
||||
end++;
|
||||
}
|
||||
return { keyword: text.slice(pos, end).toLowerCase(), end };
|
||||
};
|
||||
|
||||
const skipSqlEditorDelimited = (text: string, start: number, delimiter: string): number => {
|
||||
let pos = start + 1;
|
||||
while (pos < text.length) {
|
||||
if (text[pos] === delimiter) {
|
||||
if (text[pos + 1] === delimiter) {
|
||||
pos += 2;
|
||||
continue;
|
||||
}
|
||||
return pos + 1;
|
||||
}
|
||||
pos++;
|
||||
}
|
||||
return text.length;
|
||||
};
|
||||
|
||||
const resolveSqlEditorDollarQuoteTag = (text: string, start: number): string => {
|
||||
if (text[start] !== '$') return '';
|
||||
let end = start + 1;
|
||||
while (isSqlEditorKeywordChar(text[end])) {
|
||||
end++;
|
||||
}
|
||||
return text[end] === '$' ? text.slice(start, end + 1) : '';
|
||||
};
|
||||
|
||||
const skipSqlEditorQuotedOrComment = (text: string, start: number): number | null => {
|
||||
if (text.startsWith('--', start) || text.startsWith('#', start)) {
|
||||
const nextLine = text.indexOf('\n', start);
|
||||
return nextLine < 0 ? text.length : nextLine + 1;
|
||||
}
|
||||
if (text.startsWith('/*', start)) {
|
||||
const blockEnd = text.indexOf('*/', start + 2);
|
||||
return blockEnd < 0 ? text.length : blockEnd + 2;
|
||||
}
|
||||
const char = text[start];
|
||||
if (char === '\'' || char === '"' || char === '`') {
|
||||
return skipSqlEditorDelimited(text, start, char);
|
||||
}
|
||||
if (char === '[') {
|
||||
const bracketEnd = text.indexOf(']', start + 1);
|
||||
return bracketEnd < 0 ? text.length : bracketEnd + 1;
|
||||
}
|
||||
const dollarTag = resolveSqlEditorDollarQuoteTag(text, start);
|
||||
if (dollarTag) {
|
||||
const dollarEnd = text.indexOf(dollarTag, start + dollarTag.length);
|
||||
return dollarEnd < 0 ? text.length : dollarEnd + dollarTag.length;
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
const skipBalancedSqlEditorParens = (text: string, start: number): number => {
|
||||
if (text[start] !== '(') return -1;
|
||||
let depth = 0;
|
||||
let pos = start;
|
||||
while (pos < text.length) {
|
||||
const skipped = skipSqlEditorQuotedOrComment(text, pos);
|
||||
if (skipped !== null) {
|
||||
pos = skipped;
|
||||
continue;
|
||||
}
|
||||
if (text[pos] === '(') {
|
||||
depth++;
|
||||
pos++;
|
||||
continue;
|
||||
}
|
||||
if (text[pos] === ')') {
|
||||
depth--;
|
||||
pos++;
|
||||
if (depth === 0) return pos;
|
||||
continue;
|
||||
}
|
||||
pos++;
|
||||
}
|
||||
return -1;
|
||||
};
|
||||
|
||||
const skipSqlEditorIdentifierToken = (text: string, start: number): number => {
|
||||
if (start >= text.length) return -1;
|
||||
const char = text[start];
|
||||
if (char === '"' || char === '`') return skipSqlEditorDelimited(text, start, char);
|
||||
if (char === '[') {
|
||||
const bracketEnd = text.indexOf(']', start + 1);
|
||||
return bracketEnd < 0 ? text.length : bracketEnd + 1;
|
||||
}
|
||||
if (!isSqlEditorKeywordChar(char)) return -1;
|
||||
let end = start + 1;
|
||||
while (isSqlEditorKeywordChar(text[end])) {
|
||||
end++;
|
||||
}
|
||||
return end;
|
||||
};
|
||||
|
||||
const findTopLevelSqlEditorKeyword = (text: string, start: number, keyword: string): number => {
|
||||
let depth = 0;
|
||||
let pos = start;
|
||||
while (pos < text.length) {
|
||||
const skipped = skipSqlEditorQuotedOrComment(text, pos);
|
||||
if (skipped !== null) {
|
||||
pos = skipped;
|
||||
continue;
|
||||
}
|
||||
if (text[pos] === '(') {
|
||||
depth++;
|
||||
pos++;
|
||||
continue;
|
||||
}
|
||||
if (text[pos] === ')') {
|
||||
if (depth > 0) depth--;
|
||||
pos++;
|
||||
continue;
|
||||
}
|
||||
if (depth === 0 && isSqlEditorKeywordChar(text[pos])) {
|
||||
let end = pos + 1;
|
||||
while (isSqlEditorKeywordChar(text[end])) {
|
||||
end++;
|
||||
}
|
||||
if (text.slice(pos, end).toLowerCase() === keyword) {
|
||||
return end;
|
||||
}
|
||||
pos = end;
|
||||
continue;
|
||||
}
|
||||
pos++;
|
||||
}
|
||||
return -1;
|
||||
};
|
||||
|
||||
const resolveSqlEditorKeywordAfterWith = (text: string, start: number): string => {
|
||||
let pos = skipSqlEditorTrivia(text, start);
|
||||
const recursive = readSqlEditorKeyword(text, pos);
|
||||
if (recursive.keyword === 'recursive') {
|
||||
pos = recursive.end;
|
||||
}
|
||||
|
||||
while (pos < text.length) {
|
||||
pos = skipSqlEditorTrivia(text, pos);
|
||||
const identifierEnd = skipSqlEditorIdentifierToken(text, pos);
|
||||
if (identifierEnd < 0) return '';
|
||||
pos = skipSqlEditorTrivia(text, identifierEnd);
|
||||
if (text[pos] === '(') {
|
||||
const columnsEnd = skipBalancedSqlEditorParens(text, pos);
|
||||
if (columnsEnd < 0) return '';
|
||||
pos = skipSqlEditorTrivia(text, columnsEnd);
|
||||
}
|
||||
|
||||
const asEnd = findTopLevelSqlEditorKeyword(text, pos, 'as');
|
||||
if (asEnd < 0) return '';
|
||||
pos = skipSqlEditorTrivia(text, asEnd);
|
||||
const materialized = readSqlEditorKeyword(text, pos);
|
||||
if (materialized.keyword === 'not') {
|
||||
const next = readSqlEditorKeyword(text, materialized.end);
|
||||
if (next.keyword === 'materialized') {
|
||||
pos = next.end;
|
||||
}
|
||||
} else if (materialized.keyword === 'materialized') {
|
||||
pos = materialized.end;
|
||||
}
|
||||
|
||||
pos = skipSqlEditorTrivia(text, pos);
|
||||
if (text[pos] !== '(') return '';
|
||||
const cteEnd = skipBalancedSqlEditorParens(text, pos);
|
||||
if (cteEnd < 0) return '';
|
||||
pos = skipSqlEditorTrivia(text, cteEnd);
|
||||
if (text[pos] === ',') {
|
||||
pos++;
|
||||
continue;
|
||||
}
|
||||
|
||||
return readSqlEditorKeyword(text, pos).keyword;
|
||||
}
|
||||
return '';
|
||||
};
|
||||
|
||||
export const resolveSqlEditorOperationKeyword = (statement: string): string => {
|
||||
const text = String(statement || '');
|
||||
const leading = readSqlEditorKeyword(text, 0);
|
||||
if (leading.keyword !== 'with') {
|
||||
return leading.keyword;
|
||||
}
|
||||
return resolveSqlEditorKeywordAfterWith(text, leading.end) || leading.keyword;
|
||||
};
|
||||
|
||||
const isSqlEditorTransactionControlStatement = (statement: string): boolean => {
|
||||
const keyword = readSqlEditorKeyword(String(statement || ''), 0).keyword;
|
||||
if (SQL_EDITOR_TRANSACTION_CONTROL_KEYWORDS.has(keyword)) return true;
|
||||
return keyword === 'start' && /\btransaction\b/i.test(statement);
|
||||
};
|
||||
|
||||
export const shouldUseSqlEditorManagedTransaction = (statements: string[]): boolean => {
|
||||
let hasManagedWrite = false;
|
||||
for (const statement of statements) {
|
||||
const trimmed = String(statement || '').trim();
|
||||
if (!trimmed) continue;
|
||||
if (isSqlEditorTransactionControlStatement(trimmed)) return false;
|
||||
const keyword = resolveSqlEditorOperationKeyword(trimmed);
|
||||
if (SQL_EDITOR_READ_KEYWORDS.has(keyword)) continue;
|
||||
if (SQL_EDITOR_DML_KEYWORDS.has(keyword)) {
|
||||
hasManagedWrite = true;
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return hasManagedWrite;
|
||||
};
|
||||
@@ -584,6 +584,46 @@ func TestDBQueryMultiTransactionalKeepsDMLTransactionOpenUntilCommit(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiTransactionalTreatsWithDMLAsManagedWrite(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
})
|
||||
|
||||
stmt := "WITH target AS (SELECT id FROM users WHERE active = 1) UPDATE users SET synced = 1 WHERE id IN (SELECT id FROM target)"
|
||||
fakeDB := &fakeBatchWriteDB{
|
||||
execAffected: map[string]int64{
|
||||
stmt: 2,
|
||||
},
|
||||
}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
config := connection.ConnectionConfig{Type: "postgres", Host: "127.0.0.1", Port: 5432, User: "postgres"}
|
||||
|
||||
result := app.DBQueryMultiTransactional(config, "main", stmt, "with-dml-query")
|
||||
if !result.Success {
|
||||
t.Fatalf("expected transactional WITH DML success, got failure: %s", result.Message)
|
||||
}
|
||||
if result.TransactionID == "" || !result.TransactionPending {
|
||||
t.Fatalf("expected pending transaction metadata, got id=%q pending=%v", result.TransactionID, result.TransactionPending)
|
||||
}
|
||||
if fakeDB.session == nil || fakeDB.session.closed {
|
||||
t.Fatal("expected WITH DML transaction session to stay open")
|
||||
}
|
||||
wantExecs := []string{"BEGIN", stmt}
|
||||
if len(fakeDB.execQueries) != len(wantExecs) {
|
||||
t.Fatalf("expected exec queries %#v, got %#v", wantExecs, fakeDB.execQueries)
|
||||
}
|
||||
for i, want := range wantExecs {
|
||||
if fakeDB.execQueries[i] != want {
|
||||
t.Fatalf("expected exec query %d = %q, got %q", i, want, fakeDB.execQueries[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiTransactionalRollsBackAndClosesOnDMLFailure(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
|
||||
@@ -52,12 +52,282 @@ func leadingSQLKeyword(query string) string {
|
||||
return strings.ToLower(text)
|
||||
}
|
||||
|
||||
func sqlDataOperationKeyword(query string) string {
|
||||
keyword, keywordEnd := nextSQLKeyword(query, 0)
|
||||
if keyword != "with" {
|
||||
return keyword
|
||||
}
|
||||
if withKeyword, ok := sqlKeywordAfterLeadingWith(query, keywordEnd); ok {
|
||||
return withKeyword
|
||||
}
|
||||
return keyword
|
||||
}
|
||||
|
||||
func nextSQLKeyword(text string, start int) (string, int) {
|
||||
pos := skipSQLTrivia(text, start)
|
||||
if pos >= len(text) || !isSQLKeywordByte(text[pos]) {
|
||||
return "", pos
|
||||
}
|
||||
end := pos + 1
|
||||
for end < len(text) && isSQLKeywordByte(text[end]) {
|
||||
end++
|
||||
}
|
||||
return strings.ToLower(text[pos:end]), end
|
||||
}
|
||||
|
||||
func skipSQLTrivia(text string, start int) int {
|
||||
pos := start
|
||||
for pos < len(text) {
|
||||
switch {
|
||||
case text[pos] == ' ' || text[pos] == '\t' || text[pos] == '\r' || text[pos] == '\n' || text[pos] == '\f':
|
||||
pos++
|
||||
case strings.HasPrefix(text[pos:], "--"):
|
||||
next := strings.IndexByte(text[pos:], '\n')
|
||||
if next < 0 {
|
||||
return len(text)
|
||||
}
|
||||
pos += next + 1
|
||||
case strings.HasPrefix(text[pos:], "#"):
|
||||
next := strings.IndexByte(text[pos:], '\n')
|
||||
if next < 0 {
|
||||
return len(text)
|
||||
}
|
||||
pos += next + 1
|
||||
case strings.HasPrefix(text[pos:], "/*"):
|
||||
end := strings.Index(text[pos+2:], "*/")
|
||||
if end < 0 {
|
||||
return len(text)
|
||||
}
|
||||
pos += end + 4
|
||||
default:
|
||||
return pos
|
||||
}
|
||||
}
|
||||
return pos
|
||||
}
|
||||
|
||||
func sqlKeywordAfterLeadingWith(text string, start int) (string, bool) {
|
||||
pos := skipSQLTrivia(text, start)
|
||||
if keyword, end := nextSQLKeyword(text, pos); keyword == "recursive" {
|
||||
pos = end
|
||||
}
|
||||
|
||||
for {
|
||||
pos = skipSQLTrivia(text, pos)
|
||||
next, ok := skipSQLIdentifierToken(text, pos)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
pos = skipSQLTrivia(text, next)
|
||||
if pos < len(text) && text[pos] == '(' {
|
||||
next = skipBalancedSQLParens(text, pos)
|
||||
if next < 0 {
|
||||
return "", false
|
||||
}
|
||||
pos = skipSQLTrivia(text, next)
|
||||
}
|
||||
|
||||
asEnd := findTopLevelSQLKeyword(text, pos, "as")
|
||||
if asEnd < 0 {
|
||||
return "", false
|
||||
}
|
||||
pos = skipSQLTrivia(text, asEnd)
|
||||
if keyword, end := nextSQLKeyword(text, pos); keyword == "not" {
|
||||
if nextKeyword, nextEnd := nextSQLKeyword(text, end); nextKeyword == "materialized" {
|
||||
pos = nextEnd
|
||||
}
|
||||
} else if keyword == "materialized" {
|
||||
pos = end
|
||||
}
|
||||
|
||||
pos = skipSQLTrivia(text, pos)
|
||||
if pos >= len(text) || text[pos] != '(' {
|
||||
return "", false
|
||||
}
|
||||
next = skipBalancedSQLParens(text, pos)
|
||||
if next < 0 {
|
||||
return "", false
|
||||
}
|
||||
pos = skipSQLTrivia(text, next)
|
||||
if pos < len(text) && text[pos] == ',' {
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
|
||||
keyword, _ := nextSQLKeyword(text, pos)
|
||||
return keyword, keyword != ""
|
||||
}
|
||||
}
|
||||
|
||||
func findTopLevelSQLKeyword(text string, start int, want string) int {
|
||||
depth := 0
|
||||
for pos := start; pos < len(text); {
|
||||
if next, ok := skipSQLQuotedOrComment(text, pos); ok {
|
||||
pos = next
|
||||
continue
|
||||
}
|
||||
switch text[pos] {
|
||||
case '(':
|
||||
depth++
|
||||
pos++
|
||||
case ')':
|
||||
if depth > 0 {
|
||||
depth--
|
||||
}
|
||||
pos++
|
||||
default:
|
||||
if depth == 0 && isSQLKeywordByte(text[pos]) {
|
||||
end := pos + 1
|
||||
for end < len(text) && isSQLKeywordByte(text[end]) {
|
||||
end++
|
||||
}
|
||||
if strings.EqualFold(text[pos:end], want) {
|
||||
return end
|
||||
}
|
||||
pos = end
|
||||
continue
|
||||
}
|
||||
pos++
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func skipSQLIdentifierToken(text string, start int) (int, bool) {
|
||||
if start >= len(text) {
|
||||
return start, false
|
||||
}
|
||||
switch text[start] {
|
||||
case '"', '`':
|
||||
next := skipSQLDelimited(text, start, text[start])
|
||||
return next, next > start
|
||||
case '[':
|
||||
next := strings.IndexByte(text[start+1:], ']')
|
||||
if next < 0 {
|
||||
return len(text), true
|
||||
}
|
||||
return start + next + 2, true
|
||||
default:
|
||||
if !isSQLKeywordByte(text[start]) {
|
||||
return start, false
|
||||
}
|
||||
end := start + 1
|
||||
for end < len(text) && isSQLKeywordByte(text[end]) {
|
||||
end++
|
||||
}
|
||||
return end, true
|
||||
}
|
||||
}
|
||||
|
||||
func skipBalancedSQLParens(text string, start int) int {
|
||||
if start >= len(text) || text[start] != '(' {
|
||||
return -1
|
||||
}
|
||||
depth := 0
|
||||
for pos := start; pos < len(text); {
|
||||
if next, ok := skipSQLQuotedOrComment(text, pos); ok {
|
||||
pos = next
|
||||
continue
|
||||
}
|
||||
switch text[pos] {
|
||||
case '(':
|
||||
depth++
|
||||
pos++
|
||||
case ')':
|
||||
depth--
|
||||
pos++
|
||||
if depth == 0 {
|
||||
return pos
|
||||
}
|
||||
default:
|
||||
pos++
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func skipSQLQuotedOrComment(text string, start int) (int, bool) {
|
||||
if start >= len(text) {
|
||||
return start, false
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(text[start:], "--"):
|
||||
next := strings.IndexByte(text[start:], '\n')
|
||||
if next < 0 {
|
||||
return len(text), true
|
||||
}
|
||||
return start + next + 1, true
|
||||
case strings.HasPrefix(text[start:], "#"):
|
||||
next := strings.IndexByte(text[start:], '\n')
|
||||
if next < 0 {
|
||||
return len(text), true
|
||||
}
|
||||
return start + next + 1, true
|
||||
case strings.HasPrefix(text[start:], "/*"):
|
||||
end := strings.Index(text[start+2:], "*/")
|
||||
if end < 0 {
|
||||
return len(text), true
|
||||
}
|
||||
return start + end + 4, true
|
||||
case text[start] == '\'' || text[start] == '"' || text[start] == '`':
|
||||
return skipSQLDelimited(text, start, text[start]), true
|
||||
case text[start] == '[':
|
||||
next := strings.IndexByte(text[start+1:], ']')
|
||||
if next < 0 {
|
||||
return len(text), true
|
||||
}
|
||||
return start + next + 2, true
|
||||
default:
|
||||
if tag, ok := sqlDollarQuoteTag(text, start); ok {
|
||||
end := strings.Index(text[start+len(tag):], tag)
|
||||
if end < 0 {
|
||||
return len(text), true
|
||||
}
|
||||
return start + len(tag) + end + len(tag), true
|
||||
}
|
||||
return start, false
|
||||
}
|
||||
}
|
||||
|
||||
func skipSQLDelimited(text string, start int, delimiter byte) int {
|
||||
pos := start + 1
|
||||
for pos < len(text) {
|
||||
if text[pos] == delimiter {
|
||||
if pos+1 < len(text) && text[pos+1] == delimiter {
|
||||
pos += 2
|
||||
continue
|
||||
}
|
||||
return pos + 1
|
||||
}
|
||||
pos++
|
||||
}
|
||||
return len(text)
|
||||
}
|
||||
|
||||
func sqlDollarQuoteTag(text string, start int) (string, bool) {
|
||||
if start >= len(text) || text[start] != '$' {
|
||||
return "", false
|
||||
}
|
||||
end := start + 1
|
||||
for end < len(text) && (isSQLKeywordByte(text[end]) || text[end] == '-') {
|
||||
end++
|
||||
}
|
||||
if end < len(text) && text[end] == '$' {
|
||||
return text[start : end+1], true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func isSQLKeywordByte(ch byte) bool {
|
||||
return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_'
|
||||
}
|
||||
|
||||
func isReadOnlySQLQuery(dbType string, query string) bool {
|
||||
if strings.ToLower(strings.TrimSpace(dbType)) == "mongodb" && strings.HasPrefix(strings.TrimSpace(query), "{") {
|
||||
return true
|
||||
}
|
||||
|
||||
switch leadingSQLKeyword(query) {
|
||||
switch sqlDataOperationKeyword(query) {
|
||||
case "select", "with", "show", "describe", "desc", "explain", "pragma", "values":
|
||||
return true
|
||||
default:
|
||||
@@ -70,7 +340,7 @@ func isBatchableWriteSQLStatement(dbType string, query string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
switch leadingSQLKeyword(query) {
|
||||
switch sqlDataOperationKeyword(query) {
|
||||
case "insert", "update", "delete", "replace", "merge", "upsert":
|
||||
return true
|
||||
default:
|
||||
|
||||
@@ -69,10 +69,25 @@ func TestIsReadOnlySQLQuery_DoesNotTreatExecAsReadOnly(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsReadOnlySQLQuery_ClassifiesWithByTopLevelOperation(t *testing.T) {
|
||||
readQuery := "WITH target AS (SELECT id FROM users WHERE active = 1) SELECT * FROM target"
|
||||
if !isReadOnlySQLQuery("postgres", readQuery) {
|
||||
t.Fatal("WITH ... SELECT should stay read-only")
|
||||
}
|
||||
|
||||
writeQuery := "WITH target AS (SELECT id FROM users WHERE active = 1) UPDATE users SET synced = 1 WHERE id IN (SELECT id FROM target)"
|
||||
if isReadOnlySQLQuery("postgres", writeQuery) {
|
||||
t.Fatal("WITH ... UPDATE should not be treated as read-only")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBatchableWriteSQLStatement_OnlyMatchesRealWriteStatements(t *testing.T) {
|
||||
if !isBatchableWriteSQLStatement("mysql", "INSERT INTO demo(id) VALUES (1)") {
|
||||
t.Fatal("expected INSERT to be treated as batchable write")
|
||||
}
|
||||
if !isBatchableWriteSQLStatement("postgres", "WITH target AS (SELECT id FROM users) DELETE FROM users WHERE id IN (SELECT id FROM target)") {
|
||||
t.Fatal("expected WITH ... DELETE to be treated as batchable write")
|
||||
}
|
||||
if isBatchableWriteSQLStatement("sqlserver", "EXEC sp_who2") {
|
||||
t.Fatal("EXEC should not be treated as batchable write")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user