From 8f7c79070051540690ec9173e2b318a7939a1466 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Thu, 4 Jun 2026 15:46:09 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(sql-editor):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E5=AD=98=E5=82=A8=E8=BF=87=E7=A8=8B=E5=AE=9A=E4=B9=89?= =?UTF-8?q?=E6=89=A7=E8=A1=8C=E6=88=AA=E6=96=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- frontend/src/components/QueryEditor.tsx | 225 +----------------- .../src/utils/sqlStatementSelection.test.ts | 66 +++++ frontend/src/utils/sqlStatementSelection.ts | 47 ++++ internal/app/methods_db_multi_test.go | 46 ++++ .../app/methods_file_sql_execution_test.go | 49 ++++ internal/app/sql_split.go | 60 ++++- internal/app/sql_split_stream.go | 8 +- internal/app/sql_split_test.go | 56 +++++ 8 files changed, 328 insertions(+), 229 deletions(-) diff --git a/frontend/src/components/QueryEditor.tsx b/frontend/src/components/QueryEditor.tsx index 8449fe5..cc992b2 100644 --- a/frontend/src/components/QueryEditor.tsx +++ b/frontend/src/components/QueryEditor.tsx @@ -18,7 +18,7 @@ import { applyQueryAutoLimit } from '../utils/queryAutoLimit'; import { extractQueryResultTableRef, type QueryResultTableRef } from '../utils/queryResultTable'; import { quoteIdentPart } from '../utils/sql'; import { formatSqlExecutionError } from '../utils/sqlErrorSemantics'; -import { resolveCurrentSqlStatementRange, resolveExecutableSql } from '../utils/sqlStatementSelection'; +import { findSqlStatementRanges, resolveCurrentSqlStatementRange, resolveExecutableSql } from '../utils/sqlStatementSelection'; import { isMacLikePlatform } from '../utils/appearance'; import { splitSidebarQualifiedName } from '../utils/sidebarLocate'; import { normalizeSidebarViewName } from '../utils/sidebarMetadata'; @@ -741,65 +741,6 @@ const areSqlStatementListsEqual = (left: string[], right: string[]): boolean => && left.every((statement, index) => normalizeExecutedSqlKey(statement) === normalizeExecutedSqlKey(right[index])) ); -const isSqlIdentifierStart = (ch: string): boolean => /^[A-Za-z_]$/.test(ch); - -const isSqlIdentifierPart = (ch: string): boolean => /^[A-Za-z0-9_$#]$/.test(ch); - -const skipSqlWhitespaceAndComments = (text: string, position: number): number => { - let index = position; - while (index < text.length) { - const ch = text[index]; - const next = index + 1 < text.length ? text[index + 1] : ''; - if (ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r' || ch === '\f') { - index += 1; - continue; - } - if (ch === '-' && next === '-') { - index += 2; - while (index < text.length && text[index] !== '\n') index += 1; - continue; - } - if (ch === '/' && next === '*') { - index += 2; - while (index + 1 < text.length && !(text[index] === '*' && text[index + 1] === '/')) { - index += 1; - } - if (index + 1 < text.length) index += 2; - continue; - } - break; - } - return index; -}; - -const nextSqlSignificantToken = (text: string, position: number): string => { - const index = skipSqlWhitespaceAndComments(text, position); - if (index >= text.length || !isSqlIdentifierStart(text[index])) return ''; - let end = index + 1; - while (end < text.length && isSqlIdentifierPart(text[end])) end += 1; - return text.slice(index, end).toLowerCase(); -}; - -const nextSqlSignificantChar = (text: string, position: number): string => { - const index = skipSqlWhitespaceAndComments(text, position); - return index >= text.length ? '' : text[index]; -}; - -const shouldEnterPlsqlBeginBlock = (text: string, tokenEnd: number): boolean => { - const nextChar = nextSqlSignificantChar(text, tokenEnd); - if (!nextChar || nextChar === ';') return false; - return !['transaction', 'work', 'isolation', 'read', 'write'].includes(nextSqlSignificantToken(text, tokenEnd)); -}; - -const shouldEnterPlsqlDeclareBlock = (text: string, tokenEnd: number): boolean => { - const nextToken = nextSqlSignificantToken(text, tokenEnd); - return Boolean(nextToken); -}; - -const isPlsqlControlEnd = (text: string, tokenEnd: number): boolean => ( - ['if', 'loop', 'case'].includes(nextSqlSignificantToken(text, tokenEnd)) -); - const normalizeEditorPosition = (position: any): { lineNumber: number; column: number } | null => { if (!position) return null; const lineNumber = Number(position.positionLineNumber ?? position.lineNumber ?? position.endLineNumber ?? position.startLineNumber ?? position.selectionStartLineNumber); @@ -3834,169 +3775,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc ]; const splitSQLStatements = (sql: string): string[] => { - const text = (sql || '').replace(/\r\n/g, '\n'); - const statements: string[] = []; - - let cur = ''; - let inSingle = false; - let inDouble = false; - let inBacktick = false; - let escaped = false; - let inLineComment = false; - let inBlockComment = false; - let dollarTag: string | null = null; // postgres/kingbase: $$...$$ or $tag$...$tag$ - let plsqlDepth = 0; - let plsqlDeclareBeginSkips = 0; - let justClosedPLSQLBlock = false; - - const push = () => { - const s = cur.trim(); - if (s) statements.push(s); - cur = ''; - }; - - const isWS = (ch: string) => ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r'; - - for (let i = 0; i < text.length; i++) { - const ch = text[i]; - const next = i + 1 < text.length ? text[i + 1] : ''; - const prev = i > 0 ? text[i - 1] : ''; - const next2 = i + 2 < text.length ? text[i + 2] : ''; - - if (!inSingle && !inDouble && !inBacktick) { - if (inLineComment) { - cur += ch; - if (ch === '\n') inLineComment = false; - continue; - } - - if (inBlockComment) { - cur += ch; - if (ch === '*' && next === '/') { - cur += next; - i++; - inBlockComment = false; - } - continue; - } - - // Start comments - if (ch === '/' && next === '*') { - cur += ch + next; - i++; - inBlockComment = true; - continue; - } - if (ch === '#') { - cur += ch; - inLineComment = true; - continue; - } - if (ch === '-' && next === '-' && (i === 0 || isWS(prev)) && (next2 === '' || isWS(next2))) { - cur += ch + next; - i++; - inLineComment = true; - continue; - } - - // Dollar-quoted strings (PG/Kingbase) - if (dollarTag) { - if (text.startsWith(dollarTag, i)) { - cur += dollarTag; - i += dollarTag.length - 1; - dollarTag = null; - } else { - cur += ch; - } - continue; - } - if (ch === '$') { - const m = text.slice(i).match(/^\$[A-Za-z0-9_]*\$/); - if (m && m[0]) { - dollarTag = m[0]; - cur += dollarTag; - i += dollarTag.length - 1; - continue; - } - } - } - - if (escaped) { - cur += ch; - escaped = false; - continue; - } - - if ((inSingle || inDouble) && ch === '\\') { - cur += ch; - escaped = true; - continue; - } - - if (!inDouble && !inBacktick && ch === '\'') { - inSingle = !inSingle; - cur += ch; - continue; - } - if (!inSingle && !inBacktick && ch === '"') { - inDouble = !inDouble; - cur += ch; - continue; - } - if (!inSingle && !inDouble && ch === '`') { - inBacktick = !inBacktick; - cur += ch; - continue; - } - - if (!inSingle && !inDouble && !inBacktick && !dollarTag && isSqlIdentifierStart(ch)) { - let end = i + 1; - while (end < text.length && isSqlIdentifierPart(text[end])) { - end += 1; - } - const token = text.slice(i, end).toLowerCase(); - if (token === 'begin' && plsqlDeclareBeginSkips > 0) { - plsqlDeclareBeginSkips -= 1; - justClosedPLSQLBlock = false; - } else if (token === 'begin' && shouldEnterPlsqlBeginBlock(text, end)) { - plsqlDepth += 1; - justClosedPLSQLBlock = false; - } else if (token === 'declare' && shouldEnterPlsqlDeclareBlock(text, end)) { - plsqlDepth += 1; - plsqlDeclareBeginSkips += 1; - justClosedPLSQLBlock = false; - } else if (token === 'end' && plsqlDepth > 0 && !isPlsqlControlEnd(text, end)) { - plsqlDepth -= 1; - if (plsqlDeclareBeginSkips > plsqlDepth) { - plsqlDeclareBeginSkips = plsqlDepth; - } - justClosedPLSQLBlock = plsqlDepth === 0; - } - cur += text.slice(i, end); - i = end - 1; - continue; - } - - if (!inSingle && !inDouble && !inBacktick && !dollarTag && (ch === ';' || ch === ';')) { - if (plsqlDepth > 0) { - cur += ch; - continue; - } - if (justClosedPLSQLBlock) { - cur += ch; - push(); - justClosedPLSQLBlock = false; - continue; - } - push(); - continue; - } - - cur += ch; - } - - push(); - return statements; + return findSqlStatementRanges(sql).map((range) => range.text); }; const getSelectedSQL = (): string => { diff --git a/frontend/src/utils/sqlStatementSelection.test.ts b/frontend/src/utils/sqlStatementSelection.test.ts index a21f749..fb7788a 100644 --- a/frontend/src/utils/sqlStatementSelection.test.ts +++ b/frontend/src/utils/sqlStatementSelection.test.ts @@ -86,6 +86,72 @@ describe('sqlStatementSelection', () => { }); }); + it('keeps Oracle CREATE PROCEDURE definitions as one executable statement', () => { + const sql = [ + 'CREATE OR REPLACE PROCEDURE proc_tally2accept(', + ' p_tallyacceptno IN t_tally_accept_h.acceptno%TYPE,', + ' out_acceptno OUT t_accept_h.acceptno%TYPE', + ') IS', + ' v_busno t_tally_accept_h.busno%TYPE;', + ' v_count PLS_INTEGER;', + 'BEGIN', + " SELECT COUNT(*) INTO v_count FROM t_tally_accept_h WHERE acceptno = p_tallyacceptno;", + ' IF v_count > 0 THEN', + ' out_acceptno := p_tallyacceptno;', + ' END IF;', + 'END;', + 'SELECT 1 FROM dual;', + ].join('\n'); + + const ranges = findSqlStatementRanges(sql).map((range) => range.text); + + expect(ranges).toEqual([ + [ + 'CREATE OR REPLACE PROCEDURE proc_tally2accept(', + ' p_tallyacceptno IN t_tally_accept_h.acceptno%TYPE,', + ' out_acceptno OUT t_accept_h.acceptno%TYPE', + ') IS', + ' v_busno t_tally_accept_h.busno%TYPE;', + ' v_count PLS_INTEGER;', + 'BEGIN', + " SELECT COUNT(*) INTO v_count FROM t_tally_accept_h WHERE acceptno = p_tallyacceptno;", + ' IF v_count > 0 THEN', + ' out_acceptno := p_tallyacceptno;', + ' END IF;', + 'END;', + ].join('\n'), + 'SELECT 1 FROM dual', + ]); + expect(resolveExecutableSql(sql, sql.indexOf('v_busno'))).toEqual({ + sql: ranges[0], + source: 'statement', + }); + }); + + it('keeps PostgreSQL dollar-quoted CREATE FUNCTION definitions as one executable statement', () => { + const sql = [ + 'CREATE OR REPLACE FUNCTION refresh_stats() RETURNS void AS $$', + 'BEGIN', + ' PERFORM refresh_now();', + 'END;', + '$$ LANGUAGE plpgsql;', + 'SELECT 2;', + ].join('\n'); + + const ranges = findSqlStatementRanges(sql).map((range) => range.text); + + expect(ranges).toEqual([ + [ + 'CREATE OR REPLACE FUNCTION refresh_stats() RETURNS void AS $$', + 'BEGIN', + ' PERFORM refresh_now();', + 'END;', + '$$ LANGUAGE plpgsql', + ].join('\n'), + 'SELECT 2', + ]); + }); + it('still splits transaction BEGIN statements', () => { const sql = 'BEGIN; UPDATE accounts SET balance = balance - 1 WHERE id = 1; COMMIT;'; diff --git a/frontend/src/utils/sqlStatementSelection.ts b/frontend/src/utils/sqlStatementSelection.ts index 8a3dfe8..81a2d35 100644 --- a/frontend/src/utils/sqlStatementSelection.ts +++ b/frontend/src/utils/sqlStatementSelection.ts @@ -67,6 +67,49 @@ const shouldEnterPlsqlBeginBlock = (text: string, tokenEnd: number): boolean => const shouldEnterPlsqlDeclareBlock = (text: string, tokenEnd: number): boolean => Boolean(nextSqlSignificantToken(text, tokenEnd)); +const nextSqlSignificantTokenSpan = (text: string, position: number): { token: string; end: number } => { + const index = skipSqlWhitespaceAndComments(text, position); + if (index >= text.length || !isSqlIdentifierStart(text[index])) { + return { token: '', end: index }; + } + let end = index + 1; + while (end < text.length && isSqlIdentifierPart(text[end])) end += 1; + return { token: text.slice(index, end).toLowerCase(), end }; +}; + +const isCreateRoutineHeaderPrefix = (text: string): boolean => { + let current = nextSqlSignificantTokenSpan(text, 0); + if (current.token !== 'create') return false; + + current = nextSqlSignificantTokenSpan(text, current.end); + if (current.token === 'or') { + current = nextSqlSignificantTokenSpan(text, current.end); + if (current.token !== 'replace') return false; + current = nextSqlSignificantTokenSpan(text, current.end); + } + + while (['editionable', 'noneditionable'].includes(current.token)) { + current = nextSqlSignificantTokenSpan(text, current.end); + } + + return current.token === 'procedure' || current.token === 'function'; +}; + +const shouldEnterPlsqlCreateRoutineBlock = ( + text: string, + statementStart: number, + token: string, + tokenEnd: number, +): boolean => { + if (token !== 'is' && token !== 'as') return false; + const nextChar = nextSqlSignificantChar(text, tokenEnd); + if (!nextChar) return false; + if (token === 'as' && (nextChar === '$' || nextChar === "'" || nextChar === '"')) { + return false; + } + return isCreateRoutineHeaderPrefix(text.slice(statementStart, tokenEnd - token.length)); +}; + const isPlsqlControlEnd = (text: string, tokenEnd: number): boolean => ( ['if', 'loop', 'case'].includes(nextSqlSignificantToken(text, tokenEnd)) ); @@ -209,6 +252,10 @@ export const findSqlStatementRanges = (sql: string): SqlStatementRange[] => { plsqlDepth++; plsqlDeclareBeginSkips++; justClosedPLSQLBlock = false; + } else if (plsqlDepth === 0 && shouldEnterPlsqlCreateRoutineBlock(text, statementStart, token, tokenEnd)) { + plsqlDepth++; + plsqlDeclareBeginSkips++; + justClosedPLSQLBlock = false; } else if (token === 'end' && plsqlDepth > 0 && !isPlsqlControlEnd(text, tokenEnd)) { plsqlDepth--; if (plsqlDeclareBeginSkips > plsqlDepth) { diff --git a/internal/app/methods_db_multi_test.go b/internal/app/methods_db_multi_test.go index 4250198..1d56bed 100644 --- a/internal/app/methods_db_multi_test.go +++ b/internal/app/methods_db_multi_test.go @@ -268,6 +268,52 @@ END;` } } +func TestDBQueryMultiKeepsOracleCreateProcedureAsSingleStatement(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + t.Cleanup(func() { + newDatabaseFunc = originalNewDatabaseFunc + }) + + fakeDB := &fakeBatchWriteDB{} + newDatabaseFunc = func(dbType string) (db.Database, error) { + return fakeDB, nil + } + + app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test")) + config := connection.ConnectionConfig{ + Type: "oracle", + Host: "127.0.0.1", + Port: 1521, + User: "app", + } + query := `CREATE OR REPLACE PROCEDURE proc_tally2accept( + p_tallyacceptno IN t_tally_accept_h.acceptno%TYPE, + out_acceptno OUT t_accept_h.acceptno%TYPE +) IS + v_busno t_tally_accept_h.busno%TYPE; + v_count PLS_INTEGER; +BEGIN + SELECT COUNT(*) INTO v_count FROM t_tally_accept_h WHERE acceptno = p_tallyacceptno; + IF v_count > 0 THEN + out_acceptno := p_tallyacceptno; + END IF; +END;` + + result := app.DBQueryMulti(config, "ORCLPDB1", query, "oracle-create-procedure-test") + if !result.Success { + t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message) + } + if fakeDB.batchCalls != 0 { + t.Fatalf("expected CREATE PROCEDURE to skip batch path, got batchCalls=%d", fakeDB.batchCalls) + } + if fakeDB.execCalls != 1 || len(fakeDB.execQueries) != 1 { + t.Fatalf("expected one sequential exec call, got execCalls=%d queries=%#v", fakeDB.execCalls, fakeDB.execQueries) + } + if fakeDB.execQueries[0] != query { + t.Fatalf("expected CREATE PROCEDURE to stay intact, got %q", fakeDB.execQueries[0]) + } +} + var _ db.BatchWriteExecer = (*fakeBatchWriteDB)(nil) var _ db.SessionExecerProvider = (*fakeBatchWriteDB)(nil) var _ db.QueryMessageExecer = (*fakeBatchWriteDB)(nil) diff --git a/internal/app/methods_file_sql_execution_test.go b/internal/app/methods_file_sql_execution_test.go index 0a6ca6a..73eb206 100644 --- a/internal/app/methods_file_sql_execution_test.go +++ b/internal/app/methods_file_sql_execution_test.go @@ -357,3 +357,52 @@ func TestStreamSQLFileKeepsOracleAnonymousBlockTogether(t *testing.T) { t.Fatalf("unexpected second statement: %q", statements[1]) } } + +func TestStreamSQLFileKeepsOracleCreateProcedureTogether(t *testing.T) { + input := strings.Join([]string{ + "CREATE OR REPLACE PROCEDURE proc_tally2accept(", + " p_tallyacceptno IN t_tally_accept_h.acceptno%TYPE,", + " out_acceptno OUT t_accept_h.acceptno%TYPE", + ") IS", + " v_busno t_tally_accept_h.busno%TYPE;", + " v_count PLS_INTEGER;", + "BEGIN", + " SELECT COUNT(*) INTO v_count FROM t_tally_accept_h WHERE acceptno = p_tallyacceptno;", + " IF v_count > 0 THEN", + " out_acceptno := p_tallyacceptno;", + " END IF;", + "END;", + "SELECT 1 FROM dual;", + }, "\n") + var statements []string + + count, err := streamSQLFile(&chunkedReader{data: []byte(input), step: 5}, func(index int, stmt string) error { + statements = append(statements, stmt) + return nil + }) + if err != nil { + t.Fatalf("streamSQLFile returned error: %v", err) + } + if count != 2 || len(statements) != 2 { + t.Fatalf("expected 2 statements, got count=%d statements=%#v", count, statements) + } + if statements[0] != strings.Join([]string{ + "CREATE OR REPLACE PROCEDURE proc_tally2accept(", + " p_tallyacceptno IN t_tally_accept_h.acceptno%TYPE,", + " out_acceptno OUT t_accept_h.acceptno%TYPE", + ") IS", + " v_busno t_tally_accept_h.busno%TYPE;", + " v_count PLS_INTEGER;", + "BEGIN", + " SELECT COUNT(*) INTO v_count FROM t_tally_accept_h WHERE acceptno = p_tallyacceptno;", + " IF v_count > 0 THEN", + " out_acceptno := p_tallyacceptno;", + " END IF;", + "END;", + }, "\n") { + t.Fatalf("unexpected create procedure statement: %q", statements[0]) + } + if statements[1] != "SELECT 1 FROM dual" { + t.Fatalf("unexpected second statement: %q", statements[1]) + } +} diff --git a/internal/app/sql_split.go b/internal/app/sql_split.go index bbd9829..3287af8 100644 --- a/internal/app/sql_split.go +++ b/internal/app/sql_split.go @@ -6,7 +6,7 @@ import "strings" // 正确处理单引号/双引号/反引号字符串、行注释(-- / #)、块注释(/* */)、 // PostgreSQL/Kingbase 的 $$...$$ dollar-quoting,以及 Oracle PL/SQL 匿名块, // 避免在这些上下文中错误拆分。 -// 同时支持 SQL 标准的转义单引号(两个连续单引号 '' 表示字面量引号)。 +// 同时支持 SQL 标准的转义单引号(两个连续单引号 ” 表示字面量引号)。 func splitSQLStatements(sql string) []string { text := strings.ReplaceAll(sql, "\r\n", "\n") var statements []string @@ -129,6 +129,10 @@ func splitSQLStatements(sql string) []string { plsqlDepth++ plsqlDeclareBeginSkips++ justClosedPLSQLBlock = false + } else if plsqlDepth == 0 && shouldEnterPLSQLCreateRoutineBlock(text, cur.String(), token, tokenEnd) { + plsqlDepth++ + plsqlDeclareBeginSkips++ + justClosedPLSQLBlock = false } else if token == "end" && plsqlDepth > 0 && !isPLSQLControlEnd(text, tokenEnd) { plsqlDepth-- if plsqlDeclareBeginSkips > plsqlDepth { @@ -297,10 +301,10 @@ func isPLSQLBlockStatement(stmt string) bool { if token == "declare" { return shouldEnterPLSQLDeclareBlock(text, len("declare")) } - if token != "begin" { - return false + if token == "begin" { + return shouldEnterPLSQLBlock(text, len("begin")) } - return shouldEnterPLSQLBlock(text, len("begin")) + return isCreateRoutineHeaderPrefix(text) } func shouldEnterPLSQLDeclareBlock(text string, tokenEnd int) bool { @@ -316,6 +320,54 @@ func isPLSQLControlEnd(text string, tokenEnd int) bool { } } +func isCreateRoutineHeaderPrefix(text string) bool { + currentToken, currentEnd := nextSQLSignificantTokenSpan(text, 0) + if currentToken != "create" { + return false + } + + currentToken, currentEnd = nextSQLSignificantTokenSpan(text, currentEnd) + if currentToken == "or" { + currentToken, currentEnd = nextSQLSignificantTokenSpan(text, currentEnd) + if currentToken != "replace" { + return false + } + currentToken, currentEnd = nextSQLSignificantTokenSpan(text, currentEnd) + } + + for currentToken == "editionable" || currentToken == "noneditionable" { + currentToken, currentEnd = nextSQLSignificantTokenSpan(text, currentEnd) + } + + return currentToken == "procedure" || currentToken == "function" +} + +func nextSQLSignificantTokenSpan(text string, pos int) (string, int) { + i := skipSQLWhitespaceAndComments(text, pos) + if i >= len(text) || !isSQLIdentifierStart(text[i]) { + return "", i + } + end := i + 1 + for end < len(text) && isSQLIdentifierPart(text[end]) { + end++ + } + return strings.ToLower(text[i:end]), end +} + +func shouldEnterPLSQLCreateRoutineBlock(text string, currentStatementPrefix string, token string, tokenEnd int) bool { + if token != "is" && token != "as" { + return false + } + nextChar := nextSQLSignificantByte(text, tokenEnd) + if nextChar == 0 { + return false + } + if token == "as" && (nextChar == '$' || nextChar == '\'' || nextChar == '"') { + return false + } + return isCreateRoutineHeaderPrefix(currentStatementPrefix) +} + // parseSQLDollarTag 解析 PostgreSQL/Kingbase 的 dollar-quoting 标签。 func parseSQLDollarTag(s string) string { if len(s) < 2 || s[0] != '$' { diff --git a/internal/app/sql_split_stream.go b/internal/app/sql_split_stream.go index cfc0a6a..95c3d01 100644 --- a/internal/app/sql_split_stream.go +++ b/internal/app/sql_split_stream.go @@ -146,6 +146,10 @@ func (s *sqlStreamSplitter) Feed(chunk []byte) []string { s.plsqlDepth++ s.declareSkips++ s.closedPLSQL = false + } else if s.plsqlDepth == 0 && shouldEnterPLSQLCreateRoutineBlock(text, s.cur.String(), token, tokenEnd) { + s.plsqlDepth++ + s.declareSkips++ + s.closedPLSQL = false } else if token == "end" && s.plsqlDepth > 0 && !isPLSQLControlEnd(text, tokenEnd) { s.plsqlDepth-- if s.declareSkips > s.plsqlDepth { @@ -289,7 +293,7 @@ func isIncompleteSQLDollarTag(s string) bool { func shouldDeferPLSQLKeywordInStream(text string, tokenStart int, tokenEnd int, token string) bool { switch token { - case "begin", "declare", "end": + case "begin", "declare", "end", "create", "or", "replace", "editionable", "noneditionable", "procedure", "function", "is", "as": default: return false } @@ -314,7 +318,7 @@ func shouldDeferPLSQLKeywordPrefixInStream(text string, tokenStart int, tokenEnd if tokenEnd < len(text) { return false } - for _, keyword := range []string{"begin", "declare", "end"} { + for _, keyword := range []string{"begin", "declare", "end", "create", "or", "replace", "editionable", "noneditionable", "procedure", "function", "is", "as"} { if strings.HasPrefix(keyword, token) && token != keyword { if tokenStart > 0 && isSQLIdentifierPart(text[tokenStart-1]) { return false diff --git a/internal/app/sql_split_test.go b/internal/app/sql_split_test.go index b4e4586..0503402 100644 --- a/internal/app/sql_split_test.go +++ b/internal/app/sql_split_test.go @@ -66,6 +66,27 @@ func TestSplitSQLStatements_DollarQuoting(t *testing.T) { } } +func TestSplitSQLStatements_PostgresCreateFunctionDollarQuoting(t *testing.T) { + input := `CREATE OR REPLACE FUNCTION refresh_stats() RETURNS void AS $$ +BEGIN + PERFORM refresh_now(); +END; +$$ LANGUAGE plpgsql; +SELECT 2;` + got := splitSQLStatements(input) + want := []string{ + `CREATE OR REPLACE FUNCTION refresh_stats() RETURNS void AS $$ +BEGIN + PERFORM refresh_now(); +END; +$$ LANGUAGE plpgsql`, + "SELECT 2", + } + if !reflect.DeepEqual(got, want) { + t.Errorf("splitSQLStatements(%q) = %#v, want %#v", input, got, want) + } +} + func TestSplitSQLStatements_FullWidthSemicolon(t *testing.T) { input := "SELECT 1;SELECT 2" got := splitSQLStatements(input) @@ -155,6 +176,41 @@ END;`, } } +func TestSplitSQLStatements_OracleCreateProcedureBlock(t *testing.T) { + input := `CREATE OR REPLACE PROCEDURE proc_tally2accept( + p_tallyacceptno IN t_tally_accept_h.acceptno%TYPE, + out_acceptno OUT t_accept_h.acceptno%TYPE +) IS + v_busno t_tally_accept_h.busno%TYPE; + v_count PLS_INTEGER; +BEGIN + SELECT COUNT(*) INTO v_count FROM t_tally_accept_h WHERE acceptno = p_tallyacceptno; + IF v_count > 0 THEN + out_acceptno := p_tallyacceptno; + END IF; +END; +SELECT 1 FROM dual;` + got := splitSQLStatements(input) + want := []string{ + `CREATE OR REPLACE PROCEDURE proc_tally2accept( + p_tallyacceptno IN t_tally_accept_h.acceptno%TYPE, + out_acceptno OUT t_accept_h.acceptno%TYPE +) IS + v_busno t_tally_accept_h.busno%TYPE; + v_count PLS_INTEGER; +BEGIN + SELECT COUNT(*) INTO v_count FROM t_tally_accept_h WHERE acceptno = p_tallyacceptno; + IF v_count > 0 THEN + out_acceptno := p_tallyacceptno; + END IF; +END;`, + "SELECT 1 FROM dual", + } + if !reflect.DeepEqual(got, want) { + t.Errorf("splitSQLStatements(%q) = %#v, want %#v", input, got, want) + } +} + func TestSplitSQLStatements_TransactionBeginStillSplits(t *testing.T) { input := "BEGIN; UPDATE accounts SET balance = balance - 1 WHERE id = 1; COMMIT;" got := splitSQLStatements(input)