mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-06-21 22:14:02 +08:00
🐛 fix(sql-editor): 修复存储过程定义执行截断
This commit is contained in:
@@ -18,7 +18,7 @@ import { applyQueryAutoLimit } from '../utils/queryAutoLimit';
|
||||
import { extractQueryResultTableRef, type QueryResultTableRef } from '../utils/queryResultTable';
|
||||
import { quoteIdentPart } from '../utils/sql';
|
||||
import { formatSqlExecutionError } from '../utils/sqlErrorSemantics';
|
||||
import { resolveCurrentSqlStatementRange, resolveExecutableSql } from '../utils/sqlStatementSelection';
|
||||
import { findSqlStatementRanges, resolveCurrentSqlStatementRange, resolveExecutableSql } from '../utils/sqlStatementSelection';
|
||||
import { isMacLikePlatform } from '../utils/appearance';
|
||||
import { splitSidebarQualifiedName } from '../utils/sidebarLocate';
|
||||
import { normalizeSidebarViewName } from '../utils/sidebarMetadata';
|
||||
@@ -741,65 +741,6 @@ const areSqlStatementListsEqual = (left: string[], right: string[]): boolean =>
|
||||
&& left.every((statement, index) => normalizeExecutedSqlKey(statement) === normalizeExecutedSqlKey(right[index]))
|
||||
);
|
||||
|
||||
const isSqlIdentifierStart = (ch: string): boolean => /^[A-Za-z_]$/.test(ch);
|
||||
|
||||
const isSqlIdentifierPart = (ch: string): boolean => /^[A-Za-z0-9_$#]$/.test(ch);
|
||||
|
||||
const skipSqlWhitespaceAndComments = (text: string, position: number): number => {
|
||||
let index = position;
|
||||
while (index < text.length) {
|
||||
const ch = text[index];
|
||||
const next = index + 1 < text.length ? text[index + 1] : '';
|
||||
if (ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r' || ch === '\f') {
|
||||
index += 1;
|
||||
continue;
|
||||
}
|
||||
if (ch === '-' && next === '-') {
|
||||
index += 2;
|
||||
while (index < text.length && text[index] !== '\n') index += 1;
|
||||
continue;
|
||||
}
|
||||
if (ch === '/' && next === '*') {
|
||||
index += 2;
|
||||
while (index + 1 < text.length && !(text[index] === '*' && text[index + 1] === '/')) {
|
||||
index += 1;
|
||||
}
|
||||
if (index + 1 < text.length) index += 2;
|
||||
continue;
|
||||
}
|
||||
break;
|
||||
}
|
||||
return index;
|
||||
};
|
||||
|
||||
const nextSqlSignificantToken = (text: string, position: number): string => {
|
||||
const index = skipSqlWhitespaceAndComments(text, position);
|
||||
if (index >= text.length || !isSqlIdentifierStart(text[index])) return '';
|
||||
let end = index + 1;
|
||||
while (end < text.length && isSqlIdentifierPart(text[end])) end += 1;
|
||||
return text.slice(index, end).toLowerCase();
|
||||
};
|
||||
|
||||
const nextSqlSignificantChar = (text: string, position: number): string => {
|
||||
const index = skipSqlWhitespaceAndComments(text, position);
|
||||
return index >= text.length ? '' : text[index];
|
||||
};
|
||||
|
||||
const shouldEnterPlsqlBeginBlock = (text: string, tokenEnd: number): boolean => {
|
||||
const nextChar = nextSqlSignificantChar(text, tokenEnd);
|
||||
if (!nextChar || nextChar === ';') return false;
|
||||
return !['transaction', 'work', 'isolation', 'read', 'write'].includes(nextSqlSignificantToken(text, tokenEnd));
|
||||
};
|
||||
|
||||
const shouldEnterPlsqlDeclareBlock = (text: string, tokenEnd: number): boolean => {
|
||||
const nextToken = nextSqlSignificantToken(text, tokenEnd);
|
||||
return Boolean(nextToken);
|
||||
};
|
||||
|
||||
const isPlsqlControlEnd = (text: string, tokenEnd: number): boolean => (
|
||||
['if', 'loop', 'case'].includes(nextSqlSignificantToken(text, tokenEnd))
|
||||
);
|
||||
|
||||
const normalizeEditorPosition = (position: any): { lineNumber: number; column: number } | null => {
|
||||
if (!position) return null;
|
||||
const lineNumber = Number(position.positionLineNumber ?? position.lineNumber ?? position.endLineNumber ?? position.startLineNumber ?? position.selectionStartLineNumber);
|
||||
@@ -3834,169 +3775,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
];
|
||||
|
||||
const splitSQLStatements = (sql: string): string[] => {
|
||||
const text = (sql || '').replace(/\r\n/g, '\n');
|
||||
const statements: string[] = [];
|
||||
|
||||
let cur = '';
|
||||
let inSingle = false;
|
||||
let inDouble = false;
|
||||
let inBacktick = false;
|
||||
let escaped = false;
|
||||
let inLineComment = false;
|
||||
let inBlockComment = false;
|
||||
let dollarTag: string | null = null; // postgres/kingbase: $$...$$ or $tag$...$tag$
|
||||
let plsqlDepth = 0;
|
||||
let plsqlDeclareBeginSkips = 0;
|
||||
let justClosedPLSQLBlock = false;
|
||||
|
||||
const push = () => {
|
||||
const s = cur.trim();
|
||||
if (s) statements.push(s);
|
||||
cur = '';
|
||||
};
|
||||
|
||||
const isWS = (ch: string) => ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r';
|
||||
|
||||
for (let i = 0; i < text.length; i++) {
|
||||
const ch = text[i];
|
||||
const next = i + 1 < text.length ? text[i + 1] : '';
|
||||
const prev = i > 0 ? text[i - 1] : '';
|
||||
const next2 = i + 2 < text.length ? text[i + 2] : '';
|
||||
|
||||
if (!inSingle && !inDouble && !inBacktick) {
|
||||
if (inLineComment) {
|
||||
cur += ch;
|
||||
if (ch === '\n') inLineComment = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (inBlockComment) {
|
||||
cur += ch;
|
||||
if (ch === '*' && next === '/') {
|
||||
cur += next;
|
||||
i++;
|
||||
inBlockComment = false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Start comments
|
||||
if (ch === '/' && next === '*') {
|
||||
cur += ch + next;
|
||||
i++;
|
||||
inBlockComment = true;
|
||||
continue;
|
||||
}
|
||||
if (ch === '#') {
|
||||
cur += ch;
|
||||
inLineComment = true;
|
||||
continue;
|
||||
}
|
||||
if (ch === '-' && next === '-' && (i === 0 || isWS(prev)) && (next2 === '' || isWS(next2))) {
|
||||
cur += ch + next;
|
||||
i++;
|
||||
inLineComment = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Dollar-quoted strings (PG/Kingbase)
|
||||
if (dollarTag) {
|
||||
if (text.startsWith(dollarTag, i)) {
|
||||
cur += dollarTag;
|
||||
i += dollarTag.length - 1;
|
||||
dollarTag = null;
|
||||
} else {
|
||||
cur += ch;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (ch === '$') {
|
||||
const m = text.slice(i).match(/^\$[A-Za-z0-9_]*\$/);
|
||||
if (m && m[0]) {
|
||||
dollarTag = m[0];
|
||||
cur += dollarTag;
|
||||
i += dollarTag.length - 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (escaped) {
|
||||
cur += ch;
|
||||
escaped = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
if ((inSingle || inDouble) && ch === '\\') {
|
||||
cur += ch;
|
||||
escaped = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!inDouble && !inBacktick && ch === '\'') {
|
||||
inSingle = !inSingle;
|
||||
cur += ch;
|
||||
continue;
|
||||
}
|
||||
if (!inSingle && !inBacktick && ch === '"') {
|
||||
inDouble = !inDouble;
|
||||
cur += ch;
|
||||
continue;
|
||||
}
|
||||
if (!inSingle && !inDouble && ch === '`') {
|
||||
inBacktick = !inBacktick;
|
||||
cur += ch;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!inSingle && !inDouble && !inBacktick && !dollarTag && isSqlIdentifierStart(ch)) {
|
||||
let end = i + 1;
|
||||
while (end < text.length && isSqlIdentifierPart(text[end])) {
|
||||
end += 1;
|
||||
}
|
||||
const token = text.slice(i, end).toLowerCase();
|
||||
if (token === 'begin' && plsqlDeclareBeginSkips > 0) {
|
||||
plsqlDeclareBeginSkips -= 1;
|
||||
justClosedPLSQLBlock = false;
|
||||
} else if (token === 'begin' && shouldEnterPlsqlBeginBlock(text, end)) {
|
||||
plsqlDepth += 1;
|
||||
justClosedPLSQLBlock = false;
|
||||
} else if (token === 'declare' && shouldEnterPlsqlDeclareBlock(text, end)) {
|
||||
plsqlDepth += 1;
|
||||
plsqlDeclareBeginSkips += 1;
|
||||
justClosedPLSQLBlock = false;
|
||||
} else if (token === 'end' && plsqlDepth > 0 && !isPlsqlControlEnd(text, end)) {
|
||||
plsqlDepth -= 1;
|
||||
if (plsqlDeclareBeginSkips > plsqlDepth) {
|
||||
plsqlDeclareBeginSkips = plsqlDepth;
|
||||
}
|
||||
justClosedPLSQLBlock = plsqlDepth === 0;
|
||||
}
|
||||
cur += text.slice(i, end);
|
||||
i = end - 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!inSingle && !inDouble && !inBacktick && !dollarTag && (ch === ';' || ch === ';')) {
|
||||
if (plsqlDepth > 0) {
|
||||
cur += ch;
|
||||
continue;
|
||||
}
|
||||
if (justClosedPLSQLBlock) {
|
||||
cur += ch;
|
||||
push();
|
||||
justClosedPLSQLBlock = false;
|
||||
continue;
|
||||
}
|
||||
push();
|
||||
continue;
|
||||
}
|
||||
|
||||
cur += ch;
|
||||
}
|
||||
|
||||
push();
|
||||
return statements;
|
||||
return findSqlStatementRanges(sql).map((range) => range.text);
|
||||
};
|
||||
|
||||
const getSelectedSQL = (): string => {
|
||||
|
||||
@@ -86,6 +86,72 @@ describe('sqlStatementSelection', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('keeps Oracle CREATE PROCEDURE definitions as one executable statement', () => {
|
||||
const sql = [
|
||||
'CREATE OR REPLACE PROCEDURE proc_tally2accept(',
|
||||
' p_tallyacceptno IN t_tally_accept_h.acceptno%TYPE,',
|
||||
' out_acceptno OUT t_accept_h.acceptno%TYPE',
|
||||
') IS',
|
||||
' v_busno t_tally_accept_h.busno%TYPE;',
|
||||
' v_count PLS_INTEGER;',
|
||||
'BEGIN',
|
||||
" SELECT COUNT(*) INTO v_count FROM t_tally_accept_h WHERE acceptno = p_tallyacceptno;",
|
||||
' IF v_count > 0 THEN',
|
||||
' out_acceptno := p_tallyacceptno;',
|
||||
' END IF;',
|
||||
'END;',
|
||||
'SELECT 1 FROM dual;',
|
||||
].join('\n');
|
||||
|
||||
const ranges = findSqlStatementRanges(sql).map((range) => range.text);
|
||||
|
||||
expect(ranges).toEqual([
|
||||
[
|
||||
'CREATE OR REPLACE PROCEDURE proc_tally2accept(',
|
||||
' p_tallyacceptno IN t_tally_accept_h.acceptno%TYPE,',
|
||||
' out_acceptno OUT t_accept_h.acceptno%TYPE',
|
||||
') IS',
|
||||
' v_busno t_tally_accept_h.busno%TYPE;',
|
||||
' v_count PLS_INTEGER;',
|
||||
'BEGIN',
|
||||
" SELECT COUNT(*) INTO v_count FROM t_tally_accept_h WHERE acceptno = p_tallyacceptno;",
|
||||
' IF v_count > 0 THEN',
|
||||
' out_acceptno := p_tallyacceptno;',
|
||||
' END IF;',
|
||||
'END;',
|
||||
].join('\n'),
|
||||
'SELECT 1 FROM dual',
|
||||
]);
|
||||
expect(resolveExecutableSql(sql, sql.indexOf('v_busno'))).toEqual({
|
||||
sql: ranges[0],
|
||||
source: 'statement',
|
||||
});
|
||||
});
|
||||
|
||||
it('keeps PostgreSQL dollar-quoted CREATE FUNCTION definitions as one executable statement', () => {
|
||||
const sql = [
|
||||
'CREATE OR REPLACE FUNCTION refresh_stats() RETURNS void AS $$',
|
||||
'BEGIN',
|
||||
' PERFORM refresh_now();',
|
||||
'END;',
|
||||
'$$ LANGUAGE plpgsql;',
|
||||
'SELECT 2;',
|
||||
].join('\n');
|
||||
|
||||
const ranges = findSqlStatementRanges(sql).map((range) => range.text);
|
||||
|
||||
expect(ranges).toEqual([
|
||||
[
|
||||
'CREATE OR REPLACE FUNCTION refresh_stats() RETURNS void AS $$',
|
||||
'BEGIN',
|
||||
' PERFORM refresh_now();',
|
||||
'END;',
|
||||
'$$ LANGUAGE plpgsql',
|
||||
].join('\n'),
|
||||
'SELECT 2',
|
||||
]);
|
||||
});
|
||||
|
||||
it('still splits transaction BEGIN statements', () => {
|
||||
const sql = 'BEGIN; UPDATE accounts SET balance = balance - 1 WHERE id = 1; COMMIT;';
|
||||
|
||||
|
||||
@@ -67,6 +67,49 @@ const shouldEnterPlsqlBeginBlock = (text: string, tokenEnd: number): boolean =>
|
||||
|
||||
const shouldEnterPlsqlDeclareBlock = (text: string, tokenEnd: number): boolean => Boolean(nextSqlSignificantToken(text, tokenEnd));
|
||||
|
||||
const nextSqlSignificantTokenSpan = (text: string, position: number): { token: string; end: number } => {
|
||||
const index = skipSqlWhitespaceAndComments(text, position);
|
||||
if (index >= text.length || !isSqlIdentifierStart(text[index])) {
|
||||
return { token: '', end: index };
|
||||
}
|
||||
let end = index + 1;
|
||||
while (end < text.length && isSqlIdentifierPart(text[end])) end += 1;
|
||||
return { token: text.slice(index, end).toLowerCase(), end };
|
||||
};
|
||||
|
||||
const isCreateRoutineHeaderPrefix = (text: string): boolean => {
|
||||
let current = nextSqlSignificantTokenSpan(text, 0);
|
||||
if (current.token !== 'create') return false;
|
||||
|
||||
current = nextSqlSignificantTokenSpan(text, current.end);
|
||||
if (current.token === 'or') {
|
||||
current = nextSqlSignificantTokenSpan(text, current.end);
|
||||
if (current.token !== 'replace') return false;
|
||||
current = nextSqlSignificantTokenSpan(text, current.end);
|
||||
}
|
||||
|
||||
while (['editionable', 'noneditionable'].includes(current.token)) {
|
||||
current = nextSqlSignificantTokenSpan(text, current.end);
|
||||
}
|
||||
|
||||
return current.token === 'procedure' || current.token === 'function';
|
||||
};
|
||||
|
||||
const shouldEnterPlsqlCreateRoutineBlock = (
|
||||
text: string,
|
||||
statementStart: number,
|
||||
token: string,
|
||||
tokenEnd: number,
|
||||
): boolean => {
|
||||
if (token !== 'is' && token !== 'as') return false;
|
||||
const nextChar = nextSqlSignificantChar(text, tokenEnd);
|
||||
if (!nextChar) return false;
|
||||
if (token === 'as' && (nextChar === '$' || nextChar === "'" || nextChar === '"')) {
|
||||
return false;
|
||||
}
|
||||
return isCreateRoutineHeaderPrefix(text.slice(statementStart, tokenEnd - token.length));
|
||||
};
|
||||
|
||||
const isPlsqlControlEnd = (text: string, tokenEnd: number): boolean => (
|
||||
['if', 'loop', 'case'].includes(nextSqlSignificantToken(text, tokenEnd))
|
||||
);
|
||||
@@ -209,6 +252,10 @@ export const findSqlStatementRanges = (sql: string): SqlStatementRange[] => {
|
||||
plsqlDepth++;
|
||||
plsqlDeclareBeginSkips++;
|
||||
justClosedPLSQLBlock = false;
|
||||
} else if (plsqlDepth === 0 && shouldEnterPlsqlCreateRoutineBlock(text, statementStart, token, tokenEnd)) {
|
||||
plsqlDepth++;
|
||||
plsqlDeclareBeginSkips++;
|
||||
justClosedPLSQLBlock = false;
|
||||
} else if (token === 'end' && plsqlDepth > 0 && !isPlsqlControlEnd(text, tokenEnd)) {
|
||||
plsqlDepth--;
|
||||
if (plsqlDeclareBeginSkips > plsqlDepth) {
|
||||
|
||||
@@ -268,6 +268,52 @@ END;`
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiKeepsOracleCreateProcedureAsSingleStatement(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
})
|
||||
|
||||
fakeDB := &fakeBatchWriteDB{}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
config := connection.ConnectionConfig{
|
||||
Type: "oracle",
|
||||
Host: "127.0.0.1",
|
||||
Port: 1521,
|
||||
User: "app",
|
||||
}
|
||||
query := `CREATE OR REPLACE PROCEDURE proc_tally2accept(
|
||||
p_tallyacceptno IN t_tally_accept_h.acceptno%TYPE,
|
||||
out_acceptno OUT t_accept_h.acceptno%TYPE
|
||||
) IS
|
||||
v_busno t_tally_accept_h.busno%TYPE;
|
||||
v_count PLS_INTEGER;
|
||||
BEGIN
|
||||
SELECT COUNT(*) INTO v_count FROM t_tally_accept_h WHERE acceptno = p_tallyacceptno;
|
||||
IF v_count > 0 THEN
|
||||
out_acceptno := p_tallyacceptno;
|
||||
END IF;
|
||||
END;`
|
||||
|
||||
result := app.DBQueryMulti(config, "ORCLPDB1", query, "oracle-create-procedure-test")
|
||||
if !result.Success {
|
||||
t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message)
|
||||
}
|
||||
if fakeDB.batchCalls != 0 {
|
||||
t.Fatalf("expected CREATE PROCEDURE to skip batch path, got batchCalls=%d", fakeDB.batchCalls)
|
||||
}
|
||||
if fakeDB.execCalls != 1 || len(fakeDB.execQueries) != 1 {
|
||||
t.Fatalf("expected one sequential exec call, got execCalls=%d queries=%#v", fakeDB.execCalls, fakeDB.execQueries)
|
||||
}
|
||||
if fakeDB.execQueries[0] != query {
|
||||
t.Fatalf("expected CREATE PROCEDURE to stay intact, got %q", fakeDB.execQueries[0])
|
||||
}
|
||||
}
|
||||
|
||||
var _ db.BatchWriteExecer = (*fakeBatchWriteDB)(nil)
|
||||
var _ db.SessionExecerProvider = (*fakeBatchWriteDB)(nil)
|
||||
var _ db.QueryMessageExecer = (*fakeBatchWriteDB)(nil)
|
||||
|
||||
@@ -357,3 +357,52 @@ func TestStreamSQLFileKeepsOracleAnonymousBlockTogether(t *testing.T) {
|
||||
t.Fatalf("unexpected second statement: %q", statements[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamSQLFileKeepsOracleCreateProcedureTogether(t *testing.T) {
|
||||
input := strings.Join([]string{
|
||||
"CREATE OR REPLACE PROCEDURE proc_tally2accept(",
|
||||
" p_tallyacceptno IN t_tally_accept_h.acceptno%TYPE,",
|
||||
" out_acceptno OUT t_accept_h.acceptno%TYPE",
|
||||
") IS",
|
||||
" v_busno t_tally_accept_h.busno%TYPE;",
|
||||
" v_count PLS_INTEGER;",
|
||||
"BEGIN",
|
||||
" SELECT COUNT(*) INTO v_count FROM t_tally_accept_h WHERE acceptno = p_tallyacceptno;",
|
||||
" IF v_count > 0 THEN",
|
||||
" out_acceptno := p_tallyacceptno;",
|
||||
" END IF;",
|
||||
"END;",
|
||||
"SELECT 1 FROM dual;",
|
||||
}, "\n")
|
||||
var statements []string
|
||||
|
||||
count, err := streamSQLFile(&chunkedReader{data: []byte(input), step: 5}, func(index int, stmt string) error {
|
||||
statements = append(statements, stmt)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("streamSQLFile returned error: %v", err)
|
||||
}
|
||||
if count != 2 || len(statements) != 2 {
|
||||
t.Fatalf("expected 2 statements, got count=%d statements=%#v", count, statements)
|
||||
}
|
||||
if statements[0] != strings.Join([]string{
|
||||
"CREATE OR REPLACE PROCEDURE proc_tally2accept(",
|
||||
" p_tallyacceptno IN t_tally_accept_h.acceptno%TYPE,",
|
||||
" out_acceptno OUT t_accept_h.acceptno%TYPE",
|
||||
") IS",
|
||||
" v_busno t_tally_accept_h.busno%TYPE;",
|
||||
" v_count PLS_INTEGER;",
|
||||
"BEGIN",
|
||||
" SELECT COUNT(*) INTO v_count FROM t_tally_accept_h WHERE acceptno = p_tallyacceptno;",
|
||||
" IF v_count > 0 THEN",
|
||||
" out_acceptno := p_tallyacceptno;",
|
||||
" END IF;",
|
||||
"END;",
|
||||
}, "\n") {
|
||||
t.Fatalf("unexpected create procedure statement: %q", statements[0])
|
||||
}
|
||||
if statements[1] != "SELECT 1 FROM dual" {
|
||||
t.Fatalf("unexpected second statement: %q", statements[1])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import "strings"
|
||||
// 正确处理单引号/双引号/反引号字符串、行注释(-- / #)、块注释(/* */)、
|
||||
// PostgreSQL/Kingbase 的 $$...$$ dollar-quoting,以及 Oracle PL/SQL 匿名块,
|
||||
// 避免在这些上下文中错误拆分。
|
||||
// 同时支持 SQL 标准的转义单引号(两个连续单引号 '' 表示字面量引号)。
|
||||
// 同时支持 SQL 标准的转义单引号(两个连续单引号 ” 表示字面量引号)。
|
||||
func splitSQLStatements(sql string) []string {
|
||||
text := strings.ReplaceAll(sql, "\r\n", "\n")
|
||||
var statements []string
|
||||
@@ -129,6 +129,10 @@ func splitSQLStatements(sql string) []string {
|
||||
plsqlDepth++
|
||||
plsqlDeclareBeginSkips++
|
||||
justClosedPLSQLBlock = false
|
||||
} else if plsqlDepth == 0 && shouldEnterPLSQLCreateRoutineBlock(text, cur.String(), token, tokenEnd) {
|
||||
plsqlDepth++
|
||||
plsqlDeclareBeginSkips++
|
||||
justClosedPLSQLBlock = false
|
||||
} else if token == "end" && plsqlDepth > 0 && !isPLSQLControlEnd(text, tokenEnd) {
|
||||
plsqlDepth--
|
||||
if plsqlDeclareBeginSkips > plsqlDepth {
|
||||
@@ -297,10 +301,10 @@ func isPLSQLBlockStatement(stmt string) bool {
|
||||
if token == "declare" {
|
||||
return shouldEnterPLSQLDeclareBlock(text, len("declare"))
|
||||
}
|
||||
if token != "begin" {
|
||||
return false
|
||||
if token == "begin" {
|
||||
return shouldEnterPLSQLBlock(text, len("begin"))
|
||||
}
|
||||
return shouldEnterPLSQLBlock(text, len("begin"))
|
||||
return isCreateRoutineHeaderPrefix(text)
|
||||
}
|
||||
|
||||
func shouldEnterPLSQLDeclareBlock(text string, tokenEnd int) bool {
|
||||
@@ -316,6 +320,54 @@ func isPLSQLControlEnd(text string, tokenEnd int) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func isCreateRoutineHeaderPrefix(text string) bool {
|
||||
currentToken, currentEnd := nextSQLSignificantTokenSpan(text, 0)
|
||||
if currentToken != "create" {
|
||||
return false
|
||||
}
|
||||
|
||||
currentToken, currentEnd = nextSQLSignificantTokenSpan(text, currentEnd)
|
||||
if currentToken == "or" {
|
||||
currentToken, currentEnd = nextSQLSignificantTokenSpan(text, currentEnd)
|
||||
if currentToken != "replace" {
|
||||
return false
|
||||
}
|
||||
currentToken, currentEnd = nextSQLSignificantTokenSpan(text, currentEnd)
|
||||
}
|
||||
|
||||
for currentToken == "editionable" || currentToken == "noneditionable" {
|
||||
currentToken, currentEnd = nextSQLSignificantTokenSpan(text, currentEnd)
|
||||
}
|
||||
|
||||
return currentToken == "procedure" || currentToken == "function"
|
||||
}
|
||||
|
||||
func nextSQLSignificantTokenSpan(text string, pos int) (string, int) {
|
||||
i := skipSQLWhitespaceAndComments(text, pos)
|
||||
if i >= len(text) || !isSQLIdentifierStart(text[i]) {
|
||||
return "", i
|
||||
}
|
||||
end := i + 1
|
||||
for end < len(text) && isSQLIdentifierPart(text[end]) {
|
||||
end++
|
||||
}
|
||||
return strings.ToLower(text[i:end]), end
|
||||
}
|
||||
|
||||
func shouldEnterPLSQLCreateRoutineBlock(text string, currentStatementPrefix string, token string, tokenEnd int) bool {
|
||||
if token != "is" && token != "as" {
|
||||
return false
|
||||
}
|
||||
nextChar := nextSQLSignificantByte(text, tokenEnd)
|
||||
if nextChar == 0 {
|
||||
return false
|
||||
}
|
||||
if token == "as" && (nextChar == '$' || nextChar == '\'' || nextChar == '"') {
|
||||
return false
|
||||
}
|
||||
return isCreateRoutineHeaderPrefix(currentStatementPrefix)
|
||||
}
|
||||
|
||||
// parseSQLDollarTag 解析 PostgreSQL/Kingbase 的 dollar-quoting 标签。
|
||||
func parseSQLDollarTag(s string) string {
|
||||
if len(s) < 2 || s[0] != '$' {
|
||||
|
||||
@@ -146,6 +146,10 @@ func (s *sqlStreamSplitter) Feed(chunk []byte) []string {
|
||||
s.plsqlDepth++
|
||||
s.declareSkips++
|
||||
s.closedPLSQL = false
|
||||
} else if s.plsqlDepth == 0 && shouldEnterPLSQLCreateRoutineBlock(text, s.cur.String(), token, tokenEnd) {
|
||||
s.plsqlDepth++
|
||||
s.declareSkips++
|
||||
s.closedPLSQL = false
|
||||
} else if token == "end" && s.plsqlDepth > 0 && !isPLSQLControlEnd(text, tokenEnd) {
|
||||
s.plsqlDepth--
|
||||
if s.declareSkips > s.plsqlDepth {
|
||||
@@ -289,7 +293,7 @@ func isIncompleteSQLDollarTag(s string) bool {
|
||||
|
||||
func shouldDeferPLSQLKeywordInStream(text string, tokenStart int, tokenEnd int, token string) bool {
|
||||
switch token {
|
||||
case "begin", "declare", "end":
|
||||
case "begin", "declare", "end", "create", "or", "replace", "editionable", "noneditionable", "procedure", "function", "is", "as":
|
||||
default:
|
||||
return false
|
||||
}
|
||||
@@ -314,7 +318,7 @@ func shouldDeferPLSQLKeywordPrefixInStream(text string, tokenStart int, tokenEnd
|
||||
if tokenEnd < len(text) {
|
||||
return false
|
||||
}
|
||||
for _, keyword := range []string{"begin", "declare", "end"} {
|
||||
for _, keyword := range []string{"begin", "declare", "end", "create", "or", "replace", "editionable", "noneditionable", "procedure", "function", "is", "as"} {
|
||||
if strings.HasPrefix(keyword, token) && token != keyword {
|
||||
if tokenStart > 0 && isSQLIdentifierPart(text[tokenStart-1]) {
|
||||
return false
|
||||
|
||||
@@ -66,6 +66,27 @@ func TestSplitSQLStatements_DollarQuoting(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitSQLStatements_PostgresCreateFunctionDollarQuoting(t *testing.T) {
|
||||
input := `CREATE OR REPLACE FUNCTION refresh_stats() RETURNS void AS $$
|
||||
BEGIN
|
||||
PERFORM refresh_now();
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
SELECT 2;`
|
||||
got := splitSQLStatements(input)
|
||||
want := []string{
|
||||
`CREATE OR REPLACE FUNCTION refresh_stats() RETURNS void AS $$
|
||||
BEGIN
|
||||
PERFORM refresh_now();
|
||||
END;
|
||||
$$ LANGUAGE plpgsql`,
|
||||
"SELECT 2",
|
||||
}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("splitSQLStatements(%q) = %#v, want %#v", input, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitSQLStatements_FullWidthSemicolon(t *testing.T) {
|
||||
input := "SELECT 1;SELECT 2"
|
||||
got := splitSQLStatements(input)
|
||||
@@ -155,6 +176,41 @@ END;`,
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitSQLStatements_OracleCreateProcedureBlock(t *testing.T) {
|
||||
input := `CREATE OR REPLACE PROCEDURE proc_tally2accept(
|
||||
p_tallyacceptno IN t_tally_accept_h.acceptno%TYPE,
|
||||
out_acceptno OUT t_accept_h.acceptno%TYPE
|
||||
) IS
|
||||
v_busno t_tally_accept_h.busno%TYPE;
|
||||
v_count PLS_INTEGER;
|
||||
BEGIN
|
||||
SELECT COUNT(*) INTO v_count FROM t_tally_accept_h WHERE acceptno = p_tallyacceptno;
|
||||
IF v_count > 0 THEN
|
||||
out_acceptno := p_tallyacceptno;
|
||||
END IF;
|
||||
END;
|
||||
SELECT 1 FROM dual;`
|
||||
got := splitSQLStatements(input)
|
||||
want := []string{
|
||||
`CREATE OR REPLACE PROCEDURE proc_tally2accept(
|
||||
p_tallyacceptno IN t_tally_accept_h.acceptno%TYPE,
|
||||
out_acceptno OUT t_accept_h.acceptno%TYPE
|
||||
) IS
|
||||
v_busno t_tally_accept_h.busno%TYPE;
|
||||
v_count PLS_INTEGER;
|
||||
BEGIN
|
||||
SELECT COUNT(*) INTO v_count FROM t_tally_accept_h WHERE acceptno = p_tallyacceptno;
|
||||
IF v_count > 0 THEN
|
||||
out_acceptno := p_tallyacceptno;
|
||||
END IF;
|
||||
END;`,
|
||||
"SELECT 1 FROM dual",
|
||||
}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("splitSQLStatements(%q) = %#v, want %#v", input, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitSQLStatements_TransactionBeginStillSplits(t *testing.T) {
|
||||
input := "BEGIN; UPDATE accounts SET balance = balance - 1 WHERE id = 1; COMMIT;"
|
||||
got := splitSQLStatements(input)
|
||||
|
||||
Reference in New Issue
Block a user