diff --git a/internal/app/export_options.go b/internal/app/export_options.go new file mode 100644 index 0000000..963b639 --- /dev/null +++ b/internal/app/export_options.go @@ -0,0 +1,50 @@ +package app + +import "strings" + +const ( + maxXLSXRowsPerSheet = 1048575 + defaultXLSXRowsPerSheet = maxXLSXRowsPerSheet +) + +type ExportFileOptions struct { + Format string `json:"format"` + XLSXMaxRowsPerSheet int `json:"xlsxMaxRowsPerSheet,omitempty"` + JobID string `json:"jobId,omitempty"` + TotalRowsHint int64 `json:"totalRowsHint,omitempty"` + TotalRowsKnown bool `json:"totalRowsKnown,omitempty"` +} + +func normalizeExportFileOptions(format string, options ExportFileOptions) ExportFileOptions { + resolvedFormat := strings.ToLower(strings.TrimSpace(format)) + if explicitFormat := strings.ToLower(strings.TrimSpace(options.Format)); explicitFormat != "" { + resolvedFormat = explicitFormat + } + return ExportFileOptions{ + Format: resolvedFormat, + XLSXMaxRowsPerSheet: normalizeXLSXRowsPerSheet(options.XLSXMaxRowsPerSheet), + JobID: strings.TrimSpace(options.JobID), + TotalRowsHint: normalizeExportTotalRowsHint(options.TotalRowsHint, options.TotalRowsKnown), + TotalRowsKnown: options.TotalRowsKnown, + } +} + +func normalizeXLSXRowsPerSheet(value int) int { + if value <= 0 { + return defaultXLSXRowsPerSheet + } + if value > maxXLSXRowsPerSheet { + return maxXLSXRowsPerSheet + } + return value +} + +func normalizeExportTotalRowsHint(value int64, known bool) int64 { + if !known { + return 0 + } + if value < 0 { + return 0 + } + return value +} diff --git a/internal/app/methods_file.go b/internal/app/methods_file.go index c8390ef..3a78af6 100644 --- a/internal/app/methods_file.go +++ b/internal/app/methods_file.go @@ -36,6 +36,9 @@ const sqlFileBatchMaxStatements = 1000 const sqlFileBatchMaxBytes = 4 * 1024 * 1024 const sqlFileProgressStatementInterval = 100 const sqlFileProgressTimeInterval = time.Second +const exportProgressEvent = "export:progress" +const exportProgressRowInterval int64 = 1000 +const exportProgressTimeInterval = 500 * time.Millisecond const defaultAppLogTailLineLimit = 80 const maxAppLogTailLineLimit = 200 const appLogTailReadWindowBytes int64 = 256 * 1024 @@ -89,6 +92,31 @@ type SQLDirectoryEntry struct { Children []SQLDirectoryEntry `json:"children,omitempty"` } +type exportProgressPayload struct { + JobID string `json:"jobId"` + Status string `json:"status"` + Stage string `json:"stage"` + Current int64 `json:"current"` + Total int64 `json:"total,omitempty"` + TotalRowsKnown bool `json:"totalRowsKnown,omitempty"` + Format string `json:"format,omitempty"` + TargetName string `json:"targetName,omitempty"` + FilePath string `json:"filePath,omitempty"` + Message string `json:"message,omitempty"` +} + +type exportProgressReporter struct { + app *App + jobID string + format string + targetName string + filePath string + totalRows int64 + totalRowsKnown bool + lastRows int64 + lastEmittedAt time.Time +} + type appLogTailSnapshot struct { LogPath string `json:"logPath"` Keyword string `json:"keyword,omitempty"` @@ -125,6 +153,73 @@ func normalizeSQLDirectoryName(rawName string) (string, error) { return name, nil } +func newExportProgressReporter(a *App, options ExportFileOptions, targetName string, filePath string) *exportProgressReporter { + jobID := strings.TrimSpace(options.JobID) + if a == nil || a.ctx == nil || jobID == "" { + return nil + } + return &exportProgressReporter{ + app: a, + jobID: jobID, + format: strings.ToLower(strings.TrimSpace(options.Format)), + targetName: strings.TrimSpace(targetName), + filePath: strings.TrimSpace(filePath), + totalRows: normalizeExportTotalRowsHint(options.TotalRowsHint, options.TotalRowsKnown), + totalRowsKnown: options.TotalRowsKnown, + } +} + +func (r *exportProgressReporter) emit(status string, stage string, current int64, message string, force bool) { + if r == nil || r.app == nil || r.app.ctx == nil || r.jobID == "" { + return + } + now := time.Now() + if !force && status == "running" { + if current-r.lastRows < exportProgressRowInterval && (!r.lastEmittedAt.IsZero() && now.Sub(r.lastEmittedAt) < exportProgressTimeInterval) { + return + } + } + payload := exportProgressPayload{ + JobID: r.jobID, + Status: strings.TrimSpace(status), + Stage: strings.TrimSpace(stage), + Current: current, + Total: r.totalRows, + TotalRowsKnown: r.totalRowsKnown, + Format: r.format, + TargetName: r.targetName, + FilePath: r.filePath, + Message: strings.TrimSpace(message), + } + runtime.EventsEmit(r.app.ctx, exportProgressEvent, payload) + r.lastRows = current + r.lastEmittedAt = now +} + +func (r *exportProgressReporter) Start(stage string) { + r.emit("start", stage, 0, "", true) +} + +func (r *exportProgressReporter) Rows(current int64, stage string) { + r.emit("running", stage, current, "", false) +} + +func (r *exportProgressReporter) ForceRunning(current int64, stage string) { + r.emit("running", stage, current, "", true) +} + +func (r *exportProgressReporter) Finalizing(current int64) { + r.emit("finalizing", "正在完成文件写入", current, "", true) +} + +func (r *exportProgressReporter) Done(current int64) { + r.emit("done", "导出完成", current, "", true) +} + +func (r *exportProgressReporter) Error(current int64, message string) { + r.emit("error", "导出失败", current, message, true) +} + func normalizeSQLDirectoryPath(directoryPath string) (string, error) { target := strings.TrimSpace(directoryPath) if target == "" { @@ -1730,6 +1825,55 @@ func parseTemporalString(raw string) (time.Time, bool) { return time.Time{}, false } +func looksLikeTemporalText(raw string) bool { + text := strings.TrimSpace(raw) + if text == "" { + return false + } + + if len(text) >= 10 && + isDigit(text[0]) && + isDigit(text[1]) && + isDigit(text[2]) && + isDigit(text[3]) && + text[4] == '-' && + isDigit(text[5]) && + isDigit(text[6]) && + text[7] == '-' && + isDigit(text[8]) && + isDigit(text[9]) { + return true + } + + if len(text) >= 8 && + isDigit(text[0]) && + isDigit(text[1]) && + text[2] == ':' && + isDigit(text[3]) && + isDigit(text[4]) && + text[5] == ':' && + isDigit(text[6]) && + isDigit(text[7]) { + return true + } + + return false +} + +func isDigit(ch byte) bool { + return ch >= '0' && ch <= '9' +} + +func normalizeExportTemporalText(text string) string { + if !looksLikeTemporalText(text) { + return text + } + if parsed, ok := parseTemporalString(text); ok { + return parsed.Format("2006-01-02 15:04:05") + } + return text +} + func normalizeImportTemporalValue(dbType, columnType, raw string) string { text := strings.TrimSpace(raw) if text == "" { @@ -2019,6 +2163,12 @@ func (a *App) PreviewChanges(config connection.ConnectionConfig, dbName, tableNa } func (a *App) ExportTable(config connection.ConnectionConfig, dbName string, tableName string, format string) connection.QueryResult { + return a.ExportTableWithOptions(config, dbName, tableName, ExportFileOptions{Format: format}) +} + +func (a *App) ExportTableWithOptions(config connection.ConnectionConfig, dbName string, tableName string, options ExportFileOptions) connection.QueryResult { + options = normalizeExportFileOptions("", options) + format := options.Format filename, err := runtime.SaveFileDialog(a.ctx, runtime.SaveDialogOptions{ Title: fmt.Sprintf("Export %s", tableName), DefaultFilename: fmt.Sprintf("%s.%s", tableName, format), @@ -2028,17 +2178,21 @@ func (a *App) ExportTable(config connection.ConnectionConfig, dbName string, tab return connection.QueryResult{Success: false, Message: "已取消"} } + reporter := newExportProgressReporter(a, options, tableName, filename) + reporter.Start("正在准备导出") runConfig := normalizeRunConfig(config, dbName) dbInst, err := a.getDatabase(runConfig) if err != nil { + reporter.Error(0, err.Error()) return connection.QueryResult{Success: false, Message: err.Error()} } - format = strings.ToLower(format) if format == "sql" { + reporter.Start("正在导出 SQL 文件") f, err := os.Create(filename) if err != nil { + reporter.Error(0, err.Error()) return connection.QueryResult{Success: false, Message: err.Error()} } defer f.Close() @@ -2047,35 +2201,39 @@ func (a *App) ExportTable(config connection.ConnectionConfig, dbName string, tab defer w.Flush() if err := writeSQLHeader(w, runConfig, dbName); err != nil { + reporter.Error(0, err.Error()) return connection.QueryResult{Success: false, Message: err.Error()} } viewLookup := listViewNameLookup(dbInst, runConfig, dbName) if err := dumpTableSQL(w, dbInst, runConfig, dbName, tableName, true, true, viewLookup); err != nil { + reporter.Error(0, err.Error()) return connection.QueryResult{Success: false, Message: err.Error()} } if err := writeSQLFooter(w, runConfig); err != nil { + reporter.Error(0, err.Error()) return connection.QueryResult{Success: false, Message: err.Error()} } + reporter.Finalizing(0) + reporter.Done(0) return connection.QueryResult{Success: true, Message: "导出完成"} } dbType := resolveDDLDBType(config) query := fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(dbType, tableName)) - data, columns, err := queryDataForExport(dbInst, runConfig, query) - if err != nil { - return connection.QueryResult{Success: false, Message: err.Error()} - } - f, err := os.Create(filename) if err != nil { + reporter.Error(0, err.Error()) return connection.QueryResult{Success: false, Message: err.Error()} } defer f.Close() - if err := writeRowsToFile(f, data, columns, format); err != nil { + rowCount, _, err := exportQueryResultToFile(f, dbInst, runConfig, query, options, reporter) + if err != nil { + reporter.Error(rowCount, "写入失败:"+err.Error()) return connection.QueryResult{Success: false, Message: "写入失败:" + err.Error()} } + reporter.Done(rowCount) return connection.QueryResult{Success: true, Message: "导出完成"} } @@ -3181,45 +3339,44 @@ func dumpTableSQL( qualified := qualifyTable(schemaName, pureTableName) dbType := resolveDDLDBType(config) selectSQL := fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(dbType, qualified)) - data, columns, err := queryDataForExport(dbInst, config, selectSQL) - if err != nil { - return err - } columnTypeMap := map[string]string{} if defs, colErr := dbInst.GetColumns(schemaName, pureTableName); colErr == nil { columnTypeMap = buildImportColumnTypeMap(defs) } - if len(data) == 0 { + insertConsumer := &sqlInsertExportConsumer{ + w: w, + dbType: dbType, + quotedTable: quoteQualifiedIdentByType(dbType, qualified), + columnTypeMap: columnTypeMap, + } + if err := streamQueryDataForExport(dbInst, config, selectSQL, insertConsumer); err != nil { + return err + } + if insertConsumer.rowCount == 0 { if _, err := w.WriteString("-- (0 rows)\n"); err != nil { return err } return nil } - quotedCols := make([]string, 0, len(columns)) - for _, c := range columns { - quotedCols = append(quotedCols, quoteIdentByType(dbType, c)) - } - quotedTable := quoteQualifiedIdentByType(dbType, qualified) - - for _, row := range data { - values := make([]string, 0, len(columns)) - for _, c := range columns { - values = append(values, formatImportSQLValue(dbType, columnTypeMap[normalizeColumnName(c)], row[c])) - } - if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s);\n", quotedTable, strings.Join(quotedCols, ", "), strings.Join(values, ", "))); err != nil { - return err - } - } - return nil } // ExportData exports provided data to a file func (a *App) ExportData(data []map[string]interface{}, columns []string, defaultName string, format string) connection.QueryResult { + return a.ExportDataWithOptions(data, columns, defaultName, ExportFileOptions{Format: format}) +} + +func (a *App) ExportDataWithOptions(data []map[string]interface{}, columns []string, defaultName string, options ExportFileOptions) connection.QueryResult { if defaultName == "" { defaultName = "export" } + options = normalizeExportFileOptions("", options) + if !options.TotalRowsKnown { + options.TotalRowsKnown = true + options.TotalRowsHint = int64(len(data)) + } + format := options.Format logger.Infof("ExportData 开始:rows=%d cols=%d format=%s defaultName=%s", len(data), len(columns), strings.ToLower(strings.TrimSpace(format)), strings.TrimSpace(defaultName)) filename, err := runtime.SaveFileDialog(a.ctx, runtime.SaveDialogOptions{ Title: "Export Data", @@ -3231,24 +3388,34 @@ func (a *App) ExportData(data []map[string]interface{}, columns []string, defaul return connection.QueryResult{Success: false, Message: "已取消"} } logger.Infof("ExportData 选定文件:%s", filename) + reporter := newExportProgressReporter(a, options, defaultName, filename) + reporter.Start("正在准备导出") f, err := os.Create(filename) if err != nil { + reporter.Error(0, err.Error()) return connection.QueryResult{Success: false, Message: err.Error()} } defer f.Close() - if err := writeRowsToFile(f, data, columns, format); err != nil { + writtenRows, err := writeRowsToFileWithReporter(f, data, columns, options, reporter) + if err != nil { logger.Warnf("ExportData 写入失败:file=%s err=%v", filename, err) + reporter.Error(writtenRows, "写入失败:"+err.Error()) return connection.QueryResult{Success: false, Message: "写入失败:" + err.Error()} } logger.Infof("ExportData 完成:file=%s rows=%d", filename, len(data)) + reporter.Done(writtenRows) return connection.QueryResult{Success: true, Message: "导出完成"} } // ExportQuery exports by executing the provided SELECT query on backend side. // This avoids frontend IPC payload limits when exporting very large/long-text columns (e.g. base64). func (a *App) ExportQuery(config connection.ConnectionConfig, dbName string, query string, defaultName string, format string) connection.QueryResult { + return a.ExportQueryWithOptions(config, dbName, query, defaultName, ExportFileOptions{Format: format}) +} + +func (a *App) ExportQueryWithOptions(config connection.ConnectionConfig, dbName string, query string, defaultName string, options ExportFileOptions) connection.QueryResult { query = strings.TrimSpace(query) if query == "" { return connection.QueryResult{Success: false, Message: "查询语句不能为空"} @@ -3257,6 +3424,8 @@ func (a *App) ExportQuery(config connection.ConnectionConfig, dbName string, que if defaultName == "" { defaultName = "export" } + options = normalizeExportFileOptions("", options) + format := options.Format filename, err := runtime.SaveFileDialog(a.ctx, runtime.SaveDialogOptions{ Title: "Export Query Result", @@ -3267,36 +3436,38 @@ func (a *App) ExportQuery(config connection.ConnectionConfig, dbName string, que return connection.QueryResult{Success: false, Message: "已取消"} } logger.Infof("ExportQuery 开始:type=%s db=%s format=%s file=%s sql=%q", strings.TrimSpace(config.Type), strings.TrimSpace(dbName), strings.ToLower(strings.TrimSpace(format)), filename, sqlSnippet(query)) + reporter := newExportProgressReporter(a, options, defaultName, filename) + reporter.Start("正在准备导出") runConfig := normalizeRunConfig(config, dbName) dbInst, err := a.getDatabase(runConfig) if err != nil { + reporter.Error(0, err.Error()) return connection.QueryResult{Success: false, Message: err.Error()} } query = sanitizeSQLForPgLike(resolveDDLDBType(config), query) if !looksLikeSelectOrWith(query) { + reporter.Error(0, "仅支持 SELECT/WITH 查询导出") return connection.QueryResult{Success: false, Message: "仅支持 SELECT/WITH 查询导出"} } - data, columns, err := queryDataForExport(dbInst, runConfig, query) - if err != nil { - logger.Warnf("ExportQuery 查询失败:type=%s db=%s err=%v sql=%q", strings.TrimSpace(config.Type), strings.TrimSpace(dbName), err, sqlSnippet(query)) - return connection.QueryResult{Success: false, Message: err.Error()} - } - f, err := os.Create(filename) if err != nil { + reporter.Error(0, err.Error()) return connection.QueryResult{Success: false, Message: err.Error()} } defer f.Close() - if err := writeRowsToFile(f, data, columns, format); err != nil { - logger.Warnf("ExportQuery 写入失败:file=%s err=%v", filename, err) - return connection.QueryResult{Success: false, Message: "写入失败:" + err.Error()} + rowCount, columns, err := exportQueryResultToFile(f, dbInst, runConfig, query, options, reporter) + if err != nil { + logger.Warnf("ExportQuery 查询失败:type=%s db=%s err=%v sql=%q", strings.TrimSpace(config.Type), strings.TrimSpace(dbName), err, sqlSnippet(query)) + reporter.Error(rowCount, err.Error()) + return connection.QueryResult{Success: false, Message: err.Error()} } - logger.Infof("ExportQuery 完成:file=%s rows=%d cols=%d", filename, len(data), len(columns)) + logger.Infof("ExportQuery 完成:file=%s rows=%d cols=%d", filename, rowCount, len(columns)) + reporter.Done(rowCount) return connection.QueryResult{Success: true, Message: "导出完成"} } @@ -3341,130 +3512,756 @@ func getExportQueryTimeout(config connection.ConnectionConfig) time.Duration { return timeout } -func writeRowsToFile(f *os.File, data []map[string]interface{}, columns []string, format string) error { - format = strings.ToLower(strings.TrimSpace(format)) - if f == nil { - return fmt.Errorf("file required") - } +type exportFileWriter interface { + db.QueryStreamConsumer + Close() error +} - // xlsx 使用 excelize 写入真正的 Excel 格式 - if format == "xlsx" { - return writeRowsToXlsx(f.Name(), data, columns) - } +type exportValueStreamConsumer interface { + ConsumeRowValues(values []interface{}) error +} - // html 使用内嵌 CSS 输出可直接浏览器预览的独立页面 - if format == "html" { - return writeRowsToHTML(f, data, columns) - } +type countingExportConsumer struct { + delegate db.QueryStreamConsumer + columns []string + rowCount int64 + reporter *exportProgressReporter +} - // 如果列名为空但数据不为空,从所有数据行提取所有键 - if len(columns) == 0 && len(data) > 0 { - keySet := make(map[string]bool) - for _, row := range data { - for key := range row { - keySet[key] = true - } - } - // 排序以确保输出一致 - for key := range keySet { - columns = append(columns, key) - } - sort.Strings(columns) - } - - var csvWriter *csv.Writer - var jsonEncoder *json.Encoder - isJsonFirstRow := true - - switch format { - case "csv": - if _, err := f.Write([]byte{0xEF, 0xBB, 0xBF}); err != nil { +func (c *countingExportConsumer) SetColumns(columns []string) error { + c.columns = append([]string(nil), columns...) + if c.delegate != nil { + if err := c.delegate.SetColumns(columns); err != nil { return err } - csvWriter = csv.NewWriter(f) - if err := csvWriter.Write(columns); err != nil { - return err - } - case "json": - if _, err := f.WriteString("[\n"); err != nil { - return err - } - jsonEncoder = json.NewEncoder(f) - jsonEncoder.SetIndent(" ", " ") - case "md": - if _, err := fmt.Fprintf(f, "| %s |\n", strings.Join(columns, " | ")); err != nil { - return err - } - seps := make([]string, len(columns)) - for i := range seps { - seps[i] = "---" - } - if _, err := fmt.Fprintf(f, "| %s |\n", strings.Join(seps, " | ")); err != nil { - return err - } - default: - return fmt.Errorf("unsupported format: %s", format) } + if c.reporter != nil { + c.reporter.ForceRunning(c.rowCount, "正在写入文件") + } + return nil +} - for _, rowMap := range data { - record := make([]string, len(columns)) - for i, col := range columns { - val := rowMap[col] - if val == nil { - record[i] = "NULL" - continue - } - - s := formatExportCellText(val) - if format == "md" { - s = strings.ReplaceAll(s, "|", "\\|") - s = strings.ReplaceAll(s, "\n", "
") - } - record[i] = s +func (c *countingExportConsumer) ConsumeRow(row map[string]interface{}) error { + if c.delegate != nil { + if err := c.delegate.ConsumeRow(row); err != nil { + return err } + } + c.rowCount++ + if c.reporter != nil { + c.reporter.Rows(c.rowCount, "正在写入文件") + } + return nil +} - switch format { - case "csv": - if err := csvWriter.Write(record); err != nil { +func (c *countingExportConsumer) ConsumeRowValues(values []interface{}) error { + if c.delegate != nil { + if valueConsumer, ok := c.delegate.(exportValueStreamConsumer); ok { + if err := valueConsumer.ConsumeRowValues(values); err != nil { return err } - case "json": - if !isJsonFirstRow { - if _, err := f.WriteString(",\n"); err != nil { - return err + } else { + row := make(map[string]interface{}, len(c.columns)) + for i, column := range c.columns { + if i < len(values) { + row[column] = values[i] + } else { + row[column] = nil } } - exportedRow := make(map[string]interface{}, len(columns)) - for _, col := range columns { - exportedRow[col] = normalizeExportJSONValue(rowMap[col]) - } - if err := jsonEncoder.Encode(exportedRow); err != nil { - return err - } - isJsonFirstRow = false - case "md": - if _, err := fmt.Fprintf(f, "| %s |\n", strings.Join(record, " | ")); err != nil { + if err := c.delegate.ConsumeRow(row); err != nil { return err } } } - - if format == "csv" { - csvWriter.Flush() - if err := csvWriter.Error(); err != nil { - return err - } + c.rowCount++ + if c.reporter != nil { + c.reporter.Rows(c.rowCount, "正在写入文件") } - - if format == "json" { - if _, err := f.WriteString("\n]"); err != nil { - return err - } - } - return nil } +type csvExportFileWriter struct { + writer *csv.Writer + columns []string + record []string +} + +func newCSVExportFileWriter(f *os.File) (*csvExportFileWriter, error) { + if _, err := f.Write([]byte{0xEF, 0xBB, 0xBF}); err != nil { + return nil, err + } + return &csvExportFileWriter{writer: csv.NewWriter(f)}, nil +} + +func (w *csvExportFileWriter) SetColumns(columns []string) error { + w.columns = append([]string(nil), columns...) + w.record = make([]string, len(columns)) + return w.writer.Write(columns) +} + +func (w *csvExportFileWriter) ConsumeRow(row map[string]interface{}) error { + return w.writer.Write(fillExportRecordFromRow(w.record, row, w.columns, false)) +} + +func (w *csvExportFileWriter) ConsumeRowValues(values []interface{}) error { + return w.writer.Write(fillExportRecordFromValues(w.record, values, false)) +} + +func (w *csvExportFileWriter) Close() error { + w.writer.Flush() + return w.writer.Error() +} + +type jsonExportFileWriter struct { + file *os.File + encoder *json.Encoder + columns []string + rowBuf map[string]interface{} + first bool +} + +func newJSONExportFileWriter(f *os.File) (*jsonExportFileWriter, error) { + if _, err := f.WriteString("[\n"); err != nil { + return nil, err + } + encoder := json.NewEncoder(f) + encoder.SetIndent(" ", " ") + return &jsonExportFileWriter{file: f, encoder: encoder, first: true}, nil +} + +func (w *jsonExportFileWriter) SetColumns(columns []string) error { + w.columns = append([]string(nil), columns...) + w.rowBuf = make(map[string]interface{}, len(columns)) + return nil +} + +func (w *jsonExportFileWriter) ConsumeRow(row map[string]interface{}) error { + for _, col := range w.columns { + w.rowBuf[col] = normalizeExportJSONValue(row[col]) + } + return w.writeCurrentRow() +} + +func (w *jsonExportFileWriter) ConsumeRowValues(values []interface{}) error { + for i, col := range w.columns { + if i < len(values) { + w.rowBuf[col] = normalizeExportJSONValue(values[i]) + } else { + w.rowBuf[col] = nil + } + } + return w.writeCurrentRow() +} + +func (w *jsonExportFileWriter) writeCurrentRow() error { + if !w.first { + if _, err := w.file.WriteString(",\n"); err != nil { + return err + } + } + if err := w.encoder.Encode(w.rowBuf); err != nil { + return err + } + w.first = false + return nil +} + +func (w *jsonExportFileWriter) Close() error { + _, err := w.file.WriteString("\n]") + return err +} + +type markdownExportFileWriter struct { + file *os.File + columns []string + record []string +} + +func (w *markdownExportFileWriter) SetColumns(columns []string) error { + w.columns = append([]string(nil), columns...) + w.record = make([]string, len(columns)) + if _, err := fmt.Fprintf(w.file, "| %s |\n", strings.Join(columns, " | ")); err != nil { + return err + } + seps := make([]string, len(columns)) + for i := range seps { + seps[i] = "---" + } + _, err := fmt.Fprintf(w.file, "| %s |\n", strings.Join(seps, " | ")) + return err +} + +func (w *markdownExportFileWriter) ConsumeRow(row map[string]interface{}) error { + _, err := fmt.Fprintf(w.file, "| %s |\n", strings.Join(fillExportRecordFromRow(w.record, row, w.columns, true), " | ")) + return err +} + +func (w *markdownExportFileWriter) ConsumeRowValues(values []interface{}) error { + _, err := fmt.Fprintf(w.file, "| %s |\n", strings.Join(fillExportRecordFromValues(w.record, values, true), " | ")) + return err +} + +func (w *markdownExportFileWriter) Close() error { + return nil +} + +type htmlExportFileWriter struct { + writer *bufio.Writer + columns []string + rowCount int64 +} + +func newHTMLExportFileWriter(f *os.File) *htmlExportFileWriter { + return &htmlExportFileWriter{writer: bufio.NewWriterSize(f, 1024*256)} +} + +func (w *htmlExportFileWriter) SetColumns(columns []string) error { + w.columns = append([]string(nil), columns...) + if _, err := w.writer.WriteString(` + + + + + GoNavi Export + + + +
+
+

GoNavi Data Export

+
`); err != nil { + return err + } + + if _, err := fmt.Fprintf(w.writer, "Columns: %d · Generated: %s", len(columns), time.Now().Format("2006-01-02 15:04:05")); err != nil { + return err + } + + if _, err := w.writer.WriteString(`
+
+
+ + `); err != nil { + return err + } + + for _, col := range columns { + if _, err := fmt.Fprintf(w.writer, "", html.EscapeString(col)); err != nil { + return err + } + } + + _, err := w.writer.WriteString(``) + return err +} + +func (w *htmlExportFileWriter) ConsumeRow(row map[string]interface{}) error { + if _, err := w.writer.WriteString(""); err != nil { + return err + } + for _, col := range w.columns { + if _, err := fmt.Fprintf(w.writer, "", formatExportHTMLCell(row[col])); err != nil { + return err + } + } + if _, err := w.writer.WriteString(""); err != nil { + return err + } + w.rowCount++ + return nil +} + +func (w *htmlExportFileWriter) ConsumeRowValues(values []interface{}) error { + if _, err := w.writer.WriteString(""); err != nil { + return err + } + for i := range w.columns { + var value interface{} + if i < len(values) { + value = values[i] + } + if _, err := fmt.Fprintf(w.writer, "", formatExportHTMLCell(value)); err != nil { + return err + } + } + if _, err := w.writer.WriteString(""); err != nil { + return err + } + w.rowCount++ + return nil +} + +func (w *htmlExportFileWriter) Close() error { + if w.rowCount == 0 { + colspan := len(w.columns) + if colspan <= 0 { + colspan = 1 + } + if _, err := fmt.Fprintf(w.writer, ``, colspan); err != nil { + return err + } + } + if _, err := w.writer.WriteString(`
%s
%s
%s
(0 rows)
+
+
+ +`); err != nil { + return err + } + return w.writer.Flush() +} + +type xlsxExportFileWriter struct { + filename string + workbook *excelize.File + stream *excelize.StreamWriter + sheet string + columns []string + header []interface{} + rowBuf []interface{} + nextRow int + sheetNo int + rowCount int + maxRows int +} + +func newXLSXExportFileWriter(filename string, maxRowsPerSheet int) (*xlsxExportFileWriter, error) { + workbook := excelize.NewFile() + sheet := workbook.GetSheetName(workbook.GetActiveSheetIndex()) + stream, err := workbook.NewStreamWriter(sheet) + if err != nil { + _ = workbook.Close() + return nil, err + } + return &xlsxExportFileWriter{ + filename: filename, + workbook: workbook, + stream: stream, + sheet: sheet, + sheetNo: 1, + nextRow: 2, + maxRows: normalizeXLSXRowsPerSheet(maxRowsPerSheet), + }, nil +} + +func (w *xlsxExportFileWriter) SetColumns(columns []string) error { + w.columns = append([]string(nil), columns...) + w.rowCount = 0 + w.nextRow = 2 + w.header = make([]interface{}, len(columns)) + w.rowBuf = make([]interface{}, len(columns)) + for i, col := range columns { + w.header[i] = col + } + return w.stream.SetRow("A1", w.header) +} + +func (w *xlsxExportFileWriter) rotateSheet() error { + if err := w.stream.Flush(); err != nil { + return err + } + w.sheetNo++ + w.sheet = fmt.Sprintf("Sheet%d", w.sheetNo) + if _, err := w.workbook.NewSheet(w.sheet); err != nil { + return err + } + stream, err := w.workbook.NewStreamWriter(w.sheet) + if err != nil { + return err + } + w.stream = stream + w.rowCount = 0 + w.nextRow = 2 + return w.stream.SetRow("A1", w.header) +} + +func (w *xlsxExportFileWriter) ConsumeRow(row map[string]interface{}) error { + if w.rowCount >= w.maxRows { + if err := w.rotateSheet(); err != nil { + return err + } + } + values := w.rowBuf + for i, col := range w.columns { + val := row[col] + if val == nil { + values[i] = "NULL" + continue + } + values[i] = formatExportCellText(val) + } + cell := "A" + strconv.Itoa(w.nextRow) + w.nextRow++ + w.rowCount++ + return w.stream.SetRow(cell, values) +} + +func (w *xlsxExportFileWriter) ConsumeRowValues(values []interface{}) error { + if w.rowCount >= w.maxRows { + if err := w.rotateSheet(); err != nil { + return err + } + } + rowBuf := w.rowBuf + for i := range w.columns { + var value interface{} + if i < len(values) { + value = values[i] + } + if value == nil { + rowBuf[i] = "NULL" + continue + } + rowBuf[i] = formatExportCellText(value) + } + cell := "A" + strconv.Itoa(w.nextRow) + w.nextRow++ + w.rowCount++ + return w.stream.SetRow(cell, rowBuf) +} + +func (w *xlsxExportFileWriter) Close() error { + if err := w.stream.Flush(); err != nil { + _ = w.workbook.Close() + return err + } + saveErr := w.workbook.SaveAs(w.filename) + closeErr := w.workbook.Close() + if saveErr != nil { + return saveErr + } + return closeErr +} + +type sqlInsertExportConsumer struct { + w *bufio.Writer + dbType string + quotedTable string + columnTypeMap map[string]string + columns []string + quotedCols []string + columnTypes []string + valueBuf []string + rowCount int64 +} + +func (c *sqlInsertExportConsumer) SetColumns(columns []string) error { + c.columns = append([]string(nil), columns...) + c.quotedCols = make([]string, 0, len(columns)) + c.columnTypes = make([]string, len(columns)) + c.valueBuf = make([]string, len(columns)) + for _, column := range columns { + c.quotedCols = append(c.quotedCols, quoteIdentByType(c.dbType, column)) + } + for i, column := range columns { + c.columnTypes[i] = c.columnTypeMap[normalizeColumnName(column)] + } + return nil +} + +func (c *sqlInsertExportConsumer) ConsumeRow(row map[string]interface{}) error { + values := make([]string, 0, len(c.columns)) + for _, column := range c.columns { + values = append(values, formatImportSQLValue(c.dbType, c.columnTypeMap[normalizeColumnName(column)], row[column])) + } + if _, err := c.w.WriteString(fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s);\n", c.quotedTable, strings.Join(c.quotedCols, ", "), strings.Join(values, ", "))); err != nil { + return err + } + c.rowCount++ + return nil +} + +func (c *sqlInsertExportConsumer) ConsumeRowValues(values []interface{}) error { + for i := range c.columns { + var value interface{} + if i < len(values) { + value = values[i] + } + c.valueBuf[i] = formatImportSQLValue(c.dbType, c.columnTypes[i], value) + } + if _, err := c.w.WriteString(fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s);\n", c.quotedTable, strings.Join(c.quotedCols, ", "), strings.Join(c.valueBuf, ", "))); err != nil { + return err + } + c.rowCount++ + return nil +} + +func resolveExportColumns(columns []string, data []map[string]interface{}) []string { + if len(columns) > 0 || len(data) == 0 { + return columns + } + keySet := make(map[string]bool) + for _, row := range data { + for key := range row { + keySet[key] = true + } + } + derived := make([]string, 0, len(keySet)) + for key := range keySet { + derived = append(derived, key) + } + sort.Strings(derived) + return derived +} + +func newExportFileWriter(f *os.File, options ExportFileOptions) (exportFileWriter, error) { + options = normalizeExportFileOptions("", options) + switch options.Format { + case "csv": + return newCSVExportFileWriter(f) + case "json": + return newJSONExportFileWriter(f) + case "md": + return &markdownExportFileWriter{file: f}, nil + case "html": + return newHTMLExportFileWriter(f), nil + case "xlsx": + filename := f.Name() + if err := f.Close(); err != nil { + return nil, err + } + return newXLSXExportFileWriter(filename, options.XLSXMaxRowsPerSheet) + default: + return nil, fmt.Errorf("unsupported format: %s", options.Format) + } +} + +func streamQueryDataForExport(dbInst db.Database, config connection.ConnectionConfig, query string, consumer db.QueryStreamConsumer) error { + if consumer == nil { + return fmt.Errorf("export consumer required") + } + + timeout := getExportQueryTimeout(config) + ctx, cancel := utils.ContextWithTimeout(timeout) + defer cancel() + + if streamer, ok := dbInst.(db.StreamQueryExecer); ok { + return streamer.StreamQueryContext(ctx, query, consumer) + } + + if provider, ok := dbInst.(db.SessionExecerProvider); ok { + session, err := provider.OpenSessionExecer(ctx) + if err != nil { + logger.Warnf("导出流式会话打开失败,回退到缓冲导出:type=%s err=%v", strings.TrimSpace(config.Type), err) + } else { + defer session.Close() + if streamer, ok := session.(db.StreamQueryExecer); ok { + return streamer.StreamQueryContext(ctx, query, consumer) + } + } + } + + logger.Warnf("导出流式查询不可用,回退到缓冲导出:type=%s", strings.TrimSpace(config.Type)) + data, columns, err := queryDataForExport(dbInst, config, query) + if err != nil { + return err + } + columns = resolveExportColumns(columns, data) + if err := consumer.SetColumns(columns); err != nil { + return err + } + for _, row := range data { + if err := consumer.ConsumeRow(row); err != nil { + return err + } + } + return nil +} + +func exportQueryResultToFile(f *os.File, dbInst db.Database, config connection.ConnectionConfig, query string, options ExportFileOptions, reporter *exportProgressReporter) (int64, []string, error) { + writer, err := newExportFileWriter(f, options) + if err != nil { + return 0, nil, err + } + + if reporter != nil { + reporter.Start("正在查询数据") + } + consumer := &countingExportConsumer{delegate: writer, reporter: reporter} + streamErr := streamQueryDataForExport(dbInst, config, query, consumer) + if reporter != nil && streamErr == nil { + reporter.Finalizing(consumer.rowCount) + } + closeErr := writer.Close() + if streamErr != nil { + return consumer.rowCount, consumer.columns, streamErr + } + if closeErr != nil { + return consumer.rowCount, consumer.columns, closeErr + } + return consumer.rowCount, consumer.columns, nil +} + +func fillExportRecordFromValues(record []string, values []interface{}, markdown bool) []string { + if len(record) != len(values) { + record = make([]string, len(values)) + } + for i, val := range values { + record[i] = formatExportRecordValue(val, markdown) + } + return record +} + +func fillExportRecordFromRow(record []string, row map[string]interface{}, columns []string, markdown bool) []string { + if len(record) != len(columns) { + record = make([]string, len(columns)) + } + for i, col := range columns { + record[i] = formatExportRecordValue(row[col], markdown) + } + return record +} + +func formatExportRecordValue(val interface{}, markdown bool) string { + if val == nil { + return "NULL" + } + text := formatExportCellText(val) + if markdown { + text = strings.ReplaceAll(text, "|", "\\|") + text = strings.ReplaceAll(text, "\n", "
") + } + return text +} + +func writeRowsToFile(f *os.File, data []map[string]interface{}, columns []string, options ExportFileOptions) error { + _, err := writeRowsToFileWithReporter(f, data, columns, options, nil) + return err +} + +func writeRowsToFileWithReporter(f *os.File, data []map[string]interface{}, columns []string, options ExportFileOptions, reporter *exportProgressReporter) (int64, error) { + if f == nil { + return 0, fmt.Errorf("file required") + } + columns = resolveExportColumns(columns, data) + writer, err := newExportFileWriter(f, options) + if err != nil { + return 0, err + } + if err := writer.SetColumns(columns); err != nil { + _ = writer.Close() + return 0, err + } + if reporter != nil { + reporter.ForceRunning(0, "正在写入文件") + } + for index, row := range data { + if err := writer.ConsumeRow(row); err != nil { + _ = writer.Close() + return int64(index), err + } + if reporter != nil { + reporter.Rows(int64(index+1), "正在写入文件") + } + } + if reporter != nil { + reporter.Finalizing(int64(len(data))) + } + if err := writer.Close(); err != nil { + return int64(len(data)), err + } + return int64(len(data)), nil +} + func formatExportHTMLCell(val interface{}) string { text := formatExportCellText(val) escaped := html.EscapeString(text) @@ -3677,13 +4474,11 @@ func formatExportCellText(val interface{}) string { return "NULL" } return text + case string: + return normalizeExportTemporalText(v) default: text := fmt.Sprintf("%v", val) - // 字符串型日期时间值(如 RFC3339 "2026-03-10T17:01:55+08:00")统一格式化为 yyyy-MM-dd HH:mm:ss - if parsed, ok := parseTemporalString(text); ok { - return parsed.Format("2006-01-02 15:04:05") - } - return text + return normalizeExportTemporalText(text) } } @@ -3701,10 +4496,7 @@ func normalizeExportJSONValue(val interface{}) interface{} { } return v.Format("2006-01-02 15:04:05") case string: - if parsed, ok := parseTemporalString(v); ok { - return parsed.Format("2006-01-02 15:04:05") - } - return v + return normalizeExportTemporalText(v) case float32: f := float64(v) if math.IsNaN(f) || math.IsInf(f, 0) { diff --git a/internal/app/methods_file_export_test.go b/internal/app/methods_file_export_test.go index 7aa5176..903ec17 100644 --- a/internal/app/methods_file_export_test.go +++ b/internal/app/methods_file_export_test.go @@ -11,6 +11,8 @@ import ( "time" "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/db" + "github.com/xuri/excelize/v2" ) type fakeExportQueryDB struct { @@ -24,6 +26,23 @@ type fakeExportQueryDB struct { hasContextDeadline bool } +type fakeStreamExportDB struct { + fakeExportQueryDB + streamData []map[string]interface{} + streamCols []string + streamHits int + queryHits int +} + +type fakeValueStreamExportDB struct { + fakeExportQueryDB + streamCols []string + streamValues [][]interface{} + streamHits int + queryHits int + valueHits int +} + func (f *fakeExportQueryDB) Connect(config connection.ConnectionConfig) error { return nil } func (f *fakeExportQueryDB) Close() error { return nil } func (f *fakeExportQueryDB) Ping() error { return nil } @@ -63,6 +82,77 @@ func (f *fakeExportQueryDB) GetTriggers(dbName, tableName string) ([]connection. return nil, nil } +func (f *fakeStreamExportDB) Query(query string) ([]map[string]interface{}, []string, error) { + f.queryHits++ + return f.fakeExportQueryDB.Query(query) +} + +func (f *fakeStreamExportDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { + f.queryHits++ + return f.fakeExportQueryDB.QueryContext(ctx, query) +} + +func (f *fakeStreamExportDB) StreamQuery(query string, consumer db.QueryStreamConsumer) error { + return f.StreamQueryContext(context.Background(), query, consumer) +} + +func (f *fakeStreamExportDB) StreamQueryContext(_ context.Context, query string, consumer db.QueryStreamConsumer) error { + f.streamHits++ + f.lastQuery = query + if err := consumer.SetColumns(f.streamCols); err != nil { + return err + } + for _, row := range f.streamData { + if err := consumer.ConsumeRow(row); err != nil { + return err + } + } + return nil +} + +func (f *fakeValueStreamExportDB) Query(query string) ([]map[string]interface{}, []string, error) { + f.queryHits++ + return f.fakeExportQueryDB.Query(query) +} + +func (f *fakeValueStreamExportDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { + f.queryHits++ + return f.fakeExportQueryDB.QueryContext(ctx, query) +} + +func (f *fakeValueStreamExportDB) StreamQuery(query string, consumer db.QueryStreamConsumer) error { + return f.StreamQueryContext(context.Background(), query, consumer) +} + +func (f *fakeValueStreamExportDB) StreamQueryContext(_ context.Context, query string, consumer db.QueryStreamConsumer) error { + f.streamHits++ + f.lastQuery = query + if err := consumer.SetColumns(f.streamCols); err != nil { + return err + } + if valueConsumer, ok := consumer.(db.QueryStreamValueConsumer); ok { + for _, row := range f.streamValues { + f.valueHits++ + if err := valueConsumer.ConsumeRowValues(row); err != nil { + return err + } + } + return nil + } + for _, row := range f.streamValues { + entry := make(map[string]interface{}, len(f.streamCols)) + for idx, column := range f.streamCols { + if idx < len(row) { + entry[column] = row[idx] + } + } + if err := consumer.ConsumeRow(entry); err != nil { + return err + } + } + return nil +} + func TestFormatExportCellText_FloatNoScientificNotation(t *testing.T) { got := formatExportCellText(1.445663e+06) if strings.Contains(strings.ToLower(got), "e+") || strings.Contains(strings.ToLower(got), "e-") { @@ -86,7 +176,7 @@ func TestWriteRowsToFile_Markdown_NumberKeepPlainText(t *testing.T) { } columns := []string{"id"} - if err := writeRowsToFile(f, data, columns, "md"); err != nil { + if err := writeRowsToFile(f, data, columns, ExportFileOptions{Format: "md"}); err != nil { t.Fatalf("写入 md 失败: %v", err) } @@ -116,7 +206,7 @@ func TestWriteRowsToFile_JSON_NumberKeepPlainText(t *testing.T) { } columns := []string{"id"} - if err := writeRowsToFile(f, data, columns, "json"); err != nil { + if err := writeRowsToFile(f, data, columns, ExportFileOptions{Format: "json"}); err != nil { t.Fatalf("写入 json 失败: %v", err) } @@ -166,6 +256,24 @@ func TestFormatExportCellText_TimeValue_KeepWallClock(t *testing.T) { } } +func TestFormatExportCellText_StringRFC3339_KeepWallClock(t *testing.T) { + originalLocal := time.Local + time.Local = time.FixedZone("UTC+8", 8*60*60) + defer func() { time.Local = originalLocal }() + + got := formatExportCellText("2026-04-07T10:44:32Z") + if got != "2026-04-07 10:44:32" { + t.Fatalf("字符串时间导出应保持原始钟表时间,want=%q got=%q", "2026-04-07 10:44:32", got) + } +} + +func TestFormatExportCellText_PlainString_Untouched(t *testing.T) { + got := formatExportCellText("plain export payload without timezone marker") + if got != "plain export payload without timezone marker" { + t.Fatalf("普通字符串不应被改写,got=%q", got) + } +} + func TestParseTemporalString_LocalDateTime_NoTimezoneShift(t *testing.T) { originalLocal := time.Local time.Local = time.FixedZone("UTC+8", 8*60*60) @@ -248,6 +356,105 @@ func TestQueryDataForExport_UsesLargerConfiguredTimeout(t *testing.T) { } } +func TestExportQueryResultToFile_UsesStreamQueryPath(t *testing.T) { + f, err := os.CreateTemp("", "gonavi-export-stream-*.csv") + if err != nil { + t.Fatalf("创建临时文件失败: %v", err) + } + defer os.Remove(f.Name()) + defer f.Close() + + fake := &fakeStreamExportDB{ + fakeExportQueryDB: fakeExportQueryDB{ + err: context.DeadlineExceeded, + data: []map[string]interface{}{{"id": 999}}, + cols: []string{"id"}, + }, + streamCols: []string{"id", "name"}, + streamData: []map[string]interface{}{ + {"id": 1, "name": "alice"}, + {"id": 2, "name": "bob"}, + }, + } + + rowCount, columns, err := exportQueryResultToFile( + f, + fake, + connection.ConnectionConfig{Type: "mysql", Timeout: 10}, + "SELECT id, name FROM users", + ExportFileOptions{Format: "csv"}, + nil, + ) + if err != nil { + t.Fatalf("exportQueryResultToFile 返回错误: %v", err) + } + if fake.streamHits != 1 { + t.Fatalf("应优先使用流式查询,streamHits=%d", fake.streamHits) + } + if fake.queryHits != 0 { + t.Fatalf("不应回退到缓冲查询,queryHits=%d", fake.queryHits) + } + if rowCount != 2 { + t.Fatalf("导出行数异常,want=2 got=%d", rowCount) + } + if len(columns) != 2 || columns[0] != "id" || columns[1] != "name" { + t.Fatalf("导出列异常,got=%v", columns) + } + + contentBytes, err := os.ReadFile(f.Name()) + if err != nil { + t.Fatalf("读取导出文件失败: %v", err) + } + content := string(contentBytes) + if !strings.Contains(content, "alice") || !strings.Contains(content, "bob") { + t.Fatalf("流式导出内容异常: %s", content) + } +} + +func TestExportQueryResultToFile_UsesValueStreamPathWhenAvailable(t *testing.T) { + f, err := os.CreateTemp("", "gonavi-export-stream-values-*.csv") + if err != nil { + t.Fatalf("创建临时文件失败: %v", err) + } + defer os.Remove(f.Name()) + defer f.Close() + + fake := &fakeValueStreamExportDB{ + streamCols: []string{"id", "name"}, + streamValues: [][]interface{}{ + {1, "alice"}, + {2, "bob"}, + }, + } + + rowCount, columns, err := exportQueryResultToFile( + f, + fake, + connection.ConnectionConfig{Type: "mysql", Timeout: 10}, + "SELECT id, name FROM users", + ExportFileOptions{Format: "csv"}, + nil, + ) + if err != nil { + t.Fatalf("exportQueryResultToFile 返回错误: %v", err) + } + if fake.streamHits != 1 { + t.Fatalf("应优先使用流式查询,streamHits=%d", fake.streamHits) + } + if fake.valueHits != 2 { + t.Fatalf("应走值数组流式路径,valueHits=%d", fake.valueHits) + } + if fake.queryHits != 0 { + t.Fatalf("不应回退到缓冲查询,queryHits=%d", fake.queryHits) + } + if rowCount != 2 { + t.Fatalf("导出行数异常,want=2 got=%d", rowCount) + } + if len(columns) != 2 || columns[0] != "id" || columns[1] != "name" { + t.Fatalf("导出列异常,got=%v", columns) + } +} + func TestGetExportQueryTimeout_ClickHouseUsesLongerMinimum(t *testing.T) { timeout := getExportQueryTimeout(connection.ConnectionConfig{ Type: "clickhouse", @@ -300,7 +507,7 @@ func TestWriteRowsToFile_HTML_EscapeAndStyle(t *testing.T) { } columns := []string{"name", "note", "nullable"} - if err := writeRowsToFile(f, data, columns, "html"); err != nil { + if err := writeRowsToFile(f, data, columns, ExportFileOptions{Format: "html"}); err != nil { t.Fatalf("写入 html 失败: %v", err) } @@ -343,7 +550,7 @@ func TestWriteRowsToFile_HTML_EscapeHeader(t *testing.T) { columnName := "name" data := []map[string]interface{}{{columnName: "ok"}} - if err := writeRowsToFile(f, data, []string{columnName}, "html"); err != nil { + if err := writeRowsToFile(f, data, []string{columnName}, ExportFileOptions{Format: "html"}); err != nil { t.Fatalf("写入 html 失败: %v", err) } contentBytes, _ := os.ReadFile(f.Name()) @@ -353,6 +560,175 @@ func TestWriteRowsToFile_HTML_EscapeHeader(t *testing.T) { } } +func TestWriteRowsToFile_XLSX_SplitsByMaxRowsPerSheet(t *testing.T) { + f, err := os.CreateTemp("", "gonavi-export-*.xlsx") + if err != nil { + t.Fatalf("创建临时文件失败: %v", err) + } + defer os.Remove(f.Name()) + defer f.Close() + + data := []map[string]interface{}{ + {"id": 1, "name": "alice"}, + {"id": 2, "name": "bob"}, + {"id": 3, "name": "carol"}, + } + columns := []string{"id", "name"} + + if err := writeRowsToFile(f, data, columns, ExportFileOptions{ + Format: "xlsx", + XLSXMaxRowsPerSheet: 2, + }); err != nil { + t.Fatalf("写入 xlsx 失败: %v", err) + } + + workbook, err := excelize.OpenFile(f.Name()) + if err != nil { + t.Fatalf("打开 xlsx 失败: %v", err) + } + defer workbook.Close() + + sheets := workbook.GetSheetList() + if len(sheets) != 2 { + t.Fatalf("sheet 数量异常,want=2 got=%d (%v)", len(sheets), sheets) + } + + rows1, err := workbook.GetRows("Sheet1") + if err != nil { + t.Fatalf("读取 Sheet1 失败: %v", err) + } + if len(rows1) != 3 { + t.Fatalf("Sheet1 行数异常,want=3 got=%d", len(rows1)) + } + + rows2, err := workbook.GetRows("Sheet2") + if err != nil { + t.Fatalf("读取 Sheet2 失败: %v", err) + } + if len(rows2) != 2 { + t.Fatalf("Sheet2 行数异常,want=2 got=%d", len(rows2)) + } + if rows2[1][1] != "carol" { + t.Fatalf("Sheet2 数据异常,want=%q got=%q", "carol", rows2[1][1]) + } +} + +func benchmarkExportRows(rowCount int) ([]map[string]interface{}, []string) { + columns := []string{"id", "name", "note", "created_at", "status"} + rows := make([]map[string]interface{}, rowCount) + for i := 0; i < rowCount; i++ { + rows[i] = map[string]interface{}{ + "id": i + 1, + "name": "benchmark-user", + "note": "plain export payload without timezone marker", + "created_at": "2026-06-17 12:34:56", + "status": "enabled", + } + } + return rows, columns +} + +func benchmarkExportRowValues(rowCount int) ([][]interface{}, []string) { + columns := []string{"id", "name", "note", "created_at", "status"} + rows := make([][]interface{}, rowCount) + for i := 0; i < rowCount; i++ { + rows[i] = []interface{}{ + i + 1, + "benchmark-user", + "plain export payload without timezone marker", + "2026-06-17 12:34:56", + "enabled", + } + } + return rows, columns +} + +func BenchmarkFormatExportCellText_PlainString(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _ = formatExportCellText("plain export payload without timezone marker") + } +} + +func BenchmarkWriteRowsToFile_XLSX_20000Rows(b *testing.B) { + rows, columns := benchmarkExportRows(20000) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + f, err := os.CreateTemp("", "gonavi-export-bench-*.xlsx") + if err != nil { + b.Fatalf("创建临时文件失败: %v", err) + } + name := f.Name() + if err := writeRowsToFile(f, rows, columns, ExportFileOptions{Format: "xlsx"}); err != nil { + _ = os.Remove(name) + b.Fatalf("写入 xlsx 失败: %v", err) + } + if err := os.Remove(name); err != nil { + b.Fatalf("删除临时文件失败: %v", err) + } + } +} + +func BenchmarkExportQueryResultToFile_XLSX_StreamMap_20000Rows(b *testing.B) { + rows, columns := benchmarkExportRows(20000) + streamDB := &fakeStreamExportDB{ + streamCols: columns, + streamData: rows, + } + b.ReportAllocs() + for i := 0; i < b.N; i++ { + f, err := os.CreateTemp("", "gonavi-export-stream-map-*.xlsx") + if err != nil { + b.Fatalf("创建临时文件失败: %v", err) + } + name := f.Name() + if _, _, err := exportQueryResultToFile( + f, + streamDB, + connection.ConnectionConfig{Type: "mysql", Timeout: 10}, + "SELECT * FROM users", + ExportFileOptions{Format: "xlsx"}, + nil, + ); err != nil { + _ = os.Remove(name) + b.Fatalf("流式 map 导出失败: %v", err) + } + if err := os.Remove(name); err != nil { + b.Fatalf("删除临时文件失败: %v", err) + } + } +} + +func BenchmarkExportQueryResultToFile_XLSX_StreamValues_20000Rows(b *testing.B) { + rows, columns := benchmarkExportRowValues(20000) + streamDB := &fakeValueStreamExportDB{ + streamCols: columns, + streamValues: rows, + } + b.ReportAllocs() + for i := 0; i < b.N; i++ { + f, err := os.CreateTemp("", "gonavi-export-stream-values-*.xlsx") + if err != nil { + b.Fatalf("创建临时文件失败: %v", err) + } + name := f.Name() + if _, _, err := exportQueryResultToFile( + f, + streamDB, + connection.ConnectionConfig{Type: "mysql", Timeout: 10}, + "SELECT * FROM users", + ExportFileOptions{Format: "xlsx"}, + nil, + ); err != nil { + _ = os.Remove(name) + b.Fatalf("流式值数组导出失败: %v", err) + } + if err := os.Remove(name); err != nil { + b.Fatalf("删除临时文件失败: %v", err) + } + } +} + func TestFormatImportSQLValue_NormalizesTimestampWithoutTimezone(t *testing.T) { got := formatImportSQLValue("postgres", "timestamp without time zone", "2026-01-21T18:32:26+08:00") if got != "'2026-01-21 18:32:26'" { diff --git a/internal/db/clickhouse_impl.go b/internal/db/clickhouse_impl.go index ff12e72..0dac365 100644 --- a/internal/db/clickhouse_impl.go +++ b/internal/db/clickhouse_impl.go @@ -777,6 +777,22 @@ func (c *ClickHouseDB) Query(query string) ([]map[string]interface{}, []string, return scanRows(rows) } +func (c *ClickHouseDB) StreamQueryContext(ctx context.Context, query string, consumer QueryStreamConsumer) error { + if c.conn == nil { + return fmt.Errorf("连接未打开") + } + rows, err := c.conn.QueryContext(ctx, query) + if err != nil { + return err + } + defer rows.Close() + return streamRows(rows, consumer) +} + +func (c *ClickHouseDB) StreamQuery(query string, consumer QueryStreamConsumer) error { + return c.StreamQueryContext(context.Background(), query, consumer) +} + func (c *ClickHouseDB) ExecContext(ctx context.Context, query string) (int64, error) { if c.conn == nil { return 0, fmt.Errorf("连接未打开") diff --git a/internal/db/custom_impl.go b/internal/db/custom_impl.go index 0ad1ba8..0f2d027 100644 --- a/internal/db/custom_impl.go +++ b/internal/db/custom_impl.go @@ -111,6 +111,24 @@ func (c *CustomDB) Query(query string) ([]map[string]interface{}, []string, erro return scanRowsForDialect(rows, c.scanDialect()) } +func (c *CustomDB) StreamQueryContext(ctx context.Context, query string, consumer QueryStreamConsumer) error { + if c.conn == nil { + return fmt.Errorf("连接未打开") + } + + rows, err := c.conn.QueryContext(ctx, query) + if err != nil { + return err + } + defer rows.Close() + + return streamRowsForDialect(rows, c.scanDialect(), consumer) +} + +func (c *CustomDB) StreamQuery(query string, consumer QueryStreamConsumer) error { + return c.StreamQueryContext(context.Background(), query, consumer) +} + func (c *CustomDB) scanDialect() string { if strings.EqualFold(strings.TrimSpace(c.driver), "mysql") { return "mysql" diff --git a/internal/db/dameng_impl.go b/internal/db/dameng_impl.go index d9ce83b..7b05601 100644 --- a/internal/db/dameng_impl.go +++ b/internal/db/dameng_impl.go @@ -182,6 +182,24 @@ func (d *DamengDB) Query(query string) ([]map[string]interface{}, []string, erro return scanRows(rows) } +func (d *DamengDB) StreamQueryContext(ctx context.Context, query string, consumer QueryStreamConsumer) error { + if d.conn == nil { + return fmt.Errorf("连接未打开") + } + + rows, err := d.conn.QueryContext(ctx, query) + if err != nil { + return err + } + defer rows.Close() + + return streamRows(rows, consumer) +} + +func (d *DamengDB) StreamQuery(query string, consumer QueryStreamConsumer) error { + return d.StreamQueryContext(context.Background(), query, consumer) +} + func (d *DamengDB) ExecContext(ctx context.Context, query string) (int64, error) { if d.conn == nil { return 0, fmt.Errorf("连接未打开") diff --git a/internal/db/database.go b/internal/db/database.go index a0dd411..8f69b02 100644 --- a/internal/db/database.go +++ b/internal/db/database.go @@ -76,6 +76,28 @@ type StatementQueryExecer interface { QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) } +// QueryStreamConsumer receives query metadata and rows incrementally. +// Implementations can stream rows directly to files to avoid buffering entire result sets in memory. +type QueryStreamConsumer interface { + SetColumns(columns []string) error + ConsumeRow(row map[string]interface{}) error +} + +// QueryStreamValueConsumer is an optional fast path for stream consumers that +// can consume normalized row values in column order without requiring a +// map[string]interface{} allocation per row. +type QueryStreamValueConsumer interface { + SetColumns(columns []string) error + ConsumeRowValues(values []interface{}) error +} + +// StreamQueryExecer is an optional interface for drivers or pinned sessions that can +// stream query rows incrementally instead of materializing []map rows in memory. +type StreamQueryExecer interface { + StreamQuery(query string, consumer QueryStreamConsumer) error + StreamQueryContext(ctx context.Context, query string, consumer QueryStreamConsumer) error +} + // StatementQueryMessageExecer can run queries on a pinned session and return // extra server messages/notices alongside rows. type StatementQueryMessageExecer interface { @@ -178,6 +200,22 @@ func (e *sqlConnStatementExecer) Query(query string) ([]map[string]interface{}, return e.QueryContext(context.Background(), query) } +func (e *sqlConnStatementExecer) StreamQueryContext(ctx context.Context, query string, consumer QueryStreamConsumer) error { + if e == nil || e.conn == nil { + return fmt.Errorf("连接未打开") + } + rows, err := e.conn.QueryContext(ctx, query) + if err != nil { + return err + } + defer rows.Close() + return streamRowsForDialect(rows, e.scanDialect, consumer) +} + +func (e *sqlConnStatementExecer) StreamQuery(query string, consumer QueryStreamConsumer) error { + return e.StreamQueryContext(context.Background(), query, consumer) +} + func (e *sqlConnStatementExecer) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) { if e == nil || e.conn == nil { return nil, fmt.Errorf("连接未打开") @@ -275,6 +313,23 @@ func (e *sqlConnTransactionExecer) Query(query string) ([]map[string]interface{} return e.QueryContext(context.Background(), query) } +func (e *sqlConnTransactionExecer) StreamQueryContext(ctx context.Context, query string, consumer QueryStreamConsumer) error { + conn, err := e.activeConn() + if err != nil { + return err + } + rows, err := conn.QueryContext(ctx, query) + if err != nil { + return err + } + defer rows.Close() + return streamRowsForDialect(rows, e.scanDialect, consumer) +} + +func (e *sqlConnTransactionExecer) StreamQuery(query string, consumer QueryStreamConsumer) error { + return e.StreamQueryContext(context.Background(), query, consumer) +} + func (e *sqlConnTransactionExecer) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) { conn, err := e.activeConn() if err != nil { @@ -401,6 +456,23 @@ func (e *sqlTxStatementExecer) Query(query string) ([]map[string]interface{}, [] return e.QueryContext(context.Background(), query) } +func (e *sqlTxStatementExecer) StreamQueryContext(ctx context.Context, query string, consumer QueryStreamConsumer) error { + tx, err := e.activeTx() + if err != nil { + return err + } + rows, err := tx.QueryContext(ctx, query) + if err != nil { + return err + } + defer rows.Close() + return streamRows(rows, consumer) +} + +func (e *sqlTxStatementExecer) StreamQuery(query string, consumer QueryStreamConsumer) error { + return e.StreamQueryContext(context.Background(), query, consumer) +} + func (e *sqlTxStatementExecer) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) { tx, err := e.activeTx() if err != nil { diff --git a/internal/db/scan_rows.go b/internal/db/scan_rows.go index cac7064..3277fcf 100644 --- a/internal/db/scan_rows.go +++ b/internal/db/scan_rows.go @@ -11,6 +11,19 @@ func scanRows(rows *sql.Rows) ([]map[string]interface{}, []string, error) { return scanRowsForDialect(rows, "") } +func streamRows(rows *sql.Rows, consumer QueryStreamConsumer) error { + return streamRowsForDialect(rows, "", consumer) +} + +type queryRowScanner struct { + columns []string + dbTypeNames []string + dialect string + values []interface{} + normalized []interface{} + valuePtrs []interface{} +} + func scanRowsForDialect(rows *sql.Rows, dialect string) ([]map[string]interface{}, []string, error) { columns, err := rows.Columns() if err != nil { @@ -23,27 +36,14 @@ func scanRowsForDialect(rows *sql.Rows, dialect string) ([]map[string]interface{ colTypes = nil } + scanner := newQueryRowScanner(columns, colTypes, dialect) resultData := make([]map[string]interface{}, 0) for rows.Next() { - values := make([]interface{}, len(columns)) - valuePtrs := make([]interface{}, len(columns)) - for i := range columns { - valuePtrs[i] = &values[i] - } - - if err := rows.Scan(valuePtrs...); err != nil { + entry, err := scanner.scanCurrentRow(rows) + if err != nil { continue } - - entry := make(map[string]interface{}, len(columns)) - for i, col := range columns { - dbTypeName := "" - if colTypes != nil && i < len(colTypes) && colTypes[i] != nil { - dbTypeName = colTypes[i].DatabaseTypeName() - } - entry[col] = normalizeQueryValueWithDBTypeAndDialect(values[i], dbTypeName, dialect) - } resultData = append(resultData, entry) } @@ -53,6 +53,95 @@ func scanRowsForDialect(rows *sql.Rows, dialect string) ([]map[string]interface{ return resultData, columns, nil } +func streamRowsForDialect(rows *sql.Rows, dialect string, consumer QueryStreamConsumer) error { + if consumer == nil { + return fmt.Errorf("query stream consumer required") + } + + columns, err := rows.Columns() + if err != nil { + return err + } + columns = ensureUniqueQueryColumnNames(columns) + + colTypes, err := rows.ColumnTypes() + if err != nil || len(colTypes) != len(columns) { + colTypes = nil + } + + scanner := newQueryRowScanner(columns, colTypes, dialect) + if err := consumer.SetColumns(columns); err != nil { + return err + } + valueConsumer, useValueConsumer := consumer.(QueryStreamValueConsumer) + + for rows.Next() { + if useValueConsumer { + values, err := scanner.scanCurrentRowValues(rows) + if err != nil { + continue + } + if err := valueConsumer.ConsumeRowValues(values); err != nil { + return err + } + continue + } + entry, err := scanner.scanCurrentRow(rows) + if err != nil { + continue + } + if err := consumer.ConsumeRow(entry); err != nil { + return err + } + } + + return rows.Err() +} + +func newQueryRowScanner(columns []string, colTypes []*sql.ColumnType, dialect string) *queryRowScanner { + values := make([]interface{}, len(columns)) + valuePtrs := make([]interface{}, len(columns)) + for i := range columns { + valuePtrs[i] = &values[i] + } + dbTypeNames := make([]string, len(columns)) + for i := range columns { + if colTypes != nil && i < len(colTypes) && colTypes[i] != nil { + dbTypeNames[i] = colTypes[i].DatabaseTypeName() + } + } + return &queryRowScanner{ + columns: columns, + dbTypeNames: dbTypeNames, + dialect: dialect, + values: values, + normalized: make([]interface{}, len(columns)), + valuePtrs: valuePtrs, + } +} + +func (s *queryRowScanner) scanCurrentRowValues(rows *sql.Rows) ([]interface{}, error) { + if err := rows.Scan(s.valuePtrs...); err != nil { + return nil, err + } + for i := range s.columns { + s.normalized[i] = normalizeQueryValueWithDBTypeAndDialect(s.values[i], s.dbTypeNames[i], s.dialect) + } + return s.normalized, nil +} + +func (s *queryRowScanner) scanCurrentRow(rows *sql.Rows) (map[string]interface{}, error) { + normalized, err := s.scanCurrentRowValues(rows) + if err != nil { + return nil, err + } + entry := make(map[string]interface{}, len(s.columns)) + for i, col := range s.columns { + entry[col] = normalized[i] + } + return entry, nil +} + func ensureUniqueQueryColumnNames(columns []string) []string { if len(columns) == 0 { return columns diff --git a/internal/db/sqlserver_impl.go b/internal/db/sqlserver_impl.go index 34839d7..429a488 100644 --- a/internal/db/sqlserver_impl.go +++ b/internal/db/sqlserver_impl.go @@ -385,6 +385,23 @@ func (e *sqlServerSessionExecer) QueryContext(ctx context.Context, query string) return rows, columns, err } +func (e *sqlServerSessionExecer) StreamQueryContext(ctx context.Context, query string, consumer QueryStreamConsumer) error { + if e == nil || e.conn == nil { + return fmt.Errorf("连接未打开") + } + retmsg := &sqlexp.ReturnMessage{} + rows, err := e.conn.QueryContext(ctx, query, retmsg) + if err != nil { + return err + } + defer rows.Close() + return streamRows(rows, consumer) +} + +func (e *sqlServerSessionExecer) StreamQuery(query string, consumer QueryStreamConsumer) error { + return e.StreamQueryContext(context.Background(), query, consumer) +} + func (e *sqlServerSessionExecer) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) { return e.QueryContextWithMessages(context.Background(), query) } diff --git a/internal/db/tdengine_impl.go b/internal/db/tdengine_impl.go index 69782e8..02270cb 100644 --- a/internal/db/tdengine_impl.go +++ b/internal/db/tdengine_impl.go @@ -168,6 +168,24 @@ func (t *TDengineDB) Query(query string) ([]map[string]interface{}, []string, er return scanRows(rows) } +func (t *TDengineDB) StreamQueryContext(ctx context.Context, query string, consumer QueryStreamConsumer) error { + if t.conn == nil { + return fmt.Errorf("连接未打开") + } + + rows, err := t.conn.QueryContext(ctx, query) + if err != nil { + return err + } + defer rows.Close() + + return streamRows(rows, consumer) +} + +func (t *TDengineDB) StreamQuery(query string, consumer QueryStreamConsumer) error { + return t.StreamQueryContext(context.Background(), query, consumer) +} + func (t *TDengineDB) ExecContext(ctx context.Context, query string) (int64, error) { if t.conn == nil { return 0, fmt.Errorf("连接未打开")