️ perf(import): 重构导入链路并支持流式批量写入

- 后端新增流式导入流水线,避免预览和导入阶段整文件驻留内存\n- 导入执行优先复用 BatchApplier 按批提交,并在批量失败时回退单行定位错误\n- 导入进度事件兼容未预扫总行数场景,沿用预览总数稳定展示进度\n- 补充导入预览、批量回退和前端进度展示的最小回归测试
This commit is contained in:
Syngnat
2026-06-17 17:26:57 +08:00
parent 4e31d47936
commit e67285fde1
5 changed files with 1113 additions and 398 deletions

View File

@@ -0,0 +1,511 @@
package app
import (
"bufio"
"encoding/csv"
"encoding/json"
"fmt"
"io"
"os"
"sort"
"strings"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/db"
"github.com/xuri/excelize/v2"
)
const (
defaultImportPreviewLimit = 5
defaultImportApplyBatchSize = 1000
)
type importFileConsumer interface {
SetColumns(columns []string) error
ConsumeRow(row map[string]interface{}) error
}
type importPreviewData struct {
Columns []string
TotalRows int
PreviewRows []map[string]interface{}
}
type importProgressState struct {
Current int `json:"current"`
Total int `json:"total,omitempty"`
Success int `json:"success"`
Errors int `json:"errors"`
TotalRowsKnown bool `json:"totalRowsKnown,omitempty"`
}
type importExecutionResult struct {
Success int
Failed int
Total int
ErrorLogs []string
}
type importPreviewCollector struct {
columns []string
totalRows int
previewRows []map[string]interface{}
previewLimit int
}
func newImportPreviewCollector(limit int) *importPreviewCollector {
if limit <= 0 {
limit = defaultImportPreviewLimit
}
return &importPreviewCollector{previewLimit: limit}
}
func (c *importPreviewCollector) SetColumns(columns []string) error {
c.columns = append([]string(nil), columns...)
return nil
}
func (c *importPreviewCollector) ConsumeRow(row map[string]interface{}) error {
c.totalRows++
if len(c.previewRows) < c.previewLimit {
c.previewRows = append(c.previewRows, cloneImportRow(row))
}
return nil
}
func (c *importPreviewCollector) Result() importPreviewData {
return importPreviewData{
Columns: append([]string(nil), c.columns...),
TotalRows: c.totalRows,
PreviewRows: cloneImportRows(c.previewRows),
}
}
type importCollectConsumer struct {
columns []string
rows []map[string]interface{}
}
func (c *importCollectConsumer) SetColumns(columns []string) error {
c.columns = append([]string(nil), columns...)
return nil
}
func (c *importCollectConsumer) ConsumeRow(row map[string]interface{}) error {
c.rows = append(c.rows, cloneImportRow(row))
return nil
}
type importRowWriter interface {
SetColumns(columns []string)
ApplyBatch(rows []map[string]interface{}) error
ApplyOne(row map[string]interface{}) error
BatchEnabled() bool
}
type importDatabaseRowWriter struct {
dbInst db.Database
applier db.BatchApplier
dbType string
tableName string
columns []string
columnTypeMap map[string]string
}
func newImportDatabaseRowWriter(dbInst db.Database, dbType, tableName string, columnTypeMap map[string]string) *importDatabaseRowWriter {
writer := &importDatabaseRowWriter{
dbInst: dbInst,
dbType: dbType,
tableName: tableName,
columnTypeMap: columnTypeMap,
}
if applier, ok := dbInst.(db.BatchApplier); ok {
writer.applier = applier
}
return writer
}
func (w *importDatabaseRowWriter) SetColumns(columns []string) {
w.columns = append([]string(nil), columns...)
}
func (w *importDatabaseRowWriter) BatchEnabled() bool {
return w.applier != nil
}
func (w *importDatabaseRowWriter) ApplyBatch(rows []map[string]interface{}) error {
if w.applier == nil {
return fmt.Errorf("当前数据库类型不支持批量提交")
}
return w.applier.ApplyChanges(w.tableName, connection.ChangeSet{Inserts: cloneImportRows(rows)})
}
func (w *importDatabaseRowWriter) ApplyOne(row map[string]interface{}) error {
if w.applier != nil {
return w.applier.ApplyChanges(w.tableName, connection.ChangeSet{Inserts: []map[string]interface{}{cloneImportRow(row)}})
}
query, err := buildImportInsertQuery(w.dbType, w.tableName, w.columns, row, w.columnTypeMap)
if err != nil {
return err
}
_, err = w.dbInst.Exec(query)
return err
}
type importBatchConsumer struct {
writer importRowWriter
batchSize int
totalRows int
totalRowsKnown bool
report func(importProgressState)
batch []map[string]interface{}
batchStartRow int
currentRow int
successCount int
errorLogs []string
}
func newImportBatchConsumer(writer importRowWriter, batchSize int, totalRows int, totalRowsKnown bool, report func(importProgressState)) *importBatchConsumer {
if batchSize <= 0 {
batchSize = defaultImportApplyBatchSize
}
return &importBatchConsumer{
writer: writer,
batchSize: batchSize,
totalRows: totalRows,
totalRowsKnown: totalRowsKnown,
report: report,
}
}
func (c *importBatchConsumer) SetColumns(columns []string) error {
if c.writer != nil {
c.writer.SetColumns(columns)
}
return nil
}
func (c *importBatchConsumer) ConsumeRow(row map[string]interface{}) error {
c.currentRow++
if len(c.batch) == 0 {
c.batchStartRow = c.currentRow
}
c.batch = append(c.batch, cloneImportRow(row))
if len(c.batch) >= c.batchSize {
return c.flush()
}
return nil
}
func (c *importBatchConsumer) Flush() error {
return c.flush()
}
func (c *importBatchConsumer) Result() importExecutionResult {
return importExecutionResult{
Success: c.successCount,
Failed: len(c.errorLogs),
Total: c.currentRow,
ErrorLogs: append([]string(nil), c.errorLogs...),
}
}
func (c *importBatchConsumer) flush() error {
if len(c.batch) == 0 {
return nil
}
rows := c.batch
startRow := c.batchStartRow
c.batch = nil
c.batchStartRow = 0
if c.writer != nil && c.writer.BatchEnabled() {
if err := c.writer.ApplyBatch(rows); err == nil {
c.successCount += len(rows)
c.emitProgress(startRow + len(rows) - 1)
return nil
}
}
for idx, row := range rows {
if c.writer != nil {
if err := c.writer.ApplyOne(row); err != nil {
c.errorLogs = append(c.errorLogs, fmt.Sprintf("Row %d: %s", startRow+idx, err.Error()))
} else {
c.successCount++
}
}
c.emitProgress(startRow + idx)
}
return nil
}
func (c *importBatchConsumer) emitProgress(current int) {
if c.report == nil {
return
}
c.report(importProgressState{
Current: current,
Total: c.totalRows,
Success: c.successCount,
Errors: len(c.errorLogs),
TotalRowsKnown: c.totalRowsKnown,
})
}
func buildImportPreview(filePath string, previewLimit int) (importPreviewData, error) {
collector := newImportPreviewCollector(previewLimit)
if err := streamImportFile(filePath, collector); err != nil {
return importPreviewData{}, err
}
return collector.Result(), nil
}
func parseImportFile(filePath string) ([]map[string]interface{}, []string, error) {
collector := &importCollectConsumer{}
if err := streamImportFile(filePath, collector); err != nil {
return nil, nil, err
}
return collector.rows, collector.columns, nil
}
func streamImportFile(filePath string, consumer importFileConsumer) error {
lower := strings.ToLower(filePath)
switch {
case strings.HasSuffix(lower, ".json"):
return streamJSONImportFile(filePath, consumer)
case strings.HasSuffix(lower, ".csv"):
return streamCSVImportFile(filePath, consumer)
case strings.HasSuffix(lower, ".xlsx"), strings.HasSuffix(lower, ".xls"):
return streamExcelImportFile(filePath, consumer)
default:
return fmt.Errorf("Unsupported file format")
}
}
func streamJSONImportFile(filePath string, consumer importFileConsumer) error {
f, err := os.Open(filePath)
if err != nil {
return err
}
defer f.Close()
decoder := json.NewDecoder(bufio.NewReader(f))
token, err := decoder.Token()
if err != nil {
return fmt.Errorf("JSON Parse Error: %w", err)
}
delim, ok := token.(json.Delim)
if !ok || delim != '[' {
return fmt.Errorf("JSON Parse Error: root array expected")
}
var columns []string
for decoder.More() {
var raw map[string]interface{}
if err := decoder.Decode(&raw); err != nil {
return fmt.Errorf("JSON Parse Error: %w", err)
}
if columns == nil {
columns = importJSONColumns(raw)
if err := consumer.SetColumns(columns); err != nil {
return err
}
}
if err := consumer.ConsumeRow(normalizeImportMapRow(columns, raw)); err != nil {
return err
}
}
if _, err := decoder.Token(); err != nil {
return fmt.Errorf("JSON Parse Error: %w", err)
}
return nil
}
func streamCSVImportFile(filePath string, consumer importFileConsumer) error {
f, err := os.Open(filePath)
if err != nil {
return err
}
defer f.Close()
reader := csv.NewReader(bufio.NewReader(f))
reader.ReuseRecord = true
header, err := reader.Read()
if err != nil {
if err == io.EOF {
return fmt.Errorf("CSV empty or missing header")
}
return fmt.Errorf("CSV Parse Error: %w", err)
}
columns := cloneImportColumns(header)
if !hasImportUsableColumns(columns) {
return fmt.Errorf("CSV empty or missing header")
}
if err := consumer.SetColumns(columns); err != nil {
return err
}
for {
record, err := reader.Read()
if err != nil {
if err == io.EOF {
return nil
}
return fmt.Errorf("CSV Parse Error: %w", err)
}
if err := consumer.ConsumeRow(buildImportRowFromValues(columns, record)); err != nil {
return err
}
}
}
func streamExcelImportFile(filePath string, consumer importFileConsumer) error {
workbook, err := excelize.OpenFile(filePath)
if err != nil {
return fmt.Errorf("Excel Parse Error: %w", err)
}
defer workbook.Close()
sheetName := workbook.GetSheetName(0)
if sheetName == "" {
return fmt.Errorf("Excel file has no sheets")
}
rows, err := workbook.Rows(sheetName)
if err != nil {
return fmt.Errorf("Excel Read Error: %w", err)
}
defer rows.Close()
if !rows.Next() {
if err := rows.Error(); err != nil {
return fmt.Errorf("Excel Read Error: %w", err)
}
return fmt.Errorf("Excel empty or missing header")
}
header, err := rows.Columns()
if err != nil {
return fmt.Errorf("Excel Read Error: %w", err)
}
columns := cloneImportColumns(header)
if !hasImportUsableColumns(columns) {
return fmt.Errorf("Excel empty or missing header")
}
if err := consumer.SetColumns(columns); err != nil {
return err
}
for rows.Next() {
record, err := rows.Columns()
if err != nil {
return fmt.Errorf("Excel Read Error: %w", err)
}
if err := consumer.ConsumeRow(buildImportRowFromValues(columns, record)); err != nil {
return err
}
}
if err := rows.Error(); err != nil {
return fmt.Errorf("Excel Read Error: %w", err)
}
return nil
}
func buildImportInsertQuery(dbType, tableName string, columns []string, row map[string]interface{}, columnTypeMap map[string]string) (string, error) {
quotedCols := make([]string, 0, len(columns))
values := make([]string, 0, len(columns))
for _, column := range columns {
if strings.TrimSpace(column) == "" {
continue
}
quotedCols = append(quotedCols, quoteIdentByType(dbType, column))
colType := columnTypeMap[normalizeColumnName(column)]
values = append(values, formatImportSQLValue(dbType, colType, row[column]))
}
if len(quotedCols) == 0 {
return "", fmt.Errorf("导入文件缺少有效列头")
}
return fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
quoteQualifiedIdentByType(dbType, tableName),
strings.Join(quotedCols, ", "),
strings.Join(values, ", ")), nil
}
func importJSONColumns(row map[string]interface{}) []string {
columns := make([]string, 0, len(row))
for key := range row {
if strings.TrimSpace(key) == "" {
continue
}
columns = append(columns, key)
}
sort.Strings(columns)
return columns
}
func cloneImportColumns(raw []string) []string {
return append([]string(nil), raw...)
}
func hasImportUsableColumns(columns []string) bool {
for _, column := range columns {
if strings.TrimSpace(column) != "" {
return true
}
}
return false
}
func buildImportRowFromValues(columns []string, values []string) map[string]interface{} {
row := make(map[string]interface{}, len(columns))
for idx, column := range columns {
if strings.TrimSpace(column) == "" {
continue
}
if idx >= len(values) {
row[column] = nil
continue
}
if values[idx] == "NULL" {
row[column] = nil
continue
}
row[column] = values[idx]
}
return row
}
func normalizeImportMapRow(columns []string, raw map[string]interface{}) map[string]interface{} {
row := make(map[string]interface{}, len(columns))
for _, column := range columns {
if value, ok := raw[column]; ok {
row[column] = value
continue
}
row[column] = nil
}
return row
}
func cloneImportRow(row map[string]interface{}) map[string]interface{} {
if row == nil {
return nil
}
cloned := make(map[string]interface{}, len(row))
for key, value := range row {
cloned[key] = value
}
return cloned
}
func cloneImportRows(rows []map[string]interface{}) []map[string]interface{} {
if len(rows) == 0 {
return nil
}
cloned := make([]map[string]interface{}, 0, len(rows))
for _, row := range rows {
cloned = append(cloned, cloneImportRow(row))
}
return cloned
}

View File

@@ -1668,21 +1668,15 @@ func (a *App) PreviewImportFile(filePath string) connection.QueryResult {
return connection.QueryResult{Success: false, Message: "文件路径不能为空"}
}
rows, columns, err := parseImportFile(filePath)
preview, err := buildImportPreview(filePath, defaultImportPreviewLimit)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
totalRows := len(rows)
previewRows := rows
if len(rows) > 5 {
previewRows = rows[:5]
}
result := map[string]interface{}{
"columns": columns,
"totalRows": totalRows,
"previewRows": previewRows,
"columns": preview.Columns,
"totalRows": preview.TotalRows,
"previewRows": preview.PreviewRows,
"filePath": filePath,
}
@@ -1712,98 +1706,6 @@ func (a *App) ImportData(config connection.ConnectionConfig, dbName, tableName s
return connection.QueryResult{Success: true, Data: map[string]interface{}{"filePath": selection}}
}
// parseImportFile 解析导入文件,返回数据行和列名
func parseImportFile(filePath string) ([]map[string]interface{}, []string, error) {
var rows []map[string]interface{}
var columns []string
lower := strings.ToLower(filePath)
if strings.HasSuffix(lower, ".json") {
f, err := os.Open(filePath)
if err != nil {
return nil, nil, err
}
defer f.Close()
decoder := json.NewDecoder(f)
if err := decoder.Decode(&rows); err != nil {
return nil, nil, fmt.Errorf("JSON Parse Error: %w", err)
}
if len(rows) > 0 {
for k := range rows[0] {
columns = append(columns, k)
}
}
} else if strings.HasSuffix(lower, ".csv") {
f, err := os.Open(filePath)
if err != nil {
return nil, nil, err
}
defer f.Close()
reader := csv.NewReader(f)
records, err := reader.ReadAll()
if err != nil {
return nil, nil, fmt.Errorf("CSV Parse Error: %w", err)
}
if len(records) < 2 {
return nil, nil, fmt.Errorf("CSV empty or missing header")
}
columns = records[0]
for _, record := range records[1:] {
row := make(map[string]interface{})
for i, val := range record {
if i < len(columns) {
if val == "NULL" {
row[columns[i]] = nil
} else {
row[columns[i]] = val
}
}
}
rows = append(rows, row)
}
} else if strings.HasSuffix(lower, ".xlsx") || strings.HasSuffix(lower, ".xls") {
xlsx, err := excelize.OpenFile(filePath)
if err != nil {
return nil, nil, fmt.Errorf("Excel Parse Error: %w", err)
}
defer xlsx.Close()
sheetName := xlsx.GetSheetName(0)
if sheetName == "" {
return nil, nil, fmt.Errorf("Excel file has no sheets")
}
xlRows, err := xlsx.GetRows(sheetName)
if err != nil {
return nil, nil, fmt.Errorf("Excel Read Error: %w", err)
}
if len(xlRows) < 2 {
return nil, nil, fmt.Errorf("Excel empty or missing header")
}
columns = xlRows[0]
for _, record := range xlRows[1:] {
row := make(map[string]interface{})
for i, val := range record {
if i < len(columns) && columns[i] != "" {
if val == "NULL" {
row[columns[i]] = nil
} else {
row[columns[i]] = val
}
}
}
if len(row) > 0 {
rows = append(rows, row)
}
}
} else {
return nil, nil, fmt.Errorf("Unsupported file format")
}
return rows, columns, nil
}
func normalizeColumnName(name string) string {
return strings.ToLower(strings.TrimSpace(name))
}
@@ -2125,15 +2027,6 @@ func formatImportSQLValue(dbType, columnType string, value interface{}) string {
// ImportDataWithProgress 执行导入并发送进度事件
func (a *App) ImportDataWithProgress(config connection.ConnectionConfig, dbName, tableName, filePath string) connection.QueryResult {
rows, columns, err := parseImportFile(filePath)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
if len(rows) == 0 {
return connection.QueryResult{Success: true, Message: "无可导入数据"}
}
runConfig := normalizeRunConfig(config, dbName)
dbInst, err := a.getDatabase(runConfig)
if err != nil {
@@ -2147,55 +2040,31 @@ func (a *App) ImportDataWithProgress(config connection.ConnectionConfig, dbName,
columnTypeMap = buildImportColumnTypeMap(defs)
}
totalRows := len(rows)
successCount := 0
var errorLogs []string
quotedCols := make([]string, len(columns))
for i, c := range columns {
quotedCols[i] = quoteIdentByType(dbType, c)
writer := newImportDatabaseRowWriter(dbInst, dbType, tableName, columnTypeMap)
consumer := newImportBatchConsumer(writer, defaultImportApplyBatchSize, 0, false, func(state importProgressState) {
runtime.EventsEmit(a.ctx, "import:progress", state)
})
if err := streamImportFile(filePath, consumer); err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
if err := consumer.Flush(); err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
for idx, row := range rows {
var values []string
for _, col := range columns {
val := row[col]
colType := columnTypeMap[normalizeColumnName(col)]
values = append(values, formatImportSQLValue(dbType, colType, val))
}
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
quoteQualifiedIdentByType(dbType, tableName),
strings.Join(quotedCols, ", "),
strings.Join(values, ", "))
_, err := dbInst.Exec(query)
if err != nil {
errorLogs = append(errorLogs, fmt.Sprintf("Row %d: %s", idx+1, err.Error()))
} else {
successCount++
}
// 每 10 行发送一次进度事件
if (idx+1)%10 == 0 || idx == totalRows-1 {
runtime.EventsEmit(a.ctx, "import:progress", map[string]interface{}{
"current": idx + 1,
"total": totalRows,
"success": successCount,
"errors": len(errorLogs),
})
}
resultData := consumer.Result()
if resultData.Total == 0 {
return connection.QueryResult{Success: true, Message: "无可导入数据"}
}
result := map[string]interface{}{
"success": successCount,
"failed": len(errorLogs),
"total": totalRows,
"errorLogs": errorLogs,
"errorSummary": fmt.Sprintf("Imported: %d, Failed: %d", successCount, len(errorLogs)),
"success": resultData.Success,
"failed": resultData.Failed,
"total": resultData.Total,
"errorLogs": resultData.ErrorLogs,
"errorSummary": fmt.Sprintf("Imported: %d, Failed: %d", resultData.Success, resultData.Failed),
}
return connection.QueryResult{Success: true, Data: result, Message: fmt.Sprintf("Imported: %d, Failed: %d", successCount, len(errorLogs))}
return connection.QueryResult{Success: true, Data: result, Message: fmt.Sprintf("Imported: %d, Failed: %d", resultData.Success, resultData.Failed)}
}
func (a *App) ApplyChanges(config connection.ConnectionConfig, dbName, tableName string, changes connection.ChangeSet) connection.QueryResult {

View File

@@ -2,8 +2,11 @@ package app
import (
"errors"
"fmt"
"os"
"path/filepath"
"reflect"
"strings"
"testing"
)
@@ -31,3 +34,146 @@ func TestReadImportedConnectionConfigFileRejectsOversizedFiles(t *testing.T) {
})
}
}
func TestBuildImportPreviewCSVStreamKeepsFirstFiveRows(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "users.csv")
var builder strings.Builder
builder.WriteString("id,name\n")
for i := 1; i <= 7; i++ {
builder.WriteString(fmt.Sprintf("%d,user_%d\n", i, i))
}
if err := os.WriteFile(path, []byte(builder.String()), 0o600); err != nil {
t.Fatalf("write csv: %v", err)
}
preview, err := buildImportPreview(path, 5)
if err != nil {
t.Fatalf("buildImportPreview returned error: %v", err)
}
if !reflect.DeepEqual(preview.Columns, []string{"id", "name"}) {
t.Fatalf("unexpected columns: %#v", preview.Columns)
}
if preview.TotalRows != 7 {
t.Fatalf("expected 7 rows, got %d", preview.TotalRows)
}
if len(preview.PreviewRows) != 5 {
t.Fatalf("expected 5 preview rows, got %d", len(preview.PreviewRows))
}
if got := preview.PreviewRows[0]["name"]; got != "user_1" {
t.Fatalf("expected first preview row name user_1, got %#v", got)
}
if got := preview.PreviewRows[4]["id"]; got != "5" {
t.Fatalf("expected fifth preview row id 5, got %#v", got)
}
}
func TestBuildImportRowFromValuesPreservesPositionsWhenHeaderContainsBlankColumns(t *testing.T) {
row := buildImportRowFromValues([]string{"id", "", "name"}, []string{"1", "ignored", "alice"})
if got := row["id"]; got != "1" {
t.Fatalf("expected id to stay aligned, got %#v", got)
}
if got := row["name"]; got != "alice" {
t.Fatalf("expected name to stay aligned, got %#v", got)
}
if _, ok := row[""]; ok {
t.Fatal("blank header column should not be written into row map")
}
}
type fakeImportRowWriter struct {
columns []string
batchCalls int
singleCalls int
batchSizes []int
batchErr error
singleErrByRowID map[interface{}]error
}
func (w *fakeImportRowWriter) SetColumns(columns []string) {
w.columns = append([]string(nil), columns...)
}
func (w *fakeImportRowWriter) ApplyBatch(rows []map[string]interface{}) error {
w.batchCalls++
w.batchSizes = append(w.batchSizes, len(rows))
return w.batchErr
}
func (w *fakeImportRowWriter) ApplyOne(row map[string]interface{}) error {
w.singleCalls++
if err, ok := w.singleErrByRowID[row["id"]]; ok {
return err
}
return nil
}
func (w *fakeImportRowWriter) BatchEnabled() bool {
return true
}
func TestImportBatchConsumerUsesBatchWriterInConfiguredBatches(t *testing.T) {
writer := &fakeImportRowWriter{}
consumer := newImportBatchConsumer(writer, 1000, 1201, true, nil)
if err := consumer.SetColumns([]string{"id"}); err != nil {
t.Fatalf("SetColumns returned error: %v", err)
}
for i := 1; i <= 1201; i++ {
if err := consumer.ConsumeRow(map[string]interface{}{"id": i}); err != nil {
t.Fatalf("ConsumeRow(%d) returned error: %v", i, err)
}
}
if err := consumer.Flush(); err != nil {
t.Fatalf("Flush returned error: %v", err)
}
if writer.batchCalls != 2 {
t.Fatalf("expected 2 batch calls, got %d", writer.batchCalls)
}
if !reflect.DeepEqual(writer.batchSizes, []int{1000, 201}) {
t.Fatalf("unexpected batch sizes: %#v", writer.batchSizes)
}
result := consumer.Result()
if result.Success != 1201 || result.Failed != 0 || result.Total != 1201 {
t.Fatalf("unexpected result: %#v", result)
}
if writer.singleCalls != 0 {
t.Fatalf("expected no single-row fallback, got %d calls", writer.singleCalls)
}
}
func TestImportBatchConsumerFallsBackToSingleRowsWhenBatchFails(t *testing.T) {
writer := &fakeImportRowWriter{
batchErr: fmt.Errorf("batch failed"),
singleErrByRowID: map[interface{}]error{
2: fmt.Errorf("duplicate key"),
},
}
consumer := newImportBatchConsumer(writer, 1000, 3, true, nil)
if err := consumer.SetColumns([]string{"id"}); err != nil {
t.Fatalf("SetColumns returned error: %v", err)
}
for i := 1; i <= 3; i++ {
if err := consumer.ConsumeRow(map[string]interface{}{"id": i}); err != nil {
t.Fatalf("ConsumeRow(%d) returned error: %v", i, err)
}
}
if err := consumer.Flush(); err != nil {
t.Fatalf("Flush returned error: %v", err)
}
result := consumer.Result()
if result.Success != 2 || result.Failed != 1 || result.Total != 3 {
t.Fatalf("unexpected result: %#v", result)
}
if writer.batchCalls != 1 {
t.Fatalf("expected 1 batch call, got %d", writer.batchCalls)
}
if writer.singleCalls != 3 {
t.Fatalf("expected 3 single-row fallback calls, got %d", writer.singleCalls)
}
if len(result.ErrorLogs) != 1 || result.ErrorLogs[0] != "Row 2: duplicate key" {
t.Fatalf("unexpected error logs: %#v", result.ErrorLogs)
}
}