From 5ab50db51c9571eb5dad8580f31fed39dbb309da Mon Sep 17 00:00:00 2001 From: Syngnat Date: Tue, 26 May 2026 08:27:15 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20perf(sync):=20=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E5=A4=A7=E8=A1=A8=E5=90=8C=E6=AD=A5=E5=88=86=E9=A1=B5?= =?UTF-8?q?=E4=B8=8E=E6=89=B9=E9=87=8F=E5=86=99=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 同步分析和预览改为分页扫描差异,避免一次性加载源表和目标表 - 直接导入与源查询同步支持分页读取和分批提交,降低低内存机器 OOM 风险 - 各数据库 ApplyChanges 统一使用参数化批量 INSERT,减少大表同步 SQL 超时 - MySQL 批量写入按行数和参数数量拆分,兼容超宽表场景 - 补充批量插入、分页差异和源查询同步回归测试 --- internal/db/batch_insert.go | 293 ++++++++++++ internal/db/batch_insert_test.go | 231 ++++++++++ internal/db/clickhouse_impl.go | 28 +- internal/db/custom_impl.go | 46 +- internal/db/duckdb_impl.go | 29 +- internal/db/highgo_impl.go | 34 +- internal/db/kingbase_impl.go | 34 +- internal/db/mariadb_impl.go | 37 +- internal/db/mongodb_impl.go | 38 +- internal/db/mongodb_impl_v1.go | 38 +- internal/db/mysql_impl.go | 71 ++- internal/db/oracle_applychanges_test.go | 53 +++ internal/db/postgres_impl.go | 34 +- internal/db/sqlite_impl.go | 31 +- internal/db/sqlserver_impl.go | 38 +- internal/db/tdengine_impl.go | 29 +- internal/db/vastbase_impl.go | 34 +- internal/sync/analyze.go | 45 +- internal/sync/diff_paging.go | 587 ++++++++++++++++++++++++ internal/sync/direct_import_paging.go | 304 ++++++++++++ internal/sync/preview.go | 140 +++++- internal/sync/schema_migration_test.go | 2 + internal/sync/source_query_paging.go | 286 ++++++++++++ internal/sync/source_query_sync.go | 162 ++++++- internal/sync/source_query_sync_test.go | 324 ++++++++++++- internal/sync/sql_helpers_test.go | 95 +++- internal/sync/sync_engine.go | 122 ++++- 27 files changed, 2846 insertions(+), 319 deletions(-) create mode 100644 internal/db/batch_insert.go create mode 100644 internal/db/batch_insert_test.go create mode 100644 internal/sync/diff_paging.go create mode 100644 internal/sync/direct_import_paging.go create mode 100644 internal/sync/source_query_paging.go diff --git a/internal/db/batch_insert.go b/internal/db/batch_insert.go new file mode 100644 index 0000000..3bdd9cf --- /dev/null +++ b/internal/db/batch_insert.go @@ -0,0 +1,293 @@ +package db + +import ( + "database/sql" + "fmt" + "sort" + "strings" +) + +const ( + defaultBatchInsertRows = 1000 + defaultBatchInsertArgs = 60000 + sqlServerBatchInsertArgs = 2000 + sqliteBatchInsertArgs = 900 +) + +type preparedInsertRow struct { + columns []string + values []interface{} +} + +type parameterizedInsertConfig struct { + Table string + Rows []map[string]interface{} + QuoteColumn func(string) string + Placeholder func(int) string + Value func(string, interface{}) (interface{}, bool) + Arg func(int, string, interface{}) interface{} + Exec func(string, ...interface{}) (sql.Result, error) + MaxRows int + MaxArgs int + EmptyInsertSQL func(string) string + RequireAffected bool +} + +func execParameterizedInsertBatches(config parameterizedInsertConfig) error { + if len(config.Rows) == 0 { + return nil + } + if strings.TrimSpace(config.Table) == "" { + return fmt.Errorf("表名不能为空") + } + if config.QuoteColumn == nil { + return fmt.Errorf("列名引用函数不能为空") + } + if config.Placeholder == nil { + return fmt.Errorf("占位符函数不能为空") + } + if config.Exec == nil { + return fmt.Errorf("执行函数不能为空") + } + if config.Value == nil { + config.Value = func(_ string, value interface{}) (interface{}, bool) { return value, false } + } + if config.Arg == nil { + config.Arg = func(_ int, _ string, value interface{}) interface{} { return value } + } + + groups, order := groupPreparedInsertRows(config.Rows, config.Value) + for _, key := range order { + rows := groups[key] + if len(rows) == 0 { + continue + } + columnCount := len(rows[0].columns) + if columnCount == 0 { + if config.EmptyInsertSQL == nil { + continue + } + for range rows { + res, err := config.Exec(config.EmptyInsertSQL(config.Table)) + if err != nil { + return fmt.Errorf("插入失败:%v", err) + } + if config.RequireAffected { + if err := requireInsertAffected(res); err != nil { + return err + } + } + } + continue + } + + batchSize := batchInsertRowLimit(columnCount, config.MaxRows, config.MaxArgs) + for start := 0; start < len(rows); start += batchSize { + end := start + batchSize + if end > len(rows) { + end = len(rows) + } + if err := execParameterizedInsertBatch(config, rows[start:end]); err != nil { + return err + } + } + } + return nil +} + +func groupPreparedInsertRows(rows []map[string]interface{}, valueFunc func(string, interface{}) (interface{}, bool)) (map[string][]preparedInsertRow, []string) { + groups := make(map[string][]preparedInsertRow) + order := make([]string, 0) + for _, row := range rows { + prepared := prepareInsertRow(row, valueFunc) + key := strings.Join(prepared.columns, "\x00") + if _, ok := groups[key]; !ok { + order = append(order, key) + } + groups[key] = append(groups[key], prepared) + } + return groups, order +} + +func prepareInsertRow(row map[string]interface{}, valueFunc func(string, interface{}) (interface{}, bool)) preparedInsertRow { + columns := make([]string, 0, len(row)) + valuesByColumn := make(map[string]interface{}, len(row)) + for key, value := range row { + column := strings.TrimSpace(key) + if column == "" { + continue + } + nextValue, omit := valueFunc(column, value) + if omit { + continue + } + columns = append(columns, column) + valuesByColumn[column] = nextValue + } + sort.Strings(columns) + + values := make([]interface{}, 0, len(columns)) + for _, column := range columns { + values = append(values, valuesByColumn[column]) + } + return preparedInsertRow{columns: columns, values: values} +} + +func execParameterizedInsertBatch(config parameterizedInsertConfig, rows []preparedInsertRow) error { + if len(rows) == 0 || len(rows[0].columns) == 0 { + return nil + } + + quotedColumns := make([]string, 0, len(rows[0].columns)) + for _, column := range rows[0].columns { + quotedColumns = append(quotedColumns, config.QuoteColumn(column)) + } + + argIndex := 0 + valueGroups := make([]string, 0, len(rows)) + args := make([]interface{}, 0, len(rows)*len(rows[0].columns)) + for _, row := range rows { + placeholders := make([]string, 0, len(row.columns)) + for idx, column := range row.columns { + argIndex++ + placeholders = append(placeholders, config.Placeholder(argIndex)) + args = append(args, config.Arg(argIndex, column, row.values[idx])) + } + valueGroups = append(valueGroups, "("+strings.Join(placeholders, ", ")+")") + } + + query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s", + config.Table, + strings.Join(quotedColumns, ", "), + strings.Join(valueGroups, ", "), + ) + res, err := config.Exec(query, args...) + if err != nil { + return fmt.Errorf("插入失败:%v", err) + } + if config.RequireAffected { + if err := requireInsertAffected(res); err != nil { + return err + } + } + return nil +} + +func requireInsertAffected(result sql.Result) error { + if result == nil { + return nil + } + if affected, err := result.RowsAffected(); err == nil && affected == 0 { + return fmt.Errorf("插入未生效:未影响任何行") + } + return nil +} + +func batchInsertRowLimit(columnCount, maxRows, maxArgs int) int { + if maxRows <= 0 { + maxRows = defaultBatchInsertRows + } + if maxArgs <= 0 { + maxArgs = defaultBatchInsertArgs + } + if columnCount <= 0 { + return 1 + } + limitByArgs := maxArgs / columnCount + if limitByArgs < 1 { + return 1 + } + if limitByArgs < maxRows { + return limitByArgs + } + return maxRows +} + +type literalInsertConfig struct { + Table string + Rows []map[string]interface{} + QuoteColumn func(string) string + Literal func(interface{}) string + Exec func(string) (sql.Result, error) + RowSeparator string + MaxRows int + RequireAffected bool +} + +func execLiteralInsertBatches(config literalInsertConfig) error { + if len(config.Rows) == 0 { + return nil + } + if strings.TrimSpace(config.Table) == "" { + return fmt.Errorf("表名不能为空") + } + if config.QuoteColumn == nil { + return fmt.Errorf("列名引用函数不能为空") + } + if config.Literal == nil { + return fmt.Errorf("字面量函数不能为空") + } + if config.Exec == nil { + return fmt.Errorf("执行函数不能为空") + } + if config.RowSeparator == "" { + config.RowSeparator = ", " + } + if config.MaxRows <= 0 { + config.MaxRows = defaultBatchInsertRows + } + + groups, order := groupPreparedInsertRows(config.Rows, func(_ string, value interface{}) (interface{}, bool) { return value, false }) + for _, key := range order { + rows := groups[key] + if len(rows) == 0 || len(rows[0].columns) == 0 { + continue + } + for start := 0; start < len(rows); start += config.MaxRows { + end := start + config.MaxRows + if end > len(rows) { + end = len(rows) + } + if err := execLiteralInsertBatch(config, rows[start:end]); err != nil { + return err + } + } + } + return nil +} + +func execLiteralInsertBatch(config literalInsertConfig, rows []preparedInsertRow) error { + if len(rows) == 0 || len(rows[0].columns) == 0 { + return nil + } + + quotedColumns := make([]string, 0, len(rows[0].columns)) + for _, column := range rows[0].columns { + quotedColumns = append(quotedColumns, config.QuoteColumn(column)) + } + + valueGroups := make([]string, 0, len(rows)) + for _, row := range rows { + values := make([]string, 0, len(row.values)) + for _, value := range row.values { + values = append(values, config.Literal(value)) + } + valueGroups = append(valueGroups, "("+strings.Join(values, ", ")+")") + } + + query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s", + config.Table, + strings.Join(quotedColumns, ", "), + strings.Join(valueGroups, config.RowSeparator), + ) + res, err := config.Exec(query) + if err != nil { + return fmt.Errorf("插入失败:%v; sql=%s", err, query) + } + if config.RequireAffected { + if err := requireInsertAffected(res); err != nil { + return err + } + } + return nil +} diff --git a/internal/db/batch_insert_test.go b/internal/db/batch_insert_test.go new file mode 100644 index 0000000..44b96b1 --- /dev/null +++ b/internal/db/batch_insert_test.go @@ -0,0 +1,231 @@ +package db + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "strings" + "testing" +) + +func TestExecParameterizedInsertBatchesGroupsRowsByColumnSet(t *testing.T) { + t.Parallel() + + var queries []string + var args [][]interface{} + err := execParameterizedInsertBatches(parameterizedInsertConfig{ + Table: "\"users\"", + Rows: []map[string]interface{}{ + {"id": 1, "name": "Alice"}, + {"name": "Bob", "id": 2}, + {"id": 3}, + }, + QuoteColumn: func(column string) string { return `"` + column + `"` }, + Placeholder: func(idx int) string { + return fmt.Sprintf("$%d", idx) + }, + Exec: func(query string, values ...interface{}) (sql.Result, error) { + queries = append(queries, query) + args = append(args, append([]interface{}(nil), values...)) + return driver.RowsAffected(1), nil + }, + }) + if err != nil { + t.Fatalf("execParameterizedInsertBatches() error = %v", err) + } + + if len(queries) != 2 { + t.Fatalf("expected 2 insert statements, got %d: %v", len(queries), queries) + } + if queries[0] != `INSERT INTO "users" ("id", "name") VALUES ($1, $2), ($3, $4)` { + t.Fatalf("unexpected first query: %s", queries[0]) + } + if queries[1] != `INSERT INTO "users" ("id") VALUES ($1)` { + t.Fatalf("unexpected second query: %s", queries[1]) + } + if got := fmt.Sprint(args[0]); got != "[1 Alice 2 Bob]" { + t.Fatalf("unexpected first args: %s", got) + } +} + +func TestExecParameterizedInsertBatchesUsesNamedSQLServerArgs(t *testing.T) { + t.Parallel() + + var query string + var args []interface{} + err := execParameterizedInsertBatches(parameterizedInsertConfig{ + Table: "[dbo].[users]", + Rows: []map[string]interface{}{{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}}, + QuoteColumn: func(column string) string { return "[" + column + "]" }, + Placeholder: func(idx int) string { + return fmt.Sprintf("@p%d", idx) + }, + Arg: func(idx int, _ string, value interface{}) interface{} { + return sql.Named(fmt.Sprintf("p%d", idx), value) + }, + Exec: func(q string, values ...interface{}) (sql.Result, error) { + query = q + args = append([]interface{}(nil), values...) + return driver.RowsAffected(1), nil + }, + }) + if err != nil { + t.Fatalf("execParameterizedInsertBatches() error = %v", err) + } + + if query != `INSERT INTO [dbo].[users] ([id], [name]) VALUES (@p1, @p2), (@p3, @p4)` { + t.Fatalf("unexpected query: %s", query) + } + if len(args) != 4 { + t.Fatalf("expected 4 args, got %d", len(args)) + } + first, ok := args[0].(sql.NamedArg) + if !ok || first.Name != "p1" || first.Value != 1 { + t.Fatalf("unexpected first arg: %#v", args[0]) + } +} + +func TestExecLiteralInsertBatchesBuildsMultiRowValues(t *testing.T) { + t.Parallel() + + var query string + err := execLiteralInsertBatches(literalInsertConfig{ + Table: "`metrics`", + Rows: []map[string]interface{}{{"ts": 1, "value": "a"}, {"ts": 2, "value": "b"}}, + QuoteColumn: func(column string) string { return "`" + column + "`" }, + Literal: func(value interface{}) string { + return fmt.Sprintf("'%v'", value) + }, + Exec: func(q string) (sql.Result, error) { + query = q + return driver.RowsAffected(1), nil + }, + }) + if err != nil { + t.Fatalf("execLiteralInsertBatches() error = %v", err) + } + + if query != "INSERT INTO `metrics` (`ts`, `value`) VALUES ('1', 'a'), ('2', 'b')" { + t.Fatalf("unexpected query: %s", query) + } +} + +func TestBatchInsertRowLimitRespectsArgumentLimit(t *testing.T) { + t.Parallel() + + if got := batchInsertRowLimit(2, 1000, 60000); got != 1000 { + t.Fatalf("2 columns limit=%d, want 1000", got) + } + if got := batchInsertRowLimit(100, 1000, 60000); got != 600 { + t.Fatalf("100 columns limit=%d, want 600", got) + } + if got := batchInsertRowLimit(70000, 1000, 60000); got != 1 { + t.Fatalf("wide table limit=%d, want 1", got) + } +} + +func TestExecParameterizedInsertBatchesSplitsByArgumentLimit(t *testing.T) { + t.Parallel() + + var queries []string + rows := []map[string]interface{}{ + {"a": 1, "b": 2}, + {"a": 3, "b": 4}, + {"a": 5, "b": 6}, + } + err := execParameterizedInsertBatches(parameterizedInsertConfig{ + Table: "`t`", + Rows: rows, + QuoteColumn: func(column string) string { return "`" + column + "`" }, + Placeholder: func(int) string { + return "?" + }, + Exec: func(query string, _ ...interface{}) (sql.Result, error) { + queries = append(queries, query) + return driver.RowsAffected(1), nil + }, + MaxRows: 1000, + MaxArgs: 4, + }) + if err != nil { + t.Fatalf("execParameterizedInsertBatches() error = %v", err) + } + if len(queries) != 2 { + t.Fatalf("expected 2 queries, got %d: %v", len(queries), queries) + } + if strings.Count(queries[0], "(?, ?)") != 2 || strings.Count(queries[1], "(?, ?)") != 1 { + t.Fatalf("unexpected split queries: %v", queries) + } +} + +func TestExecParameterizedInsertBatchesOmitsColumnsPerRow(t *testing.T) { + t.Parallel() + + var queries []string + err := execParameterizedInsertBatches(parameterizedInsertConfig{ + Table: "`events`", + Rows: []map[string]interface{}{ + {"id": 1, "created_at": ""}, + {"id": 2, "created_at": "2026-05-25 10:00:00"}, + }, + QuoteColumn: func(column string) string { return "`" + column + "`" }, + Placeholder: func(int) string { + return "?" + }, + Value: func(column string, value interface{}) (interface{}, bool) { + return value, column == "created_at" && value == "" + }, + Exec: func(query string, _ ...interface{}) (sql.Result, error) { + queries = append(queries, query) + return driver.RowsAffected(1), nil + }, + }) + if err != nil { + t.Fatalf("execParameterizedInsertBatches() error = %v", err) + } + + if len(queries) != 2 { + t.Fatalf("expected rows with different effective columns to split into 2 statements, got %d: %v", len(queries), queries) + } + if queries[0] != "INSERT INTO `events` (`id`) VALUES (?)" { + t.Fatalf("unexpected omitted-column query: %s", queries[0]) + } + if queries[1] != "INSERT INTO `events` (`created_at`, `id`) VALUES (?, ?)" { + t.Fatalf("unexpected full-column query: %s", queries[1]) + } +} + +func TestExecParameterizedInsertBatchesRunsEmptyInsertSQLWhenAllColumnsOmitted(t *testing.T) { + t.Parallel() + + var queries []string + err := execParameterizedInsertBatches(parameterizedInsertConfig{ + Table: "`events`", + Rows: []map[string]interface{}{{"created_at": ""}, {"created_at": ""}}, + QuoteColumn: func(column string) string { return "`" + column + "`" }, + Placeholder: func(int) string { return "?" }, + Value: func(_ string, value interface{}) (interface{}, bool) { + return value, true + }, + EmptyInsertSQL: func(table string) string { + return fmt.Sprintf("INSERT INTO %s () VALUES ()", table) + }, + Exec: func(query string, _ ...interface{}) (sql.Result, error) { + queries = append(queries, query) + return driver.RowsAffected(1), nil + }, + RequireAffected: true, + }) + if err != nil { + t.Fatalf("execParameterizedInsertBatches() error = %v", err) + } + + if len(queries) != 2 { + t.Fatalf("expected 2 empty insert statements, got %d: %v", len(queries), queries) + } + for _, query := range queries { + if query != "INSERT INTO `events` () VALUES ()" { + t.Fatalf("unexpected empty insert query: %s", query) + } + } +} diff --git a/internal/db/clickhouse_impl.go b/internal/db/clickhouse_impl.go index 2174346..ff12e72 100644 --- a/internal/db/clickhouse_impl.go +++ b/internal/db/clickhouse_impl.go @@ -1225,21 +1225,27 @@ func (c *ClickHouseDB) ApplyChanges(tableName string, changes connection.ChangeS } } - for _, row := range changes.Inserts { - query, err := buildClickHouseInsertSQL(qualifiedTable, row) - if err != nil { - return err - } - if query == "" { - continue - } - if _, err := c.conn.Exec(query); err != nil { - return fmt.Errorf("插入失败:%v; sql=%s", err, query) - } + if err := execClickHouseInsertBatches(c.conn, qualifiedTable, changes.Inserts); err != nil { + return err } return nil } +func execClickHouseInsertBatches(conn *sql.DB, qualifiedTable string, rows []map[string]interface{}) error { + if conn == nil { + return fmt.Errorf("连接未打开") + } + return execLiteralInsertBatches(literalInsertConfig{ + Table: qualifiedTable, + Rows: rows, + QuoteColumn: quoteClickHouseIdentifier, + Literal: clickHouseLiteral, + Exec: func(query string) (sql.Result, error) { + return conn.Exec(query) + }, + }) +} + func buildClickHouseInsertSQL(qualifiedTable string, row map[string]interface{}) (string, error) { if len(row) == 0 { return "", nil diff --git a/internal/db/custom_impl.go b/internal/db/custom_impl.go index 58badeb..20b2eed 100644 --- a/internal/db/custom_impl.go +++ b/internal/db/custom_impl.go @@ -324,6 +324,8 @@ func (c *CustomDB) ApplyChanges(tableName string, changes connection.ChangeSet) isKingbase := strings.Contains(driver, "kingbase") isPostgres := strings.Contains(driver, "postgres") || isKingbase || strings.Contains(driver, "pg") isOracle := strings.Contains(driver, "oracle") || strings.Contains(driver, "ora") || strings.Contains(driver, "dm") || strings.Contains(driver, "dameng") + isSQLServer := strings.Contains(driver, "sqlserver") || strings.Contains(driver, "mssql") + isSQLite := strings.Contains(driver, "sqlite") || strings.Contains(driver, "duckdb") quoteIdent := func(name string) string { n := strings.TrimSpace(name) @@ -425,33 +427,33 @@ func (c *CustomDB) ApplyChanges(tableName string, changes connection.ChangeSet) } } - // 3. Inserts - for _, row := range changes.Inserts { - var cols []string - var placeholders []string - var args []interface{} - idx := 0 - - for k, v := range row { - idx++ - cols = append(cols, quoteIdent(k)) - placeholders = append(placeholders, placeholder(idx)) - args = append(args, v) - } - - if len(cols) == 0 { - continue - } - - query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", ")) - if _, err := tx.Exec(query, args...); err != nil { - return fmt.Errorf("插入失败:%v", err) - } + if err := execParameterizedInsertBatches(parameterizedInsertConfig{ + Table: qualifiedTable, + Rows: changes.Inserts, + QuoteColumn: quoteIdent, + Placeholder: placeholder, + Exec: func(query string, args ...interface{}) (sql.Result, error) { + return tx.Exec(query, args...) + }, + MaxArgs: customInsertMaxArgs(isSQLServer, isSQLite), + }); err != nil { + return err } return tx.Commit() } +func customInsertMaxArgs(isSQLServer, isSQLite bool) int { + switch { + case isSQLServer: + return sqlServerBatchInsertArgs + case isSQLite: + return sqliteBatchInsertArgs + default: + return 0 + } +} + func (c *CustomDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) { return nil, fmt.Errorf("not implemented for custom") } diff --git a/internal/db/duckdb_impl.go b/internal/db/duckdb_impl.go index 151c3e0..e0e36b9 100644 --- a/internal/db/duckdb_impl.go +++ b/internal/db/duckdb_impl.go @@ -398,24 +398,17 @@ func (d *DuckDB) ApplyChanges(tableName string, changes connection.ChangeSet) er } } - for _, row := range changes.Inserts { - var cols []string - var placeholders []string - var args []interface{} - - for k, v := range row { - cols = append(cols, quoteIdent(k)) - placeholders = append(placeholders, "?") - args = append(args, v) - } - if len(cols) == 0 { - continue - } - - query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", ")) - if _, err := tx.Exec(query, args...); err != nil { - return fmt.Errorf("插入失败:%v", err) - } + if err := execParameterizedInsertBatches(parameterizedInsertConfig{ + Table: qualifiedTable, + Rows: changes.Inserts, + QuoteColumn: quoteIdent, + Placeholder: func(int) string { return "?" }, + Exec: func(query string, args ...interface{}) (sql.Result, error) { + return tx.Exec(query, args...) + }, + MaxArgs: sqliteBatchInsertArgs, + }); err != nil { + return err } return tx.Commit() diff --git a/internal/db/highgo_impl.go b/internal/db/highgo_impl.go index 7dd6145..492d389 100644 --- a/internal/db/highgo_impl.go +++ b/internal/db/highgo_impl.go @@ -649,28 +649,18 @@ func (h *HighGoDB) ApplyChanges(tableName string, changes connection.ChangeSet) } } - // 3. Inserts - for _, row := range changes.Inserts { - var cols []string - var placeholders []string - var args []interface{} - idx := 0 - - for k, v := range row { - idx++ - cols = append(cols, quoteIdent(k)) - placeholders = append(placeholders, fmt.Sprintf("$%d", idx)) - args = append(args, v) - } - - if len(cols) == 0 { - continue - } - - query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", ")) - if _, err := tx.Exec(query, args...); err != nil { - return fmt.Errorf("插入失败:%v", err) - } + if err := execParameterizedInsertBatches(parameterizedInsertConfig{ + Table: qualifiedTable, + Rows: changes.Inserts, + QuoteColumn: quoteIdent, + Placeholder: func(idx int) string { + return fmt.Sprintf("$%d", idx) + }, + Exec: func(query string, args ...interface{}) (sql.Result, error) { + return tx.Exec(query, args...) + }, + }); err != nil { + return err } return tx.Commit() diff --git a/internal/db/kingbase_impl.go b/internal/db/kingbase_impl.go index 3addd15..6f508ad 100644 --- a/internal/db/kingbase_impl.go +++ b/internal/db/kingbase_impl.go @@ -917,28 +917,18 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet } } - // 3. Inserts - for _, row := range changes.Inserts { - var cols []string - var placeholders []string - var args []interface{} - idx := 0 - - for k, v := range row { - idx++ - cols = append(cols, quoteKingbaseIdent(k)) - placeholders = append(placeholders, fmt.Sprintf("$%d", idx)) - args = append(args, v) - } - - if len(cols) == 0 { - continue - } - - query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", ")) - if _, err := tx.Exec(query, args...); err != nil { - return fmt.Errorf("插入失败:%v; sql=%s", err, query) - } + if err := execParameterizedInsertBatches(parameterizedInsertConfig{ + Table: qualifiedTable, + Rows: changes.Inserts, + QuoteColumn: quoteKingbaseIdent, + Placeholder: func(idx int) string { + return fmt.Sprintf("$%d", idx) + }, + Exec: func(query string, args ...interface{}) (sql.Result, error) { + return tx.Exec(query, args...) + }, + }); err != nil { + return err } return tx.Commit() diff --git a/internal/db/mariadb_impl.go b/internal/db/mariadb_impl.go index e9df19b..410cbbe 100644 --- a/internal/db/mariadb_impl.go +++ b/internal/db/mariadb_impl.go @@ -419,26 +419,23 @@ func (m *MariaDB) ApplyChanges(tableName string, changes connection.ChangeSet) e } } - // 3. Inserts - for _, row := range changes.Inserts { - var cols []string - var placeholders []string - var args []interface{} - - for k, v := range row { - cols = append(cols, fmt.Sprintf("`%s`", k)) - placeholders = append(placeholders, "?") - args = append(args, normalizeMySQLComplexValue(normalizeMySQLDateTimeValue(v))) - } - - if len(cols) == 0 { - continue - } - - query := fmt.Sprintf("INSERT INTO `%s` (%s) VALUES (%s)", tableName, strings.Join(cols, ", "), strings.Join(placeholders, ", ")) - if _, err := tx.Exec(query, args...); err != nil { - return fmt.Errorf("插入失败:%v", err) - } + if err := execParameterizedInsertBatches(parameterizedInsertConfig{ + Table: fmt.Sprintf("`%s`", escapeMySQLBacktickIdent(tableName)), + Rows: changes.Inserts, + QuoteColumn: func(column string) string { + return fmt.Sprintf("`%s`", escapeMySQLBacktickIdent(column)) + }, + Placeholder: func(int) string { return "?" }, + Value: func(_ string, value interface{}) (interface{}, bool) { + return normalizeMySQLComplexValue(normalizeMySQLDateTimeValue(value)), false + }, + Exec: func(query string, args ...interface{}) (sql.Result, error) { + return tx.Exec(query, args...) + }, + MaxRows: defaultMySQLInsertBatchSize, + MaxArgs: maxMySQLInsertBatchArgs, + }); err != nil { + return err } return tx.Commit() diff --git a/internal/db/mongodb_impl.go b/internal/db/mongodb_impl.go index 82daf6c..e6e2e19 100644 --- a/internal/db/mongodb_impl.go +++ b/internal/db/mongodb_impl.go @@ -1314,15 +1314,37 @@ func (m *MongoDB) ApplyChanges(tableName string, changes connection.ChangeSet) e } } - // Process inserts - for _, row := range changes.Inserts { - doc := copyMongoChangeDocument(row) - if len(doc) > 0 { - if _, err := collection.InsertOne(ctx, doc); err != nil { - return fmt.Errorf("插入失败:%v", err) - } - } + if err := insertMongoDocuments(ctx, collection, changes.Inserts); err != nil { + return err } return nil } + +func insertMongoDocuments(ctx context.Context, collection *mongo.Collection, rows []map[string]interface{}) error { + if len(rows) == 0 { + return nil + } + + docs := make([]interface{}, 0, len(rows)) + for _, row := range rows { + doc := copyMongoChangeDocument(row) + if len(doc) > 0 { + docs = append(docs, doc) + } + } + if len(docs) == 0 { + return nil + } + + for start := 0; start < len(docs); start += defaultBatchInsertRows { + end := start + defaultBatchInsertRows + if end > len(docs) { + end = len(docs) + } + if _, err := collection.InsertMany(ctx, docs[start:end]); err != nil { + return fmt.Errorf("插入失败:%v", err) + } + } + return nil +} diff --git a/internal/db/mongodb_impl_v1.go b/internal/db/mongodb_impl_v1.go index 2b610b0..ed49ed6 100644 --- a/internal/db/mongodb_impl_v1.go +++ b/internal/db/mongodb_impl_v1.go @@ -1317,15 +1317,37 @@ func (m *MongoDBV1) ApplyChanges(tableName string, changes connection.ChangeSet) } } - // Process inserts - for _, row := range changes.Inserts { - doc := copyMongoChangeDocument(row) - if len(doc) > 0 { - if _, err := collection.InsertOne(ctx, doc); err != nil { - return fmt.Errorf("插入失败:%v", err) - } - } + if err := insertMongoV1Documents(ctx, collection, changes.Inserts); err != nil { + return err } return nil } + +func insertMongoV1Documents(ctx context.Context, collection *mongo.Collection, rows []map[string]interface{}) error { + if len(rows) == 0 { + return nil + } + + docs := make([]interface{}, 0, len(rows)) + for _, row := range rows { + doc := copyMongoChangeDocument(row) + if len(doc) > 0 { + docs = append(docs, doc) + } + } + if len(docs) == 0 { + return nil + } + + for start := 0; start < len(docs); start += defaultBatchInsertRows { + end := start + defaultBatchInsertRows + if end > len(docs) { + end = len(docs) + } + if _, err := collection.InsertMany(ctx, docs[start:end]); err != nil { + return fmt.Errorf("插入失败:%v", err) + } + } + return nil +} diff --git a/internal/db/mysql_impl.go b/internal/db/mysql_impl.go index d859ebd..ccbca70 100644 --- a/internal/db/mysql_impl.go +++ b/internal/db/mysql_impl.go @@ -25,7 +25,11 @@ type MySQLDB struct { pingTimeout time.Duration } -const defaultMySQLPort = 3306 +const ( + defaultMySQLPort = 3306 + defaultMySQLInsertBatchSize = 1000 + maxMySQLInsertBatchArgs = 60000 +) func parseMySQLCompatibleURI(raw string, allowedSchemes ...string) (*url.URL, bool) { return parseConnectionURI(raw, allowedSchemes...) @@ -1020,47 +1024,40 @@ func (m *MySQLDB) ApplyChanges(tableName string, changes connection.ChangeSet) e } } - // 3. Inserts - for _, row := range changes.Inserts { - var cols []string - var placeholders []string - var args []interface{} - - for k, v := range row { - normalizedValue, omit := normalizeMySQLValueForInsert(k, v, columnTypeMap) - if omit { - continue - } - cols = append(cols, fmt.Sprintf("`%s`", k)) - placeholders = append(placeholders, "?") - args = append(args, normalizedValue) - } - - if len(cols) == 0 { - query := fmt.Sprintf("INSERT INTO `%s` () VALUES ()", tableName) - res, err := tx.Exec(query) - if err != nil { - return fmt.Errorf("插入失败:%v", err) - } - if affected, err := res.RowsAffected(); err == nil && affected == 0 { - return fmt.Errorf("插入未生效:未影响任何行") - } - continue - } - - query := fmt.Sprintf("INSERT INTO `%s` (%s) VALUES (%s)", tableName, strings.Join(cols, ", "), strings.Join(placeholders, ", ")) - res, err := tx.Exec(query, args...) - if err != nil { - return fmt.Errorf("插入失败:%v", err) - } - if affected, err := res.RowsAffected(); err == nil && affected == 0 { - return fmt.Errorf("插入未生效:未影响任何行") - } + if err := m.applyInsertChanges(tx, tableName, changes.Inserts, columnTypeMap); err != nil { + return err } return tx.Commit() } +func (m *MySQLDB) applyInsertChanges(tx *sql.Tx, tableName string, rows []map[string]interface{}, columnTypeMap map[string]string) error { + return execParameterizedInsertBatches(parameterizedInsertConfig{ + Table: fmt.Sprintf("`%s`", escapeMySQLBacktickIdent(tableName)), + Rows: rows, + QuoteColumn: func(column string) string { + return fmt.Sprintf("`%s`", escapeMySQLBacktickIdent(column)) + }, + Placeholder: func(int) string { return "?" }, + Value: func(column string, value interface{}) (interface{}, bool) { + return normalizeMySQLValueForInsert(column, value, columnTypeMap) + }, + Exec: func(query string, args ...interface{}) (sql.Result, error) { + return tx.Exec(query, args...) + }, + MaxRows: defaultMySQLInsertBatchSize, + MaxArgs: maxMySQLInsertBatchArgs, + RequireAffected: true, + EmptyInsertSQL: func(table string) string { + return fmt.Sprintf("INSERT INTO %s () VALUES ()", table) + }, + }) +} + +func escapeMySQLBacktickIdent(ident string) string { + return strings.ReplaceAll(strings.TrimSpace(ident), "`", "``") +} + func normalizeMySQLComplexValue(value interface{}) interface{} { switch v := value.(type) { case map[string]interface{}, []interface{}: diff --git a/internal/db/oracle_applychanges_test.go b/internal/db/oracle_applychanges_test.go index 23891bc..0f6ac58 100644 --- a/internal/db/oracle_applychanges_test.go +++ b/internal/db/oracle_applychanges_test.go @@ -366,6 +366,59 @@ func TestMySQLApplyChangesReturnsErrorWhenUpdateAffectsMultipleRows(t *testing.T } } +func TestMySQLApplyChangesBatchesLargeInsertRows(t *testing.T) { + t.Parallel() + + dbConn, state := openOracleRecordingDB(t) + state.rowsAffected = 1000 + mysqlDB := &MySQLDB{conn: dbConn} + + rows := make([]map[string]interface{}, 1201) + for i := range rows { + rows[i] = map[string]interface{}{ + "id": i + 1, + "name": fmt.Sprintf("name-%d", i+1), + } + } + + if err := mysqlDB.ApplyChanges("users", connection.ChangeSet{Inserts: rows}); err != nil { + t.Fatalf("ApplyChanges() unexpected error: %v", err) + } + + executions := state.snapshotExecQueries() + if len(executions) != 2 { + t.Fatalf("期望 1201 行插入拆成 2 条批量 INSERT,实际 %d 条:%v", len(executions), executions) + } + for _, query := range executions { + if !strings.HasPrefix(query, "INSERT INTO `users` (`id`, `name`) VALUES ") { + t.Fatalf("批量 INSERT 语句格式不正确: %s", query) + } + if got := strings.Count(query, "(?, ?)"); got == 0 || got > defaultMySQLInsertBatchSize { + t.Fatalf("批量 INSERT values 数量异常,got=%d query=%s", got, query) + } + } + if got := strings.Count(executions[0], "(?, ?)"); got != defaultMySQLInsertBatchSize { + t.Fatalf("第一批 values=%d, want %d", got, defaultMySQLInsertBatchSize) + } + if got := strings.Count(executions[1], "(?, ?)"); got != 201 { + t.Fatalf("第二批 values=%d, want 201", got) + } +} + +func TestMySQLInsertBatchSizeRespectsArgumentLimit(t *testing.T) { + t.Parallel() + + if got := batchInsertRowLimit(2, defaultMySQLInsertBatchSize, maxMySQLInsertBatchArgs); got != defaultMySQLInsertBatchSize { + t.Fatalf("2 列批大小=%d, want %d", got, defaultMySQLInsertBatchSize) + } + if got := batchInsertRowLimit(100, defaultMySQLInsertBatchSize, maxMySQLInsertBatchArgs); got != 600 { + t.Fatalf("100 列批大小=%d, want 600", got) + } + if got := batchInsertRowLimit(70000, defaultMySQLInsertBatchSize, maxMySQLInsertBatchArgs); got != 1 { + t.Fatalf("超宽表批大小=%d, want 1", got) + } +} + func TestPostgresApplyChangesReturnsErrorWhenDeleteAffectsMultipleRows(t *testing.T) { t.Parallel() diff --git a/internal/db/postgres_impl.go b/internal/db/postgres_impl.go index ad418e4..fefce41 100644 --- a/internal/db/postgres_impl.go +++ b/internal/db/postgres_impl.go @@ -864,28 +864,18 @@ func (p *PostgresDB) ApplyChanges(tableName string, changes connection.ChangeSet } } - // 3. Inserts - for _, row := range changes.Inserts { - var cols []string - var placeholders []string - var args []interface{} - idx := 0 - - for k, v := range row { - idx++ - cols = append(cols, quoteIdent(k)) - placeholders = append(placeholders, fmt.Sprintf("$%d", idx)) - args = append(args, v) - } - - if len(cols) == 0 { - continue - } - - query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", ")) - if _, err := tx.Exec(query, args...); err != nil { - return fmt.Errorf("插入失败:%v", err) - } + if err := execParameterizedInsertBatches(parameterizedInsertConfig{ + Table: qualifiedTable, + Rows: changes.Inserts, + QuoteColumn: quoteIdent, + Placeholder: func(idx int) string { + return fmt.Sprintf("$%d", idx) + }, + Exec: func(query string, args ...interface{}) (sql.Result, error) { + return tx.Exec(query, args...) + }, + }); err != nil { + return err } return tx.Commit() diff --git a/internal/db/sqlite_impl.go b/internal/db/sqlite_impl.go index 55ef3aa..acfb53d 100644 --- a/internal/db/sqlite_impl.go +++ b/internal/db/sqlite_impl.go @@ -690,26 +690,17 @@ func (s *SQLiteDB) ApplyChanges(tableName string, changes connection.ChangeSet) } } - // 3. Inserts - for _, row := range changes.Inserts { - var cols []string - var placeholders []string - var args []interface{} - - for k, v := range row { - cols = append(cols, quoteIdent(k)) - placeholders = append(placeholders, "?") - args = append(args, v) - } - - if len(cols) == 0 { - continue - } - - query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", ")) - if _, err := tx.Exec(query, args...); err != nil { - return fmt.Errorf("插入失败:%v", err) - } + if err := execParameterizedInsertBatches(parameterizedInsertConfig{ + Table: qualifiedTable, + Rows: changes.Inserts, + QuoteColumn: quoteIdent, + Placeholder: func(int) string { return "?" }, + Exec: func(query string, args ...interface{}) (sql.Result, error) { + return tx.Exec(query, args...) + }, + MaxArgs: sqliteBatchInsertArgs, + }); err != nil { + return err } return tx.Commit() diff --git a/internal/db/sqlserver_impl.go b/internal/db/sqlserver_impl.go index 6ed6389..ffa927a 100644 --- a/internal/db/sqlserver_impl.go +++ b/internal/db/sqlserver_impl.go @@ -664,28 +664,22 @@ func (s *SqlServerDB) ApplyChanges(tableName string, changes connection.ChangeSe } } - // 3. Inserts - for _, row := range changes.Inserts { - var cols []string - var placeholders []string - var args []interface{} - idx := 0 - - for k, v := range row { - idx++ - cols = append(cols, quoteIdent(k)) - placeholders = append(placeholders, fmt.Sprintf("@p%d", idx)) - args = append(args, sql.Named(fmt.Sprintf("p%d", idx), v)) - } - - if len(cols) == 0 { - continue - } - - query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", ")) - if _, err := tx.Exec(query, args...); err != nil { - return fmt.Errorf("插入失败:%v", err) - } + if err := execParameterizedInsertBatches(parameterizedInsertConfig{ + Table: qualifiedTable, + Rows: changes.Inserts, + QuoteColumn: quoteIdent, + Placeholder: func(idx int) string { + return fmt.Sprintf("@p%d", idx) + }, + Arg: func(idx int, _ string, value interface{}) interface{} { + return sql.Named(fmt.Sprintf("p%d", idx), value) + }, + Exec: func(query string, args ...interface{}) (sql.Result, error) { + return tx.Exec(query, args...) + }, + MaxArgs: sqlServerBatchInsertArgs, + }); err != nil { + return err } return tx.Commit() diff --git a/internal/db/tdengine_impl.go b/internal/db/tdengine_impl.go index 059f601..785d307 100644 --- a/internal/db/tdengine_impl.go +++ b/internal/db/tdengine_impl.go @@ -406,19 +406,24 @@ func (t *TDengineDB) ApplyChanges(tableName string, changes connection.ChangeSet } qualifiedTable := quoteTDengineTable("", tableName) - for _, row := range changes.Inserts { - query, err := buildTDengineInsertSQL(qualifiedTable, row) - if err != nil { - return err - } - if query == "" { - continue - } - if _, err := t.conn.Exec(query); err != nil { - return fmt.Errorf("插入失败:%v; sql=%s", err, query) - } + return execTDengineInsertBatches(t.conn, qualifiedTable, changes.Inserts) +} + +func execTDengineInsertBatches(conn *sql.DB, qualifiedTable string, rows []map[string]interface{}) error { + if conn == nil { + return fmt.Errorf("连接未打开") } - return nil + return execLiteralInsertBatches(literalInsertConfig{ + Table: qualifiedTable, + Rows: rows, + QuoteColumn: func(column string) string { + return fmt.Sprintf("`%s`", escapeBacktickIdent(column)) + }, + Literal: tdengineLiteral, + Exec: func(query string) (sql.Result, error) { + return conn.Exec(query) + }, + }) } func buildTDengineInsertSQL(qualifiedTable string, row map[string]interface{}) (string, error) { diff --git a/internal/db/vastbase_impl.go b/internal/db/vastbase_impl.go index 5db6f2d..7bb5521 100644 --- a/internal/db/vastbase_impl.go +++ b/internal/db/vastbase_impl.go @@ -648,28 +648,18 @@ func (v *VastbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet } } - // 3. Inserts - for _, row := range changes.Inserts { - var cols []string - var placeholders []string - var args []interface{} - idx := 0 - - for k, val := range row { - idx++ - cols = append(cols, quoteIdent(k)) - placeholders = append(placeholders, fmt.Sprintf("$%d", idx)) - args = append(args, val) - } - - if len(cols) == 0 { - continue - } - - query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", ")) - if _, err := tx.Exec(query, args...); err != nil { - return fmt.Errorf("插入失败:%v", err) - } + if err := execParameterizedInsertBatches(parameterizedInsertConfig{ + Table: qualifiedTable, + Rows: changes.Inserts, + QuoteColumn: quoteIdent, + Placeholder: func(idx int) string { + return fmt.Sprintf("$%d", idx) + }, + Exec: func(query string, args ...interface{}) (sql.Result, error) { + return tx.Exec(query, args...) + }, + }); err != nil { + return err } return tx.Commit() diff --git a/internal/sync/analyze.go b/internal/sync/analyze.go index 4bd62b0..acb9c3a 100644 --- a/internal/sync/analyze.go +++ b/internal/sync/analyze.go @@ -101,7 +101,7 @@ func (s *SyncEngine) Analyze(config SyncConfig) SyncAnalyzeResult { HasSchema: syncSchema, } - plan, cols, _, err := buildSchemaMigrationPlan(config, tableName, sourceDB, targetDB) + plan, cols, targetCols, err := buildSchemaMigrationPlan(config, tableName, sourceDB, targetDB) if err != nil { summary.Message = err.Error() result.Tables = append(result.Tables, summary) @@ -140,16 +140,27 @@ func (s *SyncEngine) Analyze(config SyncConfig) SyncAnalyzeResult { } } - sourceRows, _, err := sourceDB.Query(fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(config.SourceConfig.Type, plan.SourceQueryTable))) + sourceType := resolveMigrationDBType(config.SourceConfig) + targetType := resolveMigrationDBType(config.TargetConfig) + sourceCount, counted, err := countTableRowsForSync(sourceDB, sourceType, plan.SourceQueryTable) if err != nil { summary.Message = "读取源表失败: " + err.Error() result.Tables = append(result.Tables, summary) return } + if !counted { + sourceRows, _, err := sourceDB.Query(fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(sourceType, plan.SourceQueryTable))) + if err != nil { + summary.Message = "读取源表失败: " + err.Error() + result.Tables = append(result.Tables, summary) + return + } + sourceCount = len(sourceRows) + } if !plan.TargetTableExists && plan.AutoCreate { summary.CanSync = true - summary.Inserts = len(sourceRows) + summary.Inserts = sourceCount summary.Message = firstNonEmpty(plan.PlannedAction, "目标表不存在,执行时将自动建表并导入全部源数据") result.Tables = append(result.Tables, summary) return @@ -157,7 +168,7 @@ func (s *SyncEngine) Analyze(config SyncConfig) SyncAnalyzeResult { if tableMode != "insert_update" { summary.CanSync = true - summary.Inserts = len(sourceRows) + summary.Inserts = sourceCount summary.Message = firstNonEmpty(plan.PlannedAction, "当前模式无需差异对比,将按源表数据执行导入") result.Tables = append(result.Tables, summary) return @@ -175,6 +186,32 @@ func (s *SyncEngine) Analyze(config SyncConfig) SyncAnalyzeResult { } summary.PKColumn = pkCols[0] + targetColSet := buildTargetColumnSet(targetCols) + handled, counts, scanErr := scanTableDiffInPages(sourceDB, targetDB, sourceType, targetType, plan, cols, targetCols, summary.PKColumn, targetColSet, true, nil) + if handled { + if scanErr != nil { + summary.Message = scanErr.Error() + result.Tables = append(result.Tables, summary) + return + } + summary.CanSync = true + summary.Inserts = counts.Inserts + summary.Updates = counts.Updates + summary.Deletes = counts.Deletes + summary.Same = counts.Same + if strings.TrimSpace(summary.Message) == "" { + summary.Message = firstNonEmpty(plan.PlannedAction, "差异分析完成") + } + result.Tables = append(result.Tables, summary) + return + } + + sourceRows, _, err := sourceDB.Query(fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(sourceType, plan.SourceQueryTable))) + if err != nil { + summary.Message = "读取源表失败: " + err.Error() + result.Tables = append(result.Tables, summary) + return + } targetRows, _, err := targetDB.Query(fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(config.TargetConfig.Type, plan.TargetQueryTable))) if err != nil { summary.Message = "读取目标表失败: " + err.Error() diff --git a/internal/sync/diff_paging.go b/internal/sync/diff_paging.go new file mode 100644 index 0000000..8c319cc --- /dev/null +++ b/internal/sync/diff_paging.go @@ -0,0 +1,587 @@ +package sync + +import ( + "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/db" + "fmt" + "reflect" + "sort" + "strconv" + "strings" + "time" +) + +type pagedDiffCounts struct { + Inserts int + Updates int + Deletes int + Same int +} + +type pagedUpdateDiff struct { + UpdateRow connection.UpdateRow + Source map[string]interface{} + Target map[string]interface{} + ChangedColumns []string +} + +type pagedDiffPage struct { + Inserts []map[string]interface{} + Updates []pagedUpdateDiff + Deletes []map[string]interface{} + Same int +} + +func (s *SyncEngine) tryApplyDiffInPages(config SyncConfig, res *SyncResult, tableIndex, totalTables int, tableName string, sourceDB db.Database, targetDB db.Database, plan SchemaMigrationPlan, sourceCols, targetCols []connection.ColumnDefinition, opts TableOptions, sourceType, targetType, applyTableName, pkCol string) (bool, pagedDiffCounts, error) { + if normalizeSyncMode(config.Mode) != "insert_update" || !plan.TargetTableExists { + return false, pagedDiffCounts{}, nil + } + if !supportsPagedDiffSelect(sourceType) || !supportsPagedDiffSelect(targetType) { + return false, pagedDiffCounts{}, nil + } + if opts.Delete && (!supportsPagedDiffKeysetSelect(targetType) || !supportsPagedDiffPKLookup(sourceType)) { + return false, pagedDiffCounts{}, nil + } + + applier, ok := targetDB.(db.BatchApplier) + if !ok { + return true, pagedDiffCounts{}, fmt.Errorf("目标驱动不支持应用数据变更 (ApplyChanges)") + } + + targetColSet, err := s.prepareDirectImportTargetColumnSet(config, res, targetDB, plan, sourceType, targetType, sourceCols, targetCols) + if err != nil { + return true, pagedDiffCounts{}, err + } + + s.appendLog(config.JobID, res, "info", fmt.Sprintf(" -> 启用分页差异同步:按主键 %s 每批读取 %d 行", pkCol, defaultSyncReadPageSize)) + s.progress(config.JobID, tableIndex, totalTables, tableName, "分页对比数据") + + applied := pagedDiffCounts{} + handled, _, err := scanTableDiffInPages(sourceDB, targetDB, sourceType, targetType, plan, sourceCols, targetCols, pkCol, targetColSet, opts.Delete, func(page pagedDiffPage) error { + changeSet := connection.ChangeSet{ + Inserts: filterRowsByPKSelection(pkCol, page.Inserts, opts.Insert, opts.SelectedInsertPKs), + Updates: filterPagedUpdatesByPKSelection(pkCol, page.Updates, opts.Update, opts.SelectedUpdatePKs), + Deletes: filterRowsByPKSelection(pkCol, page.Deletes, opts.Delete, opts.SelectedDeletePKs), + } + if len(targetColSet) > 0 { + changeSet.Inserts = filterInsertRows(changeSet.Inserts, targetColSet) + changeSet.Updates = filterUpdateRows(changeSet.Updates, targetColSet) + } + if len(changeSet.Inserts) == 0 && len(changeSet.Updates) == 0 && len(changeSet.Deletes) == 0 { + return nil + } + if err := s.applyChangesInBatches(config.JobID, res, applyTableName, applier, changeSet); err != nil { + return err + } + applied.Inserts += len(changeSet.Inserts) + applied.Updates += len(changeSet.Updates) + applied.Deletes += len(changeSet.Deletes) + return nil + }) + if err != nil { + return true, applied, err + } + return handled, applied, nil +} + +func scanTableDiffInPages(sourceDB db.Database, targetDB db.Database, sourceType, targetType string, plan SchemaMigrationPlan, sourceCols, targetCols []connection.ColumnDefinition, pkCol string, targetColSet map[string]struct{}, includeDeletes bool, consume func(page pagedDiffPage) error) (bool, pagedDiffCounts, error) { + if !supportsPagedDiffSelect(sourceType) || !supportsPagedDiffPKLookup(targetType) { + return false, pagedDiffCounts{}, nil + } + if includeDeletes && (!supportsPagedDiffKeysetSelect(targetType) || !supportsPagedDiffPKLookup(sourceType)) { + return false, pagedDiffCounts{}, nil + } + + sourceReadCols := diffReadableColumns(sourceCols, targetColSet, pkCol) + if len(sourceReadCols) == 0 { + return false, pagedDiffCounts{}, nil + } + targetLookupCols := diffLookupColumns(sourceReadCols, targetCols, targetColSet, pkCol) + if len(targetLookupCols) == 0 { + return false, pagedDiffCounts{}, nil + } + + totals := pagedDiffCounts{} + for offset := 0; ; offset += defaultSyncReadPageSize { + query := buildPagedSourceTableQuery(sourceType, plan.SourceQueryTable, sourceReadCols, pkCol, defaultSyncReadPageSize, offset) + if strings.TrimSpace(query) == "" { + return false, pagedDiffCounts{}, nil + } + sourceRows, _, err := sourceDB.Query(query) + if err != nil { + return true, totals, fmt.Errorf("分页读取源表失败(offset=%d): %w", offset, err) + } + if len(sourceRows) == 0 { + break + } + + pkValues := collectPKValues(sourceRows, pkCol) + targetRows := make([]map[string]interface{}, 0) + if len(pkValues) > 0 { + targetQuery := buildPKInSelectQuery(targetType, plan.TargetQueryTable, targetLookupCols, pkCol, pkValues) + if strings.TrimSpace(targetQuery) == "" { + return false, pagedDiffCounts{}, nil + } + targetRows, _, err = targetDB.Query(targetQuery) + if err != nil { + return true, totals, fmt.Errorf("按主键读取目标表失败(offset=%d): %w", offset, err) + } + } + + page := diffSourcePageByPK(pkCol, sourceRows, targetRows) + totals.Inserts += len(page.Inserts) + totals.Updates += len(page.Updates) + totals.Same += page.Same + if consume != nil { + if err := consume(page); err != nil { + return true, totals, err + } + } + if len(sourceRows) < defaultSyncReadPageSize { + break + } + } + + if includeDeletes { + lastPK, hasLastPK := interface{}(nil), false + targetPKCols := []connection.ColumnDefinition{{Name: pkCol}} + for { + query := buildKeysetPagedTableQuery(targetType, plan.TargetQueryTable, targetPKCols, pkCol, lastPK, hasLastPK, defaultSyncReadPageSize) + if strings.TrimSpace(query) == "" { + return false, pagedDiffCounts{}, nil + } + targetRows, _, err := targetDB.Query(query) + if err != nil { + return true, totals, fmt.Errorf("分页读取目标主键失败: %w", err) + } + if len(targetRows) == 0 { + break + } + + nextLastPK, ok := lastValidPKValue(targetRows, pkCol) + if !ok { + break + } + lastPK, hasLastPK = nextLastPK, true + + pkValues := collectPKValues(targetRows, pkCol) + sourcePKRows := make([]map[string]interface{}, 0) + if len(pkValues) > 0 { + sourceQuery := buildPKInSelectQuery(sourceType, plan.SourceQueryTable, targetPKCols, pkCol, pkValues) + if strings.TrimSpace(sourceQuery) == "" { + return false, pagedDiffCounts{}, nil + } + sourcePKRows, _, err = sourceDB.Query(sourceQuery) + if err != nil { + return true, totals, fmt.Errorf("按主键反查源表失败: %w", err) + } + } + + sourcePKSet := buildPKSet(sourcePKRows, pkCol) + deletes := make([]map[string]interface{}, 0) + for _, row := range targetRows { + pkKey, ok := pkValueKey(row[pkCol]) + if !ok { + continue + } + if _, exists := sourcePKSet[pkKey]; exists { + continue + } + deletes = append(deletes, map[string]interface{}{pkCol: row[pkCol]}) + } + if len(deletes) > 0 { + totals.Deletes += len(deletes) + if consume != nil { + if err := consume(pagedDiffPage{Deletes: deletes}); err != nil { + return true, totals, err + } + } + } + if len(targetRows) < defaultSyncReadPageSize { + break + } + } + } + + return true, totals, nil +} + +func diffSourcePageByPK(pkCol string, sourceRows, targetRows []map[string]interface{}) pagedDiffPage { + targetMap := make(map[string]map[string]interface{}, len(targetRows)) + for _, row := range targetRows { + pkKey, ok := pkValueKey(row[pkCol]) + if !ok { + continue + } + targetMap[pkKey] = row + } + + page := pagedDiffPage{ + Inserts: make([]map[string]interface{}, 0), + Updates: make([]pagedUpdateDiff, 0), + } + for _, sourceRow := range sourceRows { + pkKey, ok := pkValueKey(sourceRow[pkCol]) + if !ok { + continue + } + targetRow, exists := targetMap[pkKey] + if !exists { + page.Inserts = append(page.Inserts, sourceRow) + continue + } + + changes := make(map[string]interface{}) + changedColumns := make([]string, 0) + for key, value := range sourceRow { + if fmt.Sprintf("%v", value) == fmt.Sprintf("%v", targetRow[key]) { + continue + } + changes[key] = value + changedColumns = append(changedColumns, key) + } + if len(changes) == 0 { + page.Same++ + continue + } + sort.Strings(changedColumns) + page.Updates = append(page.Updates, pagedUpdateDiff{ + UpdateRow: connection.UpdateRow{ + Keys: map[string]interface{}{pkCol: sourceRow[pkCol]}, + Values: changes, + }, + Source: sourceRow, + Target: targetRow, + ChangedColumns: changedColumns, + }) + } + return page +} + +func filterPagedUpdatesByPKSelection(pkCol string, updates []pagedUpdateDiff, enabled bool, selectedPKs []string) []connection.UpdateRow { + if !enabled { + return nil + } + if len(updates) == 0 { + return nil + } + out := make([]connection.UpdateRow, 0, len(updates)) + for _, update := range updates { + out = append(out, update.UpdateRow) + } + return filterUpdatesByPKSelection(pkCol, out, true, selectedPKs) +} + +func diffReadableColumns(sourceCols []connection.ColumnDefinition, allowedLower map[string]struct{}, pkCol string) []connection.ColumnDefinition { + out := make([]connection.ColumnDefinition, 0, len(sourceCols)) + seen := map[string]struct{}{} + add := func(col connection.ColumnDefinition) { + name := strings.TrimSpace(col.Name) + lower := strings.ToLower(name) + if name == "" { + return + } + if _, ok := seen[lower]; ok { + return + } + seen[lower] = struct{}{} + out = append(out, col) + } + for _, col := range sourceCols { + name := strings.TrimSpace(col.Name) + lower := strings.ToLower(name) + if name == "" { + continue + } + if strings.EqualFold(name, pkCol) { + add(col) + continue + } + if len(allowedLower) > 0 { + if _, ok := allowedLower[lower]; !ok { + continue + } + } + add(col) + } + if _, ok := seen[strings.ToLower(strings.TrimSpace(pkCol))]; !ok && strings.TrimSpace(pkCol) != "" { + add(connection.ColumnDefinition{Name: pkCol}) + } + return out +} + +func diffLookupColumns(sourceReadCols, targetCols []connection.ColumnDefinition, allowedLower map[string]struct{}, pkCol string) []connection.ColumnDefinition { + targetByLower := make(map[string]connection.ColumnDefinition, len(targetCols)) + for _, col := range targetCols { + name := strings.TrimSpace(col.Name) + if name != "" { + targetByLower[strings.ToLower(name)] = col + } + } + + out := make([]connection.ColumnDefinition, 0, len(sourceReadCols)) + seen := map[string]struct{}{} + for _, sourceCol := range sourceReadCols { + name := strings.TrimSpace(sourceCol.Name) + lower := strings.ToLower(name) + if name == "" { + continue + } + if _, ok := seen[lower]; ok { + continue + } + if !strings.EqualFold(name, pkCol) && len(allowedLower) > 0 { + if _, ok := allowedLower[lower]; !ok { + continue + } + } + if targetCol, ok := targetByLower[lower]; ok { + out = append(out, targetCol) + } else { + out = append(out, connection.ColumnDefinition{Name: name}) + } + seen[lower] = struct{}{} + } + if _, ok := seen[strings.ToLower(strings.TrimSpace(pkCol))]; !ok && strings.TrimSpace(pkCol) != "" { + out = append(out, connection.ColumnDefinition{Name: pkCol}) + } + return out +} + +func collectPKValues(rows []map[string]interface{}, pkCol string) []interface{} { + values := make([]interface{}, 0, len(rows)) + seen := make(map[string]struct{}, len(rows)) + for _, row := range rows { + key, ok := pkValueKey(row[pkCol]) + if !ok { + continue + } + if _, exists := seen[key]; exists { + continue + } + seen[key] = struct{}{} + values = append(values, row[pkCol]) + } + return values +} + +func buildPKSet(rows []map[string]interface{}, pkCol string) map[string]struct{} { + set := make(map[string]struct{}, len(rows)) + for _, row := range rows { + key, ok := pkValueKey(row[pkCol]) + if ok { + set[key] = struct{}{} + } + } + return set +} + +func lastValidPKValue(rows []map[string]interface{}, pkCol string) (interface{}, bool) { + for i := len(rows) - 1; i >= 0; i-- { + if _, ok := pkValueKey(rows[i][pkCol]); ok { + return rows[i][pkCol], true + } + } + return nil, false +} + +func pkValueKey(value interface{}) (string, bool) { + if value == nil { + return "", false + } + key := strings.TrimSpace(fmt.Sprintf("%v", value)) + if key == "" || key == "" { + return "", false + } + return key, true +} + +func buildPKInSelectQuery(dbType, queryTable string, cols []connection.ColumnDefinition, pkCol string, pkValues []interface{}) string { + if len(pkValues) == 0 { + return "" + } + selectList := buildColumnSelectListForSync(dbType, cols) + if strings.TrimSpace(selectList) == "" { + selectList = "*" + } + literals := make([]string, 0, len(pkValues)) + for _, value := range pkValues { + literal, ok := formatSyncSQLLiteral(value) + if !ok { + continue + } + literals = append(literals, literal) + } + if len(literals) == 0 { + return "" + } + return fmt.Sprintf("SELECT %s FROM %s WHERE %s IN (%s)", + selectList, + quoteQualifiedIdentByType(dbType, queryTable), + quoteIdentByType(dbType, pkCol), + strings.Join(literals, ", ")) +} + +func buildKeysetPagedTableQuery(dbType, queryTable string, cols []connection.ColumnDefinition, orderCol string, lastValue interface{}, hasLastValue bool, limit int) string { + selectList := buildColumnSelectListForSync(dbType, cols) + if strings.TrimSpace(selectList) == "" { + selectList = "*" + } + safeLimit := limit + if safeLimit <= 0 { + safeLimit = defaultSyncReadPageSize + } + where := "" + if hasLastValue { + literal, ok := formatSyncSQLLiteral(lastValue) + if !ok { + return "" + } + where = fmt.Sprintf(" WHERE %s > %s", quoteIdentByType(dbType, orderCol), literal) + } + orderBy := fmt.Sprintf(" ORDER BY %s ASC", quoteIdentByType(dbType, orderCol)) + if normalizeMigrationDBType(dbType) == "sqlserver" { + return fmt.Sprintf("SELECT TOP (%d) %s FROM %s%s%s", safeLimit, selectList, quoteQualifiedIdentByType(dbType, queryTable), where, orderBy) + } + return fmt.Sprintf("SELECT %s FROM %s%s%s LIMIT %d", selectList, quoteQualifiedIdentByType(dbType, queryTable), where, orderBy, safeLimit) +} + +func countTableRowsForSync(database db.Database, dbType, queryTable string) (int, bool, error) { + query := fmt.Sprintf("SELECT COUNT(*) AS __gonavi_count__ FROM %s", quoteQualifiedIdentByType(dbType, queryTable)) + rows, _, err := database.Query(query) + if err != nil { + return 0, true, err + } + if len(rows) == 0 { + return 0, false, nil + } + for _, value := range rows[0] { + count, ok := intFromSyncValue(value) + if ok { + return count, true, nil + } + } + return 0, false, nil +} + +func intFromSyncValue(value interface{}) (int, bool) { + if value == nil { + return 0, false + } + switch v := value.(type) { + case int: + return v, true + case int8: + return int(v), true + case int16: + return int(v), true + case int32: + return int(v), true + case int64: + return int(v), true + case uint: + return int(v), true + case uint8: + return int(v), true + case uint16: + return int(v), true + case uint32: + return int(v), true + case uint64: + return int(v), true + case float32: + return int(v), true + case float64: + return int(v), true + case []byte: + i, err := strconv.Atoi(strings.TrimSpace(string(v))) + return i, err == nil + case string: + i, err := strconv.Atoi(strings.TrimSpace(v)) + return i, err == nil + default: + rv := reflect.ValueOf(value) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + i := rv.Int() + if i > int64(^uint(0)>>1) { + return 0, false + } + return int(i), true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + u := rv.Uint() + if u > uint64(^uint(0)>>1) { + return 0, false + } + return int(u), true + } + } + return 0, false +} + +func buildColumnSelectListForSync(dbType string, cols []connection.ColumnDefinition) string { + quoted := make([]string, 0, len(cols)) + seen := map[string]struct{}{} + for _, col := range cols { + name := strings.TrimSpace(col.Name) + lower := strings.ToLower(name) + if name == "" { + continue + } + if _, ok := seen[lower]; ok { + continue + } + seen[lower] = struct{}{} + quoted = append(quoted, quoteIdentByType(dbType, name)) + } + return strings.Join(quoted, ", ") +} + +func formatSyncSQLLiteral(value interface{}) (string, bool) { + if value == nil { + return "", false + } + switch v := value.(type) { + case time.Time: + return quoteSyncSQLString(v.Format("2006-01-02 15:04:05.999999999")), true + case []byte: + return quoteSyncSQLString(string(v)), true + case string: + if strings.TrimSpace(v) == "" { + return "", false + } + return quoteSyncSQLString(v), true + case bool: + if v { + return "1", true + } + return "0", true + default: + text := strings.TrimSpace(fmt.Sprintf("%v", value)) + if text == "" || text == "" { + return "", false + } + switch value.(type) { + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64: + return text, true + default: + return quoteSyncSQLString(text), true + } + } +} + +func quoteSyncSQLString(value string) string { + return "'" + strings.ReplaceAll(value, "'", "''") + "'" +} + +func supportsPagedDiffSelect(dbType string) bool { + return supportsDirectImportPagination(dbType) +} + +func supportsPagedDiffPKLookup(dbType string) bool { + return supportsDirectImportPagination(dbType) +} + +func supportsPagedDiffKeysetSelect(dbType string) bool { + return supportsDirectImportPagination(dbType) +} diff --git a/internal/sync/direct_import_paging.go b/internal/sync/direct_import_paging.go new file mode 100644 index 0000000..6142c98 --- /dev/null +++ b/internal/sync/direct_import_paging.go @@ -0,0 +1,304 @@ +package sync + +import ( + "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/db" + "fmt" + "strings" +) + +const defaultSyncReadPageSize = defaultSyncApplyBatchSize + +func (s *SyncEngine) tryApplyDirectImportInPages(config SyncConfig, res *SyncResult, tableIndex, totalTables int, tableName string, sourceDB db.Database, targetDB db.Database, plan SchemaMigrationPlan, sourceCols, targetCols []connection.ColumnDefinition, opts TableOptions, sourceType, targetType, applyTableName string) (bool, int, error) { + tableMode := normalizeSyncMode(config.Mode) + if tableMode == "insert_update" && plan.TargetTableExists { + return false, 0, nil + } + if tableMode == "full_overwrite" && plan.TargetTableExists && isSamePhysicalSyncTable(config, plan, sourceType, targetType) { + return false, 0, nil + } + if !opts.Insert { + return false, 0, nil + } + + pkCol, ok := directImportPaginationPK(sourceType, sourceCols) + if !ok && !supportsDirectImportPagination(sourceType) { + return false, 0, nil + } + if !ok && len(opts.SelectedInsertPKs) > 0 { + return false, 0, nil + } + + firstPageQuery := buildPagedSourceTableQuery(sourceType, plan.SourceQueryTable, sourceCols, pkCol, defaultSyncReadPageSize, 0) + if strings.TrimSpace(firstPageQuery) == "" { + return false, 0, nil + } + + applier, ok := targetDB.(db.BatchApplier) + if !ok { + return true, 0, fmt.Errorf("目标驱动不支持应用数据变更 (ApplyChanges)") + } + + if strings.TrimSpace(pkCol) != "" { + s.appendLog(config.JobID, res, "info", fmt.Sprintf(" -> 启用分页流式导入:按主键 %s 每批读取 %d 行", pkCol, defaultSyncReadPageSize)) + } else { + s.appendLog(config.JobID, res, "info", fmt.Sprintf(" -> 启用分页流式导入:每批读取 %d 行", defaultSyncReadPageSize)) + } + s.progress(config.JobID, tableIndex, totalTables, tableName, "分页读取源表数据") + firstRows, _, err := sourceDB.Query(firstPageQuery) + if err != nil { + return true, 0, fmt.Errorf("分页读取源表失败: %w", err) + } + + if tableMode == "full_overwrite" && plan.TargetTableExists { + s.appendLog(config.JobID, res, "warn", fmt.Sprintf(" -> 全量覆盖模式:即将清空目标表 %s", tableName)) + s.progress(config.JobID, tableIndex, totalTables, tableName, "清空目标表") + clearSQL := buildClearTargetTableSQL(targetType, plan.TargetQueryTable) + if _, err := targetDB.Exec(clearSQL); err != nil { + return true, 0, fmt.Errorf("清空目标表失败: %w", err) + } + } + + targetColSet, err := s.prepareDirectImportTargetColumnSet(config, res, targetDB, plan, sourceType, targetType, sourceCols, targetCols) + if err != nil { + return true, 0, err + } + + inserted, err := s.applyDirectImportPage(config.JobID, res, applyTableName, applier, targetColSet, pkCol, opts, firstRows) + if err != nil { + return true, inserted, err + } + if len(firstRows) < defaultSyncReadPageSize { + return true, inserted, nil + } + + for offset := defaultSyncReadPageSize; ; offset += defaultSyncReadPageSize { + s.progress(config.JobID, tableIndex, totalTables, tableName, fmt.Sprintf("分页读取源表数据(%d+)", offset)) + query := buildPagedSourceTableQuery(sourceType, plan.SourceQueryTable, sourceCols, pkCol, defaultSyncReadPageSize, offset) + rows, _, err := sourceDB.Query(query) + if err != nil { + return true, inserted, fmt.Errorf("分页读取源表失败(offset=%d): %w", offset, err) + } + if len(rows) == 0 { + return true, inserted, nil + } + applied, err := s.applyDirectImportPage(config.JobID, res, applyTableName, applier, targetColSet, pkCol, opts, rows) + inserted += applied + if err != nil { + return true, inserted, err + } + if len(rows) < defaultSyncReadPageSize { + return true, inserted, nil + } + } +} + +func (s *SyncEngine) prepareDirectImportTargetColumnSet(config SyncConfig, res *SyncResult, targetDB db.Database, plan SchemaMigrationPlan, sourceType, targetType string, sourceCols, targetCols []connection.ColumnDefinition) (map[string]struct{}, error) { + targetColsResolved := targetCols + if len(targetColsResolved) == 0 { + cols, err := targetDB.GetColumns(plan.TargetSchema, plan.TargetTable) + if err != nil { + s.appendLog(config.JobID, res, "warn", fmt.Sprintf(" -> 获取目标表字段失败,已跳过字段一致性检查: %v", err)) + return nil, nil + } + targetColsResolved = cols + } + if len(targetColsResolved) == 0 { + return nil, nil + } + + targetColSet := buildTargetColumnSet(targetColsResolved) + missing := missingSourceColumns(sourceCols, targetColSet) + if len(missing) == 0 { + return targetColSet, nil + } + + if config.AutoAddColumns && supportsAutoAddColumnsForPair(sourceType, targetType) { + s.appendLog(config.JobID, res, "warn", fmt.Sprintf(" -> 目标表缺少字段 %d 个,开始自动补齐: %s", len(missing), strings.Join(missing, ", "))) + added := 0 + sourceColsByLower := make(map[string]connection.ColumnDefinition, len(sourceCols)) + for _, col := range sourceCols { + key := strings.ToLower(strings.TrimSpace(col.Name)) + if key != "" { + sourceColsByLower[key] = col + } + } + for _, colName := range missing { + srcCol, ok := sourceColsByLower[strings.ToLower(strings.TrimSpace(colName))] + if !ok { + continue + } + alterSQL, err := buildAddColumnSQLForPair(sourceType, targetType, plan.TargetQueryTable, srcCol) + if err != nil { + s.appendLog(config.JobID, res, "error", fmt.Sprintf(" -> 自动补字段失败:字段=%s 错误=%v", colName, err)) + continue + } + if _, err := targetDB.Exec(alterSQL); err != nil { + s.appendLog(config.JobID, res, "error", fmt.Sprintf(" -> 自动补字段失败:字段=%s 错误=%v", colName, err)) + continue + } + added++ + targetColSet[strings.ToLower(strings.TrimSpace(colName))] = struct{}{} + } + s.appendLog(config.JobID, res, "info", fmt.Sprintf(" -> 自动补字段完成:成功=%d 失败=%d", added, len(missing)-added)) + return targetColSet, nil + } + + s.appendLog(config.JobID, res, "warn", fmt.Sprintf(" -> 目标表缺少字段 %d 个(未开启自动补齐),将自动忽略:%s", len(missing), strings.Join(missing, ", "))) + return targetColSet, nil +} + +func missingSourceColumns(sourceCols []connection.ColumnDefinition, targetColSet map[string]struct{}) []string { + missing := make([]string, 0) + seen := make(map[string]struct{}, len(sourceCols)) + for _, col := range sourceCols { + name := strings.TrimSpace(col.Name) + lower := strings.ToLower(name) + if name == "" { + continue + } + if _, ok := seen[lower]; ok { + continue + } + seen[lower] = struct{}{} + if _, ok := targetColSet[lower]; !ok { + missing = append(missing, name) + } + } + return missing +} + +func (s *SyncEngine) applyDirectImportPage(jobID string, res *SyncResult, tableName string, applier db.BatchApplier, targetColSet map[string]struct{}, pkCol string, opts TableOptions, rows []map[string]interface{}) (int, error) { + if len(rows) == 0 { + return 0, nil + } + rows = filterRowsByPKSelection(pkCol, rows, opts.Insert, opts.SelectedInsertPKs) + if len(rows) == 0 { + return 0, nil + } + if len(targetColSet) > 0 { + rows = filterInsertRows(rows, targetColSet) + } + if len(rows) == 0 { + return 0, nil + } + changeSet := connection.ChangeSet{Inserts: rows} + if err := s.applyChangesInBatches(jobID, res, tableName, applier, changeSet); err != nil { + return 0, err + } + return len(rows), nil +} + +func directImportPaginationPK(sourceType string, sourceCols []connection.ColumnDefinition) (string, bool) { + if !supportsDirectImportPagination(sourceType) { + return "", false + } + pkCols := make([]string, 0, 2) + for _, col := range sourceCols { + if col.Key == "PRI" || col.Key == "PK" { + pkCols = append(pkCols, col.Name) + } + } + if len(pkCols) != 1 || strings.TrimSpace(pkCols[0]) == "" { + return "", false + } + return pkCols[0], true +} + +func supportsDirectImportPagination(dbType string) bool { + switch normalizeMigrationDBType(dbType) { + case "mysql", "mariadb", "postgres", "kingbase", "highgo", "vastbase", "opengauss", "sqlserver", "sqlite", "duckdb", "clickhouse", "tdengine", "starrocks", "diros": + return true + default: + return false + } +} + +func buildPagedSourceTableQuery(dbType, queryTable string, cols []connection.ColumnDefinition, orderCol string, limit, offset int) string { + selectList := buildSourceColumnSelectList(dbType, cols) + if strings.TrimSpace(selectList) == "" { + return "" + } + pageSelectList := selectList + if normalizeMigrationDBType(dbType) == "sqlserver" { + pageSelectList = buildSQLServerPageSelectList(cols) + } + baseSQL := fmt.Sprintf("SELECT %s FROM %s", selectList, quoteQualifiedIdentByType(dbType, queryTable)) + orderBy := "" + if strings.TrimSpace(orderCol) != "" { + orderBy = fmt.Sprintf(" ORDER BY %s ASC", quoteIdentByType(dbType, orderCol)) + } + return buildPaginatedSelectSQLForSync(dbType, baseSQL, pageSelectList, orderBy, limit, offset) +} + +func buildSourceColumnSelectList(dbType string, cols []connection.ColumnDefinition) string { + quoted := make([]string, 0, len(cols)) + for _, col := range cols { + name := strings.TrimSpace(col.Name) + if name == "" { + continue + } + quoted = append(quoted, quoteIdentByType(dbType, name)) + } + return strings.Join(quoted, ", ") +} + +func buildSQLServerPageSelectList(cols []connection.ColumnDefinition) string { + quoted := make([]string, 0, len(cols)) + for _, col := range cols { + name := strings.TrimSpace(col.Name) + if name == "" { + continue + } + quoted = append(quoted, fmt.Sprintf("[__gonavi_page_result__].%s", quoteIdentByType("sqlserver", name))) + } + return strings.Join(quoted, ", ") +} + +func buildPaginatedSelectSQLForSync(dbType, baseSQL, selectList, orderBySQL string, limit, offset int) string { + safeLimit := limit + if safeLimit <= 0 { + safeLimit = defaultSyncReadPageSize + } + safeOffset := offset + if safeOffset < 0 { + safeOffset = 0 + } + base := strings.TrimSpace(baseSQL) + orderBy := strings.TrimSpace(orderBySQL) + + switch normalizeMigrationDBType(dbType) { + case "sqlserver": + upperBound := safeOffset + safeLimit + if orderBy == "" { + orderBy = "ORDER BY (SELECT NULL)" + } + return fmt.Sprintf("SELECT %s FROM (SELECT [__gonavi_page__].*, ROW_NUMBER() OVER (%s) AS [__gonavi_rn__] FROM (%s) AS [__gonavi_page__]) AS [__gonavi_page_result__] WHERE [__gonavi_rn__] > %d AND [__gonavi_rn__] <= %d ORDER BY [__gonavi_rn__]", selectList, orderBy, base, safeOffset, upperBound) + default: + return fmt.Sprintf("%s %s LIMIT %d OFFSET %d", base, orderBy, safeLimit, safeOffset) + } +} + +func buildClearTargetTableSQL(targetType, targetQueryTable string) string { + quotedTable := quoteQualifiedIdentByType(targetType, targetQueryTable) + if normalizeMigrationDBType(targetType) == "mysql" { + return fmt.Sprintf("TRUNCATE TABLE %s", quotedTable) + } + return fmt.Sprintf("DELETE FROM %s", quotedTable) +} + +func isSamePhysicalSyncTable(config SyncConfig, plan SchemaMigrationPlan, sourceType, targetType string) bool { + if normalizeMigrationDBType(sourceType) != normalizeMigrationDBType(targetType) { + return false + } + if !strings.EqualFold(strings.TrimSpace(plan.SourceQueryTable), strings.TrimSpace(plan.TargetQueryTable)) { + return false + } + source := config.SourceConfig + target := config.TargetConfig + return strings.EqualFold(strings.TrimSpace(source.Host), strings.TrimSpace(target.Host)) && + source.Port == target.Port && + strings.EqualFold(strings.TrimSpace(source.Database), strings.TrimSpace(target.Database)) && + strings.EqualFold(strings.TrimSpace(source.Driver), strings.TrimSpace(target.Driver)) && + strings.EqualFold(strings.TrimSpace(source.DSN), strings.TrimSpace(target.DSN)) +} diff --git a/internal/sync/preview.go b/internal/sync/preview.go index 7b21a9d..7d4a432 100644 --- a/internal/sync/preview.go +++ b/internal/sync/preview.go @@ -104,31 +104,8 @@ func (s *SyncEngine) Preview(config SyncConfig, tableName string, limit int) (Ta } pkCol := pkCols[0] - sourceRows, _, err := sourceDB.Query(fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(resolveMigrationDBType(config.SourceConfig), plan.SourceQueryTable))) - if err != nil { - return TableDiffPreview{}, fmt.Errorf("读取源表失败: %w", err) - } - - targetRows := make([]map[string]interface{}, 0) - if plan.TargetTableExists { - targetRows, _, err = targetDB.Query(fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(resolveMigrationDBType(config.TargetConfig), plan.TargetQueryTable))) - if err != nil { - return TableDiffPreview{}, fmt.Errorf("读取目标表失败: %w", err) - } - } - - targetMap := make(map[string]map[string]interface{}, len(targetRows)) - for _, row := range targetRows { - if row[pkCol] == nil { - continue - } - pkVal := strings.TrimSpace(fmt.Sprintf("%v", row[pkCol])) - if pkVal == "" || pkVal == "" { - continue - } - targetMap[pkVal] = row - } - + sourceType := resolveMigrationDBType(config.SourceConfig) + targetType := resolveMigrationDBType(config.TargetConfig) out := TableDiffPreview{ Table: tableName, PKColumn: pkCol, @@ -152,6 +129,119 @@ func (s *SyncEngine) Preview(config SyncConfig, tableName string, limit int) (Ta out.ColumnTypes[name] = typ } + tableMode := normalizeSyncMode(config.Mode) + targetColSet := map[string]struct{}{} + if plan.TargetTableExists { + targetCols, err := targetDB.GetColumns(plan.TargetSchema, plan.TargetTable) + if err == nil { + targetColSet = buildTargetColumnSet(targetCols) + } + } + + if !plan.TargetTableExists || tableMode != "insert_update" { + sourceCount, counted, err := countTableRowsForSync(sourceDB, sourceType, plan.SourceQueryTable) + if err != nil { + return TableDiffPreview{}, fmt.Errorf("读取源表数量失败: %w", err) + } + query := buildPagedSourceTableQuery(sourceType, plan.SourceQueryTable, cols, pkCol, limit, 0) + if strings.TrimSpace(query) == "" { + return TableDiffPreview{}, fmt.Errorf("当前数据源不支持分页预览") + } + sourceRows, _, err := sourceDB.Query(query) + if err != nil { + return TableDiffPreview{}, fmt.Errorf("读取源表失败: %w", err) + } + if !counted { + sourceCount = len(sourceRows) + } + out.TotalInserts = sourceCount + for _, row := range sourceRows { + if len(out.Inserts) >= limit { + break + } + pkVal := strings.TrimSpace(fmt.Sprintf("%v", row[pkCol])) + if pkVal == "" || pkVal == "" { + continue + } + out.Inserts = append(out.Inserts, PreviewRow{PK: pkVal, Row: row}) + } + return out, nil + } + + handled, _, err := scanTableDiffInPages(sourceDB, targetDB, sourceType, targetType, plan, cols, nil, pkCol, targetColSet, true, func(page pagedDiffPage) error { + out.TotalInserts += len(page.Inserts) + out.TotalUpdates += len(page.Updates) + out.TotalDeletes += len(page.Deletes) + + for _, row := range page.Inserts { + if len(out.Inserts) >= limit { + break + } + pkVal := strings.TrimSpace(fmt.Sprintf("%v", row[pkCol])) + if pkVal == "" || pkVal == "" { + continue + } + out.Inserts = append(out.Inserts, PreviewRow{PK: pkVal, Row: row}) + } + for _, update := range page.Updates { + if len(out.Updates) >= limit { + break + } + pkVal := strings.TrimSpace(fmt.Sprintf("%v", update.UpdateRow.Keys[pkCol])) + if pkVal == "" || pkVal == "" { + continue + } + out.Updates = append(out.Updates, PreviewUpdateRow{ + PK: pkVal, + ChangedColumns: append([]string(nil), update.ChangedColumns...), + Source: update.Source, + Target: update.Target, + }) + } + for _, row := range page.Deletes { + if len(out.Deletes) >= limit { + break + } + pkVal := strings.TrimSpace(fmt.Sprintf("%v", row[pkCol])) + if pkVal == "" || pkVal == "" { + continue + } + out.Deletes = append(out.Deletes, PreviewRow{PK: pkVal, Row: row}) + } + return nil + }) + if handled { + if err != nil { + return TableDiffPreview{}, err + } + return out, nil + } + + sourceRows, _, err := sourceDB.Query(fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(sourceType, plan.SourceQueryTable))) + if err != nil { + return TableDiffPreview{}, fmt.Errorf("读取源表失败: %w", err) + } + + targetRows := make([]map[string]interface{}, 0) + if plan.TargetTableExists { + targetRows, _, err = targetDB.Query(fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(targetType, plan.TargetQueryTable))) + if err != nil { + return TableDiffPreview{}, fmt.Errorf("读取目标表失败: %w", err) + } + } + + targetMap := make(map[string]map[string]interface{}, len(targetRows)) + for _, row := range targetRows { + if row[pkCol] == nil { + continue + } + pkVal := strings.TrimSpace(fmt.Sprintf("%v", row[pkCol])) + if pkVal == "" || pkVal == "" { + continue + } + targetMap[pkVal] = row + } + sourcePKSet := make(map[string]struct{}, len(sourceRows)) for _, sRow := range sourceRows { if sRow[pkCol] == nil { diff --git a/internal/sync/schema_migration_test.go b/internal/sync/schema_migration_test.go index 5c216c4..1539fce 100644 --- a/internal/sync/schema_migration_test.go +++ b/internal/sync/schema_migration_test.go @@ -14,12 +14,14 @@ type fakeMigrationDB struct { tables map[string][]string queryData map[string][]map[string]interface{} queryCols map[string][]string + queryLog []string } func (f *fakeMigrationDB) Connect(config connection.ConnectionConfig) error { return nil } func (f *fakeMigrationDB) Close() error { return nil } func (f *fakeMigrationDB) Ping() error { return nil } func (f *fakeMigrationDB) Query(query string) ([]map[string]interface{}, []string, error) { + f.queryLog = append(f.queryLog, query) if rows, ok := f.queryData[query]; ok { return rows, f.queryCols[query], nil } diff --git a/internal/sync/source_query_paging.go b/internal/sync/source_query_paging.go new file mode 100644 index 0000000..1faa6c1 --- /dev/null +++ b/internal/sync/source_query_paging.go @@ -0,0 +1,286 @@ +package sync + +import ( + "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/db" + "fmt" + "strings" +) + +func (s *SyncEngine) tryApplySourceQueryInPages(config SyncConfig, res *SyncResult, tableName string, sourceDB db.Database, targetDB db.Database, ctx sourceQuerySyncContext, opts TableOptions, tableMode string, applyTableName string) (bool, pagedDiffCounts, error) { + sourceType := resolveMigrationDBType(config.SourceConfig) + if !supportsPagedSourceQuery(sourceType) || !supportsPagedDiffPKLookup(ctx.TargetType) { + return false, pagedDiffCounts{}, nil + } + if strings.TrimSpace(buildSourceQueryPageSQL(sourceType, config.SourceQuery, ctx.PKColumn, defaultSyncReadPageSize, 0)) == "" { + return false, pagedDiffCounts{}, nil + } + + applier, ok := targetDB.(db.BatchApplier) + if !ok { + return true, pagedDiffCounts{}, fmt.Errorf("目标驱动不支持应用数据变更 (ApplyChanges)") + } + targetColSet := buildTargetColumnSet(ctx.TargetCols) + counts := pagedDiffCounts{} + + if tableMode == "insert_update" { + includeDeletes := opts.Delete + handled, _, err := scanSourceQueryDiffInPages(sourceDB, targetDB, sourceType, ctx.TargetType, strings.TrimSpace(config.SourceQuery), ctx.TargetQueryTable, ctx.TargetCols, ctx.PKColumn, includeDeletes, func(page pagedDiffPage) error { + changeSet := connection.ChangeSet{ + Inserts: filterRowsByPKSelection(ctx.PKColumn, page.Inserts, opts.Insert, opts.SelectedInsertPKs), + Updates: filterPagedUpdatesByPKSelection(ctx.PKColumn, page.Updates, opts.Update, opts.SelectedUpdatePKs), + Deletes: filterRowsByPKSelection(ctx.PKColumn, page.Deletes, opts.Delete, opts.SelectedDeletePKs), + } + changeSet.Inserts = filterInsertRows(changeSet.Inserts, targetColSet) + changeSet.Updates = filterUpdateRows(changeSet.Updates, targetColSet) + if len(changeSet.Inserts) == 0 && len(changeSet.Updates) == 0 && len(changeSet.Deletes) == 0 { + return nil + } + if err := s.applyChangesInBatches(config.JobID, res, applyTableName, applier, changeSet); err != nil { + return err + } + counts.Inserts += len(changeSet.Inserts) + counts.Updates += len(changeSet.Updates) + counts.Deletes += len(changeSet.Deletes) + return nil + }) + if err != nil { + return true, counts, err + } + return handled, counts, nil + } + + if tableMode == "full_overwrite" { + clearSQL := buildClearTargetTableSQL(ctx.TargetType, ctx.TargetQueryTable) + if _, err := targetDB.Exec(clearSQL); err != nil { + return true, counts, fmt.Errorf("清空目标表失败: %w", err) + } + } + if !opts.Insert { + return true, counts, nil + } + + for offset := 0; ; offset += defaultSyncReadPageSize { + query := buildSourceQueryPageSQL(sourceType, config.SourceQuery, ctx.PKColumn, defaultSyncReadPageSize, offset) + rows, _, err := sourceDB.Query(query) + if err != nil { + return true, counts, fmt.Errorf("分页读取源查询失败(offset=%d): %w", offset, err) + } + if len(rows) == 0 { + return true, counts, nil + } + pageSize := len(rows) + insertRows := filterRowsByPKSelection(ctx.PKColumn, rows, opts.Insert, opts.SelectedInsertPKs) + insertRows = filterInsertRows(insertRows, targetColSet) + if len(insertRows) > 0 { + if err := s.applyChangesInBatches(config.JobID, res, applyTableName, applier, connection.ChangeSet{Inserts: insertRows}); err != nil { + return true, counts, err + } + counts.Inserts += len(insertRows) + } + if pageSize < defaultSyncReadPageSize { + return true, counts, nil + } + } +} + +func scanSourceQueryDiffInPages(sourceDB db.Database, targetDB db.Database, sourceType, targetType, sourceQuery, targetQueryTable string, targetCols []connection.ColumnDefinition, pkCol string, includeDeletes bool, consume func(page pagedDiffPage) error) (bool, pagedDiffCounts, error) { + if !supportsPagedSourceQuery(sourceType) || !supportsPagedDiffPKLookup(targetType) { + return false, pagedDiffCounts{}, nil + } + if includeDeletes && (!supportsPagedDiffKeysetSelect(targetType) || !supportsPagedSourceQueryPKLookup(sourceType)) { + return false, pagedDiffCounts{}, nil + } + + sourcePageQuery := buildSourceQueryPageSQL(sourceType, sourceQuery, pkCol, defaultSyncReadPageSize, 0) + if strings.TrimSpace(sourcePageQuery) == "" { + return false, pagedDiffCounts{}, nil + } + targetLookupCols := diffLookupColumns(targetCols, targetCols, buildTargetColumnSet(targetCols), pkCol) + if len(targetLookupCols) == 0 { + targetLookupCols = []connection.ColumnDefinition{{Name: pkCol}} + } + + totals := pagedDiffCounts{} + for offset := 0; ; offset += defaultSyncReadPageSize { + query := buildSourceQueryPageSQL(sourceType, sourceQuery, pkCol, defaultSyncReadPageSize, offset) + sourceRows, _, err := sourceDB.Query(query) + if err != nil { + return true, totals, fmt.Errorf("分页读取源查询失败(offset=%d): %w", offset, err) + } + if len(sourceRows) == 0 { + break + } + + pkValues := collectPKValues(sourceRows, pkCol) + targetRows := make([]map[string]interface{}, 0) + if len(pkValues) > 0 { + targetQuery := buildPKInSelectQuery(targetType, targetQueryTable, targetLookupCols, pkCol, pkValues) + if strings.TrimSpace(targetQuery) == "" { + return false, pagedDiffCounts{}, nil + } + targetRows, _, err = targetDB.Query(targetQuery) + if err != nil { + return true, totals, fmt.Errorf("按主键读取目标表失败(offset=%d): %w", offset, err) + } + } + + page := diffSourcePageByPK(pkCol, sourceRows, targetRows) + totals.Inserts += len(page.Inserts) + totals.Updates += len(page.Updates) + totals.Same += page.Same + if consume != nil { + if err := consume(page); err != nil { + return true, totals, err + } + } + if len(sourceRows) < defaultSyncReadPageSize { + break + } + } + + if includeDeletes { + lastPK, hasLastPK := interface{}(nil), false + targetPKCols := []connection.ColumnDefinition{{Name: pkCol}} + for { + query := buildKeysetPagedTableQuery(targetType, targetQueryTable, targetPKCols, pkCol, lastPK, hasLastPK, defaultSyncReadPageSize) + targetRows, _, err := targetDB.Query(query) + if err != nil { + return true, totals, fmt.Errorf("分页读取目标主键失败: %w", err) + } + if len(targetRows) == 0 { + break + } + + nextLastPK, ok := lastValidPKValue(targetRows, pkCol) + if !ok { + break + } + lastPK, hasLastPK = nextLastPK, true + + pkValues := collectPKValues(targetRows, pkCol) + sourcePKRows := make([]map[string]interface{}, 0) + if len(pkValues) > 0 { + sourceQuery := buildSourceQueryPKInSelectSQL(sourceType, sourceQuery, []connection.ColumnDefinition{{Name: pkCol}}, pkCol, pkValues) + if strings.TrimSpace(sourceQuery) == "" { + return false, pagedDiffCounts{}, nil + } + sourcePKRows, _, err = sourceDB.Query(sourceQuery) + if err != nil { + return true, totals, fmt.Errorf("按主键反查源查询失败: %w", err) + } + } + + sourcePKSet := buildPKSet(sourcePKRows, pkCol) + deletes := make([]map[string]interface{}, 0) + for _, row := range targetRows { + pkKey, ok := pkValueKey(row[pkCol]) + if !ok { + continue + } + if _, exists := sourcePKSet[pkKey]; exists { + continue + } + deletes = append(deletes, map[string]interface{}{pkCol: row[pkCol]}) + } + if len(deletes) > 0 { + totals.Deletes += len(deletes) + if consume != nil { + if err := consume(pagedDiffPage{Deletes: deletes}); err != nil { + return true, totals, err + } + } + } + if len(targetRows) < defaultSyncReadPageSize { + break + } + } + } + return true, totals, nil +} + +func buildSourceQueryPageSQL(dbType, sourceQuery, orderCol string, limit, offset int) string { + subquery, ok := normalizeSourceQueryForPaging(sourceQuery) + if !ok { + return "" + } + baseSQL := fmt.Sprintf("SELECT * FROM (%s) AS __gonavi_source_query__", subquery) + orderBy := "" + if strings.TrimSpace(orderCol) != "" { + orderBy = fmt.Sprintf(" ORDER BY %s ASC", quoteIdentByType(dbType, orderCol)) + } + return buildPaginatedSelectSQLForSync(dbType, baseSQL, "*", orderBy, limit, offset) +} + +func buildSourceQueryPKInSelectSQL(dbType, sourceQuery string, cols []connection.ColumnDefinition, pkCol string, pkValues []interface{}) string { + subquery, ok := normalizeSourceQueryForPaging(sourceQuery) + if !ok || len(pkValues) == 0 { + return "" + } + selectList := buildColumnSelectListForSync(dbType, cols) + if strings.TrimSpace(selectList) == "" { + selectList = "*" + } + literals := make([]string, 0, len(pkValues)) + for _, value := range pkValues { + literal, ok := formatSyncSQLLiteral(value) + if ok { + literals = append(literals, literal) + } + } + if len(literals) == 0 { + return "" + } + return fmt.Sprintf("SELECT %s FROM (%s) AS __gonavi_source_query__ WHERE %s IN (%s)", + selectList, + subquery, + quoteIdentByType(dbType, pkCol), + strings.Join(literals, ", ")) +} + +func countSourceQueryRowsForSync(database db.Database, dbType, sourceQuery string) (int, bool, error) { + subquery, ok := normalizeSourceQueryForPaging(sourceQuery) + if !ok { + return 0, false, nil + } + query := fmt.Sprintf("SELECT COUNT(*) AS __gonavi_count__ FROM (%s) AS __gonavi_source_query__", subquery) + rows, _, err := database.Query(query) + if err != nil { + return 0, true, err + } + if len(rows) == 0 { + return 0, false, nil + } + for _, value := range rows[0] { + count, ok := intFromSyncValue(value) + if ok { + return count, true, nil + } + } + return 0, false, nil +} + +func normalizeSourceQueryForPaging(query string) (string, bool) { + trimmed := strings.TrimSpace(query) + if trimmed == "" { + return "", false + } + trimmed = strings.TrimSuffix(trimmed, ";") + trimmed = strings.TrimSpace(trimmed) + lower := strings.ToLower(trimmed) + if !(strings.HasPrefix(lower, "select ") || strings.HasPrefix(lower, "with ")) { + return "", false + } + if strings.Contains(trimmed, ";") { + return "", false + } + return trimmed, true +} + +func supportsPagedSourceQuery(dbType string) bool { + return supportsDirectImportPagination(dbType) +} + +func supportsPagedSourceQueryPKLookup(dbType string) bool { + return supportsDirectImportPagination(dbType) +} diff --git a/internal/sync/source_query_sync.go b/internal/sync/source_query_sync.go index d299080..4b36cc1 100644 --- a/internal/sync/source_query_sync.go +++ b/internal/sync/source_query_sync.go @@ -68,7 +68,7 @@ func resolveSinglePKColumn(cols []connection.ColumnDefinition) (string, error) { return pkCols[0], nil } -func loadSourceQuerySyncContext(config SyncConfig, sourceDB db.Database, targetDB db.Database, needTargetRows bool, requirePK bool) (sourceQuerySyncContext, error) { +func loadSourceQuerySyncContext(config SyncConfig, sourceDB db.Database, targetDB db.Database, needSourceRows bool, needTargetRows bool, requirePK bool) (sourceQuerySyncContext, error) { tableName, err := validateSourceQuerySyncConfig(config) if err != nil { return sourceQuerySyncContext{}, err @@ -83,11 +83,6 @@ func loadSourceQuerySyncContext(config SyncConfig, sourceDB db.Database, targetD return sourceQuerySyncContext{}, fmt.Errorf("目标表 %s 不存在或未读取到字段定义", tableName) } - sourceRows, _, err := sourceDB.Query(strings.TrimSpace(config.SourceQuery)) - if err != nil { - return sourceQuerySyncContext{}, fmt.Errorf("执行源查询失败: %w", err) - } - ctx := sourceQuerySyncContext{ TableName: tableName, TargetSchema: targetSchema, @@ -95,10 +90,18 @@ func loadSourceQuerySyncContext(config SyncConfig, sourceDB db.Database, targetD TargetQueryTable: targetQueryTable, TargetType: targetType, TargetCols: targetCols, - SourceRows: sourceRows, + SourceRows: make([]map[string]interface{}, 0), TargetRows: make([]map[string]interface{}, 0), } + if needSourceRows { + sourceRows, _, err := sourceDB.Query(strings.TrimSpace(config.SourceQuery)) + if err != nil { + return sourceQuerySyncContext{}, fmt.Errorf("执行源查询失败: %w", err) + } + ctx.SourceRows = sourceRows + } + if requirePK { pkColumn, err := resolveSinglePKColumn(targetCols) if err != nil { @@ -226,7 +229,40 @@ func (s *SyncEngine) analyzeSourceQuery(config SyncConfig) SyncAnalyzeResult { Table: tableName, CanSync: false, } - ctx, err := loadSourceQuerySyncContext(config, sourceDB, targetDB, true, true) + ctx, err := loadSourceQuerySyncContext(config, sourceDB, targetDB, false, false, true) + if err != nil { + summary.Message = err.Error() + result.Tables = append(result.Tables, summary) + result.Message = "已完成 1 个目标表的差异分析" + s.progress(config.JobID, totalTables, totalTables, tableName, "差异分析完成") + return result + } + + sourceType := resolveMigrationDBType(config.SourceConfig) + handled, counts, scanErr := scanSourceQueryDiffInPages(sourceDB, targetDB, sourceType, ctx.TargetType, strings.TrimSpace(config.SourceQuery), ctx.TargetQueryTable, ctx.TargetCols, ctx.PKColumn, true, nil) + if handled { + if scanErr != nil { + summary.Message = scanErr.Error() + result.Tables = append(result.Tables, summary) + result.Message = "已完成 1 个目标表的差异分析" + s.progress(config.JobID, totalTables, totalTables, tableName, "差异分析完成") + return result + } + summary.CanSync = true + summary.PKColumn = ctx.PKColumn + summary.Inserts = counts.Inserts + summary.Updates = counts.Updates + summary.Deletes = counts.Deletes + summary.Same = counts.Same + summary.TargetTableExists = true + summary.Message = "SQL 结果集差异分析完成" + result.Tables = append(result.Tables, summary) + result.Message = "已完成 1 个目标表的差异分析" + s.progress(config.JobID, totalTables, totalTables, tableName, "差异分析完成") + return result + } + + ctx, err = loadSourceQuerySyncContext(config, sourceDB, targetDB, true, true, true) if err != nil { summary.Message = err.Error() result.Tables = append(result.Tables, summary) @@ -270,13 +306,83 @@ func (s *SyncEngine) previewSourceQuery(config SyncConfig, limit int) (TableDiff } defer targetDB.Close() - ctx, err := loadSourceQuerySyncContext(config, sourceDB, targetDB, true, true) + ctx, err := loadSourceQuerySyncContext(config, sourceDB, targetDB, false, false, true) + if err != nil { + return TableDiffPreview{}, err + } + + sourceType := resolveMigrationDBType(config.SourceConfig) + out := TableDiffPreview{ + Table: ctx.TableName, + PKColumn: ctx.PKColumn, + ColumnTypes: make(map[string]string, len(ctx.TargetCols)), + SchemaSummary: "SQL 结果集同步预览", + Inserts: make([]PreviewRow, 0, limit), + Updates: make([]PreviewUpdateRow, 0, limit), + Deletes: make([]PreviewRow, 0, limit), + } + for _, col := range ctx.TargetCols { + name := strings.ToLower(strings.TrimSpace(col.Name)) + typ := strings.TrimSpace(col.Type) + if name == "" || typ == "" { + continue + } + out.ColumnTypes[name] = typ + } + + handled, _, scanErr := scanSourceQueryDiffInPages(sourceDB, targetDB, sourceType, ctx.TargetType, strings.TrimSpace(config.SourceQuery), ctx.TargetQueryTable, ctx.TargetCols, ctx.PKColumn, true, func(page pagedDiffPage) error { + out.TotalInserts += len(page.Inserts) + out.TotalUpdates += len(page.Updates) + out.TotalDeletes += len(page.Deletes) + for _, row := range page.Inserts { + if len(out.Inserts) >= limit { + break + } + pk := strings.TrimSpace(fmt.Sprintf("%v", row[ctx.PKColumn])) + if pk != "" && pk != "" { + out.Inserts = append(out.Inserts, PreviewRow{PK: pk, Row: row}) + } + } + for _, update := range page.Updates { + if len(out.Updates) >= limit { + break + } + pk := strings.TrimSpace(fmt.Sprintf("%v", update.UpdateRow.Keys[ctx.PKColumn])) + if pk == "" || pk == "" { + continue + } + out.Updates = append(out.Updates, PreviewUpdateRow{ + PK: pk, + ChangedColumns: append([]string(nil), update.ChangedColumns...), + Source: update.Source, + Target: update.Target, + }) + } + for _, row := range page.Deletes { + if len(out.Deletes) >= limit { + break + } + pk := strings.TrimSpace(fmt.Sprintf("%v", row[ctx.PKColumn])) + if pk != "" && pk != "" { + out.Deletes = append(out.Deletes, PreviewRow{PK: pk, Row: row}) + } + } + return nil + }) + if handled { + if scanErr != nil { + return TableDiffPreview{}, scanErr + } + return out, nil + } + + ctx, err = loadSourceQuerySyncContext(config, sourceDB, targetDB, true, true, true) if err != nil { return TableDiffPreview{}, err } inserts, updates, deletes, _ := diffRowsByPK(ctx.PKColumn, ctx.SourceRows, ctx.TargetRows) - out := TableDiffPreview{ + out = TableDiffPreview{ Table: ctx.TableName, PKColumn: ctx.PKColumn, ColumnTypes: make(map[string]string, len(ctx.TargetCols)), @@ -389,7 +495,7 @@ func (s *SyncEngine) runSourceQuerySync(config SyncConfig) SyncResult { needTargetRows := tableMode == "insert_update" requirePK := tableMode == "insert_update" - ctx, err := loadSourceQuerySyncContext(config, sourceDB, targetDB, needTargetRows, requirePK) + ctx, err := loadSourceQuerySyncContext(config, sourceDB, targetDB, false, false, requirePK) if err != nil { return s.fail(config.JobID, totalTables, result, err.Error()) } @@ -397,6 +503,33 @@ func (s *SyncEngine) runSourceQuerySync(config SyncConfig) SyncResult { inserts := make([]map[string]interface{}, 0) updates := make([]connection.UpdateRow, 0) deletes := make([]map[string]interface{}, 0) + applyTableName := ctx.TargetTable + switch ctx.TargetType { + case "postgres", "kingbase", "highgo", "vastbase", "opengauss", "sqlserver": + applyTableName = ctx.TargetQueryTable + } + + if handled, counts, err := s.tryApplySourceQueryInPages(config, &result, tableName, sourceDB, targetDB, ctx, opts, tableMode, applyTableName); handled { + if err != nil { + return s.fail(config.JobID, totalTables, result, "分页同步 SQL 结果集失败: "+err.Error()) + } + result.TablesSynced++ + result.RowsInserted += counts.Inserts + result.RowsUpdated += counts.Updates + result.RowsDeleted += counts.Deletes + if counts.Inserts == 0 && counts.Updates == 0 && counts.Deletes == 0 { + s.appendLog(config.JobID, &result, "info", "SQL 结果集与目标表一致,无需应用变更") + } else { + s.appendLog(config.JobID, &result, "info", fmt.Sprintf("SQL 结果集分页同步完成:插入=%d 更新=%d 删除=%d", counts.Inserts, counts.Updates, counts.Deletes)) + } + s.progress(config.JobID, totalTables, totalTables, tableName, "同步完成") + return result + } + + ctx, err = loadSourceQuerySyncContext(config, sourceDB, targetDB, true, needTargetRows, requirePK) + if err != nil { + return s.fail(config.JobID, totalTables, result, err.Error()) + } if tableMode == "insert_update" { inserts, updates, deletes, _ = diffRowsByPK(ctx.PKColumn, ctx.SourceRows, ctx.TargetRows) inserts = filterRowsByPKSelection(ctx.PKColumn, inserts, opts.Insert, opts.SelectedInsertPKs) @@ -431,16 +564,11 @@ func (s *SyncEngine) runSourceQuerySync(config SyncConfig) SyncResult { return result } - applyTableName := ctx.TargetTable - switch ctx.TargetType { - case "postgres", "kingbase", "highgo", "vastbase", "opengauss", "sqlserver": - applyTableName = ctx.TargetQueryTable - } applier, ok := targetDB.(db.BatchApplier) if !ok { return s.fail(config.JobID, totalTables, result, "目标驱动不支持应用数据变更 (ApplyChanges)") } - if err := applier.ApplyChanges(applyTableName, changeSet); err != nil { + if err := s.applyChangesInBatches(config.JobID, &result, applyTableName, applier, changeSet); err != nil { return s.fail(config.JobID, totalTables, result, "应用 SQL 结果集变更失败: "+err.Error()) } diff --git a/internal/sync/source_query_sync_test.go b/internal/sync/source_query_sync_test.go index 59ad071..1ecbf4d 100644 --- a/internal/sync/source_query_sync_test.go +++ b/internal/sync/source_query_sync_test.go @@ -4,6 +4,7 @@ import ( "GoNavi-Wails/internal/connection" "GoNavi-Wails/internal/db" "reflect" + "strings" "testing" ) @@ -11,11 +12,15 @@ type fakeQuerySyncTargetDB struct { fakeMigrationDB appliedTable string appliedChanges connection.ChangeSet + appliedBatches []connection.ChangeSet } func (f *fakeQuerySyncTargetDB) ApplyChanges(tableName string, changes connection.ChangeSet) error { f.appliedTable = tableName - f.appliedChanges = changes + f.appliedChanges.Inserts = append(f.appliedChanges.Inserts, changes.Inserts...) + f.appliedChanges.Updates = append(f.appliedChanges.Updates, changes.Updates...) + f.appliedChanges.Deletes = append(f.appliedChanges.Deletes, changes.Deletes...) + f.appliedBatches = append(f.appliedBatches, changes) return nil } @@ -30,10 +35,13 @@ func TestAnalyze_SourceQueryUsesQueryResultAsSourceDataset(t *testing.T) { }, }, queryData: map[string][]map[string]interface{}{ - "SELECT id, name FROM active_users": { + "SELECT * FROM (SELECT id, name FROM active_users) AS __gonavi_source_query__ ORDER BY `id` ASC LIMIT 1000 OFFSET 0": { {"id": 1, "name": "Alice New"}, {"id": 2, "name": "Bob"}, }, + "SELECT `id` FROM (SELECT id, name FROM active_users) AS __gonavi_source_query__ WHERE `id` IN (1, 3)": { + {"id": 1}, + }, }, } targetDB := &fakeQuerySyncTargetDB{ @@ -45,8 +53,11 @@ func TestAnalyze_SourceQueryUsesQueryResultAsSourceDataset(t *testing.T) { }, }, queryData: map[string][]map[string]interface{}{ - "SELECT * FROM `app`.`users`": { + "SELECT `id`, `name` FROM `app`.`users` WHERE `id` IN (1, 2)": { {"id": 1, "name": "Alice Old"}, + }, + "SELECT `id` FROM `app`.`users` ORDER BY `id` ASC LIMIT 1000": { + {"id": 1}, {"id": 3, "name": "Carol"}, }, }, @@ -101,10 +112,13 @@ func TestRunSync_SourceQueryAppliesDiffAgainstTargetTable(t *testing.T) { }, }, queryData: map[string][]map[string]interface{}{ - "SELECT id, name FROM active_users": { + "SELECT * FROM (SELECT id, name FROM active_users) AS __gonavi_source_query__ ORDER BY `id` ASC LIMIT 1000 OFFSET 0": { {"id": 1, "name": "Alice New"}, {"id": 2, "name": "Bob"}, }, + "SELECT `id` FROM (SELECT id, name FROM active_users) AS __gonavi_source_query__ WHERE `id` IN (1, 3)": { + {"id": 1}, + }, }, } targetDB := &fakeQuerySyncTargetDB{ @@ -116,8 +130,11 @@ func TestRunSync_SourceQueryAppliesDiffAgainstTargetTable(t *testing.T) { }, }, queryData: map[string][]map[string]interface{}{ - "SELECT * FROM `app`.`users`": { + "SELECT `id`, `name` FROM `app`.`users` WHERE `id` IN (1, 2)": { {"id": 1, "name": "Alice Old"}, + }, + "SELECT `id` FROM `app`.`users` ORDER BY `id` ASC LIMIT 1000": { + {"id": 1}, {"id": 3, "name": "Carol"}, }, }, @@ -175,3 +192,300 @@ func TestRunSync_SourceQueryAppliesDiffAgainstTargetTable(t *testing.T) { t.Fatalf("unexpected deletes: got=%v want=%v", targetDB.appliedChanges.Deletes, wantDeletes) } } + +func TestRunSync_SourceQueryInsertUpdateUsesPagedQueries(t *testing.T) { + columns := []connection.ColumnDefinition{ + {Name: "id", Type: "bigint", Nullable: "NO", Key: "PRI"}, + {Name: "name", Type: "varchar(64)", Nullable: "YES"}, + } + sourceDB := &fakeMigrationDB{ + queryData: map[string][]map[string]interface{}{ + "SELECT * FROM (SELECT id, name FROM active_users) AS __gonavi_source_query__ ORDER BY `id` ASC LIMIT 1000 OFFSET 0": { + {"id": 1, "name": "Alice New"}, + {"id": 2, "name": "Bob"}, + }, + "SELECT `id` FROM (SELECT id, name FROM active_users) AS __gonavi_source_query__ WHERE `id` IN (1, 3)": { + {"id": 1}, + }, + }, + } + targetDB := &fakeQuerySyncTargetDB{ + fakeMigrationDB: fakeMigrationDB{ + columns: map[string][]connection.ColumnDefinition{ + "app.users": columns, + }, + queryData: map[string][]map[string]interface{}{ + "SELECT `id`, `name` FROM `app`.`users` WHERE `id` IN (1, 2)": { + {"id": 1, "name": "Alice Old"}, + }, + "SELECT `id` FROM `app`.`users` ORDER BY `id` ASC LIMIT 1000": { + {"id": 1}, + {"id": 3}, + }, + }, + }, + } + + oldFactory := newSyncDatabase + defer func() { newSyncDatabase = oldFactory }() + callCount := 0 + newSyncDatabase = func(dbType string) (db.Database, error) { + callCount++ + if callCount == 1 { + return sourceDB, nil + } + return targetDB, nil + } + + engine := NewSyncEngine(Reporter{}) + result := engine.RunSync(SyncConfig{ + SourceConfig: connection.ConnectionConfig{Type: "mysql", Database: "app"}, + TargetConfig: connection.ConnectionConfig{Type: "mysql", Database: "app"}, + Tables: []string{"users"}, + Mode: "insert_update", + SourceQuery: "SELECT id, name FROM active_users", + TableOptions: map[string]TableOptions{ + "users": {Insert: true, Update: true, Delete: true}, + }, + }) + + if !result.Success { + t.Fatalf("RunSync 返回失败: %+v", result) + } + if result.RowsInserted != 1 || result.RowsUpdated != 1 || result.RowsDeleted != 1 { + t.Fatalf("unexpected sync result: %+v", result) + } + for _, query := range sourceDB.queryLog { + if query == "SELECT id, name FROM active_users" { + t.Fatalf("SQL 结果集分页同步不应全量执行原始查询,实际查询=%s", query) + } + } +} + +func TestRunSync_BatchesLargeTableChanges(t *testing.T) { + sourceRows := make([]map[string]interface{}, 2501) + for i := range sourceRows { + sourceRows[i] = map[string]interface{}{ + "id": i + 1, + "name": "event", + } + } + + columns := []connection.ColumnDefinition{ + {Name: "id", Type: "bigint", Nullable: "NO", Key: "PRI"}, + {Name: "name", Type: "varchar(64)", Nullable: "YES"}, + } + sourceDB := &fakeMigrationDB{ + columns: map[string][]connection.ColumnDefinition{ + "app.events": columns, + }, + queryData: map[string][]map[string]interface{}{ + "SELECT `id`, `name` FROM `app`.`events` ORDER BY `id` ASC LIMIT 1000 OFFSET 0": sourceRows[:1000], + "SELECT `id`, `name` FROM `app`.`events` ORDER BY `id` ASC LIMIT 1000 OFFSET 1000": sourceRows[1000:2000], + "SELECT `id`, `name` FROM `app`.`events` ORDER BY `id` ASC LIMIT 1000 OFFSET 2000": sourceRows[2000:], + }, + } + targetDB := &fakeQuerySyncTargetDB{ + fakeMigrationDB: fakeMigrationDB{ + columns: map[string][]connection.ColumnDefinition{ + "app.events": columns, + }, + }, + } + + oldFactory := newSyncDatabase + defer func() { newSyncDatabase = oldFactory }() + callCount := 0 + newSyncDatabase = func(dbType string) (db.Database, error) { + callCount++ + if callCount == 1 { + return sourceDB, nil + } + return targetDB, nil + } + + engine := NewSyncEngine(Reporter{}) + result := engine.RunSync(SyncConfig{ + SourceConfig: connection.ConnectionConfig{Type: "mysql", Database: "app"}, + TargetConfig: connection.ConnectionConfig{Type: "mysql", Database: "app"}, + Tables: []string{"events"}, + Mode: "insert_only", + }) + + if !result.Success { + t.Fatalf("RunSync 返回失败: %+v", result) + } + if result.RowsInserted != len(sourceRows) { + t.Fatalf("RowsInserted=%d, want %d", result.RowsInserted, len(sourceRows)) + } + for _, query := range sourceDB.queryLog { + if strings.HasPrefix(query, "SELECT * FROM") { + t.Fatalf("期望分页流式导入不再全量读取源表,实际查询=%s", query) + } + } + if len(targetDB.appliedBatches) != 3 { + t.Fatalf("期望大表拆成 3 批提交,实际 %d 批", len(targetDB.appliedBatches)) + } + wantBatchSizes := []int{1000, 1000, 501} + for idx, want := range wantBatchSizes { + if got := len(targetDB.appliedBatches[idx].Inserts); got != want { + t.Fatalf("batch %d inserts=%d, want %d", idx+1, got, want) + } + } +} + +func TestRunSync_DirectImportPagingKeepsSelectedPKFilter(t *testing.T) { + sourceRows := []map[string]interface{}{ + {"id": 1, "name": "event-1"}, + {"id": 2, "name": "event-2"}, + {"id": 3, "name": "event-3"}, + } + columns := []connection.ColumnDefinition{ + {Name: "id", Type: "bigint", Nullable: "NO", Key: "PRI"}, + {Name: "name", Type: "varchar(64)", Nullable: "YES"}, + } + sourceDB := &fakeMigrationDB{ + columns: map[string][]connection.ColumnDefinition{ + "app.events": columns, + }, + queryData: map[string][]map[string]interface{}{ + "SELECT `id`, `name` FROM `app`.`events` ORDER BY `id` ASC LIMIT 1000 OFFSET 0": sourceRows, + }, + } + targetDB := &fakeQuerySyncTargetDB{ + fakeMigrationDB: fakeMigrationDB{ + columns: map[string][]connection.ColumnDefinition{ + "app.events": columns, + }, + }, + } + + oldFactory := newSyncDatabase + defer func() { newSyncDatabase = oldFactory }() + callCount := 0 + newSyncDatabase = func(dbType string) (db.Database, error) { + callCount++ + if callCount == 1 { + return sourceDB, nil + } + return targetDB, nil + } + + engine := NewSyncEngine(Reporter{}) + result := engine.RunSync(SyncConfig{ + SourceConfig: connection.ConnectionConfig{Type: "mysql", Database: "app"}, + TargetConfig: connection.ConnectionConfig{Type: "mysql", Database: "app"}, + Tables: []string{"events"}, + Mode: "insert_only", + TableOptions: map[string]TableOptions{ + "events": { + Insert: true, + SelectedInsertPKs: []string{"2"}, + }, + }, + }) + + if !result.Success { + t.Fatalf("RunSync 返回失败: %+v", result) + } + if result.RowsInserted != 1 { + t.Fatalf("RowsInserted=%d, want 1", result.RowsInserted) + } + if len(targetDB.appliedBatches) != 1 || len(targetDB.appliedBatches[0].Inserts) != 1 { + t.Fatalf("expected one selected insert batch, got %+v", targetDB.appliedBatches) + } + if got := targetDB.appliedBatches[0].Inserts[0]["id"]; got != 2 { + t.Fatalf("selected insert id=%v, want 2", got) + } +} + +func TestRunSync_InsertUpdateDiffUsesPagedPKLookups(t *testing.T) { + sourceRows := []map[string]interface{}{ + {"id": 1, "name": "one-new"}, + {"id": 2, "name": "two"}, + {"id": 3, "name": "three"}, + } + columns := []connection.ColumnDefinition{ + {Name: "id", Type: "bigint", Nullable: "NO", Key: "PRI"}, + {Name: "name", Type: "varchar(64)", Nullable: "YES"}, + } + sourceDB := &fakeMigrationDB{ + columns: map[string][]connection.ColumnDefinition{ + "app.events": columns, + }, + queryData: map[string][]map[string]interface{}{ + "SELECT `id`, `name` FROM `app`.`events` ORDER BY `id` ASC LIMIT 1000 OFFSET 0": sourceRows, + "SELECT `id` FROM `app`.`events` WHERE `id` IN (1, 4)": { + {"id": 1}, + }, + }, + } + targetDB := &fakeQuerySyncTargetDB{ + fakeMigrationDB: fakeMigrationDB{ + columns: map[string][]connection.ColumnDefinition{ + "app.events": columns, + }, + queryData: map[string][]map[string]interface{}{ + "SELECT `id`, `name` FROM `app`.`events` WHERE `id` IN (1, 2, 3)": { + {"id": 1, "name": "one-old"}, + {"id": 2, "name": "two"}, + }, + "SELECT `id` FROM `app`.`events` ORDER BY `id` ASC LIMIT 1000": { + {"id": 1}, + {"id": 4}, + }, + }, + }, + } + + oldFactory := newSyncDatabase + defer func() { newSyncDatabase = oldFactory }() + callCount := 0 + newSyncDatabase = func(dbType string) (db.Database, error) { + callCount++ + if callCount == 1 { + return sourceDB, nil + } + return targetDB, nil + } + + engine := NewSyncEngine(Reporter{}) + result := engine.RunSync(SyncConfig{ + SourceConfig: connection.ConnectionConfig{Type: "mysql", Database: "app"}, + TargetConfig: connection.ConnectionConfig{Type: "mysql", Database: "app"}, + Tables: []string{"events"}, + Mode: "insert_update", + TableOptions: map[string]TableOptions{ + "events": {Insert: true, Update: true, Delete: true}, + }, + }) + + if !result.Success { + t.Fatalf("RunSync 返回失败: %+v", result) + } + if result.RowsInserted != 1 || result.RowsUpdated != 1 || result.RowsDeleted != 1 { + t.Fatalf("unexpected sync result: %+v", result) + } + if len(targetDB.appliedBatches) != 2 { + t.Fatalf("expected source diff batch and delete batch, got %d", len(targetDB.appliedBatches)) + } + firstBatch := targetDB.appliedBatches[0] + if !reflect.DeepEqual(firstBatch.Inserts, []map[string]interface{}{{"id": 3, "name": "three"}}) { + t.Fatalf("unexpected inserts: %+v", firstBatch.Inserts) + } + wantUpdates := []connection.UpdateRow{{ + Keys: map[string]interface{}{"id": 1}, + Values: map[string]interface{}{"name": "one-new"}, + }} + if !reflect.DeepEqual(firstBatch.Updates, wantUpdates) { + t.Fatalf("unexpected updates: %+v", firstBatch.Updates) + } + if !reflect.DeepEqual(targetDB.appliedBatches[1].Deletes, []map[string]interface{}{{"id": 4}}) { + t.Fatalf("unexpected deletes: %+v", targetDB.appliedBatches[1].Deletes) + } + for _, query := range append(sourceDB.queryLog, targetDB.queryLog...) { + if strings.HasPrefix(query, "SELECT * FROM") { + t.Fatalf("分页差异同步不应全量读取表,实际查询=%s", query) + } + } +} diff --git a/internal/sync/sql_helpers_test.go b/internal/sync/sql_helpers_test.go index 0f63fb0..bd7c0aa 100644 --- a/internal/sync/sql_helpers_test.go +++ b/internal/sync/sql_helpers_test.go @@ -1,6 +1,9 @@ package sync -import "testing" +import ( + "GoNavi-Wails/internal/connection" + "testing" +) func TestQuoteQualifiedIdentByType_KingbaseLeavesLowercaseQualifiedTableUnquoted(t *testing.T) { t.Parallel() @@ -58,3 +61,93 @@ func TestNormalizeMigrationDBType_KingbaseAliases(t *testing.T) { } } } + +func TestBuildPagedSourceTableQuery_MySQLUsesStablePKPagination(t *testing.T) { + t.Parallel() + + query := buildPagedSourceTableQuery("mysql", "app.events", []connection.ColumnDefinition{ + {Name: "id"}, + {Name: "name"}, + }, "id", 1000, 2000) + + want := "SELECT `id`, `name` FROM `app`.`events` ORDER BY `id` ASC LIMIT 1000 OFFSET 2000" + if query != want { + t.Fatalf("unexpected paged query:\n got: %s\nwant: %s", query, want) + } +} + +func TestBuildPagedSourceTableQuery_SQLServerUsesOuterAliasColumns(t *testing.T) { + t.Parallel() + + query := buildPagedSourceTableQuery("sqlserver", "dbo.events", []connection.ColumnDefinition{ + {Name: "id"}, + {Name: "name"}, + }, "id", 1000, 2000) + + want := "SELECT [__gonavi_page_result__].[id], [__gonavi_page_result__].[name] FROM (SELECT [__gonavi_page__].*, ROW_NUMBER() OVER (ORDER BY [id] ASC) AS [__gonavi_rn__] FROM (SELECT [id], [name] FROM [dbo].[events]) AS [__gonavi_page__]) AS [__gonavi_page_result__] WHERE [__gonavi_rn__] > 2000 AND [__gonavi_rn__] <= 3000 ORDER BY [__gonavi_rn__]" + if query != want { + t.Fatalf("unexpected paged query:\n got: %s\nwant: %s", query, want) + } +} + +func TestIsSamePhysicalSyncTableDetectsFullOverwriteSelfTarget(t *testing.T) { + t.Parallel() + + cfg := SyncConfig{ + SourceConfig: connection.ConnectionConfig{Type: "mysql", Host: "127.0.0.1", Port: 3306, Database: "app"}, + TargetConfig: connection.ConnectionConfig{Type: "mysql", Host: "127.0.0.1", Port: 3306, Database: "app"}, + } + plan := SchemaMigrationPlan{SourceQueryTable: "app.events", TargetQueryTable: "app.events"} + if !isSamePhysicalSyncTable(cfg, plan, "mysql", "mysql") { + t.Fatal("expected identical connection/table to be detected") + } + + cfg.TargetConfig.Database = "archive" + if isSamePhysicalSyncTable(cfg, plan, "mysql", "mysql") { + t.Fatal("different database should not be treated as same physical table") + } +} + +func TestBuildPKInSelectQueryEscapesStringLiterals(t *testing.T) { + t.Parallel() + + query := buildPKInSelectQuery("mysql", "app.users", []connection.ColumnDefinition{ + {Name: "id"}, + {Name: "name"}, + }, "id", []interface{}{"a'1", "b2"}) + + want := "SELECT `id`, `name` FROM `app`.`users` WHERE `id` IN ('a''1', 'b2')" + if query != want { + t.Fatalf("unexpected PK IN query:\n got: %s\nwant: %s", query, want) + } +} + +func TestBuildKeysetPagedTableQueryUsesLastPK(t *testing.T) { + t.Parallel() + + query := buildKeysetPagedTableQuery("mysql", "app.users", []connection.ColumnDefinition{{Name: "id"}}, "id", 100, true, 50) + + want := "SELECT `id` FROM `app`.`users` WHERE `id` > 100 ORDER BY `id` ASC LIMIT 50" + if query != want { + t.Fatalf("unexpected keyset query:\n got: %s\nwant: %s", query, want) + } +} + +func TestBuildSourceQueryPageSQLWrapsSelect(t *testing.T) { + t.Parallel() + + query := buildSourceQueryPageSQL("mysql", "SELECT id, name FROM active_users;", "id", 1000, 2000) + + want := "SELECT * FROM (SELECT id, name FROM active_users) AS __gonavi_source_query__ ORDER BY `id` ASC LIMIT 1000 OFFSET 2000" + if query != want { + t.Fatalf("unexpected source query page SQL:\n got: %s\nwant: %s", query, want) + } +} + +func TestNormalizeSourceQueryForPagingRejectsMultiStatement(t *testing.T) { + t.Parallel() + + if _, ok := normalizeSourceQueryForPaging("SELECT * FROM users; DELETE FROM users"); ok { + t.Fatal("expected multi-statement source query to be rejected for pagination") + } +} diff --git a/internal/sync/sync_engine.go b/internal/sync/sync_engine.go index 7bb4ec2..ff8795f 100644 --- a/internal/sync/sync_engine.go +++ b/internal/sync/sync_engine.go @@ -5,11 +5,14 @@ import ( "GoNavi-Wails/internal/db" "GoNavi-Wails/internal/logger" "fmt" + "math" "sort" "strings" "time" ) +const defaultSyncApplyBatchSize = 1000 + // SyncConfig defines the parameters for a synchronization task type SyncConfig struct { SourceConfig connection.ConnectionConfig `json:"sourceConfig"` @@ -251,6 +254,54 @@ func (s *SyncEngine) RunSync(config SyncConfig) SyncResult { return } + if handled, inserted, err := s.tryApplyDirectImportInPages(config, &result, i, totalTables, tableName, sourceDB, targetDB, plan, cols, targetCols, opts, sourceType, targetType, applyTableName); handled { + if err != nil { + logger.Error(err, "分页流式导入失败:表=%s", tableName) + s.appendLog(config.JobID, &result, "error", fmt.Sprintf(" -> 分页流式导入失败: %v", err)) + return + } + result.RowsInserted += inserted + if inserted > 0 { + s.appendLog(config.JobID, &result, "info", fmt.Sprintf(" -> 分页流式导入完成:插入=%d 行", inserted)) + } else { + s.appendLog(config.JobID, &result, "info", " -> 源表无可导入数据") + } + if len(plan.PostDataSQL) > 0 { + s.progress(config.JobID, i, totalTables, tableName, "创建索引") + if err := executeSQLStatements(targetDB.Exec, plan.PostDataSQL); err != nil { + s.appendLog(config.JobID, &result, "error", fmt.Sprintf("创建索引失败:表=%s 错误=%v", tableName, err)) + return + } + } + result.TablesSynced++ + return + } + + if handled, counts, err := s.tryApplyDiffInPages(config, &result, i, totalTables, tableName, sourceDB, targetDB, plan, cols, targetCols, opts, sourceType, targetType, applyTableName, pkCol); handled { + if err != nil { + logger.Error(err, "分页差异同步失败:表=%s", tableName) + s.appendLog(config.JobID, &result, "error", fmt.Sprintf(" -> 分页差异同步失败: %v", err)) + return + } + result.RowsInserted += counts.Inserts + result.RowsUpdated += counts.Updates + result.RowsDeleted += counts.Deletes + if counts.Inserts > 0 || counts.Updates > 0 || counts.Deletes > 0 { + s.appendLog(config.JobID, &result, "info", fmt.Sprintf(" -> 分页差异同步完成:插入=%d 更新=%d 删除=%d", counts.Inserts, counts.Updates, counts.Deletes)) + } else { + s.appendLog(config.JobID, &result, "info", " -> 数据一致,无需变更.") + } + if len(plan.PostDataSQL) > 0 { + s.progress(config.JobID, i, totalTables, tableName, "创建索引") + if err := executeSQLStatements(targetDB.Exec, plan.PostDataSQL); err != nil { + s.appendLog(config.JobID, &result, "error", fmt.Sprintf("创建索引失败:表=%s 错误=%v", tableName, err)) + return + } + } + result.TablesSynced++ + return + } + s.progress(config.JobID, i, totalTables, tableName, "读取源表数据") sourceRows, _, err := sourceDB.Query(fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(sourceType, sourceQueryTable))) if err != nil { @@ -401,7 +452,7 @@ func (s *SyncEngine) RunSync(config SyncConfig) SyncResult { if len(changeSet.Inserts) > 0 || len(changeSet.Updates) > 0 || len(changeSet.Deletes) > 0 { s.appendLog(config.JobID, &result, "info", fmt.Sprintf(" -> 需插入: %d 行, 需更新: %d 行, 需删除: %d 行", len(changeSet.Inserts), len(changeSet.Updates), len(changeSet.Deletes))) if applier, ok := targetDB.(db.BatchApplier); ok { - if err := applier.ApplyChanges(applyTableName, changeSet); err != nil { + if err := s.applyChangesInBatches(config.JobID, &result, applyTableName, applier, changeSet); err != nil { s.appendLog(config.JobID, &result, "error", fmt.Sprintf(" -> 应用变更失败: %v", err)) return } @@ -497,6 +548,75 @@ func (s *SyncEngine) fail(jobID string, totalTables int, res SyncResult, msg str return res } +func (s *SyncEngine) applyChangesInBatches(jobID string, res *SyncResult, tableName string, applier db.BatchApplier, changes connection.ChangeSet) error { + batches := splitChangeSetBatches(changes, defaultSyncApplyBatchSize) + if len(batches) == 0 { + return nil + } + if len(batches) > 1 { + s.appendLog(jobID, res, "info", fmt.Sprintf(" -> 大批量变更将拆分为 %d 批提交(每批最多 %d 行)", len(batches), defaultSyncApplyBatchSize)) + } + for idx, batch := range batches { + if len(batches) > 1 { + s.appendLog(jobID, res, "info", fmt.Sprintf(" -> 提交批次 %d/%d:插入=%d 更新=%d 删除=%d", + idx+1, len(batches), len(batch.Inserts), len(batch.Updates), len(batch.Deletes))) + } + if err := applier.ApplyChanges(tableName, batch); err != nil { + if len(batches) > 1 { + return fmt.Errorf("批次 %d/%d 失败: %w", idx+1, len(batches), err) + } + return err + } + } + return nil +} + +func splitChangeSetBatches(changes connection.ChangeSet, batchSize int) []connection.ChangeSet { + if batchSize <= 0 { + batchSize = defaultSyncApplyBatchSize + } + total := len(changes.Deletes) + len(changes.Updates) + len(changes.Inserts) + if total == 0 { + return nil + } + + batches := make([]connection.ChangeSet, 0, int(math.Ceil(float64(total)/float64(batchSize)))) + current := connection.ChangeSet{LocatorStrategy: changes.LocatorStrategy} + currentSize := 0 + flush := func() { + if currentSize == 0 { + return + } + batches = append(batches, current) + current = connection.ChangeSet{LocatorStrategy: changes.LocatorStrategy} + currentSize = 0 + } + + for _, row := range changes.Deletes { + if currentSize >= batchSize { + flush() + } + current.Deletes = append(current.Deletes, row) + currentSize++ + } + for _, row := range changes.Updates { + if currentSize >= batchSize { + flush() + } + current.Updates = append(current.Updates, row) + currentSize++ + } + for _, row := range changes.Inserts { + if currentSize >= batchSize { + flush() + } + current.Inserts = append(current.Inserts, row) + currentSize++ + } + flush() + return batches +} + func (s *SyncEngine) execDDLStatements(jobID string, res *SyncResult, database db.Database, tableName string, stage string, statements []string) error { for _, statement := range statements { sqlText := strings.TrimSpace(statement)