diff --git a/frontend/package.json.md5 b/frontend/package.json.md5
index 7396e24..bed8925 100755
--- a/frontend/package.json.md5
+++ b/frontend/package.json.md5
@@ -1 +1 @@
-d0464f9da25e9356e61652e638c99ffe
\ No newline at end of file
+0295a42fd931778d85157816d79d29e5
\ No newline at end of file
diff --git a/frontend/src/components/QueryEditor.external-sql-save.test.tsx b/frontend/src/components/QueryEditor.external-sql-save.test.tsx
index 2d4df89..a38f955 100644
--- a/frontend/src/components/QueryEditor.external-sql-save.test.tsx
+++ b/frontend/src/components/QueryEditor.external-sql-save.test.tsx
@@ -1676,6 +1676,42 @@ describe('QueryEditor external SQL save', () => {
renderer?.unmount();
});
+ it('keeps Oracle anonymous PL/SQL blocks intact when running from the editor', async () => {
+ storeState.connections[0].config.type = 'oracle';
+ storeState.connections[0].config.database = 'ORCLPDB1';
+ backendApp.DBQueryMulti.mockResolvedValueOnce({
+ success: true,
+ data: [{ columns: ['affectedRows'], rows: [{ affectedRows: 1 }] }],
+ });
+ const plsql = [
+ 'BEGIN',
+ " INSERT INTO tmp_disable_trigger (table_name) VALUES ('t_memcard_reg');",
+ " UPDATE t_memcard_reg SET CARDLEVEL = 1 WHERE MEMCARDNO = '8032277312';",
+ " DELETE FROM tmp_disable_trigger WHERE table_name = 't_memcard_reg';",
+ 'END;',
+ ].join('\n');
+
+ let renderer!: ReactTestRenderer;
+ await act(async () => {
+ renderer = create();
+ });
+
+ await act(async () => {
+ await findButton(renderer!, '运行').props.onClick();
+ });
+ await act(async () => {
+ await Promise.resolve();
+ await Promise.resolve();
+ });
+
+ expect(backendApp.DBQueryMulti).toHaveBeenCalledWith(expect.anything(), 'ORCLPDB1', plsql, 'query-1');
+ expect(storeState.addSqlLog).toHaveBeenCalledWith(expect.objectContaining({
+ sql: plsql,
+ status: 'success',
+ }));
+ renderer?.unmount();
+ });
+
it('keeps non-Oracle query results read-only when no safe locator exists', async () => {
backendApp.DBQueryMulti.mockResolvedValueOnce({
success: true,
@@ -1710,6 +1746,46 @@ describe('QueryEditor external SQL save', () => {
expect(messageApi.warning).toHaveBeenCalledWith('查询结果保持只读:main.users 未检测到主键或可用唯一索引,无法安全提交修改。');
});
+ it('keeps MySQL information_schema routine results read-only without a locator warning', async () => {
+ const sql = [
+ 'SELECT ROUTINE_SCHEMA, ROUTINE_NAME, DEFINER, SECURITY_TYPE',
+ 'FROM information_schema.ROUTINES',
+ "WHERE ROUTINE_SCHEMA = 'mkefu_location_dev_local'",
+ " AND ROUTINE_NAME = 'init_orgi'",
+ ].join('\n');
+ backendApp.DBQueryMulti.mockResolvedValueOnce({
+ success: true,
+ data: [{
+ columns: ['ROUTINE_SCHEMA', 'ROUTINE_NAME', 'DEFINER', 'SECURITY_TYPE'],
+ rows: [{
+ ROUTINE_SCHEMA: 'mkefu_location_dev_local',
+ ROUTINE_NAME: 'init_orgi',
+ DEFINER: 'root@%',
+ SECURITY_TYPE: 'DEFINER',
+ }],
+ }],
+ });
+
+ let renderer: ReactTestRenderer;
+ await act(async () => {
+ renderer = create();
+ });
+
+ await act(async () => {
+ await findButton(renderer!, '运行').props.onClick();
+ });
+ await act(async () => {
+ await Promise.resolve();
+ await Promise.resolve();
+ });
+
+ expect(dataGridState.latestProps?.tableName).toBe('ROUTINES');
+ expect(dataGridState.latestProps?.readOnly).toBe(true);
+ expect(backendApp.DBGetColumns).not.toHaveBeenCalled();
+ expect(backendApp.DBGetIndexes).not.toHaveBeenCalled();
+ expect(messageApi.warning).not.toHaveBeenCalled();
+ });
+
it('runs the SQL statement at the cursor instead of the whole editor when nothing is selected', async () => {
backendApp.DBQueryMulti.mockResolvedValueOnce({
success: true,
diff --git a/frontend/src/components/QueryEditor.tsx b/frontend/src/components/QueryEditor.tsx
index b9fa1f3..d58337f 100644
--- a/frontend/src/components/QueryEditor.tsx
+++ b/frontend/src/components/QueryEditor.tsx
@@ -262,6 +262,33 @@ const stripQueryIdentifierQuotes = (part: string): string => {
return text;
};
+const MYSQL_SYSTEM_METADATA_SCHEMAS = new Set(['information_schema', 'performance_schema', 'mysql', 'sys']);
+const POSTGRES_SYSTEM_METADATA_SCHEMAS = new Set(['information_schema', 'pg_catalog']);
+const SQLITE_SYSTEM_METADATA_TABLES = new Set(['sqlite_master', 'sqlite_schema', 'sqlite_temp_master', 'sqlite_temp_schema']);
+
+const isSystemMetadataQueryResult = (tableRef: QueryResultTableRef, dbType: string): boolean => {
+ const normalizedDbType = String(dbType || '').trim().toLowerCase();
+ const metadataDbName = stripQueryIdentifierQuotes(tableRef.metadataDbName).toLowerCase();
+ const metadataTableName = stripQueryIdentifierQuotes(tableRef.metadataTableName).toLowerCase();
+
+ if (['mysql', 'mariadb', 'oceanbase', 'diros', 'starrocks', 'sphinx', 'tidb'].includes(normalizedDbType)) {
+ return MYSQL_SYSTEM_METADATA_SCHEMAS.has(metadataDbName);
+ }
+ if (['postgres', 'kingbase', 'highgo', 'vastbase', 'opengauss'].includes(normalizedDbType)) {
+ return POSTGRES_SYSTEM_METADATA_SCHEMAS.has(metadataDbName);
+ }
+ if (normalizedDbType === 'sqlite' || normalizedDbType === 'duckdb') {
+ return SQLITE_SYSTEM_METADATA_TABLES.has(metadataTableName) || metadataDbName === 'information_schema';
+ }
+ if (normalizedDbType === 'sqlserver') {
+ return metadataDbName === 'information_schema' || metadataDbName === 'sys';
+ }
+ if (normalizedDbType === 'clickhouse') {
+ return metadataDbName === 'system' || metadataDbName === 'information_schema';
+ }
+ return false;
+};
+
const splitTopLevelComma = (text: string): string[] => {
const parts: string[] = [];
let current = '';
@@ -658,6 +685,65 @@ const areSqlStatementListsEqual = (left: string[], right: string[]): boolean =>
&& left.every((statement, index) => normalizeExecutedSqlKey(statement) === normalizeExecutedSqlKey(right[index]))
);
+const isSqlIdentifierStart = (ch: string): boolean => /^[A-Za-z_]$/.test(ch);
+
+const isSqlIdentifierPart = (ch: string): boolean => /^[A-Za-z0-9_$#]$/.test(ch);
+
+const skipSqlWhitespaceAndComments = (text: string, position: number): number => {
+ let index = position;
+ while (index < text.length) {
+ const ch = text[index];
+ const next = index + 1 < text.length ? text[index + 1] : '';
+ if (ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r' || ch === '\f') {
+ index += 1;
+ continue;
+ }
+ if (ch === '-' && next === '-') {
+ index += 2;
+ while (index < text.length && text[index] !== '\n') index += 1;
+ continue;
+ }
+ if (ch === '/' && next === '*') {
+ index += 2;
+ while (index + 1 < text.length && !(text[index] === '*' && text[index + 1] === '/')) {
+ index += 1;
+ }
+ if (index + 1 < text.length) index += 2;
+ continue;
+ }
+ break;
+ }
+ return index;
+};
+
+const nextSqlSignificantToken = (text: string, position: number): string => {
+ const index = skipSqlWhitespaceAndComments(text, position);
+ if (index >= text.length || !isSqlIdentifierStart(text[index])) return '';
+ let end = index + 1;
+ while (end < text.length && isSqlIdentifierPart(text[end])) end += 1;
+ return text.slice(index, end).toLowerCase();
+};
+
+const nextSqlSignificantChar = (text: string, position: number): string => {
+ const index = skipSqlWhitespaceAndComments(text, position);
+ return index >= text.length ? '' : text[index];
+};
+
+const shouldEnterPlsqlBeginBlock = (text: string, tokenEnd: number): boolean => {
+ const nextChar = nextSqlSignificantChar(text, tokenEnd);
+ if (!nextChar || nextChar === ';') return false;
+ return !['transaction', 'work', 'isolation', 'read', 'write'].includes(nextSqlSignificantToken(text, tokenEnd));
+};
+
+const shouldEnterPlsqlDeclareBlock = (text: string, tokenEnd: number): boolean => {
+ const nextToken = nextSqlSignificantToken(text, tokenEnd);
+ return Boolean(nextToken);
+};
+
+const isPlsqlControlEnd = (text: string, tokenEnd: number): boolean => (
+ ['if', 'loop', 'case'].includes(nextSqlSignificantToken(text, tokenEnd))
+);
+
const normalizeEditorPosition = (position: any): { lineNumber: number; column: number } | null => {
if (!position) return null;
const lineNumber = Number(position.positionLineNumber ?? position.lineNumber ?? position.endLineNumber ?? position.startLineNumber ?? position.selectionStartLineNumber);
@@ -1563,6 +1649,10 @@ const resolveQueryLocatorPlan = async ({
const tableRef = extractQueryResultTableRef(statement, dbType, currentDb);
if (!tableRef) return plan;
plan.tableRef = tableRef;
+ if (isSystemMetadataQueryResult(tableRef, dbType)) {
+ plan.editLocator = buildQueryReadOnlyLocator('系统元数据查询结果保持只读。');
+ return plan;
+ }
const selectInfo = parseSimpleSelectInfo(statement);
if (!selectInfo) {
@@ -3604,6 +3694,9 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
let inLineComment = false;
let inBlockComment = false;
let dollarTag: string | null = null; // postgres/kingbase: $$...$$ or $tag$...$tag$
+ let plsqlDepth = 0;
+ let plsqlDeclareBeginSkips = 0;
+ let justClosedPLSQLBlock = false;
const push = () => {
const s = cur.trim();
@@ -3705,7 +3798,45 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
continue;
}
+ if (!inSingle && !inDouble && !inBacktick && !dollarTag && isSqlIdentifierStart(ch)) {
+ let end = i + 1;
+ while (end < text.length && isSqlIdentifierPart(text[end])) {
+ end += 1;
+ }
+ const token = text.slice(i, end).toLowerCase();
+ if (token === 'begin' && plsqlDeclareBeginSkips > 0) {
+ plsqlDeclareBeginSkips -= 1;
+ justClosedPLSQLBlock = false;
+ } else if (token === 'begin' && shouldEnterPlsqlBeginBlock(text, end)) {
+ plsqlDepth += 1;
+ justClosedPLSQLBlock = false;
+ } else if (token === 'declare' && shouldEnterPlsqlDeclareBlock(text, end)) {
+ plsqlDepth += 1;
+ plsqlDeclareBeginSkips += 1;
+ justClosedPLSQLBlock = false;
+ } else if (token === 'end' && plsqlDepth > 0 && !isPlsqlControlEnd(text, end)) {
+ plsqlDepth -= 1;
+ if (plsqlDeclareBeginSkips > plsqlDepth) {
+ plsqlDeclareBeginSkips = plsqlDepth;
+ }
+ justClosedPLSQLBlock = plsqlDepth === 0;
+ }
+ cur += text.slice(i, end);
+ i = end - 1;
+ continue;
+ }
+
if (!inSingle && !inDouble && !inBacktick && !dollarTag && (ch === ';' || ch === ';')) {
+ if (plsqlDepth > 0) {
+ cur += ch;
+ continue;
+ }
+ if (justClosedPLSQLBlock) {
+ cur += ch;
+ push();
+ justClosedPLSQLBlock = false;
+ continue;
+ }
push();
continue;
}
diff --git a/frontend/src/utils/sqlStatementSelection.test.ts b/frontend/src/utils/sqlStatementSelection.test.ts
index 95a940e..a21f749 100644
--- a/frontend/src/utils/sqlStatementSelection.test.ts
+++ b/frontend/src/utils/sqlStatementSelection.test.ts
@@ -28,6 +28,74 @@ describe('sqlStatementSelection', () => {
]);
});
+ it('keeps Oracle anonymous PL/SQL blocks as one executable statement', () => {
+ const plsql = [
+ 'BEGIN',
+ " INSERT INTO tmp_disable_trigger (table_name) VALUES ('t_memcard_reg');",
+ " UPDATE t_memcard_reg SET CARDLEVEL = 1 WHERE MEMCARDNO = '8032277312';",
+ " DELETE FROM tmp_disable_trigger WHERE table_name = 't_memcard_reg';",
+ 'END;',
+ 'SELECT 1 FROM dual;',
+ ].join('\n');
+
+ const ranges = findSqlStatementRanges(plsql).map((range) => range.text);
+
+ expect(ranges).toEqual([
+ [
+ 'BEGIN',
+ " INSERT INTO tmp_disable_trigger (table_name) VALUES ('t_memcard_reg');",
+ " UPDATE t_memcard_reg SET CARDLEVEL = 1 WHERE MEMCARDNO = '8032277312';",
+ " DELETE FROM tmp_disable_trigger WHERE table_name = 't_memcard_reg';",
+ 'END;',
+ ].join('\n'),
+ 'SELECT 1 FROM dual',
+ ]);
+ expect(resolveExecutableSql(plsql, plsql.indexOf('UPDATE'))).toEqual({
+ sql: ranges[0],
+ source: 'statement',
+ });
+ });
+
+ it('keeps Oracle DECLARE blocks as one executable statement', () => {
+ const sql = [
+ 'DECLARE',
+ ' v_count NUMBER;',
+ 'BEGIN',
+ ' SELECT COUNT(*) INTO v_count FROM t_memcard_reg;',
+ " UPDATE t_memcard_reg SET CARDLEVEL = v_count WHERE MEMCARDNO = '8032277312';",
+ 'END;',
+ 'SELECT 1 FROM dual;',
+ ].join('\n');
+
+ const ranges = findSqlStatementRanges(sql).map((range) => range.text);
+
+ expect(ranges).toEqual([
+ [
+ 'DECLARE',
+ ' v_count NUMBER;',
+ 'BEGIN',
+ ' SELECT COUNT(*) INTO v_count FROM t_memcard_reg;',
+ " UPDATE t_memcard_reg SET CARDLEVEL = v_count WHERE MEMCARDNO = '8032277312';",
+ 'END;',
+ ].join('\n'),
+ 'SELECT 1 FROM dual',
+ ]);
+ expect(resolveExecutableSql(sql, sql.indexOf('UPDATE'))).toEqual({
+ sql: ranges[0],
+ source: 'statement',
+ });
+ });
+
+ it('still splits transaction BEGIN statements', () => {
+ const sql = 'BEGIN; UPDATE accounts SET balance = balance - 1 WHERE id = 1; COMMIT;';
+
+ expect(findSqlStatementRanges(sql).map((range) => range.text)).toEqual([
+ 'BEGIN',
+ 'UPDATE accounts SET balance = balance - 1 WHERE id = 1',
+ 'COMMIT',
+ ]);
+ });
+
it('selects the next statement when the cursor is on whitespace before it', () => {
const sql = 'select 1;\n\n select 2;';
const range = resolveCurrentSqlStatementRange(sql, sql.indexOf(' select 2'));
diff --git a/frontend/src/utils/sqlStatementSelection.ts b/frontend/src/utils/sqlStatementSelection.ts
index 181dc51..8a3dfe8 100644
--- a/frontend/src/utils/sqlStatementSelection.ts
+++ b/frontend/src/utils/sqlStatementSelection.ts
@@ -12,7 +12,63 @@ export interface SqlExecutionSelection {
}
const isWhitespace = (ch: string): boolean => (
- ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r'
+ ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r' || ch === '\f'
+);
+
+const isSqlIdentifierStart = (ch: string): boolean => /^[A-Za-z_]$/.test(ch);
+
+const isSqlIdentifierPart = (ch: string): boolean => /^[A-Za-z0-9_$#]$/.test(ch);
+
+const skipSqlWhitespaceAndComments = (text: string, position: number): number => {
+ let index = position;
+ while (index < text.length) {
+ const ch = text[index];
+ const next = index + 1 < text.length ? text[index + 1] : '';
+ if (isWhitespace(ch)) {
+ index += 1;
+ continue;
+ }
+ if (ch === '-' && next === '-') {
+ index += 2;
+ while (index < text.length && text[index] !== '\n') index += 1;
+ continue;
+ }
+ if (ch === '/' && next === '*') {
+ index += 2;
+ while (index + 1 < text.length && !(text[index] === '*' && text[index + 1] === '/')) {
+ index += 1;
+ }
+ if (index + 1 < text.length) index += 2;
+ continue;
+ }
+ break;
+ }
+ return index;
+};
+
+const nextSqlSignificantToken = (text: string, position: number): string => {
+ const index = skipSqlWhitespaceAndComments(text, position);
+ if (index >= text.length || !isSqlIdentifierStart(text[index])) return '';
+ let end = index + 1;
+ while (end < text.length && isSqlIdentifierPart(text[end])) end += 1;
+ return text.slice(index, end).toLowerCase();
+};
+
+const nextSqlSignificantChar = (text: string, position: number): string => {
+ const index = skipSqlWhitespaceAndComments(text, position);
+ return index >= text.length ? '' : text[index];
+};
+
+const shouldEnterPlsqlBeginBlock = (text: string, tokenEnd: number): boolean => {
+ const nextChar = nextSqlSignificantChar(text, tokenEnd);
+ if (!nextChar || nextChar === ';') return false;
+ return !['transaction', 'work', 'isolation', 'read', 'write'].includes(nextSqlSignificantToken(text, tokenEnd));
+};
+
+const shouldEnterPlsqlDeclareBlock = (text: string, tokenEnd: number): boolean => Boolean(nextSqlSignificantToken(text, tokenEnd));
+
+const isPlsqlControlEnd = (text: string, tokenEnd: number): boolean => (
+ ['if', 'loop', 'case'].includes(nextSqlSignificantToken(text, tokenEnd))
);
const trimStatementRange = (sql: string, start: number, end: number): SqlStatementRange | null => {
@@ -49,6 +105,9 @@ export const findSqlStatementRanges = (sql: string): SqlStatementRange[] => {
let inLineComment = false;
let inBlockComment = false;
let dollarTag: string | null = null;
+ let plsqlDepth = 0;
+ let plsqlDeclareBeginSkips = 0;
+ let justClosedPLSQLBlock = false;
const push = (end: number) => {
const range = trimStatementRange(text, statementStart, end);
@@ -134,9 +193,41 @@ export const findSqlStatementRanges = (sql: string): SqlStatementRange[] => {
continue;
}
+ if (!inSingle && !inDouble && !inBacktick && !dollarTag && isSqlIdentifierStart(ch)) {
+ let tokenEnd = index + 1;
+ while (tokenEnd < text.length && isSqlIdentifierPart(text[tokenEnd])) {
+ tokenEnd++;
+ }
+ const token = text.slice(index, tokenEnd).toLowerCase();
+ if (token === 'begin' && plsqlDeclareBeginSkips > 0) {
+ plsqlDeclareBeginSkips--;
+ justClosedPLSQLBlock = false;
+ } else if (token === 'begin' && shouldEnterPlsqlBeginBlock(text, tokenEnd)) {
+ plsqlDepth++;
+ justClosedPLSQLBlock = false;
+ } else if (token === 'declare' && shouldEnterPlsqlDeclareBlock(text, tokenEnd)) {
+ plsqlDepth++;
+ plsqlDeclareBeginSkips++;
+ justClosedPLSQLBlock = false;
+ } else if (token === 'end' && plsqlDepth > 0 && !isPlsqlControlEnd(text, tokenEnd)) {
+ plsqlDepth--;
+ if (plsqlDeclareBeginSkips > plsqlDepth) {
+ plsqlDeclareBeginSkips = plsqlDepth;
+ }
+ justClosedPLSQLBlock = plsqlDepth === 0;
+ }
+ index = tokenEnd - 1;
+ continue;
+ }
+
if (!inSingle && !inDouble && !inBacktick && (ch === ';' || ch === ';')) {
- push(index);
+ if (plsqlDepth > 0) {
+ continue;
+ }
+ push(justClosedPLSQLBlock ? index + 1 : index);
statementStart = index + 1;
+ justClosedPLSQLBlock = false;
+ continue;
}
}
diff --git a/internal/app/methods_db.go b/internal/app/methods_db.go
index 1cf99b5..c6625d3 100644
--- a/internal/app/methods_db.go
+++ b/internal/app/methods_db.go
@@ -770,13 +770,16 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
// 适用于 MySQL/MariaDB/Doris/PostgreSQL/SQLite/DuckDB 等支持多语句 Exec 的驱动
if !allReadOnly {
allWrite := true
+ containsPLSQLBlock := false
for _, stmt := range statements {
if strings.TrimSpace(stmt) != "" && isReadOnlySQLQuery(runConfig.Type, stmt) {
allWrite = false
- break
+ }
+ if isPLSQLBlockStatement(stmt) {
+ containsPLSQLBlock = true
}
}
- if allWrite {
+ if allWrite && !containsPLSQLBlock {
if batcher, ok := dbInst.(db.BatchWriteExecer); ok {
affected, batchErr := batcher.ExecBatchContext(ctx, query)
if batchErr != nil && shouldRefreshCachedConnection(batchErr) {
diff --git a/internal/app/methods_db_multi_test.go b/internal/app/methods_db_multi_test.go
index 1f0af6d..ff30389 100644
--- a/internal/app/methods_db_multi_test.go
+++ b/internal/app/methods_db_multi_test.go
@@ -12,6 +12,7 @@ import (
type fakeBatchWriteDB struct {
batchCalls int
execCalls int
+ execQueries []string
lastQuery string
}
@@ -33,6 +34,7 @@ func (f *fakeBatchWriteDB) Query(query string) ([]map[string]interface{}, []stri
func (f *fakeBatchWriteDB) Exec(query string) (int64, error) {
f.execCalls++
+ f.execQueries = append(f.execQueries, query)
return 1, nil
}
@@ -70,6 +72,7 @@ func (f *fakeBatchWriteDB) GetTriggers(dbName, tableName string) ([]connection.T
func (f *fakeBatchWriteDB) ExecContext(ctx context.Context, query string) (int64, error) {
f.execCalls++
+ f.execQueries = append(f.execQueries, query)
return 1, nil
}
@@ -79,6 +82,45 @@ func (f *fakeBatchWriteDB) ExecBatchContext(ctx context.Context, query string) (
return 500, nil
}
+func TestDBQueryMultiKeepsOracleAnonymousBlockAsSingleStatement(t *testing.T) {
+ originalNewDatabaseFunc := newDatabaseFunc
+ t.Cleanup(func() {
+ newDatabaseFunc = originalNewDatabaseFunc
+ })
+
+ fakeDB := &fakeBatchWriteDB{}
+ newDatabaseFunc = func(dbType string) (db.Database, error) {
+ return fakeDB, nil
+ }
+
+ app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
+ config := connection.ConnectionConfig{
+ Type: "oracle",
+ Host: "127.0.0.1",
+ Port: 1521,
+ User: "app",
+ }
+ query := `BEGIN
+ INSERT INTO tmp_disable_trigger (table_name) VALUES ('t_memcard_reg');
+ UPDATE t_memcard_reg SET CARDLEVEL = 1 WHERE MEMCARDNO = '8032277312';
+ DELETE FROM tmp_disable_trigger WHERE table_name = 't_memcard_reg';
+END;`
+
+ result := app.DBQueryMulti(config, "ORCLPDB1", query, "oracle-plsql-test")
+ if !result.Success {
+ t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message)
+ }
+ if fakeDB.batchCalls != 0 {
+ t.Fatalf("expected PL/SQL block to skip batch path, got batchCalls=%d", fakeDB.batchCalls)
+ }
+ if fakeDB.execCalls != 1 || len(fakeDB.execQueries) != 1 {
+ t.Fatalf("expected one sequential exec call, got execCalls=%d queries=%#v", fakeDB.execCalls, fakeDB.execQueries)
+ }
+ if fakeDB.execQueries[0] != query {
+ t.Fatalf("expected PL/SQL block to stay intact, got %q", fakeDB.execQueries[0])
+ }
+}
+
var _ db.BatchWriteExecer = (*fakeBatchWriteDB)(nil)
func TestDBQueryMultiUsesBatchWriteExecerForAllWriteStatements(t *testing.T) {
diff --git a/internal/app/methods_file.go b/internal/app/methods_file.go
index 19dbd22..0accc15 100644
--- a/internal/app/methods_file.go
+++ b/internal/app/methods_file.go
@@ -662,6 +662,9 @@ func isSQLFileBatchableWriteStatement(dbType string, stmt string) bool {
if isReadOnlySQLQuery(dbType, stmt) {
return false
}
+ if isPLSQLBlockStatement(stmt) {
+ return false
+ }
switch leadingSQLKeyword(stmt) {
case "insert", "update", "delete", "replace", "merge", "upsert":
return true
diff --git a/internal/app/methods_file_sql_execution_test.go b/internal/app/methods_file_sql_execution_test.go
index 928010f..0a6ca6a 100644
--- a/internal/app/methods_file_sql_execution_test.go
+++ b/internal/app/methods_file_sql_execution_test.go
@@ -322,3 +322,38 @@ func TestStreamSQLFileHandlesSplitTokenBoundaries(t *testing.T) {
t.Fatalf("unexpected full-width semicolon statement: %q", statements[2])
}
}
+
+func TestStreamSQLFileKeepsOracleAnonymousBlockTogether(t *testing.T) {
+ input := strings.Join([]string{
+ "BEGIN",
+ " INSERT INTO tmp_disable_trigger (table_name) VALUES ('t_memcard_reg');",
+ " UPDATE t_memcard_reg SET CARDLEVEL = 1 WHERE MEMCARDNO = '8032277312';",
+ " DELETE FROM tmp_disable_trigger WHERE table_name = 't_memcard_reg';",
+ "END;",
+ "SELECT 1 FROM dual;",
+ }, "\n")
+ var statements []string
+
+ count, err := streamSQLFile(&chunkedReader{data: []byte(input), step: 3}, func(index int, stmt string) error {
+ statements = append(statements, stmt)
+ return nil
+ })
+ if err != nil {
+ t.Fatalf("streamSQLFile returned error: %v", err)
+ }
+ if count != 2 || len(statements) != 2 {
+ t.Fatalf("expected 2 statements, got count=%d statements=%#v", count, statements)
+ }
+ if statements[0] != strings.Join([]string{
+ "BEGIN",
+ " INSERT INTO tmp_disable_trigger (table_name) VALUES ('t_memcard_reg');",
+ " UPDATE t_memcard_reg SET CARDLEVEL = 1 WHERE MEMCARDNO = '8032277312';",
+ " DELETE FROM tmp_disable_trigger WHERE table_name = 't_memcard_reg';",
+ "END;",
+ }, "\n") {
+ t.Fatalf("unexpected anonymous block statement: %q", statements[0])
+ }
+ if statements[1] != "SELECT 1 FROM dual" {
+ t.Fatalf("unexpected second statement: %q", statements[1])
+ }
+}
diff --git a/internal/app/sql_split.go b/internal/app/sql_split.go
index f73cebd..bbd9829 100644
--- a/internal/app/sql_split.go
+++ b/internal/app/sql_split.go
@@ -3,8 +3,9 @@ package app
import "strings"
// splitSQLStatements 按分号拆分 SQL 文本为独立语句。
-// 正确处理单引号/双引号/反引号字符串、行注释(-- / #)、块注释(/* */)和
-// PostgreSQL/Kingbase 的 $$...$$ dollar-quoting,避免在这些上下文中错误拆分。
+// 正确处理单引号/双引号/反引号字符串、行注释(-- / #)、块注释(/* */)、
+// PostgreSQL/Kingbase 的 $$...$$ dollar-quoting,以及 Oracle PL/SQL 匿名块,
+// 避免在这些上下文中错误拆分。
// 同时支持 SQL 标准的转义单引号(两个连续单引号 '' 表示字面量引号)。
func splitSQLStatements(sql string) []string {
text := strings.ReplaceAll(sql, "\r\n", "\n")
@@ -18,6 +19,9 @@ func splitSQLStatements(sql string) []string {
inLineComment := false
inBlockComment := false
var dollarTag string // postgres/kingbase: $$...$$ or $tag$...$tag$
+ plsqlDepth := 0
+ plsqlDeclareBeginSkips := 0
+ justClosedPLSQLBlock := false
push := func() {
s := strings.TrimSpace(cur.String())
@@ -108,6 +112,35 @@ func splitSQLStatements(sql string) []string {
continue
}
+ if isSQLIdentifierStart(ch) {
+ tokenStart := i
+ tokenEnd := i + 1
+ for tokenEnd < len(text) && isSQLIdentifierPart(text[tokenEnd]) {
+ tokenEnd++
+ }
+ token := strings.ToLower(text[tokenStart:tokenEnd])
+ if token == "begin" && plsqlDeclareBeginSkips > 0 {
+ plsqlDeclareBeginSkips--
+ justClosedPLSQLBlock = false
+ } else if token == "begin" && shouldEnterPLSQLBlock(text, tokenEnd) {
+ plsqlDepth++
+ justClosedPLSQLBlock = false
+ } else if token == "declare" && shouldEnterPLSQLDeclareBlock(text, tokenEnd) {
+ plsqlDepth++
+ plsqlDeclareBeginSkips++
+ justClosedPLSQLBlock = false
+ } else if token == "end" && plsqlDepth > 0 && !isPLSQLControlEnd(text, tokenEnd) {
+ plsqlDepth--
+ if plsqlDeclareBeginSkips > plsqlDepth {
+ plsqlDeclareBeginSkips = plsqlDepth
+ }
+ justClosedPLSQLBlock = plsqlDepth == 0
+ }
+ cur.WriteString(text[tokenStart:tokenEnd])
+ i = tokenEnd - 1
+ continue
+ }
+
// 行注释开始
if ch == '-' && next == '-' {
inLineComment = true
@@ -140,11 +173,33 @@ func splitSQLStatements(sql string) []string {
// 分号分隔(支持全角分号";")
if ch == ';' {
+ if plsqlDepth > 0 {
+ cur.WriteByte(ch)
+ continue
+ }
+ if justClosedPLSQLBlock {
+ cur.WriteByte(ch)
+ push()
+ justClosedPLSQLBlock = false
+ continue
+ }
push()
continue
}
// 全角分号 UTF-8 序列: 0xEF 0xBC 0x9B
if ch == 0xEF && i+2 < len(text) && text[i+1] == 0xBC && text[i+2] == 0x9B {
+ if plsqlDepth > 0 {
+ cur.WriteString(";")
+ i += 2
+ continue
+ }
+ if justClosedPLSQLBlock {
+ cur.WriteString(";")
+ push()
+ justClosedPLSQLBlock = false
+ i += 2
+ continue
+ }
push()
i += 2
continue
@@ -157,6 +212,110 @@ func splitSQLStatements(sql string) []string {
return statements
}
+func isSQLIdentifierStart(ch byte) bool {
+ return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '_'
+}
+
+func isSQLIdentifierPart(ch byte) bool {
+ return isSQLIdentifierStart(ch) || (ch >= '0' && ch <= '9') || ch == '$' || ch == '#'
+}
+
+func skipSQLWhitespaceAndComments(text string, pos int) int {
+ i := pos
+ for i < len(text) {
+ switch text[i] {
+ case ' ', '\t', '\n', '\r', '\f':
+ i++
+ continue
+ case '-':
+ if i+1 < len(text) && text[i+1] == '-' {
+ i += 2
+ for i < len(text) && text[i] != '\n' {
+ i++
+ }
+ continue
+ }
+ case '/':
+ if i+1 < len(text) && text[i+1] == '*' {
+ i += 2
+ for i+1 < len(text) && !(text[i] == '*' && text[i+1] == '/') {
+ i++
+ }
+ if i+1 < len(text) {
+ i += 2
+ }
+ continue
+ }
+ }
+ break
+ }
+ return i
+}
+
+func nextSQLSignificantToken(text string, pos int) string {
+ i := skipSQLWhitespaceAndComments(text, pos)
+ if i >= len(text) || !isSQLIdentifierStart(text[i]) {
+ return ""
+ }
+ end := i + 1
+ for end < len(text) && isSQLIdentifierPart(text[end]) {
+ end++
+ }
+ return strings.ToLower(text[i:end])
+}
+
+func nextSQLSignificantByte(text string, pos int) byte {
+ i := skipSQLWhitespaceAndComments(text, pos)
+ if i >= len(text) {
+ return 0
+ }
+ return text[i]
+}
+
+func shouldEnterPLSQLBlock(text string, tokenEnd int) bool {
+ switch nextSQLSignificantByte(text, tokenEnd) {
+ case 0, ';':
+ return false
+ }
+ switch nextSQLSignificantToken(text, tokenEnd) {
+ case "transaction", "work", "isolation", "read", "write":
+ return false
+ default:
+ return true
+ }
+}
+
+func isPLSQLBlockStatement(stmt string) bool {
+ text := strings.TrimSpace(stmt)
+ if text == "" {
+ return false
+ }
+ if strings.HasSuffix(text, "/") {
+ text = strings.TrimSpace(strings.TrimSuffix(text, "/"))
+ }
+ token := nextSQLSignificantToken(text, 0)
+ if token == "declare" {
+ return shouldEnterPLSQLDeclareBlock(text, len("declare"))
+ }
+ if token != "begin" {
+ return false
+ }
+ return shouldEnterPLSQLBlock(text, len("begin"))
+}
+
+func shouldEnterPLSQLDeclareBlock(text string, tokenEnd int) bool {
+ return nextSQLSignificantToken(text, tokenEnd) != ""
+}
+
+func isPLSQLControlEnd(text string, tokenEnd int) bool {
+ switch nextSQLSignificantToken(text, tokenEnd) {
+ case "if", "loop", "case":
+ return true
+ default:
+ return false
+ }
+}
+
// parseSQLDollarTag 解析 PostgreSQL/Kingbase 的 dollar-quoting 标签。
func parseSQLDollarTag(s string) string {
if len(s) < 2 || s[0] != '$' {
diff --git a/internal/app/sql_split_stream.go b/internal/app/sql_split_stream.go
index 7368494..cfc0a6a 100644
--- a/internal/app/sql_split_stream.go
+++ b/internal/app/sql_split_stream.go
@@ -18,6 +18,9 @@ type sqlStreamSplitter struct {
inLineComment bool
inBlockComment bool
dollarTag string
+ plsqlDepth int
+ declareSkips int
+ closedPLSQL bool
}
// Feed 将一个 chunk 喂入拆分器,返回在此 chunk 中完成的 SQL 语句列表。
@@ -118,6 +121,43 @@ func (s *sqlStreamSplitter) Feed(chunk []byte) []string {
continue
}
+ if isSQLIdentifierStart(ch) {
+ tokenStart := i
+ tokenEnd := i + 1
+ for tokenEnd < len(text) && isSQLIdentifierPart(text[tokenEnd]) {
+ tokenEnd++
+ }
+ token := strings.ToLower(text[tokenStart:tokenEnd])
+ if shouldDeferPLSQLKeywordPrefixInStream(text, tokenStart, tokenEnd, token) {
+ s.pending = text[tokenStart:]
+ break
+ }
+ if shouldDeferPLSQLKeywordInStream(text, tokenStart, tokenEnd, token) {
+ s.pending = text[tokenStart:]
+ break
+ }
+ if token == "begin" && s.declareSkips > 0 {
+ s.declareSkips--
+ s.closedPLSQL = false
+ } else if token == "begin" && shouldEnterPLSQLBlock(text, tokenEnd) {
+ s.plsqlDepth++
+ s.closedPLSQL = false
+ } else if token == "declare" && shouldEnterPLSQLDeclareBlock(text, tokenEnd) {
+ s.plsqlDepth++
+ s.declareSkips++
+ s.closedPLSQL = false
+ } else if token == "end" && s.plsqlDepth > 0 && !isPLSQLControlEnd(text, tokenEnd) {
+ s.plsqlDepth--
+ if s.declareSkips > s.plsqlDepth {
+ s.declareSkips = s.plsqlDepth
+ }
+ s.closedPLSQL = s.plsqlDepth == 0
+ }
+ s.cur.WriteString(text[tokenStart:tokenEnd])
+ i = tokenEnd - 1
+ continue
+ }
+
// 行注释开始
if ch == '-' && i+1 >= len(text) {
s.pending = text[i:]
@@ -162,6 +202,20 @@ func (s *sqlStreamSplitter) Feed(chunk []byte) []string {
// 分号分隔
if ch == ';' {
+ if s.plsqlDepth > 0 {
+ s.cur.WriteByte(ch)
+ continue
+ }
+ if s.closedPLSQL {
+ s.cur.WriteByte(ch)
+ stmt := strings.TrimSpace(s.cur.String())
+ if stmt != "" {
+ statements = append(statements, stmt)
+ }
+ s.cur.Reset()
+ s.closedPLSQL = false
+ continue
+ }
stmt := strings.TrimSpace(s.cur.String())
if stmt != "" {
statements = append(statements, stmt)
@@ -175,6 +229,22 @@ func (s *sqlStreamSplitter) Feed(chunk []byte) []string {
break
}
if ch == 0xEF && i+2 < len(text) && text[i+1] == 0xBC && text[i+2] == 0x9B {
+ if s.plsqlDepth > 0 {
+ s.cur.WriteString(";")
+ i += 2
+ continue
+ }
+ if s.closedPLSQL {
+ s.cur.WriteString(";")
+ stmt := strings.TrimSpace(s.cur.String())
+ if stmt != "" {
+ statements = append(statements, stmt)
+ }
+ s.cur.Reset()
+ s.closedPLSQL = false
+ i += 2
+ continue
+ }
stmt := strings.TrimSpace(s.cur.String())
if stmt != "" {
statements = append(statements, stmt)
@@ -217,6 +287,44 @@ func isIncompleteSQLDollarTag(s string) bool {
return true
}
+func shouldDeferPLSQLKeywordInStream(text string, tokenStart int, tokenEnd int, token string) bool {
+ switch token {
+ case "begin", "declare", "end":
+ default:
+ return false
+ }
+ if tokenEnd >= len(text) {
+ return true
+ }
+ next := skipSQLWhitespaceAndComments(text, tokenEnd)
+ if next >= len(text) {
+ return true
+ }
+ if isSQLIdentifierStart(text[next]) {
+ nextEnd := next + 1
+ for nextEnd < len(text) && isSQLIdentifierPart(text[nextEnd]) {
+ nextEnd++
+ }
+ return nextEnd >= len(text)
+ }
+ return false
+}
+
+func shouldDeferPLSQLKeywordPrefixInStream(text string, tokenStart int, tokenEnd int, token string) bool {
+ if tokenEnd < len(text) {
+ return false
+ }
+ for _, keyword := range []string{"begin", "declare", "end"} {
+ if strings.HasPrefix(keyword, token) && token != keyword {
+ if tokenStart > 0 && isSQLIdentifierPart(text[tokenStart-1]) {
+ return false
+ }
+ return true
+ }
+ }
+ return false
+}
+
// streamSQLFile 从 reader 中流式读取 SQL 并逐条回调。
// onStatement 返回 error 时停止读取并返回该 error。
// 返回总处理语句数和可能的错误。
diff --git a/internal/app/sql_split_test.go b/internal/app/sql_split_test.go
index 7c9c3d3..b4e4586 100644
--- a/internal/app/sql_split_test.go
+++ b/internal/app/sql_split_test.go
@@ -111,3 +111,59 @@ func TestSplitSQLStatements_SQLEscapedQuoteMultiple(t *testing.T) {
}
}
+func TestSplitSQLStatements_OracleAnonymousBlock(t *testing.T) {
+ input := `BEGIN
+ INSERT INTO tmp_disable_trigger (table_name) VALUES ('t_memcard_reg');
+ UPDATE t_memcard_reg SET CARDLEVEL = 1 WHERE MEMCARDNO = '8032277312';
+ DELETE FROM tmp_disable_trigger WHERE table_name = 't_memcard_reg';
+END;
+SELECT 1 FROM dual;`
+ got := splitSQLStatements(input)
+ want := []string{
+ `BEGIN
+ INSERT INTO tmp_disable_trigger (table_name) VALUES ('t_memcard_reg');
+ UPDATE t_memcard_reg SET CARDLEVEL = 1 WHERE MEMCARDNO = '8032277312';
+ DELETE FROM tmp_disable_trigger WHERE table_name = 't_memcard_reg';
+END;`,
+ "SELECT 1 FROM dual",
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("splitSQLStatements(%q) = %#v, want %#v", input, got, want)
+ }
+}
+
+func TestSplitSQLStatements_OracleDeclareBlock(t *testing.T) {
+ input := `DECLARE
+ v_count NUMBER;
+BEGIN
+ SELECT COUNT(*) INTO v_count FROM t_memcard_reg;
+ UPDATE t_memcard_reg SET CARDLEVEL = v_count WHERE MEMCARDNO = '8032277312';
+END;
+SELECT 1 FROM dual;`
+ got := splitSQLStatements(input)
+ want := []string{
+ `DECLARE
+ v_count NUMBER;
+BEGIN
+ SELECT COUNT(*) INTO v_count FROM t_memcard_reg;
+ UPDATE t_memcard_reg SET CARDLEVEL = v_count WHERE MEMCARDNO = '8032277312';
+END;`,
+ "SELECT 1 FROM dual",
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("splitSQLStatements(%q) = %#v, want %#v", input, got, want)
+ }
+}
+
+func TestSplitSQLStatements_TransactionBeginStillSplits(t *testing.T) {
+ input := "BEGIN; UPDATE accounts SET balance = balance - 1 WHERE id = 1; COMMIT;"
+ got := splitSQLStatements(input)
+ want := []string{
+ "BEGIN",
+ "UPDATE accounts SET balance = balance - 1 WHERE id = 1",
+ "COMMIT",
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("splitSQLStatements(%q) = %#v, want %#v", input, got, want)
+ }
+}