From c8fe90cbee79cebf8193b6020a50b761049766cf Mon Sep 17 00:00:00 2001 From: Syngnat Date: Thu, 18 Jun 2026 11:32:08 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20perf(import-export):=20?= =?UTF-8?q?=E9=99=8D=E4=BD=8E=20OceanBase=20=E5=AF=BC=E5=87=BA=E9=93=BE?= =?UTF-8?q?=E8=B7=AF=E5=86=85=E5=AD=98=E5=8D=A0=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 为 optional driver-agent 补齐 streamQuery 分片协议,避免大结果集整批缓冲到内存 - 在 OceanBase 整表导出和查询结果导出前强校验 driver-agent revision,旧版代理直接拦截并提示重装 - 为 driver-agent 增加大查询和流式导出完成后的 GC/FreeOSMemory 回收逻辑 - 补充导出前校验、流式分片消费和 agent 内存回收的定向测试 - 更新 driver-agent revisions 以匹配新的流式导出协议 --- cmd/optional-driver-agent/main.go | 254 +++++++++++++++++- cmd/optional-driver-agent/main_test.go | 190 +++++++++++++ internal/app/methods_driver.go | 1 + internal/app/methods_file.go | 27 ++ internal/app/methods_file_export_test.go | 43 +++ internal/db/driver_agent_revisions_gen.go | 40 +-- internal/db/optional_driver_agent_impl.go | 197 ++++++++++++++ .../db/optional_driver_agent_impl_test.go | 71 +++++ 8 files changed, 801 insertions(+), 22 deletions(-) diff --git a/cmd/optional-driver-agent/main.go b/cmd/optional-driver-agent/main.go index 4915947..c9081f0 100644 --- a/cmd/optional-driver-agent/main.go +++ b/cmd/optional-driver-agent/main.go @@ -7,8 +7,11 @@ import ( "fmt" "os" "reflect" + "runtime" + "runtime/debug" "strconv" "strings" + "sync/atomic" "time" "GoNavi-Wails/internal/connection" @@ -33,6 +36,7 @@ type agentResponse struct { Error string `json:"error,omitempty"` Data interface{} `json:"data,omitempty"` Fields []string `json:"fields,omitempty"` + ChunkType string `json:"chunkType,omitempty"` RowsAffected int64 `json:"rowsAffected,omitempty"` } @@ -44,6 +48,7 @@ const ( agentMethodOpenSession = "openSession" agentMethodCloseSession = "closeSession" agentMethodQuery = "query" + agentMethodStreamQuery = "streamQuery" agentMethodExec = "exec" agentMethodGetDatabases = "getDatabases" agentMethodGetTables = "getTables" @@ -58,9 +63,27 @@ const ( const legacyClickHouseDefaultTimeout = 2 * time.Hour +const ( + agentChunkColumns = "columns" + agentChunkRows = "rows" + agentChunkDone = "done" + agentStreamBatchSize = 256 + agentMemoryTrimRowsThreshold = 100000 + agentMemoryTrimMinInterval = 3 * time.Second +) + var ( - agentDriverType string - agentDatabaseFactory func() db.Database + agentDriverType string + agentDatabaseFactory func() db.Database + agentMemoryTrimRunning atomic.Bool + agentMemoryTrimLastAt atomic.Int64 + runAgentMemoryTrimAsync = func(fn func()) { + go fn() + } + agentMemoryTrimFn = func() { + runtime.GC() + debug.FreeOSMemory() + } ) type agentRuntime struct { @@ -99,11 +122,22 @@ func main() { continue } + if strings.TrimSpace(req.Method) == agentMethodStreamQuery { + if err := handleStreamRequest(runtimeState, req, writer); err != nil { + fmt.Fprintf(os.Stderr, "写入流式响应失败:%v\n", err) + break + } + continue + } + resp := handleRequest(runtimeState, req) if err := writeResponse(writer, resp); err != nil { fmt.Fprintf(os.Stderr, "写入响应失败:%v\n", err) break } + if strings.TrimSpace(req.Method) == agentMethodQuery { + maybeReleaseAgentMemory("query-response", countAgentResponseRows(resp.Data)) + } } runtimeState.close() @@ -288,6 +322,108 @@ func handleRequest(runtimeState *agentRuntime, req agentRequest) agentResponse { return resp } +type agentStreamResponseWriter struct { + writer *bufio.Writer + requestID int64 + columns []string + rows [][]interface{} + rowCount int64 +} + +func newAgentStreamResponseWriter(writer *bufio.Writer, requestID int64) *agentStreamResponseWriter { + return &agentStreamResponseWriter{ + writer: writer, + requestID: requestID, + } +} + +func (w *agentStreamResponseWriter) SetColumns(columns []string) error { + w.columns = append([]string(nil), columns...) + return writeResponse(w.writer, agentResponse{ + ID: w.requestID, + Success: true, + ChunkType: agentChunkColumns, + Fields: w.columns, + }) +} + +func (w *agentStreamResponseWriter) ConsumeRow(row map[string]interface{}) error { + if len(w.columns) == 0 { + return fmt.Errorf("流式查询缺少列定义") + } + values := make([]interface{}, len(w.columns)) + for idx, column := range w.columns { + values[idx] = row[column] + } + return w.ConsumeRowValues(values) +} + +func (w *agentStreamResponseWriter) ConsumeRowValues(values []interface{}) error { + row := append([]interface{}(nil), values...) + w.rows = append(w.rows, row) + w.rowCount++ + if len(w.rows) < agentStreamBatchSize { + return nil + } + return w.flushRows() +} + +func (w *agentStreamResponseWriter) flushRows() error { + if len(w.rows) == 0 { + return nil + } + rows := w.rows + w.rows = nil + return writeResponse(w.writer, agentResponse{ + ID: w.requestID, + Success: true, + ChunkType: agentChunkRows, + Data: rows, + }) +} + +func (w *agentStreamResponseWriter) finish() error { + return w.flushRows() +} + +func handleStreamRequest(runtimeState *agentRuntime, req agentRequest, writer *bufio.Writer) error { + resp := agentResponse{ID: req.ID, Success: true} + if runtimeState.inst == nil { + return writeResponse(writer, fail(resp, "connection not open")) + } + + streamWriter := newAgentStreamResponseWriter(writer, req.ID) + if session, ok, err := runtimeState.session(req.SessionID); err != nil { + return writeResponse(writer, fail(resp, err.Error())) + } else if ok { + if err := streamStatementWithOptionalTimeout(session, req.Query, req.TimeoutMs, streamWriter); err != nil { + _ = streamWriter.finish() + return writeResponse(writer, fail(resp, err.Error())) + } + if err := streamWriter.finish(); err != nil { + return err + } + if err := writeResponse(writer, agentResponse{ID: req.ID, Success: true, ChunkType: agentChunkDone}); err != nil { + return err + } + maybeReleaseAgentMemory("stream-query-session", streamWriter.rowCount) + return nil + } + + if err := streamDatabaseWithOptionalTimeout(runtimeState.inst, req.Query, req.TimeoutMs, streamWriter); err != nil { + _ = streamWriter.finish() + return writeResponse(writer, fail(resp, err.Error())) + } + if err := streamWriter.finish(); err != nil { + return err + } + if err := writeResponse(writer, agentResponse{ID: req.ID, Success: true, ChunkType: agentChunkDone}); err != nil { + return err + } + maybeReleaseAgentMemory("stream-query-db", streamWriter.rowCount) + return nil +} + func (r *agentRuntime) nextID() string { r.ensureSessionMap() r.nextSessionID++ @@ -459,6 +595,82 @@ func queryStatementWithOptionalTimeout(inst db.StatementExecer, query string, ti return queryWithOptionalTimeout(queryRunner, query, timeoutMs) } +func streamWithOptionalTimeout(inst db.StreamQueryExecer, query string, timeoutMs int64, consumer db.QueryStreamConsumer) error { + effectiveTimeoutMs := timeoutMs + if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") { + effectiveTimeoutMs = int64(legacyClickHouseDefaultTimeout / time.Millisecond) + } + if effectiveTimeoutMs <= 0 { + return inst.StreamQuery(query, consumer) + } + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond) + defer cancel() + return inst.StreamQueryContext(ctx, query, consumer) +} + +func streamBufferedQueryResult(fields []string, data []map[string]interface{}, consumer db.QueryStreamConsumer) error { + if err := consumer.SetColumns(fields); err != nil { + return err + } + if valueConsumer, ok := consumer.(db.QueryStreamValueConsumer); ok { + for _, row := range data { + values := make([]interface{}, len(fields)) + for idx, field := range fields { + values[idx] = row[field] + } + if err := valueConsumer.ConsumeRowValues(values); err != nil { + return err + } + } + return nil + } + for _, row := range data { + if err := consumer.ConsumeRow(row); err != nil { + return err + } + } + return nil +} + +func streamStatementWithOptionalTimeout(inst db.StatementExecer, query string, timeoutMs int64, consumer db.QueryStreamConsumer) error { + if streamer, ok := inst.(db.StreamQueryExecer); ok { + return streamWithOptionalTimeout(streamer, query, timeoutMs, consumer) + } + data, fields, err := queryStatementWithOptionalTimeout(inst, query, timeoutMs) + if err != nil { + return err + } + return streamBufferedQueryResult(fields, data, consumer) +} + +func streamDatabaseWithOptionalTimeout(inst db.Database, query string, timeoutMs int64, consumer db.QueryStreamConsumer) error { + if streamer, ok := inst.(db.StreamQueryExecer); ok { + return streamWithOptionalTimeout(streamer, query, timeoutMs, consumer) + } + if provider, ok := inst.(db.SessionExecerProvider); ok { + openCtx := context.Background() + var cancel context.CancelFunc + effectiveTimeoutMs := timeoutMs + if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") { + effectiveTimeoutMs = int64(legacyClickHouseDefaultTimeout / time.Millisecond) + } + if effectiveTimeoutMs > 0 { + openCtx, cancel = context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond) + defer cancel() + } + session, err := provider.OpenSessionExecer(openCtx) + if err == nil { + defer session.Close() + return streamStatementWithOptionalTimeout(session, query, timeoutMs, consumer) + } + } + data, fields, err := queryWithOptionalTimeout(inst, query, timeoutMs) + if err != nil { + return err + } + return streamBufferedQueryResult(fields, data, consumer) +} + func execWithOptionalTimeout(inst agentExecRunner, query string, timeoutMs int64) (int64, error) { effectiveTimeoutMs := timeoutMs if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") { @@ -478,3 +690,41 @@ func execWithOptionalTimeout(inst agentExecRunner, query string, timeoutMs int64 func execStatementWithOptionalTimeout(inst db.StatementExecer, query string, timeoutMs int64) (int64, error) { return execWithOptionalTimeout(inst, query, timeoutMs) } + +func countAgentResponseRows(data interface{}) int64 { + rows, ok := data.([]map[string]interface{}) + if !ok { + return 0 + } + return int64(len(rows)) +} + +func maybeReleaseAgentMemory(reason string, rows int64) { + if rows < agentMemoryTrimRowsThreshold { + return + } + if !agentMemoryTrimRunning.CompareAndSwap(false, true) { + return + } + + runAgentMemoryTrimAsync(func() { + defer agentMemoryTrimRunning.Store(false) + if delay := nextAgentMemoryTrimDelay(); delay > 0 { + time.Sleep(delay) + } + agentMemoryTrimFn() + agentMemoryTrimLastAt.Store(time.Now().UnixNano()) + }) +} + +func nextAgentMemoryTrimDelay() time.Duration { + lastUnixNano := agentMemoryTrimLastAt.Load() + if lastUnixNano <= 0 { + return 0 + } + elapsed := time.Since(time.Unix(0, lastUnixNano)) + if elapsed >= agentMemoryTrimMinInterval { + return 0 + } + return agentMemoryTrimMinInterval - elapsed +} diff --git a/cmd/optional-driver-agent/main_test.go b/cmd/optional-driver-agent/main_test.go index da976a0..eac28d6 100644 --- a/cmd/optional-driver-agent/main_test.go +++ b/cmd/optional-driver-agent/main_test.go @@ -190,6 +190,64 @@ func (f *fakeAgentStatementSession) Close() error { return nil } +type fakeAgentStreamSession struct { + closed bool + streamCalls int + deadlineSet bool +} + +func (f *fakeAgentStreamSession) Exec(query string) (int64, error) { + return 0, nil +} + +func (f *fakeAgentStreamSession) ExecContext(ctx context.Context, query string) (int64, error) { + return 0, nil +} + +func (f *fakeAgentStreamSession) Close() error { + f.closed = true + return nil +} + +func (f *fakeAgentStreamSession) StreamQuery(query string, consumer db.QueryStreamConsumer) error { + return f.StreamQueryContext(context.Background(), query, consumer) +} + +func (f *fakeAgentStreamSession) StreamQueryContext(ctx context.Context, query string, consumer db.QueryStreamConsumer) error { + f.streamCalls++ + if _, ok := ctx.Deadline(); ok { + f.deadlineSet = true + } + if err := consumer.SetColumns([]string{"id", "name"}); err != nil { + return err + } + if valueConsumer, ok := consumer.(db.QueryStreamValueConsumer); ok { + if err := valueConsumer.ConsumeRowValues([]interface{}{1, "alice"}); err != nil { + return err + } + if err := valueConsumer.ConsumeRowValues([]interface{}{2, "bob"}); err != nil { + return err + } + return nil + } + if err := consumer.ConsumeRow(map[string]interface{}{"id": 1, "name": "alice"}); err != nil { + return err + } + return consumer.ConsumeRow(map[string]interface{}{"id": 2, "name": "bob"}) +} + +type fakeAgentSessionStreamDB struct { + fakeAgentTimeoutDB + session *fakeAgentStreamSession + openCalls int +} + +func (f *fakeAgentSessionStreamDB) OpenSessionExecer(ctx context.Context) (db.StatementExecer, error) { + f.openCalls++ + f.session = &fakeAgentStreamSession{} + return f.session, nil +} + func TestQueryWithOptionalTimeout_UsesQueryContext(t *testing.T) { fake := &fakeAgentTimeoutDB{} data, fields, err := queryWithOptionalTimeout(fake, "SELECT 1", int64((2 * time.Second).Milliseconds())) @@ -306,3 +364,135 @@ func TestHandleRequest_UsesPinnedSessionForSessionScopedQueryAndExec(t *testing. t.Fatal("expected pinned session to close") } } + +func TestHandleStreamRequest_UsesSessionStreamerAndWritesChunks(t *testing.T) { + old := agentDriverType + originalAsync := runAgentMemoryTrimAsync + originalTrim := agentMemoryTrimFn + originalLastAt := agentMemoryTrimLastAt.Load() + defer func() { agentDriverType = old }() + defer func() { + runAgentMemoryTrimAsync = originalAsync + agentMemoryTrimFn = originalTrim + agentMemoryTrimRunning.Store(false) + agentMemoryTrimLastAt.Store(originalLastAt) + }() + agentDriverType = "oceanbase" + agentMemoryTrimRunning.Store(false) + agentMemoryTrimLastAt.Store(0) + + fake := &fakeAgentSessionStreamDB{} + runtimeState := &agentRuntime{ + inst: fake, + sessions: make(map[string]db.StatementExecer), + } + + trimmed := 0 + runAgentMemoryTrimAsync = func(fn func()) { + fn() + } + agentMemoryTrimFn = func() { + trimmed++ + } + + var out bytes.Buffer + writer := bufio.NewWriter(&out) + if err := handleStreamRequest(runtimeState, agentRequest{ + ID: 9, + Method: agentMethodStreamQuery, + Query: "SELECT * FROM person_info", + TimeoutMs: int64((2 * time.Second).Milliseconds()), + }, writer); err != nil { + t.Fatalf("handleStreamRequest 返回错误: %v", err) + } + + if fake.openCalls != 1 { + t.Fatalf("expected OpenSessionExecer called once, got %d", fake.openCalls) + } + if fake.session == nil || fake.session.streamCalls != 1 { + t.Fatalf("expected session streamer used once, session=%#v", fake.session) + } + if !fake.session.deadlineSet { + t.Fatal("expected stream query context deadline to be set") + } + if !fake.session.closed { + t.Fatal("expected session to close after streaming") + } + if fake.queryCalled || fake.queryContextCalled { + t.Fatalf("unexpected fallback query path, Query=%v QueryContext=%v", fake.queryCalled, fake.queryContextCalled) + } + + lines := strings.Split(strings.TrimSpace(out.String()), "\n") + if len(lines) != 3 { + t.Fatalf("expected 3 stream responses, got %d: %q", len(lines), out.String()) + } + + var columnsResp struct { + Success bool `json:"success"` + ChunkType string `json:"chunkType"` + Fields []string `json:"fields"` + } + if err := json.Unmarshal([]byte(lines[0]), &columnsResp); err != nil { + t.Fatalf("decode columns response failed: %v", err) + } + if !columnsResp.Success || columnsResp.ChunkType != agentChunkColumns || len(columnsResp.Fields) != 2 { + t.Fatalf("unexpected columns response: %#v", columnsResp) + } + + var rowsResp struct { + Success bool `json:"success"` + ChunkType string `json:"chunkType"` + Data [][]interface{} `json:"data"` + } + if err := json.Unmarshal([]byte(lines[1]), &rowsResp); err != nil { + t.Fatalf("decode rows response failed: %v", err) + } + if !rowsResp.Success || rowsResp.ChunkType != agentChunkRows || len(rowsResp.Data) != 2 { + t.Fatalf("unexpected rows response: %#v", rowsResp) + } + if got := rowsResp.Data[1][1]; got != "bob" { + t.Fatalf("unexpected streamed row payload: %v", rowsResp.Data) + } + + var doneResp struct { + Success bool `json:"success"` + ChunkType string `json:"chunkType"` + } + if err := json.Unmarshal([]byte(lines[2]), &doneResp); err != nil { + t.Fatalf("decode done response failed: %v", err) + } + if !doneResp.Success || doneResp.ChunkType != agentChunkDone { + t.Fatalf("unexpected done response: %#v", doneResp) + } + if trimmed != 0 { + t.Fatalf("小流式任务不应触发内存回收,got=%d", trimmed) + } +} + +func TestMaybeReleaseAgentMemory_TriggersTrimForLargeJobs(t *testing.T) { + originalAsync := runAgentMemoryTrimAsync + originalTrim := agentMemoryTrimFn + originalLastAt := agentMemoryTrimLastAt.Load() + t.Cleanup(func() { + runAgentMemoryTrimAsync = originalAsync + agentMemoryTrimFn = originalTrim + agentMemoryTrimRunning.Store(false) + agentMemoryTrimLastAt.Store(originalLastAt) + }) + + agentMemoryTrimRunning.Store(false) + agentMemoryTrimLastAt.Store(0) + triggered := 0 + runAgentMemoryTrimAsync = func(fn func()) { + fn() + } + agentMemoryTrimFn = func() { + triggered++ + } + + maybeReleaseAgentMemory("test-large-query", agentMemoryTrimRowsThreshold) + + if triggered != 1 { + t.Fatalf("大查询完成后应触发一次内存回收,got=%d", triggered) + } +} diff --git a/internal/app/methods_driver.go b/internal/app/methods_driver.go index f2106d6..56deff5 100644 --- a/internal/app/methods_driver.go +++ b/internal/app/methods_driver.go @@ -418,6 +418,7 @@ var ( var optionalDriverSourceBuildTimeout = 8 * time.Minute var validateOptionalDriverAgentExecutableFunc = db.ValidateOptionalDriverAgentExecutable +var resolveOptionalDriverAgentExecutablePathFunc = db.ResolveOptionalDriverAgentExecutablePath type driverVersionWarmupState struct { Running bool diff --git a/internal/app/methods_file.go b/internal/app/methods_file.go index 3640844..ca0b924 100644 --- a/internal/app/methods_file.go +++ b/internal/app/methods_file.go @@ -342,6 +342,23 @@ func tryResolveExportTableTotalRows(dbInst db.Database, config connection.Connec return resolveExportTotalRowsFromRows(rows) } +func verifyOptionalDriverAgentReadyForExport(config connection.ConnectionConfig) error { + driverType := normalizeDriverType(config.Type) + if !db.IsOptionalGoDriver(driverType) { + return nil + } + + executablePath, err := resolveOptionalDriverAgentExecutablePathFunc("", driverType) + if err != nil { + return err + } + if _, err := verifyInstalledOptionalDriverAgentRevision(driverType, executablePath); err != nil { + displayName := resolveDriverDisplayName(driverDefinition{Type: driverType}) + return fmt.Errorf("当前导出依赖最新的 %s driver-agent 流式协议;为避免大结果集回退到高内存缓冲模式,请在驱动管理中重装后重试:%w", displayName, err) + } + return nil +} + var exportFileNameSanitizer = strings.NewReplacer( "/", "_", "\\", "_", @@ -2249,6 +2266,11 @@ func (a *App) ExportTable(config connection.ConnectionConfig, dbName string, tab func (a *App) ExportTableWithOptions(config connection.ConnectionConfig, dbName string, tableName string, options ExportFileOptions) connection.QueryResult { options = normalizeExportFileOptions("", options) format := options.Format + if format != "sql" { + if err := verifyOptionalDriverAgentReadyForExport(config); err != nil { + return connection.QueryResult{Success: false, Message: err.Error()} + } + } filename, err := runtime.SaveFileDialog(a.ctx, runtime.SaveDialogOptions{ Title: fmt.Sprintf("Export %s", tableName), DefaultFilename: fmt.Sprintf("%s.%s", tableName, format), @@ -3656,6 +3678,11 @@ func (a *App) ExportQueryWithOptions(config connection.ConnectionConfig, dbName } options = normalizeExportFileOptions("", options) format := options.Format + if format != "sql" { + if err := verifyOptionalDriverAgentReadyForExport(config); err != nil { + return connection.QueryResult{Success: false, Message: err.Error()} + } + } filename, err := runtime.SaveFileDialog(a.ctx, runtime.SaveDialogOptions{ Title: "Export Query Result", diff --git a/internal/app/methods_file_export_test.go b/internal/app/methods_file_export_test.go index 0bdf07f..2fb3645 100644 --- a/internal/app/methods_file_export_test.go +++ b/internal/app/methods_file_export_test.go @@ -475,6 +475,49 @@ func TestTryResolveExportTableTotalRows_UsesCountQuery(t *testing.T) { } } +func TestVerifyOptionalDriverAgentReadyForExport_RejectsStaleAgent(t *testing.T) { + originalProbe := optionalDriverAgentMetadataProbe + originalResolvePath := resolveOptionalDriverAgentExecutablePathFunc + t.Cleanup(func() { + optionalDriverAgentMetadataProbe = originalProbe + resolveOptionalDriverAgentExecutablePathFunc = originalResolvePath + }) + + resolveOptionalDriverAgentExecutablePathFunc = func(downloadDir string, driverType string) (string, error) { + return "/tmp/oceanbase-driver-agent", nil + } + optionalDriverAgentMetadataProbe = func(driverType string, executablePath string) (db.OptionalDriverAgentMetadata, error) { + return db.OptionalDriverAgentMetadata{ + DriverType: driverType, + AgentRevision: "src-stale-agent", + }, nil + } + + err := verifyOptionalDriverAgentReadyForExport(connection.ConnectionConfig{Type: "oceanbase"}) + if err == nil { + t.Fatal("预期旧版 OceanBase driver-agent 被导出前校验拦截") + } + if !strings.Contains(err.Error(), "流式协议") { + t.Fatalf("错误信息应说明需要流式协议,got=%q", err.Error()) + } +} + +func TestVerifyOptionalDriverAgentReadyForExport_SkipsBuiltInDriver(t *testing.T) { + originalResolvePath := resolveOptionalDriverAgentExecutablePathFunc + t.Cleanup(func() { + resolveOptionalDriverAgentExecutablePathFunc = originalResolvePath + }) + + resolveOptionalDriverAgentExecutablePathFunc = func(downloadDir string, driverType string) (string, error) { + t.Fatalf("内置驱动导出不应探测 optional driver-agent 路径") + return "", nil + } + + if err := verifyOptionalDriverAgentReadyForExport(connection.ConnectionConfig{Type: "mysql"}); err != nil { + t.Fatalf("内置驱动导出不应被 optional driver-agent 校验阻断: %v", err) + } +} + func TestExportQueryResultToFile_UsesStreamQueryPath(t *testing.T) { f, err := os.CreateTemp("", "gonavi-export-stream-*.csv") if err != nil { diff --git a/internal/db/driver_agent_revisions_gen.go b/internal/db/driver_agent_revisions_gen.go index e8e25e6..01d1c81 100644 --- a/internal/db/driver_agent_revisions_gen.go +++ b/internal/db/driver_agent_revisions_gen.go @@ -4,25 +4,25 @@ package db func init() { optionalDriverAgentRevisions = map[string]string{ - "mariadb": "src-0a4176f4b5743323", - "oceanbase": "src-7cb0f2c4dc0510a5", - "diros": "src-cc11b882e28fa5d4", - "starrocks": "src-83a6d81c91c7f5c8", - "sphinx": "src-a70c2cd4d223dac2", - "sqlserver": "src-6d5cf334034bce41", - "sqlite": "src-762863d48f653b89", - "duckdb": "src-df5d60ebb175bbbc", - "dameng": "src-596bebeaa016fc74", - "kingbase": "src-2e5a1337b0405c57", - "highgo": "src-5a29a1d3685eb6b4", - "vastbase": "src-e3cfef65512feb23", - "opengauss": "src-58227ba3bc1ec894", - "gaussdb": "src-1458564993a9d455", - "iris": "src-1b072c57af08bec4", - "mongodb": "src-57fdd8bfebdcd46e", - "tdengine": "src-939715f94df1ec9c", - "iotdb": "src-473c39891f926db2", - "clickhouse": "src-482d62ed565b3e69", - "elasticsearch": "src-2fb00b94d7067c56", + "mariadb": "src-b23e2ce1581a5064", + "oceanbase": "src-5067dbdf0ca7b9c4", + "diros": "src-db43faca6bf15d9b", + "starrocks": "src-01e9f06c0fab09d5", + "sphinx": "src-38ee5cae952cc809", + "sqlserver": "src-7a87f6deb816f110", + "sqlite": "src-d3d439cd788880e2", + "duckdb": "src-b11506b8706bfb73", + "dameng": "src-1638124bfd7fce09", + "kingbase": "src-fb3a404cf4eb1bd9", + "highgo": "src-72fe51afa884f6bc", + "vastbase": "src-3d48607603bfd8b7", + "opengauss": "src-709acf442f016e30", + "gaussdb": "src-f6beccc924d71031", + "iris": "src-9ebf5b970a73b341", + "mongodb": "src-367d11cd04e982c1", + "tdengine": "src-3c13c42f18ba01e1", + "iotdb": "src-5ba9da13c6a272f9", + "clickhouse": "src-99c8babfefdf142c", + "elasticsearch": "src-36b2e2b5f49db9d1", } } diff --git a/internal/db/optional_driver_agent_impl.go b/internal/db/optional_driver_agent_impl.go index 5c65da5..d9362d8 100644 --- a/internal/db/optional_driver_agent_impl.go +++ b/internal/db/optional_driver_agent_impl.go @@ -2,6 +2,7 @@ package db import ( "bufio" + "bytes" "context" "encoding/json" "errors" @@ -28,6 +29,7 @@ const ( optionalAgentMethodOpenSession = "openSession" optionalAgentMethodCloseSession = "closeSession" optionalAgentMethodQuery = "query" + optionalAgentMethodStreamQuery = "streamQuery" optionalAgentMethodExec = "exec" optionalAgentMethodGetDatabases = "getDatabases" optionalAgentMethodGetTables = "getTables" @@ -42,6 +44,12 @@ const ( optionalAgentMetadataProbeTimeout = 5 * time.Second ) +const ( + optionalAgentChunkColumns = "columns" + optionalAgentChunkRows = "rows" + optionalAgentChunkDone = "done" +) + type optionalAgentRequest struct { ID int64 `json:"id"` Method string `json:"method"` @@ -60,6 +68,7 @@ type optionalAgentResponse struct { Error string `json:"error,omitempty"` Data json.RawMessage `json:"data,omitempty"` Fields []string `json:"fields,omitempty"` + ChunkType string `json:"chunkType,omitempty"` RowsAffected int64 `json:"rowsAffected,omitempty"` } @@ -269,6 +278,114 @@ func (c *optionalDriverAgentClient) callWithTimeout(req optionalAgentRequest, ou } } +func (c *optionalDriverAgentClient) callStreamQuery(req optionalAgentRequest, consumer QueryStreamConsumer) error { + if consumer == nil { + return fmt.Errorf("query stream consumer required") + } + + c.mu.Lock() + defer c.mu.Unlock() + + c.nextID++ + req.ID = c.nextID + + payload, err := json.Marshal(req) + if err != nil { + return err + } + payload = append(payload, '\n') + if _, err := c.stdin.Write(payload); err != nil { + stderrText := c.stderrText() + if stderrText == "" { + return fmt.Errorf("调用 %s 驱动代理失败:%w", driverDisplayName(c.driver), err) + } + return fmt.Errorf("调用 %s 驱动代理失败:%w(stderr: %s)", driverDisplayName(c.driver), err, stderrText) + } + + var columns []string + valueConsumer, useValueConsumer := consumer.(QueryStreamValueConsumer) + + for { + line, err := c.reader.ReadBytes('\n') + if err != nil { + stderrText := c.stderrText() + if stderrText == "" { + return fmt.Errorf("读取 %s 驱动代理响应失败:%w", driverDisplayName(c.driver), err) + } + return fmt.Errorf("读取 %s 驱动代理响应失败:%w(stderr: %s)", driverDisplayName(c.driver), err, stderrText) + } + + var resp optionalAgentResponse + if err := json.Unmarshal(line, &resp); err != nil { + return fmt.Errorf("解析 %s 驱动代理响应失败:%w", driverDisplayName(c.driver), err) + } + if !resp.Success { + errText := strings.TrimSpace(resp.Error) + if errText == "" { + errText = fmt.Sprintf("%s 驱动代理返回失败", driverDisplayName(c.driver)) + } + return errors.New(errText) + } + + switch resp.ChunkType { + case optionalAgentChunkColumns: + columns = append(columns[:0], resp.Fields...) + if err := consumer.SetColumns(columns); err != nil { + return err + } + case optionalAgentChunkRows: + if len(columns) == 0 { + return fmt.Errorf("%s 驱动代理流式响应缺少列信息", driverDisplayName(c.driver)) + } + rows, err := decodeOptionalAgentRowValueBatch(resp.Data) + if err != nil { + return fmt.Errorf("解析 %s 驱动代理流式数据失败:%w", driverDisplayName(c.driver), err) + } + for _, row := range rows { + if useValueConsumer { + if err := valueConsumer.ConsumeRowValues(row); err != nil { + return err + } + continue + } + entry := make(map[string]interface{}, len(columns)) + for i, column := range columns { + if i < len(row) { + entry[column] = row[i] + } else { + entry[column] = nil + } + } + if err := consumer.ConsumeRow(entry); err != nil { + return err + } + } + case optionalAgentChunkDone: + return nil + default: + return fmt.Errorf("%s 驱动代理返回未知流式分片类型:%s", driverDisplayName(c.driver), strings.TrimSpace(resp.ChunkType)) + } + } +} + +func decodeOptionalAgentRowValueBatch(data []byte) ([][]interface{}, error) { + if len(data) == 0 { + return nil, nil + } + decoder := json.NewDecoder(bytes.NewReader(data)) + decoder.UseNumber() + var rows [][]interface{} + if err := decoder.Decode(&rows); err != nil { + return nil, err + } + for rowIdx := range rows { + for colIdx := range rows[rowIdx] { + rows[rowIdx][colIdx] = normalizeQueryValue(rows[rowIdx][colIdx]) + } + } + return rows, nil +} + func (c *optionalDriverAgentClient) forceTerminate() { if c.stdin != nil { _ = c.stdin.Close() @@ -397,6 +514,42 @@ func (d *OptionalDriverAgentDB) Query(query string) ([]map[string]interface{}, [ return data, fields, nil } +func (d *OptionalDriverAgentDB) StreamQuery(query string, consumer QueryStreamConsumer) error { + return d.StreamQueryContext(context.Background(), query, consumer) +} + +func (d *OptionalDriverAgentDB) StreamQueryContext(ctx context.Context, query string, consumer QueryStreamConsumer) error { + if err := ctx.Err(); err != nil { + return err + } + client, err := d.requireClient() + if err != nil { + return err + } + err = client.callStreamQuery(optionalAgentRequest{ + Method: optionalAgentMethodStreamQuery, + Query: query, + TimeoutMs: timeoutMsFromContext(ctx), + }, consumer) + if isOptionalAgentStreamUnsupportedError(err) { + logger.Warnf("%s 驱动代理暂不支持流式查询,回退到缓冲模式:err=%v", driverDisplayName(d.driverType), err) + data, columns, queryErr := d.QueryContext(ctx, query) + if queryErr != nil { + return queryErr + } + if err := consumer.SetColumns(columns); err != nil { + return err + } + for _, row := range data { + if err := consumer.ConsumeRow(row); err != nil { + return err + } + } + return nil + } + return err +} + func (d *OptionalDriverAgentDB) ExecContext(ctx context.Context, query string) (int64, error) { if err := ctx.Err(); err != nil { return 0, err @@ -458,6 +611,39 @@ func (s *optionalDriverAgentSession) Query(query string) ([]map[string]interface return s.QueryContext(context.Background(), query) } +func (s *optionalDriverAgentSession) StreamQuery(query string, consumer QueryStreamConsumer) error { + return s.StreamQueryContext(context.Background(), query, consumer) +} + +func (s *optionalDriverAgentSession) StreamQueryContext(ctx context.Context, query string, consumer QueryStreamConsumer) error { + if err := s.ensureOpen(); err != nil { + return err + } + err := s.client.callStreamQuery(optionalAgentRequest{ + Method: optionalAgentMethodStreamQuery, + SessionID: s.sessionID, + Query: query, + TimeoutMs: timeoutMsFromContext(ctx), + }, consumer) + if isOptionalAgentStreamUnsupportedError(err) { + logger.Warnf("%s 驱动代理事务会话暂不支持流式查询,回退到缓冲模式:err=%v", driverDisplayName(s.driver), err) + data, columns, queryErr := s.QueryContext(ctx, query) + if queryErr != nil { + return queryErr + } + if err := consumer.SetColumns(columns); err != nil { + return err + } + for _, row := range data { + if err := consumer.ConsumeRow(row); err != nil { + return err + } + } + return nil + } + return err +} + func (s *optionalDriverAgentSession) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { if err := s.ensureOpen(); err != nil { return nil, nil, err @@ -525,6 +711,17 @@ func (s *optionalDriverAgentSession) ensureOpen() error { return nil } +func isOptionalAgentStreamUnsupportedError(err error) bool { + if err == nil { + return false + } + text := strings.TrimSpace(err.Error()) + if text == "" { + return false + } + return strings.Contains(text, "不支持的方法") || strings.Contains(text, "不支持流式查询") +} + func (d *OptionalDriverAgentDB) GetDatabases() ([]string, error) { client, err := d.requireClient() if err != nil { diff --git a/internal/db/optional_driver_agent_impl_test.go b/internal/db/optional_driver_agent_impl_test.go index a79b03d..90600db 100644 --- a/internal/db/optional_driver_agent_impl_test.go +++ b/internal/db/optional_driver_agent_impl_test.go @@ -1,6 +1,9 @@ package db import ( + "bufio" + "bytes" + "strings" "testing" "GoNavi-Wails/internal/connection" @@ -65,3 +68,71 @@ func TestNormalizeKingbaseAgentChangeSetByColumns(t *testing.T) { t.Fatalf("unexpected update value key \"event name\" after normalization") } } + +type optionalAgentTestWriteCloser struct { + bytes.Buffer +} + +func (w *optionalAgentTestWriteCloser) Close() error { return nil } + +type optionalAgentTestStreamConsumer struct { + columns []string + rows [][]interface{} +} + +func (c *optionalAgentTestStreamConsumer) SetColumns(columns []string) error { + c.columns = append([]string(nil), columns...) + return nil +} + +func (c *optionalAgentTestStreamConsumer) ConsumeRow(row map[string]interface{}) error { + values := make([]interface{}, len(c.columns)) + for idx, column := range c.columns { + values[idx] = row[column] + } + c.rows = append(c.rows, values) + return nil +} + +func (c *optionalAgentTestStreamConsumer) ConsumeRowValues(values []interface{}) error { + c.rows = append(c.rows, append([]interface{}(nil), values...)) + return nil +} + +func TestOptionalDriverAgentClientCallStreamQueryConsumesChunks(t *testing.T) { + var stdin optionalAgentTestWriteCloser + stdout := strings.Join([]string{ + `{"id":1,"success":true,"chunkType":"columns","fields":["id","name"]}`, + `{"id":1,"success":true,"chunkType":"rows","data":[[1,"alice"],[2,"bob"]]}`, + `{"id":1,"success":true,"chunkType":"done"}`, + }, "\n") + "\n" + + client := &optionalDriverAgentClient{ + stdin: &stdin, + reader: bufio.NewReader(strings.NewReader(stdout)), + driver: "oceanbase", + } + consumer := &optionalAgentTestStreamConsumer{} + if err := client.callStreamQuery(optionalAgentRequest{ + Method: optionalAgentMethodStreamQuery, + Query: "SELECT 1", + }, consumer); err != nil { + t.Fatalf("callStreamQuery 返回错误: %v", err) + } + + if len(consumer.columns) != 2 || consumer.columns[0] != "id" || consumer.columns[1] != "name" { + t.Fatalf("流式列定义异常: %#v", consumer.columns) + } + if len(consumer.rows) != 2 { + t.Fatalf("流式行数异常: %#v", consumer.rows) + } + if got := consumer.rows[0][1]; got != "alice" { + t.Fatalf("第 1 行数据异常,want=%q got=%v", "alice", got) + } + if got := consumer.rows[1][0]; got != int64(2) { + t.Fatalf("第 2 行 ID 异常,want=%d got=%v (%T)", 2, got, got) + } + if !strings.Contains(stdin.String(), `"method":"streamQuery"`) { + t.Fatalf("请求未使用 streamQuery 方法: %s", stdin.String()) + } +}