🐛 fix(sql-editor): 修复事务执行会话与工具栏布局交互

This commit is contained in:
Syngnat
2026-06-14 12:40:31 +08:00
parent 7a85c30752
commit 8d5a24992a
7 changed files with 500 additions and 45 deletions

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"os"
"reflect"
"strconv"
"strings"
"time"
@@ -17,6 +18,7 @@ import (
type agentRequest struct {
ID int64 `json:"id"`
Method string `json:"method"`
SessionID string `json:"sessionId,omitempty"`
Config *connection.ConnectionConfig `json:"config,omitempty"`
Query string `json:"query,omitempty"`
TimeoutMs int64 `json:"timeoutMs,omitempty"`
@@ -39,6 +41,8 @@ const (
agentMethodClose = "close"
agentMethodMetadata = "metadata"
agentMethodPing = "ping"
agentMethodOpenSession = "openSession"
agentMethodCloseSession = "closeSession"
agentMethodQuery = "query"
agentMethodExec = "exec"
agentMethodGetDatabases = "getDatabases"
@@ -59,6 +63,12 @@ var (
agentDatabaseFactory func() db.Database
)
type agentRuntime struct {
inst db.Database
sessions map[string]db.StatementExecer
nextSessionID int64
}
func main() {
if agentDatabaseFactory == nil || strings.TrimSpace(agentDriverType) == "" {
fmt.Fprintf(os.Stderr, "未配置驱动代理 provider请使用 gonavi_<driver>_driver 标签构建\n")
@@ -70,7 +80,9 @@ func main() {
writer := bufio.NewWriter(os.Stdout)
defer writer.Flush()
var inst db.Database
runtimeState := &agentRuntime{
sessions: make(map[string]db.StatementExecer),
}
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
@@ -87,23 +99,21 @@ func main() {
continue
}
resp := handleRequest(&inst, req)
resp := handleRequest(runtimeState, req)
if err := writeResponse(writer, resp); err != nil {
fmt.Fprintf(os.Stderr, "写入响应失败:%v\n", err)
break
}
}
if inst != nil {
_ = inst.Close()
}
runtimeState.close()
if err := scanner.Err(); err != nil {
fmt.Fprintf(os.Stderr, "读取请求失败:%v\n", err)
}
}
func handleRequest(inst *db.Database, req agentRequest) agentResponse {
func handleRequest(runtimeState *agentRuntime, req agentRequest) agentResponse {
resp := agentResponse{ID: req.ID, Success: true}
method := strings.TrimSpace(req.Method)
@@ -112,9 +122,7 @@ func handleRequest(inst *db.Database, req agentRequest) agentResponse {
if req.Config == nil {
return fail(resp, "连接配置为空")
}
if *inst != nil {
_ = (*inst).Close()
}
runtimeState.close()
next := agentDatabaseFactory()
if next == nil {
return fail(resp, "驱动代理初始化失败")
@@ -122,14 +130,13 @@ func handleRequest(inst *db.Database, req agentRequest) agentResponse {
if err := next.Connect(*req.Config); err != nil {
return fail(resp, err.Error())
}
*inst = next
runtimeState.inst = next
return resp
case agentMethodClose:
if *inst != nil {
if err := (*inst).Close(); err != nil {
if runtimeState.inst != nil {
if err := runtimeState.close(); err != nil {
return fail(resp, err.Error())
}
*inst = nil
}
return resp
case agentMethodMetadata:
@@ -139,74 +146,124 @@ func handleRequest(inst *db.Database, req agentRequest) agentResponse {
"protocolSchema": "json-lines-v1",
}
return resp
case agentMethodOpenSession:
if runtimeState.inst == nil {
return fail(resp, "connection not open")
}
provider, ok := runtimeState.inst.(db.SessionExecerProvider)
if !ok {
return fail(resp, fmt.Sprintf("当前数据源(%s不支持 SQL 编辑器托管事务", strings.TrimSpace(agentDriverType)))
}
openCtx := context.Background()
var cancel context.CancelFunc
if req.TimeoutMs > 0 {
openCtx, cancel = context.WithTimeout(context.Background(), time.Duration(req.TimeoutMs)*time.Millisecond)
defer cancel()
}
session, err := provider.OpenSessionExecer(openCtx)
if err != nil {
return fail(resp, err.Error())
}
sessionID := runtimeState.nextID()
runtimeState.sessions[sessionID] = session
resp.Data = sessionID
return resp
case agentMethodCloseSession:
if err := runtimeState.closeSession(req.SessionID); err != nil {
return fail(resp, err.Error())
}
return resp
}
if *inst == nil {
if runtimeState.inst == nil {
return fail(resp, "connection not open")
}
if session, ok, err := runtimeState.session(req.SessionID); err != nil {
return fail(resp, err.Error())
} else if ok {
switch method {
case agentMethodQuery:
data, fields, err := queryStatementWithOptionalTimeout(session, req.Query, req.TimeoutMs)
if err != nil {
return fail(resp, err.Error())
}
resp.Data = data
resp.Fields = fields
case agentMethodExec:
affected, err := execStatementWithOptionalTimeout(session, req.Query, req.TimeoutMs)
if err != nil {
return fail(resp, err.Error())
}
resp.RowsAffected = affected
default:
return fail(resp, "当前事务会话不支持该方法")
}
return resp
}
switch method {
case agentMethodPing:
if err := (*inst).Ping(); err != nil {
if err := runtimeState.inst.Ping(); err != nil {
return fail(resp, err.Error())
}
case agentMethodQuery:
data, fields, err := queryWithOptionalTimeout(*inst, req.Query, req.TimeoutMs)
data, fields, err := queryWithOptionalTimeout(runtimeState.inst, req.Query, req.TimeoutMs)
if err != nil {
return fail(resp, err.Error())
}
resp.Data = data
resp.Fields = fields
case agentMethodExec:
affected, err := execWithOptionalTimeout(*inst, req.Query, req.TimeoutMs)
affected, err := execWithOptionalTimeout(runtimeState.inst, req.Query, req.TimeoutMs)
if err != nil {
return fail(resp, err.Error())
}
resp.RowsAffected = affected
case agentMethodGetDatabases:
data, err := (*inst).GetDatabases()
data, err := runtimeState.inst.GetDatabases()
if err != nil {
return fail(resp, err.Error())
}
resp.Data = data
case agentMethodGetTables:
data, err := (*inst).GetTables(req.DBName)
data, err := runtimeState.inst.GetTables(req.DBName)
if err != nil {
return fail(resp, err.Error())
}
resp.Data = data
case agentMethodGetCreateStmt:
data, err := (*inst).GetCreateStatement(req.DBName, req.TableName)
data, err := runtimeState.inst.GetCreateStatement(req.DBName, req.TableName)
if err != nil {
return fail(resp, err.Error())
}
resp.Data = data
case agentMethodGetColumns:
data, err := (*inst).GetColumns(req.DBName, req.TableName)
data, err := runtimeState.inst.GetColumns(req.DBName, req.TableName)
if err != nil {
return fail(resp, err.Error())
}
resp.Data = data
case agentMethodGetAllColumns:
data, err := (*inst).GetAllColumns(req.DBName)
data, err := runtimeState.inst.GetAllColumns(req.DBName)
if err != nil {
return fail(resp, err.Error())
}
resp.Data = data
case agentMethodGetIndexes:
data, err := (*inst).GetIndexes(req.DBName, req.TableName)
data, err := runtimeState.inst.GetIndexes(req.DBName, req.TableName)
if err != nil {
return fail(resp, err.Error())
}
resp.Data = data
case agentMethodGetForeignKey:
data, err := (*inst).GetForeignKeys(req.DBName, req.TableName)
data, err := runtimeState.inst.GetForeignKeys(req.DBName, req.TableName)
if err != nil {
return fail(resp, err.Error())
}
resp.Data = data
case agentMethodGetTriggers:
data, err := (*inst).GetTriggers(req.DBName, req.TableName)
data, err := runtimeState.inst.GetTriggers(req.DBName, req.TableName)
if err != nil {
return fail(resp, err.Error())
}
@@ -215,7 +272,7 @@ func handleRequest(inst *db.Database, req agentRequest) agentResponse {
if req.Changes == nil {
return fail(resp, "变更集为空")
}
applier, ok := (*inst).(interface {
applier, ok := runtimeState.inst.(interface {
ApplyChanges(tableName string, changes connection.ChangeSet) error
})
if !ok {
@@ -231,6 +288,67 @@ func handleRequest(inst *db.Database, req agentRequest) agentResponse {
return resp
}
func (r *agentRuntime) nextID() string {
r.ensureSessionMap()
r.nextSessionID++
return "session-" + strconv.FormatInt(r.nextSessionID, 10)
}
func (r *agentRuntime) session(sessionID string) (db.StatementExecer, bool, error) {
r.ensureSessionMap()
sessionID = strings.TrimSpace(sessionID)
if sessionID == "" {
return nil, false, nil
}
session, ok := r.sessions[sessionID]
if !ok || session == nil {
return nil, false, fmt.Errorf("事务会话不存在或已结束")
}
return session, true, nil
}
func (r *agentRuntime) closeSession(sessionID string) error {
r.ensureSessionMap()
sessionID = strings.TrimSpace(sessionID)
if sessionID == "" {
return fmt.Errorf("事务会话 ID 不能为空")
}
session, ok := r.sessions[sessionID]
if ok {
delete(r.sessions, sessionID)
}
if !ok || session == nil {
return fmt.Errorf("事务会话不存在或已结束")
}
return session.Close()
}
func (r *agentRuntime) close() error {
var closeErr error
r.ensureSessionMap()
for sessionID, session := range r.sessions {
delete(r.sessions, sessionID)
if session != nil {
if err := session.Close(); err != nil && closeErr == nil {
closeErr = err
}
}
}
if r.inst != nil {
if err := r.inst.Close(); err != nil && closeErr == nil {
closeErr = err
}
r.inst = nil
}
return closeErr
}
func (r *agentRuntime) ensureSessionMap() {
if r.sessions == nil {
r.sessions = make(map[string]db.StatementExecer)
}
}
func writeResponse(writer *bufio.Writer, resp agentResponse) error {
// 对响应数据做统一 JSON 安全归一化:
// 将 map[any]any如 duckdb.Map递归转换为 map[string]any避免序列化失败导致代理进程退出。
@@ -301,7 +419,23 @@ func normalizeAgentResponseData(v interface{}) interface{} {
}
}
func queryWithOptionalTimeout(inst db.Database, query string, timeoutMs int64) ([]map[string]interface{}, []string, error) {
type agentQueryRunner interface {
Query(string) ([]map[string]interface{}, []string, error)
}
type agentQueryContextRunner interface {
QueryContext(context.Context, string) ([]map[string]interface{}, []string, error)
}
type agentExecRunner interface {
Exec(string) (int64, error)
}
type agentExecContextRunner interface {
ExecContext(context.Context, string) (int64, error)
}
func queryWithOptionalTimeout(inst agentQueryRunner, query string, timeoutMs int64) ([]map[string]interface{}, []string, error) {
effectiveTimeoutMs := timeoutMs
if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") {
effectiveTimeoutMs = int64(legacyClickHouseDefaultTimeout / time.Millisecond)
@@ -309,9 +443,7 @@ func queryWithOptionalTimeout(inst db.Database, query string, timeoutMs int64) (
if effectiveTimeoutMs <= 0 {
return inst.Query(query)
}
if q, ok := inst.(interface {
QueryContext(context.Context, string) ([]map[string]interface{}, []string, error)
}); ok {
if q, ok := inst.(agentQueryContextRunner); ok {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond)
defer cancel()
return q.QueryContext(ctx, query)
@@ -319,7 +451,15 @@ func queryWithOptionalTimeout(inst db.Database, query string, timeoutMs int64) (
return inst.Query(query)
}
func execWithOptionalTimeout(inst db.Database, query string, timeoutMs int64) (int64, error) {
func queryStatementWithOptionalTimeout(inst db.StatementExecer, query string, timeoutMs int64) ([]map[string]interface{}, []string, error) {
queryRunner, ok := inst.(agentQueryRunner)
if !ok {
return nil, nil, fmt.Errorf("当前事务会话不支持查询语句")
}
return queryWithOptionalTimeout(queryRunner, query, timeoutMs)
}
func execWithOptionalTimeout(inst agentExecRunner, query string, timeoutMs int64) (int64, error) {
effectiveTimeoutMs := timeoutMs
if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") {
effectiveTimeoutMs = int64(legacyClickHouseDefaultTimeout / time.Millisecond)
@@ -327,12 +467,14 @@ func execWithOptionalTimeout(inst db.Database, query string, timeoutMs int64) (i
if effectiveTimeoutMs <= 0 {
return inst.Exec(query)
}
if e, ok := inst.(interface {
ExecContext(context.Context, string) (int64, error)
}); ok {
if e, ok := inst.(agentExecContextRunner); ok {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond)
defer cancel()
return e.ExecContext(ctx, query)
}
return inst.Exec(query)
}
func execStatementWithOptionalTimeout(inst db.StatementExecer, query string, timeoutMs int64) (int64, error) {
return execWithOptionalTimeout(inst, query, timeoutMs)
}

View File

@@ -6,6 +6,7 @@ import (
"context"
"encoding/json"
"errors"
"strings"
"testing"
"time"
@@ -77,8 +78,8 @@ func TestHandleRequestMetadataReportsAgentRevision(t *testing.T) {
agentDriverType = "clickhouse"
agentDatabaseFactory = func() db.Database { return nil }
var inst db.Database
resp := handleRequest(&inst, agentRequest{ID: 7, Method: agentMethodMetadata})
runtimeState := &agentRuntime{sessions: make(map[string]db.StatementExecer)}
resp := handleRequest(runtimeState, agentRequest{ID: 7, Method: agentMethodMetadata})
if !resp.Success {
t.Fatalf("metadata request failed: %s", resp.Error)
}
@@ -150,6 +151,45 @@ func (f *fakeAgentTimeoutDB) GetTriggers(dbName, tableName string) ([]connection
return nil, nil
}
type fakeAgentSessionDB struct {
fakeAgentTimeoutDB
session *fakeAgentStatementSession
}
func (f *fakeAgentSessionDB) OpenSessionExecer(ctx context.Context) (db.StatementExecer, error) {
f.session = &fakeAgentStatementSession{}
return f.session, nil
}
type fakeAgentStatementSession struct {
queryCalls int
execCalls int
closed bool
}
func (f *fakeAgentStatementSession) Query(query string) ([]map[string]interface{}, []string, error) {
return f.QueryContext(context.Background(), query)
}
func (f *fakeAgentStatementSession) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
f.queryCalls++
return []map[string]interface{}{{"session_ok": 1}}, []string{"session_ok"}, nil
}
func (f *fakeAgentStatementSession) Exec(query string) (int64, error) {
return f.ExecContext(context.Background(), query)
}
func (f *fakeAgentStatementSession) ExecContext(ctx context.Context, query string) (int64, error) {
f.execCalls++
return 9, nil
}
func (f *fakeAgentStatementSession) Close() error {
f.closed = true
return nil
}
func TestQueryWithOptionalTimeout_UsesQueryContext(t *testing.T) {
fake := &fakeAgentTimeoutDB{}
data, fields, err := queryWithOptionalTimeout(fake, "SELECT 1", int64((2 * time.Second).Milliseconds()))
@@ -198,3 +238,71 @@ func TestQueryWithOptionalTimeout_ClickHouseLegacyModeUsesQueryContext(t *testin
t.Fatalf("clickhouse legacy query 调用路径异常QueryContext=%v Query=%v", fake.queryContextCalled, fake.queryCalled)
}
}
func TestHandleRequest_UsesPinnedSessionForSessionScopedQueryAndExec(t *testing.T) {
old := agentDriverType
defer func() { agentDriverType = old }()
agentDriverType = "sqlserver"
fake := &fakeAgentSessionDB{}
runtimeState := &agentRuntime{
inst: fake,
sessions: make(map[string]db.StatementExecer),
}
openResp := handleRequest(runtimeState, agentRequest{ID: 1, Method: agentMethodOpenSession})
if !openResp.Success {
t.Fatalf("openSession failed: %s", openResp.Error)
}
sessionID, ok := openResp.Data.(string)
if !ok || strings.TrimSpace(sessionID) == "" {
t.Fatalf("unexpected session id payload: %#v", openResp.Data)
}
if fake.session == nil {
t.Fatal("expected OpenSessionExecer to create a pinned session")
}
queryResp := handleRequest(runtimeState, agentRequest{
ID: 2,
Method: agentMethodQuery,
SessionID: sessionID,
Query: "SELECT 1",
})
if !queryResp.Success {
t.Fatalf("session query failed: %s", queryResp.Error)
}
if fake.queryCalled || fake.queryContextCalled {
t.Fatalf("expected session query to bypass database-level query path, got Query=%v QueryContext=%v", fake.queryCalled, fake.queryContextCalled)
}
if fake.session.queryCalls != 1 {
t.Fatalf("expected pinned session queryCalls=1, got %d", fake.session.queryCalls)
}
execResp := handleRequest(runtimeState, agentRequest{
ID: 3,
Method: agentMethodExec,
SessionID: sessionID,
Query: "UPDATE t SET v = 1",
})
if !execResp.Success {
t.Fatalf("session exec failed: %s", execResp.Error)
}
if fake.execCalled || fake.execContextCalled {
t.Fatalf("expected session exec to bypass database-level exec path, got Exec=%v ExecContext=%v", fake.execCalled, fake.execContextCalled)
}
if fake.session.execCalls != 1 {
t.Fatalf("expected pinned session execCalls=1, got %d", fake.session.execCalls)
}
closeResp := handleRequest(runtimeState, agentRequest{
ID: 4,
Method: agentMethodCloseSession,
SessionID: sessionID,
})
if !closeResp.Success {
t.Fatalf("closeSession failed: %s", closeResp.Error)
}
if !fake.session.closed {
t.Fatal("expected pinned session to close")
}
}