Files
MyGoNavi/internal/db/oracle_applychanges_test.go
Syngnat 56b3112a07 🐛 fix(oracle): 修复表结构注释读取与保存报错
- 补齐 Oracle 表字段注释元数据读取

- 在表结构 DDL 中追加表和字段注释信息

- 规范表设计器 Oracle DDL 执行前的分号处理

Refs #482
2026-05-23 17:41:46 +08:00

390 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package db
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"io"
"strings"
"sync"
"testing"
"time"
"GoNavi-Wails/internal/connection"
)
const oracleRecordingDriverName = "gonavi_oracle_recording"
var (
registerOracleRecordingDriverOnce sync.Once
oracleRecordingDriverMu sync.Mutex
oracleRecordingDriverSeq int
oracleRecordingDriverStates = map[string]*oracleRecordingState{}
)
type oracleRecordingState struct {
mu sync.Mutex
execQueries []string
execArgs [][]driver.NamedValue
queries []string
rowsAffected int64
queryResults map[string]oracleRecordingQueryResult
queryError error
}
type oracleRecordingQueryResult struct {
columns []string
rows [][]driver.Value
}
func (s *oracleRecordingState) snapshotExecQueries() []string {
s.mu.Lock()
defer s.mu.Unlock()
return append([]string(nil), s.execQueries...)
}
func (s *oracleRecordingState) snapshotExecArgs() [][]driver.NamedValue {
s.mu.Lock()
defer s.mu.Unlock()
result := make([][]driver.NamedValue, len(s.execArgs))
for i, args := range s.execArgs {
result[i] = append([]driver.NamedValue(nil), args...)
}
return result
}
func (s *oracleRecordingState) snapshotQueries() []string {
s.mu.Lock()
defer s.mu.Unlock()
return append([]string(nil), s.queries...)
}
type oracleRecordingDriver struct{}
func (oracleRecordingDriver) Open(name string) (driver.Conn, error) {
oracleRecordingDriverMu.Lock()
state := oracleRecordingDriverStates[name]
oracleRecordingDriverMu.Unlock()
if state == nil {
return nil, fmt.Errorf("recording state not found: %s", name)
}
return &oracleRecordingConn{state: state}, nil
}
type oracleRecordingConn struct {
state *oracleRecordingState
}
func (c *oracleRecordingConn) Prepare(query string) (driver.Stmt, error) {
return nil, fmt.Errorf("prepare not supported in oracle recording driver: %s", query)
}
func (c *oracleRecordingConn) Close() error { return nil }
func (c *oracleRecordingConn) Begin() (driver.Tx, error) { return oracleRecordingTx{}, nil }
func (c *oracleRecordingConn) ExecContext(_ context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
c.state.mu.Lock()
defer c.state.mu.Unlock()
c.state.execQueries = append(c.state.execQueries, query)
c.state.execArgs = append(c.state.execArgs, append([]driver.NamedValue(nil), args...))
return driver.RowsAffected(c.state.rowsAffected), nil
}
func (c *oracleRecordingConn) QueryContext(_ context.Context, query string, _ []driver.NamedValue) (driver.Rows, error) {
c.state.mu.Lock()
c.state.queries = append(c.state.queries, query)
if err := c.state.queryError; err != nil {
c.state.mu.Unlock()
return nil, err
}
if result, ok := c.state.queryResults[query]; ok {
c.state.mu.Unlock()
return &oracleRecordingRows{
columns: append([]string(nil), result.columns...),
rows: cloneOracleRecordingRows(result.rows),
}, nil
}
c.state.mu.Unlock()
if strings.Contains(strings.ToLower(query), "tab_columns") {
return &oracleRecordingRows{
columns: []string{"COLUMN_NAME", "DATA_TYPE", "NULLABLE", "DATA_DEFAULT", "COLUMN_KEY", "COMMENT"},
rows: [][]driver.Value{
{"UPDATED_AT", "TIMESTAMP", "YES", nil, "", "更新时间"},
{"CREATED_AT", "DATE", "NO", nil, "", nil},
},
}, nil
}
return &oracleRecordingRows{}, nil
}
func cloneOracleRecordingRows(src [][]driver.Value) [][]driver.Value {
dst := make([][]driver.Value, len(src))
for i, row := range src {
dst[i] = append([]driver.Value(nil), row...)
}
return dst
}
var _ driver.ExecerContext = (*oracleRecordingConn)(nil)
var _ driver.QueryerContext = (*oracleRecordingConn)(nil)
type oracleRecordingTx struct{}
func (oracleRecordingTx) Commit() error { return nil }
func (oracleRecordingTx) Rollback() error { return nil }
type oracleRecordingRows struct {
columns []string
rows [][]driver.Value
index int
}
func (r *oracleRecordingRows) Columns() []string {
return append([]string(nil), r.columns...)
}
func (r *oracleRecordingRows) Close() error { return nil }
func (r *oracleRecordingRows) Next(dest []driver.Value) error {
if r.index >= len(r.rows) {
return io.EOF
}
row := r.rows[r.index]
for idx := range dest {
if idx < len(row) {
dest[idx] = row[idx]
}
}
r.index++
return nil
}
func openOracleRecordingDB(t *testing.T) (*sql.DB, *oracleRecordingState) {
t.Helper()
registerOracleRecordingDriverOnce.Do(func() {
sql.Register(oracleRecordingDriverName, oracleRecordingDriver{})
})
oracleRecordingDriverMu.Lock()
oracleRecordingDriverSeq++
dsn := fmt.Sprintf("oracle-recording-%d", oracleRecordingDriverSeq)
state := &oracleRecordingState{rowsAffected: 1, queryResults: map[string]oracleRecordingQueryResult{}}
oracleRecordingDriverStates[dsn] = state
oracleRecordingDriverMu.Unlock()
dbConn, err := sql.Open(oracleRecordingDriverName, dsn)
if err != nil {
t.Fatalf("打开 recording db 失败: %v", err)
}
t.Cleanup(func() {
_ = dbConn.Close()
oracleRecordingDriverMu.Lock()
delete(oracleRecordingDriverStates, dsn)
oracleRecordingDriverMu.Unlock()
})
return dbConn, state
}
func TestOracleApplyChangesReturnsErrorWhenUpdateMatchesNoRows(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
state.rowsAffected = 0
oracleDB := &OracleDB{conn: dbConn}
changes := connection.ChangeSet{
Updates: []connection.UpdateRow{{
Keys: map[string]interface{}{
"ID": 7,
},
Values: map[string]interface{}{
"NAME": "new-name",
},
}},
}
err := oracleDB.ApplyChanges("MYCIMLED.EDC_LOG", changes)
if err == nil {
t.Fatal("期望更新未匹配到行时返回错误,实际为 nil")
}
if !strings.Contains(err.Error(), "更新未生效") {
t.Fatalf("错误信息应提示更新未生效,实际=%v", err)
}
}
func TestOracleApplyChangesReturnsErrorWhenUpdateAffectsMultipleRows(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
state.rowsAffected = 2
oracleDB := &OracleDB{conn: dbConn}
changes := connection.ChangeSet{
Updates: []connection.UpdateRow{{
Keys: map[string]interface{}{
"ID": 7,
},
Values: map[string]interface{}{
"NAME": "new-name",
},
}},
}
err := oracleDB.ApplyChanges("MYCIMLED.EDC_LOG", changes)
if err == nil {
t.Fatal("期望更新影响多行时返回错误,实际为 nil")
}
if !strings.Contains(err.Error(), "影响了 2 行") {
t.Fatalf("错误信息应提示影响多行,实际=%v", err)
}
}
func TestOracleApplyChangesReturnsErrorWhenDeleteAffectsMultipleRows(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
state.rowsAffected = 2
oracleDB := &OracleDB{conn: dbConn}
changes := connection.ChangeSet{
Deletes: []map[string]interface{}{{
"STATUS": "stale",
}},
}
err := oracleDB.ApplyChanges("MYCIMLED.EDC_LOG", changes)
if err == nil {
t.Fatal("期望删除影响多行时返回错误,实际为 nil")
}
if !strings.Contains(err.Error(), "影响了 2 行") {
t.Fatalf("错误信息应提示影响多行,实际=%v", err)
}
}
func TestOracleApplyChangesNormalizesTemporalStringsForUpdate(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
oracleDB := &OracleDB{conn: dbConn}
changes := connection.ChangeSet{
Updates: []connection.UpdateRow{{
Keys: map[string]interface{}{
"CREATED_AT": "2026-03-05T10:30:00Z",
},
Values: map[string]interface{}{
"UPDATED_AT": "2026-04-01T12:13:14.123456789Z",
},
}},
}
if err := oracleDB.ApplyChanges("EVENTS", changes); err != nil {
t.Fatalf("ApplyChanges 返回错误: %v", err)
}
executions := state.snapshotExecArgs()
if len(executions) != 1 {
t.Fatalf("期望执行 1 条更新,实际 %d 条", len(executions))
}
args := executions[0]
if len(args) != 2 {
t.Fatalf("期望 2 个绑定参数,实际 %d 个: %#v", len(args), args)
}
if _, ok := args[0].Value.(time.Time); !ok {
t.Fatalf("更新时间字段应绑定为 time.Time实际=%#v(%T)", args[0].Value, args[0].Value)
}
if _, ok := args[1].Value.(time.Time); !ok {
t.Fatalf("日期主键字段应绑定为 time.Time实际=%#v(%T)", args[1].Value, args[1].Value)
}
}
func TestOracleApplyChangesUsesUnquotedRowIDLocator(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
oracleDB := &OracleDB{conn: dbConn}
changes := connection.ChangeSet{
LocatorStrategy: "oracle-rowid",
Updates: []connection.UpdateRow{{
Keys: map[string]interface{}{
"ROWID": "AAAA",
},
Values: map[string]interface{}{
"NAME": "new-name",
},
}},
}
if err := oracleDB.ApplyChanges("MYCIMLED.EDC_LOG", changes); err != nil {
t.Fatalf("ApplyChanges 返回错误: %v", err)
}
executions := state.snapshotExecQueries()
if len(executions) != 1 {
t.Fatalf("期望执行 1 条更新,实际 %d 条", len(executions))
}
query := executions[0]
if !strings.Contains(query, "ROWID = :2") {
t.Fatalf("ROWID 定位条件不正确: %s", query)
}
if strings.Contains(query, "\"ROWID\" =") {
t.Fatalf("ROWID 不应被当作普通列引用: %s", query)
}
}
func TestMySQLApplyChangesReturnsErrorWhenUpdateAffectsMultipleRows(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
state.rowsAffected = 2
mysqlDB := &MySQLDB{conn: dbConn}
changes := connection.ChangeSet{
Updates: []connection.UpdateRow{{
Keys: map[string]interface{}{
"id": 7,
},
Values: map[string]interface{}{
"name": "new-name",
},
}},
}
err := mysqlDB.ApplyChanges("users", changes)
if err == nil {
t.Fatal("期望 MySQL 更新影响多行时返回错误,实际为 nil")
}
if !strings.Contains(err.Error(), "影响了 2 行") {
t.Fatalf("错误信息应提示影响多行,实际=%v", err)
}
}
func TestPostgresApplyChangesReturnsErrorWhenDeleteAffectsMultipleRows(t *testing.T) {
t.Parallel()
dbConn, state := openOracleRecordingDB(t)
state.rowsAffected = 2
postgresDB := &PostgresDB{conn: dbConn}
changes := connection.ChangeSet{
Deletes: []map[string]interface{}{{
"id": 7,
}},
}
err := postgresDB.ApplyChanges("public.users", changes)
if err == nil {
t.Fatal("期望 PostgreSQL 删除影响多行时返回错误,实际为 nil")
}
if !strings.Contains(err.Error(), "影响了 2 行") {
t.Fatalf("错误信息应提示影响多行,实际=%v", err)
}
}