From e67285fde1357d5f902ca6d2ec4a765ec288a1da Mon Sep 17 00:00:00 2001 From: Syngnat Date: Wed, 17 Jun 2026 17:26:57 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20perf(import):=20=E9=87=8D?= =?UTF-8?q?=E6=9E=84=E5=AF=BC=E5=85=A5=E9=93=BE=E8=B7=AF=E5=B9=B6=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E6=B5=81=E5=BC=8F=E6=89=B9=E9=87=8F=E5=86=99=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 后端新增流式导入流水线,避免预览和导入阶段整文件驻留内存\n- 导入执行优先复用 BatchApplier 按批提交,并在批量失败时回退单行定位错误\n- 导入进度事件兼容未预扫总行数场景,沿用预览总数稳定展示进度\n- 补充导入预览、批量回退和前端进度展示的最小回归测试 --- .../components/ImportPreviewModal.test.tsx | 109 +++- .../src/components/ImportPreviewModal.tsx | 570 +++++++++++------- internal/app/import_pipeline.go | 511 ++++++++++++++++ internal/app/methods_file.go | 175 +----- internal/app/methods_file_import_test.go | 146 +++++ 5 files changed, 1113 insertions(+), 398 deletions(-) create mode 100644 internal/app/import_pipeline.go diff --git a/frontend/src/components/ImportPreviewModal.test.tsx b/frontend/src/components/ImportPreviewModal.test.tsx index ec943f9..ae04525 100644 --- a/frontend/src/components/ImportPreviewModal.test.tsx +++ b/frontend/src/components/ImportPreviewModal.test.tsx @@ -9,7 +9,11 @@ import ImportPreviewModal from "./ImportPreviewModal"; const mocks = vi.hoisted(() => ({ previewImportFile: vi.fn(), importDataWithProgress: vi.fn(), - eventsOn: vi.fn(() => vi.fn()), + progressHandler: null as ((data: any) => void) | null, + eventsOn: vi.fn((_event: string, handler: (data: any) => void) => { + mocks.progressHandler = handler; + return vi.fn(); + }), eventsOff: vi.fn(), storeState: { connections: [ @@ -29,7 +33,8 @@ const mocks = vi.hoisted(() => ({ })); vi.mock("../store", () => ({ - useStore: (selector: (state: typeof mocks.storeState) => unknown) => selector(mocks.storeState), + useStore: (selector: (state: typeof mocks.storeState) => unknown) => + selector(mocks.storeState), })); vi.mock("../i18n/runtime", () => ({ @@ -59,12 +64,25 @@ vi.mock("antd", async () => { footer?: React.ReactNode; open?: boolean; title?: React.ReactNode; - }) => (open ? React.createElement("section", null, title, children, footer) : null); - const Table = ({ columns, dataSource }: { columns?: any[]; dataSource?: any[] }) => + }) => + open ? React.createElement("section", null, title, children, footer) : null; + const Table = ({ + columns, + dataSource, + }: { + columns?: any[]; + dataSource?: any[]; + }) => React.createElement( "div", null, - columns?.map((column) => React.createElement("span", { key: column.key || column.dataIndex }, column.title)), + columns?.map((column) => + React.createElement( + "span", + { key: column.key || column.dataIndex }, + column.title, + ), + ), dataSource?.map((row, index) => React.createElement( "div", @@ -78,12 +96,24 @@ vi.mock("antd", async () => { return { Modal, Table, - Alert: ({ message, description }: { message?: React.ReactNode; description?: React.ReactNode }) => - React.createElement("div", null, message, description), - Progress: ({ percent }: { percent: number }) => React.createElement("div", null, `${percent}%`), - Button: ({ children, onClick }: { children?: React.ReactNode; onClick?: () => void }) => - React.createElement("button", { onClick }, children), - Space: ({ children }: { children?: React.ReactNode }) => React.createElement("div", null, children), + Alert: ({ + message, + description, + }: { + message?: React.ReactNode; + description?: React.ReactNode; + }) => React.createElement("div", null, message, description), + Progress: ({ percent }: { percent: number }) => + React.createElement("div", null, `${percent}%`), + Button: ({ + children, + onClick, + }: { + children?: React.ReactNode; + onClick?: () => void; + }) => React.createElement("button", { onClick }, children), + Space: ({ children }: { children?: React.ReactNode }) => + React.createElement("div", null, children), }; }); @@ -99,7 +129,8 @@ vi.mock("@ant-design/icons", async () => { const textContent = (node: any): string => { if (node === null || node === undefined) return ""; if (typeof node === "string" || typeof node === "number") return String(node); - if (Array.isArray(node)) return node.map((item) => textContent(item)).join(""); + if (Array.isArray(node)) + return node.map((item) => textContent(item)).join(""); return textContent(node.children || []); }; @@ -146,7 +177,9 @@ describe("ImportPreviewModal i18n", () => { expect(renderedText).toContain("Import data preview"); expect(renderedText).toContain("12 rows and 2 fields"); - expect(renderedText).toContain("The first 5 rows are shown below. Start the import after confirming the data."); + expect(renderedText).toContain( + "The first 5 rows are shown below. Start the import after confirming the data.", + ); expect(renderedText).toContain("Field list:"); expect(renderedText).toContain("Data preview (first 5 rows):"); expect(renderedText).toContain("Cancel"); @@ -156,7 +189,10 @@ describe("ImportPreviewModal i18n", () => { }); it("does not keep migrated Chinese UI literals in ImportPreviewModal source", () => { - const source = readFileSync(new URL("./ImportPreviewModal.tsx", import.meta.url), "utf8"); + const source = readFileSync( + new URL("./ImportPreviewModal.tsx", import.meta.url), + "utf8", + ); expect(source).not.toContain("导入数据预览"); expect(source).not.toContain("开始导入"); @@ -166,4 +202,49 @@ describe("ImportPreviewModal i18n", () => { expect(source).not.toContain("正在导入数据..."); expect(source).not.toContain("错误日志:"); }); + + it("keeps preview total when progress events omit total rows", async () => { + let resolveImport!: (value: any) => void; + mocks.importDataWithProgress.mockImplementation( + () => + new Promise((resolve) => { + resolveImport = resolve; + }), + ); + + const renderer = await renderImportPreview(); + const button = renderer.root + .findAllByType("button") + .find((node) => textContent(node.props.children) === "Start import"); + expect(button).toBeDefined(); + + await act(async () => { + button?.props.onClick(); + await Promise.resolve(); + }); + + expect(mocks.progressHandler).toBeTypeOf("function"); + + await act(async () => { + mocks.progressHandler?.({ + current: 3, + total: 0, + success: 3, + errors: 0, + totalRowsKnown: false, + }); + await Promise.resolve(); + }); + + expect(textContent(renderer.toJSON())).toContain("Processed 3 / 12 rows"); + expect(textContent(renderer.toJSON())).toContain("25%"); + + await act(async () => { + resolveImport({ + success: true, + data: { success: 3, failed: 0, total: 12, errorLogs: [] }, + }); + await Promise.resolve(); + }); + }); }); diff --git a/frontend/src/components/ImportPreviewModal.tsx b/frontend/src/components/ImportPreviewModal.tsx index 3d8317e..340bfec 100644 --- a/frontend/src/components/ImportPreviewModal.tsx +++ b/frontend/src/components/ImportPreviewModal.tsx @@ -1,254 +1,362 @@ -import React, { useState, useEffect } from 'react'; -import { Modal, Table, Alert, Progress, Button, Space } from 'antd'; -import { CheckCircleOutlined, CloseCircleOutlined } from '@ant-design/icons'; -import { PreviewImportFile, ImportDataWithProgress } from '../../wailsjs/go/app/App'; -import { EventsOn, EventsOff } from '../../wailsjs/runtime/runtime'; -import { useStore } from '../store'; -import { t as defaultTranslate } from '../i18n'; -import { useOptionalI18n } from '../i18n/provider'; -import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig'; +import React, { useState, useEffect } from "react"; +import { Modal, Table, Alert, Progress, Button, Space } from "antd"; +import { CheckCircleOutlined, CloseCircleOutlined } from "@ant-design/icons"; +import { + PreviewImportFile, + ImportDataWithProgress, +} from "../../wailsjs/go/app/App"; +import { EventsOn, EventsOff } from "../../wailsjs/runtime/runtime"; +import { useStore } from "../store"; +import { t as defaultTranslate } from "../i18n"; +import { useOptionalI18n } from "../i18n/provider"; +import { buildRpcConnectionConfig } from "../utils/connectionRpcConfig"; interface ImportPreviewModalProps { - visible: boolean; - filePath: string; - connectionId: string; - dbName: string; - tableName: string; - onClose: () => void; - onSuccess: () => void; + visible: boolean; + filePath: string; + connectionId: string; + dbName: string; + tableName: string; + onClose: () => void; + onSuccess: () => void; } interface PreviewData { - columns: string[]; - totalRows: number; - previewRows: any[]; + columns: string[]; + totalRows: number; + previewRows: any[]; } interface ImportProgress { - current: number; - total: number; - success: number; - errors: number; + current: number; + total: number; + success: number; + errors: number; + totalRowsKnown?: boolean; } const ImportPreviewModal: React.FC = ({ - visible, - filePath, - connectionId, - dbName, - tableName, - onClose, - onSuccess + visible, + filePath, + connectionId, + dbName, + tableName, + onClose, + onSuccess, }) => { - const i18n = useOptionalI18n(); - const t = i18n?.t ?? defaultTranslate; - const connections = useStore(state => state.connections); - const [loading, setLoading] = useState(true); - const [previewData, setPreviewData] = useState(null); - const [error, setError] = useState(null); - const [importing, setImporting] = useState(false); - const [progress, setProgress] = useState(null); - const [importResult, setImportResult] = useState(null); + const i18n = useOptionalI18n(); + const t = i18n?.t ?? defaultTranslate; + const connections = useStore((state) => state.connections); + const [loading, setLoading] = useState(true); + const [previewData, setPreviewData] = useState(null); + const [error, setError] = useState(null); + const [importing, setImporting] = useState(false); + const [progress, setProgress] = useState(null); + const [importResult, setImportResult] = useState(null); - useEffect(() => { - if (visible && filePath) { - loadPreview(); - } - }, [visible, filePath]); + useEffect(() => { + if (visible && filePath) { + loadPreview(); + } + }, [visible, filePath]); - useEffect(() => { - if (importing) { - const unsubscribe = EventsOn('import:progress', (data: ImportProgress) => { - setProgress(data); - }); - return () => { - EventsOff('import:progress'); + useEffect(() => { + if (importing) { + const unsubscribe = EventsOn( + "import:progress", + (data: ImportProgress) => { + setProgress((prev) => { + const fallbackTotal = prev?.total || previewData?.totalRows || 0; + const nextTotal = + typeof data.total === "number" && data.total > 0 + ? data.total + : fallbackTotal; + return { + current: data.current ?? prev?.current ?? 0, + total: nextTotal, + success: data.success ?? prev?.success ?? 0, + errors: data.errors ?? prev?.errors ?? 0, + totalRowsKnown: data.totalRowsKnown ?? nextTotal > 0, }; + }); + }, + ); + return () => { + unsubscribe?.(); + EventsOff("import:progress"); + }; + } + }, [importing, previewData?.totalRows]); + + const loadPreview = async () => { + setLoading(true); + setError(null); + try { + const res = await PreviewImportFile(filePath); + if (res.success && res.data) { + setPreviewData({ + columns: res.data.columns || [], + totalRows: res.data.totalRows || 0, + previewRows: res.data.previewRows || [], + }); + } else { + setError(res.message || t("import_preview.error.preview_failed")); + } + } catch (e: any) { + setError( + t("import_preview.error.preview_failed_detail", { + detail: String(e?.message || e), + }), + ); + } finally { + setLoading(false); + } + }; + + const handleImport = async () => { + if (!previewData) return; + + setImporting(true); + setProgress({ + current: 0, + total: previewData.totalRows, + success: 0, + errors: 0, + }); + setImportResult(null); + + try { + const conn = connections.find((c) => c.id === connectionId); + if (!conn) { + setError(t("import_preview.error.connection_config_not_found")); + setImporting(false); + return; + } + + const config = { + ...conn.config, + port: Number(conn.config.port), + password: conn.config.password || "", + database: conn.config.database || "", + useSSH: conn.config.useSSH || false, + ssh: conn.config.ssh || { + host: "", + port: 22, + user: "", + password: "", + keyPath: "", + }, + }; + + const res = await ImportDataWithProgress( + buildRpcConnectionConfig(config) as any, + dbName, + tableName, + filePath, + ); + + if (res.success && res.data) { + setImportResult(res.data); + if (res.data.failed === 0) { + onSuccess(); } - }, [importing]); + } else { + setError(res.message || t("import_preview.error.import_failed")); + } + } catch (e: any) { + setError( + t("import_preview.error.import_failed_detail", { + detail: String(e?.message || e), + }), + ); + } finally { + setImporting(false); + } + }; - const loadPreview = async () => { - setLoading(true); - setError(null); - try { - const res = await PreviewImportFile(filePath); - if (res.success && res.data) { - setPreviewData({ - columns: res.data.columns || [], - totalRows: res.data.totalRows || 0, - previewRows: res.data.previewRows || [] - }); - } else { - setError(res.message || t('import_preview.error.preview_failed')); - } - } catch (e: any) { - setError(t('import_preview.error.preview_failed_detail', { detail: String(e?.message || e) })); - } finally { - setLoading(false); - } - }; - - const handleImport = async () => { - if (!previewData) return; - - setImporting(true); - setProgress({ current: 0, total: previewData.totalRows, success: 0, errors: 0 }); - setImportResult(null); - - try { - const conn = connections.find(c => c.id === connectionId); - if (!conn) { - setError(t('import_preview.error.connection_config_not_found')); - setImporting(false); - return; - } - - const config = { - ...conn.config, - port: Number(conn.config.port), - password: conn.config.password || '', - database: conn.config.database || '', - useSSH: conn.config.useSSH || false, - ssh: conn.config.ssh || { host: '', port: 22, user: '', password: '', keyPath: '' } - }; - - const res = await ImportDataWithProgress(buildRpcConnectionConfig(config) as any, dbName, tableName, filePath); - - if (res.success && res.data) { - setImportResult(res.data); - if (res.data.failed === 0) { - onSuccess(); - } - } else { - setError(res.message || t('import_preview.error.import_failed')); - } - } catch (e: any) { - setError(t('import_preview.error.import_failed_detail', { detail: String(e?.message || e) })); - } finally { - setImporting(false); - } - }; - - const columns = previewData?.columns.map(col => ({ - title: col, - dataIndex: col, - key: col, - ellipsis: true, - width: 150 + const columns = + previewData?.columns.map((col) => ({ + title: col, + dataIndex: col, + key: col, + ellipsis: true, + width: 150, })) || []; - const progressPercent = progress ? Math.round((progress.current / progress.total) * 100) : 0; + const progressPercent = + progress && progress.total > 0 + ? Math.round((progress.current / progress.total) * 100) + : 0; - return ( - - - - ) : importing ? null : ( - - - - - ) + return ( + + + + ) : importing ? null : ( + + + + + ) + } + > + {error && ( + + )} + + {loading && ( +
+ {t("import_preview.status.loading_preview")} +
+ )} + + {!loading && previewData && !importing && !importResult && ( + <> + +
+ {t("import_preview.preview.field_list")} +
+
+ {previewData.columns.join(", ")} +
+
+ {t("import_preview.preview.table_title")} +
+ + + )} + + {importing && progress && ( +
+
+ {t("import_preview.status.importing")} +
+ +
+ {t("import_preview.progress.processed_rows", { + current: progress.current, + total: progress.total, + })} + + {" "} + {t("import_preview.progress.success_count", { + count: progress.success, + })} + + {progress.errors > 0 && ( + + {" "} + {t("import_preview.progress.error_count", { + count: progress.errors, + })} + + )} +
+
+ )} + + {importResult && ( +
+ +
+ {t("import_preview.result.success_rows", { + count: importResult.success, + })} +
+ {importResult.failed > 0 && ( +
+ {t("import_preview.result.failed_rows", { + count: importResult.failed, + })} +
+ )} +
} - > - {error && } - - {loading &&
{t('import_preview.status.loading_preview')}
} - - {!loading && previewData && !importing && !importResult && ( - <> - -
{t('import_preview.preview.field_list')}
-
- {previewData.columns.join(', ')} -
-
{t('import_preview.preview.table_title')}
-
- - )} - - {importing && progress && ( -
-
- {t('import_preview.status.importing')} -
- -
- {t('import_preview.progress.processed_rows', { current: progress.current, total: progress.total })} - - {t('import_preview.progress.success_count', { count: progress.success })} - - {progress.errors > 0 && ( - - {t('import_preview.progress.error_count', { count: progress.errors })} - - )} -
-
- )} - - {importResult && ( -
- -
{t('import_preview.result.success_rows', { count: importResult.success })}
- {importResult.failed > 0 &&
{t('import_preview.result.failed_rows', { count: importResult.failed })}
} -
- } - showIcon - style={{ marginBottom: 16 }} - /> - {importResult.errorLogs && importResult.errorLogs.length > 0 && ( - <> -
{t('import_preview.result.error_logs')}
-
- {importResult.errorLogs.map((log: string, idx: number) => ( -
{log}
- ))} -
- - )} - - )} - - ); + showIcon + style={{ marginBottom: 16 }} + /> + {importResult.errorLogs && importResult.errorLogs.length > 0 && ( + <> +
+ {t("import_preview.result.error_logs")} +
+
+ {importResult.errorLogs.map((log: string, idx: number) => ( +
+ {log} +
+ ))} +
+ + )} + + )} + + ); }; export default ImportPreviewModal; diff --git a/internal/app/import_pipeline.go b/internal/app/import_pipeline.go new file mode 100644 index 0000000..f1d28ee --- /dev/null +++ b/internal/app/import_pipeline.go @@ -0,0 +1,511 @@ +package app + +import ( + "bufio" + "encoding/csv" + "encoding/json" + "fmt" + "io" + "os" + "sort" + "strings" + + "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/db" + "github.com/xuri/excelize/v2" +) + +const ( + defaultImportPreviewLimit = 5 + defaultImportApplyBatchSize = 1000 +) + +type importFileConsumer interface { + SetColumns(columns []string) error + ConsumeRow(row map[string]interface{}) error +} + +type importPreviewData struct { + Columns []string + TotalRows int + PreviewRows []map[string]interface{} +} + +type importProgressState struct { + Current int `json:"current"` + Total int `json:"total,omitempty"` + Success int `json:"success"` + Errors int `json:"errors"` + TotalRowsKnown bool `json:"totalRowsKnown,omitempty"` +} + +type importExecutionResult struct { + Success int + Failed int + Total int + ErrorLogs []string +} + +type importPreviewCollector struct { + columns []string + totalRows int + previewRows []map[string]interface{} + previewLimit int +} + +func newImportPreviewCollector(limit int) *importPreviewCollector { + if limit <= 0 { + limit = defaultImportPreviewLimit + } + return &importPreviewCollector{previewLimit: limit} +} + +func (c *importPreviewCollector) SetColumns(columns []string) error { + c.columns = append([]string(nil), columns...) + return nil +} + +func (c *importPreviewCollector) ConsumeRow(row map[string]interface{}) error { + c.totalRows++ + if len(c.previewRows) < c.previewLimit { + c.previewRows = append(c.previewRows, cloneImportRow(row)) + } + return nil +} + +func (c *importPreviewCollector) Result() importPreviewData { + return importPreviewData{ + Columns: append([]string(nil), c.columns...), + TotalRows: c.totalRows, + PreviewRows: cloneImportRows(c.previewRows), + } +} + +type importCollectConsumer struct { + columns []string + rows []map[string]interface{} +} + +func (c *importCollectConsumer) SetColumns(columns []string) error { + c.columns = append([]string(nil), columns...) + return nil +} + +func (c *importCollectConsumer) ConsumeRow(row map[string]interface{}) error { + c.rows = append(c.rows, cloneImportRow(row)) + return nil +} + +type importRowWriter interface { + SetColumns(columns []string) + ApplyBatch(rows []map[string]interface{}) error + ApplyOne(row map[string]interface{}) error + BatchEnabled() bool +} + +type importDatabaseRowWriter struct { + dbInst db.Database + applier db.BatchApplier + dbType string + tableName string + columns []string + columnTypeMap map[string]string +} + +func newImportDatabaseRowWriter(dbInst db.Database, dbType, tableName string, columnTypeMap map[string]string) *importDatabaseRowWriter { + writer := &importDatabaseRowWriter{ + dbInst: dbInst, + dbType: dbType, + tableName: tableName, + columnTypeMap: columnTypeMap, + } + if applier, ok := dbInst.(db.BatchApplier); ok { + writer.applier = applier + } + return writer +} + +func (w *importDatabaseRowWriter) SetColumns(columns []string) { + w.columns = append([]string(nil), columns...) +} + +func (w *importDatabaseRowWriter) BatchEnabled() bool { + return w.applier != nil +} + +func (w *importDatabaseRowWriter) ApplyBatch(rows []map[string]interface{}) error { + if w.applier == nil { + return fmt.Errorf("当前数据库类型不支持批量提交") + } + return w.applier.ApplyChanges(w.tableName, connection.ChangeSet{Inserts: cloneImportRows(rows)}) +} + +func (w *importDatabaseRowWriter) ApplyOne(row map[string]interface{}) error { + if w.applier != nil { + return w.applier.ApplyChanges(w.tableName, connection.ChangeSet{Inserts: []map[string]interface{}{cloneImportRow(row)}}) + } + query, err := buildImportInsertQuery(w.dbType, w.tableName, w.columns, row, w.columnTypeMap) + if err != nil { + return err + } + _, err = w.dbInst.Exec(query) + return err +} + +type importBatchConsumer struct { + writer importRowWriter + batchSize int + totalRows int + totalRowsKnown bool + report func(importProgressState) + batch []map[string]interface{} + batchStartRow int + currentRow int + successCount int + errorLogs []string +} + +func newImportBatchConsumer(writer importRowWriter, batchSize int, totalRows int, totalRowsKnown bool, report func(importProgressState)) *importBatchConsumer { + if batchSize <= 0 { + batchSize = defaultImportApplyBatchSize + } + return &importBatchConsumer{ + writer: writer, + batchSize: batchSize, + totalRows: totalRows, + totalRowsKnown: totalRowsKnown, + report: report, + } +} + +func (c *importBatchConsumer) SetColumns(columns []string) error { + if c.writer != nil { + c.writer.SetColumns(columns) + } + return nil +} + +func (c *importBatchConsumer) ConsumeRow(row map[string]interface{}) error { + c.currentRow++ + if len(c.batch) == 0 { + c.batchStartRow = c.currentRow + } + c.batch = append(c.batch, cloneImportRow(row)) + if len(c.batch) >= c.batchSize { + return c.flush() + } + return nil +} + +func (c *importBatchConsumer) Flush() error { + return c.flush() +} + +func (c *importBatchConsumer) Result() importExecutionResult { + return importExecutionResult{ + Success: c.successCount, + Failed: len(c.errorLogs), + Total: c.currentRow, + ErrorLogs: append([]string(nil), c.errorLogs...), + } +} + +func (c *importBatchConsumer) flush() error { + if len(c.batch) == 0 { + return nil + } + rows := c.batch + startRow := c.batchStartRow + c.batch = nil + c.batchStartRow = 0 + + if c.writer != nil && c.writer.BatchEnabled() { + if err := c.writer.ApplyBatch(rows); err == nil { + c.successCount += len(rows) + c.emitProgress(startRow + len(rows) - 1) + return nil + } + } + + for idx, row := range rows { + if c.writer != nil { + if err := c.writer.ApplyOne(row); err != nil { + c.errorLogs = append(c.errorLogs, fmt.Sprintf("Row %d: %s", startRow+idx, err.Error())) + } else { + c.successCount++ + } + } + c.emitProgress(startRow + idx) + } + return nil +} + +func (c *importBatchConsumer) emitProgress(current int) { + if c.report == nil { + return + } + c.report(importProgressState{ + Current: current, + Total: c.totalRows, + Success: c.successCount, + Errors: len(c.errorLogs), + TotalRowsKnown: c.totalRowsKnown, + }) +} + +func buildImportPreview(filePath string, previewLimit int) (importPreviewData, error) { + collector := newImportPreviewCollector(previewLimit) + if err := streamImportFile(filePath, collector); err != nil { + return importPreviewData{}, err + } + return collector.Result(), nil +} + +func parseImportFile(filePath string) ([]map[string]interface{}, []string, error) { + collector := &importCollectConsumer{} + if err := streamImportFile(filePath, collector); err != nil { + return nil, nil, err + } + return collector.rows, collector.columns, nil +} + +func streamImportFile(filePath string, consumer importFileConsumer) error { + lower := strings.ToLower(filePath) + switch { + case strings.HasSuffix(lower, ".json"): + return streamJSONImportFile(filePath, consumer) + case strings.HasSuffix(lower, ".csv"): + return streamCSVImportFile(filePath, consumer) + case strings.HasSuffix(lower, ".xlsx"), strings.HasSuffix(lower, ".xls"): + return streamExcelImportFile(filePath, consumer) + default: + return fmt.Errorf("Unsupported file format") + } +} + +func streamJSONImportFile(filePath string, consumer importFileConsumer) error { + f, err := os.Open(filePath) + if err != nil { + return err + } + defer f.Close() + + decoder := json.NewDecoder(bufio.NewReader(f)) + token, err := decoder.Token() + if err != nil { + return fmt.Errorf("JSON Parse Error: %w", err) + } + delim, ok := token.(json.Delim) + if !ok || delim != '[' { + return fmt.Errorf("JSON Parse Error: root array expected") + } + + var columns []string + for decoder.More() { + var raw map[string]interface{} + if err := decoder.Decode(&raw); err != nil { + return fmt.Errorf("JSON Parse Error: %w", err) + } + if columns == nil { + columns = importJSONColumns(raw) + if err := consumer.SetColumns(columns); err != nil { + return err + } + } + if err := consumer.ConsumeRow(normalizeImportMapRow(columns, raw)); err != nil { + return err + } + } + if _, err := decoder.Token(); err != nil { + return fmt.Errorf("JSON Parse Error: %w", err) + } + return nil +} + +func streamCSVImportFile(filePath string, consumer importFileConsumer) error { + f, err := os.Open(filePath) + if err != nil { + return err + } + defer f.Close() + + reader := csv.NewReader(bufio.NewReader(f)) + reader.ReuseRecord = true + + header, err := reader.Read() + if err != nil { + if err == io.EOF { + return fmt.Errorf("CSV empty or missing header") + } + return fmt.Errorf("CSV Parse Error: %w", err) + } + columns := cloneImportColumns(header) + if !hasImportUsableColumns(columns) { + return fmt.Errorf("CSV empty or missing header") + } + if err := consumer.SetColumns(columns); err != nil { + return err + } + + for { + record, err := reader.Read() + if err != nil { + if err == io.EOF { + return nil + } + return fmt.Errorf("CSV Parse Error: %w", err) + } + if err := consumer.ConsumeRow(buildImportRowFromValues(columns, record)); err != nil { + return err + } + } +} + +func streamExcelImportFile(filePath string, consumer importFileConsumer) error { + workbook, err := excelize.OpenFile(filePath) + if err != nil { + return fmt.Errorf("Excel Parse Error: %w", err) + } + defer workbook.Close() + + sheetName := workbook.GetSheetName(0) + if sheetName == "" { + return fmt.Errorf("Excel file has no sheets") + } + + rows, err := workbook.Rows(sheetName) + if err != nil { + return fmt.Errorf("Excel Read Error: %w", err) + } + defer rows.Close() + + if !rows.Next() { + if err := rows.Error(); err != nil { + return fmt.Errorf("Excel Read Error: %w", err) + } + return fmt.Errorf("Excel empty or missing header") + } + header, err := rows.Columns() + if err != nil { + return fmt.Errorf("Excel Read Error: %w", err) + } + columns := cloneImportColumns(header) + if !hasImportUsableColumns(columns) { + return fmt.Errorf("Excel empty or missing header") + } + if err := consumer.SetColumns(columns); err != nil { + return err + } + + for rows.Next() { + record, err := rows.Columns() + if err != nil { + return fmt.Errorf("Excel Read Error: %w", err) + } + if err := consumer.ConsumeRow(buildImportRowFromValues(columns, record)); err != nil { + return err + } + } + if err := rows.Error(); err != nil { + return fmt.Errorf("Excel Read Error: %w", err) + } + return nil +} + +func buildImportInsertQuery(dbType, tableName string, columns []string, row map[string]interface{}, columnTypeMap map[string]string) (string, error) { + quotedCols := make([]string, 0, len(columns)) + values := make([]string, 0, len(columns)) + for _, column := range columns { + if strings.TrimSpace(column) == "" { + continue + } + quotedCols = append(quotedCols, quoteIdentByType(dbType, column)) + colType := columnTypeMap[normalizeColumnName(column)] + values = append(values, formatImportSQLValue(dbType, colType, row[column])) + } + if len(quotedCols) == 0 { + return "", fmt.Errorf("导入文件缺少有效列头") + } + return fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", + quoteQualifiedIdentByType(dbType, tableName), + strings.Join(quotedCols, ", "), + strings.Join(values, ", ")), nil +} + +func importJSONColumns(row map[string]interface{}) []string { + columns := make([]string, 0, len(row)) + for key := range row { + if strings.TrimSpace(key) == "" { + continue + } + columns = append(columns, key) + } + sort.Strings(columns) + return columns +} + +func cloneImportColumns(raw []string) []string { + return append([]string(nil), raw...) +} + +func hasImportUsableColumns(columns []string) bool { + for _, column := range columns { + if strings.TrimSpace(column) != "" { + return true + } + } + return false +} + +func buildImportRowFromValues(columns []string, values []string) map[string]interface{} { + row := make(map[string]interface{}, len(columns)) + for idx, column := range columns { + if strings.TrimSpace(column) == "" { + continue + } + if idx >= len(values) { + row[column] = nil + continue + } + if values[idx] == "NULL" { + row[column] = nil + continue + } + row[column] = values[idx] + } + return row +} + +func normalizeImportMapRow(columns []string, raw map[string]interface{}) map[string]interface{} { + row := make(map[string]interface{}, len(columns)) + for _, column := range columns { + if value, ok := raw[column]; ok { + row[column] = value + continue + } + row[column] = nil + } + return row +} + +func cloneImportRow(row map[string]interface{}) map[string]interface{} { + if row == nil { + return nil + } + cloned := make(map[string]interface{}, len(row)) + for key, value := range row { + cloned[key] = value + } + return cloned +} + +func cloneImportRows(rows []map[string]interface{}) []map[string]interface{} { + if len(rows) == 0 { + return nil + } + cloned := make([]map[string]interface{}, 0, len(rows)) + for _, row := range rows { + cloned = append(cloned, cloneImportRow(row)) + } + return cloned +} diff --git a/internal/app/methods_file.go b/internal/app/methods_file.go index 4df9f1a..2bb306b 100644 --- a/internal/app/methods_file.go +++ b/internal/app/methods_file.go @@ -1668,21 +1668,15 @@ func (a *App) PreviewImportFile(filePath string) connection.QueryResult { return connection.QueryResult{Success: false, Message: "文件路径不能为空"} } - rows, columns, err := parseImportFile(filePath) + preview, err := buildImportPreview(filePath, defaultImportPreviewLimit) if err != nil { return connection.QueryResult{Success: false, Message: err.Error()} } - totalRows := len(rows) - previewRows := rows - if len(rows) > 5 { - previewRows = rows[:5] - } - result := map[string]interface{}{ - "columns": columns, - "totalRows": totalRows, - "previewRows": previewRows, + "columns": preview.Columns, + "totalRows": preview.TotalRows, + "previewRows": preview.PreviewRows, "filePath": filePath, } @@ -1712,98 +1706,6 @@ func (a *App) ImportData(config connection.ConnectionConfig, dbName, tableName s return connection.QueryResult{Success: true, Data: map[string]interface{}{"filePath": selection}} } -// parseImportFile 解析导入文件,返回数据行和列名 -func parseImportFile(filePath string) ([]map[string]interface{}, []string, error) { - var rows []map[string]interface{} - var columns []string - lower := strings.ToLower(filePath) - - if strings.HasSuffix(lower, ".json") { - f, err := os.Open(filePath) - if err != nil { - return nil, nil, err - } - defer f.Close() - decoder := json.NewDecoder(f) - if err := decoder.Decode(&rows); err != nil { - return nil, nil, fmt.Errorf("JSON Parse Error: %w", err) - } - if len(rows) > 0 { - for k := range rows[0] { - columns = append(columns, k) - } - } - } else if strings.HasSuffix(lower, ".csv") { - f, err := os.Open(filePath) - if err != nil { - return nil, nil, err - } - defer f.Close() - reader := csv.NewReader(f) - records, err := reader.ReadAll() - if err != nil { - return nil, nil, fmt.Errorf("CSV Parse Error: %w", err) - } - if len(records) < 2 { - return nil, nil, fmt.Errorf("CSV empty or missing header") - } - columns = records[0] - for _, record := range records[1:] { - row := make(map[string]interface{}) - for i, val := range record { - if i < len(columns) { - if val == "NULL" { - row[columns[i]] = nil - } else { - row[columns[i]] = val - } - } - } - rows = append(rows, row) - } - } else if strings.HasSuffix(lower, ".xlsx") || strings.HasSuffix(lower, ".xls") { - xlsx, err := excelize.OpenFile(filePath) - if err != nil { - return nil, nil, fmt.Errorf("Excel Parse Error: %w", err) - } - defer xlsx.Close() - - sheetName := xlsx.GetSheetName(0) - if sheetName == "" { - return nil, nil, fmt.Errorf("Excel file has no sheets") - } - - xlRows, err := xlsx.GetRows(sheetName) - if err != nil { - return nil, nil, fmt.Errorf("Excel Read Error: %w", err) - } - if len(xlRows) < 2 { - return nil, nil, fmt.Errorf("Excel empty or missing header") - } - - columns = xlRows[0] - for _, record := range xlRows[1:] { - row := make(map[string]interface{}) - for i, val := range record { - if i < len(columns) && columns[i] != "" { - if val == "NULL" { - row[columns[i]] = nil - } else { - row[columns[i]] = val - } - } - } - if len(row) > 0 { - rows = append(rows, row) - } - } - } else { - return nil, nil, fmt.Errorf("Unsupported file format") - } - - return rows, columns, nil -} - func normalizeColumnName(name string) string { return strings.ToLower(strings.TrimSpace(name)) } @@ -2125,15 +2027,6 @@ func formatImportSQLValue(dbType, columnType string, value interface{}) string { // ImportDataWithProgress 执行导入并发送进度事件 func (a *App) ImportDataWithProgress(config connection.ConnectionConfig, dbName, tableName, filePath string) connection.QueryResult { - rows, columns, err := parseImportFile(filePath) - if err != nil { - return connection.QueryResult{Success: false, Message: err.Error()} - } - - if len(rows) == 0 { - return connection.QueryResult{Success: true, Message: "无可导入数据"} - } - runConfig := normalizeRunConfig(config, dbName) dbInst, err := a.getDatabase(runConfig) if err != nil { @@ -2147,55 +2040,31 @@ func (a *App) ImportDataWithProgress(config connection.ConnectionConfig, dbName, columnTypeMap = buildImportColumnTypeMap(defs) } - totalRows := len(rows) - successCount := 0 - var errorLogs []string - - quotedCols := make([]string, len(columns)) - for i, c := range columns { - quotedCols[i] = quoteIdentByType(dbType, c) + writer := newImportDatabaseRowWriter(dbInst, dbType, tableName, columnTypeMap) + consumer := newImportBatchConsumer(writer, defaultImportApplyBatchSize, 0, false, func(state importProgressState) { + runtime.EventsEmit(a.ctx, "import:progress", state) + }) + if err := streamImportFile(filePath, consumer); err != nil { + return connection.QueryResult{Success: false, Message: err.Error()} + } + if err := consumer.Flush(); err != nil { + return connection.QueryResult{Success: false, Message: err.Error()} } - for idx, row := range rows { - var values []string - for _, col := range columns { - val := row[col] - colType := columnTypeMap[normalizeColumnName(col)] - values = append(values, formatImportSQLValue(dbType, colType, val)) - } - - query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", - quoteQualifiedIdentByType(dbType, tableName), - strings.Join(quotedCols, ", "), - strings.Join(values, ", ")) - - _, err := dbInst.Exec(query) - if err != nil { - errorLogs = append(errorLogs, fmt.Sprintf("Row %d: %s", idx+1, err.Error())) - } else { - successCount++ - } - - // 每 10 行发送一次进度事件 - if (idx+1)%10 == 0 || idx == totalRows-1 { - runtime.EventsEmit(a.ctx, "import:progress", map[string]interface{}{ - "current": idx + 1, - "total": totalRows, - "success": successCount, - "errors": len(errorLogs), - }) - } + resultData := consumer.Result() + if resultData.Total == 0 { + return connection.QueryResult{Success: true, Message: "无可导入数据"} } result := map[string]interface{}{ - "success": successCount, - "failed": len(errorLogs), - "total": totalRows, - "errorLogs": errorLogs, - "errorSummary": fmt.Sprintf("Imported: %d, Failed: %d", successCount, len(errorLogs)), + "success": resultData.Success, + "failed": resultData.Failed, + "total": resultData.Total, + "errorLogs": resultData.ErrorLogs, + "errorSummary": fmt.Sprintf("Imported: %d, Failed: %d", resultData.Success, resultData.Failed), } - return connection.QueryResult{Success: true, Data: result, Message: fmt.Sprintf("Imported: %d, Failed: %d", successCount, len(errorLogs))} + return connection.QueryResult{Success: true, Data: result, Message: fmt.Sprintf("Imported: %d, Failed: %d", resultData.Success, resultData.Failed)} } func (a *App) ApplyChanges(config connection.ConnectionConfig, dbName, tableName string, changes connection.ChangeSet) connection.QueryResult { diff --git a/internal/app/methods_file_import_test.go b/internal/app/methods_file_import_test.go index d2b13bc..b23f278 100644 --- a/internal/app/methods_file_import_test.go +++ b/internal/app/methods_file_import_test.go @@ -2,8 +2,11 @@ package app import ( "errors" + "fmt" "os" "path/filepath" + "reflect" + "strings" "testing" ) @@ -31,3 +34,146 @@ func TestReadImportedConnectionConfigFileRejectsOversizedFiles(t *testing.T) { }) } } + +func TestBuildImportPreviewCSVStreamKeepsFirstFiveRows(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "users.csv") + var builder strings.Builder + builder.WriteString("id,name\n") + for i := 1; i <= 7; i++ { + builder.WriteString(fmt.Sprintf("%d,user_%d\n", i, i)) + } + if err := os.WriteFile(path, []byte(builder.String()), 0o600); err != nil { + t.Fatalf("write csv: %v", err) + } + + preview, err := buildImportPreview(path, 5) + if err != nil { + t.Fatalf("buildImportPreview returned error: %v", err) + } + + if !reflect.DeepEqual(preview.Columns, []string{"id", "name"}) { + t.Fatalf("unexpected columns: %#v", preview.Columns) + } + if preview.TotalRows != 7 { + t.Fatalf("expected 7 rows, got %d", preview.TotalRows) + } + if len(preview.PreviewRows) != 5 { + t.Fatalf("expected 5 preview rows, got %d", len(preview.PreviewRows)) + } + if got := preview.PreviewRows[0]["name"]; got != "user_1" { + t.Fatalf("expected first preview row name user_1, got %#v", got) + } + if got := preview.PreviewRows[4]["id"]; got != "5" { + t.Fatalf("expected fifth preview row id 5, got %#v", got) + } +} + +func TestBuildImportRowFromValuesPreservesPositionsWhenHeaderContainsBlankColumns(t *testing.T) { + row := buildImportRowFromValues([]string{"id", "", "name"}, []string{"1", "ignored", "alice"}) + if got := row["id"]; got != "1" { + t.Fatalf("expected id to stay aligned, got %#v", got) + } + if got := row["name"]; got != "alice" { + t.Fatalf("expected name to stay aligned, got %#v", got) + } + if _, ok := row[""]; ok { + t.Fatal("blank header column should not be written into row map") + } +} + +type fakeImportRowWriter struct { + columns []string + batchCalls int + singleCalls int + batchSizes []int + batchErr error + singleErrByRowID map[interface{}]error +} + +func (w *fakeImportRowWriter) SetColumns(columns []string) { + w.columns = append([]string(nil), columns...) +} + +func (w *fakeImportRowWriter) ApplyBatch(rows []map[string]interface{}) error { + w.batchCalls++ + w.batchSizes = append(w.batchSizes, len(rows)) + return w.batchErr +} + +func (w *fakeImportRowWriter) ApplyOne(row map[string]interface{}) error { + w.singleCalls++ + if err, ok := w.singleErrByRowID[row["id"]]; ok { + return err + } + return nil +} + +func (w *fakeImportRowWriter) BatchEnabled() bool { + return true +} + +func TestImportBatchConsumerUsesBatchWriterInConfiguredBatches(t *testing.T) { + writer := &fakeImportRowWriter{} + consumer := newImportBatchConsumer(writer, 1000, 1201, true, nil) + if err := consumer.SetColumns([]string{"id"}); err != nil { + t.Fatalf("SetColumns returned error: %v", err) + } + for i := 1; i <= 1201; i++ { + if err := consumer.ConsumeRow(map[string]interface{}{"id": i}); err != nil { + t.Fatalf("ConsumeRow(%d) returned error: %v", i, err) + } + } + if err := consumer.Flush(); err != nil { + t.Fatalf("Flush returned error: %v", err) + } + + if writer.batchCalls != 2 { + t.Fatalf("expected 2 batch calls, got %d", writer.batchCalls) + } + if !reflect.DeepEqual(writer.batchSizes, []int{1000, 201}) { + t.Fatalf("unexpected batch sizes: %#v", writer.batchSizes) + } + result := consumer.Result() + if result.Success != 1201 || result.Failed != 0 || result.Total != 1201 { + t.Fatalf("unexpected result: %#v", result) + } + if writer.singleCalls != 0 { + t.Fatalf("expected no single-row fallback, got %d calls", writer.singleCalls) + } +} + +func TestImportBatchConsumerFallsBackToSingleRowsWhenBatchFails(t *testing.T) { + writer := &fakeImportRowWriter{ + batchErr: fmt.Errorf("batch failed"), + singleErrByRowID: map[interface{}]error{ + 2: fmt.Errorf("duplicate key"), + }, + } + consumer := newImportBatchConsumer(writer, 1000, 3, true, nil) + if err := consumer.SetColumns([]string{"id"}); err != nil { + t.Fatalf("SetColumns returned error: %v", err) + } + for i := 1; i <= 3; i++ { + if err := consumer.ConsumeRow(map[string]interface{}{"id": i}); err != nil { + t.Fatalf("ConsumeRow(%d) returned error: %v", i, err) + } + } + if err := consumer.Flush(); err != nil { + t.Fatalf("Flush returned error: %v", err) + } + + result := consumer.Result() + if result.Success != 2 || result.Failed != 1 || result.Total != 3 { + t.Fatalf("unexpected result: %#v", result) + } + if writer.batchCalls != 1 { + t.Fatalf("expected 1 batch call, got %d", writer.batchCalls) + } + if writer.singleCalls != 3 { + t.Fatalf("expected 3 single-row fallback calls, got %d", writer.singleCalls) + } + if len(result.ErrorLogs) != 1 || result.ErrorLogs[0] != "Row 2: duplicate key" { + t.Fatalf("unexpected error logs: %#v", result.ErrorLogs) + } +}