Files
MyGoNavi/internal/app/methods_file_export_test.go
Syngnat c8fe90cbee ️ perf(import-export): 降低 OceanBase 导出链路内存占用
- 为 optional driver-agent 补齐 streamQuery 分片协议,避免大结果集整批缓冲到内存
- 在 OceanBase 整表导出和查询结果导出前强校验 driver-agent revision,旧版代理直接拦截并提示重装
- 为 driver-agent 增加大查询和流式导出完成后的 GC/FreeOSMemory 回收逻辑
- 补充导出前校验、流式分片消费和 agent 内存回收的定向测试
- 更新 driver-agent revisions 以匹配新的流式导出协议
2026-06-18 11:32:08 +08:00

1175 lines
35 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package app
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"os"
"strings"
"testing"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/db"
"github.com/xuri/excelize/v2"
)
type fakeExportQueryDB struct {
data []map[string]interface{}
cols []string
err error
defs []connection.ColumnDefinition
lastQuery string
lastContextTimeout time.Duration
hasContextDeadline bool
}
type fakeStreamExportDB struct {
fakeExportQueryDB
streamData []map[string]interface{}
streamCols []string
streamHits int
queryHits int
}
type fakeValueStreamExportDB struct {
fakeExportQueryDB
streamCols []string
streamValues [][]interface{}
streamHits int
queryHits int
valueHits int
}
type fakeGeneratedValueStreamExportDB struct {
streamCols []string
rowCount int
streamHits int
valueHits int
}
func (f *fakeExportQueryDB) Connect(config connection.ConnectionConfig) error { return nil }
func (f *fakeExportQueryDB) Close() error { return nil }
func (f *fakeExportQueryDB) Ping() error { return nil }
func (f *fakeExportQueryDB) Query(query string) ([]map[string]interface{}, []string, error) {
f.lastQuery = query
return f.data, f.cols, f.err
}
func (f *fakeExportQueryDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
f.lastQuery = query
if deadline, ok := ctx.Deadline(); ok {
f.hasContextDeadline = true
f.lastContextTimeout = time.Until(deadline)
}
return f.data, f.cols, f.err
}
func (f *fakeExportQueryDB) Exec(query string) (int64, error) { return 0, nil }
func (f *fakeExportQueryDB) GetDatabases() ([]string, error) { return nil, nil }
func (f *fakeExportQueryDB) GetTables(dbName string) ([]string, error) {
return nil, nil
}
func (f *fakeExportQueryDB) GetCreateStatement(dbName, tableName string) (string, error) {
return "", nil
}
func (f *fakeExportQueryDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
return f.defs, nil
}
func (f *fakeExportQueryDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
return nil, nil
}
func (f *fakeExportQueryDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
return nil, nil
}
func (f *fakeExportQueryDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
return nil, nil
}
func (f *fakeExportQueryDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
return nil, nil
}
func (f *fakeStreamExportDB) Query(query string) ([]map[string]interface{}, []string, error) {
f.queryHits++
return f.fakeExportQueryDB.Query(query)
}
func (f *fakeStreamExportDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
f.queryHits++
return f.fakeExportQueryDB.QueryContext(ctx, query)
}
func (f *fakeStreamExportDB) StreamQuery(query string, consumer db.QueryStreamConsumer) error {
return f.StreamQueryContext(context.Background(), query, consumer)
}
func (f *fakeStreamExportDB) StreamQueryContext(_ context.Context, query string, consumer db.QueryStreamConsumer) error {
f.streamHits++
f.lastQuery = query
if err := consumer.SetColumns(f.streamCols); err != nil {
return err
}
for _, row := range f.streamData {
if err := consumer.ConsumeRow(row); err != nil {
return err
}
}
return nil
}
func (f *fakeValueStreamExportDB) Query(query string) ([]map[string]interface{}, []string, error) {
f.queryHits++
return f.fakeExportQueryDB.Query(query)
}
func (f *fakeValueStreamExportDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
f.queryHits++
return f.fakeExportQueryDB.QueryContext(ctx, query)
}
func (f *fakeValueStreamExportDB) StreamQuery(query string, consumer db.QueryStreamConsumer) error {
return f.StreamQueryContext(context.Background(), query, consumer)
}
func (f *fakeValueStreamExportDB) StreamQueryContext(_ context.Context, query string, consumer db.QueryStreamConsumer) error {
f.streamHits++
f.lastQuery = query
if err := consumer.SetColumns(f.streamCols); err != nil {
return err
}
if valueConsumer, ok := consumer.(db.QueryStreamValueConsumer); ok {
for _, row := range f.streamValues {
f.valueHits++
if err := valueConsumer.ConsumeRowValues(row); err != nil {
return err
}
}
return nil
}
for _, row := range f.streamValues {
entry := make(map[string]interface{}, len(f.streamCols))
for idx, column := range f.streamCols {
if idx < len(row) {
entry[column] = row[idx]
}
}
if err := consumer.ConsumeRow(entry); err != nil {
return err
}
}
return nil
}
func (f *fakeGeneratedValueStreamExportDB) Connect(config connection.ConnectionConfig) error {
return nil
}
func (f *fakeGeneratedValueStreamExportDB) Close() error { return nil }
func (f *fakeGeneratedValueStreamExportDB) Ping() error { return nil }
func (f *fakeGeneratedValueStreamExportDB) Query(query string) ([]map[string]interface{}, []string, error) {
return nil, nil, context.DeadlineExceeded
}
func (f *fakeGeneratedValueStreamExportDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
return nil, nil, context.DeadlineExceeded
}
func (f *fakeGeneratedValueStreamExportDB) Exec(query string) (int64, error) { return 0, nil }
func (f *fakeGeneratedValueStreamExportDB) GetDatabases() ([]string, error) { return nil, nil }
func (f *fakeGeneratedValueStreamExportDB) GetTables(dbName string) ([]string, error) {
return nil, nil
}
func (f *fakeGeneratedValueStreamExportDB) GetCreateStatement(dbName, tableName string) (string, error) {
return "", nil
}
func (f *fakeGeneratedValueStreamExportDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
return nil, nil
}
func (f *fakeGeneratedValueStreamExportDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
return nil, nil
}
func (f *fakeGeneratedValueStreamExportDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
return nil, nil
}
func (f *fakeGeneratedValueStreamExportDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
return nil, nil
}
func (f *fakeGeneratedValueStreamExportDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
return nil, nil
}
func (f *fakeGeneratedValueStreamExportDB) StreamQuery(query string, consumer db.QueryStreamConsumer) error {
return f.StreamQueryContext(context.Background(), query, consumer)
}
func (f *fakeGeneratedValueStreamExportDB) StreamQueryContext(_ context.Context, query string, consumer db.QueryStreamConsumer) error {
f.streamHits++
if err := consumer.SetColumns(f.streamCols); err != nil {
return err
}
valueConsumer, ok := consumer.(db.QueryStreamValueConsumer)
if !ok {
return fmt.Errorf("value stream consumer required")
}
for i := 0; i < f.rowCount; i++ {
f.valueHits++
if err := valueConsumer.ConsumeRowValues([]interface{}{
i + 1,
"benchmark-user",
"plain export payload without timezone marker",
"2026-06-17 12:34:56",
"enabled",
}); err != nil {
return err
}
}
return nil
}
func TestFormatExportCellText_FloatNoScientificNotation(t *testing.T) {
got := formatExportCellText(1.445663e+06)
if strings.Contains(strings.ToLower(got), "e+") || strings.Contains(strings.ToLower(got), "e-") {
t.Fatalf("不应输出科学计数法got=%q", got)
}
if got != "1445663" {
t.Fatalf("浮点整值导出异常want=%q got=%q", "1445663", got)
}
}
func TestWriteRowsToFile_Markdown_NumberKeepPlainText(t *testing.T) {
f, err := os.CreateTemp("", "gonavi-export-*.md")
if err != nil {
t.Fatalf("创建临时文件失败: %v", err)
}
defer os.Remove(f.Name())
defer f.Close()
data := []map[string]interface{}{
{"id": 1.445663e+06},
}
columns := []string{"id"}
if err := writeRowsToFile(f, data, columns, ExportFileOptions{Format: "md"}); err != nil {
t.Fatalf("写入 md 失败: %v", err)
}
contentBytes, err := os.ReadFile(f.Name())
if err != nil {
t.Fatalf("读取 md 失败: %v", err)
}
content := string(contentBytes)
if strings.Contains(strings.ToLower(content), "e+") || strings.Contains(strings.ToLower(content), "e-") {
t.Fatalf("md 导出包含科学计数法: %s", content)
}
if !strings.Contains(content, "| 1445663 |") {
t.Fatalf("md 导出未保留整数字面量content=%s", content)
}
}
func TestWriteRowsToFile_JSON_NumberKeepPlainText(t *testing.T) {
f, err := os.CreateTemp("", "gonavi-export-*.json")
if err != nil {
t.Fatalf("创建临时文件失败: %v", err)
}
defer os.Remove(f.Name())
defer f.Close()
data := []map[string]interface{}{
{"id": 1.445663e+06},
}
columns := []string{"id"}
if err := writeRowsToFile(f, data, columns, ExportFileOptions{Format: "json"}); err != nil {
t.Fatalf("写入 json 失败: %v", err)
}
contentBytes, err := os.ReadFile(f.Name())
if err != nil {
t.Fatalf("读取 json 失败: %v", err)
}
content := string(contentBytes)
if strings.Contains(strings.ToLower(content), "e+") || strings.Contains(strings.ToLower(content), "e-") {
t.Fatalf("json 导出包含科学计数法: %s", content)
}
var decoded []map[string]json.Number
decoder := json.NewDecoder(bytes.NewReader(contentBytes))
decoder.UseNumber()
if err := decoder.Decode(&decoded); err != nil {
t.Fatalf("解析导出 json 失败: %v", err)
}
if len(decoded) != 1 {
t.Fatalf("导出行数异常got=%d", len(decoded))
}
if decoded[0]["id"].String() != "1445663" {
t.Fatalf("json 数值格式异常want=1445663 got=%s", decoded[0]["id"].String())
}
}
func TestNormalizeExportJSONValue_LocalDateTimeString_NoTimezoneShift(t *testing.T) {
originalLocal := time.Local
time.Local = time.FixedZone("UTC+8", 8*60*60)
defer func() { time.Local = originalLocal }()
got := normalizeExportJSONValue("2026-04-07 18:44:32")
if got != "2026-04-07 18:44:32" {
t.Fatalf("本地无时区字符串不应发生时区偏移want=%q got=%v", "2026-04-07 18:44:32", got)
}
}
func TestFormatExportCellText_TimeValue_KeepWallClock(t *testing.T) {
originalLocal := time.Local
time.Local = time.FixedZone("UTC+8", 8*60*60)
defer func() { time.Local = originalLocal }()
utc := time.Date(2026, 4, 7, 10, 44, 32, 0, time.UTC)
got := formatExportCellText(utc)
if got != "2026-04-07 10:44:32" {
t.Fatalf("time.Time 导出应保持原始钟表时间want=%q got=%q", "2026-04-07 10:44:32", got)
}
}
func TestFormatExportCellText_StringRFC3339_KeepWallClock(t *testing.T) {
originalLocal := time.Local
time.Local = time.FixedZone("UTC+8", 8*60*60)
defer func() { time.Local = originalLocal }()
got := formatExportCellText("2026-04-07T10:44:32Z")
if got != "2026-04-07 10:44:32" {
t.Fatalf("字符串时间导出应保持原始钟表时间want=%q got=%q", "2026-04-07 10:44:32", got)
}
}
func TestFormatExportCellText_PlainString_Untouched(t *testing.T) {
got := formatExportCellText("plain export payload without timezone marker")
if got != "plain export payload without timezone marker" {
t.Fatalf("普通字符串不应被改写got=%q", got)
}
}
func TestParseTemporalString_LocalDateTime_NoTimezoneShift(t *testing.T) {
originalLocal := time.Local
time.Local = time.FixedZone("UTC+8", 8*60*60)
defer func() { time.Local = originalLocal }()
parsed, ok := parseTemporalString("2026-04-07 18:44:32")
if !ok {
t.Fatal("parseTemporalString 应成功解析本地日期时间")
}
if parsed.Local().Format("2006-01-02 15:04:05") != "2026-04-07 18:44:32" {
t.Fatalf("无时区时间解析后不应发生偏移got=%q", parsed.Local().Format("2006-01-02 15:04:05"))
}
}
func TestParseTemporalString_RFC3339_KeepWallClock(t *testing.T) {
originalLocal := time.Local
time.Local = time.FixedZone("UTC+8", 8*60*60)
defer func() { time.Local = originalLocal }()
parsed, ok := parseTemporalString("2026-04-07T10:44:32Z")
if !ok {
t.Fatal("parseTemporalString 应成功解析 RFC3339")
}
if parsed.Format("2006-01-02 15:04:05") != "2026-04-07 10:44:32" {
t.Fatalf("RFC3339 解析后应保持原始钟表时间got=%q", parsed.Format("2006-01-02 15:04:05"))
}
}
func TestNormalizeExportJSONValue_TimeValue_KeepWallClock(t *testing.T) {
originalLocal := time.Local
time.Local = time.FixedZone("UTC+8", 8*60*60)
defer func() { time.Local = originalLocal }()
utc := time.Date(2026, 4, 7, 18, 44, 32, 0, time.UTC)
got := normalizeExportJSONValue(utc)
if got != "2026-04-07 18:44:32" {
t.Fatalf("JSON 导出 time.Time 应保持原始钟表时间want=%q got=%v", "2026-04-07 18:44:32", got)
}
}
func TestQueryDataForExport_UsesMinimumTimeout(t *testing.T) {
fake := &fakeExportQueryDB{
data: []map[string]interface{}{{"v": 1}},
cols: []string{"v"},
}
_, _, err := queryDataForExport(fake, connection.ConnectionConfig{Timeout: 10}, "SELECT 1")
if err != nil {
t.Fatalf("queryDataForExport 返回错误: %v", err)
}
if !fake.hasContextDeadline {
t.Fatal("queryDataForExport 应设置 context deadline")
}
if fake.lastQuery != "SELECT 1" {
t.Fatalf("queryDataForExport 查询语句异常want=%q got=%q", "SELECT 1", fake.lastQuery)
}
lowerBound := minExportQueryTimeout - 5*time.Second
upperBound := minExportQueryTimeout + 5*time.Second
if fake.lastContextTimeout < lowerBound || fake.lastContextTimeout > upperBound {
t.Fatalf("导出最小超时异常want≈%s got=%s", minExportQueryTimeout, fake.lastContextTimeout)
}
}
func TestQueryDataForExport_UsesLargerConfiguredTimeout(t *testing.T) {
fake := &fakeExportQueryDB{
data: []map[string]interface{}{{"v": 1}},
cols: []string{"v"},
}
_, _, err := queryDataForExport(fake, connection.ConnectionConfig{Timeout: 900}, "SELECT 1")
if err != nil {
t.Fatalf("queryDataForExport 返回错误: %v", err)
}
if !fake.hasContextDeadline {
t.Fatal("queryDataForExport 应设置 context deadline")
}
expected := 900 * time.Second
lowerBound := expected - 5*time.Second
upperBound := expected + 5*time.Second
if fake.lastContextTimeout < lowerBound || fake.lastContextTimeout > upperBound {
t.Fatalf("导出配置超时异常want≈%s got=%s", expected, fake.lastContextTimeout)
}
}
func TestResolveExportTotalRowsFromRows_PrefersNamedTotalColumn(t *testing.T) {
total, ok := resolveExportTotalRowsFromRows([]map[string]interface{}{
{"COUNT": "96000", "other": 1},
})
if !ok {
t.Fatal("应成功解析导出总行数")
}
if total != 96000 {
t.Fatalf("解析导出总行数错误want=%d got=%d", 96000, total)
}
}
func TestTryResolveExportTableTotalRows_UsesCountQuery(t *testing.T) {
fake := &fakeExportQueryDB{
data: []map[string]interface{}{{"total": int64(128000)}},
cols: []string{"total"},
}
total, ok := tryResolveExportTableTotalRows(
fake,
connection.ConnectionConfig{Type: "mysql", Timeout: 10},
"SYS.test",
)
if !ok {
t.Fatal("应成功解析整表导出总行数")
}
if total != 128000 {
t.Fatalf("整表导出总行数错误want=%d got=%d", 128000, total)
}
if fake.lastQuery != "SELECT COUNT(*) AS total FROM `SYS`.`test`" {
t.Fatalf("整表导出统计 SQL 错误got=%q", fake.lastQuery)
}
}
func TestVerifyOptionalDriverAgentReadyForExport_RejectsStaleAgent(t *testing.T) {
originalProbe := optionalDriverAgentMetadataProbe
originalResolvePath := resolveOptionalDriverAgentExecutablePathFunc
t.Cleanup(func() {
optionalDriverAgentMetadataProbe = originalProbe
resolveOptionalDriverAgentExecutablePathFunc = originalResolvePath
})
resolveOptionalDriverAgentExecutablePathFunc = func(downloadDir string, driverType string) (string, error) {
return "/tmp/oceanbase-driver-agent", nil
}
optionalDriverAgentMetadataProbe = func(driverType string, executablePath string) (db.OptionalDriverAgentMetadata, error) {
return db.OptionalDriverAgentMetadata{
DriverType: driverType,
AgentRevision: "src-stale-agent",
}, nil
}
err := verifyOptionalDriverAgentReadyForExport(connection.ConnectionConfig{Type: "oceanbase"})
if err == nil {
t.Fatal("预期旧版 OceanBase driver-agent 被导出前校验拦截")
}
if !strings.Contains(err.Error(), "流式协议") {
t.Fatalf("错误信息应说明需要流式协议got=%q", err.Error())
}
}
func TestVerifyOptionalDriverAgentReadyForExport_SkipsBuiltInDriver(t *testing.T) {
originalResolvePath := resolveOptionalDriverAgentExecutablePathFunc
t.Cleanup(func() {
resolveOptionalDriverAgentExecutablePathFunc = originalResolvePath
})
resolveOptionalDriverAgentExecutablePathFunc = func(downloadDir string, driverType string) (string, error) {
t.Fatalf("内置驱动导出不应探测 optional driver-agent 路径")
return "", nil
}
if err := verifyOptionalDriverAgentReadyForExport(connection.ConnectionConfig{Type: "mysql"}); err != nil {
t.Fatalf("内置驱动导出不应被 optional driver-agent 校验阻断: %v", err)
}
}
func TestExportQueryResultToFile_UsesStreamQueryPath(t *testing.T) {
f, err := os.CreateTemp("", "gonavi-export-stream-*.csv")
if err != nil {
t.Fatalf("创建临时文件失败: %v", err)
}
defer os.Remove(f.Name())
defer f.Close()
fake := &fakeStreamExportDB{
fakeExportQueryDB: fakeExportQueryDB{
err: context.DeadlineExceeded,
data: []map[string]interface{}{{"id": 999}},
cols: []string{"id"},
},
streamCols: []string{"id", "name"},
streamData: []map[string]interface{}{
{"id": 1, "name": "alice"},
{"id": 2, "name": "bob"},
},
}
rowCount, columns, err := exportQueryResultToFile(
f,
fake,
connection.ConnectionConfig{Type: "mysql", Timeout: 10},
"SELECT id, name FROM users",
ExportFileOptions{Format: "csv"},
nil,
)
if err != nil {
t.Fatalf("exportQueryResultToFile 返回错误: %v", err)
}
if fake.streamHits != 1 {
t.Fatalf("应优先使用流式查询streamHits=%d", fake.streamHits)
}
if fake.queryHits != 0 {
t.Fatalf("不应回退到缓冲查询queryHits=%d", fake.queryHits)
}
if rowCount != 2 {
t.Fatalf("导出行数异常want=2 got=%d", rowCount)
}
if len(columns) != 2 || columns[0] != "id" || columns[1] != "name" {
t.Fatalf("导出列异常got=%v", columns)
}
contentBytes, err := os.ReadFile(f.Name())
if err != nil {
t.Fatalf("读取导出文件失败: %v", err)
}
content := string(contentBytes)
if !strings.Contains(content, "alice") || !strings.Contains(content, "bob") {
t.Fatalf("流式导出内容异常: %s", content)
}
}
func TestExportQueryResultToFile_UsesValueStreamPathWhenAvailable(t *testing.T) {
f, err := os.CreateTemp("", "gonavi-export-stream-values-*.csv")
if err != nil {
t.Fatalf("创建临时文件失败: %v", err)
}
defer os.Remove(f.Name())
defer f.Close()
fake := &fakeValueStreamExportDB{
streamCols: []string{"id", "name"},
streamValues: [][]interface{}{
{1, "alice"},
{2, "bob"},
},
}
rowCount, columns, err := exportQueryResultToFile(
f,
fake,
connection.ConnectionConfig{Type: "mysql", Timeout: 10},
"SELECT id, name FROM users",
ExportFileOptions{Format: "csv"},
nil,
)
if err != nil {
t.Fatalf("exportQueryResultToFile 返回错误: %v", err)
}
if fake.streamHits != 1 {
t.Fatalf("应优先使用流式查询streamHits=%d", fake.streamHits)
}
if fake.valueHits != 2 {
t.Fatalf("应走值数组流式路径valueHits=%d", fake.valueHits)
}
if fake.queryHits != 0 {
t.Fatalf("不应回退到缓冲查询queryHits=%d", fake.queryHits)
}
if rowCount != 2 {
t.Fatalf("导出行数异常want=2 got=%d", rowCount)
}
if len(columns) != 2 || columns[0] != "id" || columns[1] != "name" {
t.Fatalf("导出列异常got=%v", columns)
}
}
func TestGetExportQueryTimeout_ClickHouseUsesLongerMinimum(t *testing.T) {
timeout := getExportQueryTimeout(connection.ConnectionConfig{
Type: "clickhouse",
Timeout: 30,
})
if timeout != minClickHouseExportQueryTimeout {
t.Fatalf("clickhouse 导出超时下限异常want=%s got=%s", minClickHouseExportQueryTimeout, timeout)
}
}
func TestGetExportQueryTimeout_CustomClickHouseUsesLongerMinimum(t *testing.T) {
timeout := getExportQueryTimeout(connection.ConnectionConfig{
Type: "custom",
Driver: "clickhouse",
Timeout: 30,
})
if timeout != minClickHouseExportQueryTimeout {
t.Fatalf("custom clickhouse 导出超时下限异常want=%s got=%s", minClickHouseExportQueryTimeout, timeout)
}
}
func TestLooksLikeSelectOrWith_AllowsInnerJoinQueryAfterLeadingComments(t *testing.T) {
query := `
-- query result export
/* generated by query editor */
SELECT o.id, c.name
FROM orders o
INNER JOIN customers c ON c.id = o.customer_id
`
if !looksLikeSelectOrWith(query) {
t.Fatalf("带前置注释的 INNER JOIN 查询应允许导出query=%q", query)
}
}
func TestWriteRowsToFile_HTML_EscapeAndStyle(t *testing.T) {
f, err := os.CreateTemp("", "gonavi-export-*.html")
if err != nil {
t.Fatalf("创建临时文件失败: %v", err)
}
defer os.Remove(f.Name())
defer f.Close()
data := []map[string]interface{}{
{
"name": "<script>alert(1)</script>",
"note": "line1\nline2",
"nullable": nil,
},
}
columns := []string{"name", "note", "nullable"}
if err := writeRowsToFile(f, data, columns, ExportFileOptions{Format: "html"}); err != nil {
t.Fatalf("写入 html 失败: %v", err)
}
contentBytes, err := os.ReadFile(f.Name())
if err != nil {
t.Fatalf("读取 html 失败: %v", err)
}
content := string(contentBytes)
if !strings.Contains(content, "<!DOCTYPE html>") {
t.Fatalf("html 导出缺少 doctype: %s", content)
}
if !strings.Contains(content, "position: sticky") {
t.Fatalf("html 导出缺少表头吸顶样式: %s", content)
}
if !strings.Contains(content, "tbody tr:nth-child(even)") {
t.Fatalf("html 导出缺少斑马纹样式: %s", content)
}
if !strings.Contains(content, "&lt;script&gt;alert(1)&lt;/script&gt;") {
t.Fatalf("html 导出未进行 XSS 转义: %s", content)
}
if strings.Contains(content, "<script>alert(1)</script>") {
t.Fatalf("html 导出包含未转义脚本: %s", content)
}
if !strings.Contains(content, "line1<br>line2") {
t.Fatalf("html 导出换行未转为 <br>: %s", content)
}
if !strings.Contains(content, "<td>NULL</td>") {
t.Fatalf("html 导出空值显示异常: %s", content)
}
}
func TestWriteRowsToFile_HTML_EscapeHeader(t *testing.T) {
f, err := os.CreateTemp("", "gonavi-export-*.html")
if err != nil {
t.Fatalf("创建临时文件失败: %v", err)
}
defer os.Remove(f.Name())
defer f.Close()
columnName := "<b>name</b>"
data := []map[string]interface{}{{columnName: "ok"}}
if err := writeRowsToFile(f, data, []string{columnName}, ExportFileOptions{Format: "html"}); err != nil {
t.Fatalf("写入 html 失败: %v", err)
}
contentBytes, _ := os.ReadFile(f.Name())
content := string(contentBytes)
if !strings.Contains(content, "<th>&lt;b&gt;name&lt;/b&gt;</th>") || strings.Contains(content, "<th><b>name</b></th>") {
t.Fatalf("html 表头未正确转义: %s", content)
}
}
func TestWriteRowsToFile_XLSX_SplitsByMaxRowsPerSheet(t *testing.T) {
f, err := os.CreateTemp("", "gonavi-export-*.xlsx")
if err != nil {
t.Fatalf("创建临时文件失败: %v", err)
}
defer os.Remove(f.Name())
defer f.Close()
data := []map[string]interface{}{
{"id": 1, "name": "alice"},
{"id": 2, "name": "bob"},
{"id": 3, "name": "carol"},
}
columns := []string{"id", "name"}
if err := writeRowsToFile(f, data, columns, ExportFileOptions{
Format: "xlsx",
XLSXMaxRowsPerSheet: 2,
}); err != nil {
t.Fatalf("写入 xlsx 失败: %v", err)
}
workbook, err := excelize.OpenFile(f.Name())
if err != nil {
t.Fatalf("打开 xlsx 失败: %v", err)
}
defer workbook.Close()
sheets := workbook.GetSheetList()
if len(sheets) != 2 {
t.Fatalf("sheet 数量异常want=2 got=%d (%v)", len(sheets), sheets)
}
rows1, err := workbook.GetRows("Sheet1")
if err != nil {
t.Fatalf("读取 Sheet1 失败: %v", err)
}
if len(rows1) != 3 {
t.Fatalf("Sheet1 行数异常want=3 got=%d", len(rows1))
}
rows2, err := workbook.GetRows("Sheet2")
if err != nil {
t.Fatalf("读取 Sheet2 失败: %v", err)
}
if len(rows2) != 2 {
t.Fatalf("Sheet2 行数异常want=2 got=%d", len(rows2))
}
if rows2[1][1] != "carol" {
t.Fatalf("Sheet2 数据异常want=%q got=%q", "carol", rows2[1][1])
}
}
func benchmarkExportRows(rowCount int) ([]map[string]interface{}, []string) {
columns := []string{"id", "name", "note", "created_at", "status"}
rows := make([]map[string]interface{}, rowCount)
for i := 0; i < rowCount; i++ {
rows[i] = map[string]interface{}{
"id": i + 1,
"name": "benchmark-user",
"note": "plain export payload without timezone marker",
"created_at": "2026-06-17 12:34:56",
"status": "enabled",
}
}
return rows, columns
}
func benchmarkExportRowValues(rowCount int) ([][]interface{}, []string) {
columns := []string{"id", "name", "note", "created_at", "status"}
rows := make([][]interface{}, rowCount)
for i := 0; i < rowCount; i++ {
rows[i] = []interface{}{
i + 1,
"benchmark-user",
"plain export payload without timezone marker",
"2026-06-17 12:34:56",
"enabled",
}
}
return rows, columns
}
func BenchmarkFormatExportCellText_PlainString(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = formatExportCellText("plain export payload without timezone marker")
}
}
func BenchmarkWriteRowsToFile_XLSX_20000Rows(b *testing.B) {
rows, columns := benchmarkExportRows(20000)
b.ReportAllocs()
for i := 0; i < b.N; i++ {
f, err := os.CreateTemp("", "gonavi-export-bench-*.xlsx")
if err != nil {
b.Fatalf("创建临时文件失败: %v", err)
}
name := f.Name()
if err := writeRowsToFile(f, rows, columns, ExportFileOptions{Format: "xlsx"}); err != nil {
_ = os.Remove(name)
b.Fatalf("写入 xlsx 失败: %v", err)
}
if err := os.Remove(name); err != nil {
b.Fatalf("删除临时文件失败: %v", err)
}
}
}
func BenchmarkExportQueryResultToFile_XLSX_StreamMap_20000Rows(b *testing.B) {
rows, columns := benchmarkExportRows(20000)
streamDB := &fakeStreamExportDB{
streamCols: columns,
streamData: rows,
}
b.ReportAllocs()
for i := 0; i < b.N; i++ {
f, err := os.CreateTemp("", "gonavi-export-stream-map-*.xlsx")
if err != nil {
b.Fatalf("创建临时文件失败: %v", err)
}
name := f.Name()
if _, _, err := exportQueryResultToFile(
f,
streamDB,
connection.ConnectionConfig{Type: "mysql", Timeout: 10},
"SELECT * FROM users",
ExportFileOptions{Format: "xlsx"},
nil,
); err != nil {
_ = os.Remove(name)
b.Fatalf("流式 map 导出失败: %v", err)
}
if err := os.Remove(name); err != nil {
b.Fatalf("删除临时文件失败: %v", err)
}
}
}
func BenchmarkExportQueryResultToFile_XLSX_StreamValues_20000Rows(b *testing.B) {
rows, columns := benchmarkExportRowValues(20000)
streamDB := &fakeValueStreamExportDB{
streamCols: columns,
streamValues: rows,
}
b.ReportAllocs()
for i := 0; i < b.N; i++ {
f, err := os.CreateTemp("", "gonavi-export-stream-values-*.xlsx")
if err != nil {
b.Fatalf("创建临时文件失败: %v", err)
}
name := f.Name()
if _, _, err := exportQueryResultToFile(
f,
streamDB,
connection.ConnectionConfig{Type: "mysql", Timeout: 10},
"SELECT * FROM users",
ExportFileOptions{Format: "xlsx"},
nil,
); err != nil {
_ = os.Remove(name)
b.Fatalf("流式值数组导出失败: %v", err)
}
if err := os.Remove(name); err != nil {
b.Fatalf("删除临时文件失败: %v", err)
}
}
}
func BenchmarkExportQueryResultToFile_XLSX_StreamGenerated_50000Rows(b *testing.B) {
streamDB := &fakeGeneratedValueStreamExportDB{
streamCols: []string{"id", "name", "note", "created_at", "status"},
rowCount: 50000,
}
b.ReportAllocs()
for i := 0; i < b.N; i++ {
f, err := os.CreateTemp("", "gonavi-export-stream-generated-*.xlsx")
if err != nil {
b.Fatalf("创建临时文件失败: %v", err)
}
name := f.Name()
if _, _, err := exportQueryResultToFile(
f,
streamDB,
connection.ConnectionConfig{Type: "mysql", Timeout: 10},
"SELECT * FROM users",
ExportFileOptions{Format: "xlsx"},
nil,
); err != nil {
_ = os.Remove(name)
b.Fatalf("流式生成导出失败: %v", err)
}
if err := os.Remove(name); err != nil {
b.Fatalf("删除临时文件失败: %v", err)
}
}
}
func BenchmarkDumpTableSQL_SQLBackup_StreamMap_20000Rows(b *testing.B) {
rows, columns := benchmarkExportRows(20000)
streamDB := &fakeStreamExportDB{
streamCols: columns,
streamData: rows,
}
b.ReportAllocs()
for i := 0; i < b.N; i++ {
writer := bufio.NewWriterSize(io.Discard, 1024*1024)
if err := dumpTableSQL(
writer,
streamDB,
connection.ConnectionConfig{Type: "mysql"},
"app",
"users",
false,
true,
map[string]string{},
); err != nil {
b.Fatalf("SQL 备份导出失败: %v", err)
}
if err := writer.Flush(); err != nil {
b.Fatalf("flush SQL 备份失败: %v", err)
}
}
}
func BenchmarkDumpTableSQL_SQLBackup_StreamValues_20000Rows(b *testing.B) {
rows, columns := benchmarkExportRowValues(20000)
streamDB := &fakeValueStreamExportDB{
streamCols: columns,
streamValues: rows,
}
b.ReportAllocs()
for i := 0; i < b.N; i++ {
writer := bufio.NewWriterSize(io.Discard, 1024*1024)
if err := dumpTableSQL(
writer,
streamDB,
connection.ConnectionConfig{Type: "mysql"},
"app",
"users",
false,
true,
map[string]string{},
); err != nil {
b.Fatalf("SQL 备份导出失败: %v", err)
}
if err := writer.Flush(); err != nil {
b.Fatalf("flush SQL 备份失败: %v", err)
}
}
}
func TestFormatImportSQLValue_NormalizesTimestampWithoutTimezone(t *testing.T) {
got := formatImportSQLValue("postgres", "timestamp without time zone", "2026-01-21T18:32:26+08:00")
if got != "'2026-01-21 18:32:26'" {
t.Fatalf("时间字面量归一化异常want=%q got=%q", "'2026-01-21 18:32:26'", got)
}
}
func TestFormatImportSQLValue_LeavesTextLiteralUntouched(t *testing.T) {
got := formatImportSQLValue("postgres", "text", "2026-01-21T18:32:26+08:00")
if got != "'2026-01-21T18:32:26+08:00'" {
t.Fatalf("文本字段不应被归一化want=%q got=%q", "'2026-01-21T18:32:26+08:00'", got)
}
}
func TestFormatImportSQLValue_PostgresBooleanColumnUsesBooleanLiteral(t *testing.T) {
cases := []struct {
name string
dbType string
columnType string
value interface{}
want string
}{
{name: "postgres bool true", dbType: "postgres", columnType: "boolean", value: true, want: "true"},
{name: "postgres bool false", dbType: "postgres", columnType: "bool", value: false, want: "false"},
{name: "pg catalog bool string", dbType: "postgres", columnType: "pg_catalog.bool", value: "t", want: "true"},
{name: "highgo boolean bytes", dbType: "highgo", columnType: "boolean", value: []byte("0"), want: "false"},
{name: "mysql keeps numeric bool", dbType: "mysql", columnType: "tinyint(1)", value: true, want: "1"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := formatImportSQLValue(tc.dbType, tc.columnType, tc.value)
if got != tc.want {
t.Fatalf("布尔字面量异常want=%q got=%q", tc.want, got)
}
})
}
}
func TestDumpTableSQL_PostgresBooleanBackupUsesBooleanLiterals(t *testing.T) {
fake := &fakeExportQueryDB{
data: []map[string]interface{}{
{"active": true, "archived": false},
},
cols: []string{"active", "archived"},
defs: []connection.ColumnDefinition{
{Name: "active", Type: "boolean"},
{Name: "archived", Type: "bool"},
},
}
var buf bytes.Buffer
writer := bufio.NewWriter(&buf)
err := dumpTableSQL(
writer,
fake,
connection.ConnectionConfig{Type: "postgres"},
"public",
"orders",
false,
true,
map[string]string{},
)
if err != nil {
t.Fatalf("dumpTableSQL 返回错误: %v", err)
}
if err := writer.Flush(); err != nil {
t.Fatalf("flush 导出 SQL 失败: %v", err)
}
content := buf.String()
if !strings.Contains(content, `INSERT INTO "public"."orders" ("active", "archived") VALUES (true, false);`) {
t.Fatalf("PostgreSQL bool 备份应使用 true/false 字面量content=%s", content)
}
if strings.Contains(content, "VALUES (1, 0)") {
t.Fatalf("PostgreSQL bool 备份不应输出数字布尔值content=%s", content)
}
}
func TestDumpTableSQL_MySQLBackupBatchesRowsIntoMultiValueInsert(t *testing.T) {
fake := &fakeValueStreamExportDB{
streamCols: []string{"id", "name"},
streamValues: [][]interface{}{
{1, "alice"},
{2, "bob"},
{3, "carol"},
},
}
var buf bytes.Buffer
writer := bufio.NewWriter(&buf)
err := dumpTableSQL(
writer,
fake,
connection.ConnectionConfig{Type: "mysql"},
"app",
"users",
false,
true,
map[string]string{},
)
if err != nil {
t.Fatalf("dumpTableSQL 返回错误: %v", err)
}
if err := writer.Flush(); err != nil {
t.Fatalf("flush 导出 SQL 失败: %v", err)
}
content := buf.String()
if strings.Count(content, "INSERT INTO `app`.`users`") != 1 {
t.Fatalf("MySQL 备份应合并为单条批量 INSERTcontent=%s", content)
}
if !strings.Contains(content, "VALUES (1, 'alice'),\n(2, 'bob'),\n(3, 'carol');") {
t.Fatalf("MySQL 批量 INSERT 内容异常content=%s", content)
}
}
func TestDumpTableSQL_OracleBackupBatchesRowsIntoInsertAll(t *testing.T) {
fake := &fakeValueStreamExportDB{
streamCols: []string{"id", "name"},
streamValues: [][]interface{}{
{1, "alice"},
{2, "bob"},
},
}
var buf bytes.Buffer
writer := bufio.NewWriter(&buf)
err := dumpTableSQL(
writer,
fake,
connection.ConnectionConfig{Type: "oracle"},
"APP",
"USERS",
false,
true,
map[string]string{},
)
if err != nil {
t.Fatalf("dumpTableSQL 返回错误: %v", err)
}
if err := writer.Flush(); err != nil {
t.Fatalf("flush 导出 SQL 失败: %v", err)
}
content := buf.String()
if strings.Count(content, "INSERT ALL") != 1 {
t.Fatalf("Oracle 备份应合并为单条 INSERT ALLcontent=%s", content)
}
if !strings.Contains(content, "INTO \"APP\".\"USERS\" (\"id\", \"name\") VALUES (1, 'alice')\n INTO \"APP\".\"USERS\" (\"id\", \"name\") VALUES (2, 'bob')\nSELECT 1 FROM DUAL;") {
t.Fatalf("Oracle INSERT ALL 内容异常content=%s", content)
}
}
func TestFilterExportObjectsBySchema_PostgresQualifiedObjectsOnly(t *testing.T) {
got := filterExportObjectsBySchema(
connection.ConnectionConfig{Type: "postgres"},
"app_db",
[]string{"public.users", "sales.orders", "sales.v_orders", "analytics.events"},
"sales",
)
want := []string{"sales.orders", "sales.v_orders"}
if len(got) != len(want) {
t.Fatalf("filtered objects length mismatch, want=%d got=%d (%v)", len(want), len(got), got)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("filtered objects mismatch at %d, want=%q got=%q", i, want[i], got[i])
}
}
}
func TestFilterExportViewLookupBySchema_PostgresQualifiedViewsOnly(t *testing.T) {
got := filterExportViewLookupBySchema(
connection.ConnectionConfig{Type: "postgres"},
"app_db",
map[string]string{
"public.v_users": "public.v_users",
"sales.v_orders": "sales.v_orders",
"sales.v_summary": "sales.v_summary",
},
"sales",
)
if len(got) != 2 {
t.Fatalf("filtered views length mismatch, want=2 got=%d (%v)", len(got), got)
}
if got["sales.v_orders"] != "sales.v_orders" {
t.Fatalf("expected sales.v_orders to be retained, got=%q", got["sales.v_orders"])
}
if got["sales.v_summary"] != "sales.v_summary" {
t.Fatalf("expected sales.v_summary to be retained, got=%q", got["sales.v_summary"])
}
if _, ok := got["public.v_users"]; ok {
t.Fatalf("expected public.v_users to be filtered out, got=%v", got)
}
}