mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-06-28 01:11:31 +08:00
🐛 fix(oracle): 修复查询结果编辑提交日期格式报错
- 参数处理:提交事务前加载 Oracle 表字段类型,用于识别 DATE 和 TIMESTAMP 字段 - 更新修复:UPDATE 的 SET 值和 WHERE 条件统一转换日期时间参数 - 场景覆盖:修复新建查询结果编辑后提交事务触发 ORA-01861 的问题 - 类型绑定:将 Oracle 日期时间字符串解析为 time.Time,避免依赖数据库会话日期格式 - 兼容处理:支持 RFC3339、带时区和常见本地日期时间格式 - 测试覆盖:新增 Oracle ApplyChanges recording driver 回归测试 Refs #419
This commit is contained in:
183
internal/db/oracle_applychanges_test.go
Normal file
183
internal/db/oracle_applychanges_test.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
const oracleRecordingDriverName = "gonavi_oracle_recording"
|
||||
|
||||
var (
|
||||
registerOracleRecordingDriverOnce sync.Once
|
||||
oracleRecordingDriverMu sync.Mutex
|
||||
oracleRecordingDriverSeq int
|
||||
oracleRecordingDriverStates = map[string]*oracleRecordingState{}
|
||||
)
|
||||
|
||||
type oracleRecordingState struct {
|
||||
mu sync.Mutex
|
||||
execArgs [][]driver.NamedValue
|
||||
}
|
||||
|
||||
func (s *oracleRecordingState) snapshotExecArgs() [][]driver.NamedValue {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
result := make([][]driver.NamedValue, len(s.execArgs))
|
||||
for i, args := range s.execArgs {
|
||||
result[i] = append([]driver.NamedValue(nil), args...)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type oracleRecordingDriver struct{}
|
||||
|
||||
func (oracleRecordingDriver) Open(name string) (driver.Conn, error) {
|
||||
oracleRecordingDriverMu.Lock()
|
||||
state := oracleRecordingDriverStates[name]
|
||||
oracleRecordingDriverMu.Unlock()
|
||||
if state == nil {
|
||||
return nil, fmt.Errorf("recording state not found: %s", name)
|
||||
}
|
||||
return &oracleRecordingConn{state: state}, nil
|
||||
}
|
||||
|
||||
type oracleRecordingConn struct {
|
||||
state *oracleRecordingState
|
||||
}
|
||||
|
||||
func (c *oracleRecordingConn) Prepare(query string) (driver.Stmt, error) {
|
||||
return nil, fmt.Errorf("prepare not supported in oracle recording driver: %s", query)
|
||||
}
|
||||
|
||||
func (c *oracleRecordingConn) Close() error { return nil }
|
||||
|
||||
func (c *oracleRecordingConn) Begin() (driver.Tx, error) { return oracleRecordingTx{}, nil }
|
||||
|
||||
func (c *oracleRecordingConn) ExecContext(_ context.Context, _ string, args []driver.NamedValue) (driver.Result, error) {
|
||||
c.state.mu.Lock()
|
||||
defer c.state.mu.Unlock()
|
||||
c.state.execArgs = append(c.state.execArgs, append([]driver.NamedValue(nil), args...))
|
||||
return driver.RowsAffected(1), nil
|
||||
}
|
||||
|
||||
func (c *oracleRecordingConn) QueryContext(_ context.Context, query string, _ []driver.NamedValue) (driver.Rows, error) {
|
||||
if strings.Contains(strings.ToLower(query), "tab_columns") {
|
||||
return &oracleRecordingRows{
|
||||
columns: []string{"COLUMN_NAME", "DATA_TYPE", "NULLABLE", "DATA_DEFAULT"},
|
||||
rows: [][]driver.Value{
|
||||
{"UPDATED_AT", "TIMESTAMP", "YES", nil},
|
||||
{"CREATED_AT", "DATE", "NO", nil},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return &oracleRecordingRows{}, nil
|
||||
}
|
||||
|
||||
var _ driver.ExecerContext = (*oracleRecordingConn)(nil)
|
||||
var _ driver.QueryerContext = (*oracleRecordingConn)(nil)
|
||||
|
||||
type oracleRecordingTx struct{}
|
||||
|
||||
func (oracleRecordingTx) Commit() error { return nil }
|
||||
func (oracleRecordingTx) Rollback() error { return nil }
|
||||
|
||||
type oracleRecordingRows struct {
|
||||
columns []string
|
||||
rows [][]driver.Value
|
||||
index int
|
||||
}
|
||||
|
||||
func (r *oracleRecordingRows) Columns() []string {
|
||||
return append([]string(nil), r.columns...)
|
||||
}
|
||||
|
||||
func (r *oracleRecordingRows) Close() error { return nil }
|
||||
|
||||
func (r *oracleRecordingRows) Next(dest []driver.Value) error {
|
||||
if r.index >= len(r.rows) {
|
||||
return io.EOF
|
||||
}
|
||||
row := r.rows[r.index]
|
||||
for idx := range dest {
|
||||
if idx < len(row) {
|
||||
dest[idx] = row[idx]
|
||||
}
|
||||
}
|
||||
r.index++
|
||||
return nil
|
||||
}
|
||||
|
||||
func openOracleRecordingDB(t *testing.T) (*sql.DB, *oracleRecordingState) {
|
||||
t.Helper()
|
||||
registerOracleRecordingDriverOnce.Do(func() {
|
||||
sql.Register(oracleRecordingDriverName, oracleRecordingDriver{})
|
||||
})
|
||||
|
||||
oracleRecordingDriverMu.Lock()
|
||||
oracleRecordingDriverSeq++
|
||||
dsn := fmt.Sprintf("oracle-recording-%d", oracleRecordingDriverSeq)
|
||||
state := &oracleRecordingState{}
|
||||
oracleRecordingDriverStates[dsn] = state
|
||||
oracleRecordingDriverMu.Unlock()
|
||||
|
||||
dbConn, err := sql.Open(oracleRecordingDriverName, dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("打开 recording db 失败: %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = dbConn.Close()
|
||||
oracleRecordingDriverMu.Lock()
|
||||
delete(oracleRecordingDriverStates, dsn)
|
||||
oracleRecordingDriverMu.Unlock()
|
||||
})
|
||||
|
||||
return dbConn, state
|
||||
}
|
||||
|
||||
func TestOracleApplyChangesNormalizesTemporalStringsForUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbConn, state := openOracleRecordingDB(t)
|
||||
oracleDB := &OracleDB{conn: dbConn}
|
||||
|
||||
changes := connection.ChangeSet{
|
||||
Updates: []connection.UpdateRow{{
|
||||
Keys: map[string]interface{}{
|
||||
"CREATED_AT": "2026-03-05T10:30:00Z",
|
||||
},
|
||||
Values: map[string]interface{}{
|
||||
"UPDATED_AT": "2026-04-01T12:13:14.123456789Z",
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
if err := oracleDB.ApplyChanges("EVENTS", changes); err != nil {
|
||||
t.Fatalf("ApplyChanges 返回错误: %v", err)
|
||||
}
|
||||
|
||||
executions := state.snapshotExecArgs()
|
||||
if len(executions) != 1 {
|
||||
t.Fatalf("期望执行 1 条更新,实际 %d 条", len(executions))
|
||||
}
|
||||
args := executions[0]
|
||||
if len(args) != 2 {
|
||||
t.Fatalf("期望 2 个绑定参数,实际 %d 个: %#v", len(args), args)
|
||||
}
|
||||
if _, ok := args[0].Value.(time.Time); !ok {
|
||||
t.Fatalf("更新时间字段应绑定为 time.Time,实际=%#v(%T)", args[0].Value, args[0].Value)
|
||||
}
|
||||
if _, ok := args[1].Value.(time.Time); !ok {
|
||||
t.Fatalf("日期主键字段应绑定为 time.Time,实际=%#v(%T)", args[1].Value, args[1].Value)
|
||||
}
|
||||
}
|
||||
@@ -389,11 +389,118 @@ func (o *OracleDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDe
|
||||
return triggers, nil
|
||||
}
|
||||
|
||||
func splitOracleQualifiedTableName(raw string) (string, string) {
|
||||
table := strings.TrimSpace(raw)
|
||||
schema := ""
|
||||
if parts := strings.SplitN(table, ".", 2); len(parts) == 2 {
|
||||
schema = strings.Trim(strings.TrimSpace(parts[0]), "\"")
|
||||
table = strings.TrimSpace(parts[1])
|
||||
}
|
||||
table = strings.Trim(strings.TrimSpace(table), "\"")
|
||||
return schema, table
|
||||
}
|
||||
|
||||
func (o *OracleDB) loadColumnTypeMap(tableName string) map[string]string {
|
||||
result := map[string]string{}
|
||||
schema, table := splitOracleQualifiedTableName(tableName)
|
||||
if table == "" {
|
||||
return result
|
||||
}
|
||||
|
||||
columns, err := o.GetColumns(schema, table)
|
||||
if err != nil {
|
||||
logger.Warnf("加载 Oracle 列元数据失败(不影响提交):表=%s err=%v", tableName, err)
|
||||
return result
|
||||
}
|
||||
|
||||
for _, col := range columns {
|
||||
name := strings.ToLower(strings.TrimSpace(col.Name))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
result[name] = strings.TrimSpace(col.Type)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func normalizeOracleValueForWrite(columnName string, value interface{}, columnTypeMap map[string]string) interface{} {
|
||||
columnType := columnTypeMap[strings.ToLower(strings.TrimSpace(columnName))]
|
||||
if !isOracleTemporalColumnType(columnType) {
|
||||
return value
|
||||
}
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
text, ok := value.(string)
|
||||
if !ok {
|
||||
return value
|
||||
}
|
||||
raw := strings.TrimSpace(text)
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
if parsed, ok := parseOracleTemporalString(raw); ok {
|
||||
return parsed
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func isOracleTemporalColumnType(columnType string) bool {
|
||||
typ := strings.ToUpper(strings.TrimSpace(columnType))
|
||||
return strings.Contains(typ, "DATE") || strings.Contains(typ, "TIMESTAMP")
|
||||
}
|
||||
|
||||
func parseOracleTemporalString(raw string) (time.Time, bool) {
|
||||
text := strings.TrimSpace(raw)
|
||||
if text == "" {
|
||||
return time.Time{}, false
|
||||
}
|
||||
text = strings.ReplaceAll(text, "+ ", "+")
|
||||
text = strings.ReplaceAll(text, "- ", "-")
|
||||
|
||||
candidates := []string{text}
|
||||
if len(text) >= 19 && text[10] == ' ' && (strings.HasSuffix(text, "Z") || hasTimezoneOffset(text)) {
|
||||
candidates = append(candidates, strings.Replace(text, " ", "T", 1))
|
||||
}
|
||||
|
||||
layoutsWithZone := []string{
|
||||
"2006-01-02 15:04:05.999999999 -0700 MST",
|
||||
"2006-01-02 15:04:05 -0700 MST",
|
||||
"2006-01-02 15:04:05.999999999 -0700",
|
||||
"2006-01-02 15:04:05 -0700",
|
||||
time.RFC3339Nano,
|
||||
time.RFC3339,
|
||||
}
|
||||
for _, candidate := range candidates {
|
||||
for _, layout := range layoutsWithZone {
|
||||
if parsed, err := time.Parse(layout, candidate); err == nil {
|
||||
return parsed, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
layoutsWithoutZone := []string{
|
||||
"2006-01-02T15:04:05.999999999",
|
||||
"2006-01-02T15:04:05",
|
||||
"2006-01-02 15:04:05.999999999",
|
||||
"2006-01-02 15:04:05",
|
||||
"2006-01-02",
|
||||
}
|
||||
for _, layout := range layoutsWithoutZone {
|
||||
if parsed, err := time.ParseInLocation(layout, text, time.Local); err == nil {
|
||||
return parsed, true
|
||||
}
|
||||
}
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
|
||||
if o.conn == nil {
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
columnTypeMap := o.loadColumnTypeMap(tableName)
|
||||
|
||||
tx, err := o.conn.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -432,7 +539,7 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet)
|
||||
for k, v := range pk {
|
||||
idx++
|
||||
wheres = append(wheres, fmt.Sprintf("%s = :%d", quoteIdent(k), idx))
|
||||
args = append(args, v)
|
||||
args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap))
|
||||
}
|
||||
if len(wheres) == 0 {
|
||||
continue
|
||||
@@ -452,7 +559,7 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet)
|
||||
for k, v := range update.Values {
|
||||
idx++
|
||||
sets = append(sets, fmt.Sprintf("%s = :%d", quoteIdent(k), idx))
|
||||
args = append(args, v)
|
||||
args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap))
|
||||
}
|
||||
|
||||
if len(sets) == 0 {
|
||||
@@ -463,7 +570,7 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet)
|
||||
for k, v := range update.Keys {
|
||||
idx++
|
||||
wheres = append(wheres, fmt.Sprintf("%s = :%d", quoteIdent(k), idx))
|
||||
args = append(args, v)
|
||||
args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap))
|
||||
}
|
||||
|
||||
if len(wheres) == 0 {
|
||||
@@ -487,7 +594,7 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet)
|
||||
idx++
|
||||
cols = append(cols, quoteIdent(k))
|
||||
placeholders = append(placeholders, fmt.Sprintf(":%d", idx))
|
||||
args = append(args, v)
|
||||
args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap))
|
||||
}
|
||||
|
||||
if len(cols) == 0 {
|
||||
|
||||
Reference in New Issue
Block a user