From 495a985ae1b13649d6bf967bcc410099a9f6cfe3 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Tue, 23 Jun 2026 08:48:42 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(sqlserver):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E5=8F=AF=E9=80=89=E9=A9=B1=E5=8A=A8=E6=9F=A5=E8=AF=A2?= =?UTF-8?q?=E6=B6=88=E6=81=AF=E9=80=8F=E4=BC=A0=E7=BC=BA=E5=A4=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 为 optional-driver-agent 的 query 和 queryMulti 响应补充 messages 字段 - 在可选驱动 DB 客户端透传 SQL Server 查询提示信息与多结果集 - 补充 agent 与数据库层回归测试并更新 driver agent revision --- cmd/optional-driver-agent/main.go | 153 +++++++++++++- cmd/optional-driver-agent/main_test.go | 103 ++++++++++ internal/db/driver_agent_revisions_gen.go | 42 ++-- internal/db/optional_driver_agent_impl.go | 186 +++++++++++++----- .../db/optional_driver_agent_impl_test.go | 65 ++++++ 5 files changed, 472 insertions(+), 77 deletions(-) diff --git a/cmd/optional-driver-agent/main.go b/cmd/optional-driver-agent/main.go index 063e6ca..03d45e6 100644 --- a/cmd/optional-driver-agent/main.go +++ b/cmd/optional-driver-agent/main.go @@ -36,6 +36,7 @@ type agentResponse struct { Error string `json:"error,omitempty"` Data interface{} `json:"data,omitempty"` Fields []string `json:"fields,omitempty"` + Messages []string `json:"messages,omitempty"` ChunkType string `json:"chunkType,omitempty"` RowsAffected int64 `json:"rowsAffected,omitempty"` } @@ -48,6 +49,7 @@ const ( agentMethodOpenSession = "openSession" agentMethodCloseSession = "closeSession" agentMethodQuery = "query" + agentMethodQueryMulti = "queryMulti" agentMethodStreamQuery = "streamQuery" agentMethodExec = "exec" agentMethodGetDatabases = "getDatabases" @@ -64,9 +66,9 @@ const ( const legacyClickHouseDefaultTimeout = 2 * time.Hour const ( - agentChunkColumns = "columns" - agentChunkRows = "rows" - agentChunkDone = "done" + agentChunkColumns = "columns" + agentChunkRows = "rows" + agentChunkDone = "done" // agentStreamBatchSize 控制 driver-agent 向主进程发送 row chunk 的批次大小。 // 调小到 64:单批 JSON 编码 + 主进程解码的瞬时内存峰值降为原来的 1/4, // 代价是 IPC 次数变为 4 倍,但每批仅一次 stdin/stdout 行读写,整体影响可忽略。 @@ -236,12 +238,23 @@ func handleRequest(runtimeState *agentRuntime, req agentRequest) agentResponse { } else if ok { switch method { case agentMethodQuery: - data, fields, err := queryStatementWithOptionalTimeout(session, req.Query, req.TimeoutMs) + data, fields, messages, err := queryStatementWithMessagesOptionalTimeout(session, req.Query, req.TimeoutMs) if err != nil { return fail(resp, err.Error()) } resp.Data = data resp.Fields = fields + resp.Messages = messages + case agentMethodQueryMulti: + data, messages, supported, err := queryMultiStatementWithMessagesOptionalTimeout(session, req.Query, req.TimeoutMs) + if err != nil { + return fail(resp, err.Error()) + } + if !supported { + return fail(resp, "当前事务会话不支持多结果集查询") + } + resp.Data = data + resp.Messages = messages case agentMethodExec: affected, err := execStatementWithOptionalTimeout(session, req.Query, req.TimeoutMs) if err != nil { @@ -260,12 +273,23 @@ func handleRequest(runtimeState *agentRuntime, req agentRequest) agentResponse { return fail(resp, err.Error()) } case agentMethodQuery: - data, fields, err := queryWithOptionalTimeout(runtimeState.inst, req.Query, req.TimeoutMs) + data, fields, messages, err := queryWithMessagesOptionalTimeout(runtimeState.inst, req.Query, req.TimeoutMs) if err != nil { return fail(resp, err.Error()) } resp.Data = data resp.Fields = fields + resp.Messages = messages + case agentMethodQueryMulti: + data, messages, supported, err := queryMultiWithMessagesOptionalTimeout(runtimeState.inst, req.Query, req.TimeoutMs) + if err != nil { + return fail(resp, err.Error()) + } + if !supported { + return fail(resp, "当前驱动不支持原生多结果集查询") + } + resp.Data = data + resp.Messages = messages case agentMethodExec: affected, err := execWithOptionalTimeout(runtimeState.inst, req.Query, req.TimeoutMs) if err != nil { @@ -581,6 +605,30 @@ type agentQueryContextRunner interface { QueryContext(context.Context, string) ([]map[string]interface{}, []string, error) } +type agentQueryMessageRunner interface { + QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) +} + +type agentQueryMessageContextRunner interface { + QueryContextWithMessages(context.Context, string) ([]map[string]interface{}, []string, []string, error) +} + +type agentMultiResultMessageRunner interface { + QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) +} + +type agentMultiResultMessageContextRunner interface { + QueryMultiContextWithMessages(context.Context, string) ([]connection.ResultSetData, []string, error) +} + +type agentMultiResultRunner interface { + QueryMulti(query string) ([]connection.ResultSetData, error) +} + +type agentMultiResultContextRunner interface { + QueryMultiContext(context.Context, string) ([]connection.ResultSetData, error) +} + type agentExecRunner interface { Exec(string) (int64, error) } @@ -589,20 +637,39 @@ type agentExecContextRunner interface { ExecContext(context.Context, string) (int64, error) } -func queryWithOptionalTimeout(inst agentQueryRunner, query string, timeoutMs int64) ([]map[string]interface{}, []string, error) { +func queryWithMessagesOptionalTimeout(inst agentQueryRunner, query string, timeoutMs int64) ([]map[string]interface{}, []string, []string, error) { effectiveTimeoutMs := timeoutMs if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") { effectiveTimeoutMs = int64(legacyClickHouseDefaultTimeout / time.Millisecond) } if effectiveTimeoutMs <= 0 { - return inst.Query(query) + if q, ok := inst.(agentQueryMessageRunner); ok { + return q.QueryWithMessages(query) + } + data, fields, err := inst.Query(query) + return data, fields, nil, err + } + if q, ok := inst.(agentQueryMessageContextRunner); ok { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond) + defer cancel() + return q.QueryContextWithMessages(ctx, query) } if q, ok := inst.(agentQueryContextRunner); ok { ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond) defer cancel() - return q.QueryContext(ctx, query) + data, fields, err := q.QueryContext(ctx, query) + return data, fields, nil, err } - return inst.Query(query) + if q, ok := inst.(agentQueryMessageRunner); ok { + return q.QueryWithMessages(query) + } + data, fields, err := inst.Query(query) + return data, fields, nil, err +} + +func queryWithOptionalTimeout(inst agentQueryRunner, query string, timeoutMs int64) ([]map[string]interface{}, []string, error) { + data, fields, _, err := queryWithMessagesOptionalTimeout(inst, query, timeoutMs) + return data, fields, err } func queryStatementWithOptionalTimeout(inst db.StatementExecer, query string, timeoutMs int64) ([]map[string]interface{}, []string, error) { @@ -613,6 +680,74 @@ func queryStatementWithOptionalTimeout(inst db.StatementExecer, query string, ti return queryWithOptionalTimeout(queryRunner, query, timeoutMs) } +func queryStatementWithMessagesOptionalTimeout(inst db.StatementExecer, query string, timeoutMs int64) ([]map[string]interface{}, []string, []string, error) { + queryRunner, ok := inst.(agentQueryRunner) + if !ok { + return nil, nil, nil, fmt.Errorf("当前事务会话不支持查询语句") + } + return queryWithMessagesOptionalTimeout(queryRunner, query, timeoutMs) +} + +func queryMultiWithMessagesOptionalTimeout(inst db.Database, query string, timeoutMs int64) ([]connection.ResultSetData, []string, bool, error) { + effectiveTimeoutMs := timeoutMs + if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") { + effectiveTimeoutMs = int64(legacyClickHouseDefaultTimeout / time.Millisecond) + } + if effectiveTimeoutMs > 0 { + if q, ok := inst.(agentMultiResultMessageContextRunner); ok { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond) + defer cancel() + data, messages, err := q.QueryMultiContextWithMessages(ctx, query) + return data, messages, true, err + } + if q, ok := inst.(agentMultiResultContextRunner); ok { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond) + defer cancel() + data, err := q.QueryMultiContext(ctx, query) + return data, nil, true, err + } + } + if q, ok := inst.(agentMultiResultMessageRunner); ok { + data, messages, err := q.QueryMultiWithMessages(query) + return data, messages, true, err + } + if q, ok := inst.(agentMultiResultRunner); ok { + data, err := q.QueryMulti(query) + return data, nil, true, err + } + return nil, nil, false, nil +} + +func queryMultiStatementWithMessagesOptionalTimeout(inst db.StatementExecer, query string, timeoutMs int64) ([]connection.ResultSetData, []string, bool, error) { + effectiveTimeoutMs := timeoutMs + if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") { + effectiveTimeoutMs = int64(legacyClickHouseDefaultTimeout / time.Millisecond) + } + if effectiveTimeoutMs > 0 { + if q, ok := inst.(agentMultiResultMessageContextRunner); ok { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond) + defer cancel() + data, messages, err := q.QueryMultiContextWithMessages(ctx, query) + return data, messages, true, err + } + if q, ok := inst.(agentMultiResultContextRunner); ok { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond) + defer cancel() + data, err := q.QueryMultiContext(ctx, query) + return data, nil, true, err + } + } + if q, ok := inst.(agentMultiResultMessageRunner); ok { + data, messages, err := q.QueryMultiWithMessages(query) + return data, messages, true, err + } + if q, ok := inst.(agentMultiResultRunner); ok { + data, err := q.QueryMulti(query) + return data, nil, true, err + } + return nil, nil, false, nil +} + func streamWithOptionalTimeout(inst db.StreamQueryExecer, query string, timeoutMs int64, consumer db.QueryStreamConsumer) error { effectiveTimeoutMs := timeoutMs if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") { diff --git a/cmd/optional-driver-agent/main_test.go b/cmd/optional-driver-agent/main_test.go index eac28d6..5db34e9 100644 --- a/cmd/optional-driver-agent/main_test.go +++ b/cmd/optional-driver-agent/main_test.go @@ -101,6 +101,9 @@ type fakeAgentTimeoutDB struct { execCalled bool execContextCalled bool deadlineSet bool + queryMessages []string + multiResults []connection.ResultSetData + multiMessages []string } func (f *fakeAgentTimeoutDB) Connect(config connection.ConnectionConfig) error { return nil } @@ -117,6 +120,14 @@ func (f *fakeAgentTimeoutDB) QueryContext(ctx context.Context, query string) ([] } return []map[string]interface{}{{"ok": 1}}, []string{"ok"}, nil } +func (f *fakeAgentTimeoutDB) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) { + data, fields, err := f.QueryContext(context.Background(), query) + return data, fields, append([]string(nil), f.queryMessages...), err +} +func (f *fakeAgentTimeoutDB) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) { + data, fields, err := f.QueryContext(ctx, query) + return data, fields, append([]string(nil), f.queryMessages...), err +} func (f *fakeAgentTimeoutDB) Exec(query string) (int64, error) { f.execCalled = true return 0, errors.New("exec should not be called") @@ -150,6 +161,15 @@ func (f *fakeAgentTimeoutDB) GetForeignKeys(dbName, tableName string) ([]connect func (f *fakeAgentTimeoutDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) { return nil, nil } +func (f *fakeAgentTimeoutDB) QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) { + return append([]connection.ResultSetData(nil), f.multiResults...), append([]string(nil), f.multiMessages...), nil +} +func (f *fakeAgentTimeoutDB) QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error) { + if _, ok := ctx.Deadline(); ok { + f.deadlineSet = true + } + return f.QueryMultiWithMessages(query) +} type fakeAgentSessionDB struct { fakeAgentTimeoutDB @@ -165,6 +185,7 @@ type fakeAgentStatementSession struct { queryCalls int execCalls int closed bool + messages []string } func (f *fakeAgentStatementSession) Query(query string) ([]map[string]interface{}, []string, error) { @@ -175,6 +196,14 @@ func (f *fakeAgentStatementSession) QueryContext(ctx context.Context, query stri f.queryCalls++ return []map[string]interface{}{{"session_ok": 1}}, []string{"session_ok"}, nil } +func (f *fakeAgentStatementSession) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) { + data, fields, err := f.QueryContext(context.Background(), query) + return data, fields, append([]string(nil), f.messages...), err +} +func (f *fakeAgentStatementSession) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) { + data, fields, err := f.QueryContext(ctx, query) + return data, fields, append([]string(nil), f.messages...), err +} func (f *fakeAgentStatementSession) Exec(query string) (int64, error) { return f.ExecContext(context.Background(), query) @@ -297,6 +326,77 @@ func TestQueryWithOptionalTimeout_ClickHouseLegacyModeUsesQueryContext(t *testin } } +func TestHandleRequest_QueryIncludesServerMessages(t *testing.T) { + old := agentDriverType + defer func() { agentDriverType = old }() + agentDriverType = "sqlserver" + + fake := &fakeAgentTimeoutDB{ + queryMessages: []string{"PRINT sql line 1", "PRINT sql line 2"}, + } + runtimeState := &agentRuntime{inst: fake, sessions: make(map[string]db.StatementExecer)} + + resp := handleRequest(runtimeState, agentRequest{ + ID: 11, + Method: agentMethodQuery, + Query: "exec dbo.p_get_select", + TimeoutMs: int64((2 * time.Second).Milliseconds()), + }) + if !resp.Success { + t.Fatalf("query request failed: %s", resp.Error) + } + if len(resp.Messages) != 2 || resp.Messages[0] != "PRINT sql line 1" { + t.Fatalf("expected query messages to be preserved, got %#v", resp.Messages) + } +} + +func TestHandleRequest_QueryMultiIncludesResultSetsAndMessages(t *testing.T) { + old := agentDriverType + defer func() { agentDriverType = old }() + agentDriverType = "sqlserver" + + fake := &fakeAgentTimeoutDB{ + multiResults: []connection.ResultSetData{ + { + StatementIndex: 1, + Rows: []map[string]interface{}{{"name": "master"}}, + Columns: []string{"name"}, + }, + { + StatementIndex: 1, + Rows: []map[string]interface{}{}, + Columns: []string{}, + Messages: []string{"PRINT generated sql"}, + }, + }, + multiMessages: []string{"batch top-level message"}, + } + runtimeState := &agentRuntime{inst: fake, sessions: make(map[string]db.StatementExecer)} + + resp := handleRequest(runtimeState, agentRequest{ + ID: 12, + Method: agentMethodQueryMulti, + Query: "exec dbo.p_get_select", + TimeoutMs: int64((2 * time.Second).Milliseconds()), + }) + if !resp.Success { + t.Fatalf("queryMulti request failed: %s", resp.Error) + } + if len(resp.Messages) != 1 || resp.Messages[0] != "batch top-level message" { + t.Fatalf("expected top-level messages to be preserved, got %#v", resp.Messages) + } + resultSets, ok := resp.Data.([]connection.ResultSetData) + if !ok { + t.Fatalf("expected []connection.ResultSetData, got %T", resp.Data) + } + if len(resultSets) != 2 { + t.Fatalf("expected 2 result sets, got %#v", resultSets) + } + if len(resultSets[1].Messages) != 1 || resultSets[1].Messages[0] != "PRINT generated sql" { + t.Fatalf("expected message-only result set to be preserved, got %#v", resultSets[1]) + } +} + func TestHandleRequest_UsesPinnedSessionForSessionScopedQueryAndExec(t *testing.T) { old := agentDriverType defer func() { agentDriverType = old }() @@ -329,6 +429,9 @@ func TestHandleRequest_UsesPinnedSessionForSessionScopedQueryAndExec(t *testing. if !queryResp.Success { t.Fatalf("session query failed: %s", queryResp.Error) } + if len(queryResp.Messages) != 0 { + t.Fatalf("expected empty default session messages, got %#v", queryResp.Messages) + } if fake.queryCalled || fake.queryContextCalled { t.Fatalf("expected session query to bypass database-level query path, got Query=%v QueryContext=%v", fake.queryCalled, fake.queryContextCalled) } diff --git a/internal/db/driver_agent_revisions_gen.go b/internal/db/driver_agent_revisions_gen.go index e05b8c9..12e7049 100644 --- a/internal/db/driver_agent_revisions_gen.go +++ b/internal/db/driver_agent_revisions_gen.go @@ -4,26 +4,26 @@ package db func init() { optionalDriverAgentRevisions = map[string]string{ - "mariadb": "src-b23e2ce1581a5064", - "oceanbase": "src-5067dbdf0ca7b9c4", - "diros": "src-db43faca6bf15d9b", - "starrocks": "src-01e9f06c0fab09d5", - "sphinx": "src-38ee5cae952cc809", - "sqlserver": "src-7a87f6deb816f110", - "sqlite": "src-d3d439cd788880e2", - "duckdb": "src-b11506b8706bfb73", - "dameng": "src-1638124bfd7fce09", - "kingbase": "src-fb3a404cf4eb1bd9", - "highgo": "src-72fe51afa884f6bc", - "vastbase": "src-3d48607603bfd8b7", - "opengauss": "src-709acf442f016e30", - "gaussdb": "src-f6beccc924d71031", - "iris": "src-9ebf5b970a73b341", - "mongodb": "src-367d11cd04e982c1", - "tdengine": "src-3c13c42f18ba01e1", - "iotdb": "src-5ba9da13c6a272f9", - "clickhouse": "src-99c8babfefdf142c", - "elasticsearch": "src-36b2e2b5f49db9d1", - "trino": "src-d264ceca132c185c", + "mariadb": "src-cc133d2524ceb634", + "oceanbase": "src-ac17327184366ff0", + "diros": "src-7d4fe439271d0c56", + "starrocks": "src-ce9ee22641a32f46", + "sphinx": "src-08f5ae54efb3d9df", + "sqlserver": "src-33b3b2c6dad5b3e6", + "sqlite": "src-96dfa25b3042b2d5", + "duckdb": "src-8804eb2cdbc89433", + "dameng": "src-016e77082aea6718", + "kingbase": "src-17728b2ebda94dc9", + "highgo": "src-da2e8a9d2e661d3b", + "vastbase": "src-da186ac367206c16", + "opengauss": "src-54dc852e4c502947", + "gaussdb": "src-3bbbffc6991dc8ae", + "iris": "src-e798713e492e9a09", + "mongodb": "src-2610395b35c2e708", + "tdengine": "src-779b9b537f08856f", + "iotdb": "src-7edea4aba8d4869e", + "clickhouse": "src-0197342ca5afa8b5", + "elasticsearch": "src-08e8e80cb17a409a", + "trino": "src-ba947f211ce7b19f", } } diff --git a/internal/db/optional_driver_agent_impl.go b/internal/db/optional_driver_agent_impl.go index 3ec52b8..07b0762 100644 --- a/internal/db/optional_driver_agent_impl.go +++ b/internal/db/optional_driver_agent_impl.go @@ -29,6 +29,7 @@ const ( optionalAgentMethodOpenSession = "openSession" optionalAgentMethodCloseSession = "closeSession" optionalAgentMethodQuery = "query" + optionalAgentMethodQueryMulti = "queryMulti" optionalAgentMethodStreamQuery = "streamQuery" optionalAgentMethodExec = "exec" optionalAgentMethodGetDatabases = "getDatabases" @@ -75,6 +76,7 @@ type optionalAgentResponse struct { Error string `json:"error,omitempty"` Data json.RawMessage `json:"data,omitempty"` Fields []string `json:"fields,omitempty"` + Messages []string `json:"messages,omitempty"` ChunkType string `json:"chunkType,omitempty"` RowsAffected int64 `json:"rowsAffected,omitempty"` } @@ -106,7 +108,7 @@ func ProbeOptionalDriverAgentMetadata(driverType string, executablePath string) }() var metadata OptionalDriverAgentMetadata - if err := client.callWithTimeout(optionalAgentRequest{Method: optionalAgentMethodMetadata}, &metadata, nil, nil, optionalAgentMetadataProbeTimeout); err != nil { + if err := client.callWithTimeout(optionalAgentRequest{Method: optionalAgentMethodMetadata}, &metadata, nil, nil, nil, optionalAgentMetadataProbeTimeout); err != nil { return OptionalDriverAgentMetadata{}, err } metadata.DriverType = normalizeRuntimeDriverType(metadata.DriverType) @@ -208,7 +210,7 @@ func (c *optionalDriverAgentClient) stderrText() string { return strings.TrimSpace(c.stderr.String()) } -func (c *optionalDriverAgentClient) call(req optionalAgentRequest, out interface{}, fields *[]string, rowsAffected *int64) error { +func (c *optionalDriverAgentClient) call(req optionalAgentRequest, out interface{}, fields *[]string, messages *[]string, rowsAffected *int64) error { c.mu.Lock() defer c.mu.Unlock() @@ -252,6 +254,9 @@ func (c *optionalDriverAgentClient) call(req optionalAgentRequest, out interface if fields != nil { *fields = resp.Fields } + if messages != nil { + *messages = append((*messages)[:0], resp.Messages...) + } if rowsAffected != nil { *rowsAffected = resp.RowsAffected } @@ -263,14 +268,14 @@ func (c *optionalDriverAgentClient) call(req optionalAgentRequest, out interface return nil } -func (c *optionalDriverAgentClient) callWithTimeout(req optionalAgentRequest, out interface{}, fields *[]string, rowsAffected *int64, timeout time.Duration) error { +func (c *optionalDriverAgentClient) callWithTimeout(req optionalAgentRequest, out interface{}, fields *[]string, messages *[]string, rowsAffected *int64, timeout time.Duration) error { if timeout <= 0 { - return c.call(req, out, fields, rowsAffected) + return c.call(req, out, fields, messages, rowsAffected) } errCh := make(chan error, 1) go func() { - errCh <- c.call(req, out, fields, rowsAffected) + errCh <- c.call(req, out, fields, messages, rowsAffected) }() timer := time.NewTimer(timeout) @@ -469,7 +474,7 @@ func (d *OptionalDriverAgentDB) Connect(config connection.ConnectionConfig) erro if err := client.call(optionalAgentRequest{ Method: optionalAgentMethodConnect, Config: &config, - }, nil, nil, nil); err != nil { + }, nil, nil, nil, nil); err != nil { _ = client.close() return err } @@ -482,7 +487,7 @@ func (d *OptionalDriverAgentDB) Close() error { if d.client == nil { return nil } - _ = d.client.call(optionalAgentRequest{Method: optionalAgentMethodClose}, nil, nil, nil) + _ = d.client.call(optionalAgentRequest{Method: optionalAgentMethodClose}, nil, nil, nil, nil) err := d.client.close() d.client = nil return err @@ -493,10 +498,87 @@ func (d *OptionalDriverAgentDB) Ping() error { if err != nil { return err } - return client.call(optionalAgentRequest{Method: optionalAgentMethodPing}, nil, nil, nil) + return client.call(optionalAgentRequest{Method: optionalAgentMethodPing}, nil, nil, nil, nil) } func (d *OptionalDriverAgentDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { + data, fields, _, err := d.QueryContextWithMessages(ctx, query) + return data, fields, err +} + +func (d *OptionalDriverAgentDB) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) { + if err := ctx.Err(); err != nil { + return nil, nil, nil, err + } + client, err := d.requireClient() + if err != nil { + return nil, nil, nil, err + } + var data []map[string]interface{} + var fields []string + var messages []string + if err := client.call(optionalAgentRequest{ + Method: optionalAgentMethodQuery, + Query: query, + TimeoutMs: timeoutMsFromContext(ctx), + }, &data, &fields, &messages, nil); err != nil { + return nil, nil, nil, err + } + return data, fields, messages, nil +} + +func (d *OptionalDriverAgentDB) Query(query string) ([]map[string]interface{}, []string, error) { + data, fields, _, err := d.QueryWithMessages(query) + return data, fields, err +} + +func (d *OptionalDriverAgentDB) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) { + client, err := d.requireClient() + if err != nil { + return nil, nil, nil, err + } + var data []map[string]interface{} + var fields []string + var messages []string + if err := client.call(optionalAgentRequest{ + Method: optionalAgentMethodQuery, + Query: query, + }, &data, &fields, &messages, nil); err != nil { + return nil, nil, nil, err + } + return data, fields, messages, nil +} + +func (d *OptionalDriverAgentDB) QueryMulti(query string) ([]connection.ResultSetData, error) { + results, _, err := d.QueryMultiWithMessages(query) + return results, err +} + +func (d *OptionalDriverAgentDB) QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) { + client, err := d.requireClient() + if err != nil { + return nil, nil, err + } + var results []connection.ResultSetData + var messages []string + if err := client.call(optionalAgentRequest{ + Method: optionalAgentMethodQueryMulti, + Query: query, + }, &results, nil, &messages, nil); err != nil { + if isOptionalAgentMultiResultUnsupportedError(err) { + return nil, nil, nil + } + return nil, nil, err + } + return results, messages, nil +} + +func (d *OptionalDriverAgentDB) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) { + results, _, err := d.QueryMultiContextWithMessages(ctx, query) + return results, err +} + +func (d *OptionalDriverAgentDB) QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error) { if err := ctx.Err(); err != nil { return nil, nil, err } @@ -504,32 +586,19 @@ func (d *OptionalDriverAgentDB) QueryContext(ctx context.Context, query string) if err != nil { return nil, nil, err } - var data []map[string]interface{} - var fields []string + var results []connection.ResultSetData + var messages []string if err := client.call(optionalAgentRequest{ - Method: optionalAgentMethodQuery, + Method: optionalAgentMethodQueryMulti, Query: query, TimeoutMs: timeoutMsFromContext(ctx), - }, &data, &fields, nil); err != nil { + }, &results, nil, &messages, nil); err != nil { + if isOptionalAgentMultiResultUnsupportedError(err) { + return nil, nil, nil + } return nil, nil, err } - return data, fields, nil -} - -func (d *OptionalDriverAgentDB) Query(query string) ([]map[string]interface{}, []string, error) { - client, err := d.requireClient() - if err != nil { - return nil, nil, err - } - var data []map[string]interface{} - var fields []string - if err := client.call(optionalAgentRequest{ - Method: optionalAgentMethodQuery, - Query: query, - }, &data, &fields, nil); err != nil { - return nil, nil, err - } - return data, fields, nil + return results, messages, nil } func (d *OptionalDriverAgentDB) StreamQuery(query string, consumer QueryStreamConsumer) error { @@ -581,7 +650,7 @@ func (d *OptionalDriverAgentDB) ExecContext(ctx context.Context, query string) ( Method: optionalAgentMethodExec, Query: query, TimeoutMs: timeoutMsFromContext(ctx), - }, nil, nil, &affected); err != nil { + }, nil, nil, nil, &affected); err != nil { return 0, err } return affected, nil @@ -596,7 +665,7 @@ func (d *OptionalDriverAgentDB) Exec(query string) (int64, error) { if err := client.call(optionalAgentRequest{ Method: optionalAgentMethodExec, Query: query, - }, nil, nil, &affected); err != nil { + }, nil, nil, nil, &affected); err != nil { return 0, err } return affected, nil @@ -611,7 +680,7 @@ func (d *OptionalDriverAgentDB) OpenSessionExecer(ctx context.Context) (Statemen if err := client.call(optionalAgentRequest{ Method: optionalAgentMethodOpenSession, TimeoutMs: timeoutMsFromContext(ctx), - }, &sessionID, nil, nil); err != nil { + }, &sessionID, nil, nil, nil); err != nil { return nil, err } sessionID = strings.TrimSpace(sessionID) @@ -629,6 +698,10 @@ func (s *optionalDriverAgentSession) Query(query string) ([]map[string]interface return s.QueryContext(context.Background(), query) } +func (s *optionalDriverAgentSession) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) { + return s.QueryContextWithMessages(context.Background(), query) +} + func (s *optionalDriverAgentSession) StreamQuery(query string, consumer QueryStreamConsumer) error { return s.StreamQueryContext(context.Background(), query, consumer) } @@ -663,20 +736,26 @@ func (s *optionalDriverAgentSession) StreamQueryContext(ctx context.Context, que } func (s *optionalDriverAgentSession) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { + data, fields, _, err := s.QueryContextWithMessages(ctx, query) + return data, fields, err +} + +func (s *optionalDriverAgentSession) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) { if err := s.ensureOpen(); err != nil { - return nil, nil, err + return nil, nil, nil, err } var data []map[string]interface{} var fields []string + var messages []string if err := s.client.call(optionalAgentRequest{ Method: optionalAgentMethodQuery, SessionID: s.sessionID, Query: query, TimeoutMs: timeoutMsFromContext(ctx), - }, &data, &fields, nil); err != nil { - return nil, nil, err + }, &data, &fields, &messages, nil); err != nil { + return nil, nil, nil, err } - return data, fields, nil + return data, fields, messages, nil } func (s *optionalDriverAgentSession) Exec(query string) (int64, error) { @@ -693,7 +772,7 @@ func (s *optionalDriverAgentSession) ExecContext(ctx context.Context, query stri SessionID: s.sessionID, Query: query, TimeoutMs: timeoutMsFromContext(ctx), - }, nil, nil, &affected); err != nil { + }, nil, nil, nil, &affected); err != nil { return 0, err } return affected, nil @@ -714,7 +793,7 @@ func (s *optionalDriverAgentSession) Close() error { return s.client.call(optionalAgentRequest{ Method: optionalAgentMethodCloseSession, SessionID: sessionID, - }, nil, nil, nil) + }, nil, nil, nil, nil) } func (s *optionalDriverAgentSession) ensureOpen() error { @@ -740,6 +819,19 @@ func isOptionalAgentStreamUnsupportedError(err error) bool { return strings.Contains(text, "不支持的方法") || strings.Contains(text, "不支持流式查询") } +func isOptionalAgentMultiResultUnsupportedError(err error) bool { + if err == nil { + return false + } + text := strings.TrimSpace(err.Error()) + if text == "" { + return false + } + return strings.Contains(text, "不支持的方法") || + strings.Contains(text, "不支持原生多结果集查询") || + strings.Contains(text, "不支持多结果集查询") +} + func (d *OptionalDriverAgentDB) GetDatabases() ([]string, error) { client, err := d.requireClient() if err != nil { @@ -748,7 +840,7 @@ func (d *OptionalDriverAgentDB) GetDatabases() ([]string, error) { var dbs []string if err := client.call(optionalAgentRequest{ Method: optionalAgentMethodGetDatabases, - }, &dbs, nil, nil); err != nil { + }, &dbs, nil, nil, nil); err != nil { return nil, err } return dbs, nil @@ -763,7 +855,7 @@ func (d *OptionalDriverAgentDB) GetTables(dbName string) ([]string, error) { if err := client.call(optionalAgentRequest{ Method: optionalAgentMethodGetTables, DBName: dbName, - }, &tables, nil, nil); err != nil { + }, &tables, nil, nil, nil); err != nil { return nil, err } return tables, nil @@ -779,7 +871,7 @@ func (d *OptionalDriverAgentDB) GetCreateStatement(dbName, tableName string) (st Method: optionalAgentMethodGetCreateStmt, DBName: dbName, TableName: tableName, - }, &sqlText, nil, nil); err != nil { + }, &sqlText, nil, nil, nil); err != nil { return "", err } return sqlText, nil @@ -795,7 +887,7 @@ func (d *OptionalDriverAgentDB) GetColumns(dbName, tableName string) ([]connecti Method: optionalAgentMethodGetColumns, DBName: dbName, TableName: tableName, - }, &columns, nil, nil); err != nil { + }, &columns, nil, nil, nil); err != nil { return nil, err } return columns, nil @@ -810,7 +902,7 @@ func (d *OptionalDriverAgentDB) GetAllColumns(dbName string) ([]connection.Colum if err := client.call(optionalAgentRequest{ Method: optionalAgentMethodGetAllColumns, DBName: dbName, - }, &columns, nil, nil); err != nil { + }, &columns, nil, nil, nil); err != nil { return nil, err } return columns, nil @@ -826,7 +918,7 @@ func (d *OptionalDriverAgentDB) GetIndexes(dbName, tableName string) ([]connecti Method: optionalAgentMethodGetIndexes, DBName: dbName, TableName: tableName, - }, &indexes, nil, nil); err != nil { + }, &indexes, nil, nil, nil); err != nil { return nil, err } return indexes, nil @@ -842,7 +934,7 @@ func (d *OptionalDriverAgentDB) GetForeignKeys(dbName, tableName string) ([]conn Method: optionalAgentMethodGetForeignKeys, DBName: dbName, TableName: tableName, - }, &keys, nil, nil); err != nil { + }, &keys, nil, nil, nil); err != nil { return nil, err } return keys, nil @@ -858,7 +950,7 @@ func (d *OptionalDriverAgentDB) GetTriggers(dbName, tableName string) ([]connect Method: optionalAgentMethodGetTriggers, DBName: dbName, TableName: tableName, - }, &triggers, nil, nil); err != nil { + }, &triggers, nil, nil, nil); err != nil { return nil, err } return triggers, nil @@ -883,7 +975,7 @@ func (d *OptionalDriverAgentDB) ApplyChanges(tableName string, changes connectio Method: optionalAgentMethodApplyChanges, TableName: tableName, Changes: &changes, - }, nil, nil, nil) + }, nil, nil, nil, nil) } func (d *OptionalDriverAgentDB) requireClient() (*optionalDriverAgentClient, error) { diff --git a/internal/db/optional_driver_agent_impl_test.go b/internal/db/optional_driver_agent_impl_test.go index 90600db..e1612ac 100644 --- a/internal/db/optional_driver_agent_impl_test.go +++ b/internal/db/optional_driver_agent_impl_test.go @@ -136,3 +136,68 @@ func TestOptionalDriverAgentClientCallStreamQueryConsumesChunks(t *testing.T) { t.Fatalf("请求未使用 streamQuery 方法: %s", stdin.String()) } } + +func TestOptionalDriverAgentDBQueryWithMessagesParsesAgentMessages(t *testing.T) { + var stdin optionalAgentTestWriteCloser + stdout := `{"id":1,"success":true,"data":[{"sql_text":"select 1"}],"fields":["sql_text"],"messages":["PRINT sql line 1","PRINT sql line 2"]}` + "\n" + + dbInst := &OptionalDriverAgentDB{ + driverType: "sqlserver", + client: &optionalDriverAgentClient{ + stdin: &stdin, + reader: bufio.NewReader(strings.NewReader(stdout)), + driver: "sqlserver", + }, + } + + rows, fields, messages, err := dbInst.QueryWithMessages("exec dbo.p_get_select") + if err != nil { + t.Fatalf("QueryWithMessages 返回错误: %v", err) + } + if len(rows) != 1 || rows[0]["sql_text"] != "select 1" { + t.Fatalf("查询结果异常: %#v", rows) + } + if len(fields) != 1 || fields[0] != "sql_text" { + t.Fatalf("字段异常: %#v", fields) + } + if len(messages) != 2 || messages[0] != "PRINT sql line 1" { + t.Fatalf("消息异常: %#v", messages) + } + if !strings.Contains(stdin.String(), `"method":"query"`) { + t.Fatalf("请求未使用 query 方法: %s", stdin.String()) + } +} + +func TestOptionalDriverAgentDBQueryMultiWithMessagesParsesResultSets(t *testing.T) { + var stdin optionalAgentTestWriteCloser + stdout := `{"id":1,"success":true,"data":[{"statementIndex":1,"rows":[{"name":"master"}],"columns":["name"]},{"statementIndex":1,"rows":[],"columns":[],"messages":["PRINT generated sql"]}],"messages":["batch top-level message"]}` + "\n" + + dbInst := &OptionalDriverAgentDB{ + driverType: "sqlserver", + client: &optionalDriverAgentClient{ + stdin: &stdin, + reader: bufio.NewReader(strings.NewReader(stdout)), + driver: "sqlserver", + }, + } + + resultSets, messages, err := dbInst.QueryMultiWithMessages("exec dbo.p_get_select") + if err != nil { + t.Fatalf("QueryMultiWithMessages 返回错误: %v", err) + } + if len(resultSets) != 2 { + t.Fatalf("结果集数量异常: %#v", resultSets) + } + if got := resultSets[0].Rows[0]["name"]; got != "master" { + t.Fatalf("首个结果集异常,got=%v", got) + } + if len(resultSets[1].Messages) != 1 || resultSets[1].Messages[0] != "PRINT generated sql" { + t.Fatalf("消息结果集异常: %#v", resultSets[1]) + } + if len(messages) != 1 || messages[0] != "batch top-level message" { + t.Fatalf("顶层消息异常: %#v", messages) + } + if !strings.Contains(stdin.String(), `"method":"queryMulti"`) { + t.Fatalf("请求未使用 queryMulti 方法: %s", stdin.String()) + } +}