From f5f87189df666e23b4d85076804b8d3f886e99e1 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Tue, 28 Apr 2026 13:39:32 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(oracle):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E6=9F=A5=E8=AF=A2=E7=BB=93=E6=9E=9C=E7=BC=96=E8=BE=91?= =?UTF-8?q?=E6=8F=90=E4=BA=A4=E6=97=A5=E6=9C=9F=E6=A0=BC=E5=BC=8F=E6=8A=A5?= =?UTF-8?q?=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 参数处理:提交事务前加载 Oracle 表字段类型,用于识别 DATE 和 TIMESTAMP 字段 - 更新修复:UPDATE 的 SET 值和 WHERE 条件统一转换日期时间参数 - 场景覆盖:修复新建查询结果编辑后提交事务触发 ORA-01861 的问题 - 类型绑定:将 Oracle 日期时间字符串解析为 time.Time,避免依赖数据库会话日期格式 - 兼容处理:支持 RFC3339、带时区和常见本地日期时间格式 - 测试覆盖:新增 Oracle ApplyChanges recording driver 回归测试 Refs #419 --- internal/db/oracle_applychanges_test.go | 183 ++++++++++++++++++++++++ internal/db/oracle_impl.go | 115 ++++++++++++++- 2 files changed, 294 insertions(+), 4 deletions(-) create mode 100644 internal/db/oracle_applychanges_test.go diff --git a/internal/db/oracle_applychanges_test.go b/internal/db/oracle_applychanges_test.go new file mode 100644 index 0000000..88ff7e5 --- /dev/null +++ b/internal/db/oracle_applychanges_test.go @@ -0,0 +1,183 @@ +package db + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "io" + "strings" + "sync" + "testing" + "time" + + "GoNavi-Wails/internal/connection" +) + +const oracleRecordingDriverName = "gonavi_oracle_recording" + +var ( + registerOracleRecordingDriverOnce sync.Once + oracleRecordingDriverMu sync.Mutex + oracleRecordingDriverSeq int + oracleRecordingDriverStates = map[string]*oracleRecordingState{} +) + +type oracleRecordingState struct { + mu sync.Mutex + execArgs [][]driver.NamedValue +} + +func (s *oracleRecordingState) snapshotExecArgs() [][]driver.NamedValue { + s.mu.Lock() + defer s.mu.Unlock() + + result := make([][]driver.NamedValue, len(s.execArgs)) + for i, args := range s.execArgs { + result[i] = append([]driver.NamedValue(nil), args...) + } + return result +} + +type oracleRecordingDriver struct{} + +func (oracleRecordingDriver) Open(name string) (driver.Conn, error) { + oracleRecordingDriverMu.Lock() + state := oracleRecordingDriverStates[name] + oracleRecordingDriverMu.Unlock() + if state == nil { + return nil, fmt.Errorf("recording state not found: %s", name) + } + return &oracleRecordingConn{state: state}, nil +} + +type oracleRecordingConn struct { + state *oracleRecordingState +} + +func (c *oracleRecordingConn) Prepare(query string) (driver.Stmt, error) { + return nil, fmt.Errorf("prepare not supported in oracle recording driver: %s", query) +} + +func (c *oracleRecordingConn) Close() error { return nil } + +func (c *oracleRecordingConn) Begin() (driver.Tx, error) { return oracleRecordingTx{}, nil } + +func (c *oracleRecordingConn) ExecContext(_ context.Context, _ string, args []driver.NamedValue) (driver.Result, error) { + c.state.mu.Lock() + defer c.state.mu.Unlock() + c.state.execArgs = append(c.state.execArgs, append([]driver.NamedValue(nil), args...)) + return driver.RowsAffected(1), nil +} + +func (c *oracleRecordingConn) QueryContext(_ context.Context, query string, _ []driver.NamedValue) (driver.Rows, error) { + if strings.Contains(strings.ToLower(query), "tab_columns") { + return &oracleRecordingRows{ + columns: []string{"COLUMN_NAME", "DATA_TYPE", "NULLABLE", "DATA_DEFAULT"}, + rows: [][]driver.Value{ + {"UPDATED_AT", "TIMESTAMP", "YES", nil}, + {"CREATED_AT", "DATE", "NO", nil}, + }, + }, nil + } + return &oracleRecordingRows{}, nil +} + +var _ driver.ExecerContext = (*oracleRecordingConn)(nil) +var _ driver.QueryerContext = (*oracleRecordingConn)(nil) + +type oracleRecordingTx struct{} + +func (oracleRecordingTx) Commit() error { return nil } +func (oracleRecordingTx) Rollback() error { return nil } + +type oracleRecordingRows struct { + columns []string + rows [][]driver.Value + index int +} + +func (r *oracleRecordingRows) Columns() []string { + return append([]string(nil), r.columns...) +} + +func (r *oracleRecordingRows) Close() error { return nil } + +func (r *oracleRecordingRows) Next(dest []driver.Value) error { + if r.index >= len(r.rows) { + return io.EOF + } + row := r.rows[r.index] + for idx := range dest { + if idx < len(row) { + dest[idx] = row[idx] + } + } + r.index++ + return nil +} + +func openOracleRecordingDB(t *testing.T) (*sql.DB, *oracleRecordingState) { + t.Helper() + registerOracleRecordingDriverOnce.Do(func() { + sql.Register(oracleRecordingDriverName, oracleRecordingDriver{}) + }) + + oracleRecordingDriverMu.Lock() + oracleRecordingDriverSeq++ + dsn := fmt.Sprintf("oracle-recording-%d", oracleRecordingDriverSeq) + state := &oracleRecordingState{} + oracleRecordingDriverStates[dsn] = state + oracleRecordingDriverMu.Unlock() + + dbConn, err := sql.Open(oracleRecordingDriverName, dsn) + if err != nil { + t.Fatalf("打开 recording db 失败: %v", err) + } + + t.Cleanup(func() { + _ = dbConn.Close() + oracleRecordingDriverMu.Lock() + delete(oracleRecordingDriverStates, dsn) + oracleRecordingDriverMu.Unlock() + }) + + return dbConn, state +} + +func TestOracleApplyChangesNormalizesTemporalStringsForUpdate(t *testing.T) { + t.Parallel() + + dbConn, state := openOracleRecordingDB(t) + oracleDB := &OracleDB{conn: dbConn} + + changes := connection.ChangeSet{ + Updates: []connection.UpdateRow{{ + Keys: map[string]interface{}{ + "CREATED_AT": "2026-03-05T10:30:00Z", + }, + Values: map[string]interface{}{ + "UPDATED_AT": "2026-04-01T12:13:14.123456789Z", + }, + }}, + } + + if err := oracleDB.ApplyChanges("EVENTS", changes); err != nil { + t.Fatalf("ApplyChanges 返回错误: %v", err) + } + + executions := state.snapshotExecArgs() + if len(executions) != 1 { + t.Fatalf("期望执行 1 条更新,实际 %d 条", len(executions)) + } + args := executions[0] + if len(args) != 2 { + t.Fatalf("期望 2 个绑定参数,实际 %d 个: %#v", len(args), args) + } + if _, ok := args[0].Value.(time.Time); !ok { + t.Fatalf("更新时间字段应绑定为 time.Time,实际=%#v(%T)", args[0].Value, args[0].Value) + } + if _, ok := args[1].Value.(time.Time); !ok { + t.Fatalf("日期主键字段应绑定为 time.Time,实际=%#v(%T)", args[1].Value, args[1].Value) + } +} diff --git a/internal/db/oracle_impl.go b/internal/db/oracle_impl.go index e82e356..9efb1b6 100644 --- a/internal/db/oracle_impl.go +++ b/internal/db/oracle_impl.go @@ -389,11 +389,118 @@ func (o *OracleDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDe return triggers, nil } +func splitOracleQualifiedTableName(raw string) (string, string) { + table := strings.TrimSpace(raw) + schema := "" + if parts := strings.SplitN(table, ".", 2); len(parts) == 2 { + schema = strings.Trim(strings.TrimSpace(parts[0]), "\"") + table = strings.TrimSpace(parts[1]) + } + table = strings.Trim(strings.TrimSpace(table), "\"") + return schema, table +} + +func (o *OracleDB) loadColumnTypeMap(tableName string) map[string]string { + result := map[string]string{} + schema, table := splitOracleQualifiedTableName(tableName) + if table == "" { + return result + } + + columns, err := o.GetColumns(schema, table) + if err != nil { + logger.Warnf("加载 Oracle 列元数据失败(不影响提交):表=%s err=%v", tableName, err) + return result + } + + for _, col := range columns { + name := strings.ToLower(strings.TrimSpace(col.Name)) + if name == "" { + continue + } + result[name] = strings.TrimSpace(col.Type) + } + return result +} + +func normalizeOracleValueForWrite(columnName string, value interface{}, columnTypeMap map[string]string) interface{} { + columnType := columnTypeMap[strings.ToLower(strings.TrimSpace(columnName))] + if !isOracleTemporalColumnType(columnType) { + return value + } + if value == nil { + return nil + } + text, ok := value.(string) + if !ok { + return value + } + raw := strings.TrimSpace(text) + if raw == "" { + return nil + } + if parsed, ok := parseOracleTemporalString(raw); ok { + return parsed + } + return value +} + +func isOracleTemporalColumnType(columnType string) bool { + typ := strings.ToUpper(strings.TrimSpace(columnType)) + return strings.Contains(typ, "DATE") || strings.Contains(typ, "TIMESTAMP") +} + +func parseOracleTemporalString(raw string) (time.Time, bool) { + text := strings.TrimSpace(raw) + if text == "" { + return time.Time{}, false + } + text = strings.ReplaceAll(text, "+ ", "+") + text = strings.ReplaceAll(text, "- ", "-") + + candidates := []string{text} + if len(text) >= 19 && text[10] == ' ' && (strings.HasSuffix(text, "Z") || hasTimezoneOffset(text)) { + candidates = append(candidates, strings.Replace(text, " ", "T", 1)) + } + + layoutsWithZone := []string{ + "2006-01-02 15:04:05.999999999 -0700 MST", + "2006-01-02 15:04:05 -0700 MST", + "2006-01-02 15:04:05.999999999 -0700", + "2006-01-02 15:04:05 -0700", + time.RFC3339Nano, + time.RFC3339, + } + for _, candidate := range candidates { + for _, layout := range layoutsWithZone { + if parsed, err := time.Parse(layout, candidate); err == nil { + return parsed, true + } + } + } + + layoutsWithoutZone := []string{ + "2006-01-02T15:04:05.999999999", + "2006-01-02T15:04:05", + "2006-01-02 15:04:05.999999999", + "2006-01-02 15:04:05", + "2006-01-02", + } + for _, layout := range layoutsWithoutZone { + if parsed, err := time.ParseInLocation(layout, text, time.Local); err == nil { + return parsed, true + } + } + return time.Time{}, false +} + func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet) error { if o.conn == nil { return fmt.Errorf("连接未打开") } + columnTypeMap := o.loadColumnTypeMap(tableName) + tx, err := o.conn.Begin() if err != nil { return err @@ -432,7 +539,7 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet) for k, v := range pk { idx++ wheres = append(wheres, fmt.Sprintf("%s = :%d", quoteIdent(k), idx)) - args = append(args, v) + args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap)) } if len(wheres) == 0 { continue @@ -452,7 +559,7 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet) for k, v := range update.Values { idx++ sets = append(sets, fmt.Sprintf("%s = :%d", quoteIdent(k), idx)) - args = append(args, v) + args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap)) } if len(sets) == 0 { @@ -463,7 +570,7 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet) for k, v := range update.Keys { idx++ wheres = append(wheres, fmt.Sprintf("%s = :%d", quoteIdent(k), idx)) - args = append(args, v) + args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap)) } if len(wheres) == 0 { @@ -487,7 +594,7 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet) idx++ cols = append(cols, quoteIdent(k)) placeholders = append(placeholders, fmt.Sprintf(":%d", idx)) - args = append(args, v) + args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap)) } if len(cols) == 0 {