Files
MyGoNavi/cmd/optional-driver-agent/main.go
Syngnat c8fe90cbee ️ perf(import-export): 降低 OceanBase 导出链路内存占用
- 为 optional driver-agent 补齐 streamQuery 分片协议,避免大结果集整批缓冲到内存
- 在 OceanBase 整表导出和查询结果导出前强校验 driver-agent revision,旧版代理直接拦截并提示重装
- 为 driver-agent 增加大查询和流式导出完成后的 GC/FreeOSMemory 回收逻辑
- 补充导出前校验、流式分片消费和 agent 内存回收的定向测试
- 更新 driver-agent revisions 以匹配新的流式导出协议
2026-06-18 11:32:08 +08:00

731 lines
21 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"bufio"
"context"
"encoding/json"
"fmt"
"os"
"reflect"
"runtime"
"runtime/debug"
"strconv"
"strings"
"sync/atomic"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/db"
)
type agentRequest struct {
ID int64 `json:"id"`
Method string `json:"method"`
SessionID string `json:"sessionId,omitempty"`
Config *connection.ConnectionConfig `json:"config,omitempty"`
Query string `json:"query,omitempty"`
TimeoutMs int64 `json:"timeoutMs,omitempty"`
DBName string `json:"dbName,omitempty"`
TableName string `json:"tableName,omitempty"`
Changes *connection.ChangeSet `json:"changes,omitempty"`
}
type agentResponse struct {
ID int64 `json:"id"`
Success bool `json:"success"`
Error string `json:"error,omitempty"`
Data interface{} `json:"data,omitempty"`
Fields []string `json:"fields,omitempty"`
ChunkType string `json:"chunkType,omitempty"`
RowsAffected int64 `json:"rowsAffected,omitempty"`
}
const (
agentMethodConnect = "connect"
agentMethodClose = "close"
agentMethodMetadata = "metadata"
agentMethodPing = "ping"
agentMethodOpenSession = "openSession"
agentMethodCloseSession = "closeSession"
agentMethodQuery = "query"
agentMethodStreamQuery = "streamQuery"
agentMethodExec = "exec"
agentMethodGetDatabases = "getDatabases"
agentMethodGetTables = "getTables"
agentMethodGetCreateStmt = "getCreateStatement"
agentMethodGetColumns = "getColumns"
agentMethodGetAllColumns = "getAllColumns"
agentMethodGetIndexes = "getIndexes"
agentMethodGetForeignKey = "getForeignKeys"
agentMethodGetTriggers = "getTriggers"
agentMethodApplyChanges = "applyChanges"
)
const legacyClickHouseDefaultTimeout = 2 * time.Hour
const (
agentChunkColumns = "columns"
agentChunkRows = "rows"
agentChunkDone = "done"
agentStreamBatchSize = 256
agentMemoryTrimRowsThreshold = 100000
agentMemoryTrimMinInterval = 3 * time.Second
)
var (
agentDriverType string
agentDatabaseFactory func() db.Database
agentMemoryTrimRunning atomic.Bool
agentMemoryTrimLastAt atomic.Int64
runAgentMemoryTrimAsync = func(fn func()) {
go fn()
}
agentMemoryTrimFn = func() {
runtime.GC()
debug.FreeOSMemory()
}
)
type agentRuntime struct {
inst db.Database
sessions map[string]db.StatementExecer
nextSessionID int64
}
func main() {
if agentDatabaseFactory == nil || strings.TrimSpace(agentDriverType) == "" {
fmt.Fprintf(os.Stderr, "未配置驱动代理 provider请使用 gonavi_<driver>_driver 标签构建\n")
os.Exit(2)
}
scanner := bufio.NewScanner(os.Stdin)
scanner.Buffer(make([]byte, 0, 16<<10), 8<<20)
writer := bufio.NewWriter(os.Stdout)
defer writer.Flush()
runtimeState := &agentRuntime{
sessions: make(map[string]db.StatementExecer),
}
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
var req agentRequest
if err := json.Unmarshal([]byte(line), &req); err != nil {
_ = writeResponse(writer, agentResponse{
ID: req.ID,
Success: false,
Error: fmt.Sprintf("解析请求失败:%v", err),
})
continue
}
if strings.TrimSpace(req.Method) == agentMethodStreamQuery {
if err := handleStreamRequest(runtimeState, req, writer); err != nil {
fmt.Fprintf(os.Stderr, "写入流式响应失败:%v\n", err)
break
}
continue
}
resp := handleRequest(runtimeState, req)
if err := writeResponse(writer, resp); err != nil {
fmt.Fprintf(os.Stderr, "写入响应失败:%v\n", err)
break
}
if strings.TrimSpace(req.Method) == agentMethodQuery {
maybeReleaseAgentMemory("query-response", countAgentResponseRows(resp.Data))
}
}
runtimeState.close()
if err := scanner.Err(); err != nil {
fmt.Fprintf(os.Stderr, "读取请求失败:%v\n", err)
}
}
func handleRequest(runtimeState *agentRuntime, req agentRequest) agentResponse {
resp := agentResponse{ID: req.ID, Success: true}
method := strings.TrimSpace(req.Method)
switch method {
case agentMethodConnect:
if req.Config == nil {
return fail(resp, "连接配置为空")
}
runtimeState.close()
next := agentDatabaseFactory()
if next == nil {
return fail(resp, "驱动代理初始化失败")
}
if err := next.Connect(*req.Config); err != nil {
return fail(resp, err.Error())
}
runtimeState.inst = next
return resp
case agentMethodClose:
if runtimeState.inst != nil {
if err := runtimeState.close(); err != nil {
return fail(resp, err.Error())
}
}
return resp
case agentMethodMetadata:
resp.Data = map[string]string{
"driverType": strings.TrimSpace(agentDriverType),
"agentRevision": db.OptionalDriverAgentRevision(agentDriverType),
"protocolSchema": "json-lines-v1",
}
return resp
case agentMethodOpenSession:
if runtimeState.inst == nil {
return fail(resp, "connection not open")
}
provider, ok := runtimeState.inst.(db.SessionExecerProvider)
if !ok {
return fail(resp, fmt.Sprintf("当前数据源(%s不支持 SQL 编辑器托管事务", strings.TrimSpace(agentDriverType)))
}
openCtx := context.Background()
var cancel context.CancelFunc
if req.TimeoutMs > 0 {
openCtx, cancel = context.WithTimeout(context.Background(), time.Duration(req.TimeoutMs)*time.Millisecond)
defer cancel()
}
session, err := provider.OpenSessionExecer(openCtx)
if err != nil {
return fail(resp, err.Error())
}
sessionID := runtimeState.nextID()
runtimeState.sessions[sessionID] = session
resp.Data = sessionID
return resp
case agentMethodCloseSession:
if err := runtimeState.closeSession(req.SessionID); err != nil {
return fail(resp, err.Error())
}
return resp
}
if runtimeState.inst == nil {
return fail(resp, "connection not open")
}
if session, ok, err := runtimeState.session(req.SessionID); err != nil {
return fail(resp, err.Error())
} else if ok {
switch method {
case agentMethodQuery:
data, fields, err := queryStatementWithOptionalTimeout(session, req.Query, req.TimeoutMs)
if err != nil {
return fail(resp, err.Error())
}
resp.Data = data
resp.Fields = fields
case agentMethodExec:
affected, err := execStatementWithOptionalTimeout(session, req.Query, req.TimeoutMs)
if err != nil {
return fail(resp, err.Error())
}
resp.RowsAffected = affected
default:
return fail(resp, "当前事务会话不支持该方法")
}
return resp
}
switch method {
case agentMethodPing:
if err := runtimeState.inst.Ping(); err != nil {
return fail(resp, err.Error())
}
case agentMethodQuery:
data, fields, err := queryWithOptionalTimeout(runtimeState.inst, req.Query, req.TimeoutMs)
if err != nil {
return fail(resp, err.Error())
}
resp.Data = data
resp.Fields = fields
case agentMethodExec:
affected, err := execWithOptionalTimeout(runtimeState.inst, req.Query, req.TimeoutMs)
if err != nil {
return fail(resp, err.Error())
}
resp.RowsAffected = affected
case agentMethodGetDatabases:
data, err := runtimeState.inst.GetDatabases()
if err != nil {
return fail(resp, err.Error())
}
resp.Data = data
case agentMethodGetTables:
data, err := runtimeState.inst.GetTables(req.DBName)
if err != nil {
return fail(resp, err.Error())
}
resp.Data = data
case agentMethodGetCreateStmt:
data, err := runtimeState.inst.GetCreateStatement(req.DBName, req.TableName)
if err != nil {
return fail(resp, err.Error())
}
resp.Data = data
case agentMethodGetColumns:
data, err := runtimeState.inst.GetColumns(req.DBName, req.TableName)
if err != nil {
return fail(resp, err.Error())
}
resp.Data = data
case agentMethodGetAllColumns:
data, err := runtimeState.inst.GetAllColumns(req.DBName)
if err != nil {
return fail(resp, err.Error())
}
resp.Data = data
case agentMethodGetIndexes:
data, err := runtimeState.inst.GetIndexes(req.DBName, req.TableName)
if err != nil {
return fail(resp, err.Error())
}
resp.Data = data
case agentMethodGetForeignKey:
data, err := runtimeState.inst.GetForeignKeys(req.DBName, req.TableName)
if err != nil {
return fail(resp, err.Error())
}
resp.Data = data
case agentMethodGetTriggers:
data, err := runtimeState.inst.GetTriggers(req.DBName, req.TableName)
if err != nil {
return fail(resp, err.Error())
}
resp.Data = data
case agentMethodApplyChanges:
if req.Changes == nil {
return fail(resp, "变更集为空")
}
applier, ok := runtimeState.inst.(interface {
ApplyChanges(tableName string, changes connection.ChangeSet) error
})
if !ok {
return fail(resp, "当前驱动不支持 ApplyChanges")
}
if err := applier.ApplyChanges(req.TableName, *req.Changes); err != nil {
return fail(resp, err.Error())
}
default:
return fail(resp, "不支持的方法")
}
return resp
}
type agentStreamResponseWriter struct {
writer *bufio.Writer
requestID int64
columns []string
rows [][]interface{}
rowCount int64
}
func newAgentStreamResponseWriter(writer *bufio.Writer, requestID int64) *agentStreamResponseWriter {
return &agentStreamResponseWriter{
writer: writer,
requestID: requestID,
}
}
func (w *agentStreamResponseWriter) SetColumns(columns []string) error {
w.columns = append([]string(nil), columns...)
return writeResponse(w.writer, agentResponse{
ID: w.requestID,
Success: true,
ChunkType: agentChunkColumns,
Fields: w.columns,
})
}
func (w *agentStreamResponseWriter) ConsumeRow(row map[string]interface{}) error {
if len(w.columns) == 0 {
return fmt.Errorf("流式查询缺少列定义")
}
values := make([]interface{}, len(w.columns))
for idx, column := range w.columns {
values[idx] = row[column]
}
return w.ConsumeRowValues(values)
}
func (w *agentStreamResponseWriter) ConsumeRowValues(values []interface{}) error {
row := append([]interface{}(nil), values...)
w.rows = append(w.rows, row)
w.rowCount++
if len(w.rows) < agentStreamBatchSize {
return nil
}
return w.flushRows()
}
func (w *agentStreamResponseWriter) flushRows() error {
if len(w.rows) == 0 {
return nil
}
rows := w.rows
w.rows = nil
return writeResponse(w.writer, agentResponse{
ID: w.requestID,
Success: true,
ChunkType: agentChunkRows,
Data: rows,
})
}
func (w *agentStreamResponseWriter) finish() error {
return w.flushRows()
}
func handleStreamRequest(runtimeState *agentRuntime, req agentRequest, writer *bufio.Writer) error {
resp := agentResponse{ID: req.ID, Success: true}
if runtimeState.inst == nil {
return writeResponse(writer, fail(resp, "connection not open"))
}
streamWriter := newAgentStreamResponseWriter(writer, req.ID)
if session, ok, err := runtimeState.session(req.SessionID); err != nil {
return writeResponse(writer, fail(resp, err.Error()))
} else if ok {
if err := streamStatementWithOptionalTimeout(session, req.Query, req.TimeoutMs, streamWriter); err != nil {
_ = streamWriter.finish()
return writeResponse(writer, fail(resp, err.Error()))
}
if err := streamWriter.finish(); err != nil {
return err
}
if err := writeResponse(writer, agentResponse{ID: req.ID, Success: true, ChunkType: agentChunkDone}); err != nil {
return err
}
maybeReleaseAgentMemory("stream-query-session", streamWriter.rowCount)
return nil
}
if err := streamDatabaseWithOptionalTimeout(runtimeState.inst, req.Query, req.TimeoutMs, streamWriter); err != nil {
_ = streamWriter.finish()
return writeResponse(writer, fail(resp, err.Error()))
}
if err := streamWriter.finish(); err != nil {
return err
}
if err := writeResponse(writer, agentResponse{ID: req.ID, Success: true, ChunkType: agentChunkDone}); err != nil {
return err
}
maybeReleaseAgentMemory("stream-query-db", streamWriter.rowCount)
return nil
}
func (r *agentRuntime) nextID() string {
r.ensureSessionMap()
r.nextSessionID++
return "session-" + strconv.FormatInt(r.nextSessionID, 10)
}
func (r *agentRuntime) session(sessionID string) (db.StatementExecer, bool, error) {
r.ensureSessionMap()
sessionID = strings.TrimSpace(sessionID)
if sessionID == "" {
return nil, false, nil
}
session, ok := r.sessions[sessionID]
if !ok || session == nil {
return nil, false, fmt.Errorf("事务会话不存在或已结束")
}
return session, true, nil
}
func (r *agentRuntime) closeSession(sessionID string) error {
r.ensureSessionMap()
sessionID = strings.TrimSpace(sessionID)
if sessionID == "" {
return fmt.Errorf("事务会话 ID 不能为空")
}
session, ok := r.sessions[sessionID]
if ok {
delete(r.sessions, sessionID)
}
if !ok || session == nil {
return fmt.Errorf("事务会话不存在或已结束")
}
return session.Close()
}
func (r *agentRuntime) close() error {
var closeErr error
r.ensureSessionMap()
for sessionID, session := range r.sessions {
delete(r.sessions, sessionID)
if session != nil {
if err := session.Close(); err != nil && closeErr == nil {
closeErr = err
}
}
}
if r.inst != nil {
if err := r.inst.Close(); err != nil && closeErr == nil {
closeErr = err
}
r.inst = nil
}
return closeErr
}
func (r *agentRuntime) ensureSessionMap() {
if r.sessions == nil {
r.sessions = make(map[string]db.StatementExecer)
}
}
func writeResponse(writer *bufio.Writer, resp agentResponse) error {
// 对响应数据做统一 JSON 安全归一化:
// 将 map[any]any如 duckdb.Map递归转换为 map[string]any避免序列化失败导致代理进程退出。
safeResp := resp
safeResp.Data = normalizeAgentResponseData(resp.Data)
payload, err := json.Marshal(safeResp)
if err != nil {
return err
}
payload = append(payload, '\n')
if _, err := writer.Write(payload); err != nil {
return err
}
return writer.Flush()
}
func fail(resp agentResponse, errText string) agentResponse {
resp.Success = false
resp.Error = strings.TrimSpace(errText)
return resp
}
func normalizeAgentResponseData(v interface{}) interface{} {
if v == nil {
return nil
}
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Pointer, reflect.Interface:
if rv.IsNil() {
return nil
}
return normalizeAgentResponseData(rv.Elem().Interface())
case reflect.Map:
if rv.IsNil() {
return nil
}
out := make(map[string]interface{}, rv.Len())
iter := rv.MapRange()
for iter.Next() {
out[fmt.Sprint(iter.Key().Interface())] = normalizeAgentResponseData(iter.Value().Interface())
}
return out
case reflect.Slice:
if rv.IsNil() {
return nil
}
// 保持 []byte 原样,避免改变现有二进制列的 JSON 编码行为base64
if rv.Type().Elem().Kind() == reflect.Uint8 {
return v
}
size := rv.Len()
items := make([]interface{}, size)
for i := 0; i < size; i++ {
items[i] = normalizeAgentResponseData(rv.Index(i).Interface())
}
return items
case reflect.Array:
size := rv.Len()
items := make([]interface{}, size)
for i := 0; i < size; i++ {
items[i] = normalizeAgentResponseData(rv.Index(i).Interface())
}
return items
default:
return v
}
}
type agentQueryRunner interface {
Query(string) ([]map[string]interface{}, []string, error)
}
type agentQueryContextRunner interface {
QueryContext(context.Context, string) ([]map[string]interface{}, []string, error)
}
type agentExecRunner interface {
Exec(string) (int64, error)
}
type agentExecContextRunner interface {
ExecContext(context.Context, string) (int64, error)
}
func queryWithOptionalTimeout(inst agentQueryRunner, query string, timeoutMs int64) ([]map[string]interface{}, []string, error) {
effectiveTimeoutMs := timeoutMs
if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") {
effectiveTimeoutMs = int64(legacyClickHouseDefaultTimeout / time.Millisecond)
}
if effectiveTimeoutMs <= 0 {
return inst.Query(query)
}
if q, ok := inst.(agentQueryContextRunner); ok {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond)
defer cancel()
return q.QueryContext(ctx, query)
}
return inst.Query(query)
}
func queryStatementWithOptionalTimeout(inst db.StatementExecer, query string, timeoutMs int64) ([]map[string]interface{}, []string, error) {
queryRunner, ok := inst.(agentQueryRunner)
if !ok {
return nil, nil, fmt.Errorf("当前事务会话不支持查询语句")
}
return queryWithOptionalTimeout(queryRunner, query, timeoutMs)
}
func streamWithOptionalTimeout(inst db.StreamQueryExecer, query string, timeoutMs int64, consumer db.QueryStreamConsumer) error {
effectiveTimeoutMs := timeoutMs
if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") {
effectiveTimeoutMs = int64(legacyClickHouseDefaultTimeout / time.Millisecond)
}
if effectiveTimeoutMs <= 0 {
return inst.StreamQuery(query, consumer)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond)
defer cancel()
return inst.StreamQueryContext(ctx, query, consumer)
}
func streamBufferedQueryResult(fields []string, data []map[string]interface{}, consumer db.QueryStreamConsumer) error {
if err := consumer.SetColumns(fields); err != nil {
return err
}
if valueConsumer, ok := consumer.(db.QueryStreamValueConsumer); ok {
for _, row := range data {
values := make([]interface{}, len(fields))
for idx, field := range fields {
values[idx] = row[field]
}
if err := valueConsumer.ConsumeRowValues(values); err != nil {
return err
}
}
return nil
}
for _, row := range data {
if err := consumer.ConsumeRow(row); err != nil {
return err
}
}
return nil
}
func streamStatementWithOptionalTimeout(inst db.StatementExecer, query string, timeoutMs int64, consumer db.QueryStreamConsumer) error {
if streamer, ok := inst.(db.StreamQueryExecer); ok {
return streamWithOptionalTimeout(streamer, query, timeoutMs, consumer)
}
data, fields, err := queryStatementWithOptionalTimeout(inst, query, timeoutMs)
if err != nil {
return err
}
return streamBufferedQueryResult(fields, data, consumer)
}
func streamDatabaseWithOptionalTimeout(inst db.Database, query string, timeoutMs int64, consumer db.QueryStreamConsumer) error {
if streamer, ok := inst.(db.StreamQueryExecer); ok {
return streamWithOptionalTimeout(streamer, query, timeoutMs, consumer)
}
if provider, ok := inst.(db.SessionExecerProvider); ok {
openCtx := context.Background()
var cancel context.CancelFunc
effectiveTimeoutMs := timeoutMs
if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") {
effectiveTimeoutMs = int64(legacyClickHouseDefaultTimeout / time.Millisecond)
}
if effectiveTimeoutMs > 0 {
openCtx, cancel = context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond)
defer cancel()
}
session, err := provider.OpenSessionExecer(openCtx)
if err == nil {
defer session.Close()
return streamStatementWithOptionalTimeout(session, query, timeoutMs, consumer)
}
}
data, fields, err := queryWithOptionalTimeout(inst, query, timeoutMs)
if err != nil {
return err
}
return streamBufferedQueryResult(fields, data, consumer)
}
func execWithOptionalTimeout(inst agentExecRunner, query string, timeoutMs int64) (int64, error) {
effectiveTimeoutMs := timeoutMs
if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") {
effectiveTimeoutMs = int64(legacyClickHouseDefaultTimeout / time.Millisecond)
}
if effectiveTimeoutMs <= 0 {
return inst.Exec(query)
}
if e, ok := inst.(agentExecContextRunner); ok {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond)
defer cancel()
return e.ExecContext(ctx, query)
}
return inst.Exec(query)
}
func execStatementWithOptionalTimeout(inst db.StatementExecer, query string, timeoutMs int64) (int64, error) {
return execWithOptionalTimeout(inst, query, timeoutMs)
}
func countAgentResponseRows(data interface{}) int64 {
rows, ok := data.([]map[string]interface{})
if !ok {
return 0
}
return int64(len(rows))
}
func maybeReleaseAgentMemory(reason string, rows int64) {
if rows < agentMemoryTrimRowsThreshold {
return
}
if !agentMemoryTrimRunning.CompareAndSwap(false, true) {
return
}
runAgentMemoryTrimAsync(func() {
defer agentMemoryTrimRunning.Store(false)
if delay := nextAgentMemoryTrimDelay(); delay > 0 {
time.Sleep(delay)
}
agentMemoryTrimFn()
agentMemoryTrimLastAt.Store(time.Now().UnixNano())
})
}
func nextAgentMemoryTrimDelay() time.Duration {
lastUnixNano := agentMemoryTrimLastAt.Load()
if lastUnixNano <= 0 {
return 0
}
elapsed := time.Since(time.Unix(0, lastUnixNano))
if elapsed >= agentMemoryTrimMinInterval {
return 0
}
return agentMemoryTrimMinInterval - elapsed
}