🐛 fix(jvm): 加固诊断与变更安全边界

- 诊断 SSE 支持空心跳事件,避免无输出时解码失败

- Arthas Tunnel 增加会话过期清理、配置漂移校验和取消兜底

- Provider 合约清理 Base URL 查询参数和片段,避免探测泄露敏感信息

- JVM 变更请求强制校验原因并规范化写入审计字段
This commit is contained in:
Syngnat
2026-04-26 14:34:43 +08:00
parent 38e71119a4
commit f16e2f15c2
10 changed files with 343 additions and 17 deletions

View File

@@ -82,6 +82,12 @@ func (a *App) JVMGetValue(cfg connection.ConnectionConfig, resourcePath string)
}
func (a *App) JVMPreviewChange(cfg connection.ConnectionConfig, req jvm.ChangeRequest) connection.QueryResult {
var err error
req, err = jvm.NormalizeChangeRequest(req)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
normalized, provider, err := resolveJVMProviderForMode(cfg, req.ProviderMode)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
@@ -96,6 +102,12 @@ func (a *App) JVMPreviewChange(cfg connection.ConnectionConfig, req jvm.ChangeRe
}
func (a *App) JVMApplyChange(cfg connection.ConnectionConfig, req jvm.ChangeRequest) connection.QueryResult {
var err error
req, err = jvm.NormalizeChangeRequest(req)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
normalized, provider, err := resolveJVMProviderForMode(cfg, req.ProviderMode)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}

View File

@@ -25,6 +25,8 @@ type fakeJVMProvider struct {
previewErr error
apply jvm.ApplyResult
applyErr error
previewReq *jvm.ChangeRequest
applyReq *jvm.ChangeRequest
}
func (f fakeJVMProvider) Mode() string { return jvm.ModeJMX }
@@ -40,13 +42,19 @@ func (f fakeJVMProvider) ListResources(context.Context, connection.ConnectionCon
func (f fakeJVMProvider) GetValue(context.Context, connection.ConnectionConfig, string) (jvm.ValueSnapshot, error) {
return f.value, f.valueErr
}
func (f fakeJVMProvider) PreviewChange(context.Context, connection.ConnectionConfig, jvm.ChangeRequest) (jvm.ChangePreview, error) {
func (f fakeJVMProvider) PreviewChange(_ context.Context, _ connection.ConnectionConfig, req jvm.ChangeRequest) (jvm.ChangePreview, error) {
if f.previewReq != nil {
*f.previewReq = req
}
if !f.previewSet {
return jvm.ChangePreview{Allowed: true, Summary: "preview", RiskLevel: "low"}, f.previewErr
}
return f.preview, f.previewErr
}
func (f fakeJVMProvider) ApplyChange(context.Context, connection.ConnectionConfig, jvm.ChangeRequest) (jvm.ApplyResult, error) {
func (f fakeJVMProvider) ApplyChange(_ context.Context, _ connection.ConnectionConfig, req jvm.ChangeRequest) (jvm.ApplyResult, error) {
if f.applyReq != nil {
*f.applyReq = req
}
return f.apply, f.applyErr
}
@@ -578,6 +586,75 @@ func TestJVMApplyChangePersistsAuditSource(t *testing.T) {
}
}
func TestJVMApplyChangeNormalizesRequestBeforeProviderAndAudit(t *testing.T) {
app := NewAppWithSecretStore(nil)
app.configDir = t.TempDir()
readOnly := false
var previewReq jvm.ChangeRequest
var applyReq jvm.ChangeRequest
restore := swapJVMProviderFactory(func(mode string) (jvm.Provider, error) {
return fakeJVMProvider{
value: jvm.ValueSnapshot{
ResourceID: "/cache/orders",
Kind: "entry",
Format: "json",
},
previewReq: &previewReq,
applyReq: &applyReq,
apply: jvm.ApplyResult{
Status: "applied",
UpdatedValue: jvm.ValueSnapshot{
ResourceID: "/cache/orders",
Kind: "entry",
Format: "json",
},
},
}, nil
})
defer restore()
res := app.JVMApplyChange(connection.ConnectionConfig{
Type: "jvm",
ID: "conn-orders",
Host: "orders.internal",
JVM: connection.JVMConfig{
ReadOnly: &readOnly,
PreferredMode: "endpoint",
AllowedModes: []string{"endpoint"},
},
}, jvm.ChangeRequest{
ProviderMode: " endpoint ",
ResourceID: " /cache/orders ",
Action: " put ",
Reason: " repair cache ",
Source: " manual ",
Payload: map[string]any{
"status": "ready",
},
})
if !res.Success {
t.Fatalf("expected success, got %+v", res)
}
if previewReq.ProviderMode != "endpoint" || previewReq.ResourceID != "/cache/orders" || previewReq.Action != "put" || previewReq.Reason != "repair cache" {
t.Fatalf("expected normalized preview request, got %#v", previewReq)
}
if applyReq.ProviderMode != "endpoint" || applyReq.ResourceID != "/cache/orders" || applyReq.Action != "put" || applyReq.Reason != "repair cache" || applyReq.Source != "manual" {
t.Fatalf("expected normalized apply request, got %#v", applyReq)
}
listRes := app.JVMListAuditRecords("conn-orders", 10)
if !listRes.Success {
t.Fatalf("expected audit list success, got %+v", listRes)
}
records, ok := listRes.Data.([]jvm.AuditRecord)
if !ok || len(records) != 1 {
t.Fatalf("expected one audit record, got %#v", listRes.Data)
}
if records[0].ProviderMode != "endpoint" || records[0].ResourceID != "/cache/orders" || records[0].Action != "put" || records[0].Reason != "repair cache" || records[0].Source != "manual" {
t.Fatalf("expected normalized audit record, got %#v", records[0])
}
}
func TestJVMPreviewChangeRejectsModeOutsideAllowedModes(t *testing.T) {
app := NewAppWithSecretStore(nil)

View File

@@ -164,8 +164,15 @@ func consumeDiagnosticSSE(body io.Reader, sink DiagnosticEventSink) error {
return nil
}
dataPayload := bytes.Join(stringSliceToBytes(dataLines), []byte("\n"))
if len(bytes.TrimSpace(dataPayload)) == 0 {
eventName = ""
dataLines = dataLines[:0]
return nil
}
var chunk DiagnosticEventChunk
if err := json.Unmarshal([]byte(bytes.Join(stringSliceToBytes(dataLines), []byte("\n"))), &chunk); err != nil {
if err := json.Unmarshal(dataPayload, &chunk); err != nil {
return fmt.Errorf("diagnostic sse decode failed: %w", err)
}
if chunk.Event == "" {

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"GoNavi-Wails/internal/connection"
@@ -132,3 +133,22 @@ func TestDiagnosticAgentBridgeCancelCommandSendsRequest(t *testing.T) {
t.Fatalf("unexpected cancel payload: %#v", cancelPayload)
}
}
func TestConsumeDiagnosticSSEToleratesEmptyHeartbeatEvents(t *testing.T) {
input := strings.NewReader(": ping\n\ndata:\n\nevent: chunk\ndata: {\"sessionId\":\"sess-1\",\"commandId\":\"cmd-1\",\"phase\":\"running\",\"content\":\"ok\"}\n\n")
var chunks []DiagnosticEventChunk
err := consumeDiagnosticSSE(input, func(chunk DiagnosticEventChunk) {
chunks = append(chunks, chunk)
})
if err != nil {
t.Fatalf("consumeDiagnosticSSE returned error for heartbeat-only event: %v", err)
}
if len(chunks) != 1 {
t.Fatalf("expected exactly one diagnostic chunk, got %#v", chunks)
}
if chunks[0].Content != "ok" || chunks[0].Event != "chunk" {
t.Fatalf("unexpected diagnostic chunk: %#v", chunks[0])
}
}

View File

@@ -9,6 +9,7 @@ import (
"net/http"
"net/url"
"regexp"
"sort"
"strings"
"sync"
"time"
@@ -19,11 +20,13 @@ import (
)
const (
arthasTunnelDefaultCols = 160
arthasTunnelDefaultRows = 48
arthasTunnelDefaultCols = 160
arthasTunnelDefaultRows = 48
arthasTunnelReadStep = 250 * time.Millisecond
arthasTunnelPromptDetectionTail = 96
arthasTunnelInterruptInput = "\u0003"
arthasTunnelSessionTTL = 12 * time.Hour
arthasTunnelMaxSessions = 128
)
var arthasPromptPattern = regexp.MustCompile(`\[arthas@[^\]]+\]\$ `)
@@ -143,7 +146,7 @@ func (t *DiagnosticArthasTunnelTransport) ExecuteCommand(ctx context.Context, cf
commandCtx, cancel := context.WithTimeout(ctx, runtime.timeout)
defer cancel()
activeCommand, err := diagnosticArthasTunnelSessions.beginCommand(req.SessionID, req.CommandID)
activeCommand, err := diagnosticArthasTunnelSessions.beginCommand(req.SessionID, req.CommandID, cfg)
if err != nil {
return err
}
@@ -156,6 +159,11 @@ func (t *DiagnosticArthasTunnelTransport) ExecuteCommand(ctx context.Context, cf
activeCommand.attachConn(conn)
defer conn.Close()
if activeCommand.isCancelRequested() {
t.emitChunk(req, "canceled", "Arthas 命令已取消")
return fmt.Errorf("arthas tunnel command canceled")
}
if err := activeCommand.send(arthasTunnelTTYFrame{
Action: "resize",
Cols: arthasTunnelDefaultCols,
@@ -462,6 +470,7 @@ func (r *arthasTunnelSessionRegistry) createSession(cfg connection.ConnectionCon
sessionID := "arthas-" + uuid.NewString()
startedAt := time.Now().UnixMilli()
r.pruneLocked(startedAt)
r.sessions[sessionID] = arthasTunnelSessionMeta{
createdAt: startedAt,
targetID: strings.TrimSpace(cfg.JVM.Diagnostic.TargetID),
@@ -475,13 +484,18 @@ func (r *arthasTunnelSessionRegistry) createSession(cfg connection.ConnectionCon
}
}
func (r *arthasTunnelSessionRegistry) beginCommand(sessionID string, commandID string) (*arthasTunnelActiveCommand, error) {
func (r *arthasTunnelSessionRegistry) beginCommand(sessionID string, commandID string, cfg connection.ConnectionConfig) (*arthasTunnelActiveCommand, error) {
r.mu.Lock()
defer r.mu.Unlock()
if _, ok := r.sessions[sessionID]; !ok {
r.pruneLocked(time.Now().UnixMilli())
meta, ok := r.sessions[sessionID]
if !ok {
return nil, errors.New("诊断会话不存在,请重新创建 Arthas Tunnel 会话")
}
if !meta.matchesConfig(cfg) {
return nil, errors.New("Arthas Tunnel 会话配置已变化,请重新创建诊断会话")
}
if existing := r.active[sessionID]; existing != nil {
return nil, errors.New("当前 Arthas Tunnel 会话已有命令在执行,请先等待完成或取消")
}
@@ -504,6 +518,49 @@ func (r *arthasTunnelSessionRegistry) finishCommand(sessionID string, commandID
}
}
func (r *arthasTunnelSessionRegistry) pruneLocked(nowMillis int64) {
if len(r.sessions) == 0 {
return
}
cutoff := nowMillis - int64(arthasTunnelSessionTTL/time.Millisecond)
for sessionID, meta := range r.sessions {
if meta.createdAt > 0 && meta.createdAt < cutoff {
delete(r.sessions, sessionID)
delete(r.active, sessionID)
}
}
if len(r.sessions) <= arthasTunnelMaxSessions {
return
}
type sessionAge struct {
sessionID string
createdAt int64
}
items := make([]sessionAge, 0, len(r.sessions))
for sessionID, meta := range r.sessions {
items = append(items, sessionAge{sessionID: sessionID, createdAt: meta.createdAt})
}
sort.Slice(items, func(i, j int) bool {
return items[i].createdAt < items[j].createdAt
})
for len(r.sessions) > arthasTunnelMaxSessions && len(items) > 0 {
victim := items[0].sessionID
items = items[1:]
if _, active := r.active[victim]; active {
continue
}
delete(r.sessions, victim)
}
}
func (m arthasTunnelSessionMeta) matchesConfig(cfg connection.ConnectionConfig) bool {
return strings.TrimSpace(m.targetID) == strings.TrimSpace(cfg.JVM.Diagnostic.TargetID) &&
strings.TrimSpace(m.baseURL) == strings.TrimSpace(cfg.JVM.Diagnostic.BaseURL)
}
func (r *arthasTunnelSessionRegistry) cancelCommand(sessionID string, commandID string) error {
r.mu.Lock()
activeCommand := r.active[sessionID]
@@ -564,8 +621,12 @@ func (c *arthasTunnelActiveCommand) send(frame arthasTunnelTTYFrame) error {
func (c *arthasTunnelActiveCommand) requestCancel() error {
c.mu.Lock()
c.cancelRequested = true
conn := c.conn
c.mu.Unlock()
if conn == nil {
return nil
}
return c.send(arthasTunnelTTYFrame{
Action: "read",
Data: arthasTunnelInterruptInput,

View File

@@ -40,7 +40,7 @@ func newFakeArthasTunnelServer(
t.Helper()
fake := &fakeArthasTunnelServer{
t: t,
t: t,
upgrader: websocket.Upgrader{
CheckOrigin: func(*http.Request) bool { return true },
},
@@ -317,6 +317,78 @@ func TestDiagnosticArthasTunnelCancelCommandInterruptsActiveCommand(t *testing.T
}
}
func TestArthasTunnelActiveCommandAcceptsCancelBeforeConnectionAttach(t *testing.T) {
activeCommand := &arthasTunnelActiveCommand{commandID: "cmd-before-attach"}
if err := activeCommand.requestCancel(); err != nil {
t.Fatalf("expected pre-attach cancel request to be recorded without error, got %v", err)
}
if !activeCommand.isCancelRequested() {
t.Fatal("expected cancelRequested flag to be recorded")
}
}
func TestDiagnosticArthasTunnelRejectsSessionConfigDrift(t *testing.T) {
server := newFakeArthasTunnelServer(t, func(conn *websocket.Conn, frame fakeArthasTTYFrame) {
if frame.Action == "read" && strings.Contains(frame.Data, "thread -n 5") {
_ = conn.WriteMessage(websocket.TextMessage, []byte("thread top 5\r\n[arthas@12345]$ "))
}
})
defer server.close()
transport, err := NewDiagnosticTransport(DiagnosticTransportArthasTunnel)
if err != nil {
t.Fatalf("NewDiagnosticTransport returned error: %v", err)
}
tunnel := transport.(*DiagnosticArthasTunnelTransport)
cfg := testArthasTunnelConfig(server.wsURL())
handle, err := tunnel.StartSession(context.Background(), cfg, DiagnosticSessionRequest{})
if err != nil {
t.Fatalf("StartSession returned error: %v", err)
}
driftedCfg := cfg
driftedCfg.JVM.Diagnostic.TargetID = "orders-prod-02"
err = tunnel.ExecuteCommand(context.Background(), driftedCfg, DiagnosticCommandRequest{
SessionID: handle.SessionID,
CommandID: "cmd-drift",
Command: "thread -n 5",
})
if err == nil {
t.Fatal("expected config drift to reject stale Arthas Tunnel session")
}
if !strings.Contains(err.Error(), "会话配置已变化") {
t.Fatalf("expected config drift error, got %v", err)
}
}
func TestArthasTunnelSessionRegistryPrunesExpiredSessions(t *testing.T) {
registry := newArthasTunnelSessionRegistry()
cfg := testArthasTunnelConfig("http://127.0.0.1:7777")
oldHandle := registry.createSession(cfg)
registry.mu.Lock()
oldMeta := registry.sessions[oldHandle.SessionID]
oldMeta.createdAt = time.Now().Add(-24 * time.Hour).UnixMilli()
registry.sessions[oldHandle.SessionID] = oldMeta
registry.mu.Unlock()
registry.createSession(cfg)
registry.mu.Lock()
_, oldExists := registry.sessions[oldHandle.SessionID]
sessionCount := len(registry.sessions)
registry.mu.Unlock()
if oldExists {
t.Fatalf("expected expired session %s to be pruned", oldHandle.SessionID)
}
if sessionCount != 1 {
t.Fatalf("expected only fresh session to remain, got %d", sessionCount)
}
}
func TestDiagnosticArthasTunnelRequiresTargetID(t *testing.T) {
transport, err := NewDiagnosticTransport(DiagnosticTransportArthasTunnel)
if err != nil {

View File

@@ -17,19 +17,18 @@ func BuildChangePreview(
cfg connection.ConnectionConfig,
req ChangeRequest,
) (ChangePreview, error) {
req, err := NormalizeChangeRequest(req)
if err != nil {
return ChangePreview{}, err
}
normalized, err := NormalizeConnectionConfig(cfg)
if err != nil {
return ChangePreview{}, err
}
resourceID := strings.TrimSpace(req.ResourceID)
if resourceID == "" {
return ChangePreview{}, fmt.Errorf("resource id is required")
}
action := strings.TrimSpace(req.Action)
if action == "" {
return ChangePreview{}, fmt.Errorf("action is required")
}
resourceID := req.ResourceID
action := req.Action
before := ValueSnapshot{
ResourceID: resourceID,
@@ -111,6 +110,28 @@ func BuildChangePreview(
return preview, nil
}
func NormalizeChangeRequest(req ChangeRequest) (ChangeRequest, error) {
normalized := req
normalized.ProviderMode = strings.ToLower(strings.TrimSpace(normalized.ProviderMode))
normalized.ResourceID = strings.TrimSpace(normalized.ResourceID)
normalized.Action = strings.TrimSpace(normalized.Action)
normalized.Reason = strings.TrimSpace(normalized.Reason)
normalized.Source = strings.TrimSpace(normalized.Source)
normalized.ExpectedVersion = strings.TrimSpace(normalized.ExpectedVersion)
if normalized.ResourceID == "" {
return ChangeRequest{}, fmt.Errorf("resource id is required")
}
if normalized.Action == "" {
return ChangeRequest{}, fmt.Errorf("action is required")
}
if normalized.Reason == "" {
return ChangeRequest{}, fmt.Errorf("reason is required")
}
return normalized, nil
}
func hasSnapshotOverride(snapshot ValueSnapshot) bool {
return strings.TrimSpace(snapshot.ResourceID) != "" ||
strings.TrimSpace(snapshot.Kind) != "" ||

View File

@@ -76,6 +76,32 @@ func TestPreviewChangeBlocksReadOnlyConnection(t *testing.T) {
}
}
func TestPreviewChangeRejectsMissingReason(t *testing.T) {
readOnly := false
_, err := BuildChangePreview(context.Background(), fakeGuardProvider{}, connection.ConnectionConfig{
Type: "jvm",
ID: "conn-writable",
Host: "orders.internal",
JVM: connection.JVMConfig{
ReadOnly: &readOnly,
PreferredMode: ModeJMX,
AllowedModes: []string{ModeJMX},
},
}, ChangeRequest{
ProviderMode: ModeJMX,
ResourceID: "/cache/orders",
Action: "put",
Reason: " ",
Payload: map[string]any{
"status": "ready",
},
})
if err == nil || !strings.Contains(err.Error(), "reason is required") {
t.Fatalf("expected missing reason to be rejected, got %v", err)
}
}
func TestPreviewChangeReturnsProviderPreviewErrorWhenWriteAllowed(t *testing.T) {
readOnly := false

View File

@@ -53,6 +53,8 @@ func normalizeContractBaseURL(rawBaseURL string, errorPrefix string) (*url.URL,
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return nil, fmt.Errorf("%s scheme is unsupported: %s", errorPrefix, parsed.Scheme)
}
parsed.RawQuery = ""
parsed.Fragment = ""
return parsed, nil
}

View File

@@ -2,6 +2,8 @@ package jvm
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
@@ -89,6 +91,32 @@ func TestHTTPProviderTestConnectionReturnsErrorWhenBaseURLInvalid(t *testing.T)
}
}
func TestHTTPProviderProbeStripsBaseURLQueryAndFragment(t *testing.T) {
provider := NewHTTPProvider()
seen := make(chan string, 1)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seen <- r.URL.RequestURI()
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
err := provider.TestConnection(context.Background(), connection.ConnectionConfig{
Type: "jvm",
JVM: connection.JVMConfig{
Endpoint: connection.JVMEndpointConfig{
BaseURL: server.URL + "/gonavi/jvm?api_key=secret-token#debug",
},
},
})
if err != nil {
t.Fatalf("expected probe to succeed, got %v", err)
}
if got := <-seen; got != "/gonavi/jvm" {
t.Fatalf("expected query and fragment to be stripped, got %q", got)
}
}
func TestAgentProviderTestConnectionReturnsErrorWhenBaseURLMissing(t *testing.T) {
provider := NewAgentProvider()