Files
Syngnat 495a985ae1 🐛 fix(sqlserver): 修复可选驱动查询消息透传缺失
- 为 optional-driver-agent 的 query 和 queryMulti 响应补充 messages 字段
- 在可选驱动 DB 客户端透传 SQL Server 查询提示信息与多结果集
- 补充 agent 与数据库层回归测试并更新 driver agent revision
2026-06-23 08:48:42 +08:00

602 lines
19 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"strings"
"testing"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/db"
)
type duckMapLike map[any]any
func TestWriteResponse_NormalizesMapAnyAny(t *testing.T) {
resp := agentResponse{
ID: 1,
Success: true,
Data: []map[string]interface{}{
{
"id": int64(7),
"meta": duckMapLike{"k": "v", 2: "two"},
},
},
}
var out bytes.Buffer
writer := bufio.NewWriter(&out)
if err := writeResponse(writer, resp); err != nil {
t.Fatalf("writeResponse 返回错误: %v", err)
}
var decoded struct {
Data []map[string]interface{} `json:"data"`
}
if err := json.Unmarshal(bytes.TrimSpace(out.Bytes()), &decoded); err != nil {
t.Fatalf("解码响应失败: %v", err)
}
if len(decoded.Data) != 1 {
t.Fatalf("期望 1 行数据,实际 %d", len(decoded.Data))
}
meta, ok := decoded.Data[0]["meta"].(map[string]interface{})
if !ok {
t.Fatalf("meta 字段类型异常: %T", decoded.Data[0]["meta"])
}
if meta["k"] != "v" {
t.Fatalf("字符串 key 转换异常: %v", meta["k"])
}
if meta["2"] != "two" {
t.Fatalf("数字 key 未字符串化: %v", meta["2"])
}
}
func TestNormalizeAgentResponseData_KeepByteSlice(t *testing.T) {
raw := []byte{0x61, 0x62, 0x63}
normalized := normalizeAgentResponseData(raw)
out, ok := normalized.([]byte)
if !ok {
t.Fatalf("期望 []byte实际 %T", normalized)
}
if !bytes.Equal(out, raw) {
t.Fatalf("[]byte 内容被意外改写: %v", out)
}
}
func TestHandleRequestMetadataReportsAgentRevision(t *testing.T) {
previousDriverType := agentDriverType
previousFactory := agentDatabaseFactory
t.Cleanup(func() {
agentDriverType = previousDriverType
agentDatabaseFactory = previousFactory
})
agentDriverType = "clickhouse"
agentDatabaseFactory = func() db.Database { return nil }
runtimeState := &agentRuntime{sessions: make(map[string]db.StatementExecer)}
resp := handleRequest(runtimeState, agentRequest{ID: 7, Method: agentMethodMetadata})
if !resp.Success {
t.Fatalf("metadata request failed: %s", resp.Error)
}
data, ok := resp.Data.(map[string]string)
if !ok {
t.Fatalf("metadata response data type = %T", resp.Data)
}
if data["driverType"] != "clickhouse" {
t.Fatalf("unexpected driver type: %q", data["driverType"])
}
if data["agentRevision"] != db.OptionalDriverAgentRevision("clickhouse") {
t.Fatalf("unexpected agent revision: %q", data["agentRevision"])
}
}
type fakeAgentTimeoutDB struct {
queryCalled bool
queryContextCalled bool
execCalled bool
execContextCalled bool
deadlineSet bool
queryMessages []string
multiResults []connection.ResultSetData
multiMessages []string
}
func (f *fakeAgentTimeoutDB) Connect(config connection.ConnectionConfig) error { return nil }
func (f *fakeAgentTimeoutDB) Close() error { return nil }
func (f *fakeAgentTimeoutDB) Ping() error { return nil }
func (f *fakeAgentTimeoutDB) Query(query string) ([]map[string]interface{}, []string, error) {
f.queryCalled = true
return nil, nil, errors.New("query should not be called")
}
func (f *fakeAgentTimeoutDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
f.queryContextCalled = true
if _, ok := ctx.Deadline(); ok {
f.deadlineSet = true
}
return []map[string]interface{}{{"ok": 1}}, []string{"ok"}, nil
}
func (f *fakeAgentTimeoutDB) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
data, fields, err := f.QueryContext(context.Background(), query)
return data, fields, append([]string(nil), f.queryMessages...), err
}
func (f *fakeAgentTimeoutDB) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
data, fields, err := f.QueryContext(ctx, query)
return data, fields, append([]string(nil), f.queryMessages...), err
}
func (f *fakeAgentTimeoutDB) Exec(query string) (int64, error) {
f.execCalled = true
return 0, errors.New("exec should not be called")
}
func (f *fakeAgentTimeoutDB) ExecContext(ctx context.Context, query string) (int64, error) {
f.execContextCalled = true
if _, ok := ctx.Deadline(); ok {
f.deadlineSet = true
}
return 3, nil
}
func (f *fakeAgentTimeoutDB) GetDatabases() ([]string, error) { return nil, nil }
func (f *fakeAgentTimeoutDB) GetTables(dbName string) ([]string, error) {
return nil, nil
}
func (f *fakeAgentTimeoutDB) GetCreateStatement(dbName, tableName string) (string, error) {
return "", nil
}
func (f *fakeAgentTimeoutDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
return nil, nil
}
func (f *fakeAgentTimeoutDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
return nil, nil
}
func (f *fakeAgentTimeoutDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
return nil, nil
}
func (f *fakeAgentTimeoutDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
return nil, nil
}
func (f *fakeAgentTimeoutDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
return nil, nil
}
func (f *fakeAgentTimeoutDB) QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) {
return append([]connection.ResultSetData(nil), f.multiResults...), append([]string(nil), f.multiMessages...), nil
}
func (f *fakeAgentTimeoutDB) QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error) {
if _, ok := ctx.Deadline(); ok {
f.deadlineSet = true
}
return f.QueryMultiWithMessages(query)
}
type fakeAgentSessionDB struct {
fakeAgentTimeoutDB
session *fakeAgentStatementSession
}
func (f *fakeAgentSessionDB) OpenSessionExecer(ctx context.Context) (db.StatementExecer, error) {
f.session = &fakeAgentStatementSession{}
return f.session, nil
}
type fakeAgentStatementSession struct {
queryCalls int
execCalls int
closed bool
messages []string
}
func (f *fakeAgentStatementSession) Query(query string) ([]map[string]interface{}, []string, error) {
return f.QueryContext(context.Background(), query)
}
func (f *fakeAgentStatementSession) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
f.queryCalls++
return []map[string]interface{}{{"session_ok": 1}}, []string{"session_ok"}, nil
}
func (f *fakeAgentStatementSession) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
data, fields, err := f.QueryContext(context.Background(), query)
return data, fields, append([]string(nil), f.messages...), err
}
func (f *fakeAgentStatementSession) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
data, fields, err := f.QueryContext(ctx, query)
return data, fields, append([]string(nil), f.messages...), err
}
func (f *fakeAgentStatementSession) Exec(query string) (int64, error) {
return f.ExecContext(context.Background(), query)
}
func (f *fakeAgentStatementSession) ExecContext(ctx context.Context, query string) (int64, error) {
f.execCalls++
return 9, nil
}
func (f *fakeAgentStatementSession) Close() error {
f.closed = true
return nil
}
type fakeAgentStreamSession struct {
closed bool
streamCalls int
deadlineSet bool
}
func (f *fakeAgentStreamSession) Exec(query string) (int64, error) {
return 0, nil
}
func (f *fakeAgentStreamSession) ExecContext(ctx context.Context, query string) (int64, error) {
return 0, nil
}
func (f *fakeAgentStreamSession) Close() error {
f.closed = true
return nil
}
func (f *fakeAgentStreamSession) StreamQuery(query string, consumer db.QueryStreamConsumer) error {
return f.StreamQueryContext(context.Background(), query, consumer)
}
func (f *fakeAgentStreamSession) StreamQueryContext(ctx context.Context, query string, consumer db.QueryStreamConsumer) error {
f.streamCalls++
if _, ok := ctx.Deadline(); ok {
f.deadlineSet = true
}
if err := consumer.SetColumns([]string{"id", "name"}); err != nil {
return err
}
if valueConsumer, ok := consumer.(db.QueryStreamValueConsumer); ok {
if err := valueConsumer.ConsumeRowValues([]interface{}{1, "alice"}); err != nil {
return err
}
if err := valueConsumer.ConsumeRowValues([]interface{}{2, "bob"}); err != nil {
return err
}
return nil
}
if err := consumer.ConsumeRow(map[string]interface{}{"id": 1, "name": "alice"}); err != nil {
return err
}
return consumer.ConsumeRow(map[string]interface{}{"id": 2, "name": "bob"})
}
type fakeAgentSessionStreamDB struct {
fakeAgentTimeoutDB
session *fakeAgentStreamSession
openCalls int
}
func (f *fakeAgentSessionStreamDB) OpenSessionExecer(ctx context.Context) (db.StatementExecer, error) {
f.openCalls++
f.session = &fakeAgentStreamSession{}
return f.session, nil
}
func TestQueryWithOptionalTimeout_UsesQueryContext(t *testing.T) {
fake := &fakeAgentTimeoutDB{}
data, fields, err := queryWithOptionalTimeout(fake, "SELECT 1", int64((2 * time.Second).Milliseconds()))
if err != nil {
t.Fatalf("queryWithOptionalTimeout 返回错误: %v", err)
}
if !fake.queryContextCalled || fake.queryCalled {
t.Fatalf("query 调用路径异常QueryContext=%v Query=%v", fake.queryContextCalled, fake.queryCalled)
}
if !fake.deadlineSet {
t.Fatal("queryWithOptionalTimeout 未设置 deadline")
}
if len(data) != 1 || len(fields) != 1 || fields[0] != "ok" {
t.Fatalf("queryWithOptionalTimeout 返回数据异常: data=%v fields=%v", data, fields)
}
}
func TestExecWithOptionalTimeout_UsesExecContext(t *testing.T) {
fake := &fakeAgentTimeoutDB{}
affected, err := execWithOptionalTimeout(fake, "DELETE FROM t", int64((2 * time.Second).Milliseconds()))
if err != nil {
t.Fatalf("execWithOptionalTimeout 返回错误: %v", err)
}
if !fake.execContextCalled || fake.execCalled {
t.Fatalf("exec 调用路径异常ExecContext=%v Exec=%v", fake.execContextCalled, fake.execCalled)
}
if !fake.deadlineSet {
t.Fatal("execWithOptionalTimeout 未设置 deadline")
}
if affected != 3 {
t.Fatalf("受影响行数异常want=3 got=%d", affected)
}
}
func TestQueryWithOptionalTimeout_ClickHouseLegacyModeUsesQueryContext(t *testing.T) {
old := agentDriverType
agentDriverType = "clickhouse"
defer func() { agentDriverType = old }()
fake := &fakeAgentTimeoutDB{}
_, _, err := queryWithOptionalTimeout(fake, "SELECT 1", 0)
if err != nil {
t.Fatalf("queryWithOptionalTimeout 返回错误: %v", err)
}
if !fake.queryContextCalled || fake.queryCalled {
t.Fatalf("clickhouse legacy query 调用路径异常QueryContext=%v Query=%v", fake.queryContextCalled, fake.queryCalled)
}
}
func TestHandleRequest_QueryIncludesServerMessages(t *testing.T) {
old := agentDriverType
defer func() { agentDriverType = old }()
agentDriverType = "sqlserver"
fake := &fakeAgentTimeoutDB{
queryMessages: []string{"PRINT sql line 1", "PRINT sql line 2"},
}
runtimeState := &agentRuntime{inst: fake, sessions: make(map[string]db.StatementExecer)}
resp := handleRequest(runtimeState, agentRequest{
ID: 11,
Method: agentMethodQuery,
Query: "exec dbo.p_get_select",
TimeoutMs: int64((2 * time.Second).Milliseconds()),
})
if !resp.Success {
t.Fatalf("query request failed: %s", resp.Error)
}
if len(resp.Messages) != 2 || resp.Messages[0] != "PRINT sql line 1" {
t.Fatalf("expected query messages to be preserved, got %#v", resp.Messages)
}
}
func TestHandleRequest_QueryMultiIncludesResultSetsAndMessages(t *testing.T) {
old := agentDriverType
defer func() { agentDriverType = old }()
agentDriverType = "sqlserver"
fake := &fakeAgentTimeoutDB{
multiResults: []connection.ResultSetData{
{
StatementIndex: 1,
Rows: []map[string]interface{}{{"name": "master"}},
Columns: []string{"name"},
},
{
StatementIndex: 1,
Rows: []map[string]interface{}{},
Columns: []string{},
Messages: []string{"PRINT generated sql"},
},
},
multiMessages: []string{"batch top-level message"},
}
runtimeState := &agentRuntime{inst: fake, sessions: make(map[string]db.StatementExecer)}
resp := handleRequest(runtimeState, agentRequest{
ID: 12,
Method: agentMethodQueryMulti,
Query: "exec dbo.p_get_select",
TimeoutMs: int64((2 * time.Second).Milliseconds()),
})
if !resp.Success {
t.Fatalf("queryMulti request failed: %s", resp.Error)
}
if len(resp.Messages) != 1 || resp.Messages[0] != "batch top-level message" {
t.Fatalf("expected top-level messages to be preserved, got %#v", resp.Messages)
}
resultSets, ok := resp.Data.([]connection.ResultSetData)
if !ok {
t.Fatalf("expected []connection.ResultSetData, got %T", resp.Data)
}
if len(resultSets) != 2 {
t.Fatalf("expected 2 result sets, got %#v", resultSets)
}
if len(resultSets[1].Messages) != 1 || resultSets[1].Messages[0] != "PRINT generated sql" {
t.Fatalf("expected message-only result set to be preserved, got %#v", resultSets[1])
}
}
func TestHandleRequest_UsesPinnedSessionForSessionScopedQueryAndExec(t *testing.T) {
old := agentDriverType
defer func() { agentDriverType = old }()
agentDriverType = "sqlserver"
fake := &fakeAgentSessionDB{}
runtimeState := &agentRuntime{
inst: fake,
sessions: make(map[string]db.StatementExecer),
}
openResp := handleRequest(runtimeState, agentRequest{ID: 1, Method: agentMethodOpenSession})
if !openResp.Success {
t.Fatalf("openSession failed: %s", openResp.Error)
}
sessionID, ok := openResp.Data.(string)
if !ok || strings.TrimSpace(sessionID) == "" {
t.Fatalf("unexpected session id payload: %#v", openResp.Data)
}
if fake.session == nil {
t.Fatal("expected OpenSessionExecer to create a pinned session")
}
queryResp := handleRequest(runtimeState, agentRequest{
ID: 2,
Method: agentMethodQuery,
SessionID: sessionID,
Query: "SELECT 1",
})
if !queryResp.Success {
t.Fatalf("session query failed: %s", queryResp.Error)
}
if len(queryResp.Messages) != 0 {
t.Fatalf("expected empty default session messages, got %#v", queryResp.Messages)
}
if fake.queryCalled || fake.queryContextCalled {
t.Fatalf("expected session query to bypass database-level query path, got Query=%v QueryContext=%v", fake.queryCalled, fake.queryContextCalled)
}
if fake.session.queryCalls != 1 {
t.Fatalf("expected pinned session queryCalls=1, got %d", fake.session.queryCalls)
}
execResp := handleRequest(runtimeState, agentRequest{
ID: 3,
Method: agentMethodExec,
SessionID: sessionID,
Query: "UPDATE t SET v = 1",
})
if !execResp.Success {
t.Fatalf("session exec failed: %s", execResp.Error)
}
if fake.execCalled || fake.execContextCalled {
t.Fatalf("expected session exec to bypass database-level exec path, got Exec=%v ExecContext=%v", fake.execCalled, fake.execContextCalled)
}
if fake.session.execCalls != 1 {
t.Fatalf("expected pinned session execCalls=1, got %d", fake.session.execCalls)
}
closeResp := handleRequest(runtimeState, agentRequest{
ID: 4,
Method: agentMethodCloseSession,
SessionID: sessionID,
})
if !closeResp.Success {
t.Fatalf("closeSession failed: %s", closeResp.Error)
}
if !fake.session.closed {
t.Fatal("expected pinned session to close")
}
}
func TestHandleStreamRequest_UsesSessionStreamerAndWritesChunks(t *testing.T) {
old := agentDriverType
originalAsync := runAgentMemoryTrimAsync
originalTrim := agentMemoryTrimFn
originalLastAt := agentMemoryTrimLastAt.Load()
defer func() { agentDriverType = old }()
defer func() {
runAgentMemoryTrimAsync = originalAsync
agentMemoryTrimFn = originalTrim
agentMemoryTrimRunning.Store(false)
agentMemoryTrimLastAt.Store(originalLastAt)
}()
agentDriverType = "oceanbase"
agentMemoryTrimRunning.Store(false)
agentMemoryTrimLastAt.Store(0)
fake := &fakeAgentSessionStreamDB{}
runtimeState := &agentRuntime{
inst: fake,
sessions: make(map[string]db.StatementExecer),
}
trimmed := 0
runAgentMemoryTrimAsync = func(fn func()) {
fn()
}
agentMemoryTrimFn = func() {
trimmed++
}
var out bytes.Buffer
writer := bufio.NewWriter(&out)
if err := handleStreamRequest(runtimeState, agentRequest{
ID: 9,
Method: agentMethodStreamQuery,
Query: "SELECT * FROM person_info",
TimeoutMs: int64((2 * time.Second).Milliseconds()),
}, writer); err != nil {
t.Fatalf("handleStreamRequest 返回错误: %v", err)
}
if fake.openCalls != 1 {
t.Fatalf("expected OpenSessionExecer called once, got %d", fake.openCalls)
}
if fake.session == nil || fake.session.streamCalls != 1 {
t.Fatalf("expected session streamer used once, session=%#v", fake.session)
}
if !fake.session.deadlineSet {
t.Fatal("expected stream query context deadline to be set")
}
if !fake.session.closed {
t.Fatal("expected session to close after streaming")
}
if fake.queryCalled || fake.queryContextCalled {
t.Fatalf("unexpected fallback query path, Query=%v QueryContext=%v", fake.queryCalled, fake.queryContextCalled)
}
lines := strings.Split(strings.TrimSpace(out.String()), "\n")
if len(lines) != 3 {
t.Fatalf("expected 3 stream responses, got %d: %q", len(lines), out.String())
}
var columnsResp struct {
Success bool `json:"success"`
ChunkType string `json:"chunkType"`
Fields []string `json:"fields"`
}
if err := json.Unmarshal([]byte(lines[0]), &columnsResp); err != nil {
t.Fatalf("decode columns response failed: %v", err)
}
if !columnsResp.Success || columnsResp.ChunkType != agentChunkColumns || len(columnsResp.Fields) != 2 {
t.Fatalf("unexpected columns response: %#v", columnsResp)
}
var rowsResp struct {
Success bool `json:"success"`
ChunkType string `json:"chunkType"`
Data [][]interface{} `json:"data"`
}
if err := json.Unmarshal([]byte(lines[1]), &rowsResp); err != nil {
t.Fatalf("decode rows response failed: %v", err)
}
if !rowsResp.Success || rowsResp.ChunkType != agentChunkRows || len(rowsResp.Data) != 2 {
t.Fatalf("unexpected rows response: %#v", rowsResp)
}
if got := rowsResp.Data[1][1]; got != "bob" {
t.Fatalf("unexpected streamed row payload: %v", rowsResp.Data)
}
var doneResp struct {
Success bool `json:"success"`
ChunkType string `json:"chunkType"`
}
if err := json.Unmarshal([]byte(lines[2]), &doneResp); err != nil {
t.Fatalf("decode done response failed: %v", err)
}
if !doneResp.Success || doneResp.ChunkType != agentChunkDone {
t.Fatalf("unexpected done response: %#v", doneResp)
}
if trimmed != 0 {
t.Fatalf("小流式任务不应触发内存回收got=%d", trimmed)
}
}
func TestMaybeReleaseAgentMemory_TriggersTrimForLargeJobs(t *testing.T) {
originalAsync := runAgentMemoryTrimAsync
originalTrim := agentMemoryTrimFn
originalLastAt := agentMemoryTrimLastAt.Load()
t.Cleanup(func() {
runAgentMemoryTrimAsync = originalAsync
agentMemoryTrimFn = originalTrim
agentMemoryTrimRunning.Store(false)
agentMemoryTrimLastAt.Store(originalLastAt)
})
agentMemoryTrimRunning.Store(false)
agentMemoryTrimLastAt.Store(0)
triggered := 0
runAgentMemoryTrimAsync = func(fn func()) {
fn()
}
agentMemoryTrimFn = func() {
triggered++
}
maybeReleaseAgentMemory("test-large-query", agentMemoryTrimRowsThreshold)
if triggered != 1 {
t.Fatalf("大查询完成后应触发一次内存回收got=%d", triggered)
}
}