package app import "strings" // splitSQLStatements 按分号拆分 SQL 文本为独立语句。 // 正确处理单引号/双引号/反引号字符串、行注释(-- / #)、块注释(/* */)、 // PostgreSQL/Kingbase 的 $$...$$ dollar-quoting,以及 Oracle PL/SQL 匿名块, // 避免在这些上下文中错误拆分。 // 同时支持 SQL 标准的转义单引号(两个连续单引号 ” 表示字面量引号)。 func splitSQLStatements(sql string) []string { text := strings.ReplaceAll(sql, "\r\n", "\n") var statements []string var cur strings.Builder inSingle := false inDouble := false inBacktick := false escaped := false inLineComment := false inBlockComment := false var dollarTag string // postgres/kingbase: $$...$$ or $tag$...$tag$ plsqlDepth := 0 plsqlDeclareBeginSkips := 0 justClosedPLSQLBlock := false push := func() { s := strings.TrimSpace(cur.String()) if s != "" { statements = append(statements, s) } cur.Reset() } for i := 0; i < len(text); i++ { ch := text[i] next := byte(0) if i+1 < len(text) { next = text[i+1] } // 行注释 if inLineComment { if ch == '\n' { inLineComment = false } cur.WriteByte(ch) continue } // 块注释 if inBlockComment { cur.WriteByte(ch) if ch == '*' && next == '/' { cur.WriteByte('/') i++ inBlockComment = false } continue } // Dollar-quoting if dollarTag != "" { if strings.HasPrefix(text[i:], dollarTag) { cur.WriteString(dollarTag) i += len(dollarTag) - 1 dollarTag = "" } else { cur.WriteByte(ch) } continue } // 转义字符(反斜杠转义,MySQL 风格) if escaped { escaped = false cur.WriteByte(ch) continue } if (inSingle || inDouble) && ch == '\\' { escaped = true cur.WriteByte(ch) continue } // 字符串开闭 if !inDouble && !inBacktick && ch == '\'' { if inSingle && next == '\'' { // SQL 标准转义:两个连续单引号 '' 表示字面量引号,保持在引号内 cur.WriteByte(ch) cur.WriteByte(next) i++ continue } inSingle = !inSingle cur.WriteByte(ch) continue } if !inSingle && !inBacktick && ch == '"' { inDouble = !inDouble cur.WriteByte(ch) continue } if !inSingle && !inDouble && ch == '`' { inBacktick = !inBacktick cur.WriteByte(ch) continue } // 在引号/反引号内部不做任何判断 if inSingle || inDouble || inBacktick { cur.WriteByte(ch) continue } if isSQLIdentifierStart(ch) { tokenStart := i tokenEnd := i + 1 for tokenEnd < len(text) && isSQLIdentifierPart(text[tokenEnd]) { tokenEnd++ } token := strings.ToLower(text[tokenStart:tokenEnd]) if token == "begin" && plsqlDeclareBeginSkips > 0 { plsqlDeclareBeginSkips-- justClosedPLSQLBlock = false } else if token == "begin" && shouldEnterPLSQLBlock(text, tokenEnd) { plsqlDepth++ justClosedPLSQLBlock = false } else if token == "declare" && shouldEnterPLSQLDeclareBlock(text, tokenEnd) { plsqlDepth++ plsqlDeclareBeginSkips++ justClosedPLSQLBlock = false } else if plsqlDepth == 0 && shouldEnterPLSQLCreateRoutineBlock(text, cur.String(), token, tokenEnd) { plsqlDepth++ plsqlDeclareBeginSkips++ justClosedPLSQLBlock = false } else if token == "end" && plsqlDepth > 0 && !isPLSQLControlEnd(text, tokenEnd) { plsqlDepth-- if plsqlDeclareBeginSkips > plsqlDepth { plsqlDeclareBeginSkips = plsqlDepth } justClosedPLSQLBlock = plsqlDepth == 0 } cur.WriteString(text[tokenStart:tokenEnd]) i = tokenEnd - 1 continue } // 行注释开始 if ch == '-' && next == '-' { inLineComment = true cur.WriteByte(ch) continue } if ch == '#' { inLineComment = true cur.WriteByte(ch) continue } // 块注释开始 if ch == '/' && next == '*' { inBlockComment = true cur.WriteString("/*") i++ continue } // Dollar-quoting 开始 if ch == '$' { if tag := parseSQLDollarTag(text[i:]); tag != "" { dollarTag = tag cur.WriteString(tag) i += len(tag) - 1 continue } } // 分号分隔(支持全角分号";") if ch == ';' { if plsqlDepth > 0 { cur.WriteByte(ch) continue } if justClosedPLSQLBlock { cur.WriteByte(ch) push() justClosedPLSQLBlock = false continue } push() continue } // 全角分号 UTF-8 序列: 0xEF 0xBC 0x9B if ch == 0xEF && i+2 < len(text) && text[i+1] == 0xBC && text[i+2] == 0x9B { if plsqlDepth > 0 { cur.WriteString(";") i += 2 continue } if justClosedPLSQLBlock { cur.WriteString(";") push() justClosedPLSQLBlock = false i += 2 continue } push() i += 2 continue } cur.WriteByte(ch) } push() return statements } func isSQLIdentifierStart(ch byte) bool { return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '_' } func isSQLIdentifierPart(ch byte) bool { return isSQLIdentifierStart(ch) || (ch >= '0' && ch <= '9') || ch == '$' || ch == '#' } func skipSQLWhitespaceAndComments(text string, pos int) int { i := pos for i < len(text) { switch text[i] { case ' ', '\t', '\n', '\r', '\f': i++ continue case '-': if i+1 < len(text) && text[i+1] == '-' { i += 2 for i < len(text) && text[i] != '\n' { i++ } continue } case '/': if i+1 < len(text) && text[i+1] == '*' { i += 2 for i+1 < len(text) && !(text[i] == '*' && text[i+1] == '/') { i++ } if i+1 < len(text) { i += 2 } continue } } break } return i } func nextSQLSignificantToken(text string, pos int) string { i := skipSQLWhitespaceAndComments(text, pos) if i >= len(text) || !isSQLIdentifierStart(text[i]) { return "" } end := i + 1 for end < len(text) && isSQLIdentifierPart(text[end]) { end++ } return strings.ToLower(text[i:end]) } func nextSQLSignificantByte(text string, pos int) byte { i := skipSQLWhitespaceAndComments(text, pos) if i >= len(text) { return 0 } return text[i] } func shouldEnterPLSQLBlock(text string, tokenEnd int) bool { switch nextSQLSignificantByte(text, tokenEnd) { case 0, ';': return false } switch nextSQLSignificantToken(text, tokenEnd) { case "transaction", "work", "isolation", "read", "write": return false default: return true } } func isPLSQLBlockStatement(stmt string) bool { text := strings.TrimSpace(stmt) if text == "" { return false } if strings.HasSuffix(text, "/") { text = strings.TrimSpace(strings.TrimSuffix(text, "/")) } token := nextSQLSignificantToken(text, 0) if token == "declare" { return shouldEnterPLSQLDeclareBlock(text, len("declare")) } if token == "begin" { return shouldEnterPLSQLBlock(text, len("begin")) } return isCreateRoutineHeaderPrefix(text) } func shouldEnterPLSQLDeclareBlock(text string, tokenEnd int) bool { return nextSQLSignificantToken(text, tokenEnd) != "" } func isPLSQLControlEnd(text string, tokenEnd int) bool { switch nextSQLSignificantToken(text, tokenEnd) { case "if", "loop", "case": return true default: return false } } func isCreateRoutineHeaderPrefix(text string) bool { currentToken, currentEnd := nextSQLSignificantTokenSpan(text, 0) if currentToken != "create" { return false } currentToken, currentEnd = nextSQLSignificantTokenSpan(text, currentEnd) if currentToken == "or" { currentToken, currentEnd = nextSQLSignificantTokenSpan(text, currentEnd) if currentToken != "replace" { return false } currentToken, currentEnd = nextSQLSignificantTokenSpan(text, currentEnd) } for currentToken == "editionable" || currentToken == "noneditionable" { currentToken, currentEnd = nextSQLSignificantTokenSpan(text, currentEnd) } return currentToken == "procedure" || currentToken == "function" } func nextSQLSignificantTokenSpan(text string, pos int) (string, int) { i := skipSQLWhitespaceAndComments(text, pos) if i >= len(text) || !isSQLIdentifierStart(text[i]) { return "", i } end := i + 1 for end < len(text) && isSQLIdentifierPart(text[end]) { end++ } return strings.ToLower(text[i:end]), end } func shouldEnterPLSQLCreateRoutineBlock(text string, currentStatementPrefix string, token string, tokenEnd int) bool { if token != "is" && token != "as" { return false } nextChar := nextSQLSignificantByte(text, tokenEnd) if nextChar == 0 { return false } if token == "as" && (nextChar == '$' || nextChar == '\'' || nextChar == '"') { return false } return isCreateRoutineHeaderPrefix(currentStatementPrefix) } // parseSQLDollarTag 解析 PostgreSQL/Kingbase 的 dollar-quoting 标签。 func parseSQLDollarTag(s string) string { if len(s) < 2 || s[0] != '$' { return "" } for i := 1; i < len(s); i++ { c := s[i] if c == '$' { return s[:i+1] } if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') { return "" } } return "" }