From acfa112415b1f64ff9d82cb871ef52f6118249ea Mon Sep 17 00:00:00 2001 From: Syngnat Date: Tue, 9 Jun 2026 16:45:39 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(ai-chat):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E6=B5=81=E5=BC=8F=E5=9B=9E=E5=A4=8D=E5=88=86=E8=A3=82?= =?UTF-8?q?=E4=B8=BA=E5=A4=9A=E4=B8=AA=E6=B0=94=E6=B3=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 持久化流式回复状态,避免重渲染后丢失当前 assistant 消息 - 补充回归测试覆盖 chunk 追加到同一气泡 --- .../AIChatPanel.message-boundary.test.tsx | 7 +- .../ai/useAIChatStreamSubscription.test.tsx | 167 ++++++++++++++++++ .../ai/useAIChatStreamSubscription.ts | 91 ++++++---- 3 files changed, 232 insertions(+), 33 deletions(-) create mode 100644 frontend/src/components/ai/useAIChatStreamSubscription.test.tsx diff --git a/frontend/src/components/AIChatPanel.message-boundary.test.tsx b/frontend/src/components/AIChatPanel.message-boundary.test.tsx index a86b554..008baa7 100644 --- a/frontend/src/components/AIChatPanel.message-boundary.test.tsx +++ b/frontend/src/components/AIChatPanel.message-boundary.test.tsx @@ -12,6 +12,7 @@ const resizeSource = readFileSync(new URL('./ai/useAIChatPanelResize.ts', import const runtimeResourcesSource = readFileSync(new URL('./ai/useAIChatRuntimeResources.ts', import.meta.url), 'utf8'); const sessionStateSource = readFileSync(new URL('./ai/useAIChatSessionState.ts', import.meta.url), 'utf8'); const streamSubscriptionSource = readFileSync(new URL('./ai/useAIChatStreamSubscription.ts', import.meta.url), 'utf8'); +const inspectionGuidanceSource = readFileSync(new URL('./ai/aiSystemInspectionGuidance.ts', import.meta.url), 'utf8'); const systemContextSource = readFileSync(new URL('./ai/aiSystemContextMessages.ts', import.meta.url), 'utf8'); const runtimeSource = readFileSync(new URL('../utils/aiChatRuntime.ts', import.meta.url), 'utf8'); @@ -48,9 +49,9 @@ describe('AIChatPanel message render isolation', () => { expect(systemContextSource).toContain('get_indexes、get_foreign_keys、get_triggers、get_table_ddl'); expect(systemContextSource).toContain('inspect_active_tab 读取当前活动页签上下文'); expect(systemContextSource).toContain('inspect_workspace_tabs 盘点当前工作区'); - expect(systemContextSource).toContain('inspect_current_connection'); - expect(systemContextSource).toContain('inspect_external_sql_directories'); - expect(systemContextSource).toContain('inspect_external_sql_file'); + expect(inspectionGuidanceSource).toContain('inspect_current_connection'); + expect(inspectionGuidanceSource).toContain('inspect_external_sql_directories'); + expect(inspectionGuidanceSource).toContain('inspect_external_sql_file'); expect(source).toContain('tabs: useStore.getState().tabs'); expect(source).toContain('activeTabId: useStore.getState().activeTabId'); expect(source).toContain('externalSQLDirectories: useStore.getState().externalSQLDirectories'); diff --git a/frontend/src/components/ai/useAIChatStreamSubscription.test.tsx b/frontend/src/components/ai/useAIChatStreamSubscription.test.tsx new file mode 100644 index 0000000..a02a0e0 --- /dev/null +++ b/frontend/src/components/ai/useAIChatStreamSubscription.test.tsx @@ -0,0 +1,167 @@ +import React, { useRef, useState } from 'react'; +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; +import { act, create, type ReactTestRenderer } from 'react-test-renderer'; + +import { useStore } from '../../store'; +import { useAIChatStreamSubscription } from './useAIChatStreamSubscription'; + +const runtimeMock = vi.hoisted(() => { + const handlers = new Map void>(); + return { + handlers, + EventsOn: vi.fn((eventName: string, handler: (data: any) => void) => { + handlers.set(eventName, handler); + }), + EventsOff: vi.fn((eventName: string) => { + handlers.delete(eventName); + }), + }; +}); + +vi.mock('../../../wailsjs/runtime', () => ({ + EventsOn: runtimeMock.EventsOn, + EventsOff: runtimeMock.EventsOff, +})); + +const SESSION_ID = 'session-stream'; +let nextId = 0; + +const emitStreamChunk = async (data: any) => { + const handler = runtimeMock.handlers.get(`ai:stream:${SESSION_ID}`); + expect(handler).toBeTypeOf('function'); + await act(async () => { + handler?.(data); + await Promise.resolve(); + }); +}; + +const appendMessage = ( + sessionId: string, + message: Parameters['addAIChatMessage']>[1], +) => { + useStore.setState((state) => { + const messages = state.aiChatHistory[sessionId] || []; + return { + aiChatHistory: { + ...state.aiChatHistory, + [sessionId]: [...messages, message], + }, + }; + }); +}; + +const patchMessage = ( + sessionId: string, + messageId: string, + patch: Parameters['updateAIChatMessage']>[2], +) => { + useStore.setState((state) => { + const messages = state.aiChatHistory[sessionId]; + if (!messages) { + return state; + } + return { + aiChatHistory: { + ...state.aiChatHistory, + [sessionId]: messages.map((message) => + message.id === messageId ? { ...message, ...patch } : message, + ), + }, + }; + }); +}; + +const StreamHarness = () => { + const [sending, setSending] = useState(true); + const nudgeCountRef = useRef(0); + const pendingJVMPlanContextRef = useRef(undefined); + const pendingJVMDiagnosticPlanContextRef = useRef(undefined); + + useAIChatStreamSubscription({ + sid: SESSION_ID, + sending, + setSending, + availableTools: [], + addAIChatMessage: appendMessage, + updateAIChatMessage: patchMessage, + buildSystemContextMessages: async () => [], + executeLocalTools: async () => {}, + generateTitleForSession: async () => {}, + nextMessageId: () => `assistant-created-${++nextId}`, + nudgeCountRef, + pendingJVMPlanContextRef, + pendingJVMDiagnosticPlanContextRef, + }); + + return null; +}; + +describe('useAIChatStreamSubscription', () => { + beforeEach(() => { + nextId = 0; + runtimeMock.handlers.clear(); + runtimeMock.EventsOn.mockClear(); + runtimeMock.EventsOff.mockClear(); + vi.stubGlobal('requestAnimationFrame', (callback: FrameRequestCallback) => { + callback(0); + return 1; + }); + useStore.setState({ + aiChatHistory: { + [SESSION_ID]: [ + { + id: 'user-1', + role: 'user', + content: 'hello', + timestamp: 1, + }, + { + id: 'assistant-connecting', + role: 'assistant', + phase: 'connecting', + content: '', + timestamp: 2, + loading: true, + }, + ], + }, + aiChatSessions: [{ id: SESSION_ID, title: 'hello', updatedAt: 1 }], + aiActiveSessionId: SESSION_ID, + }); + }); + + afterEach(() => { + vi.unstubAllGlobals(); + useStore.setState({ + aiChatHistory: {}, + aiChatSessions: [], + aiActiveSessionId: null, + }); + }); + + it('keeps streamed chunks in the same assistant message after a parent rerender', async () => { + let renderer: ReactTestRenderer | undefined; + + await act(async () => { + renderer = create(); + }); + + await emitStreamChunk({ content: 'Hello' }); + await emitStreamChunk({ content: ' world' }); + + const messages = useStore.getState().aiChatHistory[SESSION_ID] || []; + const assistantMessages = messages.filter((message) => message.role === 'assistant'); + + expect(assistantMessages).toHaveLength(1); + expect(assistantMessages[0]).toMatchObject({ + id: 'assistant-connecting', + phase: 'generating', + content: 'Hello world', + loading: true, + }); + + await act(async () => { + renderer?.unmount(); + }); + }); +}); diff --git a/frontend/src/components/ai/useAIChatStreamSubscription.ts b/frontend/src/components/ai/useAIChatStreamSubscription.ts index 99f6d28..5c5d90c 100644 --- a/frontend/src/components/ai/useAIChatStreamSubscription.ts +++ b/frontend/src/components/ai/useAIChatStreamSubscription.ts @@ -45,6 +45,35 @@ interface UseAIChatStreamSubscriptionOptions { pendingJVMDiagnosticPlanContextRef: MutableRefObject; } +interface AIChatStreamState { + sid: string; + assistantMsgId: string; + isFirstCompletion: boolean; + streamBuffer: { + thinking: string; + reasoningContent: string; + content: string; + }; + flushPending: boolean; +} + +const createAIChatStreamState = (sid: string): AIChatStreamState => ({ + sid, + assistantMsgId: '', + isFirstCompletion: false, + streamBuffer: { thinking: '', reasoningContent: '', content: '' }, + flushPending: false, +}); + +const resetAIChatStreamProgress = (state: AIChatStreamState) => { + state.assistantMsgId = ''; + state.isFirstCompletion = false; + state.streamBuffer.thinking = ''; + state.streamBuffer.reasoningContent = ''; + state.streamBuffer.content = ''; + state.flushPending = false; +}; + export const useAIChatStreamSubscription = ({ sid, sending, @@ -61,6 +90,7 @@ export const useAIChatStreamSubscription = ({ pendingJVMDiagnosticPlanContextRef, }: UseAIChatStreamSubscriptionOptions) => { const sendingRef = useRef(sending); + const streamStateRef = useRef(createAIChatStreamState(sid)); useEffect(() => { sendingRef.current = sending; @@ -68,17 +98,19 @@ export const useAIChatStreamSubscription = ({ useEffect(() => { const eventName = `ai:stream:${sid}`; - let assistantMsgId = ''; - let isFirstCompletion = false; + if (streamStateRef.current.sid !== sid) { + streamStateRef.current = createAIChatStreamState(sid); + } + const streamState = streamStateRef.current; // 缓冲高频 token,避免把流式吞吐直接转成同步重绘风暴 - const streamBuffer = { thinking: '', reasoningContent: '', content: '' }; - let flushPending = false; + const streamBuffer = streamState.streamBuffer; const flushStreamBuffer = () => { - if (!assistantMsgId) return; + streamState.flushPending = false; + if (!streamState.assistantMsgId) return; const current = useStore.getState().aiChatHistory[sid]; - const existing = current?.find((message) => message.id === assistantMsgId); + const existing = current?.find((message) => message.id === streamState.assistantMsgId); if (!existing) return; const updates: Partial = {}; @@ -98,26 +130,25 @@ export const useAIChatStreamSubscription = ({ } if (Object.keys(updates).length > 0) { - updateAIChatMessage(sid, assistantMsgId, updates); + updateAIChatMessage(sid, streamState.assistantMsgId, updates); } - flushPending = false; }; const handler = (data: AIChatStreamChunk) => { - if (!assistantMsgId) { + if (!streamState.assistantMsgId) { const history = useStore.getState().aiChatHistory[sid] || []; const lastMsg = history[history.length - 1]; if (lastMsg && lastMsg.role === 'assistant' && lastMsg.loading && lastMsg.phase === 'connecting') { - assistantMsgId = lastMsg.id; - updateAIChatMessage(sid, assistantMsgId, { content: '' }); + streamState.assistantMsgId = lastMsg.id; + updateAIChatMessage(sid, streamState.assistantMsgId, { content: '' }); } } if (data.error) { const cleanErr = sanitizeErrorMsg(data.error); const rawErr = cleanErr !== data.error ? data.error : undefined; - if (assistantMsgId) { - updateAIChatMessage(sid, assistantMsgId, { + if (streamState.assistantMsgId) { + updateAIChatMessage(sid, streamState.assistantMsgId, { content: `❌ 错误: ${cleanErr}`, phase: 'idle', loading: false, @@ -135,18 +166,18 @@ export const useAIChatStreamSubscription = ({ jvmDiagnosticPlanContext: pendingJVMDiagnosticPlanContextRef.current, }); } - assistantMsgId = ''; + resetAIChatStreamProgress(streamState); setSending(false); return; } if (data.tool_calls && data.tool_calls.length > 0) { - if (assistantMsgId) { - updateAIChatMessage(sid, assistantMsgId, { tool_calls: data.tool_calls, phase: 'tool_calling' }); + if (streamState.assistantMsgId) { + updateAIChatMessage(sid, streamState.assistantMsgId, { tool_calls: data.tool_calls, phase: 'tool_calling' }); } else { - assistantMsgId = nextMessageId(); + streamState.assistantMsgId = nextMessageId(); addAIChatMessage(sid, { - id: assistantMsgId, + id: streamState.assistantMsgId, role: 'assistant', phase: 'tool_calling', content: '', @@ -161,10 +192,10 @@ export const useAIChatStreamSubscription = ({ const displayThinking = data.thinking || data.reasoning_content || ''; if (displayThinking || data.reasoning_content) { - if (!assistantMsgId) { - assistantMsgId = nextMessageId(); + if (!streamState.assistantMsgId) { + streamState.assistantMsgId = nextMessageId(); addAIChatMessage(sid, { - id: assistantMsgId, + id: streamState.assistantMsgId, role: 'assistant', phase: 'thinking', content: '', @@ -186,10 +217,10 @@ export const useAIChatStreamSubscription = ({ } if (data.content) { - if (!assistantMsgId) { - assistantMsgId = nextMessageId(); + if (!streamState.assistantMsgId) { + streamState.assistantMsgId = nextMessageId(); addAIChatMessage(sid, { - id: assistantMsgId, + id: streamState.assistantMsgId, role: 'assistant', phase: 'generating', content: data.content, @@ -200,7 +231,7 @@ export const useAIChatStreamSubscription = ({ }); setSending(false); const currentHistory = useStore.getState().aiChatHistory[sid] || []; - if (currentHistory.length <= 1) isFirstCompletion = true; + if (currentHistory.length <= 1) streamState.isFirstCompletion = true; } else { streamBuffer.content += data.content; if (sendingRef.current) setSending(false); @@ -208,8 +239,8 @@ export const useAIChatStreamSubscription = ({ } if (streamBuffer.thinking || streamBuffer.reasoningContent || streamBuffer.content) { - if (!flushPending) { - flushPending = true; + if (!streamState.flushPending) { + streamState.flushPending = true; requestAnimationFrame(flushStreamBuffer); } } @@ -218,9 +249,9 @@ export const useAIChatStreamSubscription = ({ if (streamBuffer.thinking || streamBuffer.reasoningContent || streamBuffer.content) { flushStreamBuffer(); } - const doneAssistantId = assistantMsgId; - const doneIsFirst = isFirstCompletion; - assistantMsgId = ''; + const doneAssistantId = streamState.assistantMsgId; + const doneIsFirst = streamState.isFirstCompletion; + resetAIChatStreamProgress(streamState); setTimeout(() => { const currentMsgs = useStore.getState().aiChatHistory[sid] || []; for (const msg of currentMsgs) {