From 938bc539666e4785188daf33192f0ec329aec863 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Tue, 16 Jun 2026 09:25:16 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(mysql):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=20DATE=20=E5=AD=97=E6=AE=B5=E6=98=BE=E7=A4=BA=E4=B8=BA=20datet?= =?UTF-8?q?ime?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 查询扫描链路透传数据库方言,区分 MySQL 与 Oracle DATE 语义 - MySQL/MariaDB/自定义 mysql 驱动的 DATE/NEWDATE 只展示 YYYY-MM-DD - 保留 DATETIME/TIMESTAMP 和 Oracle DATE 的时间信息 - 补充值规整与扫描链路回归测试 Close #565 --- internal/db/custom_impl.go | 11 +++- internal/db/mariadb_impl.go | 8 +-- internal/db/mysql_impl.go | 8 +-- internal/db/query_value.go | 26 +++++++-- internal/db/query_value_test.go | 28 ++++++++++ internal/db/scan_rows.go | 12 ++++- internal/db/scan_rows_test.go | 95 +++++++++++++++++++++++++++++++-- 7 files changed, 170 insertions(+), 18 deletions(-) diff --git a/internal/db/custom_impl.go b/internal/db/custom_impl.go index 8a06122..0ad1ba8 100644 --- a/internal/db/custom_impl.go +++ b/internal/db/custom_impl.go @@ -95,7 +95,7 @@ func (c *CustomDB) QueryContext(ctx context.Context, query string) ([]map[string } defer rows.Close() - return scanRows(rows) + return scanRowsForDialect(rows, c.scanDialect()) } func (c *CustomDB) Query(query string) ([]map[string]interface{}, []string, error) { @@ -108,7 +108,14 @@ func (c *CustomDB) Query(query string) ([]map[string]interface{}, []string, erro return nil, nil, err } defer rows.Close() - return scanRows(rows) + return scanRowsForDialect(rows, c.scanDialect()) +} + +func (c *CustomDB) scanDialect() string { + if strings.EqualFold(strings.TrimSpace(c.driver), "mysql") { + return "mysql" + } + return "" } func (c *CustomDB) ExecContext(ctx context.Context, query string) (int64, error) { diff --git a/internal/db/mariadb_impl.go b/internal/db/mariadb_impl.go index 06f4465..508fddb 100644 --- a/internal/db/mariadb_impl.go +++ b/internal/db/mariadb_impl.go @@ -87,7 +87,7 @@ func (m *MariaDB) QueryMulti(query string) ([]connection.ResultSetData, error) { return nil, err } defer rows.Close() - return scanMultiRows(rows) + return scanMultiRowsForDialect(rows, "mariadb") } func (m *MariaDB) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) { @@ -99,7 +99,7 @@ func (m *MariaDB) QueryMultiContext(ctx context.Context, query string) ([]connec return nil, err } defer rows.Close() - return scanMultiRows(rows) + return scanMultiRowsForDialect(rows, "mariadb") } func (m *MariaDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { @@ -113,7 +113,7 @@ func (m *MariaDB) QueryContext(ctx context.Context, query string) ([]map[string] } defer rows.Close() - return scanRows(rows) + return scanRowsForDialect(rows, "mariadb") } func (m *MariaDB) Query(query string) ([]map[string]interface{}, []string, error) { @@ -126,7 +126,7 @@ func (m *MariaDB) Query(query string) ([]map[string]interface{}, []string, error return nil, nil, err } defer rows.Close() - return scanRows(rows) + return scanRowsForDialect(rows, "mariadb") } func (m *MariaDB) ExecBatchContext(ctx context.Context, query string) (int64, error) { diff --git a/internal/db/mysql_impl.go b/internal/db/mysql_impl.go index ce3124e..2b4adf6 100644 --- a/internal/db/mysql_impl.go +++ b/internal/db/mysql_impl.go @@ -907,7 +907,7 @@ func (m *MySQLDB) QueryMulti(query string) ([]connection.ResultSetData, error) { return nil, err } defer rows.Close() - return scanMultiRows(rows) + return scanMultiRowsForDialect(rows, "mysql") } func (m *MySQLDB) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) { @@ -919,7 +919,7 @@ func (m *MySQLDB) QueryMultiContext(ctx context.Context, query string) ([]connec return nil, err } defer rows.Close() - return scanMultiRows(rows) + return scanMultiRowsForDialect(rows, "mysql") } func (m *MySQLDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { @@ -933,7 +933,7 @@ func (m *MySQLDB) QueryContext(ctx context.Context, query string) ([]map[string] } defer rows.Close() - return scanRows(rows) + return scanRowsForDialect(rows, "mysql") } func (m *MySQLDB) Query(query string) ([]map[string]interface{}, []string, error) { @@ -946,7 +946,7 @@ func (m *MySQLDB) Query(query string) ([]map[string]interface{}, []string, error return nil, nil, err } defer rows.Close() - return scanRows(rows) + return scanRowsForDialect(rows, "mysql") } func (m *MySQLDB) ExecBatchContext(ctx context.Context, query string) (int64, error) { diff --git a/internal/db/query_value.go b/internal/db/query_value.go index 75554e1..def345e 100644 --- a/internal/db/query_value.go +++ b/internal/db/query_value.go @@ -33,8 +33,12 @@ func normalizeQueryValue(v interface{}) interface{} { } func normalizeQueryValueWithDBType(v interface{}, databaseTypeName string) interface{} { + return normalizeQueryValueWithDBTypeAndDialect(v, databaseTypeName, "") +} + +func normalizeQueryValueWithDBTypeAndDialect(v interface{}, databaseTypeName, dialect string) interface{} { if tm, ok := v.(time.Time); ok { - return normalizeTemporalValueForDisplay(tm, databaseTypeName) + return normalizeTemporalValueForDisplay(tm, databaseTypeName, dialect) } if b, ok := v.([]byte); ok { return bytesToDisplayValue(b, databaseTypeName) @@ -42,15 +46,31 @@ func normalizeQueryValueWithDBType(v interface{}, databaseTypeName string) inter return normalizeCompositeQueryValue(v) } -func normalizeTemporalValueForDisplay(value time.Time, databaseTypeName string) interface{} { +func normalizeTemporalValueForDisplay(value time.Time, databaseTypeName, dialect string) interface{} { if value.IsZero() { if zeroValue, ok := zeroTemporalDisplayValue(databaseTypeName); ok { return zeroValue } } + if shouldDisplayTemporalValueAsDateOnly(databaseTypeName, dialect) { + return value.Format("2006-01-02") + } return value.Format(time.RFC3339Nano) } +func shouldDisplayTemporalValueAsDateOnly(databaseTypeName, dialect string) bool { + typeName := strings.ToUpper(strings.TrimSpace(databaseTypeName)) + if typeName != "DATE" && typeName != "NEWDATE" { + return false + } + switch strings.ToLower(strings.TrimSpace(dialect)) { + case "mysql", "mariadb", "goldendb", "greatdb", "gdb", "diros", "doris", "starrocks", "sphinx": + return true + default: + return false + } +} + func zeroTemporalDisplayValue(databaseTypeName string) (string, bool) { typeName := strings.ToUpper(strings.TrimSpace(databaseTypeName)) if typeName == "" { @@ -125,7 +145,7 @@ func normalizeCompositeQueryValue(v interface{}) interface{} { // 部分驱动(如 Kingbase)会返回复杂结构体值,直接透传会导致前端渲染和比较开销激增。 // 统一降级为可读字符串,避免对象深层序列化触发 UI 卡顿。 if tm, ok := v.(time.Time); ok { - return normalizeTemporalValueForDisplay(tm, "") + return normalizeTemporalValueForDisplay(tm, "", "") } if stringer, ok := v.(fmt.Stringer); ok { return stringer.String() diff --git a/internal/db/query_value_test.go b/internal/db/query_value_test.go index d02dda3..7cce552 100644 --- a/internal/db/query_value_test.go +++ b/internal/db/query_value_test.go @@ -220,6 +220,34 @@ func TestNormalizeQueryValueWithDBType_TimeStructToRFC3339(t *testing.T) { } } +func TestNormalizeQueryValueWithDBTypeAndDialect_MySQLDateOnly(t *testing.T) { + input := time.Date(2025, 10, 1, 0, 0, 0, 0, time.Local) + + got := normalizeQueryValueWithDBTypeAndDialect(input, "DATE", "mysql") + if got != "2025-10-01" { + t.Fatalf("MySQL DATE 应只展示日期,实际=%v(%T)", got, got) + } + + got = normalizeQueryValueWithDBTypeAndDialect(input, "NEWDATE", "mysql") + if got != "2025-10-01" { + t.Fatalf("MySQL NEWDATE 应只展示日期,实际=%v(%T)", got, got) + } +} + +func TestNormalizeQueryValueWithDBTypeAndDialect_DatetimeKeepsTime(t *testing.T) { + input := time.Date(2025, 10, 1, 13, 14, 15, 0, time.UTC) + + got := normalizeQueryValueWithDBTypeAndDialect(input, "DATETIME", "mysql") + if got != "2025-10-01T13:14:15Z" { + t.Fatalf("MySQL DATETIME 应保留时间,实际=%v(%T)", got, got) + } + + got = normalizeQueryValueWithDBTypeAndDialect(input, "DATE", "oracle") + if got != "2025-10-01T13:14:15Z" { + t.Fatalf("Oracle DATE 应保留时间语义,实际=%v(%T)", got, got) + } +} + func TestNormalizeQueryValueWithDBType_ZeroTemporalValues(t *testing.T) { zero := time.Time{} cases := []struct { diff --git a/internal/db/scan_rows.go b/internal/db/scan_rows.go index e1959ef..cac7064 100644 --- a/internal/db/scan_rows.go +++ b/internal/db/scan_rows.go @@ -8,6 +8,10 @@ import ( ) func scanRows(rows *sql.Rows) ([]map[string]interface{}, []string, error) { + return scanRowsForDialect(rows, "") +} + +func scanRowsForDialect(rows *sql.Rows, dialect string) ([]map[string]interface{}, []string, error) { columns, err := rows.Columns() if err != nil { return nil, nil, err @@ -38,7 +42,7 @@ func scanRows(rows *sql.Rows) ([]map[string]interface{}, []string, error) { if colTypes != nil && i < len(colTypes) && colTypes[i] != nil { dbTypeName = colTypes[i].DatabaseTypeName() } - entry[col] = normalizeQueryValueWithDBType(values[i], dbTypeName) + entry[col] = normalizeQueryValueWithDBTypeAndDialect(values[i], dbTypeName, dialect) } resultData = append(resultData, entry) } @@ -92,9 +96,13 @@ func ensureUniqueQueryColumnNames(columns []string) []string { // scanMultiRows 遍历 sql.Rows 中的所有结果集,将每个结果集作为 ResultSetData 返回。 // 利用 rows.NextResultSet() 支持一次 query 返回多个结果集的场景。 func scanMultiRows(rows *sql.Rows) ([]connection.ResultSetData, error) { + return scanMultiRowsForDialect(rows, "") +} + +func scanMultiRowsForDialect(rows *sql.Rows, dialect string) ([]connection.ResultSetData, error) { var results []connection.ResultSetData for { - data, cols, err := scanRows(rows) + data, cols, err := scanRowsForDialect(rows, dialect) if err != nil { return results, err } diff --git a/internal/db/scan_rows_test.go b/internal/db/scan_rows_test.go index 91c5a62..cd683ac 100644 --- a/internal/db/scan_rows_test.go +++ b/internal/db/scan_rows_test.go @@ -8,6 +8,7 @@ import ( "reflect" "sync" "testing" + "time" ) const scanRowsDuplicateDriverName = "gonavi-scan-rows-duplicate" @@ -27,6 +28,18 @@ func (scanRowsDuplicateConn) Close() error { return func (scanRowsDuplicateConn) Begin() (driver.Tx, error) { return nil, driver.ErrSkip } func (scanRowsDuplicateConn) QueryContext(_ context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + if query == "SELECT date_columns" { + return &scanRowsDuplicateRows{ + columns: []string{"ship_date", "created_at"}, + columnTypes: []string{"DATE", "DATETIME"}, + rows: [][]driver.Value{ + { + time.Date(2025, 10, 1, 0, 0, 0, 0, time.UTC), + time.Date(2025, 10, 1, 13, 14, 15, 0, time.UTC), + }, + }, + }, nil + } return &scanRowsDuplicateRows{ columns: []string{"id", "id", "name"}, rows: [][]driver.Value{ @@ -38,13 +51,20 @@ func (scanRowsDuplicateConn) QueryContext(_ context.Context, query string, args var _ driver.QueryerContext = (*scanRowsDuplicateConn)(nil) type scanRowsDuplicateRows struct { - columns []string - rows [][]driver.Value - index int + columns []string + columnTypes []string + rows [][]driver.Value + index int } func (r *scanRowsDuplicateRows) Columns() []string { return append([]string(nil), r.columns...) } func (r *scanRowsDuplicateRows) Close() error { return nil } +func (r *scanRowsDuplicateRows) ColumnTypeDatabaseTypeName(index int) string { + if index < 0 || index >= len(r.columnTypes) { + return "" + } + return r.columnTypes[index] +} func (r *scanRowsDuplicateRows) Next(dest []driver.Value) error { if r.index >= len(r.rows) { @@ -95,3 +115,72 @@ func TestScanRowsRenamesDuplicateColumns(t *testing.T) { t.Fatalf("unexpected row data: %#v", data[0]) } } + +func TestScanRowsForMySQLDialectFormatsDateOnly(t *testing.T) { + t.Parallel() + + registerScanRowsDuplicateDriverOnce.Do(func() { + sql.Register(scanRowsDuplicateDriverName, scanRowsDuplicateDriver{}) + }) + + dbConn, err := sql.Open(scanRowsDuplicateDriverName, "") + if err != nil { + t.Fatalf("open date scan rows db failed: %v", err) + } + defer dbConn.Close() + + rows, err := dbConn.QueryContext(context.Background(), "SELECT date_columns") + if err != nil { + t.Fatalf("query date scan rows db failed: %v", err) + } + defer rows.Close() + + data, columns, err := scanRowsForDialect(rows, "mysql") + if err != nil { + t.Fatalf("scanRowsForDialect returned error: %v", err) + } + + if !reflect.DeepEqual(columns, []string{"ship_date", "created_at"}) { + t.Fatalf("unexpected columns: %v", columns) + } + if len(data) != 1 { + t.Fatalf("expected one row, got=%d", len(data)) + } + if data[0]["ship_date"] != "2025-10-01" { + t.Fatalf("MySQL DATE 应展示为日期,实际=%v(%T)", data[0]["ship_date"], data[0]["ship_date"]) + } + if data[0]["created_at"] != "2025-10-01T13:14:15Z" { + t.Fatalf("MySQL DATETIME 应保留时间,实际=%v(%T)", data[0]["created_at"], data[0]["created_at"]) + } +} + +func TestScanRowsForOracleDialectKeepsDateTime(t *testing.T) { + t.Parallel() + + registerScanRowsDuplicateDriverOnce.Do(func() { + sql.Register(scanRowsDuplicateDriverName, scanRowsDuplicateDriver{}) + }) + + dbConn, err := sql.Open(scanRowsDuplicateDriverName, "") + if err != nil { + t.Fatalf("open date scan rows db failed: %v", err) + } + defer dbConn.Close() + + rows, err := dbConn.QueryContext(context.Background(), "SELECT date_columns") + if err != nil { + t.Fatalf("query date scan rows db failed: %v", err) + } + defer rows.Close() + + data, _, err := scanRowsForDialect(rows, "oracle") + if err != nil { + t.Fatalf("scanRowsForDialect returned error: %v", err) + } + if len(data) != 1 { + t.Fatalf("expected one row, got=%d", len(data)) + } + if data[0]["ship_date"] != "2025-10-01T00:00:00Z" { + t.Fatalf("Oracle DATE 应保留 datetime 语义,实际=%v(%T)", data[0]["ship_date"], data[0]["ship_date"]) + } +}