diff --git a/frontend/src/components/useExportProgressRunner.test.tsx b/frontend/src/components/useExportProgressRunner.test.tsx index 1d17a8e..02ea5c3 100644 --- a/frontend/src/components/useExportProgressRunner.test.tsx +++ b/frontend/src/components/useExportProgressRunner.test.tsx @@ -177,4 +177,67 @@ describe('useExportProgressRunner', () => { expect(runner?.state.status).toBe('done'); expect(runner?.state.totalRowsKnown).toBe(false); }); + + it('switches to exact progress when backend start events later provide total rows', async () => { + renderRunner(); + + let resolveRun!: (value: { success: boolean; message: string }) => void; + const pendingRun = new Promise<{ success: boolean; message: string }>((resolve) => { + resolveRun = resolve; + }); + + let runPromise: Promise<{ success: boolean; message: string } | null> | null = null; + await act(async () => { + runPromise = runner?.runExportWithProgress({ + title: '导出 SYS.test', + targetName: 'SYS.test', + format: 'xlsx', + run: async () => pendingRun, + }) || null; + await Promise.resolve(); + }); + + expect(runner?.state.totalRowsKnown).toBe(false); + expect(runner?.state.total).toBe(0); + + const jobId = runner?.state.jobId || ''; + now = 4_000; + act(() => { + runtimeApi.emitExportProgress({ + jobId, + status: 'start', + stage: '正在准备导出', + total: 96000, + totalRowsKnown: true, + filePath: '/Users/yangguofeng/Desktop/SYS.test.xlsx', + }); + }); + + expect(runner?.state.totalRowsKnown).toBe(true); + expect(runner?.state.total).toBe(96000); + expect(runner?.state.filePath).toBe('/Users/yangguofeng/Desktop/SYS.test.xlsx'); + + act(() => { + runtimeApi.emitExportProgress({ + jobId, + status: 'running', + stage: '正在写入文件', + current: 24000, + }); + }); + + expect(runner?.state.current).toBe(24000); + expect(runner?.state.total).toBe(96000); + expect(runner?.state.totalRowsKnown).toBe(true); + + now = 8_000; + await act(async () => { + resolveRun({ success: true, message: '导出完成' }); + await runPromise; + }); + + expect(runner?.state.status).toBe('done'); + expect(runner?.state.total).toBe(96000); + expect(runner?.state.totalRowsKnown).toBe(true); + }); }); diff --git a/internal/app/methods_file.go b/internal/app/methods_file.go index 6bcc206..3640844 100644 --- a/internal/app/methods_file.go +++ b/internal/app/methods_file.go @@ -230,6 +230,118 @@ func (r *exportProgressReporter) Error(current int64, message string) { r.emit("error", "导出失败", current, message, true) } +func resolveExportTotalRowValue(value interface{}) (int64, bool) { + switch v := value.(type) { + case int: + if v < 0 { + return 0, false + } + return int64(v), true + case int8: + if v < 0 { + return 0, false + } + return int64(v), true + case int16: + if v < 0 { + return 0, false + } + return int64(v), true + case int32: + if v < 0 { + return 0, false + } + return int64(v), true + case int64: + if v < 0 { + return 0, false + } + return v, true + case uint: + if uint64(v) > math.MaxInt64 { + return 0, false + } + return int64(v), true + case uint8: + return int64(v), true + case uint16: + return int64(v), true + case uint32: + return int64(v), true + case uint64: + if v > math.MaxInt64 { + return 0, false + } + return int64(v), true + case float32: + if !isFiniteFloat64(float64(v)) || v < 0 { + return 0, false + } + return int64(v), true + case float64: + if !isFiniteFloat64(v) || v < 0 { + return 0, false + } + return int64(v), true + case json.Number: + if i, err := v.Int64(); err == nil && i >= 0 { + return i, true + } + if f, err := v.Float64(); err == nil && isFiniteFloat64(f) && f >= 0 { + return int64(f), true + } + case []byte: + return resolveExportTotalRowValue(string(v)) + case string: + text := strings.TrimSpace(v) + if text == "" { + return 0, false + } + if i, err := strconv.ParseInt(text, 10, 64); err == nil && i >= 0 { + return i, true + } + if f, err := strconv.ParseFloat(text, 64); err == nil && isFiniteFloat64(f) && f >= 0 { + return int64(f), true + } + } + return 0, false +} + +func isFiniteFloat64(value float64) bool { + return !math.IsNaN(value) && !math.IsInf(value, 0) +} + +func resolveExportTotalRowsFromRows(rows []map[string]interface{}) (int64, bool) { + if len(rows) == 0 || rows[0] == nil { + return 0, false + } + row := rows[0] + preferredKeys := []string{"total", "TOTAL", "count", "COUNT", "cnt", "CNT", "table_rows", "TABLE_ROWS"} + for _, key := range preferredKeys { + if value, ok := row[key]; ok { + if total, ok := resolveExportTotalRowValue(value); ok { + return total, true + } + } + } + for _, value := range row { + if total, ok := resolveExportTotalRowValue(value); ok { + return total, true + } + } + return 0, false +} + +func tryResolveExportTableTotalRows(dbInst db.Database, config connection.ConnectionConfig, tableName string) (int64, bool) { + dbType := resolveDDLDBType(config) + query := fmt.Sprintf("SELECT COUNT(*) AS total FROM %s", quoteQualifiedIdentByType(dbType, tableName)) + rows, _, err := queryDataForExport(dbInst, config, query) + if err != nil { + return 0, false + } + return resolveExportTotalRowsFromRows(rows) +} + var exportFileNameSanitizer = strings.NewReplacer( "/", "_", "\\", "_", @@ -2156,6 +2268,18 @@ func (a *App) ExportTableWithOptions(config connection.ConnectionConfig, dbName return connection.QueryResult{Success: false, Message: err.Error()} } + if format != "sql" && !options.TotalRowsKnown { + if totalRows, ok := tryResolveExportTableTotalRows(dbInst, runConfig, tableName); ok { + options.TotalRowsHint = totalRows + options.TotalRowsKnown = true + if reporter != nil { + reporter.totalRows = totalRows + reporter.totalRowsKnown = true + reporter.Start("正在准备导出") + } + } + } + if format == "sql" { reporter.Start("正在导出 SQL 文件") f, err := os.Create(filename) diff --git a/internal/app/methods_file_export_test.go b/internal/app/methods_file_export_test.go index 831f704..0bdf07f 100644 --- a/internal/app/methods_file_export_test.go +++ b/internal/app/methods_file_export_test.go @@ -441,6 +441,40 @@ func TestQueryDataForExport_UsesLargerConfiguredTimeout(t *testing.T) { } } +func TestResolveExportTotalRowsFromRows_PrefersNamedTotalColumn(t *testing.T) { + total, ok := resolveExportTotalRowsFromRows([]map[string]interface{}{ + {"COUNT": "96000", "other": 1}, + }) + if !ok { + t.Fatal("应成功解析导出总行数") + } + if total != 96000 { + t.Fatalf("解析导出总行数错误,want=%d got=%d", 96000, total) + } +} + +func TestTryResolveExportTableTotalRows_UsesCountQuery(t *testing.T) { + fake := &fakeExportQueryDB{ + data: []map[string]interface{}{{"total": int64(128000)}}, + cols: []string{"total"}, + } + + total, ok := tryResolveExportTableTotalRows( + fake, + connection.ConnectionConfig{Type: "mysql", Timeout: 10}, + "SYS.test", + ) + if !ok { + t.Fatal("应成功解析整表导出总行数") + } + if total != 128000 { + t.Fatalf("整表导出总行数错误,want=%d got=%d", 128000, total) + } + if fake.lastQuery != "SELECT COUNT(*) AS total FROM `SYS`.`test`" { + t.Fatalf("整表导出统计 SQL 错误,got=%q", fake.lastQuery) + } +} + func TestExportQueryResultToFile_UsesStreamQueryPath(t *testing.T) { f, err := os.CreateTemp("", "gonavi-export-stream-*.csv") if err != nil {