🐛 fix(sqlserver): 修复可选驱动查询消息透传缺失

- 为 optional-driver-agent 的 query 和 queryMulti 响应补充 messages 字段
- 在可选驱动 DB 客户端透传 SQL Server 查询提示信息与多结果集
- 补充 agent 与数据库层回归测试并更新 driver agent revision
This commit is contained in:
Syngnat
2026-06-23 08:48:42 +08:00
parent 8f1e6cf379
commit 495a985ae1
5 changed files with 472 additions and 77 deletions

View File

@@ -36,6 +36,7 @@ type agentResponse struct {
Error string `json:"error,omitempty"`
Data interface{} `json:"data,omitempty"`
Fields []string `json:"fields,omitempty"`
Messages []string `json:"messages,omitempty"`
ChunkType string `json:"chunkType,omitempty"`
RowsAffected int64 `json:"rowsAffected,omitempty"`
}
@@ -48,6 +49,7 @@ const (
agentMethodOpenSession = "openSession"
agentMethodCloseSession = "closeSession"
agentMethodQuery = "query"
agentMethodQueryMulti = "queryMulti"
agentMethodStreamQuery = "streamQuery"
agentMethodExec = "exec"
agentMethodGetDatabases = "getDatabases"
@@ -64,9 +66,9 @@ const (
const legacyClickHouseDefaultTimeout = 2 * time.Hour
const (
agentChunkColumns = "columns"
agentChunkRows = "rows"
agentChunkDone = "done"
agentChunkColumns = "columns"
agentChunkRows = "rows"
agentChunkDone = "done"
// agentStreamBatchSize 控制 driver-agent 向主进程发送 row chunk 的批次大小。
// 调小到 64单批 JSON 编码 + 主进程解码的瞬时内存峰值降为原来的 1/4
// 代价是 IPC 次数变为 4 倍,但每批仅一次 stdin/stdout 行读写,整体影响可忽略。
@@ -236,12 +238,23 @@ func handleRequest(runtimeState *agentRuntime, req agentRequest) agentResponse {
} else if ok {
switch method {
case agentMethodQuery:
data, fields, err := queryStatementWithOptionalTimeout(session, req.Query, req.TimeoutMs)
data, fields, messages, err := queryStatementWithMessagesOptionalTimeout(session, req.Query, req.TimeoutMs)
if err != nil {
return fail(resp, err.Error())
}
resp.Data = data
resp.Fields = fields
resp.Messages = messages
case agentMethodQueryMulti:
data, messages, supported, err := queryMultiStatementWithMessagesOptionalTimeout(session, req.Query, req.TimeoutMs)
if err != nil {
return fail(resp, err.Error())
}
if !supported {
return fail(resp, "当前事务会话不支持多结果集查询")
}
resp.Data = data
resp.Messages = messages
case agentMethodExec:
affected, err := execStatementWithOptionalTimeout(session, req.Query, req.TimeoutMs)
if err != nil {
@@ -260,12 +273,23 @@ func handleRequest(runtimeState *agentRuntime, req agentRequest) agentResponse {
return fail(resp, err.Error())
}
case agentMethodQuery:
data, fields, err := queryWithOptionalTimeout(runtimeState.inst, req.Query, req.TimeoutMs)
data, fields, messages, err := queryWithMessagesOptionalTimeout(runtimeState.inst, req.Query, req.TimeoutMs)
if err != nil {
return fail(resp, err.Error())
}
resp.Data = data
resp.Fields = fields
resp.Messages = messages
case agentMethodQueryMulti:
data, messages, supported, err := queryMultiWithMessagesOptionalTimeout(runtimeState.inst, req.Query, req.TimeoutMs)
if err != nil {
return fail(resp, err.Error())
}
if !supported {
return fail(resp, "当前驱动不支持原生多结果集查询")
}
resp.Data = data
resp.Messages = messages
case agentMethodExec:
affected, err := execWithOptionalTimeout(runtimeState.inst, req.Query, req.TimeoutMs)
if err != nil {
@@ -581,6 +605,30 @@ type agentQueryContextRunner interface {
QueryContext(context.Context, string) ([]map[string]interface{}, []string, error)
}
type agentQueryMessageRunner interface {
QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error)
}
type agentQueryMessageContextRunner interface {
QueryContextWithMessages(context.Context, string) ([]map[string]interface{}, []string, []string, error)
}
type agentMultiResultMessageRunner interface {
QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error)
}
type agentMultiResultMessageContextRunner interface {
QueryMultiContextWithMessages(context.Context, string) ([]connection.ResultSetData, []string, error)
}
type agentMultiResultRunner interface {
QueryMulti(query string) ([]connection.ResultSetData, error)
}
type agentMultiResultContextRunner interface {
QueryMultiContext(context.Context, string) ([]connection.ResultSetData, error)
}
type agentExecRunner interface {
Exec(string) (int64, error)
}
@@ -589,20 +637,39 @@ type agentExecContextRunner interface {
ExecContext(context.Context, string) (int64, error)
}
func queryWithOptionalTimeout(inst agentQueryRunner, query string, timeoutMs int64) ([]map[string]interface{}, []string, error) {
func queryWithMessagesOptionalTimeout(inst agentQueryRunner, query string, timeoutMs int64) ([]map[string]interface{}, []string, []string, error) {
effectiveTimeoutMs := timeoutMs
if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") {
effectiveTimeoutMs = int64(legacyClickHouseDefaultTimeout / time.Millisecond)
}
if effectiveTimeoutMs <= 0 {
return inst.Query(query)
if q, ok := inst.(agentQueryMessageRunner); ok {
return q.QueryWithMessages(query)
}
data, fields, err := inst.Query(query)
return data, fields, nil, err
}
if q, ok := inst.(agentQueryMessageContextRunner); ok {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond)
defer cancel()
return q.QueryContextWithMessages(ctx, query)
}
if q, ok := inst.(agentQueryContextRunner); ok {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond)
defer cancel()
return q.QueryContext(ctx, query)
data, fields, err := q.QueryContext(ctx, query)
return data, fields, nil, err
}
return inst.Query(query)
if q, ok := inst.(agentQueryMessageRunner); ok {
return q.QueryWithMessages(query)
}
data, fields, err := inst.Query(query)
return data, fields, nil, err
}
func queryWithOptionalTimeout(inst agentQueryRunner, query string, timeoutMs int64) ([]map[string]interface{}, []string, error) {
data, fields, _, err := queryWithMessagesOptionalTimeout(inst, query, timeoutMs)
return data, fields, err
}
func queryStatementWithOptionalTimeout(inst db.StatementExecer, query string, timeoutMs int64) ([]map[string]interface{}, []string, error) {
@@ -613,6 +680,74 @@ func queryStatementWithOptionalTimeout(inst db.StatementExecer, query string, ti
return queryWithOptionalTimeout(queryRunner, query, timeoutMs)
}
func queryStatementWithMessagesOptionalTimeout(inst db.StatementExecer, query string, timeoutMs int64) ([]map[string]interface{}, []string, []string, error) {
queryRunner, ok := inst.(agentQueryRunner)
if !ok {
return nil, nil, nil, fmt.Errorf("当前事务会话不支持查询语句")
}
return queryWithMessagesOptionalTimeout(queryRunner, query, timeoutMs)
}
func queryMultiWithMessagesOptionalTimeout(inst db.Database, query string, timeoutMs int64) ([]connection.ResultSetData, []string, bool, error) {
effectiveTimeoutMs := timeoutMs
if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") {
effectiveTimeoutMs = int64(legacyClickHouseDefaultTimeout / time.Millisecond)
}
if effectiveTimeoutMs > 0 {
if q, ok := inst.(agentMultiResultMessageContextRunner); ok {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond)
defer cancel()
data, messages, err := q.QueryMultiContextWithMessages(ctx, query)
return data, messages, true, err
}
if q, ok := inst.(agentMultiResultContextRunner); ok {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond)
defer cancel()
data, err := q.QueryMultiContext(ctx, query)
return data, nil, true, err
}
}
if q, ok := inst.(agentMultiResultMessageRunner); ok {
data, messages, err := q.QueryMultiWithMessages(query)
return data, messages, true, err
}
if q, ok := inst.(agentMultiResultRunner); ok {
data, err := q.QueryMulti(query)
return data, nil, true, err
}
return nil, nil, false, nil
}
func queryMultiStatementWithMessagesOptionalTimeout(inst db.StatementExecer, query string, timeoutMs int64) ([]connection.ResultSetData, []string, bool, error) {
effectiveTimeoutMs := timeoutMs
if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") {
effectiveTimeoutMs = int64(legacyClickHouseDefaultTimeout / time.Millisecond)
}
if effectiveTimeoutMs > 0 {
if q, ok := inst.(agentMultiResultMessageContextRunner); ok {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond)
defer cancel()
data, messages, err := q.QueryMultiContextWithMessages(ctx, query)
return data, messages, true, err
}
if q, ok := inst.(agentMultiResultContextRunner); ok {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond)
defer cancel()
data, err := q.QueryMultiContext(ctx, query)
return data, nil, true, err
}
}
if q, ok := inst.(agentMultiResultMessageRunner); ok {
data, messages, err := q.QueryMultiWithMessages(query)
return data, messages, true, err
}
if q, ok := inst.(agentMultiResultRunner); ok {
data, err := q.QueryMulti(query)
return data, nil, true, err
}
return nil, nil, false, nil
}
func streamWithOptionalTimeout(inst db.StreamQueryExecer, query string, timeoutMs int64, consumer db.QueryStreamConsumer) error {
effectiveTimeoutMs := timeoutMs
if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") {

View File

@@ -101,6 +101,9 @@ type fakeAgentTimeoutDB struct {
execCalled bool
execContextCalled bool
deadlineSet bool
queryMessages []string
multiResults []connection.ResultSetData
multiMessages []string
}
func (f *fakeAgentTimeoutDB) Connect(config connection.ConnectionConfig) error { return nil }
@@ -117,6 +120,14 @@ func (f *fakeAgentTimeoutDB) QueryContext(ctx context.Context, query string) ([]
}
return []map[string]interface{}{{"ok": 1}}, []string{"ok"}, nil
}
func (f *fakeAgentTimeoutDB) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
data, fields, err := f.QueryContext(context.Background(), query)
return data, fields, append([]string(nil), f.queryMessages...), err
}
func (f *fakeAgentTimeoutDB) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
data, fields, err := f.QueryContext(ctx, query)
return data, fields, append([]string(nil), f.queryMessages...), err
}
func (f *fakeAgentTimeoutDB) Exec(query string) (int64, error) {
f.execCalled = true
return 0, errors.New("exec should not be called")
@@ -150,6 +161,15 @@ func (f *fakeAgentTimeoutDB) GetForeignKeys(dbName, tableName string) ([]connect
func (f *fakeAgentTimeoutDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
return nil, nil
}
func (f *fakeAgentTimeoutDB) QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) {
return append([]connection.ResultSetData(nil), f.multiResults...), append([]string(nil), f.multiMessages...), nil
}
func (f *fakeAgentTimeoutDB) QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error) {
if _, ok := ctx.Deadline(); ok {
f.deadlineSet = true
}
return f.QueryMultiWithMessages(query)
}
type fakeAgentSessionDB struct {
fakeAgentTimeoutDB
@@ -165,6 +185,7 @@ type fakeAgentStatementSession struct {
queryCalls int
execCalls int
closed bool
messages []string
}
func (f *fakeAgentStatementSession) Query(query string) ([]map[string]interface{}, []string, error) {
@@ -175,6 +196,14 @@ func (f *fakeAgentStatementSession) QueryContext(ctx context.Context, query stri
f.queryCalls++
return []map[string]interface{}{{"session_ok": 1}}, []string{"session_ok"}, nil
}
func (f *fakeAgentStatementSession) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
data, fields, err := f.QueryContext(context.Background(), query)
return data, fields, append([]string(nil), f.messages...), err
}
func (f *fakeAgentStatementSession) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
data, fields, err := f.QueryContext(ctx, query)
return data, fields, append([]string(nil), f.messages...), err
}
func (f *fakeAgentStatementSession) Exec(query string) (int64, error) {
return f.ExecContext(context.Background(), query)
@@ -297,6 +326,77 @@ func TestQueryWithOptionalTimeout_ClickHouseLegacyModeUsesQueryContext(t *testin
}
}
func TestHandleRequest_QueryIncludesServerMessages(t *testing.T) {
old := agentDriverType
defer func() { agentDriverType = old }()
agentDriverType = "sqlserver"
fake := &fakeAgentTimeoutDB{
queryMessages: []string{"PRINT sql line 1", "PRINT sql line 2"},
}
runtimeState := &agentRuntime{inst: fake, sessions: make(map[string]db.StatementExecer)}
resp := handleRequest(runtimeState, agentRequest{
ID: 11,
Method: agentMethodQuery,
Query: "exec dbo.p_get_select",
TimeoutMs: int64((2 * time.Second).Milliseconds()),
})
if !resp.Success {
t.Fatalf("query request failed: %s", resp.Error)
}
if len(resp.Messages) != 2 || resp.Messages[0] != "PRINT sql line 1" {
t.Fatalf("expected query messages to be preserved, got %#v", resp.Messages)
}
}
func TestHandleRequest_QueryMultiIncludesResultSetsAndMessages(t *testing.T) {
old := agentDriverType
defer func() { agentDriverType = old }()
agentDriverType = "sqlserver"
fake := &fakeAgentTimeoutDB{
multiResults: []connection.ResultSetData{
{
StatementIndex: 1,
Rows: []map[string]interface{}{{"name": "master"}},
Columns: []string{"name"},
},
{
StatementIndex: 1,
Rows: []map[string]interface{}{},
Columns: []string{},
Messages: []string{"PRINT generated sql"},
},
},
multiMessages: []string{"batch top-level message"},
}
runtimeState := &agentRuntime{inst: fake, sessions: make(map[string]db.StatementExecer)}
resp := handleRequest(runtimeState, agentRequest{
ID: 12,
Method: agentMethodQueryMulti,
Query: "exec dbo.p_get_select",
TimeoutMs: int64((2 * time.Second).Milliseconds()),
})
if !resp.Success {
t.Fatalf("queryMulti request failed: %s", resp.Error)
}
if len(resp.Messages) != 1 || resp.Messages[0] != "batch top-level message" {
t.Fatalf("expected top-level messages to be preserved, got %#v", resp.Messages)
}
resultSets, ok := resp.Data.([]connection.ResultSetData)
if !ok {
t.Fatalf("expected []connection.ResultSetData, got %T", resp.Data)
}
if len(resultSets) != 2 {
t.Fatalf("expected 2 result sets, got %#v", resultSets)
}
if len(resultSets[1].Messages) != 1 || resultSets[1].Messages[0] != "PRINT generated sql" {
t.Fatalf("expected message-only result set to be preserved, got %#v", resultSets[1])
}
}
func TestHandleRequest_UsesPinnedSessionForSessionScopedQueryAndExec(t *testing.T) {
old := agentDriverType
defer func() { agentDriverType = old }()
@@ -329,6 +429,9 @@ func TestHandleRequest_UsesPinnedSessionForSessionScopedQueryAndExec(t *testing.
if !queryResp.Success {
t.Fatalf("session query failed: %s", queryResp.Error)
}
if len(queryResp.Messages) != 0 {
t.Fatalf("expected empty default session messages, got %#v", queryResp.Messages)
}
if fake.queryCalled || fake.queryContextCalled {
t.Fatalf("expected session query to bypass database-level query path, got Query=%v QueryContext=%v", fake.queryCalled, fake.queryContextCalled)
}