diff --git a/internal/app/methods_jvm.go b/internal/app/methods_jvm.go index 777d5c8..cef2196 100644 --- a/internal/app/methods_jvm.go +++ b/internal/app/methods_jvm.go @@ -82,6 +82,12 @@ func (a *App) JVMGetValue(cfg connection.ConnectionConfig, resourcePath string) } func (a *App) JVMPreviewChange(cfg connection.ConnectionConfig, req jvm.ChangeRequest) connection.QueryResult { + var err error + req, err = jvm.NormalizeChangeRequest(req) + if err != nil { + return connection.QueryResult{Success: false, Message: err.Error()} + } + normalized, provider, err := resolveJVMProviderForMode(cfg, req.ProviderMode) if err != nil { return connection.QueryResult{Success: false, Message: err.Error()} @@ -96,6 +102,12 @@ func (a *App) JVMPreviewChange(cfg connection.ConnectionConfig, req jvm.ChangeRe } func (a *App) JVMApplyChange(cfg connection.ConnectionConfig, req jvm.ChangeRequest) connection.QueryResult { + var err error + req, err = jvm.NormalizeChangeRequest(req) + if err != nil { + return connection.QueryResult{Success: false, Message: err.Error()} + } + normalized, provider, err := resolveJVMProviderForMode(cfg, req.ProviderMode) if err != nil { return connection.QueryResult{Success: false, Message: err.Error()} diff --git a/internal/app/methods_jvm_test.go b/internal/app/methods_jvm_test.go index 5079b49..5bd21a6 100644 --- a/internal/app/methods_jvm_test.go +++ b/internal/app/methods_jvm_test.go @@ -25,6 +25,8 @@ type fakeJVMProvider struct { previewErr error apply jvm.ApplyResult applyErr error + previewReq *jvm.ChangeRequest + applyReq *jvm.ChangeRequest } func (f fakeJVMProvider) Mode() string { return jvm.ModeJMX } @@ -40,13 +42,19 @@ func (f fakeJVMProvider) ListResources(context.Context, connection.ConnectionCon func (f fakeJVMProvider) GetValue(context.Context, connection.ConnectionConfig, string) (jvm.ValueSnapshot, error) { return f.value, f.valueErr } -func (f fakeJVMProvider) PreviewChange(context.Context, connection.ConnectionConfig, jvm.ChangeRequest) (jvm.ChangePreview, error) { +func (f fakeJVMProvider) PreviewChange(_ context.Context, _ connection.ConnectionConfig, req jvm.ChangeRequest) (jvm.ChangePreview, error) { + if f.previewReq != nil { + *f.previewReq = req + } if !f.previewSet { return jvm.ChangePreview{Allowed: true, Summary: "preview", RiskLevel: "low"}, f.previewErr } return f.preview, f.previewErr } -func (f fakeJVMProvider) ApplyChange(context.Context, connection.ConnectionConfig, jvm.ChangeRequest) (jvm.ApplyResult, error) { +func (f fakeJVMProvider) ApplyChange(_ context.Context, _ connection.ConnectionConfig, req jvm.ChangeRequest) (jvm.ApplyResult, error) { + if f.applyReq != nil { + *f.applyReq = req + } return f.apply, f.applyErr } @@ -578,6 +586,75 @@ func TestJVMApplyChangePersistsAuditSource(t *testing.T) { } } +func TestJVMApplyChangeNormalizesRequestBeforeProviderAndAudit(t *testing.T) { + app := NewAppWithSecretStore(nil) + app.configDir = t.TempDir() + readOnly := false + var previewReq jvm.ChangeRequest + var applyReq jvm.ChangeRequest + restore := swapJVMProviderFactory(func(mode string) (jvm.Provider, error) { + return fakeJVMProvider{ + value: jvm.ValueSnapshot{ + ResourceID: "/cache/orders", + Kind: "entry", + Format: "json", + }, + previewReq: &previewReq, + applyReq: &applyReq, + apply: jvm.ApplyResult{ + Status: "applied", + UpdatedValue: jvm.ValueSnapshot{ + ResourceID: "/cache/orders", + Kind: "entry", + Format: "json", + }, + }, + }, nil + }) + defer restore() + + res := app.JVMApplyChange(connection.ConnectionConfig{ + Type: "jvm", + ID: "conn-orders", + Host: "orders.internal", + JVM: connection.JVMConfig{ + ReadOnly: &readOnly, + PreferredMode: "endpoint", + AllowedModes: []string{"endpoint"}, + }, + }, jvm.ChangeRequest{ + ProviderMode: " endpoint ", + ResourceID: " /cache/orders ", + Action: " put ", + Reason: " repair cache ", + Source: " manual ", + Payload: map[string]any{ + "status": "ready", + }, + }) + if !res.Success { + t.Fatalf("expected success, got %+v", res) + } + if previewReq.ProviderMode != "endpoint" || previewReq.ResourceID != "/cache/orders" || previewReq.Action != "put" || previewReq.Reason != "repair cache" { + t.Fatalf("expected normalized preview request, got %#v", previewReq) + } + if applyReq.ProviderMode != "endpoint" || applyReq.ResourceID != "/cache/orders" || applyReq.Action != "put" || applyReq.Reason != "repair cache" || applyReq.Source != "manual" { + t.Fatalf("expected normalized apply request, got %#v", applyReq) + } + + listRes := app.JVMListAuditRecords("conn-orders", 10) + if !listRes.Success { + t.Fatalf("expected audit list success, got %+v", listRes) + } + records, ok := listRes.Data.([]jvm.AuditRecord) + if !ok || len(records) != 1 { + t.Fatalf("expected one audit record, got %#v", listRes.Data) + } + if records[0].ProviderMode != "endpoint" || records[0].ResourceID != "/cache/orders" || records[0].Action != "put" || records[0].Reason != "repair cache" || records[0].Source != "manual" { + t.Fatalf("expected normalized audit record, got %#v", records[0]) + } +} + func TestJVMPreviewChangeRejectsModeOutsideAllowedModes(t *testing.T) { app := NewAppWithSecretStore(nil) diff --git a/internal/jvm/diagnostic_agent_bridge.go b/internal/jvm/diagnostic_agent_bridge.go index ae94405..32ca610 100644 --- a/internal/jvm/diagnostic_agent_bridge.go +++ b/internal/jvm/diagnostic_agent_bridge.go @@ -164,8 +164,15 @@ func consumeDiagnosticSSE(body io.Reader, sink DiagnosticEventSink) error { return nil } + dataPayload := bytes.Join(stringSliceToBytes(dataLines), []byte("\n")) + if len(bytes.TrimSpace(dataPayload)) == 0 { + eventName = "" + dataLines = dataLines[:0] + return nil + } + var chunk DiagnosticEventChunk - if err := json.Unmarshal([]byte(bytes.Join(stringSliceToBytes(dataLines), []byte("\n"))), &chunk); err != nil { + if err := json.Unmarshal(dataPayload, &chunk); err != nil { return fmt.Errorf("diagnostic sse decode failed: %w", err) } if chunk.Event == "" { diff --git a/internal/jvm/diagnostic_agent_bridge_test.go b/internal/jvm/diagnostic_agent_bridge_test.go index 90cd799..6123201 100644 --- a/internal/jvm/diagnostic_agent_bridge_test.go +++ b/internal/jvm/diagnostic_agent_bridge_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "strings" "testing" "GoNavi-Wails/internal/connection" @@ -132,3 +133,22 @@ func TestDiagnosticAgentBridgeCancelCommandSendsRequest(t *testing.T) { t.Fatalf("unexpected cancel payload: %#v", cancelPayload) } } + +func TestConsumeDiagnosticSSEToleratesEmptyHeartbeatEvents(t *testing.T) { + input := strings.NewReader(": ping\n\ndata:\n\nevent: chunk\ndata: {\"sessionId\":\"sess-1\",\"commandId\":\"cmd-1\",\"phase\":\"running\",\"content\":\"ok\"}\n\n") + var chunks []DiagnosticEventChunk + + err := consumeDiagnosticSSE(input, func(chunk DiagnosticEventChunk) { + chunks = append(chunks, chunk) + }) + + if err != nil { + t.Fatalf("consumeDiagnosticSSE returned error for heartbeat-only event: %v", err) + } + if len(chunks) != 1 { + t.Fatalf("expected exactly one diagnostic chunk, got %#v", chunks) + } + if chunks[0].Content != "ok" || chunks[0].Event != "chunk" { + t.Fatalf("unexpected diagnostic chunk: %#v", chunks[0]) + } +} diff --git a/internal/jvm/diagnostic_arthas_tunnel.go b/internal/jvm/diagnostic_arthas_tunnel.go index b331931..16ae985 100644 --- a/internal/jvm/diagnostic_arthas_tunnel.go +++ b/internal/jvm/diagnostic_arthas_tunnel.go @@ -9,6 +9,7 @@ import ( "net/http" "net/url" "regexp" + "sort" "strings" "sync" "time" @@ -19,11 +20,13 @@ import ( ) const ( - arthasTunnelDefaultCols = 160 - arthasTunnelDefaultRows = 48 + arthasTunnelDefaultCols = 160 + arthasTunnelDefaultRows = 48 arthasTunnelReadStep = 250 * time.Millisecond arthasTunnelPromptDetectionTail = 96 arthasTunnelInterruptInput = "\u0003" + arthasTunnelSessionTTL = 12 * time.Hour + arthasTunnelMaxSessions = 128 ) var arthasPromptPattern = regexp.MustCompile(`\[arthas@[^\]]+\]\$ `) @@ -143,7 +146,7 @@ func (t *DiagnosticArthasTunnelTransport) ExecuteCommand(ctx context.Context, cf commandCtx, cancel := context.WithTimeout(ctx, runtime.timeout) defer cancel() - activeCommand, err := diagnosticArthasTunnelSessions.beginCommand(req.SessionID, req.CommandID) + activeCommand, err := diagnosticArthasTunnelSessions.beginCommand(req.SessionID, req.CommandID, cfg) if err != nil { return err } @@ -156,6 +159,11 @@ func (t *DiagnosticArthasTunnelTransport) ExecuteCommand(ctx context.Context, cf activeCommand.attachConn(conn) defer conn.Close() + if activeCommand.isCancelRequested() { + t.emitChunk(req, "canceled", "Arthas 命令已取消") + return fmt.Errorf("arthas tunnel command canceled") + } + if err := activeCommand.send(arthasTunnelTTYFrame{ Action: "resize", Cols: arthasTunnelDefaultCols, @@ -462,6 +470,7 @@ func (r *arthasTunnelSessionRegistry) createSession(cfg connection.ConnectionCon sessionID := "arthas-" + uuid.NewString() startedAt := time.Now().UnixMilli() + r.pruneLocked(startedAt) r.sessions[sessionID] = arthasTunnelSessionMeta{ createdAt: startedAt, targetID: strings.TrimSpace(cfg.JVM.Diagnostic.TargetID), @@ -475,13 +484,18 @@ func (r *arthasTunnelSessionRegistry) createSession(cfg connection.ConnectionCon } } -func (r *arthasTunnelSessionRegistry) beginCommand(sessionID string, commandID string) (*arthasTunnelActiveCommand, error) { +func (r *arthasTunnelSessionRegistry) beginCommand(sessionID string, commandID string, cfg connection.ConnectionConfig) (*arthasTunnelActiveCommand, error) { r.mu.Lock() defer r.mu.Unlock() - if _, ok := r.sessions[sessionID]; !ok { + r.pruneLocked(time.Now().UnixMilli()) + meta, ok := r.sessions[sessionID] + if !ok { return nil, errors.New("诊断会话不存在,请重新创建 Arthas Tunnel 会话") } + if !meta.matchesConfig(cfg) { + return nil, errors.New("Arthas Tunnel 会话配置已变化,请重新创建诊断会话") + } if existing := r.active[sessionID]; existing != nil { return nil, errors.New("当前 Arthas Tunnel 会话已有命令在执行,请先等待完成或取消") } @@ -504,6 +518,49 @@ func (r *arthasTunnelSessionRegistry) finishCommand(sessionID string, commandID } } +func (r *arthasTunnelSessionRegistry) pruneLocked(nowMillis int64) { + if len(r.sessions) == 0 { + return + } + + cutoff := nowMillis - int64(arthasTunnelSessionTTL/time.Millisecond) + for sessionID, meta := range r.sessions { + if meta.createdAt > 0 && meta.createdAt < cutoff { + delete(r.sessions, sessionID) + delete(r.active, sessionID) + } + } + + if len(r.sessions) <= arthasTunnelMaxSessions { + return + } + + type sessionAge struct { + sessionID string + createdAt int64 + } + items := make([]sessionAge, 0, len(r.sessions)) + for sessionID, meta := range r.sessions { + items = append(items, sessionAge{sessionID: sessionID, createdAt: meta.createdAt}) + } + sort.Slice(items, func(i, j int) bool { + return items[i].createdAt < items[j].createdAt + }) + for len(r.sessions) > arthasTunnelMaxSessions && len(items) > 0 { + victim := items[0].sessionID + items = items[1:] + if _, active := r.active[victim]; active { + continue + } + delete(r.sessions, victim) + } +} + +func (m arthasTunnelSessionMeta) matchesConfig(cfg connection.ConnectionConfig) bool { + return strings.TrimSpace(m.targetID) == strings.TrimSpace(cfg.JVM.Diagnostic.TargetID) && + strings.TrimSpace(m.baseURL) == strings.TrimSpace(cfg.JVM.Diagnostic.BaseURL) +} + func (r *arthasTunnelSessionRegistry) cancelCommand(sessionID string, commandID string) error { r.mu.Lock() activeCommand := r.active[sessionID] @@ -564,8 +621,12 @@ func (c *arthasTunnelActiveCommand) send(frame arthasTunnelTTYFrame) error { func (c *arthasTunnelActiveCommand) requestCancel() error { c.mu.Lock() c.cancelRequested = true + conn := c.conn c.mu.Unlock() + if conn == nil { + return nil + } return c.send(arthasTunnelTTYFrame{ Action: "read", Data: arthasTunnelInterruptInput, diff --git a/internal/jvm/diagnostic_arthas_tunnel_test.go b/internal/jvm/diagnostic_arthas_tunnel_test.go index f30fbae..04155ab 100644 --- a/internal/jvm/diagnostic_arthas_tunnel_test.go +++ b/internal/jvm/diagnostic_arthas_tunnel_test.go @@ -40,7 +40,7 @@ func newFakeArthasTunnelServer( t.Helper() fake := &fakeArthasTunnelServer{ - t: t, + t: t, upgrader: websocket.Upgrader{ CheckOrigin: func(*http.Request) bool { return true }, }, @@ -317,6 +317,78 @@ func TestDiagnosticArthasTunnelCancelCommandInterruptsActiveCommand(t *testing.T } } +func TestArthasTunnelActiveCommandAcceptsCancelBeforeConnectionAttach(t *testing.T) { + activeCommand := &arthasTunnelActiveCommand{commandID: "cmd-before-attach"} + + if err := activeCommand.requestCancel(); err != nil { + t.Fatalf("expected pre-attach cancel request to be recorded without error, got %v", err) + } + if !activeCommand.isCancelRequested() { + t.Fatal("expected cancelRequested flag to be recorded") + } +} + +func TestDiagnosticArthasTunnelRejectsSessionConfigDrift(t *testing.T) { + server := newFakeArthasTunnelServer(t, func(conn *websocket.Conn, frame fakeArthasTTYFrame) { + if frame.Action == "read" && strings.Contains(frame.Data, "thread -n 5") { + _ = conn.WriteMessage(websocket.TextMessage, []byte("thread top 5\r\n[arthas@12345]$ ")) + } + }) + defer server.close() + + transport, err := NewDiagnosticTransport(DiagnosticTransportArthasTunnel) + if err != nil { + t.Fatalf("NewDiagnosticTransport returned error: %v", err) + } + + tunnel := transport.(*DiagnosticArthasTunnelTransport) + cfg := testArthasTunnelConfig(server.wsURL()) + handle, err := tunnel.StartSession(context.Background(), cfg, DiagnosticSessionRequest{}) + if err != nil { + t.Fatalf("StartSession returned error: %v", err) + } + + driftedCfg := cfg + driftedCfg.JVM.Diagnostic.TargetID = "orders-prod-02" + err = tunnel.ExecuteCommand(context.Background(), driftedCfg, DiagnosticCommandRequest{ + SessionID: handle.SessionID, + CommandID: "cmd-drift", + Command: "thread -n 5", + }) + if err == nil { + t.Fatal("expected config drift to reject stale Arthas Tunnel session") + } + if !strings.Contains(err.Error(), "会话配置已变化") { + t.Fatalf("expected config drift error, got %v", err) + } +} + +func TestArthasTunnelSessionRegistryPrunesExpiredSessions(t *testing.T) { + registry := newArthasTunnelSessionRegistry() + cfg := testArthasTunnelConfig("http://127.0.0.1:7777") + oldHandle := registry.createSession(cfg) + + registry.mu.Lock() + oldMeta := registry.sessions[oldHandle.SessionID] + oldMeta.createdAt = time.Now().Add(-24 * time.Hour).UnixMilli() + registry.sessions[oldHandle.SessionID] = oldMeta + registry.mu.Unlock() + + registry.createSession(cfg) + + registry.mu.Lock() + _, oldExists := registry.sessions[oldHandle.SessionID] + sessionCount := len(registry.sessions) + registry.mu.Unlock() + + if oldExists { + t.Fatalf("expected expired session %s to be pruned", oldHandle.SessionID) + } + if sessionCount != 1 { + t.Fatalf("expected only fresh session to remain, got %d", sessionCount) + } +} + func TestDiagnosticArthasTunnelRequiresTargetID(t *testing.T) { transport, err := NewDiagnosticTransport(DiagnosticTransportArthasTunnel) if err != nil { diff --git a/internal/jvm/guard.go b/internal/jvm/guard.go index f83052c..bde86c9 100644 --- a/internal/jvm/guard.go +++ b/internal/jvm/guard.go @@ -17,19 +17,18 @@ func BuildChangePreview( cfg connection.ConnectionConfig, req ChangeRequest, ) (ChangePreview, error) { + req, err := NormalizeChangeRequest(req) + if err != nil { + return ChangePreview{}, err + } + normalized, err := NormalizeConnectionConfig(cfg) if err != nil { return ChangePreview{}, err } - resourceID := strings.TrimSpace(req.ResourceID) - if resourceID == "" { - return ChangePreview{}, fmt.Errorf("resource id is required") - } - action := strings.TrimSpace(req.Action) - if action == "" { - return ChangePreview{}, fmt.Errorf("action is required") - } + resourceID := req.ResourceID + action := req.Action before := ValueSnapshot{ ResourceID: resourceID, @@ -111,6 +110,28 @@ func BuildChangePreview( return preview, nil } +func NormalizeChangeRequest(req ChangeRequest) (ChangeRequest, error) { + normalized := req + normalized.ProviderMode = strings.ToLower(strings.TrimSpace(normalized.ProviderMode)) + normalized.ResourceID = strings.TrimSpace(normalized.ResourceID) + normalized.Action = strings.TrimSpace(normalized.Action) + normalized.Reason = strings.TrimSpace(normalized.Reason) + normalized.Source = strings.TrimSpace(normalized.Source) + normalized.ExpectedVersion = strings.TrimSpace(normalized.ExpectedVersion) + + if normalized.ResourceID == "" { + return ChangeRequest{}, fmt.Errorf("resource id is required") + } + if normalized.Action == "" { + return ChangeRequest{}, fmt.Errorf("action is required") + } + if normalized.Reason == "" { + return ChangeRequest{}, fmt.Errorf("reason is required") + } + + return normalized, nil +} + func hasSnapshotOverride(snapshot ValueSnapshot) bool { return strings.TrimSpace(snapshot.ResourceID) != "" || strings.TrimSpace(snapshot.Kind) != "" || diff --git a/internal/jvm/guard_test.go b/internal/jvm/guard_test.go index 95ae5d0..9c8f0a1 100644 --- a/internal/jvm/guard_test.go +++ b/internal/jvm/guard_test.go @@ -76,6 +76,32 @@ func TestPreviewChangeBlocksReadOnlyConnection(t *testing.T) { } } +func TestPreviewChangeRejectsMissingReason(t *testing.T) { + readOnly := false + + _, err := BuildChangePreview(context.Background(), fakeGuardProvider{}, connection.ConnectionConfig{ + Type: "jvm", + ID: "conn-writable", + Host: "orders.internal", + JVM: connection.JVMConfig{ + ReadOnly: &readOnly, + PreferredMode: ModeJMX, + AllowedModes: []string{ModeJMX}, + }, + }, ChangeRequest{ + ProviderMode: ModeJMX, + ResourceID: "/cache/orders", + Action: "put", + Reason: " ", + Payload: map[string]any{ + "status": "ready", + }, + }) + if err == nil || !strings.Contains(err.Error(), "reason is required") { + t.Fatalf("expected missing reason to be rejected, got %v", err) + } +} + func TestPreviewChangeReturnsProviderPreviewErrorWhenWriteAllowed(t *testing.T) { readOnly := false diff --git a/internal/jvm/http_contract.go b/internal/jvm/http_contract.go index ef31fb8..f3e015f 100644 --- a/internal/jvm/http_contract.go +++ b/internal/jvm/http_contract.go @@ -53,6 +53,8 @@ func normalizeContractBaseURL(rawBaseURL string, errorPrefix string) (*url.URL, if parsed.Scheme != "http" && parsed.Scheme != "https" { return nil, fmt.Errorf("%s scheme is unsupported: %s", errorPrefix, parsed.Scheme) } + parsed.RawQuery = "" + parsed.Fragment = "" return parsed, nil } diff --git a/internal/jvm/provider_contract_test.go b/internal/jvm/provider_contract_test.go index 7d9fdd4..ea391ed 100644 --- a/internal/jvm/provider_contract_test.go +++ b/internal/jvm/provider_contract_test.go @@ -2,6 +2,8 @@ package jvm import ( "context" + "net/http" + "net/http/httptest" "strings" "testing" @@ -89,6 +91,32 @@ func TestHTTPProviderTestConnectionReturnsErrorWhenBaseURLInvalid(t *testing.T) } } +func TestHTTPProviderProbeStripsBaseURLQueryAndFragment(t *testing.T) { + provider := NewHTTPProvider() + seen := make(chan string, 1) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + seen <- r.URL.RequestURI() + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + err := provider.TestConnection(context.Background(), connection.ConnectionConfig{ + Type: "jvm", + JVM: connection.JVMConfig{ + Endpoint: connection.JVMEndpointConfig{ + BaseURL: server.URL + "/gonavi/jvm?api_key=secret-token#debug", + }, + }, + }) + + if err != nil { + t.Fatalf("expected probe to succeed, got %v", err) + } + if got := <-seen; got != "/gonavi/jvm" { + t.Fatalf("expected query and fragment to be stripped, got %q", got) + } +} + func TestAgentProviderTestConnectionReturnsErrorWhenBaseURLMissing(t *testing.T) { provider := NewAgentProvider()