diff --git a/internal/db/oracle_applychanges_test.go b/internal/db/oracle_applychanges_test.go index b06041c..ce7024e 100644 --- a/internal/db/oracle_applychanges_test.go +++ b/internal/db/oracle_applychanges_test.go @@ -259,6 +259,72 @@ func TestOracleOpenTransactionExecerUsesPinnedSessionTransactionSQL(t *testing.T } } +func TestOracleApplyChangesUsesPinnedSessionTransactionSQL(t *testing.T) { + t.Parallel() + + dbConn, state := openOracleRecordingDB(t) + oracleDB := &OracleDB{conn: dbConn} + + changes := connection.ChangeSet{ + Updates: []connection.UpdateRow{{ + Keys: map[string]interface{}{ + "ID": 7, + }, + Values: map[string]interface{}{ + "NAME": "new-name", + }, + }}, + } + + if err := oracleDB.ApplyChanges("MYCIMLED.EDC_LOG", changes); err != nil { + t.Fatalf("ApplyChanges() unexpected error: %v", err) + } + if got := state.snapshotBeginCalls(); got != 0 { + t.Fatalf("expected Oracle ApplyChanges not to call database/sql Begin, got %d", got) + } + wantExecs := []string{ + `UPDATE "MYCIMLED"."EDC_LOG" SET "NAME" = :1 WHERE "ID" = :2`, + "COMMIT", + } + if got := state.snapshotExecQueries(); !reflect.DeepEqual(got, wantExecs) { + t.Fatalf("expected Oracle ApplyChanges pinned-session execs %#v, got %#v", wantExecs, got) + } +} + +func TestOracleApplyChangesRollsBackPinnedSessionOnError(t *testing.T) { + t.Parallel() + + dbConn, state := openOracleRecordingDB(t) + state.rowsAffected = 0 + oracleDB := &OracleDB{conn: dbConn} + + changes := connection.ChangeSet{ + Updates: []connection.UpdateRow{{ + Keys: map[string]interface{}{ + "ID": 7, + }, + Values: map[string]interface{}{ + "NAME": "new-name", + }, + }}, + } + + err := oracleDB.ApplyChanges("MYCIMLED.EDC_LOG", changes) + if err == nil { + t.Fatal("expected ApplyChanges to return update row-count error") + } + if got := state.snapshotBeginCalls(); got != 0 { + t.Fatalf("expected Oracle ApplyChanges not to call database/sql Begin, got %d", got) + } + wantExecs := []string{ + `UPDATE "MYCIMLED"."EDC_LOG" SET "NAME" = :1 WHERE "ID" = :2`, + "ROLLBACK", + } + if got := state.snapshotExecQueries(); !reflect.DeepEqual(got, wantExecs) { + t.Fatalf("expected Oracle ApplyChanges pinned-session rollback execs %#v, got %#v", wantExecs, got) + } +} + func TestOracleApplyChangesReturnsErrorWhenUpdateMatchesNoRows(t *testing.T) { t.Parallel() @@ -357,8 +423,8 @@ func TestOracleApplyChangesNormalizesTemporalStringsForUpdate(t *testing.T) { } executions := state.snapshotExecArgs() - if len(executions) != 1 { - t.Fatalf("期望执行 1 条更新,实际 %d 条", len(executions)) + if len(executions) == 0 { + t.Fatal("期望至少执行 1 条更新,实际没有执行") } args := executions[0] if len(args) != 2 { @@ -395,8 +461,8 @@ func TestOracleApplyChangesUsesUnquotedRowIDLocator(t *testing.T) { } executions := state.snapshotExecQueries() - if len(executions) != 1 { - t.Fatalf("期望执行 1 条更新,实际 %d 条", len(executions)) + if len(executions) == 0 { + t.Fatal("期望至少执行 1 条更新,实际没有执行") } query := executions[0] if !strings.Contains(query, "ROWID = :2") { diff --git a/internal/db/oracle_impl.go b/internal/db/oracle_impl.go index 459bfa0..409e3fa 100644 --- a/internal/db/oracle_impl.go +++ b/internal/db/oracle_impl.go @@ -823,7 +823,7 @@ func parseOracleTemporalString(raw string) (time.Time, bool) { return time.Time{}, false } -func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet) error { +func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet) (err error) { if o.conn == nil { return fmt.Errorf("连接未打开") } @@ -833,11 +833,26 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet) return err } - tx, err := o.conn.Begin() + ctx := context.Background() + conn, err := o.conn.Conn(ctx) if err != nil { return err } - defer tx.Rollback() + defer func() { + if closeErr := conn.Close(); closeErr != nil && err == nil { + err = closeErr + } + }() + + transactionFinished := false + defer func() { + if transactionFinished { + return + } + if _, rollbackErr := conn.ExecContext(context.Background(), "ROLLBACK"); rollbackErr != nil { + logger.Warnf("Oracle 表格编辑事务回滚失败:table=%s err=%v", tableName, rollbackErr) + } + }() quoteIdent := func(name string) string { n := strings.TrimSpace(name) @@ -888,7 +903,7 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet) continue } query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND ")) - res, err := tx.Exec(query, args...) + res, err := conn.ExecContext(ctx, query, args...) if err != nil { return fmt.Errorf("删除失败:%v", err) } @@ -921,7 +936,7 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet) } query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND ")) - res, err := tx.Exec(query, args...) + res, err := conn.ExecContext(ctx, query, args...) if err != nil { return fmt.Errorf("更新失败:%v", err) } @@ -949,7 +964,7 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet) } query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", ")) - res, err := tx.Exec(query, args...) + res, err := conn.ExecContext(ctx, query, args...) if err != nil { return fmt.Errorf("插入失败:%v", err) } @@ -958,7 +973,11 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet) } } - return tx.Commit() + if _, err := conn.ExecContext(ctx, "COMMIT"); err != nil { + return fmt.Errorf("事务提交失败:%v", err) + } + transactionFinished = true + return nil } func (o *OracleDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {