mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-07-02 06:31:21 +08:00
🐛 fix(sqlserver): 修复可选驱动查询消息透传缺失
- 为 optional-driver-agent 的 query 和 queryMulti 响应补充 messages 字段 - 在可选驱动 DB 客户端透传 SQL Server 查询提示信息与多结果集 - 补充 agent 与数据库层回归测试并更新 driver agent revision
This commit is contained in:
@@ -36,6 +36,7 @@ type agentResponse struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
Fields []string `json:"fields,omitempty"`
|
||||
Messages []string `json:"messages,omitempty"`
|
||||
ChunkType string `json:"chunkType,omitempty"`
|
||||
RowsAffected int64 `json:"rowsAffected,omitempty"`
|
||||
}
|
||||
@@ -48,6 +49,7 @@ const (
|
||||
agentMethodOpenSession = "openSession"
|
||||
agentMethodCloseSession = "closeSession"
|
||||
agentMethodQuery = "query"
|
||||
agentMethodQueryMulti = "queryMulti"
|
||||
agentMethodStreamQuery = "streamQuery"
|
||||
agentMethodExec = "exec"
|
||||
agentMethodGetDatabases = "getDatabases"
|
||||
@@ -64,9 +66,9 @@ const (
|
||||
const legacyClickHouseDefaultTimeout = 2 * time.Hour
|
||||
|
||||
const (
|
||||
agentChunkColumns = "columns"
|
||||
agentChunkRows = "rows"
|
||||
agentChunkDone = "done"
|
||||
agentChunkColumns = "columns"
|
||||
agentChunkRows = "rows"
|
||||
agentChunkDone = "done"
|
||||
// agentStreamBatchSize 控制 driver-agent 向主进程发送 row chunk 的批次大小。
|
||||
// 调小到 64:单批 JSON 编码 + 主进程解码的瞬时内存峰值降为原来的 1/4,
|
||||
// 代价是 IPC 次数变为 4 倍,但每批仅一次 stdin/stdout 行读写,整体影响可忽略。
|
||||
@@ -236,12 +238,23 @@ func handleRequest(runtimeState *agentRuntime, req agentRequest) agentResponse {
|
||||
} else if ok {
|
||||
switch method {
|
||||
case agentMethodQuery:
|
||||
data, fields, err := queryStatementWithOptionalTimeout(session, req.Query, req.TimeoutMs)
|
||||
data, fields, messages, err := queryStatementWithMessagesOptionalTimeout(session, req.Query, req.TimeoutMs)
|
||||
if err != nil {
|
||||
return fail(resp, err.Error())
|
||||
}
|
||||
resp.Data = data
|
||||
resp.Fields = fields
|
||||
resp.Messages = messages
|
||||
case agentMethodQueryMulti:
|
||||
data, messages, supported, err := queryMultiStatementWithMessagesOptionalTimeout(session, req.Query, req.TimeoutMs)
|
||||
if err != nil {
|
||||
return fail(resp, err.Error())
|
||||
}
|
||||
if !supported {
|
||||
return fail(resp, "当前事务会话不支持多结果集查询")
|
||||
}
|
||||
resp.Data = data
|
||||
resp.Messages = messages
|
||||
case agentMethodExec:
|
||||
affected, err := execStatementWithOptionalTimeout(session, req.Query, req.TimeoutMs)
|
||||
if err != nil {
|
||||
@@ -260,12 +273,23 @@ func handleRequest(runtimeState *agentRuntime, req agentRequest) agentResponse {
|
||||
return fail(resp, err.Error())
|
||||
}
|
||||
case agentMethodQuery:
|
||||
data, fields, err := queryWithOptionalTimeout(runtimeState.inst, req.Query, req.TimeoutMs)
|
||||
data, fields, messages, err := queryWithMessagesOptionalTimeout(runtimeState.inst, req.Query, req.TimeoutMs)
|
||||
if err != nil {
|
||||
return fail(resp, err.Error())
|
||||
}
|
||||
resp.Data = data
|
||||
resp.Fields = fields
|
||||
resp.Messages = messages
|
||||
case agentMethodQueryMulti:
|
||||
data, messages, supported, err := queryMultiWithMessagesOptionalTimeout(runtimeState.inst, req.Query, req.TimeoutMs)
|
||||
if err != nil {
|
||||
return fail(resp, err.Error())
|
||||
}
|
||||
if !supported {
|
||||
return fail(resp, "当前驱动不支持原生多结果集查询")
|
||||
}
|
||||
resp.Data = data
|
||||
resp.Messages = messages
|
||||
case agentMethodExec:
|
||||
affected, err := execWithOptionalTimeout(runtimeState.inst, req.Query, req.TimeoutMs)
|
||||
if err != nil {
|
||||
@@ -581,6 +605,30 @@ type agentQueryContextRunner interface {
|
||||
QueryContext(context.Context, string) ([]map[string]interface{}, []string, error)
|
||||
}
|
||||
|
||||
type agentQueryMessageRunner interface {
|
||||
QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error)
|
||||
}
|
||||
|
||||
type agentQueryMessageContextRunner interface {
|
||||
QueryContextWithMessages(context.Context, string) ([]map[string]interface{}, []string, []string, error)
|
||||
}
|
||||
|
||||
type agentMultiResultMessageRunner interface {
|
||||
QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error)
|
||||
}
|
||||
|
||||
type agentMultiResultMessageContextRunner interface {
|
||||
QueryMultiContextWithMessages(context.Context, string) ([]connection.ResultSetData, []string, error)
|
||||
}
|
||||
|
||||
type agentMultiResultRunner interface {
|
||||
QueryMulti(query string) ([]connection.ResultSetData, error)
|
||||
}
|
||||
|
||||
type agentMultiResultContextRunner interface {
|
||||
QueryMultiContext(context.Context, string) ([]connection.ResultSetData, error)
|
||||
}
|
||||
|
||||
type agentExecRunner interface {
|
||||
Exec(string) (int64, error)
|
||||
}
|
||||
@@ -589,20 +637,39 @@ type agentExecContextRunner interface {
|
||||
ExecContext(context.Context, string) (int64, error)
|
||||
}
|
||||
|
||||
func queryWithOptionalTimeout(inst agentQueryRunner, query string, timeoutMs int64) ([]map[string]interface{}, []string, error) {
|
||||
func queryWithMessagesOptionalTimeout(inst agentQueryRunner, query string, timeoutMs int64) ([]map[string]interface{}, []string, []string, error) {
|
||||
effectiveTimeoutMs := timeoutMs
|
||||
if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") {
|
||||
effectiveTimeoutMs = int64(legacyClickHouseDefaultTimeout / time.Millisecond)
|
||||
}
|
||||
if effectiveTimeoutMs <= 0 {
|
||||
return inst.Query(query)
|
||||
if q, ok := inst.(agentQueryMessageRunner); ok {
|
||||
return q.QueryWithMessages(query)
|
||||
}
|
||||
data, fields, err := inst.Query(query)
|
||||
return data, fields, nil, err
|
||||
}
|
||||
if q, ok := inst.(agentQueryMessageContextRunner); ok {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond)
|
||||
defer cancel()
|
||||
return q.QueryContextWithMessages(ctx, query)
|
||||
}
|
||||
if q, ok := inst.(agentQueryContextRunner); ok {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond)
|
||||
defer cancel()
|
||||
return q.QueryContext(ctx, query)
|
||||
data, fields, err := q.QueryContext(ctx, query)
|
||||
return data, fields, nil, err
|
||||
}
|
||||
return inst.Query(query)
|
||||
if q, ok := inst.(agentQueryMessageRunner); ok {
|
||||
return q.QueryWithMessages(query)
|
||||
}
|
||||
data, fields, err := inst.Query(query)
|
||||
return data, fields, nil, err
|
||||
}
|
||||
|
||||
func queryWithOptionalTimeout(inst agentQueryRunner, query string, timeoutMs int64) ([]map[string]interface{}, []string, error) {
|
||||
data, fields, _, err := queryWithMessagesOptionalTimeout(inst, query, timeoutMs)
|
||||
return data, fields, err
|
||||
}
|
||||
|
||||
func queryStatementWithOptionalTimeout(inst db.StatementExecer, query string, timeoutMs int64) ([]map[string]interface{}, []string, error) {
|
||||
@@ -613,6 +680,74 @@ func queryStatementWithOptionalTimeout(inst db.StatementExecer, query string, ti
|
||||
return queryWithOptionalTimeout(queryRunner, query, timeoutMs)
|
||||
}
|
||||
|
||||
func queryStatementWithMessagesOptionalTimeout(inst db.StatementExecer, query string, timeoutMs int64) ([]map[string]interface{}, []string, []string, error) {
|
||||
queryRunner, ok := inst.(agentQueryRunner)
|
||||
if !ok {
|
||||
return nil, nil, nil, fmt.Errorf("当前事务会话不支持查询语句")
|
||||
}
|
||||
return queryWithMessagesOptionalTimeout(queryRunner, query, timeoutMs)
|
||||
}
|
||||
|
||||
func queryMultiWithMessagesOptionalTimeout(inst db.Database, query string, timeoutMs int64) ([]connection.ResultSetData, []string, bool, error) {
|
||||
effectiveTimeoutMs := timeoutMs
|
||||
if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") {
|
||||
effectiveTimeoutMs = int64(legacyClickHouseDefaultTimeout / time.Millisecond)
|
||||
}
|
||||
if effectiveTimeoutMs > 0 {
|
||||
if q, ok := inst.(agentMultiResultMessageContextRunner); ok {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond)
|
||||
defer cancel()
|
||||
data, messages, err := q.QueryMultiContextWithMessages(ctx, query)
|
||||
return data, messages, true, err
|
||||
}
|
||||
if q, ok := inst.(agentMultiResultContextRunner); ok {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond)
|
||||
defer cancel()
|
||||
data, err := q.QueryMultiContext(ctx, query)
|
||||
return data, nil, true, err
|
||||
}
|
||||
}
|
||||
if q, ok := inst.(agentMultiResultMessageRunner); ok {
|
||||
data, messages, err := q.QueryMultiWithMessages(query)
|
||||
return data, messages, true, err
|
||||
}
|
||||
if q, ok := inst.(agentMultiResultRunner); ok {
|
||||
data, err := q.QueryMulti(query)
|
||||
return data, nil, true, err
|
||||
}
|
||||
return nil, nil, false, nil
|
||||
}
|
||||
|
||||
func queryMultiStatementWithMessagesOptionalTimeout(inst db.StatementExecer, query string, timeoutMs int64) ([]connection.ResultSetData, []string, bool, error) {
|
||||
effectiveTimeoutMs := timeoutMs
|
||||
if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") {
|
||||
effectiveTimeoutMs = int64(legacyClickHouseDefaultTimeout / time.Millisecond)
|
||||
}
|
||||
if effectiveTimeoutMs > 0 {
|
||||
if q, ok := inst.(agentMultiResultMessageContextRunner); ok {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond)
|
||||
defer cancel()
|
||||
data, messages, err := q.QueryMultiContextWithMessages(ctx, query)
|
||||
return data, messages, true, err
|
||||
}
|
||||
if q, ok := inst.(agentMultiResultContextRunner); ok {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond)
|
||||
defer cancel()
|
||||
data, err := q.QueryMultiContext(ctx, query)
|
||||
return data, nil, true, err
|
||||
}
|
||||
}
|
||||
if q, ok := inst.(agentMultiResultMessageRunner); ok {
|
||||
data, messages, err := q.QueryMultiWithMessages(query)
|
||||
return data, messages, true, err
|
||||
}
|
||||
if q, ok := inst.(agentMultiResultRunner); ok {
|
||||
data, err := q.QueryMulti(query)
|
||||
return data, nil, true, err
|
||||
}
|
||||
return nil, nil, false, nil
|
||||
}
|
||||
|
||||
func streamWithOptionalTimeout(inst db.StreamQueryExecer, query string, timeoutMs int64, consumer db.QueryStreamConsumer) error {
|
||||
effectiveTimeoutMs := timeoutMs
|
||||
if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") {
|
||||
|
||||
@@ -101,6 +101,9 @@ type fakeAgentTimeoutDB struct {
|
||||
execCalled bool
|
||||
execContextCalled bool
|
||||
deadlineSet bool
|
||||
queryMessages []string
|
||||
multiResults []connection.ResultSetData
|
||||
multiMessages []string
|
||||
}
|
||||
|
||||
func (f *fakeAgentTimeoutDB) Connect(config connection.ConnectionConfig) error { return nil }
|
||||
@@ -117,6 +120,14 @@ func (f *fakeAgentTimeoutDB) QueryContext(ctx context.Context, query string) ([]
|
||||
}
|
||||
return []map[string]interface{}{{"ok": 1}}, []string{"ok"}, nil
|
||||
}
|
||||
func (f *fakeAgentTimeoutDB) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
data, fields, err := f.QueryContext(context.Background(), query)
|
||||
return data, fields, append([]string(nil), f.queryMessages...), err
|
||||
}
|
||||
func (f *fakeAgentTimeoutDB) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
data, fields, err := f.QueryContext(ctx, query)
|
||||
return data, fields, append([]string(nil), f.queryMessages...), err
|
||||
}
|
||||
func (f *fakeAgentTimeoutDB) Exec(query string) (int64, error) {
|
||||
f.execCalled = true
|
||||
return 0, errors.New("exec should not be called")
|
||||
@@ -150,6 +161,15 @@ func (f *fakeAgentTimeoutDB) GetForeignKeys(dbName, tableName string) ([]connect
|
||||
func (f *fakeAgentTimeoutDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeAgentTimeoutDB) QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) {
|
||||
return append([]connection.ResultSetData(nil), f.multiResults...), append([]string(nil), f.multiMessages...), nil
|
||||
}
|
||||
func (f *fakeAgentTimeoutDB) QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error) {
|
||||
if _, ok := ctx.Deadline(); ok {
|
||||
f.deadlineSet = true
|
||||
}
|
||||
return f.QueryMultiWithMessages(query)
|
||||
}
|
||||
|
||||
type fakeAgentSessionDB struct {
|
||||
fakeAgentTimeoutDB
|
||||
@@ -165,6 +185,7 @@ type fakeAgentStatementSession struct {
|
||||
queryCalls int
|
||||
execCalls int
|
||||
closed bool
|
||||
messages []string
|
||||
}
|
||||
|
||||
func (f *fakeAgentStatementSession) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
@@ -175,6 +196,14 @@ func (f *fakeAgentStatementSession) QueryContext(ctx context.Context, query stri
|
||||
f.queryCalls++
|
||||
return []map[string]interface{}{{"session_ok": 1}}, []string{"session_ok"}, nil
|
||||
}
|
||||
func (f *fakeAgentStatementSession) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
data, fields, err := f.QueryContext(context.Background(), query)
|
||||
return data, fields, append([]string(nil), f.messages...), err
|
||||
}
|
||||
func (f *fakeAgentStatementSession) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
data, fields, err := f.QueryContext(ctx, query)
|
||||
return data, fields, append([]string(nil), f.messages...), err
|
||||
}
|
||||
|
||||
func (f *fakeAgentStatementSession) Exec(query string) (int64, error) {
|
||||
return f.ExecContext(context.Background(), query)
|
||||
@@ -297,6 +326,77 @@ func TestQueryWithOptionalTimeout_ClickHouseLegacyModeUsesQueryContext(t *testin
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleRequest_QueryIncludesServerMessages(t *testing.T) {
|
||||
old := agentDriverType
|
||||
defer func() { agentDriverType = old }()
|
||||
agentDriverType = "sqlserver"
|
||||
|
||||
fake := &fakeAgentTimeoutDB{
|
||||
queryMessages: []string{"PRINT sql line 1", "PRINT sql line 2"},
|
||||
}
|
||||
runtimeState := &agentRuntime{inst: fake, sessions: make(map[string]db.StatementExecer)}
|
||||
|
||||
resp := handleRequest(runtimeState, agentRequest{
|
||||
ID: 11,
|
||||
Method: agentMethodQuery,
|
||||
Query: "exec dbo.p_get_select",
|
||||
TimeoutMs: int64((2 * time.Second).Milliseconds()),
|
||||
})
|
||||
if !resp.Success {
|
||||
t.Fatalf("query request failed: %s", resp.Error)
|
||||
}
|
||||
if len(resp.Messages) != 2 || resp.Messages[0] != "PRINT sql line 1" {
|
||||
t.Fatalf("expected query messages to be preserved, got %#v", resp.Messages)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleRequest_QueryMultiIncludesResultSetsAndMessages(t *testing.T) {
|
||||
old := agentDriverType
|
||||
defer func() { agentDriverType = old }()
|
||||
agentDriverType = "sqlserver"
|
||||
|
||||
fake := &fakeAgentTimeoutDB{
|
||||
multiResults: []connection.ResultSetData{
|
||||
{
|
||||
StatementIndex: 1,
|
||||
Rows: []map[string]interface{}{{"name": "master"}},
|
||||
Columns: []string{"name"},
|
||||
},
|
||||
{
|
||||
StatementIndex: 1,
|
||||
Rows: []map[string]interface{}{},
|
||||
Columns: []string{},
|
||||
Messages: []string{"PRINT generated sql"},
|
||||
},
|
||||
},
|
||||
multiMessages: []string{"batch top-level message"},
|
||||
}
|
||||
runtimeState := &agentRuntime{inst: fake, sessions: make(map[string]db.StatementExecer)}
|
||||
|
||||
resp := handleRequest(runtimeState, agentRequest{
|
||||
ID: 12,
|
||||
Method: agentMethodQueryMulti,
|
||||
Query: "exec dbo.p_get_select",
|
||||
TimeoutMs: int64((2 * time.Second).Milliseconds()),
|
||||
})
|
||||
if !resp.Success {
|
||||
t.Fatalf("queryMulti request failed: %s", resp.Error)
|
||||
}
|
||||
if len(resp.Messages) != 1 || resp.Messages[0] != "batch top-level message" {
|
||||
t.Fatalf("expected top-level messages to be preserved, got %#v", resp.Messages)
|
||||
}
|
||||
resultSets, ok := resp.Data.([]connection.ResultSetData)
|
||||
if !ok {
|
||||
t.Fatalf("expected []connection.ResultSetData, got %T", resp.Data)
|
||||
}
|
||||
if len(resultSets) != 2 {
|
||||
t.Fatalf("expected 2 result sets, got %#v", resultSets)
|
||||
}
|
||||
if len(resultSets[1].Messages) != 1 || resultSets[1].Messages[0] != "PRINT generated sql" {
|
||||
t.Fatalf("expected message-only result set to be preserved, got %#v", resultSets[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleRequest_UsesPinnedSessionForSessionScopedQueryAndExec(t *testing.T) {
|
||||
old := agentDriverType
|
||||
defer func() { agentDriverType = old }()
|
||||
@@ -329,6 +429,9 @@ func TestHandleRequest_UsesPinnedSessionForSessionScopedQueryAndExec(t *testing.
|
||||
if !queryResp.Success {
|
||||
t.Fatalf("session query failed: %s", queryResp.Error)
|
||||
}
|
||||
if len(queryResp.Messages) != 0 {
|
||||
t.Fatalf("expected empty default session messages, got %#v", queryResp.Messages)
|
||||
}
|
||||
if fake.queryCalled || fake.queryContextCalled {
|
||||
t.Fatalf("expected session query to bypass database-level query path, got Query=%v QueryContext=%v", fake.queryCalled, fake.queryContextCalled)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user