Files
MyGoNavi/internal/app/sql_split_stream.go
Syngnat 872b089b15 ️ perf(sql-import): 优化 SQL 文件流式导入性能
- 使用批量执行减少大 SQL 文件导入的数据库往返

- 引入独立导入会话,保留导入过程中的会话状态

- 批量失败时回滚并降级逐条执行,避免中断后续导入

- 补充 SQL 文件导入与流式拆分回归测试

Refs #487
2026-05-23 12:58:38 +08:00

260 lines
5.3 KiB
Go

package app
import (
"io"
"strings"
)
// sqlStreamSplitter 是一个流式 SQL 语句拆分器,适用于处理大文件。
// 调用方通过 Feed(chunk) 逐块喂入数据,通过 Flush() 获取最后一条残余语句。
// 内部维护与 splitSQLStatements 完全一致的状态机逻辑。
type sqlStreamSplitter struct {
cur strings.Builder
pending string
inSingle bool
inDouble bool
inBacktick bool
escaped bool
inLineComment bool
inBlockComment bool
dollarTag string
}
// Feed 将一个 chunk 喂入拆分器,返回在此 chunk 中完成的 SQL 语句列表。
func (s *sqlStreamSplitter) Feed(chunk []byte) []string {
var statements []string
text := s.pending + string(chunk)
s.pending = ""
for i := 0; i < len(text); i++ {
ch := text[i]
next := byte(0)
if i+1 < len(text) {
next = text[i+1]
}
// 行注释
if s.inLineComment {
if ch == '\n' {
s.inLineComment = false
}
s.cur.WriteByte(ch)
continue
}
// 块注释
if s.inBlockComment {
if ch == '*' && i+1 >= len(text) {
s.pending = text[i:]
break
}
s.cur.WriteByte(ch)
if ch == '*' && next == '/' {
s.cur.WriteByte('/')
i++
s.inBlockComment = false
}
continue
}
// Dollar-quoting
if s.dollarTag != "" {
if strings.HasPrefix(text[i:], s.dollarTag) {
s.cur.WriteString(s.dollarTag)
i += len(s.dollarTag) - 1
s.dollarTag = ""
} else if ch == '$' && len(text[i:]) < len(s.dollarTag) && strings.HasPrefix(s.dollarTag, text[i:]) {
s.pending = text[i:]
break
} else {
s.cur.WriteByte(ch)
}
continue
}
// 转义字符
if s.escaped {
s.escaped = false
s.cur.WriteByte(ch)
continue
}
if (s.inSingle || s.inDouble) && ch == '\\' {
s.escaped = true
s.cur.WriteByte(ch)
continue
}
// 字符串开闭
if !s.inDouble && !s.inBacktick && ch == '\'' {
if s.inSingle && i+1 >= len(text) {
s.pending = text[i:]
break
}
if s.inSingle && next == '\'' {
// SQL 标准转义:两个连续单引号
s.cur.WriteByte(ch)
s.cur.WriteByte(next)
i++
continue
}
s.inSingle = !s.inSingle
s.cur.WriteByte(ch)
continue
}
if !s.inSingle && !s.inBacktick && ch == '"' {
s.inDouble = !s.inDouble
s.cur.WriteByte(ch)
continue
}
if !s.inSingle && !s.inDouble && ch == '`' {
s.inBacktick = !s.inBacktick
s.cur.WriteByte(ch)
continue
}
// 在引号/反引号内部不做任何判断
if s.inSingle || s.inDouble || s.inBacktick {
s.cur.WriteByte(ch)
continue
}
// 行注释开始
if ch == '-' && i+1 >= len(text) {
s.pending = text[i:]
break
}
if ch == '-' && next == '-' {
s.inLineComment = true
s.cur.WriteByte(ch)
continue
}
if ch == '#' {
s.inLineComment = true
s.cur.WriteByte(ch)
continue
}
// 块注释开始
if ch == '/' && i+1 >= len(text) {
s.pending = text[i:]
break
}
if ch == '/' && next == '*' {
s.inBlockComment = true
s.cur.WriteString("/*")
i++
continue
}
// Dollar-quoting 开始
if ch == '$' {
if tag := parseSQLDollarTag(text[i:]); tag != "" {
s.dollarTag = tag
s.cur.WriteString(tag)
i += len(tag) - 1
continue
}
if isIncompleteSQLDollarTag(text[i:]) {
s.pending = text[i:]
break
}
}
// 分号分隔
if ch == ';' {
stmt := strings.TrimSpace(s.cur.String())
if stmt != "" {
statements = append(statements, stmt)
}
s.cur.Reset()
continue
}
// 全角分号
if ch == 0xEF && i+2 >= len(text) {
s.pending = text[i:]
break
}
if ch == 0xEF && i+2 < len(text) && text[i+1] == 0xBC && text[i+2] == 0x9B {
stmt := strings.TrimSpace(s.cur.String())
if stmt != "" {
statements = append(statements, stmt)
}
s.cur.Reset()
i += 2
continue
}
s.cur.WriteByte(ch)
}
return statements
}
// Flush 返回缓冲区中剩余的不完整语句(文件结束时调用)。
func (s *sqlStreamSplitter) Flush() string {
if s.pending != "" {
s.cur.WriteString(s.pending)
s.pending = ""
}
stmt := strings.TrimSpace(s.cur.String())
s.cur.Reset()
return stmt
}
func isIncompleteSQLDollarTag(s string) bool {
if len(s) == 0 || s[0] != '$' {
return false
}
for i := 1; i < len(s); i++ {
c := s[i]
if c == '$' {
return false
}
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') {
return false
}
}
return true
}
// streamSQLFile 从 reader 中流式读取 SQL 并逐条回调。
// onStatement 返回 error 时停止读取并返回该 error。
// 返回总处理语句数和可能的错误。
func streamSQLFile(reader io.Reader, onStatement func(index int, stmt string) error) (int, error) {
splitter := &sqlStreamSplitter{}
buffer := make([]byte, 256*1024)
count := 0
for {
n, err := reader.Read(buffer)
if n > 0 {
stmts := splitter.Feed(buffer[:n])
for _, stmt := range stmts {
if err := onStatement(count, stmt); err != nil {
return count, err
}
count++
}
}
if err == io.EOF {
break
}
if err != nil {
return count, err
}
if n == 0 {
continue
}
}
// 处理文件末尾不以分号结尾的最后一条语句
if last := splitter.Flush(); last != "" {
if err := onStatement(count, last); err != nil {
return count, err
}
count++
}
return count, nil
}