package app import ( "strings" "unicode" ) func leadingSQLKeyword(query string) string { text := strings.TrimSpace(query) for len(text) > 0 { trimmed := strings.TrimLeft(text, " \t\r\n") if trimmed == "" { return "" } text = trimmed switch { case strings.HasPrefix(text, "--"): if idx := strings.IndexByte(text, '\n'); idx >= 0 { text = text[idx+1:] continue } return "" case strings.HasPrefix(text, "#"): if idx := strings.IndexByte(text, '\n'); idx >= 0 { text = text[idx+1:] continue } return "" case strings.HasPrefix(text, "/*"): if idx := strings.Index(text, "*/"); idx >= 0 { text = text[idx+2:] continue } return "" } break } if text == "" { return "" } for i, r := range text { if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' { continue } if i == 0 { return "" } return strings.ToLower(text[:i]) } return strings.ToLower(text) } func sqlDataOperationKeyword(query string) string { keyword, _ := sqlDataOperationInfo(query) return keyword } func sqlDataOperationInfo(query string) (keyword string, withHasWrite bool) { keyword, keywordEnd := nextSQLKeyword(query, 0) if keyword != "with" { return keyword, false } if withKeyword, hasWrite, ok := sqlKeywordAfterLeadingWith(query, keywordEnd); ok { return withKeyword, hasWrite } return keyword, false } func nextSQLKeyword(text string, start int) (string, int) { pos := skipSQLTrivia(text, start) if pos >= len(text) || !isSQLKeywordByte(text[pos]) { return "", pos } end := pos + 1 for end < len(text) && isSQLKeywordByte(text[end]) { end++ } return strings.ToLower(text[pos:end]), end } func skipSQLTrivia(text string, start int) int { pos := start for pos < len(text) { switch { case text[pos] == ' ' || text[pos] == '\t' || text[pos] == '\r' || text[pos] == '\n' || text[pos] == '\f': pos++ case strings.HasPrefix(text[pos:], "--"): next := strings.IndexByte(text[pos:], '\n') if next < 0 { return len(text) } pos += next + 1 case strings.HasPrefix(text[pos:], "#"): next := strings.IndexByte(text[pos:], '\n') if next < 0 { return len(text) } pos += next + 1 case strings.HasPrefix(text[pos:], "/*"): end := strings.Index(text[pos+2:], "*/") if end < 0 { return len(text) } pos += end + 4 default: return pos } } return pos } func sqlKeywordAfterLeadingWith(text string, start int) (string, bool, bool) { pos := skipSQLTrivia(text, start) hasWriteCTE := false if keyword, end := nextSQLKeyword(text, pos); keyword == "recursive" { pos = end } for { pos = skipSQLTrivia(text, pos) next, ok := skipSQLIdentifierToken(text, pos) if !ok { return "", hasWriteCTE, false } pos = skipSQLTrivia(text, next) if pos < len(text) && text[pos] == '(' { next = skipBalancedSQLParens(text, pos) if next < 0 { return "", hasWriteCTE, false } pos = skipSQLTrivia(text, next) } asEnd := findTopLevelSQLKeyword(text, pos, "as") if asEnd < 0 { return "", hasWriteCTE, false } pos = skipSQLTrivia(text, asEnd) if keyword, end := nextSQLKeyword(text, pos); keyword == "not" { if nextKeyword, nextEnd := nextSQLKeyword(text, end); nextKeyword == "materialized" { pos = nextEnd } } else if keyword == "materialized" { pos = end } pos = skipSQLTrivia(text, pos) if pos >= len(text) || text[pos] != '(' { return "", hasWriteCTE, false } cteBodyStart := pos + 1 next = skipBalancedSQLParens(text, pos) if next < 0 { return "", hasWriteCTE, false } cteBodyEnd := next - 1 if cteBodyEnd >= cteBodyStart { bodyKeyword, bodyHasWrite := sqlDataOperationInfo(text[cteBodyStart:cteBodyEnd]) if bodyHasWrite || isSQLDataWriteKeyword(bodyKeyword) { hasWriteCTE = true } } pos = skipSQLTrivia(text, next) if pos < len(text) && text[pos] == ',' { pos++ continue } keyword, _ := nextSQLKeyword(text, pos) return keyword, hasWriteCTE, keyword != "" } } func findTopLevelSQLKeyword(text string, start int, want string) int { depth := 0 for pos := start; pos < len(text); { if next, ok := skipSQLQuotedOrComment(text, pos); ok { pos = next continue } switch text[pos] { case '(': depth++ pos++ case ')': if depth > 0 { depth-- } pos++ default: if depth == 0 && isSQLKeywordByte(text[pos]) { end := pos + 1 for end < len(text) && isSQLKeywordByte(text[end]) { end++ } if strings.EqualFold(text[pos:end], want) { return end } pos = end continue } pos++ } } return -1 } func skipSQLIdentifierToken(text string, start int) (int, bool) { if start >= len(text) { return start, false } switch text[start] { case '"', '`': next := skipSQLDelimited(text, start, text[start]) return next, next > start case '[': next := strings.IndexByte(text[start+1:], ']') if next < 0 { return len(text), true } return start + next + 2, true default: if !isSQLKeywordByte(text[start]) { return start, false } end := start + 1 for end < len(text) && isSQLKeywordByte(text[end]) { end++ } return end, true } } func skipBalancedSQLParens(text string, start int) int { if start >= len(text) || text[start] != '(' { return -1 } depth := 0 for pos := start; pos < len(text); { if next, ok := skipSQLQuotedOrComment(text, pos); ok { pos = next continue } switch text[pos] { case '(': depth++ pos++ case ')': depth-- pos++ if depth == 0 { return pos } default: pos++ } } return -1 } func skipSQLQuotedOrComment(text string, start int) (int, bool) { if start >= len(text) { return start, false } switch { case strings.HasPrefix(text[start:], "--"): next := strings.IndexByte(text[start:], '\n') if next < 0 { return len(text), true } return start + next + 1, true case strings.HasPrefix(text[start:], "#"): next := strings.IndexByte(text[start:], '\n') if next < 0 { return len(text), true } return start + next + 1, true case strings.HasPrefix(text[start:], "/*"): end := strings.Index(text[start+2:], "*/") if end < 0 { return len(text), true } return start + end + 4, true case text[start] == '\'' || text[start] == '"' || text[start] == '`': return skipSQLDelimited(text, start, text[start]), true case text[start] == '[': next := strings.IndexByte(text[start+1:], ']') if next < 0 { return len(text), true } return start + next + 2, true default: if tag, ok := sqlDollarQuoteTag(text, start); ok { end := strings.Index(text[start+len(tag):], tag) if end < 0 { return len(text), true } return start + len(tag) + end + len(tag), true } return start, false } } func skipSQLDelimited(text string, start int, delimiter byte) int { pos := start + 1 for pos < len(text) { if text[pos] == delimiter { if pos+1 < len(text) && text[pos+1] == delimiter { pos += 2 continue } return pos + 1 } pos++ } return len(text) } func sqlDollarQuoteTag(text string, start int) (string, bool) { if start >= len(text) || text[start] != '$' { return "", false } end := start + 1 for end < len(text) && (isSQLKeywordByte(text[end]) || text[end] == '-') { end++ } if end < len(text) && text[end] == '$' { return text[start : end+1], true } return "", false } func isSQLKeywordByte(ch byte) bool { return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' } func isReadOnlySQLQuery(dbType string, query string) bool { if strings.ToLower(strings.TrimSpace(dbType)) == "mongodb" && strings.HasPrefix(strings.TrimSpace(query), "{") { return true } keyword, withHasWrite := sqlDataOperationInfo(query) if withHasWrite { return false } switch keyword { case "select", "with", "show", "describe", "desc", "explain", "pragma", "values": return true default: return false } } func isBatchableWriteSQLStatement(dbType string, query string) bool { if isReadOnlySQLQuery(dbType, query) { return false } keyword, withHasWrite := sqlDataOperationInfo(query) if withHasWrite { return true } return isSQLDataWriteKeyword(keyword) } func isSQLDataWriteKeyword(keyword string) bool { switch keyword { case "insert", "update", "delete", "replace", "merge", "upsert": return true default: return false } } func sanitizeSQLForPgLike(dbType string, query string) string { normalizedType := strings.ToLower(strings.TrimSpace(dbType)) switch normalizedType { case "postgresql": normalizedType = "postgres" case "kingbase8", "kingbasees", "kingbasev8": normalizedType = "kingbase" } switch normalizedType { case "postgres", "kingbase", "highgo", "vastbase", "opengauss": // 有些情况下会出现多层重复引用(例如 """"schema"""" 或 ""schema"""),单次修复不一定收敛。 // 这里做有限次数的迭代,直到输出不再变化。 out := query for i := 0; i < 3; i++ { fixed := fixBrokenDoubleDoubleQuotedIdent(out) if fixed == out { break } out = fixed } return out default: return query } } // fixBrokenDoubleDoubleQuotedIdent fixes accidental identifiers like: // // SELECT * FROM ""schema"".""table"" // // which can be produced when a quoted identifier gets wrapped by quotes again. // // It is intentionally conservative: // - only runs outside strings/comments/dollar-quoted blocks // - does not touch valid escaped-quote sequences inside quoted identifiers (e.g. "a""b") func fixBrokenDoubleDoubleQuotedIdent(query string) string { if !strings.Contains(query, `""`) { return query } var b strings.Builder b.Grow(len(query)) inSingle := false inDoubleIdent := false inLineComment := false inBlockComment := false dollarTag := "" for i := 0; i < len(query); i++ { ch := query[i] next := byte(0) if i+1 < len(query) { next = query[i+1] } if inLineComment { b.WriteByte(ch) if ch == '\n' { inLineComment = false } continue } if inBlockComment { b.WriteByte(ch) if ch == '*' && next == '/' { b.WriteByte('/') i++ inBlockComment = false } continue } if dollarTag != "" { if strings.HasPrefix(query[i:], dollarTag) { b.WriteString(dollarTag) i += len(dollarTag) - 1 dollarTag = "" continue } b.WriteByte(ch) continue } if inSingle { b.WriteByte(ch) if ch == '\'' { // escaped single quote if next == '\'' { b.WriteByte('\'') i++ continue } inSingle = false } continue } if inDoubleIdent { b.WriteByte(ch) if ch == '"' { // escaped quote inside identifier if next == '"' { b.WriteByte('"') i++ continue } inDoubleIdent = false } continue } // --- Outside of all string/comment blocks --- if ch == '-' && next == '-' { b.WriteByte(ch) b.WriteByte('-') i++ inLineComment = true continue } if ch == '/' && next == '*' { b.WriteByte(ch) b.WriteByte('*') i++ inBlockComment = true continue } if ch == '\'' { b.WriteByte(ch) inSingle = true continue } if ch == '$' { if tag := parseDollarTag(query[i:]); tag != "" { b.WriteString(tag) i += len(tag) - 1 dollarTag = tag continue } } if ch == '"' { // Fix: ""ident"" -> "ident" (only when it looks like a plain identifier) // Also handle variants like ""ident""" / """"ident"""" (extra quotes at either side). if next == '"' { if replacement, advance, ok := tryFixDoubleDoubleQuotedIdent(query, i); ok { b.WriteString(replacement) i = advance - 1 continue } } b.WriteByte(ch) inDoubleIdent = true continue } b.WriteByte(ch) } return b.String() } func tryFixDoubleDoubleQuotedIdent(query string, start int) (replacement string, advance int, ok bool) { // start points at the first quote of a broken identifier, usually like: // ""ident"" / ""ident""" / """"ident"""" if start < 0 || start+1 >= len(query) { return "", 0, false } if query[start] != '"' || query[start+1] != '"' { return "", 0, false } if start > 0 && query[start-1] == '"' { return "", 0, false } runLen := 0 for start+runLen < len(query) && query[start+runLen] == '"' { runLen++ } if runLen < 2 || runLen%2 == 1 { // Odd run (e.g. """...) can be a valid quoted identifier with escaped quotes. return "", 0, false } contentStart := start + runLen j := contentStart for j < len(query) { if query[j] == '"' { endRunLen := 0 for j+endRunLen < len(query) && query[j+endRunLen] == '"' { endRunLen++ } if endRunLen >= 2 { content := strings.TrimSpace(query[contentStart:j]) if looksLikeIdentifierContent(content) { return `"` + content + `"`, j + endRunLen, true } return "", 0, false } } // Fast abort: identifier-like content should not span lines. if query[j] == '\n' || query[j] == '\r' { break } j++ } return "", 0, false } func looksLikeIdentifierContent(s string) bool { if strings.TrimSpace(s) == "" { return false } for _, r := range s { if r == '_' || r == '$' || r == '-' || unicode.IsLetter(r) || unicode.IsDigit(r) { continue } return false } return true } func parseDollarTag(s string) string { // Match: $tag$ where tag is [A-Za-z0-9_]* (can be empty => $$) if len(s) < 2 || s[0] != '$' { return "" } for i := 1; i < len(s); i++ { c := s[i] if c == '$' { return s[:i+1] } if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') { return "" } } return "" }