diff --git a/internal/app/methods_file.go b/internal/app/methods_file.go index 3bfccbd..e0f79c2 100644 --- a/internal/app/methods_file.go +++ b/internal/app/methods_file.go @@ -30,9 +30,53 @@ import ( const minExportQueryTimeout = 5 * time.Minute const minClickHouseExportQueryTimeout = 2 * time.Hour const maxSQLFileSizeBytes int64 = 50 * 1024 * 1024 +const sqlFileBatchMaxStatements = 1000 +const sqlFileBatchMaxBytes = 4 * 1024 * 1024 +const sqlFileProgressStatementInterval = 100 +const sqlFileProgressTimeInterval = time.Second var mysqlCreateViewPrefixPattern = regexp.MustCompile(`(?is)^\s*create\s+(?:algorithm\s*=\s*\w+\s+)?(?:definer\s*=\s*(?:` + "`[^`]+`" + `|\S+)\s*@\s*(?:` + "`[^`]+`" + `|\S+)\s+)?(?:sql\s+security\s+(?:definer|invoker)\s+)?view\s+`) +type sqlFileExecutionProgress struct { + Status string + Executed int + Failed int + Total int + BytesRead int64 + CurrentSQL string + Error string +} + +type sqlFileExecutionOptions struct { + DBType string + BatchMaxStatements int + BatchMaxBytes int + OnProgress func(sqlFileExecutionProgress) +} + +type sqlFileExecutionResult struct { + Executed int + Failed int + Errors []string +} + +type sqlFilePendingStatement struct { + Index int + SQL string +} + +type sqlFileStatementExecer interface { + Exec(query string) (int64, error) +} + +type sqlFileContextStatementExecer interface { + ExecContext(ctx context.Context, query string) (int64, error) +} + +type sqlFileBatchStatementExecer interface { + ExecBatchContext(ctx context.Context, query string) (int64, error) +} + type SQLDirectoryEntry struct { Name string `json:"name"` Path string `json:"path"` @@ -242,6 +286,313 @@ func (a *App) WriteSQLFile(filePath string, content string) connection.QueryResu return writeSQLFileByPath(filePath, content) } +func normalizeSQLFileExecutionOptions(options sqlFileExecutionOptions) sqlFileExecutionOptions { + if options.BatchMaxStatements <= 0 { + options.BatchMaxStatements = sqlFileBatchMaxStatements + } + if options.BatchMaxBytes <= 0 { + options.BatchMaxBytes = sqlFileBatchMaxBytes + } + return options +} + +func appendSQLFileBatchStatement(batch []sqlFilePendingStatement, index int, stmt string) []sqlFilePendingStatement { + return append(batch, sqlFilePendingStatement{ + Index: index, + SQL: stmt, + }) +} + +func joinSQLFileBatchStatements(batch []sqlFilePendingStatement) string { + if len(batch) == 0 { + return "" + } + var builder strings.Builder + for i, item := range batch { + if i > 0 { + builder.WriteString(";\n") + } + builder.WriteString(item.SQL) + } + return builder.String() +} + +func sqlFileStatementSnippet(stmt string, maxLen int) string { + snippet := strings.TrimSpace(stmt) + if maxLen > 0 && len(snippet) > maxLen { + return snippet[:maxLen] + "..." + } + return snippet +} + +func execSQLFileStatement(ctx context.Context, execer sqlFileStatementExecer, stmt string) (int64, error) { + if ctxErr := ctx.Err(); ctxErr != nil { + return 0, ctxErr + } + if e, ok := execer.(sqlFileContextStatementExecer); ok { + return e.ExecContext(ctx, stmt) + } + return execer.Exec(stmt) +} + +func isSQLFileBatchableWriteStatement(dbType string, stmt string) bool { + if isReadOnlySQLQuery(dbType, stmt) { + return false + } + switch leadingSQLKeyword(stmt) { + case "insert", "update", "delete", "replace", "merge", "upsert": + return true + default: + return false + } +} + +func sqlFileBatchTransactionSQL(dbType string) (beginSQL string, commitSQL string, rollbackSQL string, ok bool) { + switch strings.ToLower(strings.TrimSpace(dbType)) { + case "mysql", "mariadb", "diros", "starrocks", "sphinx", "oceanbase": + return "START TRANSACTION", "COMMIT", "ROLLBACK", true + case "sqlserver": + return "BEGIN TRANSACTION", "COMMIT TRANSACTION", "ROLLBACK TRANSACTION", true + case "postgres", "kingbase", "highgo", "vastbase", "opengauss", "sqlite", "duckdb", "iris": + return "BEGIN", "COMMIT", "ROLLBACK", true + default: + return "", "", "", false + } +} + +func updateSQLFileTransactionState(inTransaction bool, stmt string) bool { + switch leadingSQLKeyword(stmt) { + case "begin": + return true + case "start": + return strings.Contains(strings.ToLower(stmt), "transaction") + case "commit": + return false + case "rollback": + lower := strings.ToLower(stmt) + if strings.Contains(lower, " rollback to ") || strings.Contains(lower, "rollback to ") { + return inTransaction + } + return false + default: + return inTransaction + } +} + +func executeSQLFileBatch(ctx context.Context, execer sqlFileStatementExecer, batcher sqlFileBatchStatementExecer, dbType string, batchSQL string, useTransaction bool) (bool, error) { + if !useTransaction { + _, err := batcher.ExecBatchContext(ctx, batchSQL) + return false, err + } + + beginSQL, commitSQL, rollbackSQL, ok := sqlFileBatchTransactionSQL(dbType) + if !ok { + _, err := batcher.ExecBatchContext(ctx, batchSQL) + return false, err + } + + if _, err := execSQLFileStatement(ctx, execer, beginSQL); err != nil { + return true, err + } + if _, err := batcher.ExecBatchContext(ctx, batchSQL); err != nil { + if _, rollbackErr := execSQLFileStatement(ctx, execer, rollbackSQL); rollbackErr != nil { + return false, fmt.Errorf("批量执行失败: %v;回滚失败: %w", err, rollbackErr) + } + return true, err + } + if _, err := execSQLFileStatement(ctx, execer, commitSQL); err != nil { + _, _ = execSQLFileStatement(ctx, execer, rollbackSQL) + return false, err + } + return false, nil +} + +func executeSQLFileStream(ctx context.Context, dbInst db.Database, reader io.Reader, options sqlFileExecutionOptions, bytesRead func() int64) (sqlFileExecutionResult, error) { + options = normalizeSQLFileExecutionOptions(options) + var result sqlFileExecutionResult + var batch []sqlFilePendingStatement + var batchBytes int + var lastProgressAt time.Time + var inUserTransaction bool + var useTransactionalBatch bool + execer := sqlFileStatementExecer(dbInst) + batcher, supportsBatch := dbInst.(sqlFileBatchStatementExecer) + if provider, ok := dbInst.(db.SessionExecerProvider); ok { + sessionExecer, err := provider.OpenSessionExecer(ctx) + if err != nil { + return result, err + } + defer sessionExecer.Close() + execer = sessionExecer + if supportsBatch { + var ok bool + batcher, ok = sessionExecer.(sqlFileBatchStatementExecer) + supportsBatch = ok + } + useTransactionalBatch = supportsBatch + } + + readBytes := func() int64 { + if bytesRead == nil { + return 0 + } + return bytesRead() + } + + emitProgress := func(currentSQL string) { + if options.OnProgress == nil { + return + } + total := result.Executed + result.Failed + options.OnProgress(sqlFileExecutionProgress{ + Status: "running", + Executed: result.Executed, + Failed: result.Failed, + Total: total, + BytesRead: readBytes(), + CurrentSQL: currentSQL, + }) + lastProgressAt = time.Now() + } + + shouldEmitProgress := func() bool { + total := result.Executed + result.Failed + if total <= 10 { + return true + } + if total%sqlFileProgressStatementInterval == 0 { + return true + } + return !lastProgressAt.IsZero() && time.Since(lastProgressAt) >= sqlFileProgressTimeInterval + } + + recordError := func(index int, stmt string, err error) { + result.Failed++ + errLog := fmt.Sprintf("第 %d 条语句执行失败: %v\n SQL: %s", index+1, err, sqlFileStatementSnippet(stmt, 200)) + result.Errors = append(result.Errors, errLog) + logger.Warnf("ExecuteSQLFile %s", errLog) + } + + executeSingle := func(item sqlFilePendingStatement) error { + if _, err := execSQLFileStatement(ctx, execer, item.SQL); err != nil { + if ctx.Err() != nil { + return fmt.Errorf("已取消") + } + recordError(item.Index, item.SQL, err) + } else { + result.Executed++ + } + if shouldEmitProgress() { + emitProgress(sqlFileStatementSnippet(item.SQL, 100)) + } + return nil + } + + executeBatchSequentially := func(items []sqlFilePendingStatement) error { + for _, item := range items { + if err := executeSingle(item); err != nil { + return err + } + } + return nil + } + + flushBatch := func() error { + if len(batch) == 0 { + return nil + } + select { + case <-ctx.Done(): + return fmt.Errorf("已取消") + default: + } + + startIndex := batch[0].Index + batchSQL := joinSQLFileBatchStatements(batch) + canFallback, err := executeSQLFileBatch(ctx, execer, batcher, options.DBType, batchSQL, useTransactionalBatch) + if err != nil { + logger.Warnf("ExecuteSQLFile 批量执行 %d 条语句失败,将降级逐条执行:第 %d 条起: %v", len(batch), startIndex+1, err) + pending := append([]sqlFilePendingStatement(nil), batch...) + batch = batch[:0] + batchBytes = 0 + if !canFallback { + return fmt.Errorf("第 %d 条起的批量语句执行失败: %w", startIndex+1, err) + } + return executeBatchSequentially(pending) + } + result.Executed += len(batch) + if shouldEmitProgress() { + emitProgress(sqlFileStatementSnippet(batch[len(batch)-1].SQL, 100)) + } + batch = batch[:0] + batchBytes = 0 + return nil + } + + _, streamErr := streamSQLFile(reader, func(index int, stmt string) error { + select { + case <-ctx.Done(): + return fmt.Errorf("已取消") + default: + } + + stmt = strings.TrimSpace(stmt) + if stmt == "" { + return nil + } + + if supportsBatch && !inUserTransaction && isSQLFileBatchableWriteStatement(options.DBType, stmt) { + stmtBytes := len(stmt) + if len(batch) > 0 && (len(batch) >= options.BatchMaxStatements || batchBytes+2+stmtBytes > options.BatchMaxBytes) { + if err := flushBatch(); err != nil { + return err + } + } + if stmtBytes > options.BatchMaxBytes { + if err := flushBatch(); err != nil { + return err + } + canFallback, err := executeSQLFileBatch(ctx, execer, batcher, options.DBType, stmt, useTransactionalBatch) + if err != nil { + logger.Warnf("ExecuteSQLFile 超大语句批量执行失败,将降级单条执行:第 %d 条: %v", index+1, err) + if !canFallback { + return fmt.Errorf("第 %d 条语句执行失败: %w", index+1, err) + } + return executeSingle(sqlFilePendingStatement{Index: index, SQL: stmt}) + } + result.Executed++ + if shouldEmitProgress() { + emitProgress(sqlFileStatementSnippet(stmt, 100)) + } + return nil + } + batch = appendSQLFileBatchStatement(batch, index, stmt) + if batchBytes == 0 { + batchBytes = stmtBytes + } else { + batchBytes += 2 + stmtBytes + } + return nil + } + + if err := flushBatch(); err != nil { + return err + } + if err := executeSingle(sqlFilePendingStatement{Index: index, SQL: stmt}); err != nil { + return err + } + inUserTransaction = updateSQLFileTransactionState(inUserTransaction, stmt) + return nil + }) + if streamErr != nil { + return result, streamErr + } + if err := flushBatch(); err != nil { + return result, err + } + return result, nil +} + // ExecuteSQLFile 在后端流式读取并执行大 SQL 文件,通过事件推送进度。 // 前端通过 EventsOn("sqlfile:progress", ...) 监听进度。 func (a *App) ExecuteSQLFile(config connection.ConnectionConfig, dbName string, filePath string, jobID string) connection.QueryResult { @@ -317,48 +668,28 @@ func (a *App) ExecuteSQLFile(config connection.ConnectionConfig, dbName string, // 使用 countingReader 追踪已读取字节数 cr := &countingReader{r: f} - var executedCount int - var failedCount int - var errorLogs []string startTime := time.Now() - - _, streamErr := streamSQLFile(cr, func(index int, stmt string) error { - // 检查是否已取消 - select { - case <-ctx.Done(): - return fmt.Errorf("已取消") - default: - } - - // 执行语句 - _, execErr := dbInst.Exec(stmt) - if execErr != nil { - failedCount++ - snippet := stmt - if len(snippet) > 200 { - snippet = snippet[:200] + "..." - } - errLog := fmt.Sprintf("第 %d 条语句执行失败: %v\n SQL: %s", index+1, execErr, snippet) - errorLogs = append(errorLogs, errLog) - logger.Warnf("ExecuteSQLFile %s", errLog) - } else { - executedCount++ - } - - // 每条语句执行后推送进度(但限频:每 100 条或每秒推一次) - total := executedCount + failedCount - if total%100 == 0 || total <= 10 { - snippet := stmt - if len(snippet) > 100 { - snippet = snippet[:100] + "..." - } - emitProgress("running", executedCount, failedCount, total, cr.n, snippet, "") - } - - return nil + execResult, streamErr := executeSQLFileStream(ctx, dbInst, cr, sqlFileExecutionOptions{ + DBType: resolveDDLDBType(runConfig), + OnProgress: func(progress sqlFileExecutionProgress) { + emitProgress( + progress.Status, + progress.Executed, + progress.Failed, + progress.Total, + progress.BytesRead, + progress.CurrentSQL, + progress.Error, + ) + }, + }, func() int64 { + return cr.n }) duration := time.Since(startTime) + executedCount := execResult.Executed + failedCount := execResult.Failed + errorLogs := execResult.Errors if streamErr != nil && streamErr.Error() == "已取消" { emitProgress("cancelled", executedCount, failedCount, executedCount+failedCount, cr.n, "", "用户取消执行") diff --git a/internal/app/methods_file_sql_execution_test.go b/internal/app/methods_file_sql_execution_test.go new file mode 100644 index 0000000..928010f --- /dev/null +++ b/internal/app/methods_file_sql_execution_test.go @@ -0,0 +1,324 @@ +package app + +import ( + "context" + "errors" + "fmt" + "io" + "strings" + "testing" + + "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/db" +) + +type fakeSQLFileBatchDB struct { + batchCalls int + execCalls int + batchQueries []string + execQueries []string + failBatch bool + failExecSQL string + session *fakeSQLFileSessionDB +} + +func (f *fakeSQLFileBatchDB) Connect(config connection.ConnectionConfig) error { + return nil +} + +func (f *fakeSQLFileBatchDB) Close() error { + return nil +} + +func (f *fakeSQLFileBatchDB) Ping() error { + return nil +} + +func (f *fakeSQLFileBatchDB) Query(query string) ([]map[string]interface{}, []string, error) { + return nil, nil, nil +} + +func (f *fakeSQLFileBatchDB) Exec(query string) (int64, error) { + f.execCalls++ + f.execQueries = append(f.execQueries, query) + if f.failExecSQL != "" && strings.Contains(query, f.failExecSQL) { + return 0, errors.New("exec failed") + } + return 1, nil +} + +func (f *fakeSQLFileBatchDB) ExecBatchContext(ctx context.Context, query string) (int64, error) { + f.batchCalls++ + f.batchQueries = append(f.batchQueries, query) + if f.failBatch { + return 0, errors.New("batch failed") + } + return int64(strings.Count(query, "INSERT")), nil +} + +func (f *fakeSQLFileBatchDB) GetDatabases() ([]string, error) { + return nil, nil +} + +func (f *fakeSQLFileBatchDB) GetTables(dbName string) ([]string, error) { + return nil, nil +} + +func (f *fakeSQLFileBatchDB) GetCreateStatement(dbName, tableName string) (string, error) { + return "", nil +} + +func (f *fakeSQLFileBatchDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) { + return nil, nil +} + +func (f *fakeSQLFileBatchDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) { + return nil, nil +} + +func (f *fakeSQLFileBatchDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) { + return nil, nil +} + +func (f *fakeSQLFileBatchDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) { + return nil, nil +} + +func (f *fakeSQLFileBatchDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) { + return nil, nil +} + +var _ db.BatchWriteExecer = (*fakeSQLFileBatchDB)(nil) + +func (f *fakeSQLFileBatchDB) OpenSessionExecer(ctx context.Context) (db.StatementExecer, error) { + f.session = &fakeSQLFileSessionDB{parent: f} + return f.session, nil +} + +type fakeSQLFileSessionDB struct { + parent *fakeSQLFileBatchDB + closed bool +} + +func (s *fakeSQLFileSessionDB) Exec(query string) (int64, error) { + return s.ExecContext(context.Background(), query) +} + +func (s *fakeSQLFileSessionDB) ExecContext(ctx context.Context, query string) (int64, error) { + return s.parent.Exec(query) +} + +func (s *fakeSQLFileSessionDB) ExecBatchContext(ctx context.Context, query string) (int64, error) { + return s.parent.ExecBatchContext(ctx, query) +} + +func (s *fakeSQLFileSessionDB) Close() error { + s.closed = true + return nil +} + +func TestExecuteSQLFileStreamBatchesWriteStatements(t *testing.T) { + fakeDB := &fakeSQLFileBatchDB{} + input := strings.Join([]string{ + "INSERT INTO demo(id) VALUES (1);", + "INSERT INTO demo(id) VALUES (2);", + "INSERT INTO demo(id) VALUES (3);", + }, "\n") + + result, err := executeSQLFileStream(context.Background(), fakeDB, strings.NewReader(input), sqlFileExecutionOptions{ + DBType: "mysql", + BatchMaxStatements: 100, + BatchMaxBytes: 1024, + }, nil) + if err != nil { + t.Fatalf("executeSQLFileStream returned error: %v", err) + } + if result.Executed != 3 || result.Failed != 0 { + t.Fatalf("expected 3 executed and 0 failed, got %#v", result) + } + if fakeDB.batchCalls != 1 { + t.Fatalf("expected one batch call, got %d", fakeDB.batchCalls) + } + if fakeDB.execCalls != 2 { + t.Fatalf("expected transaction wrapper exec calls only, got %d", fakeDB.execCalls) + } + if fakeDB.execQueries[0] != "START TRANSACTION" || fakeDB.execQueries[1] != "COMMIT" { + t.Fatalf("expected transaction wrapper around batch, got %#v", fakeDB.execQueries) + } + if fakeDB.session == nil || !fakeDB.session.closed { + t.Fatalf("expected SQL file import to use and close an isolated session") + } + if !strings.Contains(fakeDB.batchQueries[0], "INSERT INTO demo(id) VALUES (1);\nINSERT INTO demo(id) VALUES (2)") { + t.Fatalf("expected batched SQL to join statements, got %q", fakeDB.batchQueries[0]) + } +} + +func TestExecuteSQLFileStreamFlushesBatchBeforeReadStatement(t *testing.T) { + fakeDB := &fakeSQLFileBatchDB{} + input := strings.Join([]string{ + "INSERT INTO demo(id) VALUES (1);", + "INSERT INTO demo(id) VALUES (2);", + "SELECT * FROM demo;", + "INSERT INTO demo(id) VALUES (3);", + }, "\n") + + result, err := executeSQLFileStream(context.Background(), fakeDB, strings.NewReader(input), sqlFileExecutionOptions{ + DBType: "mysql", + BatchMaxStatements: 100, + BatchMaxBytes: 1024, + }, nil) + if err != nil { + t.Fatalf("executeSQLFileStream returned error: %v", err) + } + if result.Executed != 4 || result.Failed != 0 { + t.Fatalf("expected 4 executed and 0 failed, got %#v", result) + } + if fakeDB.batchCalls != 2 { + t.Fatalf("expected two batch calls around read statement, got %d", fakeDB.batchCalls) + } + if fakeDB.execCalls != 5 { + t.Fatalf("expected transaction wrappers plus one read exec call, got %d", fakeDB.execCalls) + } + if fakeDB.execQueries[2] != "SELECT * FROM demo" { + t.Fatalf("expected read statement to execute outside batch, got %#v", fakeDB.execQueries) + } +} + +func TestExecuteSQLFileStreamFallsBackToSequentialWhenBatchFails(t *testing.T) { + fakeDB := &fakeSQLFileBatchDB{failBatch: true, failExecSQL: "VALUES (2)"} + input := strings.Join([]string{ + "INSERT INTO demo(id) VALUES (1);", + "INSERT INTO demo(id) VALUES (2);", + "INSERT INTO demo(id) VALUES (3);", + }, "\n") + + result, err := executeSQLFileStream(context.Background(), fakeDB, strings.NewReader(input), sqlFileExecutionOptions{ + DBType: "mysql", + BatchMaxStatements: 100, + BatchMaxBytes: 1024, + }, nil) + if err != nil { + t.Fatalf("executeSQLFileStream returned error: %v", err) + } + if result.Executed != 2 || result.Failed != 1 { + t.Fatalf("expected 2 executed and 1 failed, got %#v", result) + } + if fakeDB.batchCalls != 1 { + t.Fatalf("expected one failed batch attempt, got %d", fakeDB.batchCalls) + } + if fakeDB.execCalls != 5 { + t.Fatalf("expected transaction wrapper plus 3 sequential exec calls, got %d", fakeDB.execCalls) + } + if fakeDB.execQueries[0] != "START TRANSACTION" || fakeDB.execQueries[1] != "ROLLBACK" { + t.Fatalf("expected failed batch to roll back before sequential fallback, got %#v", fakeDB.execQueries) + } + if len(result.Errors) != 1 || !strings.Contains(result.Errors[0], "第 2 条语句执行失败") { + t.Fatalf("expected per-statement error for second statement, got %#v", result.Errors) + } +} + +func TestExecuteSQLFileStreamDoesNotBatchSessionControlStatements(t *testing.T) { + fakeDB := &fakeSQLFileBatchDB{} + input := strings.Join([]string{ + "SET FOREIGN_KEY_CHECKS=0;", + "INSERT INTO demo(id) VALUES (1);", + "INSERT INTO demo(id) VALUES (2);", + "CREATE TABLE demo2(id INT);", + "INSERT INTO demo2(id) VALUES (3);", + }, "\n") + + result, err := executeSQLFileStream(context.Background(), fakeDB, strings.NewReader(input), sqlFileExecutionOptions{ + DBType: "mysql", + BatchMaxStatements: 100, + BatchMaxBytes: 1024, + }, nil) + if err != nil { + t.Fatalf("executeSQLFileStream returned error: %v", err) + } + if result.Executed != 5 || result.Failed != 0 { + t.Fatalf("expected 5 executed and 0 failed, got %#v", result) + } + if fakeDB.batchCalls != 2 { + t.Fatalf("expected two DML batch calls split by control/DDL statements, got %d", fakeDB.batchCalls) + } + if fakeDB.execCalls != 6 { + t.Fatalf("expected SET, CREATE, and transaction wrappers to execute sequentially, got %d", fakeDB.execCalls) + } + if fakeDB.execQueries[0] != "SET FOREIGN_KEY_CHECKS=0" || fakeDB.execQueries[3] != "CREATE TABLE demo2(id INT)" { + t.Fatalf("unexpected sequential statements: %#v", fakeDB.execQueries) + } +} + +type chunkedReader struct { + data []byte + step int +} + +func (r *chunkedReader) Read(p []byte) (int, error) { + if len(r.data) == 0 { + return 0, io.EOF + } + n := r.step + if n <= 0 || n > len(r.data) { + n = len(r.data) + } + if n > len(p) { + n = len(p) + } + copy(p, r.data[:n]) + r.data = r.data[n:] + return n, nil +} + +func TestStreamSQLFileHandlesLongSingleLineAcrossChunks(t *testing.T) { + longValue := strings.Repeat("x", 5*1024*1024) + input := fmt.Sprintf("INSERT INTO demo(value) VALUES ('%s');SELECT 1;", longValue) + var statements []string + + count, err := streamSQLFile(&chunkedReader{data: []byte(input), step: 257}, func(index int, stmt string) error { + statements = append(statements, stmt) + return nil + }) + if err != nil { + t.Fatalf("streamSQLFile returned error: %v", err) + } + if count != 2 || len(statements) != 2 { + t.Fatalf("expected 2 statements, got count=%d statements=%d", count, len(statements)) + } + if !strings.HasPrefix(statements[0], "INSERT INTO demo(value)") { + t.Fatalf("expected first statement to be insert, got %.80q", statements[0]) + } + if statements[1] != "SELECT 1" { + t.Fatalf("expected second statement SELECT 1, got %q", statements[1]) + } +} + +func TestStreamSQLFileHandlesSplitTokenBoundaries(t *testing.T) { + input := strings.Join([]string{ + "SELECT 1 -- comment; still comment", + "SELECT 'it''s ok';", + "SELECT $tag$hello;world$tag$;", + "SELECT 2;", + }, "\n") + var statements []string + + count, err := streamSQLFile(&chunkedReader{data: []byte(input), step: 1}, func(index int, stmt string) error { + statements = append(statements, stmt) + return nil + }) + if err != nil { + t.Fatalf("streamSQLFile returned error: %v", err) + } + if count != 3 || len(statements) != 3 { + t.Fatalf("expected 3 statements, got count=%d statements=%#v", count, statements) + } + if statements[0] != "SELECT 1 -- comment; still comment\nSELECT 'it''s ok'" { + t.Fatalf("unexpected first statement: %q", statements[0]) + } + if statements[1] != "SELECT $tag$hello;world$tag$" { + t.Fatalf("unexpected dollar-quoted statement: %q", statements[1]) + } + if statements[2] != "SELECT 2" { + t.Fatalf("unexpected full-width semicolon statement: %q", statements[2]) + } +} diff --git a/internal/app/sql_split_stream.go b/internal/app/sql_split_stream.go index 69a206b..7368494 100644 --- a/internal/app/sql_split_stream.go +++ b/internal/app/sql_split_stream.go @@ -1,7 +1,6 @@ package app import ( - "bufio" "io" "strings" ) @@ -10,20 +9,22 @@ import ( // 调用方通过 Feed(chunk) 逐块喂入数据,通过 Flush() 获取最后一条残余语句。 // 内部维护与 splitSQLStatements 完全一致的状态机逻辑。 type sqlStreamSplitter struct { - cur strings.Builder - inSingle bool - inDouble bool - inBacktick bool - escaped bool + cur strings.Builder + pending string + inSingle bool + inDouble bool + inBacktick bool + escaped bool inLineComment bool inBlockComment bool - dollarTag string + dollarTag string } // Feed 将一个 chunk 喂入拆分器,返回在此 chunk 中完成的 SQL 语句列表。 func (s *sqlStreamSplitter) Feed(chunk []byte) []string { var statements []string - text := string(chunk) + text := s.pending + string(chunk) + s.pending = "" for i := 0; i < len(text); i++ { ch := text[i] @@ -43,6 +44,10 @@ func (s *sqlStreamSplitter) Feed(chunk []byte) []string { // 块注释 if s.inBlockComment { + if ch == '*' && i+1 >= len(text) { + s.pending = text[i:] + break + } s.cur.WriteByte(ch) if ch == '*' && next == '/' { s.cur.WriteByte('/') @@ -58,6 +63,9 @@ func (s *sqlStreamSplitter) Feed(chunk []byte) []string { s.cur.WriteString(s.dollarTag) i += len(s.dollarTag) - 1 s.dollarTag = "" + } else if ch == '$' && len(text[i:]) < len(s.dollarTag) && strings.HasPrefix(s.dollarTag, text[i:]) { + s.pending = text[i:] + break } else { s.cur.WriteByte(ch) } @@ -78,6 +86,10 @@ func (s *sqlStreamSplitter) Feed(chunk []byte) []string { // 字符串开闭 if !s.inDouble && !s.inBacktick && ch == '\'' { + if s.inSingle && i+1 >= len(text) { + s.pending = text[i:] + break + } if s.inSingle && next == '\'' { // SQL 标准转义:两个连续单引号 s.cur.WriteByte(ch) @@ -107,6 +119,10 @@ func (s *sqlStreamSplitter) Feed(chunk []byte) []string { } // 行注释开始 + if ch == '-' && i+1 >= len(text) { + s.pending = text[i:] + break + } if ch == '-' && next == '-' { s.inLineComment = true s.cur.WriteByte(ch) @@ -119,6 +135,10 @@ func (s *sqlStreamSplitter) Feed(chunk []byte) []string { } // 块注释开始 + if ch == '/' && i+1 >= len(text) { + s.pending = text[i:] + break + } if ch == '/' && next == '*' { s.inBlockComment = true s.cur.WriteString("/*") @@ -134,6 +154,10 @@ func (s *sqlStreamSplitter) Feed(chunk []byte) []string { i += len(tag) - 1 continue } + if isIncompleteSQLDollarTag(text[i:]) { + s.pending = text[i:] + break + } } // 分号分隔 @@ -146,6 +170,10 @@ func (s *sqlStreamSplitter) Feed(chunk []byte) []string { continue } // 全角分号 + if ch == 0xEF && i+2 >= len(text) { + s.pending = text[i:] + break + } if ch == 0xEF && i+2 < len(text) && text[i+1] == 0xBC && text[i+2] == 0x9B { stmt := strings.TrimSpace(s.cur.String()) if stmt != "" { @@ -164,37 +192,59 @@ func (s *sqlStreamSplitter) Feed(chunk []byte) []string { // Flush 返回缓冲区中剩余的不完整语句(文件结束时调用)。 func (s *sqlStreamSplitter) Flush() string { + if s.pending != "" { + s.cur.WriteString(s.pending) + s.pending = "" + } stmt := strings.TrimSpace(s.cur.String()) s.cur.Reset() return stmt } +func isIncompleteSQLDollarTag(s string) bool { + if len(s) == 0 || s[0] != '$' { + return false + } + for i := 1; i < len(s); i++ { + c := s[i] + if c == '$' { + return false + } + if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') { + return false + } + } + return true +} + // streamSQLFile 从 reader 中流式读取 SQL 并逐条回调。 // onStatement 返回 error 时停止读取并返回该 error。 // 返回总处理语句数和可能的错误。 func streamSQLFile(reader io.Reader, onStatement func(index int, stmt string) error) (int, error) { splitter := &sqlStreamSplitter{} - scanner := bufio.NewScanner(reader) - // 设置最大 token 为 4MB,处理超长单行 - const maxLineSize = 4 * 1024 * 1024 - scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + buffer := make([]byte, 256*1024) count := 0 - for scanner.Scan() { - line := scanner.Bytes() - // 保持换行符,因为行注释依赖 \n 来结束 - lineWithNewline := append(line, '\n') - stmts := splitter.Feed(lineWithNewline) - for _, stmt := range stmts { - if err := onStatement(count, stmt); err != nil { - return count, err + for { + n, err := reader.Read(buffer) + if n > 0 { + stmts := splitter.Feed(buffer[:n]) + for _, stmt := range stmts { + if err := onStatement(count, stmt); err != nil { + return count, err + } + count++ } - count++ } - } - - if err := scanner.Err(); err != nil { - return count, err + if err == io.EOF { + break + } + if err != nil { + return count, err + } + if n == 0 { + continue + } } // 处理文件末尾不以分号结尾的最后一条语句 diff --git a/internal/db/database.go b/internal/db/database.go index e4494b7..7472e69 100644 --- a/internal/db/database.go +++ b/internal/db/database.go @@ -58,6 +58,55 @@ type BatchWriteExecer interface { ExecBatchContext(ctx context.Context, query string) (int64, error) } +// StatementExecer is a single-session SQL execution handle. +// It is used by long-running import jobs that must preserve session-scoped +// settings across multiple statements. +type StatementExecer interface { + Exec(query string) (int64, error) + ExecContext(ctx context.Context, query string) (int64, error) + Close() error +} + +// SessionExecerProvider is implemented by database/sql based drivers that can +// pin a long-running job to one physical connection. +type SessionExecerProvider interface { + OpenSessionExecer(ctx context.Context) (StatementExecer, error) +} + +type sqlConnStatementExecer struct { + conn *sql.Conn +} + +func NewSQLConnStatementExecer(conn *sql.Conn) StatementExecer { + return &sqlConnStatementExecer{conn: conn} +} + +func (e *sqlConnStatementExecer) ExecContext(ctx context.Context, query string) (int64, error) { + if e == nil || e.conn == nil { + return 0, fmt.Errorf("连接未打开") + } + res, err := e.conn.ExecContext(ctx, query) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func (e *sqlConnStatementExecer) Exec(query string) (int64, error) { + return e.ExecContext(context.Background(), query) +} + +func (e *sqlConnStatementExecer) ExecBatchContext(ctx context.Context, query string) (int64, error) { + return e.ExecContext(ctx, query) +} + +func (e *sqlConnStatementExecer) Close() error { + if e == nil || e.conn == nil { + return nil + } + return e.conn.Close() +} + // BatchApplier 定义了批量变更提交接口。 // 支持批量编辑的驱动实现此接口,用于一次性提交前端 DataGrid 中的增删改操作。 type BatchApplier interface { diff --git a/internal/db/duckdb_impl.go b/internal/db/duckdb_impl.go index eeefdf2..151c3e0 100644 --- a/internal/db/duckdb_impl.go +++ b/internal/db/duckdb_impl.go @@ -101,6 +101,17 @@ func (d *DuckDB) ExecBatchContext(ctx context.Context, query string) (int64, err return res.RowsAffected() } +func (d *DuckDB) OpenSessionExecer(ctx context.Context) (StatementExecer, error) { + if d.conn == nil { + return nil, fmt.Errorf("连接未打开") + } + conn, err := d.conn.Conn(ctx) + if err != nil { + return nil, err + } + return NewSQLConnStatementExecer(conn), nil +} + func (d *DuckDB) ExecContext(ctx context.Context, query string) (int64, error) { if d.conn == nil { return 0, fmt.Errorf("连接未打开") diff --git a/internal/db/highgo_impl.go b/internal/db/highgo_impl.go index a8455f3..ba0b7da 100644 --- a/internal/db/highgo_impl.go +++ b/internal/db/highgo_impl.go @@ -187,6 +187,17 @@ func (h *HighGoDB) ExecBatchContext(ctx context.Context, query string) (int64, e return res.RowsAffected() } +func (h *HighGoDB) OpenSessionExecer(ctx context.Context) (StatementExecer, error) { + if h.conn == nil { + return nil, fmt.Errorf("连接未打开") + } + conn, err := h.conn.Conn(ctx) + if err != nil { + return nil, err + } + return NewSQLConnStatementExecer(conn), nil +} + func (h *HighGoDB) Exec(query string) (int64, error) { if h.conn == nil { return 0, fmt.Errorf("连接未打开") diff --git a/internal/db/iris_impl.go b/internal/db/iris_impl.go index 74746f8..309c931 100644 --- a/internal/db/iris_impl.go +++ b/internal/db/iris_impl.go @@ -239,6 +239,17 @@ func (i *IrisDB) ExecBatchContext(ctx context.Context, query string) (int64, err return i.ExecContext(ctx, query) } +func (i *IrisDB) OpenSessionExecer(ctx context.Context) (StatementExecer, error) { + if i.conn == nil { + return nil, fmt.Errorf("连接未打开") + } + conn, err := i.conn.Conn(ctx) + if err != nil { + return nil, err + } + return NewSQLConnStatementExecer(conn), nil +} + func (i *IrisDB) Exec(query string) (int64, error) { if i.conn == nil { return 0, fmt.Errorf("连接未打开") diff --git a/internal/db/kingbase_impl.go b/internal/db/kingbase_impl.go index 2e4d498..9fa12f1 100644 --- a/internal/db/kingbase_impl.go +++ b/internal/db/kingbase_impl.go @@ -315,6 +315,17 @@ func (k *KingbaseDB) ExecBatchContext(ctx context.Context, query string) (int64, return res.RowsAffected() } +func (k *KingbaseDB) OpenSessionExecer(ctx context.Context) (StatementExecer, error) { + if k.conn == nil { + return nil, fmt.Errorf("连接未打开") + } + conn, err := k.conn.Conn(ctx) + if err != nil { + return nil, err + } + return NewSQLConnStatementExecer(conn), nil +} + func (k *KingbaseDB) Exec(query string) (int64, error) { if k.conn == nil { return 0, fmt.Errorf("连接未打开") diff --git a/internal/db/mariadb_impl.go b/internal/db/mariadb_impl.go index 719caee..ac51f49 100644 --- a/internal/db/mariadb_impl.go +++ b/internal/db/mariadb_impl.go @@ -140,6 +140,17 @@ func (m *MariaDB) ExecBatchContext(ctx context.Context, query string) (int64, er return res.RowsAffected() } +func (m *MariaDB) OpenSessionExecer(ctx context.Context) (StatementExecer, error) { + if m.conn == nil { + return nil, fmt.Errorf("连接未打开") + } + conn, err := m.conn.Conn(ctx) + if err != nil { + return nil, err + } + return NewSQLConnStatementExecer(conn), nil +} + func (m *MariaDB) ExecContext(ctx context.Context, query string) (int64, error) { if m.conn == nil { return 0, fmt.Errorf("连接未打开") diff --git a/internal/db/mysql_impl.go b/internal/db/mysql_impl.go index 895b728..cbae90f 100644 --- a/internal/db/mysql_impl.go +++ b/internal/db/mysql_impl.go @@ -679,6 +679,17 @@ func (m *MySQLDB) ExecBatchContext(ctx context.Context, query string) (int64, er return res.RowsAffected() } +func (m *MySQLDB) OpenSessionExecer(ctx context.Context) (StatementExecer, error) { + if m.conn == nil { + return nil, fmt.Errorf("连接未打开") + } + conn, err := m.conn.Conn(ctx) + if err != nil { + return nil, err + } + return NewSQLConnStatementExecer(conn), nil +} + func (m *MySQLDB) ExecContext(ctx context.Context, query string) (int64, error) { if m.conn == nil { return 0, fmt.Errorf("连接未打开") diff --git a/internal/db/oceanbase_impl.go b/internal/db/oceanbase_impl.go index 68f44fd..4ec7573 100644 --- a/internal/db/oceanbase_impl.go +++ b/internal/db/oceanbase_impl.go @@ -797,6 +797,13 @@ func (o *OceanBaseDB) ExecBatchContext(ctx context.Context, query string) (int64 return o.ExecContext(ctx, query) } +func (o *OceanBaseDB) OpenSessionExecer(ctx context.Context) (StatementExecer, error) { + if p, ok := o.activeDatabase().(SessionExecerProvider); ok { + return p.OpenSessionExecer(ctx) + } + return nil, fmt.Errorf("当前 OceanBase %s 协议不支持独立导入会话", o.protocol) +} + func (o *OceanBaseDB) GetDatabases() ([]string, error) { return o.activeDatabase().GetDatabases() } diff --git a/internal/db/postgres_impl.go b/internal/db/postgres_impl.go index d746f01..327830c 100644 --- a/internal/db/postgres_impl.go +++ b/internal/db/postgres_impl.go @@ -249,6 +249,17 @@ func (p *PostgresDB) ExecBatchContext(ctx context.Context, query string) (int64, return res.RowsAffected() } +func (p *PostgresDB) OpenSessionExecer(ctx context.Context) (StatementExecer, error) { + if p.conn == nil { + return nil, fmt.Errorf("连接未打开") + } + conn, err := p.conn.Conn(ctx) + if err != nil { + return nil, err + } + return NewSQLConnStatementExecer(conn), nil +} + func (p *PostgresDB) ExecContext(ctx context.Context, query string) (int64, error) { if p.conn == nil { return 0, fmt.Errorf("连接未打开") diff --git a/internal/db/sqlite_impl.go b/internal/db/sqlite_impl.go index 7d74610..cfbf997 100644 --- a/internal/db/sqlite_impl.go +++ b/internal/db/sqlite_impl.go @@ -233,6 +233,17 @@ func (s *SQLiteDB) ExecBatchContext(ctx context.Context, query string) (int64, e return res.RowsAffected() } +func (s *SQLiteDB) OpenSessionExecer(ctx context.Context) (StatementExecer, error) { + if s.conn == nil { + return nil, fmt.Errorf("连接未打开") + } + conn, err := s.conn.Conn(ctx) + if err != nil { + return nil, err + } + return NewSQLConnStatementExecer(conn), nil +} + func (s *SQLiteDB) ExecContext(ctx context.Context, query string) (int64, error) { if s.conn == nil { return 0, fmt.Errorf("连接未打开") diff --git a/internal/db/sqlserver_impl.go b/internal/db/sqlserver_impl.go index 2d45384..d682907 100644 --- a/internal/db/sqlserver_impl.go +++ b/internal/db/sqlserver_impl.go @@ -205,6 +205,17 @@ func (s *SqlServerDB) ExecBatchContext(ctx context.Context, query string) (int64 return res.RowsAffected() } +func (s *SqlServerDB) OpenSessionExecer(ctx context.Context) (StatementExecer, error) { + if s.conn == nil { + return nil, fmt.Errorf("连接未打开") + } + conn, err := s.conn.Conn(ctx) + if err != nil { + return nil, err + } + return NewSQLConnStatementExecer(conn), nil +} + func (s *SqlServerDB) Exec(query string) (int64, error) { if s.conn == nil { return 0, fmt.Errorf("连接未打开") diff --git a/internal/db/vastbase_impl.go b/internal/db/vastbase_impl.go index ef0c497..0041f01 100644 --- a/internal/db/vastbase_impl.go +++ b/internal/db/vastbase_impl.go @@ -186,6 +186,17 @@ func (v *VastbaseDB) ExecBatchContext(ctx context.Context, query string) (int64, return res.RowsAffected() } +func (v *VastbaseDB) OpenSessionExecer(ctx context.Context) (StatementExecer, error) { + if v.conn == nil { + return nil, fmt.Errorf("连接未打开") + } + conn, err := v.conn.Conn(ctx) + if err != nil { + return nil, err + } + return NewSQLConnStatementExecer(conn), nil +} + func (v *VastbaseDB) Exec(query string) (int64, error) { if v.conn == nil { return 0, fmt.Errorf("连接未打开")