diff --git a/internal/app/methods_db.go b/internal/app/methods_db.go index e777a8f..688f03b 100644 --- a/internal/app/methods_db.go +++ b/internal/app/methods_db.go @@ -998,6 +998,7 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu }() runConfig := normalizeRunConfig(config, dbName) + resolvedDBType := resolveDDLDBType(runConfig) buildStatementExecutionFailedMessage := func(index int, err error, previousSuccessCount int) string { message := a.appText("db.backend.error.multi_statement_execution_failed", map[string]any{ "index": index, @@ -1059,7 +1060,7 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu break } } - useNativeMultiResult := shouldUseNativeMultiResultBatch(runConfig.Type, statements, allReadOnly) + useNativeMultiResult := shouldUseNativeMultiResultBatch(resolvedDBType, statements, allReadOnly) runMultiQuery := func(inst db.Database) ([]connection.ResultSetData, []string, error) { if !useNativeMultiResult { @@ -1226,7 +1227,7 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu isReadStmt := isReadOnlySQLQuery(runConfig.Type, stmt) tryQueryStmtFirst := shouldTryQueryResultFirst(runConfig.Type, stmt) if isReadStmt || tryQueryStmtFirst { - preferPlainReadQuery := isReadStmt && shouldPreferPlainReadQueryResult(runConfig.Type) + preferPlainReadQuery := isReadStmt && shouldPreferPlainReadQueryResult(resolvedDBType) var ( data []map[string]interface{} columns []string @@ -1417,8 +1418,13 @@ func shouldUseNativeMultiResultBatch(dbType string, statements []string, allRead } func shouldPreferPlainReadQueryResult(dbType string) bool { - switch resolveDDLDBType(connection.ConnectionConfig{Type: dbType}) { - case "postgres", "kingbase", "highgo", "vastbase", "opengauss", "gaussdb": + switch strings.ToLower(strings.TrimSpace(dbType)) { + case "postgres", "postgresql", + "kingbase", "kingbase8", "kingbasees", "kingbasev8", + "highgo", "vastbase", + "opengauss", "open_gauss", "open-gauss", + "gaussdb", "gauss_db", "gauss-db", + "dameng", "dm", "dm8": return true default: return false diff --git a/internal/app/methods_db_multi_test.go b/internal/app/methods_db_multi_test.go index 0273b5c..4462b2e 100644 --- a/internal/app/methods_db_multi_test.go +++ b/internal/app/methods_db_multi_test.go @@ -1878,6 +1878,67 @@ func TestDBQueryMultiPrefersPlainQueryForKingbaseReadResults(t *testing.T) { } } +func TestDBQueryMultiPrefersPlainQueryForDamengReadResults(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + t.Cleanup(func() { + newDatabaseFunc = originalNewDatabaseFunc + }) + + query := "SELECT * FROM PUB_TIMER" + nativeEmptyRowsResult := []connection.ResultSetData{{ + Rows: []map[string]interface{}{}, + Columns: []string{"ID", "NAME"}, + }} + baseDB := &fakeBatchWriteDB{ + queryMap: map[string][]map[string]interface{}{ + query: { + {"ID": 1, "NAME": "timer_a"}, + }, + }, + fieldMap: map[string][]string{ + query: {"ID", "NAME"}, + }, + multiResult: map[string][]connection.ResultSetData{ + query: nativeEmptyRowsResult, + }, + queryErr: map[string]error{}, + } + fakeDB := &fakeNativeMultiResultDB{fakeBatchWriteDB: baseDB} + newDatabaseFunc = func(dbType string) (db.Database, error) { + return fakeDB, nil + } + + app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test")) + config := connection.ConnectionConfig{Type: "custom", Driver: "dm8", Host: "127.0.0.1", Port: 5236, User: "SYSDBA"} + + result := app.DBQueryMulti(config, "SYSDBA", query, "dameng-plain-query-result-test") + if !result.Success { + t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message) + } + if fakeDB.multiCalls != 0 { + t.Fatalf("expected dameng read query to skip top-level native multi-result path, got %d calls", fakeDB.multiCalls) + } + if baseDB.session == nil { + t.Fatal("expected DBQueryMulti to open a pinned session for dameng read query") + } + if baseDB.session.queryCalls != 1 { + t.Fatalf("expected dameng read query to use plain session query once, got %d calls", baseDB.session.queryCalls) + } + resultSets, ok := result.Data.([]connection.ResultSetData) + if !ok { + t.Fatalf("expected []connection.ResultSetData, got %T", result.Data) + } + if len(resultSets) != 1 { + t.Fatalf("expected one result set, got %#v", resultSets) + } + if !reflect.DeepEqual(resultSets[0].Columns, []string{"ID", "NAME"}) { + t.Fatalf("expected plain query columns, got %#v", resultSets[0].Columns) + } + if got := resultSets[0].Rows[0]["NAME"]; got != "timer_a" { + t.Fatalf("expected plain query SELECT result NAME=timer_a, got %#v", got) + } +} + func TestDBQueryMultiUsesPinnedSessionForSequentialFallback(t *testing.T) { originalNewDatabaseFunc := newDatabaseFunc t.Cleanup(func() { @@ -2265,3 +2326,48 @@ func TestDBQueryMultiTransactionalTreatsSelectIntoAsManagedWrite(t *testing.T) { } } } + +func TestExecuteManagedSQLTransactionStatementsPrefersPlainQueryForDamengReadResults(t *testing.T) { + query := "SELECT * FROM PUB_TIMER" + baseDB := &fakeBatchWriteDB{ + queryMap: map[string][]map[string]interface{}{ + query: { + {"ID": 1, "NAME": "timer_a"}, + }, + }, + fieldMap: map[string][]string{ + query: {"ID", "NAME"}, + }, + multiResult: map[string][]connection.ResultSetData{ + query: {{ + Rows: []map[string]interface{}{}, + Columns: []string{"ID", "NAME"}, + }}, + }, + queryErr: map[string]error{}, + } + session := &fakeBatchWriteSession{parent: baseDB} + + results, err := executeManagedSQLTransactionStatements( + context.Background(), + session, + connection.ConnectionConfig{Type: "custom", Driver: "dm8"}, + []string{query}, + nil, + ) + if err != nil { + t.Fatalf("expected executeManagedSQLTransactionStatements success, got %v", err) + } + if session.queryCalls != 1 { + t.Fatalf("expected dameng managed read query to use plain query once, got %d calls", session.queryCalls) + } + if len(results) != 1 { + t.Fatalf("expected one result set, got %#v", results) + } + if !reflect.DeepEqual(results[0].Columns, []string{"ID", "NAME"}) { + t.Fatalf("expected plain query columns, got %#v", results[0].Columns) + } + if got := results[0].Rows[0]["NAME"]; got != "timer_a" { + t.Fatalf("expected plain query SELECT result NAME=timer_a, got %#v", got) + } +} diff --git a/internal/app/methods_db_transaction.go b/internal/app/methods_db_transaction.go index eb1befc..0bd885e 100644 --- a/internal/app/methods_db_transaction.go +++ b/internal/app/methods_db_transaction.go @@ -265,6 +265,7 @@ func executeManagedSQLTransactionStatements(ctx context.Context, session db.Stat if text == nil { text = defaultDBBackendText } + resolvedDBType := resolveDDLDBType(runConfig) buildStatementExecutionFailedError := func(index int, err error) error { return fmt.Errorf("%s", text("db.backend.error.multi_statement_execution_failed", map[string]any{ "index": index, @@ -298,7 +299,15 @@ func executeManagedSQLTransactionStatements(ctx context.Context, session db.Stat usedMultiResult bool err error ) - if sessionMultiQueryMessageTarget != nil { + if isReadStmt && shouldPreferPlainReadQueryResult(resolvedDBType) { + if sessionQueryMessageTarget != nil { + data, columns, messages, err = sessionQueryMessageTarget.QueryContextWithMessages(ctx, stmt) + } else if sessionQueryTarget != nil { + data, columns, err = sessionQueryTarget.QueryContext(ctx, stmt) + } else { + err = buildTransactionQueryUnsupportedError() + } + } else if sessionMultiQueryMessageTarget != nil { statementResults, messages, err = sessionMultiQueryMessageTarget.QueryMultiContextWithMessages(ctx, stmt) usedMultiResult = true } else if sessionMultiQueryTarget != nil {