diff --git a/frontend/package.json.md5 b/frontend/package.json.md5 index 7396e24..bed8925 100755 --- a/frontend/package.json.md5 +++ b/frontend/package.json.md5 @@ -1 +1 @@ -d0464f9da25e9356e61652e638c99ffe \ No newline at end of file +0295a42fd931778d85157816d79d29e5 \ No newline at end of file diff --git a/frontend/src/components/QueryEditor.external-sql-save.test.tsx b/frontend/src/components/QueryEditor.external-sql-save.test.tsx index 2d4df89..a38f955 100644 --- a/frontend/src/components/QueryEditor.external-sql-save.test.tsx +++ b/frontend/src/components/QueryEditor.external-sql-save.test.tsx @@ -1676,6 +1676,42 @@ describe('QueryEditor external SQL save', () => { renderer?.unmount(); }); + it('keeps Oracle anonymous PL/SQL blocks intact when running from the editor', async () => { + storeState.connections[0].config.type = 'oracle'; + storeState.connections[0].config.database = 'ORCLPDB1'; + backendApp.DBQueryMulti.mockResolvedValueOnce({ + success: true, + data: [{ columns: ['affectedRows'], rows: [{ affectedRows: 1 }] }], + }); + const plsql = [ + 'BEGIN', + " INSERT INTO tmp_disable_trigger (table_name) VALUES ('t_memcard_reg');", + " UPDATE t_memcard_reg SET CARDLEVEL = 1 WHERE MEMCARDNO = '8032277312';", + " DELETE FROM tmp_disable_trigger WHERE table_name = 't_memcard_reg';", + 'END;', + ].join('\n'); + + let renderer!: ReactTestRenderer; + await act(async () => { + renderer = create(); + }); + + await act(async () => { + await findButton(renderer!, '运行').props.onClick(); + }); + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + }); + + expect(backendApp.DBQueryMulti).toHaveBeenCalledWith(expect.anything(), 'ORCLPDB1', plsql, 'query-1'); + expect(storeState.addSqlLog).toHaveBeenCalledWith(expect.objectContaining({ + sql: plsql, + status: 'success', + })); + renderer?.unmount(); + }); + it('keeps non-Oracle query results read-only when no safe locator exists', async () => { backendApp.DBQueryMulti.mockResolvedValueOnce({ success: true, @@ -1710,6 +1746,46 @@ describe('QueryEditor external SQL save', () => { expect(messageApi.warning).toHaveBeenCalledWith('查询结果保持只读:main.users 未检测到主键或可用唯一索引,无法安全提交修改。'); }); + it('keeps MySQL information_schema routine results read-only without a locator warning', async () => { + const sql = [ + 'SELECT ROUTINE_SCHEMA, ROUTINE_NAME, DEFINER, SECURITY_TYPE', + 'FROM information_schema.ROUTINES', + "WHERE ROUTINE_SCHEMA = 'mkefu_location_dev_local'", + " AND ROUTINE_NAME = 'init_orgi'", + ].join('\n'); + backendApp.DBQueryMulti.mockResolvedValueOnce({ + success: true, + data: [{ + columns: ['ROUTINE_SCHEMA', 'ROUTINE_NAME', 'DEFINER', 'SECURITY_TYPE'], + rows: [{ + ROUTINE_SCHEMA: 'mkefu_location_dev_local', + ROUTINE_NAME: 'init_orgi', + DEFINER: 'root@%', + SECURITY_TYPE: 'DEFINER', + }], + }], + }); + + let renderer: ReactTestRenderer; + await act(async () => { + renderer = create(); + }); + + await act(async () => { + await findButton(renderer!, '运行').props.onClick(); + }); + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + }); + + expect(dataGridState.latestProps?.tableName).toBe('ROUTINES'); + expect(dataGridState.latestProps?.readOnly).toBe(true); + expect(backendApp.DBGetColumns).not.toHaveBeenCalled(); + expect(backendApp.DBGetIndexes).not.toHaveBeenCalled(); + expect(messageApi.warning).not.toHaveBeenCalled(); + }); + it('runs the SQL statement at the cursor instead of the whole editor when nothing is selected', async () => { backendApp.DBQueryMulti.mockResolvedValueOnce({ success: true, diff --git a/frontend/src/components/QueryEditor.tsx b/frontend/src/components/QueryEditor.tsx index b9fa1f3..d58337f 100644 --- a/frontend/src/components/QueryEditor.tsx +++ b/frontend/src/components/QueryEditor.tsx @@ -262,6 +262,33 @@ const stripQueryIdentifierQuotes = (part: string): string => { return text; }; +const MYSQL_SYSTEM_METADATA_SCHEMAS = new Set(['information_schema', 'performance_schema', 'mysql', 'sys']); +const POSTGRES_SYSTEM_METADATA_SCHEMAS = new Set(['information_schema', 'pg_catalog']); +const SQLITE_SYSTEM_METADATA_TABLES = new Set(['sqlite_master', 'sqlite_schema', 'sqlite_temp_master', 'sqlite_temp_schema']); + +const isSystemMetadataQueryResult = (tableRef: QueryResultTableRef, dbType: string): boolean => { + const normalizedDbType = String(dbType || '').trim().toLowerCase(); + const metadataDbName = stripQueryIdentifierQuotes(tableRef.metadataDbName).toLowerCase(); + const metadataTableName = stripQueryIdentifierQuotes(tableRef.metadataTableName).toLowerCase(); + + if (['mysql', 'mariadb', 'oceanbase', 'diros', 'starrocks', 'sphinx', 'tidb'].includes(normalizedDbType)) { + return MYSQL_SYSTEM_METADATA_SCHEMAS.has(metadataDbName); + } + if (['postgres', 'kingbase', 'highgo', 'vastbase', 'opengauss'].includes(normalizedDbType)) { + return POSTGRES_SYSTEM_METADATA_SCHEMAS.has(metadataDbName); + } + if (normalizedDbType === 'sqlite' || normalizedDbType === 'duckdb') { + return SQLITE_SYSTEM_METADATA_TABLES.has(metadataTableName) || metadataDbName === 'information_schema'; + } + if (normalizedDbType === 'sqlserver') { + return metadataDbName === 'information_schema' || metadataDbName === 'sys'; + } + if (normalizedDbType === 'clickhouse') { + return metadataDbName === 'system' || metadataDbName === 'information_schema'; + } + return false; +}; + const splitTopLevelComma = (text: string): string[] => { const parts: string[] = []; let current = ''; @@ -658,6 +685,65 @@ const areSqlStatementListsEqual = (left: string[], right: string[]): boolean => && left.every((statement, index) => normalizeExecutedSqlKey(statement) === normalizeExecutedSqlKey(right[index])) ); +const isSqlIdentifierStart = (ch: string): boolean => /^[A-Za-z_]$/.test(ch); + +const isSqlIdentifierPart = (ch: string): boolean => /^[A-Za-z0-9_$#]$/.test(ch); + +const skipSqlWhitespaceAndComments = (text: string, position: number): number => { + let index = position; + while (index < text.length) { + const ch = text[index]; + const next = index + 1 < text.length ? text[index + 1] : ''; + if (ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r' || ch === '\f') { + index += 1; + continue; + } + if (ch === '-' && next === '-') { + index += 2; + while (index < text.length && text[index] !== '\n') index += 1; + continue; + } + if (ch === '/' && next === '*') { + index += 2; + while (index + 1 < text.length && !(text[index] === '*' && text[index + 1] === '/')) { + index += 1; + } + if (index + 1 < text.length) index += 2; + continue; + } + break; + } + return index; +}; + +const nextSqlSignificantToken = (text: string, position: number): string => { + const index = skipSqlWhitespaceAndComments(text, position); + if (index >= text.length || !isSqlIdentifierStart(text[index])) return ''; + let end = index + 1; + while (end < text.length && isSqlIdentifierPart(text[end])) end += 1; + return text.slice(index, end).toLowerCase(); +}; + +const nextSqlSignificantChar = (text: string, position: number): string => { + const index = skipSqlWhitespaceAndComments(text, position); + return index >= text.length ? '' : text[index]; +}; + +const shouldEnterPlsqlBeginBlock = (text: string, tokenEnd: number): boolean => { + const nextChar = nextSqlSignificantChar(text, tokenEnd); + if (!nextChar || nextChar === ';') return false; + return !['transaction', 'work', 'isolation', 'read', 'write'].includes(nextSqlSignificantToken(text, tokenEnd)); +}; + +const shouldEnterPlsqlDeclareBlock = (text: string, tokenEnd: number): boolean => { + const nextToken = nextSqlSignificantToken(text, tokenEnd); + return Boolean(nextToken); +}; + +const isPlsqlControlEnd = (text: string, tokenEnd: number): boolean => ( + ['if', 'loop', 'case'].includes(nextSqlSignificantToken(text, tokenEnd)) +); + const normalizeEditorPosition = (position: any): { lineNumber: number; column: number } | null => { if (!position) return null; const lineNumber = Number(position.positionLineNumber ?? position.lineNumber ?? position.endLineNumber ?? position.startLineNumber ?? position.selectionStartLineNumber); @@ -1563,6 +1649,10 @@ const resolveQueryLocatorPlan = async ({ const tableRef = extractQueryResultTableRef(statement, dbType, currentDb); if (!tableRef) return plan; plan.tableRef = tableRef; + if (isSystemMetadataQueryResult(tableRef, dbType)) { + plan.editLocator = buildQueryReadOnlyLocator('系统元数据查询结果保持只读。'); + return plan; + } const selectInfo = parseSimpleSelectInfo(statement); if (!selectInfo) { @@ -3604,6 +3694,9 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc let inLineComment = false; let inBlockComment = false; let dollarTag: string | null = null; // postgres/kingbase: $$...$$ or $tag$...$tag$ + let plsqlDepth = 0; + let plsqlDeclareBeginSkips = 0; + let justClosedPLSQLBlock = false; const push = () => { const s = cur.trim(); @@ -3705,7 +3798,45 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc continue; } + if (!inSingle && !inDouble && !inBacktick && !dollarTag && isSqlIdentifierStart(ch)) { + let end = i + 1; + while (end < text.length && isSqlIdentifierPart(text[end])) { + end += 1; + } + const token = text.slice(i, end).toLowerCase(); + if (token === 'begin' && plsqlDeclareBeginSkips > 0) { + plsqlDeclareBeginSkips -= 1; + justClosedPLSQLBlock = false; + } else if (token === 'begin' && shouldEnterPlsqlBeginBlock(text, end)) { + plsqlDepth += 1; + justClosedPLSQLBlock = false; + } else if (token === 'declare' && shouldEnterPlsqlDeclareBlock(text, end)) { + plsqlDepth += 1; + plsqlDeclareBeginSkips += 1; + justClosedPLSQLBlock = false; + } else if (token === 'end' && plsqlDepth > 0 && !isPlsqlControlEnd(text, end)) { + plsqlDepth -= 1; + if (plsqlDeclareBeginSkips > plsqlDepth) { + plsqlDeclareBeginSkips = plsqlDepth; + } + justClosedPLSQLBlock = plsqlDepth === 0; + } + cur += text.slice(i, end); + i = end - 1; + continue; + } + if (!inSingle && !inDouble && !inBacktick && !dollarTag && (ch === ';' || ch === ';')) { + if (plsqlDepth > 0) { + cur += ch; + continue; + } + if (justClosedPLSQLBlock) { + cur += ch; + push(); + justClosedPLSQLBlock = false; + continue; + } push(); continue; } diff --git a/frontend/src/utils/sqlStatementSelection.test.ts b/frontend/src/utils/sqlStatementSelection.test.ts index 95a940e..a21f749 100644 --- a/frontend/src/utils/sqlStatementSelection.test.ts +++ b/frontend/src/utils/sqlStatementSelection.test.ts @@ -28,6 +28,74 @@ describe('sqlStatementSelection', () => { ]); }); + it('keeps Oracle anonymous PL/SQL blocks as one executable statement', () => { + const plsql = [ + 'BEGIN', + " INSERT INTO tmp_disable_trigger (table_name) VALUES ('t_memcard_reg');", + " UPDATE t_memcard_reg SET CARDLEVEL = 1 WHERE MEMCARDNO = '8032277312';", + " DELETE FROM tmp_disable_trigger WHERE table_name = 't_memcard_reg';", + 'END;', + 'SELECT 1 FROM dual;', + ].join('\n'); + + const ranges = findSqlStatementRanges(plsql).map((range) => range.text); + + expect(ranges).toEqual([ + [ + 'BEGIN', + " INSERT INTO tmp_disable_trigger (table_name) VALUES ('t_memcard_reg');", + " UPDATE t_memcard_reg SET CARDLEVEL = 1 WHERE MEMCARDNO = '8032277312';", + " DELETE FROM tmp_disable_trigger WHERE table_name = 't_memcard_reg';", + 'END;', + ].join('\n'), + 'SELECT 1 FROM dual', + ]); + expect(resolveExecutableSql(plsql, plsql.indexOf('UPDATE'))).toEqual({ + sql: ranges[0], + source: 'statement', + }); + }); + + it('keeps Oracle DECLARE blocks as one executable statement', () => { + const sql = [ + 'DECLARE', + ' v_count NUMBER;', + 'BEGIN', + ' SELECT COUNT(*) INTO v_count FROM t_memcard_reg;', + " UPDATE t_memcard_reg SET CARDLEVEL = v_count WHERE MEMCARDNO = '8032277312';", + 'END;', + 'SELECT 1 FROM dual;', + ].join('\n'); + + const ranges = findSqlStatementRanges(sql).map((range) => range.text); + + expect(ranges).toEqual([ + [ + 'DECLARE', + ' v_count NUMBER;', + 'BEGIN', + ' SELECT COUNT(*) INTO v_count FROM t_memcard_reg;', + " UPDATE t_memcard_reg SET CARDLEVEL = v_count WHERE MEMCARDNO = '8032277312';", + 'END;', + ].join('\n'), + 'SELECT 1 FROM dual', + ]); + expect(resolveExecutableSql(sql, sql.indexOf('UPDATE'))).toEqual({ + sql: ranges[0], + source: 'statement', + }); + }); + + it('still splits transaction BEGIN statements', () => { + const sql = 'BEGIN; UPDATE accounts SET balance = balance - 1 WHERE id = 1; COMMIT;'; + + expect(findSqlStatementRanges(sql).map((range) => range.text)).toEqual([ + 'BEGIN', + 'UPDATE accounts SET balance = balance - 1 WHERE id = 1', + 'COMMIT', + ]); + }); + it('selects the next statement when the cursor is on whitespace before it', () => { const sql = 'select 1;\n\n select 2;'; const range = resolveCurrentSqlStatementRange(sql, sql.indexOf(' select 2')); diff --git a/frontend/src/utils/sqlStatementSelection.ts b/frontend/src/utils/sqlStatementSelection.ts index 181dc51..8a3dfe8 100644 --- a/frontend/src/utils/sqlStatementSelection.ts +++ b/frontend/src/utils/sqlStatementSelection.ts @@ -12,7 +12,63 @@ export interface SqlExecutionSelection { } const isWhitespace = (ch: string): boolean => ( - ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r' + ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r' || ch === '\f' +); + +const isSqlIdentifierStart = (ch: string): boolean => /^[A-Za-z_]$/.test(ch); + +const isSqlIdentifierPart = (ch: string): boolean => /^[A-Za-z0-9_$#]$/.test(ch); + +const skipSqlWhitespaceAndComments = (text: string, position: number): number => { + let index = position; + while (index < text.length) { + const ch = text[index]; + const next = index + 1 < text.length ? text[index + 1] : ''; + if (isWhitespace(ch)) { + index += 1; + continue; + } + if (ch === '-' && next === '-') { + index += 2; + while (index < text.length && text[index] !== '\n') index += 1; + continue; + } + if (ch === '/' && next === '*') { + index += 2; + while (index + 1 < text.length && !(text[index] === '*' && text[index + 1] === '/')) { + index += 1; + } + if (index + 1 < text.length) index += 2; + continue; + } + break; + } + return index; +}; + +const nextSqlSignificantToken = (text: string, position: number): string => { + const index = skipSqlWhitespaceAndComments(text, position); + if (index >= text.length || !isSqlIdentifierStart(text[index])) return ''; + let end = index + 1; + while (end < text.length && isSqlIdentifierPart(text[end])) end += 1; + return text.slice(index, end).toLowerCase(); +}; + +const nextSqlSignificantChar = (text: string, position: number): string => { + const index = skipSqlWhitespaceAndComments(text, position); + return index >= text.length ? '' : text[index]; +}; + +const shouldEnterPlsqlBeginBlock = (text: string, tokenEnd: number): boolean => { + const nextChar = nextSqlSignificantChar(text, tokenEnd); + if (!nextChar || nextChar === ';') return false; + return !['transaction', 'work', 'isolation', 'read', 'write'].includes(nextSqlSignificantToken(text, tokenEnd)); +}; + +const shouldEnterPlsqlDeclareBlock = (text: string, tokenEnd: number): boolean => Boolean(nextSqlSignificantToken(text, tokenEnd)); + +const isPlsqlControlEnd = (text: string, tokenEnd: number): boolean => ( + ['if', 'loop', 'case'].includes(nextSqlSignificantToken(text, tokenEnd)) ); const trimStatementRange = (sql: string, start: number, end: number): SqlStatementRange | null => { @@ -49,6 +105,9 @@ export const findSqlStatementRanges = (sql: string): SqlStatementRange[] => { let inLineComment = false; let inBlockComment = false; let dollarTag: string | null = null; + let plsqlDepth = 0; + let plsqlDeclareBeginSkips = 0; + let justClosedPLSQLBlock = false; const push = (end: number) => { const range = trimStatementRange(text, statementStart, end); @@ -134,9 +193,41 @@ export const findSqlStatementRanges = (sql: string): SqlStatementRange[] => { continue; } + if (!inSingle && !inDouble && !inBacktick && !dollarTag && isSqlIdentifierStart(ch)) { + let tokenEnd = index + 1; + while (tokenEnd < text.length && isSqlIdentifierPart(text[tokenEnd])) { + tokenEnd++; + } + const token = text.slice(index, tokenEnd).toLowerCase(); + if (token === 'begin' && plsqlDeclareBeginSkips > 0) { + plsqlDeclareBeginSkips--; + justClosedPLSQLBlock = false; + } else if (token === 'begin' && shouldEnterPlsqlBeginBlock(text, tokenEnd)) { + plsqlDepth++; + justClosedPLSQLBlock = false; + } else if (token === 'declare' && shouldEnterPlsqlDeclareBlock(text, tokenEnd)) { + plsqlDepth++; + plsqlDeclareBeginSkips++; + justClosedPLSQLBlock = false; + } else if (token === 'end' && plsqlDepth > 0 && !isPlsqlControlEnd(text, tokenEnd)) { + plsqlDepth--; + if (plsqlDeclareBeginSkips > plsqlDepth) { + plsqlDeclareBeginSkips = plsqlDepth; + } + justClosedPLSQLBlock = plsqlDepth === 0; + } + index = tokenEnd - 1; + continue; + } + if (!inSingle && !inDouble && !inBacktick && (ch === ';' || ch === ';')) { - push(index); + if (plsqlDepth > 0) { + continue; + } + push(justClosedPLSQLBlock ? index + 1 : index); statementStart = index + 1; + justClosedPLSQLBlock = false; + continue; } } diff --git a/internal/app/methods_db.go b/internal/app/methods_db.go index 1cf99b5..c6625d3 100644 --- a/internal/app/methods_db.go +++ b/internal/app/methods_db.go @@ -770,13 +770,16 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu // 适用于 MySQL/MariaDB/Doris/PostgreSQL/SQLite/DuckDB 等支持多语句 Exec 的驱动 if !allReadOnly { allWrite := true + containsPLSQLBlock := false for _, stmt := range statements { if strings.TrimSpace(stmt) != "" && isReadOnlySQLQuery(runConfig.Type, stmt) { allWrite = false - break + } + if isPLSQLBlockStatement(stmt) { + containsPLSQLBlock = true } } - if allWrite { + if allWrite && !containsPLSQLBlock { if batcher, ok := dbInst.(db.BatchWriteExecer); ok { affected, batchErr := batcher.ExecBatchContext(ctx, query) if batchErr != nil && shouldRefreshCachedConnection(batchErr) { diff --git a/internal/app/methods_db_multi_test.go b/internal/app/methods_db_multi_test.go index 1f0af6d..ff30389 100644 --- a/internal/app/methods_db_multi_test.go +++ b/internal/app/methods_db_multi_test.go @@ -12,6 +12,7 @@ import ( type fakeBatchWriteDB struct { batchCalls int execCalls int + execQueries []string lastQuery string } @@ -33,6 +34,7 @@ func (f *fakeBatchWriteDB) Query(query string) ([]map[string]interface{}, []stri func (f *fakeBatchWriteDB) Exec(query string) (int64, error) { f.execCalls++ + f.execQueries = append(f.execQueries, query) return 1, nil } @@ -70,6 +72,7 @@ func (f *fakeBatchWriteDB) GetTriggers(dbName, tableName string) ([]connection.T func (f *fakeBatchWriteDB) ExecContext(ctx context.Context, query string) (int64, error) { f.execCalls++ + f.execQueries = append(f.execQueries, query) return 1, nil } @@ -79,6 +82,45 @@ func (f *fakeBatchWriteDB) ExecBatchContext(ctx context.Context, query string) ( return 500, nil } +func TestDBQueryMultiKeepsOracleAnonymousBlockAsSingleStatement(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + t.Cleanup(func() { + newDatabaseFunc = originalNewDatabaseFunc + }) + + fakeDB := &fakeBatchWriteDB{} + newDatabaseFunc = func(dbType string) (db.Database, error) { + return fakeDB, nil + } + + app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test")) + config := connection.ConnectionConfig{ + Type: "oracle", + Host: "127.0.0.1", + Port: 1521, + User: "app", + } + query := `BEGIN + INSERT INTO tmp_disable_trigger (table_name) VALUES ('t_memcard_reg'); + UPDATE t_memcard_reg SET CARDLEVEL = 1 WHERE MEMCARDNO = '8032277312'; + DELETE FROM tmp_disable_trigger WHERE table_name = 't_memcard_reg'; +END;` + + result := app.DBQueryMulti(config, "ORCLPDB1", query, "oracle-plsql-test") + if !result.Success { + t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message) + } + if fakeDB.batchCalls != 0 { + t.Fatalf("expected PL/SQL block to skip batch path, got batchCalls=%d", fakeDB.batchCalls) + } + if fakeDB.execCalls != 1 || len(fakeDB.execQueries) != 1 { + t.Fatalf("expected one sequential exec call, got execCalls=%d queries=%#v", fakeDB.execCalls, fakeDB.execQueries) + } + if fakeDB.execQueries[0] != query { + t.Fatalf("expected PL/SQL block to stay intact, got %q", fakeDB.execQueries[0]) + } +} + var _ db.BatchWriteExecer = (*fakeBatchWriteDB)(nil) func TestDBQueryMultiUsesBatchWriteExecerForAllWriteStatements(t *testing.T) { diff --git a/internal/app/methods_file.go b/internal/app/methods_file.go index 19dbd22..0accc15 100644 --- a/internal/app/methods_file.go +++ b/internal/app/methods_file.go @@ -662,6 +662,9 @@ func isSQLFileBatchableWriteStatement(dbType string, stmt string) bool { if isReadOnlySQLQuery(dbType, stmt) { return false } + if isPLSQLBlockStatement(stmt) { + return false + } switch leadingSQLKeyword(stmt) { case "insert", "update", "delete", "replace", "merge", "upsert": return true diff --git a/internal/app/methods_file_sql_execution_test.go b/internal/app/methods_file_sql_execution_test.go index 928010f..0a6ca6a 100644 --- a/internal/app/methods_file_sql_execution_test.go +++ b/internal/app/methods_file_sql_execution_test.go @@ -322,3 +322,38 @@ func TestStreamSQLFileHandlesSplitTokenBoundaries(t *testing.T) { t.Fatalf("unexpected full-width semicolon statement: %q", statements[2]) } } + +func TestStreamSQLFileKeepsOracleAnonymousBlockTogether(t *testing.T) { + input := strings.Join([]string{ + "BEGIN", + " INSERT INTO tmp_disable_trigger (table_name) VALUES ('t_memcard_reg');", + " UPDATE t_memcard_reg SET CARDLEVEL = 1 WHERE MEMCARDNO = '8032277312';", + " DELETE FROM tmp_disable_trigger WHERE table_name = 't_memcard_reg';", + "END;", + "SELECT 1 FROM dual;", + }, "\n") + var statements []string + + count, err := streamSQLFile(&chunkedReader{data: []byte(input), step: 3}, func(index int, stmt string) error { + statements = append(statements, stmt) + return nil + }) + if err != nil { + t.Fatalf("streamSQLFile returned error: %v", err) + } + if count != 2 || len(statements) != 2 { + t.Fatalf("expected 2 statements, got count=%d statements=%#v", count, statements) + } + if statements[0] != strings.Join([]string{ + "BEGIN", + " INSERT INTO tmp_disable_trigger (table_name) VALUES ('t_memcard_reg');", + " UPDATE t_memcard_reg SET CARDLEVEL = 1 WHERE MEMCARDNO = '8032277312';", + " DELETE FROM tmp_disable_trigger WHERE table_name = 't_memcard_reg';", + "END;", + }, "\n") { + t.Fatalf("unexpected anonymous block statement: %q", statements[0]) + } + if statements[1] != "SELECT 1 FROM dual" { + t.Fatalf("unexpected second statement: %q", statements[1]) + } +} diff --git a/internal/app/sql_split.go b/internal/app/sql_split.go index f73cebd..bbd9829 100644 --- a/internal/app/sql_split.go +++ b/internal/app/sql_split.go @@ -3,8 +3,9 @@ package app import "strings" // splitSQLStatements 按分号拆分 SQL 文本为独立语句。 -// 正确处理单引号/双引号/反引号字符串、行注释(-- / #)、块注释(/* */)和 -// PostgreSQL/Kingbase 的 $$...$$ dollar-quoting,避免在这些上下文中错误拆分。 +// 正确处理单引号/双引号/反引号字符串、行注释(-- / #)、块注释(/* */)、 +// PostgreSQL/Kingbase 的 $$...$$ dollar-quoting,以及 Oracle PL/SQL 匿名块, +// 避免在这些上下文中错误拆分。 // 同时支持 SQL 标准的转义单引号(两个连续单引号 '' 表示字面量引号)。 func splitSQLStatements(sql string) []string { text := strings.ReplaceAll(sql, "\r\n", "\n") @@ -18,6 +19,9 @@ func splitSQLStatements(sql string) []string { inLineComment := false inBlockComment := false var dollarTag string // postgres/kingbase: $$...$$ or $tag$...$tag$ + plsqlDepth := 0 + plsqlDeclareBeginSkips := 0 + justClosedPLSQLBlock := false push := func() { s := strings.TrimSpace(cur.String()) @@ -108,6 +112,35 @@ func splitSQLStatements(sql string) []string { continue } + if isSQLIdentifierStart(ch) { + tokenStart := i + tokenEnd := i + 1 + for tokenEnd < len(text) && isSQLIdentifierPart(text[tokenEnd]) { + tokenEnd++ + } + token := strings.ToLower(text[tokenStart:tokenEnd]) + if token == "begin" && plsqlDeclareBeginSkips > 0 { + plsqlDeclareBeginSkips-- + justClosedPLSQLBlock = false + } else if token == "begin" && shouldEnterPLSQLBlock(text, tokenEnd) { + plsqlDepth++ + justClosedPLSQLBlock = false + } else if token == "declare" && shouldEnterPLSQLDeclareBlock(text, tokenEnd) { + plsqlDepth++ + plsqlDeclareBeginSkips++ + justClosedPLSQLBlock = false + } else if token == "end" && plsqlDepth > 0 && !isPLSQLControlEnd(text, tokenEnd) { + plsqlDepth-- + if plsqlDeclareBeginSkips > plsqlDepth { + plsqlDeclareBeginSkips = plsqlDepth + } + justClosedPLSQLBlock = plsqlDepth == 0 + } + cur.WriteString(text[tokenStart:tokenEnd]) + i = tokenEnd - 1 + continue + } + // 行注释开始 if ch == '-' && next == '-' { inLineComment = true @@ -140,11 +173,33 @@ func splitSQLStatements(sql string) []string { // 分号分隔(支持全角分号";") if ch == ';' { + if plsqlDepth > 0 { + cur.WriteByte(ch) + continue + } + if justClosedPLSQLBlock { + cur.WriteByte(ch) + push() + justClosedPLSQLBlock = false + continue + } push() continue } // 全角分号 UTF-8 序列: 0xEF 0xBC 0x9B if ch == 0xEF && i+2 < len(text) && text[i+1] == 0xBC && text[i+2] == 0x9B { + if plsqlDepth > 0 { + cur.WriteString(";") + i += 2 + continue + } + if justClosedPLSQLBlock { + cur.WriteString(";") + push() + justClosedPLSQLBlock = false + i += 2 + continue + } push() i += 2 continue @@ -157,6 +212,110 @@ func splitSQLStatements(sql string) []string { return statements } +func isSQLIdentifierStart(ch byte) bool { + return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '_' +} + +func isSQLIdentifierPart(ch byte) bool { + return isSQLIdentifierStart(ch) || (ch >= '0' && ch <= '9') || ch == '$' || ch == '#' +} + +func skipSQLWhitespaceAndComments(text string, pos int) int { + i := pos + for i < len(text) { + switch text[i] { + case ' ', '\t', '\n', '\r', '\f': + i++ + continue + case '-': + if i+1 < len(text) && text[i+1] == '-' { + i += 2 + for i < len(text) && text[i] != '\n' { + i++ + } + continue + } + case '/': + if i+1 < len(text) && text[i+1] == '*' { + i += 2 + for i+1 < len(text) && !(text[i] == '*' && text[i+1] == '/') { + i++ + } + if i+1 < len(text) { + i += 2 + } + continue + } + } + break + } + return i +} + +func nextSQLSignificantToken(text string, pos int) string { + i := skipSQLWhitespaceAndComments(text, pos) + if i >= len(text) || !isSQLIdentifierStart(text[i]) { + return "" + } + end := i + 1 + for end < len(text) && isSQLIdentifierPart(text[end]) { + end++ + } + return strings.ToLower(text[i:end]) +} + +func nextSQLSignificantByte(text string, pos int) byte { + i := skipSQLWhitespaceAndComments(text, pos) + if i >= len(text) { + return 0 + } + return text[i] +} + +func shouldEnterPLSQLBlock(text string, tokenEnd int) bool { + switch nextSQLSignificantByte(text, tokenEnd) { + case 0, ';': + return false + } + switch nextSQLSignificantToken(text, tokenEnd) { + case "transaction", "work", "isolation", "read", "write": + return false + default: + return true + } +} + +func isPLSQLBlockStatement(stmt string) bool { + text := strings.TrimSpace(stmt) + if text == "" { + return false + } + if strings.HasSuffix(text, "/") { + text = strings.TrimSpace(strings.TrimSuffix(text, "/")) + } + token := nextSQLSignificantToken(text, 0) + if token == "declare" { + return shouldEnterPLSQLDeclareBlock(text, len("declare")) + } + if token != "begin" { + return false + } + return shouldEnterPLSQLBlock(text, len("begin")) +} + +func shouldEnterPLSQLDeclareBlock(text string, tokenEnd int) bool { + return nextSQLSignificantToken(text, tokenEnd) != "" +} + +func isPLSQLControlEnd(text string, tokenEnd int) bool { + switch nextSQLSignificantToken(text, tokenEnd) { + case "if", "loop", "case": + return true + default: + return false + } +} + // parseSQLDollarTag 解析 PostgreSQL/Kingbase 的 dollar-quoting 标签。 func parseSQLDollarTag(s string) string { if len(s) < 2 || s[0] != '$' { diff --git a/internal/app/sql_split_stream.go b/internal/app/sql_split_stream.go index 7368494..cfc0a6a 100644 --- a/internal/app/sql_split_stream.go +++ b/internal/app/sql_split_stream.go @@ -18,6 +18,9 @@ type sqlStreamSplitter struct { inLineComment bool inBlockComment bool dollarTag string + plsqlDepth int + declareSkips int + closedPLSQL bool } // Feed 将一个 chunk 喂入拆分器,返回在此 chunk 中完成的 SQL 语句列表。 @@ -118,6 +121,43 @@ func (s *sqlStreamSplitter) Feed(chunk []byte) []string { continue } + if isSQLIdentifierStart(ch) { + tokenStart := i + tokenEnd := i + 1 + for tokenEnd < len(text) && isSQLIdentifierPart(text[tokenEnd]) { + tokenEnd++ + } + token := strings.ToLower(text[tokenStart:tokenEnd]) + if shouldDeferPLSQLKeywordPrefixInStream(text, tokenStart, tokenEnd, token) { + s.pending = text[tokenStart:] + break + } + if shouldDeferPLSQLKeywordInStream(text, tokenStart, tokenEnd, token) { + s.pending = text[tokenStart:] + break + } + if token == "begin" && s.declareSkips > 0 { + s.declareSkips-- + s.closedPLSQL = false + } else if token == "begin" && shouldEnterPLSQLBlock(text, tokenEnd) { + s.plsqlDepth++ + s.closedPLSQL = false + } else if token == "declare" && shouldEnterPLSQLDeclareBlock(text, tokenEnd) { + s.plsqlDepth++ + s.declareSkips++ + s.closedPLSQL = false + } else if token == "end" && s.plsqlDepth > 0 && !isPLSQLControlEnd(text, tokenEnd) { + s.plsqlDepth-- + if s.declareSkips > s.plsqlDepth { + s.declareSkips = s.plsqlDepth + } + s.closedPLSQL = s.plsqlDepth == 0 + } + s.cur.WriteString(text[tokenStart:tokenEnd]) + i = tokenEnd - 1 + continue + } + // 行注释开始 if ch == '-' && i+1 >= len(text) { s.pending = text[i:] @@ -162,6 +202,20 @@ func (s *sqlStreamSplitter) Feed(chunk []byte) []string { // 分号分隔 if ch == ';' { + if s.plsqlDepth > 0 { + s.cur.WriteByte(ch) + continue + } + if s.closedPLSQL { + s.cur.WriteByte(ch) + stmt := strings.TrimSpace(s.cur.String()) + if stmt != "" { + statements = append(statements, stmt) + } + s.cur.Reset() + s.closedPLSQL = false + continue + } stmt := strings.TrimSpace(s.cur.String()) if stmt != "" { statements = append(statements, stmt) @@ -175,6 +229,22 @@ func (s *sqlStreamSplitter) Feed(chunk []byte) []string { break } if ch == 0xEF && i+2 < len(text) && text[i+1] == 0xBC && text[i+2] == 0x9B { + if s.plsqlDepth > 0 { + s.cur.WriteString(";") + i += 2 + continue + } + if s.closedPLSQL { + s.cur.WriteString(";") + stmt := strings.TrimSpace(s.cur.String()) + if stmt != "" { + statements = append(statements, stmt) + } + s.cur.Reset() + s.closedPLSQL = false + i += 2 + continue + } stmt := strings.TrimSpace(s.cur.String()) if stmt != "" { statements = append(statements, stmt) @@ -217,6 +287,44 @@ func isIncompleteSQLDollarTag(s string) bool { return true } +func shouldDeferPLSQLKeywordInStream(text string, tokenStart int, tokenEnd int, token string) bool { + switch token { + case "begin", "declare", "end": + default: + return false + } + if tokenEnd >= len(text) { + return true + } + next := skipSQLWhitespaceAndComments(text, tokenEnd) + if next >= len(text) { + return true + } + if isSQLIdentifierStart(text[next]) { + nextEnd := next + 1 + for nextEnd < len(text) && isSQLIdentifierPart(text[nextEnd]) { + nextEnd++ + } + return nextEnd >= len(text) + } + return false +} + +func shouldDeferPLSQLKeywordPrefixInStream(text string, tokenStart int, tokenEnd int, token string) bool { + if tokenEnd < len(text) { + return false + } + for _, keyword := range []string{"begin", "declare", "end"} { + if strings.HasPrefix(keyword, token) && token != keyword { + if tokenStart > 0 && isSQLIdentifierPart(text[tokenStart-1]) { + return false + } + return true + } + } + return false +} + // streamSQLFile 从 reader 中流式读取 SQL 并逐条回调。 // onStatement 返回 error 时停止读取并返回该 error。 // 返回总处理语句数和可能的错误。 diff --git a/internal/app/sql_split_test.go b/internal/app/sql_split_test.go index 7c9c3d3..b4e4586 100644 --- a/internal/app/sql_split_test.go +++ b/internal/app/sql_split_test.go @@ -111,3 +111,59 @@ func TestSplitSQLStatements_SQLEscapedQuoteMultiple(t *testing.T) { } } +func TestSplitSQLStatements_OracleAnonymousBlock(t *testing.T) { + input := `BEGIN + INSERT INTO tmp_disable_trigger (table_name) VALUES ('t_memcard_reg'); + UPDATE t_memcard_reg SET CARDLEVEL = 1 WHERE MEMCARDNO = '8032277312'; + DELETE FROM tmp_disable_trigger WHERE table_name = 't_memcard_reg'; +END; +SELECT 1 FROM dual;` + got := splitSQLStatements(input) + want := []string{ + `BEGIN + INSERT INTO tmp_disable_trigger (table_name) VALUES ('t_memcard_reg'); + UPDATE t_memcard_reg SET CARDLEVEL = 1 WHERE MEMCARDNO = '8032277312'; + DELETE FROM tmp_disable_trigger WHERE table_name = 't_memcard_reg'; +END;`, + "SELECT 1 FROM dual", + } + if !reflect.DeepEqual(got, want) { + t.Errorf("splitSQLStatements(%q) = %#v, want %#v", input, got, want) + } +} + +func TestSplitSQLStatements_OracleDeclareBlock(t *testing.T) { + input := `DECLARE + v_count NUMBER; +BEGIN + SELECT COUNT(*) INTO v_count FROM t_memcard_reg; + UPDATE t_memcard_reg SET CARDLEVEL = v_count WHERE MEMCARDNO = '8032277312'; +END; +SELECT 1 FROM dual;` + got := splitSQLStatements(input) + want := []string{ + `DECLARE + v_count NUMBER; +BEGIN + SELECT COUNT(*) INTO v_count FROM t_memcard_reg; + UPDATE t_memcard_reg SET CARDLEVEL = v_count WHERE MEMCARDNO = '8032277312'; +END;`, + "SELECT 1 FROM dual", + } + if !reflect.DeepEqual(got, want) { + t.Errorf("splitSQLStatements(%q) = %#v, want %#v", input, got, want) + } +} + +func TestSplitSQLStatements_TransactionBeginStillSplits(t *testing.T) { + input := "BEGIN; UPDATE accounts SET balance = balance - 1 WHERE id = 1; COMMIT;" + got := splitSQLStatements(input) + want := []string{ + "BEGIN", + "UPDATE accounts SET balance = balance - 1 WHERE id = 1", + "COMMIT", + } + if !reflect.DeepEqual(got, want) { + t.Errorf("splitSQLStatements(%q) = %#v, want %#v", input, got, want) + } +}