diff --git a/internal/app/methods_db_multi_test.go b/internal/app/methods_db_multi_test.go index 9b7e164..8ede51d 100644 --- a/internal/app/methods_db_multi_test.go +++ b/internal/app/methods_db_multi_test.go @@ -683,6 +683,68 @@ func TestDBQueryMultiTransactionalUsesImplicitSessionTransactionForOracle(t *tes } } +func TestDBQueryMultiTransactionalUsesOracleImplicitSessionForOceanBaseOracleProtocol(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + originalVerifyDriverAgentRevisionFunc := verifyDriverAgentRevisionFunc + t.Cleanup(func() { + newDatabaseFunc = originalNewDatabaseFunc + verifyDriverAgentRevisionFunc = originalVerifyDriverAgentRevisionFunc + }) + + stmt := "UPDATE USERS SET NAME = 'new' WHERE ID = 1" + fakeDB := &fakeBatchWriteDB{ + execAffected: map[string]int64{ + stmt: 1, + }, + } + newDatabaseFunc = func(dbType string) (db.Database, error) { + return fakeDB, nil + } + verifyDriverAgentRevisionFunc = func(config connection.ConnectionConfig) error { + return nil + } + + app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test")) + config := connection.ConnectionConfig{ + Type: "oceanbase", + Host: "127.0.0.1", + Port: 2881, + User: "app", + OceanBaseProtocol: "oracle", + } + + result := app.DBQueryMultiTransactional(config, "APP", stmt, "ob-oracle-tx-query") + if !result.Success { + t.Fatalf("expected OceanBase Oracle transactional query success, got failure: %s", result.Message) + } + if result.TransactionID == "" || !result.TransactionPending { + t.Fatalf("expected pending transaction metadata, got id=%q pending=%v", result.TransactionID, result.TransactionPending) + } + if fakeDB.session == nil { + t.Fatal("expected OceanBase Oracle transactional query to open a pinned Oracle-style session") + } + if fakeDB.session.closed { + t.Fatal("expected OceanBase Oracle transaction session to stay open before commit") + } + if len(fakeDB.execQueries) != 1 || fakeDB.execQueries[0] != stmt { + t.Fatalf("expected OceanBase Oracle to skip START TRANSACTION and execute only DML before commit, got %#v", fakeDB.execQueries) + } + + commitResult := app.DBCommitTransaction(result.TransactionID) + if !commitResult.Success { + t.Fatalf("expected OceanBase Oracle commit success, got failure: %s", commitResult.Message) + } + wantExecs := []string{stmt, "COMMIT"} + if len(fakeDB.execQueries) != len(wantExecs) { + t.Fatalf("expected OceanBase Oracle implicit transaction COMMIT on pinned session, got %#v", fakeDB.execQueries) + } + for i, want := range wantExecs { + if fakeDB.execQueries[i] != want { + t.Fatalf("expected exec query %d = %q, got %q", i, want, fakeDB.execQueries[i]) + } + } +} + func TestDBQueryMultiTransactionalOracleImplicitSessionOutlivesAppContextCancellation(t *testing.T) { originalNewDatabaseFunc := newDatabaseFunc t.Cleanup(func() { diff --git a/internal/app/methods_db_transaction.go b/internal/app/methods_db_transaction.go index 3001e3a..ba6c5d6 100644 --- a/internal/app/methods_db_transaction.go +++ b/internal/app/methods_db_transaction.go @@ -19,19 +19,22 @@ const sqlEditorTransactionFinishTimeout = 30 * time.Second // is called by the SQL editor UI. func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbName string, query string, queryID string) connection.QueryResult { runConfig := normalizeRunConfig(config, dbName) + transactionDBType := resolveDDLDBType(runConfig) + transactionConfig := runConfig + transactionConfig.Type = transactionDBType if queryID == "" { queryID = generateQueryID() } - query = sanitizeSQLForPgLike(resolveDDLDBType(config), query) - if !shouldUseManagedSQLTransaction(runConfig.Type, query) { + query = sanitizeSQLForPgLike(transactionDBType, query) + if !shouldUseManagedSQLTransaction(transactionDBType, query) { return a.DBQueryMulti(config, dbName, query, queryID) } - beginSQL, commitSQL, rollbackSQL, hasTextTransaction := sqlFileBatchTransactionSQL(runConfig.Type) + beginSQL, commitSQL, rollbackSQL, hasTextTransaction := sqlFileBatchTransactionSQL(transactionDBType) implicitTextTransaction := false - if implicitCommitSQL, implicitRollbackSQL, ok := sqlEditorImplicitTransactionSQL(runConfig.Type); ok { + if implicitCommitSQL, implicitRollbackSQL, ok := sqlEditorImplicitTransactionSQL(transactionDBType); ok { commitSQL = implicitCommitSQL rollbackSQL = implicitRollbackSQL hasTextTransaction = true @@ -70,7 +73,7 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa if !ok { return connection.QueryResult{ Success: false, - Message: fmt.Sprintf("当前数据源(%s)不支持 SQL 编辑器托管事务", runConfig.Type), + Message: fmt.Sprintf("当前数据源(%s)不支持 SQL 编辑器托管事务", transactionDBType), QueryID: queryID, } } @@ -97,7 +100,7 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa if !hasTextTransaction { return connection.QueryResult{ Success: false, - Message: fmt.Sprintf("当前数据源(%s)不支持 SQL 编辑器托管事务", runConfig.Type), + Message: fmt.Sprintf("当前数据源(%s)不支持 SQL 编辑器托管事务", transactionDBType), QueryID: queryID, } } @@ -105,7 +108,7 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa if !ok { return connection.QueryResult{ Success: false, - Message: fmt.Sprintf("当前数据源(%s)不支持 SQL 编辑器托管事务", runConfig.Type), + Message: fmt.Sprintf("当前数据源(%s)不支持 SQL 编辑器托管事务", transactionDBType), QueryID: queryID, } } @@ -137,7 +140,7 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa } statements := splitSQLStatements(query) - resultSets, err := executeManagedSQLTransactionStatements(ctx, sessionExecer, runConfig, statements) + resultSets, err := executeManagedSQLTransactionStatements(ctx, sessionExecer, transactionConfig, statements) if err != nil { var rollbackErr error if transactor != nil { @@ -163,7 +166,7 @@ func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbNa execer: sessionExecer, transactor: transactor, cancel: transactionCancel, - dbType: runConfig.Type, + dbType: transactionDBType, commitSQL: commitSQL, rollbackSQL: rollbackSQL, createdAt: time.Now(), diff --git a/internal/db/database.go b/internal/db/database.go index 7ca6c1c..1272599 100644 --- a/internal/db/database.go +++ b/internal/db/database.go @@ -128,8 +128,8 @@ type TransactionExecer interface { Rollback() error } -// TransactionExecerProvider is implemented by drivers that can expose a real -// database/sql transaction for long-running SQL editor managed transactions. +// TransactionExecerProvider is implemented by drivers that can expose a +// long-running SQL editor managed transaction. type TransactionExecerProvider interface { OpenTransactionExecer(ctx context.Context) (TransactionExecer, error) } @@ -200,6 +200,141 @@ func (e *sqlConnStatementExecer) Close() error { return e.conn.Close() } +type sqlConnTransactionExecer struct { + mu sync.Mutex + conn *sql.Conn + done bool + commitSQL string + rollbackSQL string +} + +func NewSQLConnTransactionExecer(conn *sql.Conn, commitSQL string, rollbackSQL string) TransactionExecer { + return &sqlConnTransactionExecer{ + conn: conn, + commitSQL: strings.TrimSpace(commitSQL), + rollbackSQL: strings.TrimSpace(rollbackSQL), + } +} + +func (e *sqlConnTransactionExecer) activeConn() (*sql.Conn, error) { + if e == nil { + return nil, fmt.Errorf("连接未打开") + } + e.mu.Lock() + defer e.mu.Unlock() + if e.conn == nil { + return nil, fmt.Errorf("连接未打开") + } + if e.done { + return nil, fmt.Errorf("事务已结束") + } + return e.conn, nil +} + +func (e *sqlConnTransactionExecer) ExecContext(ctx context.Context, query string) (int64, error) { + conn, err := e.activeConn() + if err != nil { + return 0, err + } + res, err := conn.ExecContext(ctx, query) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func (e *sqlConnTransactionExecer) Exec(query string) (int64, error) { + return e.ExecContext(context.Background(), query) +} + +func (e *sqlConnTransactionExecer) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { + conn, err := e.activeConn() + if err != nil { + return nil, nil, err + } + rows, err := conn.QueryContext(ctx, query) + if err != nil { + return nil, nil, err + } + defer rows.Close() + return scanRows(rows) +} + +func (e *sqlConnTransactionExecer) Query(query string) ([]map[string]interface{}, []string, error) { + return e.QueryContext(context.Background(), query) +} + +func (e *sqlConnTransactionExecer) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) { + conn, err := e.activeConn() + if err != nil { + return nil, err + } + rows, err := conn.QueryContext(ctx, query) + if err != nil { + return nil, err + } + defer rows.Close() + return scanMultiRows(rows) +} + +func (e *sqlConnTransactionExecer) QueryMulti(query string) ([]connection.ResultSetData, error) { + return e.QueryMultiContext(context.Background(), query) +} + +func (e *sqlConnTransactionExecer) finish(sqlText string) error { + if e == nil { + return nil + } + e.mu.Lock() + if e.conn == nil || e.done { + e.mu.Unlock() + return nil + } + conn := e.conn + e.done = true + e.mu.Unlock() + if strings.TrimSpace(sqlText) == "" { + return nil + } + _, err := conn.ExecContext(context.Background(), sqlText) + return err +} + +func (e *sqlConnTransactionExecer) Commit() error { + return e.finish(e.commitSQL) +} + +func (e *sqlConnTransactionExecer) Rollback() error { + return e.finish(e.rollbackSQL) +} + +func (e *sqlConnTransactionExecer) Close() error { + if e == nil { + return nil + } + e.mu.Lock() + if e.conn == nil { + e.mu.Unlock() + return nil + } + conn := e.conn + shouldRollback := !e.done && e.rollbackSQL != "" + rollbackSQL := e.rollbackSQL + e.conn = nil + e.done = true + e.mu.Unlock() + + var rollbackErr error + if shouldRollback { + _, rollbackErr = conn.ExecContext(context.Background(), rollbackSQL) + } + closeErr := conn.Close() + if rollbackErr != nil { + return rollbackErr + } + return closeErr +} + type sqlTxStatementExecer struct { mu sync.Mutex tx *sql.Tx diff --git a/internal/db/oracle_applychanges_test.go b/internal/db/oracle_applychanges_test.go index 0f6ac58..b06041c 100644 --- a/internal/db/oracle_applychanges_test.go +++ b/internal/db/oracle_applychanges_test.go @@ -6,6 +6,7 @@ import ( "database/sql/driver" "fmt" "io" + "reflect" "strings" "sync" "testing" @@ -28,6 +29,7 @@ type oracleRecordingState struct { execQueries []string execArgs [][]driver.NamedValue queries []string + beginCalls int rowsAffected int64 queryResults map[string]oracleRecordingQueryResult queryError error @@ -61,6 +63,12 @@ func (s *oracleRecordingState) snapshotQueries() []string { return append([]string(nil), s.queries...) } +func (s *oracleRecordingState) snapshotBeginCalls() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.beginCalls +} + type oracleRecordingDriver struct{} func (oracleRecordingDriver) Open(name string) (driver.Conn, error) { @@ -83,7 +91,12 @@ func (c *oracleRecordingConn) Prepare(query string) (driver.Stmt, error) { func (c *oracleRecordingConn) Close() error { return nil } -func (c *oracleRecordingConn) Begin() (driver.Tx, error) { return oracleRecordingTx{}, nil } +func (c *oracleRecordingConn) Begin() (driver.Tx, error) { + c.state.mu.Lock() + c.state.beginCalls++ + c.state.mu.Unlock() + return oracleRecordingTx{}, nil +} func (c *oracleRecordingConn) ExecContext(_ context.Context, query string, args []driver.NamedValue) (driver.Result, error) { c.state.mu.Lock() @@ -191,6 +204,61 @@ func openOracleRecordingDB(t *testing.T) (*sql.DB, *oracleRecordingState) { return dbConn, state } +func TestOracleOpenTransactionExecerUsesPinnedSessionTransactionSQL(t *testing.T) { + t.Parallel() + + for _, tt := range []struct { + name string + finish func(TransactionExecer) error + wantFinalSQL string + }{ + { + name: "commit", + finish: func(tx TransactionExecer) error { + return tx.Commit() + }, + wantFinalSQL: "COMMIT", + }, + { + name: "rollback", + finish: func(tx TransactionExecer) error { + return tx.Rollback() + }, + wantFinalSQL: "ROLLBACK", + }, + } { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + dbConn, state := openOracleRecordingDB(t) + oracleDB := &OracleDB{conn: dbConn} + stmt := "UPDATE USERS SET NAME = 'new' WHERE ID = 1" + + tx, err := oracleDB.OpenTransactionExecer(context.Background()) + if err != nil { + t.Fatalf("OpenTransactionExecer returned error: %v", err) + } + if _, err := tx.ExecContext(context.Background(), stmt); err != nil { + t.Fatalf("ExecContext returned error: %v", err) + } + if err := tt.finish(tx); err != nil { + t.Fatalf("finish returned error: %v", err) + } + if err := tx.Close(); err != nil { + t.Fatalf("Close returned error: %v", err) + } + + if got := state.snapshotBeginCalls(); got != 0 { + t.Fatalf("expected Oracle transaction execer not to call database/sql Begin, got %d", got) + } + wantExecs := []string{stmt, tt.wantFinalSQL} + if got := state.snapshotExecQueries(); !reflect.DeepEqual(got, wantExecs) { + t.Fatalf("expected exec queries %#v, got %#v", wantExecs, got) + } + }) + } +} + func TestOracleApplyChangesReturnsErrorWhenUpdateMatchesNoRows(t *testing.T) { t.Parallel() diff --git a/internal/db/oracle_impl.go b/internal/db/oracle_impl.go index 411c7be..459bfa0 100644 --- a/internal/db/oracle_impl.go +++ b/internal/db/oracle_impl.go @@ -258,11 +258,11 @@ func (o *OracleDB) OpenTransactionExecer(ctx context.Context) (TransactionExecer if o.conn == nil { return nil, fmt.Errorf("连接未打开") } - tx, err := o.conn.BeginTx(ctx, nil) + conn, err := o.conn.Conn(ctx) if err != nil { return nil, err } - return NewSQLTxStatementExecer(tx), nil + return NewSQLConnTransactionExecer(conn, "COMMIT", "ROLLBACK"), nil } func (o *OracleDB) OpenSessionExecer(ctx context.Context) (StatementExecer, error) {