🐛 fix(oracle): 修复查询结果编辑提交日期格式报错

- 参数处理:提交事务前加载 Oracle 表字段类型,用于识别 DATE 和 TIMESTAMP 字段
- 更新修复:UPDATE 的 SET 值和 WHERE 条件统一转换日期时间参数
- 场景覆盖:修复新建查询结果编辑后提交事务触发 ORA-01861 的问题
- 类型绑定:将 Oracle 日期时间字符串解析为 time.Time,避免依赖数据库会话日期格式
- 兼容处理:支持 RFC3339、带时区和常见本地日期时间格式
- 测试覆盖:新增 Oracle ApplyChanges recording driver 回归测试
Refs #419
This commit is contained in:
Syngnat
2026-04-28 13:39:32 +08:00
parent ef634075ab
commit f5f87189df
2 changed files with 294 additions and 4 deletions

View File

@@ -0,0 +1,183 @@
package db
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"io"
"strings"
"sync"
"testing"
"time"
"GoNavi-Wails/internal/connection"
)
const oracleRecordingDriverName = "gonavi_oracle_recording"
var (
registerOracleRecordingDriverOnce sync.Once
oracleRecordingDriverMu sync.Mutex
oracleRecordingDriverSeq int
oracleRecordingDriverStates = map[string]*oracleRecordingState{}
)
type oracleRecordingState struct {
mu sync.Mutex
execArgs [][]driver.NamedValue
}
func (s *oracleRecordingState) snapshotExecArgs() [][]driver.NamedValue {
s.mu.Lock()
defer s.mu.Unlock()
result := make([][]driver.NamedValue, len(s.execArgs))
for i, args := range s.execArgs {
result[i] = append([]driver.NamedValue(nil), args...)
}
return result
}
type oracleRecordingDriver struct{}
func (oracleRecordingDriver) Open(name string) (driver.Conn, error) {
oracleRecordingDriverMu.Lock()
state := oracleRecordingDriverStates[name]
oracleRecordingDriverMu.Unlock()
if state == nil {
return nil, fmt.Errorf("recording state not found: %s", name)
}
return &oracleRecordingConn{state: state}, nil
}
type oracleRecordingConn struct {
state *oracleRecordingState
}
func (c *oracleRecordingConn) Prepare(query string) (driver.Stmt, error) {
return nil, fmt.Errorf("prepare not supported in oracle recording driver: %s", query)
}
func (c *oracleRecordingConn) Close() error { return nil }
func (c *oracleRecordingConn) Begin() (driver.Tx, error) { return oracleRecordingTx{}, nil }
func (c *oracleRecordingConn) ExecContext(_ context.Context, _ string, args []driver.NamedValue) (driver.Result, error) {
c.state.mu.Lock()
defer c.state.mu.Unlock()
c.state.execArgs = append(c.state.execArgs, append([]driver.NamedValue(nil), args...))
return driver.RowsAffected(1), nil
}
func (c *oracleRecordingConn) QueryContext(_ context.Context, query string, _ []driver.NamedValue) (driver.Rows, error) {
if strings.Contains(strings.ToLower(query), "tab_columns") {
return &oracleRecordingRows{
columns: []string{"COLUMN_NAME", "DATA_TYPE", "NULLABLE", "DATA_DEFAULT"},
rows: [][]driver.Value{
{"UPDATED_AT", "TIMESTAMP", "YES", nil},
{"CREATED_AT", "DATE", "NO", nil},
},
}, nil
}
return &oracleRecordingRows{}, nil
}
var _ driver.ExecerContext = (*oracleRecordingConn)(nil)
var _ driver.QueryerContext = (*oracleRecordingConn)(nil)
type oracleRecordingTx struct{}
func (oracleRecordingTx) Commit() error { return nil }
func (oracleRecordingTx) Rollback() error { return nil }
type oracleRecordingRows struct {
columns []string
rows [][]driver.Value
index int
}
func (r *oracleRecordingRows) Columns() []string {
return append([]string(nil), r.columns...)
}
func (r *oracleRecordingRows) Close() error { return nil }
func (r *oracleRecordingRows) Next(dest []driver.Value) error {
if r.index >= len(r.rows) {
return io.EOF
}
row := r.rows[r.index]
for idx := range dest {
if idx < len(row) {
dest[idx] = row[idx]
}
}
r.index++
return nil
}
func openOracleRecordingDB(t *testing.T) (*sql.DB, *oracleRecordingState) {
t.Helper()
registerOracleRecordingDriverOnce.Do(func() {
sql.Register(oracleRecordingDriverName, oracleRecordingDriver{})
})
oracleRecordingDriverMu.Lock()
oracleRecordingDriverSeq++
dsn := fmt.Sprintf("oracle-recording-%d", oracleRecordingDriverSeq)
state := &oracleRecordingState{}
oracleRecordingDriverStates[dsn] = state
oracleRecordingDriverMu.Unlock()
dbConn, err := sql.Open(oracleRecordingDriverName, dsn)
if err != nil {
t.Fatalf("打开 recording db 失败: %v", err)
}
t.Cleanup(func() {
_ = dbConn.Close()
oracleRecordingDriverMu.Lock()
delete(oracleRecordingDriverStates, dsn)
oracleRecordingDriverMu.Unlock()
})
return dbConn, state
}
func TestOracleApplyChangesNormalizesTemporalStringsForUpdate(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
oracleDB := &OracleDB{conn: dbConn}
changes := connection.ChangeSet{
Updates: []connection.UpdateRow{{
Keys: map[string]interface{}{
"CREATED_AT": "2026-03-05T10:30:00Z",
},
Values: map[string]interface{}{
"UPDATED_AT": "2026-04-01T12:13:14.123456789Z",
},
}},
}
if err := oracleDB.ApplyChanges("EVENTS", changes); err != nil {
t.Fatalf("ApplyChanges 返回错误: %v", err)
}
executions := state.snapshotExecArgs()
if len(executions) != 1 {
t.Fatalf("期望执行 1 条更新,实际 %d 条", len(executions))
}
args := executions[0]
if len(args) != 2 {
t.Fatalf("期望 2 个绑定参数,实际 %d 个: %#v", len(args), args)
}
if _, ok := args[0].Value.(time.Time); !ok {
t.Fatalf("更新时间字段应绑定为 time.Time实际=%#v(%T)", args[0].Value, args[0].Value)
}
if _, ok := args[1].Value.(time.Time); !ok {
t.Fatalf("日期主键字段应绑定为 time.Time实际=%#v(%T)", args[1].Value, args[1].Value)
}
}

View File

@@ -389,11 +389,118 @@ func (o *OracleDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDe
return triggers, nil
}
func splitOracleQualifiedTableName(raw string) (string, string) {
table := strings.TrimSpace(raw)
schema := ""
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
schema = strings.Trim(strings.TrimSpace(parts[0]), "\"")
table = strings.TrimSpace(parts[1])
}
table = strings.Trim(strings.TrimSpace(table), "\"")
return schema, table
}
func (o *OracleDB) loadColumnTypeMap(tableName string) map[string]string {
result := map[string]string{}
schema, table := splitOracleQualifiedTableName(tableName)
if table == "" {
return result
}
columns, err := o.GetColumns(schema, table)
if err != nil {
logger.Warnf("加载 Oracle 列元数据失败(不影响提交):表=%s err=%v", tableName, err)
return result
}
for _, col := range columns {
name := strings.ToLower(strings.TrimSpace(col.Name))
if name == "" {
continue
}
result[name] = strings.TrimSpace(col.Type)
}
return result
}
func normalizeOracleValueForWrite(columnName string, value interface{}, columnTypeMap map[string]string) interface{} {
columnType := columnTypeMap[strings.ToLower(strings.TrimSpace(columnName))]
if !isOracleTemporalColumnType(columnType) {
return value
}
if value == nil {
return nil
}
text, ok := value.(string)
if !ok {
return value
}
raw := strings.TrimSpace(text)
if raw == "" {
return nil
}
if parsed, ok := parseOracleTemporalString(raw); ok {
return parsed
}
return value
}
func isOracleTemporalColumnType(columnType string) bool {
typ := strings.ToUpper(strings.TrimSpace(columnType))
return strings.Contains(typ, "DATE") || strings.Contains(typ, "TIMESTAMP")
}
func parseOracleTemporalString(raw string) (time.Time, bool) {
text := strings.TrimSpace(raw)
if text == "" {
return time.Time{}, false
}
text = strings.ReplaceAll(text, "+ ", "+")
text = strings.ReplaceAll(text, "- ", "-")
candidates := []string{text}
if len(text) >= 19 && text[10] == ' ' && (strings.HasSuffix(text, "Z") || hasTimezoneOffset(text)) {
candidates = append(candidates, strings.Replace(text, " ", "T", 1))
}
layoutsWithZone := []string{
"2006-01-02 15:04:05.999999999 -0700 MST",
"2006-01-02 15:04:05 -0700 MST",
"2006-01-02 15:04:05.999999999 -0700",
"2006-01-02 15:04:05 -0700",
time.RFC3339Nano,
time.RFC3339,
}
for _, candidate := range candidates {
for _, layout := range layoutsWithZone {
if parsed, err := time.Parse(layout, candidate); err == nil {
return parsed, true
}
}
}
layoutsWithoutZone := []string{
"2006-01-02T15:04:05.999999999",
"2006-01-02T15:04:05",
"2006-01-02 15:04:05.999999999",
"2006-01-02 15:04:05",
"2006-01-02",
}
for _, layout := range layoutsWithoutZone {
if parsed, err := time.ParseInLocation(layout, text, time.Local); err == nil {
return parsed, true
}
}
return time.Time{}, false
}
func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
if o.conn == nil {
return fmt.Errorf("连接未打开")
}
columnTypeMap := o.loadColumnTypeMap(tableName)
tx, err := o.conn.Begin()
if err != nil {
return err
@@ -432,7 +539,7 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet)
for k, v := range pk {
idx++
wheres = append(wheres, fmt.Sprintf("%s = :%d", quoteIdent(k), idx))
args = append(args, v)
args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap))
}
if len(wheres) == 0 {
continue
@@ -452,7 +559,7 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet)
for k, v := range update.Values {
idx++
sets = append(sets, fmt.Sprintf("%s = :%d", quoteIdent(k), idx))
args = append(args, v)
args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap))
}
if len(sets) == 0 {
@@ -463,7 +570,7 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet)
for k, v := range update.Keys {
idx++
wheres = append(wheres, fmt.Sprintf("%s = :%d", quoteIdent(k), idx))
args = append(args, v)
args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap))
}
if len(wheres) == 0 {
@@ -487,7 +594,7 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet)
idx++
cols = append(cols, quoteIdent(k))
placeholders = append(placeholders, fmt.Sprintf(":%d", idx))
args = append(args, v)
args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap))
}
if len(cols) == 0 {