🐛 fix(sql-editor): 修复脚本执行拆分与元数据只读提示

- Oracle 匿名块:识别 BEGIN/DECLARE...END 块,避免按内部分号错误拆分
- 执行路径:PL/SQL 块跳过批量写入路径,保持单条语句语义
- SQL 文件:同步修复流式 SQL 文件拆分逻辑
- 查询结果:系统元数据表保持只读但不再弹业务表主键提示
- 测试覆盖:补充前后端拆分、执行和 information_schema 回归用例
This commit is contained in:
Syngnat
2026-06-03 17:11:05 +08:00
parent 4b23c013d9
commit 1ae44941dd
12 changed files with 779 additions and 7 deletions

View File

@@ -1 +1 @@
d0464f9da25e9356e61652e638c99ffe
0295a42fd931778d85157816d79d29e5

View File

@@ -1676,6 +1676,42 @@ describe('QueryEditor external SQL save', () => {
renderer?.unmount();
});
it('keeps Oracle anonymous PL/SQL blocks intact when running from the editor', async () => {
storeState.connections[0].config.type = 'oracle';
storeState.connections[0].config.database = 'ORCLPDB1';
backendApp.DBQueryMulti.mockResolvedValueOnce({
success: true,
data: [{ columns: ['affectedRows'], rows: [{ affectedRows: 1 }] }],
});
const plsql = [
'BEGIN',
" INSERT INTO tmp_disable_trigger (table_name) VALUES ('t_memcard_reg');",
" UPDATE t_memcard_reg SET CARDLEVEL = 1 WHERE MEMCARDNO = '8032277312';",
" DELETE FROM tmp_disable_trigger WHERE table_name = 't_memcard_reg';",
'END;',
].join('\n');
let renderer!: ReactTestRenderer;
await act(async () => {
renderer = create(<QueryEditor tab={createTab({ dbName: 'ORCLPDB1', query: plsql })} />);
});
await act(async () => {
await findButton(renderer!, '运行').props.onClick();
});
await act(async () => {
await Promise.resolve();
await Promise.resolve();
});
expect(backendApp.DBQueryMulti).toHaveBeenCalledWith(expect.anything(), 'ORCLPDB1', plsql, 'query-1');
expect(storeState.addSqlLog).toHaveBeenCalledWith(expect.objectContaining({
sql: plsql,
status: 'success',
}));
renderer?.unmount();
});
it('keeps non-Oracle query results read-only when no safe locator exists', async () => {
backendApp.DBQueryMulti.mockResolvedValueOnce({
success: true,
@@ -1710,6 +1746,46 @@ describe('QueryEditor external SQL save', () => {
expect(messageApi.warning).toHaveBeenCalledWith('查询结果保持只读main.users 未检测到主键或可用唯一索引,无法安全提交修改。');
});
it('keeps MySQL information_schema routine results read-only without a locator warning', async () => {
const sql = [
'SELECT ROUTINE_SCHEMA, ROUTINE_NAME, DEFINER, SECURITY_TYPE',
'FROM information_schema.ROUTINES',
"WHERE ROUTINE_SCHEMA = 'mkefu_location_dev_local'",
" AND ROUTINE_NAME = 'init_orgi'",
].join('\n');
backendApp.DBQueryMulti.mockResolvedValueOnce({
success: true,
data: [{
columns: ['ROUTINE_SCHEMA', 'ROUTINE_NAME', 'DEFINER', 'SECURITY_TYPE'],
rows: [{
ROUTINE_SCHEMA: 'mkefu_location_dev_local',
ROUTINE_NAME: 'init_orgi',
DEFINER: 'root@%',
SECURITY_TYPE: 'DEFINER',
}],
}],
});
let renderer: ReactTestRenderer;
await act(async () => {
renderer = create(<QueryEditor tab={createTab({ dbName: 'mkefu_location_dev_local', query: sql })} />);
});
await act(async () => {
await findButton(renderer!, '运行').props.onClick();
});
await act(async () => {
await Promise.resolve();
await Promise.resolve();
});
expect(dataGridState.latestProps?.tableName).toBe('ROUTINES');
expect(dataGridState.latestProps?.readOnly).toBe(true);
expect(backendApp.DBGetColumns).not.toHaveBeenCalled();
expect(backendApp.DBGetIndexes).not.toHaveBeenCalled();
expect(messageApi.warning).not.toHaveBeenCalled();
});
it('runs the SQL statement at the cursor instead of the whole editor when nothing is selected', async () => {
backendApp.DBQueryMulti.mockResolvedValueOnce({
success: true,

View File

@@ -262,6 +262,33 @@ const stripQueryIdentifierQuotes = (part: string): string => {
return text;
};
const MYSQL_SYSTEM_METADATA_SCHEMAS = new Set(['information_schema', 'performance_schema', 'mysql', 'sys']);
const POSTGRES_SYSTEM_METADATA_SCHEMAS = new Set(['information_schema', 'pg_catalog']);
const SQLITE_SYSTEM_METADATA_TABLES = new Set(['sqlite_master', 'sqlite_schema', 'sqlite_temp_master', 'sqlite_temp_schema']);
const isSystemMetadataQueryResult = (tableRef: QueryResultTableRef, dbType: string): boolean => {
const normalizedDbType = String(dbType || '').trim().toLowerCase();
const metadataDbName = stripQueryIdentifierQuotes(tableRef.metadataDbName).toLowerCase();
const metadataTableName = stripQueryIdentifierQuotes(tableRef.metadataTableName).toLowerCase();
if (['mysql', 'mariadb', 'oceanbase', 'diros', 'starrocks', 'sphinx', 'tidb'].includes(normalizedDbType)) {
return MYSQL_SYSTEM_METADATA_SCHEMAS.has(metadataDbName);
}
if (['postgres', 'kingbase', 'highgo', 'vastbase', 'opengauss'].includes(normalizedDbType)) {
return POSTGRES_SYSTEM_METADATA_SCHEMAS.has(metadataDbName);
}
if (normalizedDbType === 'sqlite' || normalizedDbType === 'duckdb') {
return SQLITE_SYSTEM_METADATA_TABLES.has(metadataTableName) || metadataDbName === 'information_schema';
}
if (normalizedDbType === 'sqlserver') {
return metadataDbName === 'information_schema' || metadataDbName === 'sys';
}
if (normalizedDbType === 'clickhouse') {
return metadataDbName === 'system' || metadataDbName === 'information_schema';
}
return false;
};
const splitTopLevelComma = (text: string): string[] => {
const parts: string[] = [];
let current = '';
@@ -658,6 +685,65 @@ const areSqlStatementListsEqual = (left: string[], right: string[]): boolean =>
&& left.every((statement, index) => normalizeExecutedSqlKey(statement) === normalizeExecutedSqlKey(right[index]))
);
const isSqlIdentifierStart = (ch: string): boolean => /^[A-Za-z_]$/.test(ch);
const isSqlIdentifierPart = (ch: string): boolean => /^[A-Za-z0-9_$#]$/.test(ch);
const skipSqlWhitespaceAndComments = (text: string, position: number): number => {
let index = position;
while (index < text.length) {
const ch = text[index];
const next = index + 1 < text.length ? text[index + 1] : '';
if (ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r' || ch === '\f') {
index += 1;
continue;
}
if (ch === '-' && next === '-') {
index += 2;
while (index < text.length && text[index] !== '\n') index += 1;
continue;
}
if (ch === '/' && next === '*') {
index += 2;
while (index + 1 < text.length && !(text[index] === '*' && text[index + 1] === '/')) {
index += 1;
}
if (index + 1 < text.length) index += 2;
continue;
}
break;
}
return index;
};
const nextSqlSignificantToken = (text: string, position: number): string => {
const index = skipSqlWhitespaceAndComments(text, position);
if (index >= text.length || !isSqlIdentifierStart(text[index])) return '';
let end = index + 1;
while (end < text.length && isSqlIdentifierPart(text[end])) end += 1;
return text.slice(index, end).toLowerCase();
};
const nextSqlSignificantChar = (text: string, position: number): string => {
const index = skipSqlWhitespaceAndComments(text, position);
return index >= text.length ? '' : text[index];
};
const shouldEnterPlsqlBeginBlock = (text: string, tokenEnd: number): boolean => {
const nextChar = nextSqlSignificantChar(text, tokenEnd);
if (!nextChar || nextChar === ';') return false;
return !['transaction', 'work', 'isolation', 'read', 'write'].includes(nextSqlSignificantToken(text, tokenEnd));
};
const shouldEnterPlsqlDeclareBlock = (text: string, tokenEnd: number): boolean => {
const nextToken = nextSqlSignificantToken(text, tokenEnd);
return Boolean(nextToken);
};
const isPlsqlControlEnd = (text: string, tokenEnd: number): boolean => (
['if', 'loop', 'case'].includes(nextSqlSignificantToken(text, tokenEnd))
);
const normalizeEditorPosition = (position: any): { lineNumber: number; column: number } | null => {
if (!position) return null;
const lineNumber = Number(position.positionLineNumber ?? position.lineNumber ?? position.endLineNumber ?? position.startLineNumber ?? position.selectionStartLineNumber);
@@ -1563,6 +1649,10 @@ const resolveQueryLocatorPlan = async ({
const tableRef = extractQueryResultTableRef(statement, dbType, currentDb);
if (!tableRef) return plan;
plan.tableRef = tableRef;
if (isSystemMetadataQueryResult(tableRef, dbType)) {
plan.editLocator = buildQueryReadOnlyLocator('系统元数据查询结果保持只读。');
return plan;
}
const selectInfo = parseSimpleSelectInfo(statement);
if (!selectInfo) {
@@ -3604,6 +3694,9 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
let inLineComment = false;
let inBlockComment = false;
let dollarTag: string | null = null; // postgres/kingbase: $$...$$ or $tag$...$tag$
let plsqlDepth = 0;
let plsqlDeclareBeginSkips = 0;
let justClosedPLSQLBlock = false;
const push = () => {
const s = cur.trim();
@@ -3705,7 +3798,45 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
continue;
}
if (!inSingle && !inDouble && !inBacktick && !dollarTag && isSqlIdentifierStart(ch)) {
let end = i + 1;
while (end < text.length && isSqlIdentifierPart(text[end])) {
end += 1;
}
const token = text.slice(i, end).toLowerCase();
if (token === 'begin' && plsqlDeclareBeginSkips > 0) {
plsqlDeclareBeginSkips -= 1;
justClosedPLSQLBlock = false;
} else if (token === 'begin' && shouldEnterPlsqlBeginBlock(text, end)) {
plsqlDepth += 1;
justClosedPLSQLBlock = false;
} else if (token === 'declare' && shouldEnterPlsqlDeclareBlock(text, end)) {
plsqlDepth += 1;
plsqlDeclareBeginSkips += 1;
justClosedPLSQLBlock = false;
} else if (token === 'end' && plsqlDepth > 0 && !isPlsqlControlEnd(text, end)) {
plsqlDepth -= 1;
if (plsqlDeclareBeginSkips > plsqlDepth) {
plsqlDeclareBeginSkips = plsqlDepth;
}
justClosedPLSQLBlock = plsqlDepth === 0;
}
cur += text.slice(i, end);
i = end - 1;
continue;
}
if (!inSingle && !inDouble && !inBacktick && !dollarTag && (ch === ';' || ch === '')) {
if (plsqlDepth > 0) {
cur += ch;
continue;
}
if (justClosedPLSQLBlock) {
cur += ch;
push();
justClosedPLSQLBlock = false;
continue;
}
push();
continue;
}

View File

@@ -28,6 +28,74 @@ describe('sqlStatementSelection', () => {
]);
});
it('keeps Oracle anonymous PL/SQL blocks as one executable statement', () => {
const plsql = [
'BEGIN',
" INSERT INTO tmp_disable_trigger (table_name) VALUES ('t_memcard_reg');",
" UPDATE t_memcard_reg SET CARDLEVEL = 1 WHERE MEMCARDNO = '8032277312';",
" DELETE FROM tmp_disable_trigger WHERE table_name = 't_memcard_reg';",
'END;',
'SELECT 1 FROM dual;',
].join('\n');
const ranges = findSqlStatementRanges(plsql).map((range) => range.text);
expect(ranges).toEqual([
[
'BEGIN',
" INSERT INTO tmp_disable_trigger (table_name) VALUES ('t_memcard_reg');",
" UPDATE t_memcard_reg SET CARDLEVEL = 1 WHERE MEMCARDNO = '8032277312';",
" DELETE FROM tmp_disable_trigger WHERE table_name = 't_memcard_reg';",
'END;',
].join('\n'),
'SELECT 1 FROM dual',
]);
expect(resolveExecutableSql(plsql, plsql.indexOf('UPDATE'))).toEqual({
sql: ranges[0],
source: 'statement',
});
});
it('keeps Oracle DECLARE blocks as one executable statement', () => {
const sql = [
'DECLARE',
' v_count NUMBER;',
'BEGIN',
' SELECT COUNT(*) INTO v_count FROM t_memcard_reg;',
" UPDATE t_memcard_reg SET CARDLEVEL = v_count WHERE MEMCARDNO = '8032277312';",
'END;',
'SELECT 1 FROM dual;',
].join('\n');
const ranges = findSqlStatementRanges(sql).map((range) => range.text);
expect(ranges).toEqual([
[
'DECLARE',
' v_count NUMBER;',
'BEGIN',
' SELECT COUNT(*) INTO v_count FROM t_memcard_reg;',
" UPDATE t_memcard_reg SET CARDLEVEL = v_count WHERE MEMCARDNO = '8032277312';",
'END;',
].join('\n'),
'SELECT 1 FROM dual',
]);
expect(resolveExecutableSql(sql, sql.indexOf('UPDATE'))).toEqual({
sql: ranges[0],
source: 'statement',
});
});
it('still splits transaction BEGIN statements', () => {
const sql = 'BEGIN; UPDATE accounts SET balance = balance - 1 WHERE id = 1; COMMIT;';
expect(findSqlStatementRanges(sql).map((range) => range.text)).toEqual([
'BEGIN',
'UPDATE accounts SET balance = balance - 1 WHERE id = 1',
'COMMIT',
]);
});
it('selects the next statement when the cursor is on whitespace before it', () => {
const sql = 'select 1;\n\n select 2;';
const range = resolveCurrentSqlStatementRange(sql, sql.indexOf(' select 2'));

View File

@@ -12,7 +12,63 @@ export interface SqlExecutionSelection {
}
const isWhitespace = (ch: string): boolean => (
ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r'
ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r' || ch === '\f'
);
const isSqlIdentifierStart = (ch: string): boolean => /^[A-Za-z_]$/.test(ch);
const isSqlIdentifierPart = (ch: string): boolean => /^[A-Za-z0-9_$#]$/.test(ch);
const skipSqlWhitespaceAndComments = (text: string, position: number): number => {
let index = position;
while (index < text.length) {
const ch = text[index];
const next = index + 1 < text.length ? text[index + 1] : '';
if (isWhitespace(ch)) {
index += 1;
continue;
}
if (ch === '-' && next === '-') {
index += 2;
while (index < text.length && text[index] !== '\n') index += 1;
continue;
}
if (ch === '/' && next === '*') {
index += 2;
while (index + 1 < text.length && !(text[index] === '*' && text[index + 1] === '/')) {
index += 1;
}
if (index + 1 < text.length) index += 2;
continue;
}
break;
}
return index;
};
const nextSqlSignificantToken = (text: string, position: number): string => {
const index = skipSqlWhitespaceAndComments(text, position);
if (index >= text.length || !isSqlIdentifierStart(text[index])) return '';
let end = index + 1;
while (end < text.length && isSqlIdentifierPart(text[end])) end += 1;
return text.slice(index, end).toLowerCase();
};
const nextSqlSignificantChar = (text: string, position: number): string => {
const index = skipSqlWhitespaceAndComments(text, position);
return index >= text.length ? '' : text[index];
};
const shouldEnterPlsqlBeginBlock = (text: string, tokenEnd: number): boolean => {
const nextChar = nextSqlSignificantChar(text, tokenEnd);
if (!nextChar || nextChar === ';') return false;
return !['transaction', 'work', 'isolation', 'read', 'write'].includes(nextSqlSignificantToken(text, tokenEnd));
};
const shouldEnterPlsqlDeclareBlock = (text: string, tokenEnd: number): boolean => Boolean(nextSqlSignificantToken(text, tokenEnd));
const isPlsqlControlEnd = (text: string, tokenEnd: number): boolean => (
['if', 'loop', 'case'].includes(nextSqlSignificantToken(text, tokenEnd))
);
const trimStatementRange = (sql: string, start: number, end: number): SqlStatementRange | null => {
@@ -49,6 +105,9 @@ export const findSqlStatementRanges = (sql: string): SqlStatementRange[] => {
let inLineComment = false;
let inBlockComment = false;
let dollarTag: string | null = null;
let plsqlDepth = 0;
let plsqlDeclareBeginSkips = 0;
let justClosedPLSQLBlock = false;
const push = (end: number) => {
const range = trimStatementRange(text, statementStart, end);
@@ -134,9 +193,41 @@ export const findSqlStatementRanges = (sql: string): SqlStatementRange[] => {
continue;
}
if (!inSingle && !inDouble && !inBacktick && !dollarTag && isSqlIdentifierStart(ch)) {
let tokenEnd = index + 1;
while (tokenEnd < text.length && isSqlIdentifierPart(text[tokenEnd])) {
tokenEnd++;
}
const token = text.slice(index, tokenEnd).toLowerCase();
if (token === 'begin' && plsqlDeclareBeginSkips > 0) {
plsqlDeclareBeginSkips--;
justClosedPLSQLBlock = false;
} else if (token === 'begin' && shouldEnterPlsqlBeginBlock(text, tokenEnd)) {
plsqlDepth++;
justClosedPLSQLBlock = false;
} else if (token === 'declare' && shouldEnterPlsqlDeclareBlock(text, tokenEnd)) {
plsqlDepth++;
plsqlDeclareBeginSkips++;
justClosedPLSQLBlock = false;
} else if (token === 'end' && plsqlDepth > 0 && !isPlsqlControlEnd(text, tokenEnd)) {
plsqlDepth--;
if (plsqlDeclareBeginSkips > plsqlDepth) {
plsqlDeclareBeginSkips = plsqlDepth;
}
justClosedPLSQLBlock = plsqlDepth === 0;
}
index = tokenEnd - 1;
continue;
}
if (!inSingle && !inDouble && !inBacktick && (ch === ';' || ch === '')) {
push(index);
if (plsqlDepth > 0) {
continue;
}
push(justClosedPLSQLBlock ? index + 1 : index);
statementStart = index + 1;
justClosedPLSQLBlock = false;
continue;
}
}

View File

@@ -770,13 +770,16 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
// 适用于 MySQL/MariaDB/Doris/PostgreSQL/SQLite/DuckDB 等支持多语句 Exec 的驱动
if !allReadOnly {
allWrite := true
containsPLSQLBlock := false
for _, stmt := range statements {
if strings.TrimSpace(stmt) != "" && isReadOnlySQLQuery(runConfig.Type, stmt) {
allWrite = false
break
}
if isPLSQLBlockStatement(stmt) {
containsPLSQLBlock = true
}
}
if allWrite {
if allWrite && !containsPLSQLBlock {
if batcher, ok := dbInst.(db.BatchWriteExecer); ok {
affected, batchErr := batcher.ExecBatchContext(ctx, query)
if batchErr != nil && shouldRefreshCachedConnection(batchErr) {

View File

@@ -12,6 +12,7 @@ import (
type fakeBatchWriteDB struct {
batchCalls int
execCalls int
execQueries []string
lastQuery string
}
@@ -33,6 +34,7 @@ func (f *fakeBatchWriteDB) Query(query string) ([]map[string]interface{}, []stri
func (f *fakeBatchWriteDB) Exec(query string) (int64, error) {
f.execCalls++
f.execQueries = append(f.execQueries, query)
return 1, nil
}
@@ -70,6 +72,7 @@ func (f *fakeBatchWriteDB) GetTriggers(dbName, tableName string) ([]connection.T
func (f *fakeBatchWriteDB) ExecContext(ctx context.Context, query string) (int64, error) {
f.execCalls++
f.execQueries = append(f.execQueries, query)
return 1, nil
}
@@ -79,6 +82,45 @@ func (f *fakeBatchWriteDB) ExecBatchContext(ctx context.Context, query string) (
return 500, nil
}
func TestDBQueryMultiKeepsOracleAnonymousBlockAsSingleStatement(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
})
fakeDB := &fakeBatchWriteDB{}
newDatabaseFunc = func(dbType string) (db.Database, error) {
return fakeDB, nil
}
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
config := connection.ConnectionConfig{
Type: "oracle",
Host: "127.0.0.1",
Port: 1521,
User: "app",
}
query := `BEGIN
INSERT INTO tmp_disable_trigger (table_name) VALUES ('t_memcard_reg');
UPDATE t_memcard_reg SET CARDLEVEL = 1 WHERE MEMCARDNO = '8032277312';
DELETE FROM tmp_disable_trigger WHERE table_name = 't_memcard_reg';
END;`
result := app.DBQueryMulti(config, "ORCLPDB1", query, "oracle-plsql-test")
if !result.Success {
t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message)
}
if fakeDB.batchCalls != 0 {
t.Fatalf("expected PL/SQL block to skip batch path, got batchCalls=%d", fakeDB.batchCalls)
}
if fakeDB.execCalls != 1 || len(fakeDB.execQueries) != 1 {
t.Fatalf("expected one sequential exec call, got execCalls=%d queries=%#v", fakeDB.execCalls, fakeDB.execQueries)
}
if fakeDB.execQueries[0] != query {
t.Fatalf("expected PL/SQL block to stay intact, got %q", fakeDB.execQueries[0])
}
}
var _ db.BatchWriteExecer = (*fakeBatchWriteDB)(nil)
func TestDBQueryMultiUsesBatchWriteExecerForAllWriteStatements(t *testing.T) {

View File

@@ -662,6 +662,9 @@ func isSQLFileBatchableWriteStatement(dbType string, stmt string) bool {
if isReadOnlySQLQuery(dbType, stmt) {
return false
}
if isPLSQLBlockStatement(stmt) {
return false
}
switch leadingSQLKeyword(stmt) {
case "insert", "update", "delete", "replace", "merge", "upsert":
return true

View File

@@ -322,3 +322,38 @@ func TestStreamSQLFileHandlesSplitTokenBoundaries(t *testing.T) {
t.Fatalf("unexpected full-width semicolon statement: %q", statements[2])
}
}
func TestStreamSQLFileKeepsOracleAnonymousBlockTogether(t *testing.T) {
input := strings.Join([]string{
"BEGIN",
" INSERT INTO tmp_disable_trigger (table_name) VALUES ('t_memcard_reg');",
" UPDATE t_memcard_reg SET CARDLEVEL = 1 WHERE MEMCARDNO = '8032277312';",
" DELETE FROM tmp_disable_trigger WHERE table_name = 't_memcard_reg';",
"END;",
"SELECT 1 FROM dual;",
}, "\n")
var statements []string
count, err := streamSQLFile(&chunkedReader{data: []byte(input), step: 3}, func(index int, stmt string) error {
statements = append(statements, stmt)
return nil
})
if err != nil {
t.Fatalf("streamSQLFile returned error: %v", err)
}
if count != 2 || len(statements) != 2 {
t.Fatalf("expected 2 statements, got count=%d statements=%#v", count, statements)
}
if statements[0] != strings.Join([]string{
"BEGIN",
" INSERT INTO tmp_disable_trigger (table_name) VALUES ('t_memcard_reg');",
" UPDATE t_memcard_reg SET CARDLEVEL = 1 WHERE MEMCARDNO = '8032277312';",
" DELETE FROM tmp_disable_trigger WHERE table_name = 't_memcard_reg';",
"END;",
}, "\n") {
t.Fatalf("unexpected anonymous block statement: %q", statements[0])
}
if statements[1] != "SELECT 1 FROM dual" {
t.Fatalf("unexpected second statement: %q", statements[1])
}
}

View File

@@ -3,8 +3,9 @@ package app
import "strings"
// splitSQLStatements 按分号拆分 SQL 文本为独立语句。
// 正确处理单引号/双引号/反引号字符串、行注释(-- / #)、块注释(/* */
// PostgreSQL/Kingbase 的 $$...$$ dollar-quoting避免在这些上下文中错误拆分。
// 正确处理单引号/双引号/反引号字符串、行注释(-- / #)、块注释(/* */
// PostgreSQL/Kingbase 的 $$...$$ dollar-quoting以及 Oracle PL/SQL 匿名块,
// 避免在这些上下文中错误拆分。
// 同时支持 SQL 标准的转义单引号(两个连续单引号 '' 表示字面量引号)。
func splitSQLStatements(sql string) []string {
text := strings.ReplaceAll(sql, "\r\n", "\n")
@@ -18,6 +19,9 @@ func splitSQLStatements(sql string) []string {
inLineComment := false
inBlockComment := false
var dollarTag string // postgres/kingbase: $$...$$ or $tag$...$tag$
plsqlDepth := 0
plsqlDeclareBeginSkips := 0
justClosedPLSQLBlock := false
push := func() {
s := strings.TrimSpace(cur.String())
@@ -108,6 +112,35 @@ func splitSQLStatements(sql string) []string {
continue
}
if isSQLIdentifierStart(ch) {
tokenStart := i
tokenEnd := i + 1
for tokenEnd < len(text) && isSQLIdentifierPart(text[tokenEnd]) {
tokenEnd++
}
token := strings.ToLower(text[tokenStart:tokenEnd])
if token == "begin" && plsqlDeclareBeginSkips > 0 {
plsqlDeclareBeginSkips--
justClosedPLSQLBlock = false
} else if token == "begin" && shouldEnterPLSQLBlock(text, tokenEnd) {
plsqlDepth++
justClosedPLSQLBlock = false
} else if token == "declare" && shouldEnterPLSQLDeclareBlock(text, tokenEnd) {
plsqlDepth++
plsqlDeclareBeginSkips++
justClosedPLSQLBlock = false
} else if token == "end" && plsqlDepth > 0 && !isPLSQLControlEnd(text, tokenEnd) {
plsqlDepth--
if plsqlDeclareBeginSkips > plsqlDepth {
plsqlDeclareBeginSkips = plsqlDepth
}
justClosedPLSQLBlock = plsqlDepth == 0
}
cur.WriteString(text[tokenStart:tokenEnd])
i = tokenEnd - 1
continue
}
// 行注释开始
if ch == '-' && next == '-' {
inLineComment = true
@@ -140,11 +173,33 @@ func splitSQLStatements(sql string) []string {
// 分号分隔(支持全角分号""
if ch == ';' {
if plsqlDepth > 0 {
cur.WriteByte(ch)
continue
}
if justClosedPLSQLBlock {
cur.WriteByte(ch)
push()
justClosedPLSQLBlock = false
continue
}
push()
continue
}
// 全角分号 UTF-8 序列: 0xEF 0xBC 0x9B
if ch == 0xEF && i+2 < len(text) && text[i+1] == 0xBC && text[i+2] == 0x9B {
if plsqlDepth > 0 {
cur.WriteString("")
i += 2
continue
}
if justClosedPLSQLBlock {
cur.WriteString("")
push()
justClosedPLSQLBlock = false
i += 2
continue
}
push()
i += 2
continue
@@ -157,6 +212,110 @@ func splitSQLStatements(sql string) []string {
return statements
}
func isSQLIdentifierStart(ch byte) bool {
return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '_'
}
func isSQLIdentifierPart(ch byte) bool {
return isSQLIdentifierStart(ch) || (ch >= '0' && ch <= '9') || ch == '$' || ch == '#'
}
func skipSQLWhitespaceAndComments(text string, pos int) int {
i := pos
for i < len(text) {
switch text[i] {
case ' ', '\t', '\n', '\r', '\f':
i++
continue
case '-':
if i+1 < len(text) && text[i+1] == '-' {
i += 2
for i < len(text) && text[i] != '\n' {
i++
}
continue
}
case '/':
if i+1 < len(text) && text[i+1] == '*' {
i += 2
for i+1 < len(text) && !(text[i] == '*' && text[i+1] == '/') {
i++
}
if i+1 < len(text) {
i += 2
}
continue
}
}
break
}
return i
}
func nextSQLSignificantToken(text string, pos int) string {
i := skipSQLWhitespaceAndComments(text, pos)
if i >= len(text) || !isSQLIdentifierStart(text[i]) {
return ""
}
end := i + 1
for end < len(text) && isSQLIdentifierPart(text[end]) {
end++
}
return strings.ToLower(text[i:end])
}
func nextSQLSignificantByte(text string, pos int) byte {
i := skipSQLWhitespaceAndComments(text, pos)
if i >= len(text) {
return 0
}
return text[i]
}
func shouldEnterPLSQLBlock(text string, tokenEnd int) bool {
switch nextSQLSignificantByte(text, tokenEnd) {
case 0, ';':
return false
}
switch nextSQLSignificantToken(text, tokenEnd) {
case "transaction", "work", "isolation", "read", "write":
return false
default:
return true
}
}
func isPLSQLBlockStatement(stmt string) bool {
text := strings.TrimSpace(stmt)
if text == "" {
return false
}
if strings.HasSuffix(text, "/") {
text = strings.TrimSpace(strings.TrimSuffix(text, "/"))
}
token := nextSQLSignificantToken(text, 0)
if token == "declare" {
return shouldEnterPLSQLDeclareBlock(text, len("declare"))
}
if token != "begin" {
return false
}
return shouldEnterPLSQLBlock(text, len("begin"))
}
func shouldEnterPLSQLDeclareBlock(text string, tokenEnd int) bool {
return nextSQLSignificantToken(text, tokenEnd) != ""
}
func isPLSQLControlEnd(text string, tokenEnd int) bool {
switch nextSQLSignificantToken(text, tokenEnd) {
case "if", "loop", "case":
return true
default:
return false
}
}
// parseSQLDollarTag 解析 PostgreSQL/Kingbase 的 dollar-quoting 标签。
func parseSQLDollarTag(s string) string {
if len(s) < 2 || s[0] != '$' {

View File

@@ -18,6 +18,9 @@ type sqlStreamSplitter struct {
inLineComment bool
inBlockComment bool
dollarTag string
plsqlDepth int
declareSkips int
closedPLSQL bool
}
// Feed 将一个 chunk 喂入拆分器,返回在此 chunk 中完成的 SQL 语句列表。
@@ -118,6 +121,43 @@ func (s *sqlStreamSplitter) Feed(chunk []byte) []string {
continue
}
if isSQLIdentifierStart(ch) {
tokenStart := i
tokenEnd := i + 1
for tokenEnd < len(text) && isSQLIdentifierPart(text[tokenEnd]) {
tokenEnd++
}
token := strings.ToLower(text[tokenStart:tokenEnd])
if shouldDeferPLSQLKeywordPrefixInStream(text, tokenStart, tokenEnd, token) {
s.pending = text[tokenStart:]
break
}
if shouldDeferPLSQLKeywordInStream(text, tokenStart, tokenEnd, token) {
s.pending = text[tokenStart:]
break
}
if token == "begin" && s.declareSkips > 0 {
s.declareSkips--
s.closedPLSQL = false
} else if token == "begin" && shouldEnterPLSQLBlock(text, tokenEnd) {
s.plsqlDepth++
s.closedPLSQL = false
} else if token == "declare" && shouldEnterPLSQLDeclareBlock(text, tokenEnd) {
s.plsqlDepth++
s.declareSkips++
s.closedPLSQL = false
} else if token == "end" && s.plsqlDepth > 0 && !isPLSQLControlEnd(text, tokenEnd) {
s.plsqlDepth--
if s.declareSkips > s.plsqlDepth {
s.declareSkips = s.plsqlDepth
}
s.closedPLSQL = s.plsqlDepth == 0
}
s.cur.WriteString(text[tokenStart:tokenEnd])
i = tokenEnd - 1
continue
}
// 行注释开始
if ch == '-' && i+1 >= len(text) {
s.pending = text[i:]
@@ -162,6 +202,20 @@ func (s *sqlStreamSplitter) Feed(chunk []byte) []string {
// 分号分隔
if ch == ';' {
if s.plsqlDepth > 0 {
s.cur.WriteByte(ch)
continue
}
if s.closedPLSQL {
s.cur.WriteByte(ch)
stmt := strings.TrimSpace(s.cur.String())
if stmt != "" {
statements = append(statements, stmt)
}
s.cur.Reset()
s.closedPLSQL = false
continue
}
stmt := strings.TrimSpace(s.cur.String())
if stmt != "" {
statements = append(statements, stmt)
@@ -175,6 +229,22 @@ func (s *sqlStreamSplitter) Feed(chunk []byte) []string {
break
}
if ch == 0xEF && i+2 < len(text) && text[i+1] == 0xBC && text[i+2] == 0x9B {
if s.plsqlDepth > 0 {
s.cur.WriteString("")
i += 2
continue
}
if s.closedPLSQL {
s.cur.WriteString("")
stmt := strings.TrimSpace(s.cur.String())
if stmt != "" {
statements = append(statements, stmt)
}
s.cur.Reset()
s.closedPLSQL = false
i += 2
continue
}
stmt := strings.TrimSpace(s.cur.String())
if stmt != "" {
statements = append(statements, stmt)
@@ -217,6 +287,44 @@ func isIncompleteSQLDollarTag(s string) bool {
return true
}
func shouldDeferPLSQLKeywordInStream(text string, tokenStart int, tokenEnd int, token string) bool {
switch token {
case "begin", "declare", "end":
default:
return false
}
if tokenEnd >= len(text) {
return true
}
next := skipSQLWhitespaceAndComments(text, tokenEnd)
if next >= len(text) {
return true
}
if isSQLIdentifierStart(text[next]) {
nextEnd := next + 1
for nextEnd < len(text) && isSQLIdentifierPart(text[nextEnd]) {
nextEnd++
}
return nextEnd >= len(text)
}
return false
}
func shouldDeferPLSQLKeywordPrefixInStream(text string, tokenStart int, tokenEnd int, token string) bool {
if tokenEnd < len(text) {
return false
}
for _, keyword := range []string{"begin", "declare", "end"} {
if strings.HasPrefix(keyword, token) && token != keyword {
if tokenStart > 0 && isSQLIdentifierPart(text[tokenStart-1]) {
return false
}
return true
}
}
return false
}
// streamSQLFile 从 reader 中流式读取 SQL 并逐条回调。
// onStatement 返回 error 时停止读取并返回该 error。
// 返回总处理语句数和可能的错误。

View File

@@ -111,3 +111,59 @@ func TestSplitSQLStatements_SQLEscapedQuoteMultiple(t *testing.T) {
}
}
func TestSplitSQLStatements_OracleAnonymousBlock(t *testing.T) {
input := `BEGIN
INSERT INTO tmp_disable_trigger (table_name) VALUES ('t_memcard_reg');
UPDATE t_memcard_reg SET CARDLEVEL = 1 WHERE MEMCARDNO = '8032277312';
DELETE FROM tmp_disable_trigger WHERE table_name = 't_memcard_reg';
END;
SELECT 1 FROM dual;`
got := splitSQLStatements(input)
want := []string{
`BEGIN
INSERT INTO tmp_disable_trigger (table_name) VALUES ('t_memcard_reg');
UPDATE t_memcard_reg SET CARDLEVEL = 1 WHERE MEMCARDNO = '8032277312';
DELETE FROM tmp_disable_trigger WHERE table_name = 't_memcard_reg';
END;`,
"SELECT 1 FROM dual",
}
if !reflect.DeepEqual(got, want) {
t.Errorf("splitSQLStatements(%q) = %#v, want %#v", input, got, want)
}
}
func TestSplitSQLStatements_OracleDeclareBlock(t *testing.T) {
input := `DECLARE
v_count NUMBER;
BEGIN
SELECT COUNT(*) INTO v_count FROM t_memcard_reg;
UPDATE t_memcard_reg SET CARDLEVEL = v_count WHERE MEMCARDNO = '8032277312';
END;
SELECT 1 FROM dual;`
got := splitSQLStatements(input)
want := []string{
`DECLARE
v_count NUMBER;
BEGIN
SELECT COUNT(*) INTO v_count FROM t_memcard_reg;
UPDATE t_memcard_reg SET CARDLEVEL = v_count WHERE MEMCARDNO = '8032277312';
END;`,
"SELECT 1 FROM dual",
}
if !reflect.DeepEqual(got, want) {
t.Errorf("splitSQLStatements(%q) = %#v, want %#v", input, got, want)
}
}
func TestSplitSQLStatements_TransactionBeginStillSplits(t *testing.T) {
input := "BEGIN; UPDATE accounts SET balance = balance - 1 WHERE id = 1; COMMIT;"
got := splitSQLStatements(input)
want := []string{
"BEGIN",
"UPDATE accounts SET balance = balance - 1 WHERE id = 1",
"COMMIT",
}
if !reflect.DeepEqual(got, want) {
t.Errorf("splitSQLStatements(%q) = %#v, want %#v", input, got, want)
}
}