diff --git a/frontend/src/utils/sqlStatementSelection.test.ts b/frontend/src/utils/sqlStatementSelection.test.ts index c0764ae..7c3a6a9 100644 --- a/frontend/src/utils/sqlStatementSelection.test.ts +++ b/frontend/src/utils/sqlStatementSelection.test.ts @@ -175,6 +175,63 @@ describe('sqlStatementSelection', () => { }); }); + it('keeps Oracle CREATE PROCEDURE cursor CASE expressions as one executable statement', () => { + const sql = [ + 'CREATE OR REPLACE PROCEDURE proc_accept_to_add(', + ' p_acceptno IN t_accept_h.acceptno%TYPE', + ') IS', + ' CURSOR cur_store_same(p_ind s_sys_ini.inipara%TYPE) IS', + ' SELECT si.compid, si.batid, si.wareid', + ' FROM t_store_i si', + ' ORDER BY CASE', + " WHEN p_ind = '1' THEN", + " to_char(si.invalidate - to_date('19700101', 'yyyymmdd'))", + " WHEN p_ind = '2' THEN", + " lpad(to_char(floor(si.wareqty)), 10, '0')", + ' ELSE', + ' to_char(si.batid)', + ' END,si.batid;', + 'BEGIN', + ' NULL;', + 'END;', + '/', + 'SELECT 1 FROM dual;', + ].join('\n'); + + const ranges = findSqlStatementRanges(sql).map((range) => range.text); + + expect(ranges).toEqual([ + [ + 'CREATE OR REPLACE PROCEDURE proc_accept_to_add(', + ' p_acceptno IN t_accept_h.acceptno%TYPE', + ') IS', + ' CURSOR cur_store_same(p_ind s_sys_ini.inipara%TYPE) IS', + ' SELECT si.compid, si.batid, si.wareid', + ' FROM t_store_i si', + ' ORDER BY CASE', + " WHEN p_ind = '1' THEN", + " to_char(si.invalidate - to_date('19700101', 'yyyymmdd'))", + " WHEN p_ind = '2' THEN", + " lpad(to_char(floor(si.wareqty)), 10, '0')", + ' ELSE', + ' to_char(si.batid)', + ' END,si.batid;', + 'BEGIN', + ' NULL;', + 'END;', + ].join('\n'), + 'SELECT 1 FROM dual', + ]); + expect(resolveExecutableSql(sql, sql.indexOf('ORDER BY CASE'))).toEqual({ + sql: ranges[0], + source: 'statement', + }); + expect(resolveExecutableSql(sql, sql.indexOf('NULL'))).toEqual({ + sql: ranges[0], + source: 'statement', + }); + }); + it('skips SQL*Plus slash delimiter comments after named Oracle procedure endings', () => { const sql = [ '-- 修改函数/存储过程:H2.cproc_tzhssr_order2sale_A1', diff --git a/frontend/src/utils/sqlStatementSelection.ts b/frontend/src/utils/sqlStatementSelection.ts index 2788ce0..bb353bc 100644 --- a/frontend/src/utils/sqlStatementSelection.ts +++ b/frontend/src/utils/sqlStatementSelection.ts @@ -205,6 +205,8 @@ export const findSqlStatementRanges = (sql: string): SqlStatementRange[] => { let dollarTag: string | null = null; let plsqlDepth = 0; let plsqlDeclareBeginSkips = 0; + let plsqlCaseDepth = 0; + let skipNextPlsqlCaseEndToken = false; let justClosedPLSQLBlock = false; const push = (end: number) => { @@ -309,6 +311,16 @@ export const findSqlStatementRanges = (sql: string): SqlStatementRange[] => { tokenEnd++; } const token = text.slice(index, tokenEnd).toLowerCase(); + if (token === 'case' && plsqlDepth > 0) { + if (skipNextPlsqlCaseEndToken) { + skipNextPlsqlCaseEndToken = false; + } else { + plsqlCaseDepth++; + justClosedPLSQLBlock = false; + } + } else if (token !== 'case') { + skipNextPlsqlCaseEndToken = false; + } if (token === 'begin' && plsqlDeclareBeginSkips > 0) { plsqlDeclareBeginSkips--; justClosedPLSQLBlock = false; @@ -325,11 +337,20 @@ export const findSqlStatementRanges = (sql: string): SqlStatementRange[] => { plsqlDeclareBeginSkips++; } justClosedPLSQLBlock = false; + } else if (token === 'end' && plsqlDepth > 0 && plsqlCaseDepth > 0) { + plsqlCaseDepth--; + if (nextSqlSignificantToken(text, tokenEnd) === 'case') { + skipNextPlsqlCaseEndToken = true; + } + justClosedPLSQLBlock = false; } else if (token === 'end' && plsqlDepth > 0 && !isPlsqlControlEnd(text, tokenEnd)) { plsqlDepth--; if (plsqlDeclareBeginSkips > plsqlDepth) { plsqlDeclareBeginSkips = plsqlDepth; } + if (plsqlCaseDepth > plsqlDepth) { + plsqlCaseDepth = plsqlDepth; + } justClosedPLSQLBlock = plsqlDepth === 0; } index = tokenEnd - 1; diff --git a/internal/app/methods_db_multi_test.go b/internal/app/methods_db_multi_test.go index 3b52ddd..218059e 100644 --- a/internal/app/methods_db_multi_test.go +++ b/internal/app/methods_db_multi_test.go @@ -436,6 +436,57 @@ END;` } } +func TestDBQueryMultiKeepsOracleCreateProcedureCursorCaseExpressionAsSingleStatement(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + t.Cleanup(func() { + newDatabaseFunc = originalNewDatabaseFunc + }) + + fakeDB := &fakeBatchWriteDB{} + newDatabaseFunc = func(dbType string) (db.Database, error) { + return fakeDB, nil + } + + app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test")) + config := connection.ConnectionConfig{ + Type: "oracle", + Host: "127.0.0.1", + Port: 1521, + User: "app", + } + query := `CREATE OR REPLACE PROCEDURE proc_accept_to_add( + p_acceptno IN t_accept_h.acceptno%TYPE +) IS + CURSOR cur_store_same(p_ind s_sys_ini.inipara%TYPE) IS + SELECT si.compid, si.batid, si.wareid + FROM t_store_i si + ORDER BY CASE + WHEN p_ind = '1' THEN + to_char(si.invalidate - to_date('19700101', 'yyyymmdd')) + WHEN p_ind = '2' THEN + lpad(to_char(floor(si.wareqty)), 10, '0') + ELSE + to_char(si.batid) + END,si.batid; +BEGIN + NULL; +END;` + + result := app.DBQueryMulti(config, "ORCLPDB1", query, "oracle-create-procedure-cursor-case-test") + if !result.Success { + t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message) + } + if fakeDB.batchCalls != 0 { + t.Fatalf("expected CREATE PROCEDURE to skip batch path, got batchCalls=%d", fakeDB.batchCalls) + } + if fakeDB.execCalls != 1 || len(fakeDB.execQueries) != 1 { + t.Fatalf("expected one sequential exec call, got execCalls=%d queries=%#v", fakeDB.execCalls, fakeDB.execQueries) + } + if fakeDB.execQueries[0] != query { + t.Fatalf("expected CREATE PROCEDURE to stay intact, got %q", fakeDB.execQueries[0]) + } +} + func TestDBQueryMultiSkipsOracleSqlPlusSlashDelimiter(t *testing.T) { originalNewDatabaseFunc := newDatabaseFunc t.Cleanup(func() { diff --git a/internal/app/methods_file_sql_execution_test.go b/internal/app/methods_file_sql_execution_test.go index 5a12789..2d7bd45 100644 --- a/internal/app/methods_file_sql_execution_test.go +++ b/internal/app/methods_file_sql_execution_test.go @@ -436,6 +436,66 @@ func TestStreamSQLFileKeepsOracleCreateProcedureTogether(t *testing.T) { } } +func TestStreamSQLFileKeepsOracleCreateProcedureCursorCaseExpressionTogether(t *testing.T) { + input := strings.Join([]string{ + "CREATE OR REPLACE PROCEDURE proc_accept_to_add(", + " p_acceptno IN t_accept_h.acceptno%TYPE", + ") IS", + " CURSOR cur_store_same(p_ind s_sys_ini.inipara%TYPE) IS", + " SELECT si.compid, si.batid, si.wareid", + " FROM t_store_i si", + " ORDER BY CASE", + " WHEN p_ind = '1' THEN", + " to_char(si.invalidate - to_date('19700101', 'yyyymmdd'))", + " WHEN p_ind = '2' THEN", + " lpad(to_char(floor(si.wareqty)), 10, '0')", + " ELSE", + " to_char(si.batid)", + " END,si.batid;", + "BEGIN", + " NULL;", + "END;", + "/", + "SELECT 1 FROM dual;", + }, "\n") + var statements []string + + count, err := streamSQLFile(&chunkedReader{data: []byte(input), step: 4}, func(index int, stmt string) error { + statements = append(statements, stmt) + return nil + }) + if err != nil { + t.Fatalf("streamSQLFile returned error: %v", err) + } + if count != 2 || len(statements) != 2 { + t.Fatalf("expected 2 statements, got count=%d statements=%#v", count, statements) + } + if statements[0] != strings.Join([]string{ + "CREATE OR REPLACE PROCEDURE proc_accept_to_add(", + " p_acceptno IN t_accept_h.acceptno%TYPE", + ") IS", + " CURSOR cur_store_same(p_ind s_sys_ini.inipara%TYPE) IS", + " SELECT si.compid, si.batid, si.wareid", + " FROM t_store_i si", + " ORDER BY CASE", + " WHEN p_ind = '1' THEN", + " to_char(si.invalidate - to_date('19700101', 'yyyymmdd'))", + " WHEN p_ind = '2' THEN", + " lpad(to_char(floor(si.wareqty)), 10, '0')", + " ELSE", + " to_char(si.batid)", + " END,si.batid;", + "BEGIN", + " NULL;", + "END;", + }, "\n") { + t.Fatalf("unexpected create procedure statement: %q", statements[0]) + } + if statements[1] != "SELECT 1 FROM dual" { + t.Fatalf("unexpected second statement: %q", statements[1]) + } +} + func TestStreamSQLFileSkipsOracleSqlPlusSlashDelimiter(t *testing.T) { input := strings.Join([]string{ "CREATE OR REPLACE PROCEDURE proc_tally2accept(", diff --git a/internal/app/sql_split.go b/internal/app/sql_split.go index c08002d..6b36a1d 100644 --- a/internal/app/sql_split.go +++ b/internal/app/sql_split.go @@ -21,6 +21,8 @@ func splitSQLStatements(sql string) []string { var dollarTag string // postgres/kingbase: $$...$$ or $tag$...$tag$ plsqlDepth := 0 plsqlDeclareBeginSkips := 0 + plsqlCaseDepth := 0 + skipNextPLSQLCaseEndToken := false justClosedPLSQLBlock := false push := func() { @@ -119,6 +121,16 @@ func splitSQLStatements(sql string) []string { tokenEnd++ } token := strings.ToLower(text[tokenStart:tokenEnd]) + if token == "case" && plsqlDepth > 0 { + if skipNextPLSQLCaseEndToken { + skipNextPLSQLCaseEndToken = false + } else { + plsqlCaseDepth++ + justClosedPLSQLBlock = false + } + } else if token != "case" { + skipNextPLSQLCaseEndToken = false + } if token == "begin" && plsqlDeclareBeginSkips > 0 { plsqlDeclareBeginSkips-- justClosedPLSQLBlock = false @@ -135,11 +147,20 @@ func splitSQLStatements(sql string) []string { plsqlDeclareBeginSkips++ } justClosedPLSQLBlock = false + } else if token == "end" && plsqlDepth > 0 && plsqlCaseDepth > 0 { + plsqlCaseDepth-- + if nextSQLSignificantToken(text, tokenEnd) == "case" { + skipNextPLSQLCaseEndToken = true + } + justClosedPLSQLBlock = false } else if token == "end" && plsqlDepth > 0 && !isPLSQLControlEnd(text, tokenEnd) { plsqlDepth-- if plsqlDeclareBeginSkips > plsqlDepth { plsqlDeclareBeginSkips = plsqlDepth } + if plsqlCaseDepth > plsqlDepth { + plsqlCaseDepth = plsqlDepth + } justClosedPLSQLBlock = plsqlDepth == 0 } cur.WriteString(text[tokenStart:tokenEnd]) diff --git a/internal/app/sql_split_stream.go b/internal/app/sql_split_stream.go index d11f5ff..5edf7d4 100644 --- a/internal/app/sql_split_stream.go +++ b/internal/app/sql_split_stream.go @@ -20,6 +20,8 @@ type sqlStreamSplitter struct { dollarTag string plsqlDepth int declareSkips int + plsqlCaseDepth int + skipCaseEnd bool closedPLSQL bool } @@ -136,6 +138,16 @@ func (s *sqlStreamSplitter) Feed(chunk []byte) []string { s.pending = text[tokenStart:] break } + if token == "case" && s.plsqlDepth > 0 { + if s.skipCaseEnd { + s.skipCaseEnd = false + } else { + s.plsqlCaseDepth++ + s.closedPLSQL = false + } + } else if token != "case" { + s.skipCaseEnd = false + } if token == "begin" && s.declareSkips > 0 { s.declareSkips-- s.closedPLSQL = false @@ -152,11 +164,20 @@ func (s *sqlStreamSplitter) Feed(chunk []byte) []string { s.declareSkips++ } s.closedPLSQL = false + } else if token == "end" && s.plsqlDepth > 0 && s.plsqlCaseDepth > 0 { + s.plsqlCaseDepth-- + if nextSQLSignificantToken(text, tokenEnd) == "case" { + s.skipCaseEnd = true + } + s.closedPLSQL = false } else if token == "end" && s.plsqlDepth > 0 && !isPLSQLControlEnd(text, tokenEnd) { s.plsqlDepth-- if s.declareSkips > s.plsqlDepth { s.declareSkips = s.plsqlDepth } + if s.plsqlCaseDepth > s.plsqlDepth { + s.plsqlCaseDepth = s.plsqlDepth + } s.closedPLSQL = s.plsqlDepth == 0 } s.cur.WriteString(text[tokenStart:tokenEnd]) diff --git a/internal/app/sql_split_test.go b/internal/app/sql_split_test.go index d812cac..e7eb630 100644 --- a/internal/app/sql_split_test.go +++ b/internal/app/sql_split_test.go @@ -251,6 +251,52 @@ END;`, } } +func TestSplitSQLStatements_OracleCreateProcedureKeepsCursorCaseExpression(t *testing.T) { + input := `CREATE OR REPLACE PROCEDURE proc_accept_to_add( + p_acceptno IN t_accept_h.acceptno%TYPE +) IS + CURSOR cur_store_same(p_ind s_sys_ini.inipara%TYPE) IS + SELECT si.compid, si.batid, si.wareid + FROM t_store_i si + ORDER BY CASE + WHEN p_ind = '1' THEN + to_char(si.invalidate - to_date('19700101', 'yyyymmdd')) + WHEN p_ind = '2' THEN + lpad(to_char(floor(si.wareqty)), 10, '0') + ELSE + to_char(si.batid) + END,si.batid; +BEGIN + NULL; +END; +/ +SELECT 1 FROM dual;` + got := splitSQLStatements(input) + want := []string{ + `CREATE OR REPLACE PROCEDURE proc_accept_to_add( + p_acceptno IN t_accept_h.acceptno%TYPE +) IS + CURSOR cur_store_same(p_ind s_sys_ini.inipara%TYPE) IS + SELECT si.compid, si.batid, si.wareid + FROM t_store_i si + ORDER BY CASE + WHEN p_ind = '1' THEN + to_char(si.invalidate - to_date('19700101', 'yyyymmdd')) + WHEN p_ind = '2' THEN + lpad(to_char(floor(si.wareqty)), 10, '0') + ELSE + to_char(si.batid) + END,si.batid; +BEGIN + NULL; +END;`, + "SELECT 1 FROM dual", + } + if !reflect.DeepEqual(got, want) { + t.Errorf("splitSQLStatements(%q) = %#v, want %#v", input, got, want) + } +} + func TestSplitSQLStatements_OracleCreateProcedureSkipsCommentedSqlPlusSlashDelimiter(t *testing.T) { input := `-- 修改函数/存储过程:H2.cproc_tzhssr_order2sale_A1 -- 请确认语法兼容当前数据库后执行