Files
MyGoNavi/internal/app/sql_split.go

387 lines
8.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package app
import "strings"
// splitSQLStatements 按分号拆分 SQL 文本为独立语句。
// 正确处理单引号/双引号/反引号字符串、行注释(-- / #)、块注释(/* */)、
// PostgreSQL/Kingbase 的 $$...$$ dollar-quoting以及 Oracle PL/SQL 匿名块,
// 避免在这些上下文中错误拆分。
// 同时支持 SQL 标准的转义单引号(两个连续单引号 ” 表示字面量引号)。
func splitSQLStatements(sql string) []string {
text := strings.ReplaceAll(sql, "\r\n", "\n")
var statements []string
var cur strings.Builder
inSingle := false
inDouble := false
inBacktick := false
escaped := false
inLineComment := false
inBlockComment := false
var dollarTag string // postgres/kingbase: $$...$$ or $tag$...$tag$
plsqlDepth := 0
plsqlDeclareBeginSkips := 0
justClosedPLSQLBlock := false
push := func() {
s := strings.TrimSpace(cur.String())
if s != "" {
statements = append(statements, s)
}
cur.Reset()
}
for i := 0; i < len(text); i++ {
ch := text[i]
next := byte(0)
if i+1 < len(text) {
next = text[i+1]
}
// 行注释
if inLineComment {
if ch == '\n' {
inLineComment = false
}
cur.WriteByte(ch)
continue
}
// 块注释
if inBlockComment {
cur.WriteByte(ch)
if ch == '*' && next == '/' {
cur.WriteByte('/')
i++
inBlockComment = false
}
continue
}
// Dollar-quoting
if dollarTag != "" {
if strings.HasPrefix(text[i:], dollarTag) {
cur.WriteString(dollarTag)
i += len(dollarTag) - 1
dollarTag = ""
} else {
cur.WriteByte(ch)
}
continue
}
// 转义字符反斜杠转义MySQL 风格)
if escaped {
escaped = false
cur.WriteByte(ch)
continue
}
if (inSingle || inDouble) && ch == '\\' {
escaped = true
cur.WriteByte(ch)
continue
}
// 字符串开闭
if !inDouble && !inBacktick && ch == '\'' {
if inSingle && next == '\'' {
// SQL 标准转义:两个连续单引号 '' 表示字面量引号,保持在引号内
cur.WriteByte(ch)
cur.WriteByte(next)
i++
continue
}
inSingle = !inSingle
cur.WriteByte(ch)
continue
}
if !inSingle && !inBacktick && ch == '"' {
inDouble = !inDouble
cur.WriteByte(ch)
continue
}
if !inSingle && !inDouble && ch == '`' {
inBacktick = !inBacktick
cur.WriteByte(ch)
continue
}
// 在引号/反引号内部不做任何判断
if inSingle || inDouble || inBacktick {
cur.WriteByte(ch)
continue
}
if isSQLIdentifierStart(ch) {
tokenStart := i
tokenEnd := i + 1
for tokenEnd < len(text) && isSQLIdentifierPart(text[tokenEnd]) {
tokenEnd++
}
token := strings.ToLower(text[tokenStart:tokenEnd])
if token == "begin" && plsqlDeclareBeginSkips > 0 {
plsqlDeclareBeginSkips--
justClosedPLSQLBlock = false
} else if token == "begin" && shouldEnterPLSQLBlock(text, tokenEnd) {
plsqlDepth++
justClosedPLSQLBlock = false
} else if token == "declare" && shouldEnterPLSQLDeclareBlock(text, tokenEnd) {
plsqlDepth++
plsqlDeclareBeginSkips++
justClosedPLSQLBlock = false
} else if plsqlDepth == 0 && shouldEnterPLSQLCreateRoutineBlock(text, cur.String(), token, tokenEnd) {
plsqlDepth++
plsqlDeclareBeginSkips++
justClosedPLSQLBlock = false
} else if token == "end" && plsqlDepth > 0 && !isPLSQLControlEnd(text, tokenEnd) {
plsqlDepth--
if plsqlDeclareBeginSkips > plsqlDepth {
plsqlDeclareBeginSkips = plsqlDepth
}
justClosedPLSQLBlock = plsqlDepth == 0
}
cur.WriteString(text[tokenStart:tokenEnd])
i = tokenEnd - 1
continue
}
// 行注释开始
if ch == '-' && next == '-' {
inLineComment = true
cur.WriteByte(ch)
continue
}
if ch == '#' {
inLineComment = true
cur.WriteByte(ch)
continue
}
// 块注释开始
if ch == '/' && next == '*' {
inBlockComment = true
cur.WriteString("/*")
i++
continue
}
// Dollar-quoting 开始
if ch == '$' {
if tag := parseSQLDollarTag(text[i:]); tag != "" {
dollarTag = tag
cur.WriteString(tag)
i += len(tag) - 1
continue
}
}
// 分号分隔(支持全角分号""
if ch == ';' {
if plsqlDepth > 0 {
cur.WriteByte(ch)
continue
}
if justClosedPLSQLBlock {
cur.WriteByte(ch)
push()
justClosedPLSQLBlock = false
continue
}
push()
continue
}
// 全角分号 UTF-8 序列: 0xEF 0xBC 0x9B
if ch == 0xEF && i+2 < len(text) && text[i+1] == 0xBC && text[i+2] == 0x9B {
if plsqlDepth > 0 {
cur.WriteString("")
i += 2
continue
}
if justClosedPLSQLBlock {
cur.WriteString("")
push()
justClosedPLSQLBlock = false
i += 2
continue
}
push()
i += 2
continue
}
cur.WriteByte(ch)
}
push()
return statements
}
func isSQLIdentifierStart(ch byte) bool {
return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '_'
}
func isSQLIdentifierPart(ch byte) bool {
return isSQLIdentifierStart(ch) || (ch >= '0' && ch <= '9') || ch == '$' || ch == '#'
}
func skipSQLWhitespaceAndComments(text string, pos int) int {
i := pos
for i < len(text) {
switch text[i] {
case ' ', '\t', '\n', '\r', '\f':
i++
continue
case '-':
if i+1 < len(text) && text[i+1] == '-' {
i += 2
for i < len(text) && text[i] != '\n' {
i++
}
continue
}
case '/':
if i+1 < len(text) && text[i+1] == '*' {
i += 2
for i+1 < len(text) && !(text[i] == '*' && text[i+1] == '/') {
i++
}
if i+1 < len(text) {
i += 2
}
continue
}
}
break
}
return i
}
func nextSQLSignificantToken(text string, pos int) string {
i := skipSQLWhitespaceAndComments(text, pos)
if i >= len(text) || !isSQLIdentifierStart(text[i]) {
return ""
}
end := i + 1
for end < len(text) && isSQLIdentifierPart(text[end]) {
end++
}
return strings.ToLower(text[i:end])
}
func nextSQLSignificantByte(text string, pos int) byte {
i := skipSQLWhitespaceAndComments(text, pos)
if i >= len(text) {
return 0
}
return text[i]
}
func shouldEnterPLSQLBlock(text string, tokenEnd int) bool {
switch nextSQLSignificantByte(text, tokenEnd) {
case 0, ';':
return false
}
switch nextSQLSignificantToken(text, tokenEnd) {
case "transaction", "work", "isolation", "read", "write":
return false
default:
return true
}
}
func isPLSQLBlockStatement(stmt string) bool {
text := strings.TrimSpace(stmt)
if text == "" {
return false
}
if strings.HasSuffix(text, "/") {
text = strings.TrimSpace(strings.TrimSuffix(text, "/"))
}
token := nextSQLSignificantToken(text, 0)
if token == "declare" {
return shouldEnterPLSQLDeclareBlock(text, len("declare"))
}
if token == "begin" {
return shouldEnterPLSQLBlock(text, len("begin"))
}
return isCreateRoutineHeaderPrefix(text)
}
func shouldEnterPLSQLDeclareBlock(text string, tokenEnd int) bool {
return nextSQLSignificantToken(text, tokenEnd) != ""
}
func isPLSQLControlEnd(text string, tokenEnd int) bool {
switch nextSQLSignificantToken(text, tokenEnd) {
case "if", "loop", "case":
return true
default:
return false
}
}
func isCreateRoutineHeaderPrefix(text string) bool {
currentToken, currentEnd := nextSQLSignificantTokenSpan(text, 0)
if currentToken != "create" {
return false
}
currentToken, currentEnd = nextSQLSignificantTokenSpan(text, currentEnd)
if currentToken == "or" {
currentToken, currentEnd = nextSQLSignificantTokenSpan(text, currentEnd)
if currentToken != "replace" {
return false
}
currentToken, currentEnd = nextSQLSignificantTokenSpan(text, currentEnd)
}
for currentToken == "editionable" || currentToken == "noneditionable" {
currentToken, currentEnd = nextSQLSignificantTokenSpan(text, currentEnd)
}
return currentToken == "procedure" || currentToken == "function"
}
func nextSQLSignificantTokenSpan(text string, pos int) (string, int) {
i := skipSQLWhitespaceAndComments(text, pos)
if i >= len(text) || !isSQLIdentifierStart(text[i]) {
return "", i
}
end := i + 1
for end < len(text) && isSQLIdentifierPart(text[end]) {
end++
}
return strings.ToLower(text[i:end]), end
}
func shouldEnterPLSQLCreateRoutineBlock(text string, currentStatementPrefix string, token string, tokenEnd int) bool {
if token != "is" && token != "as" {
return false
}
nextChar := nextSQLSignificantByte(text, tokenEnd)
if nextChar == 0 {
return false
}
if token == "as" && (nextChar == '$' || nextChar == '\'' || nextChar == '"') {
return false
}
return isCreateRoutineHeaderPrefix(currentStatementPrefix)
}
// parseSQLDollarTag 解析 PostgreSQL/Kingbase 的 dollar-quoting 标签。
func parseSQLDollarTag(s string) string {
if len(s) < 2 || s[0] != '$' {
return ""
}
for i := 1; i < len(s); i++ {
c := s[i]
if c == '$' {
return s[:i+1]
}
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') {
return ""
}
}
return ""
}