diff --git a/frontend/src/components/ai/aiChatPayloadDispatch.test.ts b/frontend/src/components/ai/aiChatPayloadDispatch.test.ts index 4109bd6..7b506ac 100644 --- a/frontend/src/components/ai/aiChatPayloadDispatch.test.ts +++ b/frontend/src/components/ai/aiChatPayloadDispatch.test.ts @@ -38,8 +38,8 @@ describe('aiChatPayloadDispatch', () => { expect(setSending).not.toHaveBeenCalled(); }); - it('appends a non-stream assistant message when only AIChatSend is available', async () => { - const AIChatSend = vi.fn().mockResolvedValue({ + it('appends a non-stream assistant message when session-aware send is available', async () => { + const AIChatSendInSession = vi.fn().mockResolvedValue({ success: true, content: 'done', reasoning_content: 'thinking', @@ -51,7 +51,7 @@ describe('aiChatPayloadDispatch', () => { (globalThis as any).window = { go: { aiservice: { - Service: { AIChatSend }, + Service: { AIChatSendInSession }, }, }, }; @@ -67,6 +67,7 @@ describe('aiChatPayloadDispatch', () => { }); expect(result).toBe('send'); + expect(AIChatSendInSession).toHaveBeenCalledWith('session-1', [{ role: 'user', content: 'hello' }], []); expect(addAIChatMessage).toHaveBeenCalledWith('session-1', expect.objectContaining({ id: 'msg-send', role: 'assistant', @@ -79,7 +80,7 @@ describe('aiChatPayloadDispatch', () => { }); it('settles the pending assistant message when falling back to non-stream send', async () => { - const AIChatSend = vi.fn().mockResolvedValue({ + const AIChatSendInSession = vi.fn().mockResolvedValue({ success: true, content: 'done', reasoning_content: 'thinking', @@ -91,7 +92,7 @@ describe('aiChatPayloadDispatch', () => { (globalThis as any).window = { go: { aiservice: { - Service: { AIChatSend }, + Service: { AIChatSendInSession }, }, }, }; @@ -108,6 +109,7 @@ describe('aiChatPayloadDispatch', () => { }); expect(result).toBe('send'); + expect(AIChatSendInSession).toHaveBeenCalledWith('session-1', [{ role: 'user', content: 'hello' }], []); expect(addAIChatMessage).not.toHaveBeenCalled(); expect(updateAIChatMessage).toHaveBeenCalledWith('session-1', 'assistant-connecting', expect.objectContaining({ content: 'done', @@ -119,6 +121,44 @@ describe('aiChatPayloadDispatch', () => { expect(setSending).toHaveBeenCalledWith(false); }); + it('falls back to stateless AIChatSend when session-aware send is unavailable', async () => { + const AIChatSend = vi.fn().mockResolvedValue({ + success: true, + content: 'done', + reasoning_content: 'thinking', + }); + const addAIChatMessage = vi.fn(); + const setSending = vi.fn(); + + (globalThis as any).window = { + go: { + aiservice: { + Service: { AIChatSend }, + }, + }, + }; + + const result = await dispatchAIChatPayload({ + sid: 'session-1', + messages: [{ role: 'user', content: 'hello' }], + tools: [], + addAIChatMessage, + setSending, + nextMessageId: () => 'msg-send', + }); + + expect(result).toBe('send'); + expect(AIChatSend).toHaveBeenCalledWith([{ role: 'user', content: 'hello' }], []); + expect(addAIChatMessage).toHaveBeenCalledWith('session-1', expect.objectContaining({ + id: 'msg-send', + role: 'assistant', + content: 'done', + thinking: 'thinking', + reasoning_content: 'thinking', + })); + expect(setSending).toHaveBeenCalledWith(false); + }); + it('emits the unavailable message when the AI service is missing', async () => { const addAIChatMessage = vi.fn(); const setSending = vi.fn(); diff --git a/frontend/src/components/ai/aiChatPayloadDispatch.ts b/frontend/src/components/ai/aiChatPayloadDispatch.ts index 14660be..40f6333 100644 --- a/frontend/src/components/ai/aiChatPayloadDispatch.ts +++ b/frontend/src/components/ai/aiChatPayloadDispatch.ts @@ -8,6 +8,7 @@ import { sanitizeErrorMsg } from '../../utils/aiChatRuntime'; interface AIChatService { AIChatStream?: (sid: string, messages: any[], tools: AIChatToolDefinition[]) => Promise; + AIChatSendInSession?: (sid: string, messages: any[], tools: AIChatToolDefinition[]) => Promise; AIChatSend?: (messages: any[], tools: AIChatToolDefinition[]) => Promise; } @@ -92,8 +93,10 @@ export const dispatchAIChatPayload = async ({ return 'stream'; } - if (service?.AIChatSend) { - const result = await service.AIChatSend(messages, tools); + if (service?.AIChatSendInSession || service?.AIChatSend) { + const result = service?.AIChatSendInSession + ? await service.AIChatSendInSession(sid, messages, tools) + : await service!.AIChatSend!(messages, tools); const rawError = result?.error || '未知错误'; const cleanError = sanitizeErrorMsg(rawError); diff --git a/frontend/wailsjs/go/aiservice/Service.d.ts b/frontend/wailsjs/go/aiservice/Service.d.ts index 9f85ebc..847aa0e 100755 --- a/frontend/wailsjs/go/aiservice/Service.d.ts +++ b/frontend/wailsjs/go/aiservice/Service.d.ts @@ -8,6 +8,8 @@ export function AIChatCancel(arg1:string):Promise; export function AIChatSend(arg1:Array,arg2:Array):Promise>; +export function AIChatSendInSession(arg1:string,arg2:Array,arg3:Array):Promise>; + export function AIChatStream(arg1:string,arg2:Array,arg3:Array):Promise; export function AICheckSQL(arg1:string):Promise; diff --git a/frontend/wailsjs/go/aiservice/Service.js b/frontend/wailsjs/go/aiservice/Service.js index 0b44f60..5a7f078 100755 --- a/frontend/wailsjs/go/aiservice/Service.js +++ b/frontend/wailsjs/go/aiservice/Service.js @@ -14,6 +14,10 @@ export function AIChatSend(arg1, arg2) { return window['go']['aiservice']['Service']['AIChatSend'](arg1, arg2); } +export function AIChatSendInSession(arg1, arg2, arg3) { + return window['go']['aiservice']['Service']['AIChatSendInSession'](arg1, arg2, arg3); +} + export function AIChatStream(arg1, arg2, arg3) { return window['go']['aiservice']['Service']['AIChatStream'](arg1, arg2, arg3); } diff --git a/internal/ai/provider/codebuddy_cli.go b/internal/ai/provider/codebuddy_cli.go index 84714c8..4bfb169 100644 --- a/internal/ai/provider/codebuddy_cli.go +++ b/internal/ai/provider/codebuddy_cli.go @@ -24,6 +24,10 @@ type CodeBuddyCLIProvider struct { config ai.ProviderConfig } +type codebuddySessionState struct { + SessionID string `json:"sessionId,omitempty"` +} + // NewCodeBuddyCLIProvider 创建 CodeBuddyCLIProvider 实例。 func NewCodeBuddyCLIProvider(config ai.ProviderConfig) (Provider, error) { return &CodeBuddyCLIProvider{config: config}, nil @@ -42,8 +46,13 @@ func (p *CodeBuddyCLIProvider) Validate() error { } func (p *CodeBuddyCLIProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.ChatResponse, error) { + resp, _, err := p.ChatWithState(ctx, nil, req) + return resp, err +} + +func (p *CodeBuddyCLIProvider) ChatWithState(ctx context.Context, state json.RawMessage, req ai.ChatRequest) (*ai.ChatResponse, json.RawMessage, error) { if err := p.Validate(); err != nil { - return nil, err + return nil, nil, err } ctx, cancel := ensureClaudeCLITimeout(ctx, codebuddyCLIRequestTimeout) @@ -51,18 +60,25 @@ func (p *CodeBuddyCLIProvider) Chat(ctx context.Context, req ai.ChatRequest) (*a commandName, err := resolveCodeBuddyCLICommand(codebuddyLookPath) if err != nil { - return nil, err + return nil, nil, err } + sessionState, err := parseCodeBuddySessionState(state) + if err != nil { + return nil, nil, err + } prompt := buildPrompt(req.Messages) - args := []string{"-p", prompt, "--output-format", "json", "--no-session-persistence"} + args := []string{"-p", prompt, "--output-format", "json", "--enable-session-tracking"} if strings.TrimSpace(p.config.Model) != "" { args = append(args, "--model", strings.TrimSpace(p.config.Model)) } + if strings.TrimSpace(sessionState.SessionID) != "" { + args = append(args, "--resume", strings.TrimSpace(sessionState.SessionID)) + } cmd := codebuddyCommandContext(ctx, commandName, args...) if err := p.setEnv(cmd); err != nil { - return nil, err + return nil, nil, err } requestLog := logAIUpstreamRequestStart( @@ -80,27 +96,55 @@ func (p *CodeBuddyCLIProvider) Chat(ctx context.Context, req ai.ChatRequest) (*a if err != nil { if isClaudeCLITimeout(ctx, err) { requestErr = fmt.Errorf("CodeBuddy CLI 执行超时(%s),当前登录态、Base URL 或 API Key 可能没有返回有效响应", codebuddyCLIRequestTimeout) - return nil, requestErr + return nil, nil, requestErr } if exitErr, ok := err.(*exec.ExitError); ok { requestErr = fmt.Errorf("CodeBuddy CLI 执行失败: %s", string(exitErr.Stderr)) - return nil, requestErr + return nil, nil, requestErr } requestErr = fmt.Errorf("CodeBuddy CLI 执行失败: %w", err) - return nil, requestErr + return nil, nil, requestErr } - resp, parseErr := parseCodeBuddyCLIChatOutput(output) + resp, nextSessionID, parseErr := parseCodeBuddyCLIChatOutput(output) if parseErr != nil { requestErr = parseErr - return nil, requestErr + return nil, nil, requestErr } - return resp, nil + if strings.TrimSpace(nextSessionID) == "" { + nextSessionID = strings.TrimSpace(sessionState.SessionID) + } + nextState, err := marshalCodeBuddySessionState(nextSessionID) + if err != nil { + return nil, nil, err + } + return resp, nextState, nil } func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error { + _, err := p.ChatStreamWithState(ctx, nil, req, callback) + return err +} + +func (p *CodeBuddyCLIProvider) ChatStreamWithState(ctx context.Context, state json.RawMessage, req ai.ChatRequest, callback func(ai.StreamChunk)) (json.RawMessage, error) { + sessionState, err := parseCodeBuddySessionState(state) + if err != nil { + return nil, err + } + + sessionID, err := p.chatStreamWithSession(ctx, strings.TrimSpace(sessionState.SessionID), req, callback) + if err != nil { + return nil, err + } + if strings.TrimSpace(sessionID) == "" { + sessionID = strings.TrimSpace(sessionState.SessionID) + } + return marshalCodeBuddySessionState(sessionID) +} + +func (p *CodeBuddyCLIProvider) chatStreamWithSession(ctx context.Context, resumeSessionID string, req ai.ChatRequest, callback func(ai.StreamChunk)) (string, error) { if err := p.Validate(); err != nil { - return err + return "", err } ctx, cancel := ensureClaudeCLITimeout(ctx, codebuddyCLIRequestTimeout) @@ -108,18 +152,21 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques commandName, err := resolveCodeBuddyCLICommand(codebuddyLookPath) if err != nil { - return err + return "", err } prompt := buildPrompt(req.Messages) - args := []string{"-p", prompt, "--output-format", "stream-json", "--verbose", "--include-partial-messages", "--no-session-persistence"} + args := []string{"-p", prompt, "--output-format", "stream-json", "--verbose", "--include-partial-messages", "--enable-session-tracking"} if strings.TrimSpace(p.config.Model) != "" { args = append(args, "--model", strings.TrimSpace(p.config.Model)) } + if strings.TrimSpace(resumeSessionID) != "" { + args = append(args, "--resume", strings.TrimSpace(resumeSessionID)) + } cmd := codebuddyCommandContext(ctx, commandName, args...) if err := p.setEnv(cmd); err != nil { - return err + return "", err } requestLog := logAIUpstreamRequestStart( @@ -138,7 +185,7 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques stdout, err := cmd.StdoutPipe() if err != nil { requestErr = fmt.Errorf("创建 stdout 管道失败: %w", err) - return requestErr + return "", requestErr } var stderrBuf bytes.Buffer @@ -146,7 +193,7 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques if err := cmd.Start(); err != nil { requestErr = fmt.Errorf("启动 CodeBuddy CLI 失败: %w", err) - return requestErr + return "", requestErr } if cmd.Process != nil { @@ -155,6 +202,7 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques scanner := bufio.NewScanner(stdout) scanner.Buffer(make([]byte, 64*1024), 1024*1024) + currentSessionID := strings.TrimSpace(resumeSessionID) for scanner.Scan() { line := scanner.Text() @@ -167,6 +215,9 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques logger.Warnf("CodeBuddyCLI 忽略非 JSON 输出:requestId=%s line=%s", requestLog.id, RedactAIUpstreamLogText(line)) continue } + if strings.TrimSpace(event.SessionID) != "" { + currentSessionID = strings.TrimSpace(event.SessionID) + } switch event.Type { case "system": @@ -178,7 +229,7 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques _ = cmd.Process.Kill() } _ = cmd.Wait() - return nil + return "", nil } } case "assistant": @@ -186,7 +237,7 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques callback(ai.StreamChunk{Error: errMsg, Done: true}) requestErr = fmt.Errorf("CodeBuddy CLI 返回错误: %s", errMsg) _ = cmd.Wait() - return nil + return "", nil } if event.Message.Content != nil { for _, block := range event.Message.Content { @@ -208,17 +259,17 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques callback(ai.StreamChunk{Error: errMsg, Done: true}) requestErr = fmt.Errorf("CodeBuddy CLI 返回错误: %s", errMsg) _ = cmd.Wait() - return nil + return "", nil } callback(ai.StreamChunk{Done: true}) _ = cmd.Wait() - return nil + return currentSessionID, nil case "error": errMsg, _ := extractCodeBuddyCLIEventError(event) callback(ai.StreamChunk{Error: errMsg, Done: true}) requestErr = fmt.Errorf("CodeBuddy CLI 返回错误: %s", errMsg) _ = cmd.Wait() - return nil + return "", nil } } @@ -231,7 +282,7 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques Error: requestErr.Error(), Done: true, }) - return nil + return "", nil } if waitErr != nil { @@ -241,11 +292,38 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques } requestErr = fmt.Errorf("%s", errMsg) callback(ai.StreamChunk{Error: errMsg, Done: true}) - return nil + return "", nil } callback(ai.StreamChunk{Done: true}) - return nil + return currentSessionID, nil +} + +func parseCodeBuddySessionState(state json.RawMessage) (codebuddySessionState, error) { + trimmed := bytes.TrimSpace(state) + if len(trimmed) == 0 { + return codebuddySessionState{}, nil + } + + var sessionState codebuddySessionState + if err := json.Unmarshal(trimmed, &sessionState); err != nil { + return codebuddySessionState{}, fmt.Errorf("解析 CodeBuddy 会话状态失败: %w", err) + } + sessionState.SessionID = strings.TrimSpace(sessionState.SessionID) + return sessionState, nil +} + +func marshalCodeBuddySessionState(sessionID string) (json.RawMessage, error) { + sessionID = strings.TrimSpace(sessionID) + if sessionID == "" { + return nil, nil + } + + payload, err := json.Marshal(codebuddySessionState{SessionID: sessionID}) + if err != nil { + return nil, fmt.Errorf("序列化 CodeBuddy 会话状态失败: %w", err) + } + return json.RawMessage(payload), nil } func resolveCodeBuddyCLICommand(lookPath func(string) (string, error)) (string, error) { @@ -280,10 +358,10 @@ func buildCodeBuddyCLIRequestLogBody(outputFormat string, commandName string, ar } } -func parseCodeBuddyCLIChatOutput(output []byte) (*ai.ChatResponse, error) { +func parseCodeBuddyCLIChatOutput(output []byte) (*ai.ChatResponse, string, error) { trimmed := bytes.TrimSpace(output) if len(trimmed) == 0 { - return &ai.ChatResponse{}, nil + return &ai.ChatResponse{}, "", nil } var events []cliStreamEvent @@ -296,20 +374,24 @@ func parseCodeBuddyCLIChatOutput(output []byte) (*ai.ChatResponse, error) { return buildCodeBuddyCLIResponseFromEvents([]cliStreamEvent{event}) } - return &ai.ChatResponse{Content: strings.TrimSpace(string(output))}, nil + return &ai.ChatResponse{Content: strings.TrimSpace(string(output))}, "", nil } -func buildCodeBuddyCLIResponseFromEvents(events []cliStreamEvent) (*ai.ChatResponse, error) { +func buildCodeBuddyCLIResponseFromEvents(events []cliStreamEvent) (*ai.ChatResponse, string, error) { parts := make([]string, 0, len(events)) resultText := "" + sessionID := "" for _, event := range events { if errMsg, hasError := extractCodeBuddyCLIEventError(event); hasError { - return nil, fmt.Errorf("CodeBuddy CLI 返回错误: %s", errMsg) + return nil, "", fmt.Errorf("CodeBuddy CLI 返回错误: %s", errMsg) } if strings.TrimSpace(event.Result) != "" { resultText = strings.TrimSpace(event.Result) } + if strings.TrimSpace(event.SessionID) != "" { + sessionID = strings.TrimSpace(event.SessionID) + } for _, block := range event.Message.Content { if block.Type == "text" && strings.TrimSpace(block.Text) != "" { parts = append(parts, block.Text) @@ -318,12 +400,12 @@ func buildCodeBuddyCLIResponseFromEvents(events []cliStreamEvent) (*ai.ChatRespo } if resultText != "" { - return &ai.ChatResponse{Content: resultText}, nil + return &ai.ChatResponse{Content: resultText}, sessionID, nil } if len(parts) > 0 { - return &ai.ChatResponse{Content: strings.Join(parts, "")}, nil + return &ai.ChatResponse{Content: strings.Join(parts, "")}, sessionID, nil } - return &ai.ChatResponse{}, nil + return &ai.ChatResponse{}, sessionID, nil } func (p *CodeBuddyCLIProvider) setEnv(cmd *exec.Cmd) error { diff --git a/internal/ai/provider/codebuddy_cli_test.go b/internal/ai/provider/codebuddy_cli_test.go index 6112fa8..a65941a 100644 --- a/internal/ai/provider/codebuddy_cli_test.go +++ b/internal/ai/provider/codebuddy_cli_test.go @@ -2,6 +2,7 @@ package provider import ( "context" + "encoding/json" "errors" "os" "os/exec" @@ -79,6 +80,191 @@ func TestCodeBuddyCLIProvider_ChatParsesJSONEventArray(t *testing.T) { } } +func TestCodeBuddyCLIProviderChatWithState_StartsTrackedSession(t *testing.T) { + fakeCodeBuddy := writeFakeCodeBuddyScript(t, "#!/bin/sh\necho '[{\"type\":\"assistant\",\"session_id\":\"session-new\",\"message\":{\"content\":[{\"type\":\"text\",\"text\":\"hello \"}]}},{\"type\":\"result\",\"subtype\":\"success\",\"is_error\":false,\"result\":\"hello world\",\"session_id\":\"session-new\"}]'\n") + var capturedArgs []string + restore := overrideCodeBuddyCLIForTestWithCapture(t, fakeCodeBuddy, func(args []string) { + capturedArgs = append([]string(nil), args...) + }) + defer restore() + + providerInstance, err := NewCodeBuddyCLIProvider(ai.ProviderConfig{ + APIKey: "cb-test", + Model: "deepseek-v3", + }) + if err != nil { + t.Fatalf("unexpected provider error: %v", err) + } + + resp, nextState, err := providerInstance.(SessionChatProvider).ChatWithState( + context.Background(), + nil, + ai.ChatRequest{ + Messages: []ai.Message{{Role: "user", Content: "ping"}}, + }, + ) + if err != nil { + t.Fatalf("expected chat with state to succeed, got %v", err) + } + + if resp == nil || resp.Content != "hello world" { + t.Fatalf("unexpected response: %#v", resp) + } + if string(nextState) != `{"sessionId":"session-new"}` { + t.Fatalf("expected new session state, got %s", string(nextState)) + } + if !hasArg(capturedArgs, "--enable-session-tracking") { + t.Fatalf("expected session tracking flag, got args %#v", capturedArgs) + } + if hasArg(capturedArgs, "--no-session-persistence") { + t.Fatalf("did not expect no-session-persistence flag, got args %#v", capturedArgs) + } + if hasArg(capturedArgs, "--resume") { + t.Fatalf("did not expect resume flag for first session, got args %#v", capturedArgs) + } + if !hasArgSequence(capturedArgs, "--model", "deepseek-v3") { + t.Fatalf("expected model flag to be preserved, got args %#v", capturedArgs) + } +} + +func TestCodeBuddyCLIProviderChatWithState_ResumesExistingSession(t *testing.T) { + fakeCodeBuddy := writeFakeCodeBuddyScript(t, "#!/bin/sh\necho '[{\"type\":\"assistant\",\"message\":{\"content\":[{\"type\":\"text\",\"text\":\"continued\"}]}}]'\n") + var capturedArgs []string + restore := overrideCodeBuddyCLIForTestWithCapture(t, fakeCodeBuddy, func(args []string) { + capturedArgs = append([]string(nil), args...) + }) + defer restore() + + providerInstance, err := NewCodeBuddyCLIProvider(ai.ProviderConfig{ + APIKey: "cb-test", + }) + if err != nil { + t.Fatalf("unexpected provider error: %v", err) + } + + resp, nextState, err := providerInstance.(SessionChatProvider).ChatWithState( + context.Background(), + json.RawMessage(`{"sessionId":"session-existing"}`), + ai.ChatRequest{ + Messages: []ai.Message{{Role: "user", Content: "ping again"}}, + }, + ) + if err != nil { + t.Fatalf("expected resumed chat with state to succeed, got %v", err) + } + + if resp == nil || resp.Content != "continued" { + t.Fatalf("unexpected response: %#v", resp) + } + if string(nextState) != `{"sessionId":"session-existing"}` { + t.Fatalf("expected existing session state to be preserved, got %s", string(nextState)) + } + if !hasArgSequence(capturedArgs, "--resume", "session-existing") { + t.Fatalf("expected resume args, got %#v", capturedArgs) + } + if !hasArg(capturedArgs, "--enable-session-tracking") { + t.Fatalf("expected session tracking flag, got args %#v", capturedArgs) + } +} + +func TestCodeBuddyCLIProviderChatStreamWithState_StartsTrackedSession(t *testing.T) { + fakeCodeBuddy := writeFakeCodeBuddyScript(t, "#!/bin/sh\nprintf '%s\\n' '{\"type\":\"system\",\"session_id\":\"session-new\"}' '{\"type\":\"assistant\",\"message\":{\"content\":[{\"type\":\"text\",\"text\":\"hello from codebuddy\"}]}}' '{\"type\":\"result\",\"subtype\":\"success\",\"is_error\":false,\"result\":\"hello from codebuddy\",\"session_id\":\"session-new\"}'\n") + var capturedArgs []string + restore := overrideCodeBuddyCLIForTestWithCapture(t, fakeCodeBuddy, func(args []string) { + capturedArgs = append([]string(nil), args...) + }) + defer restore() + + providerInstance, err := NewCodeBuddyCLIProvider(ai.ProviderConfig{ + APIKey: "cb-test", + Model: "deepseek-v3", + }) + if err != nil { + t.Fatalf("unexpected provider error: %v", err) + } + + var chunks []ai.StreamChunk + nextState, err := providerInstance.(SessionStreamProvider).ChatStreamWithState( + context.Background(), + nil, + ai.ChatRequest{ + Messages: []ai.Message{{Role: "user", Content: "ping"}}, + }, + func(chunk ai.StreamChunk) { + chunks = append(chunks, chunk) + }, + ) + if err != nil { + t.Fatalf("expected chat stream with state to succeed, got %v", err) + } + + if string(nextState) != `{"sessionId":"session-new"}` { + t.Fatalf("expected new session state, got %s", string(nextState)) + } + if len(chunks) < 2 || chunks[0].Content != "hello from codebuddy" || !chunks[len(chunks)-1].Done { + t.Fatalf("unexpected stream chunks: %#v", chunks) + } + if !hasArg(capturedArgs, "--enable-session-tracking") { + t.Fatalf("expected session tracking flag, got args %#v", capturedArgs) + } + if hasArg(capturedArgs, "--no-session-persistence") { + t.Fatalf("did not expect no-session-persistence flag, got args %#v", capturedArgs) + } + if hasArg(capturedArgs, "--resume") { + t.Fatalf("did not expect resume flag for first session, got args %#v", capturedArgs) + } + if !hasArgSequence(capturedArgs, "--model", "deepseek-v3") { + t.Fatalf("expected model flag to be preserved, got args %#v", capturedArgs) + } +} + +func TestCodeBuddyCLIProviderChatStreamWithState_ResumesExistingSessionWithoutDroppingState(t *testing.T) { + fakeCodeBuddy := writeFakeCodeBuddyScript(t, "#!/bin/sh\nprintf '%s\\n' '{\"type\":\"assistant\",\"message\":{\"content\":[{\"type\":\"text\",\"text\":\"continued\"}]}}' '{\"type\":\"result\",\"subtype\":\"success\",\"is_error\":false,\"result\":\"continued\"}'\n") + var capturedArgs []string + restore := overrideCodeBuddyCLIForTestWithCapture(t, fakeCodeBuddy, func(args []string) { + capturedArgs = append([]string(nil), args...) + }) + defer restore() + + providerInstance, err := NewCodeBuddyCLIProvider(ai.ProviderConfig{ + APIKey: "cb-test", + }) + if err != nil { + t.Fatalf("unexpected provider error: %v", err) + } + + var chunks []ai.StreamChunk + nextState, err := providerInstance.(SessionStreamProvider).ChatStreamWithState( + context.Background(), + json.RawMessage(`{"sessionId":"session-existing"}`), + ai.ChatRequest{ + Messages: []ai.Message{{Role: "user", Content: "ping again"}}, + }, + func(chunk ai.StreamChunk) { + chunks = append(chunks, chunk) + }, + ) + if err != nil { + t.Fatalf("expected resumed chat stream to succeed, got %v", err) + } + + if string(nextState) != `{"sessionId":"session-existing"}` { + t.Fatalf("expected existing session state to be preserved, got %s", string(nextState)) + } + if len(chunks) < 2 || chunks[0].Content != "continued" || !chunks[len(chunks)-1].Done { + t.Fatalf("unexpected stream chunks: %#v", chunks) + } + if !hasArgSequence(capturedArgs, "--resume", "session-existing") { + t.Fatalf("expected resume args, got %#v", capturedArgs) + } + if !hasArg(capturedArgs, "--enable-session-tracking") { + t.Fatalf("expected session tracking flag, got args %#v", capturedArgs) + } + if hasArg(capturedArgs, "--no-session-persistence") { + t.Fatalf("did not expect no-session-persistence flag, got args %#v", capturedArgs) + } +} + func writeFakeCodeBuddyScript(t *testing.T, content string) string { t.Helper() dir := t.TempDir() @@ -138,3 +324,54 @@ func overrideCodeBuddyCLIForTest(t *testing.T, fakeCodeBuddyPath string) func() _ = os.Setenv("PATH", originalPath) } } + +func overrideCodeBuddyCLIForTestWithCapture(t *testing.T, fakeCodeBuddyPath string, capture func(args []string)) func() { + t.Helper() + + originalLookPath := codebuddyLookPath + originalCommandContext := codebuddyCommandContext + codebuddyLookPath = func(name string) (string, error) { + if name == "codebuddy" || name == "cbc" { + return fakeCodeBuddyPath, nil + } + return originalLookPath(name) + } + codebuddyCommandContext = func(ctx context.Context, name string, args ...string) *exec.Cmd { + if name == "codebuddy" || name == "cbc" { + if capture != nil { + capture(args) + } + return exec.CommandContext(ctx, fakeCodeBuddyPath, args...) + } + return originalCommandContext(ctx, name, args...) + } + + originalPath := os.Getenv("PATH") + if err := os.Setenv("PATH", filepath.Dir(fakeCodeBuddyPath)+string(os.PathListSeparator)+originalPath); err != nil { + t.Fatalf("failed to override PATH: %v", err) + } + + return func() { + codebuddyLookPath = originalLookPath + codebuddyCommandContext = originalCommandContext + _ = os.Setenv("PATH", originalPath) + } +} + +func hasArg(args []string, target string) bool { + for _, arg := range args { + if arg == target { + return true + } + } + return false +} + +func hasArgSequence(args []string, key string, value string) bool { + for index := 0; index < len(args)-1; index++ { + if args[index] == key && args[index+1] == value { + return true + } + } + return false +} diff --git a/internal/ai/provider/cursor_agent.go b/internal/ai/provider/cursor_agent.go index 7f110bf..f158b82 100644 --- a/internal/ai/provider/cursor_agent.go +++ b/internal/ai/provider/cursor_agent.go @@ -22,13 +22,24 @@ const ( ) // CursorAgentProvider 通过 Cursor Cloud Agents API 发起对话。 -// 当前实现为无状态适配:每次请求都创建一个新的 agent,再消费本次 run 的结果。 +// 支持基于 session state 复用已有 agent,并对 follow-up runs 继续追加上下文。 type CursorAgentProvider struct { config ai.ProviderConfig baseURL string client *http.Client } +type cursorSessionState struct { + AgentID string `json:"agentId,omitempty"` + LastRunID string `json:"lastRunId,omitempty"` +} + +type cursorImageInput struct { + Data string `json:"data,omitempty"` + URL string `json:"url,omitempty"` + MimeType string `json:"mimeType,omitempty"` +} + // NewCursorAgentProvider 创建 Cursor Agent Provider。 func NewCursorAgentProvider(config ai.ProviderConfig) (Provider, error) { normalized := config @@ -134,7 +145,8 @@ func normalizeCursorAPIPath(path string) string { } type cursorPrompt struct { - Text string `json:"text"` + Text string `json:"text"` + Images []cursorImageInput `json:"images,omitempty"` } type cursorModelSelection struct { @@ -146,6 +158,10 @@ type cursorCreateAgentRequest struct { Model *cursorModelSelection `json:"model,omitempty"` } +type cursorCreateRunRequest struct { + Prompt cursorPrompt `json:"prompt"` +} + type cursorCreateAgentResponse struct { Agent struct { ID string `json:"id"` @@ -181,38 +197,85 @@ type cursorResultEvent struct { } func (p *CursorAgentProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.ChatResponse, error) { + resp, _, err := p.ChatWithState(ctx, nil, req) + return resp, err +} + +func (p *CursorAgentProvider) ChatWithState(ctx context.Context, state json.RawMessage, req ai.ChatRequest) (*ai.ChatResponse, json.RawMessage, error) { if err := p.Validate(); err != nil { - return nil, err + return nil, nil, err } - agentID, runID, err := p.createAgent(ctx, req) + sessionState, err := parseCursorSessionState(state) if err != nil { - return nil, err + return nil, nil, err + } + + agentID := strings.TrimSpace(sessionState.AgentID) + runID := "" + if agentID == "" { + agentID, runID, err = p.createAgent(ctx, req) + if err != nil { + return nil, nil, err + } + } else { + runID, err = p.createRun(ctx, agentID, req) + if err != nil { + return nil, nil, err + } } run, err := p.waitForRun(ctx, agentID, runID) if err != nil { - return nil, err + return nil, nil, err + } + + sessionState.AgentID = agentID + sessionState.LastRunID = runID + nextState, err := json.Marshal(sessionState) + if err != nil { + return nil, nil, fmt.Errorf("序列化 Cursor 会话状态失败: %w", err) } return &ai.ChatResponse{ Content: strings.TrimSpace(run.Result), - }, nil + }, json.RawMessage(nextState), nil } func (p *CursorAgentProvider) ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error { + _, err := p.ChatStreamWithState(ctx, nil, req, callback) + return err +} + +func (p *CursorAgentProvider) ChatStreamWithState(ctx context.Context, state json.RawMessage, req ai.ChatRequest, callback func(ai.StreamChunk)) (json.RawMessage, error) { if err := p.Validate(); err != nil { - return err + return nil, err } - agentID, runID, err := p.createAgent(ctx, req) + sessionState, err := parseCursorSessionState(state) if err != nil { - return err + return nil, err } + agentID := strings.TrimSpace(sessionState.AgentID) + runID := "" + if agentID == "" { + agentID, runID, err = p.createAgent(ctx, req) + if err != nil { + return nil, err + } + } else { + runID, err = p.createRun(ctx, agentID, req) + if err != nil { + return nil, err + } + } + sessionState.AgentID = agentID + sessionState.LastRunID = runID + stream, err := p.openRunStream(ctx, agentID, runID) if err != nil { - return err + return nil, err } defer stream.Close() @@ -314,10 +377,10 @@ func (p *CursorAgentProvider) ChatStream(ctx context.Context, req ai.ChatRequest currentEventType = "" currentDataLines = nil if dispatchErr != nil { - return dispatchErr + return nil, dispatchErr } if done { - return nil + return marshalCursorSessionState(sessionState) } case strings.HasPrefix(line, "event:"): currentEventType = strings.TrimSpace(strings.TrimPrefix(line, "event:")) @@ -327,27 +390,27 @@ func (p *CursorAgentProvider) ChatStream(ctx context.Context, req ai.ChatRequest } if err := scanner.Err(); err != nil { - return fmt.Errorf("读取 Cursor 流式响应失败: %w", err) + return nil, fmt.Errorf("读取 Cursor 流式响应失败: %w", err) } if len(currentDataLines) > 0 || strings.TrimSpace(currentEventType) != "" { done, dispatchErr := dispatchEvent(currentEventType, currentDataLines) if dispatchErr != nil { - return dispatchErr + return nil, dispatchErr } if done { - return nil + return marshalCursorSessionState(sessionState) } } if !completedExplicitly { if !receivedAssistantText && !receivedResultText { callback(ai.StreamChunk{Error: "未收到任何有效响应内容,请检查 Cursor 配置或模型权限", Done: true}) - return nil + return marshalCursorSessionState(sessionState) } callback(ai.StreamChunk{Done: true}) } - return nil + return marshalCursorSessionState(sessionState) } func (p *CursorAgentProvider) createAgent(ctx context.Context, req ai.ChatRequest) (string, string, error) { @@ -370,15 +433,13 @@ func (p *CursorAgentProvider) createAgent(ctx context.Context, req ai.ChatReques } func buildCursorCreateAgentRequest(req ai.ChatRequest, model string) (cursorCreateAgentRequest, error) { - prompt, err := buildCursorPrompt(req.Messages) + prompt, err := buildCursorPromptInput(req.Messages) if err != nil { return cursorCreateAgentRequest{}, err } requestBody := cursorCreateAgentRequest{ - Prompt: cursorPrompt{ - Text: prompt, - }, + Prompt: prompt, } if trimmedModel := strings.TrimSpace(model); trimmedModel != "" { @@ -389,18 +450,59 @@ func buildCursorCreateAgentRequest(req ai.ChatRequest, model string) (cursorCrea } func buildCursorPrompt(messages []ai.Message) (string, error) { - requestMessages := messages - if requestMessagesContainImages(messages) { - requestMessages = stripImagesFromRequestMessages(messages) + prompt := strings.TrimSpace(buildPrompt(messages)) + if prompt == "" && requestMessagesContainImages(messages) { + return "请结合这些图片继续分析并回答。", nil } - - prompt := strings.TrimSpace(buildPrompt(requestMessages)) if prompt == "" { return "", fmt.Errorf("请求内容不能为空") } return prompt, nil } +func buildCursorPromptInput(messages []ai.Message) (cursorPrompt, error) { + text, err := buildCursorPrompt(messages) + if err != nil { + return cursorPrompt{}, err + } + images, err := buildCursorImageInputs(messages) + if err != nil { + return cursorPrompt{}, err + } + return cursorPrompt{ + Text: text, + Images: images, + }, nil +} + +func buildCursorImageInputs(messages []ai.Message) ([]cursorImageInput, error) { + images := make([]cursorImageInput, 0) + for _, message := range messages { + for _, img := range message.Images { + trimmed := strings.TrimSpace(img) + if trimmed == "" { + continue + } + if strings.HasPrefix(trimmed, "http://") || strings.HasPrefix(trimmed, "https://") { + images = append(images, cursorImageInput{URL: trimmed}) + continue + } + mimeType, rawBase64, err := ParseDataURI(trimmed) + if err != nil { + return nil, fmt.Errorf("解析图片数据失败: %w", err) + } + images = append(images, cursorImageInput{ + Data: rawBase64, + MimeType: mimeType, + }) + } + } + if len(images) > 5 { + return nil, fmt.Errorf("Cursor 最多支持 5 张图片,当前请求包含 %d 张", len(images)) + } + return images, nil +} + func (p *CursorAgentProvider) waitForRun(ctx context.Context, agentID string, runID string) (*cursorRunResponse, error) { ticker := time.NewTicker(cursorRunPollInterval) defer ticker.Stop() @@ -455,6 +557,31 @@ func (p *CursorAgentProvider) getRun(ctx context.Context, agentID string, runID return &responseBody, nil } +func (p *CursorAgentProvider) createRun(ctx context.Context, agentID string, req ai.ChatRequest) (string, error) { + prompt, err := buildCursorPromptInput(req.Messages) + if err != nil { + return "", err + } + + requestBody := cursorCreateRunRequest{ + Prompt: prompt, + } + + var responseBody struct { + Run struct { + ID string `json:"id"` + } `json:"run"` + } + if err := p.doJSONRequest(ctx, http.MethodPost, ResolveCursorAPIEndpoint(p.baseURL, fmt.Sprintf("agents/%s/runs", agentID)), requestBody, &responseBody, "application/json"); err != nil { + return "", err + } + runID := strings.TrimSpace(responseBody.Run.ID) + if runID == "" { + return "", fmt.Errorf("Cursor 创建 follow-up run 成功,但未返回有效 runId") + } + return runID, nil +} + func (p *CursorAgentProvider) openRunStream(ctx context.Context, agentID string, runID string) (io.ReadCloser, error) { endpoint := ResolveCursorAPIEndpoint(p.baseURL, fmt.Sprintf("agents/%s/runs/%s/stream", agentID, runID)) requestLog := logAIUpstreamRequestStart(p.Name(), http.MethodGet, endpoint, nil) @@ -541,6 +668,28 @@ func (p *CursorAgentProvider) doJSONRequest(ctx context.Context, method string, return nil } +func parseCursorSessionState(state json.RawMessage) (cursorSessionState, error) { + if len(state) == 0 { + return cursorSessionState{}, nil + } + var result cursorSessionState + if err := json.Unmarshal(state, &result); err != nil { + return cursorSessionState{}, fmt.Errorf("解析 Cursor 会话状态失败: %w", err) + } + return result, nil +} + +func marshalCursorSessionState(state cursorSessionState) (json.RawMessage, error) { + if strings.TrimSpace(state.AgentID) == "" { + return nil, nil + } + bytes, err := json.Marshal(state) + if err != nil { + return nil, fmt.Errorf("序列化 Cursor 会话状态失败: %w", err) + } + return json.RawMessage(bytes), nil +} + func isCursorRunTerminalStatus(status string) bool { switch strings.ToUpper(strings.TrimSpace(status)) { case "FINISHED", "ERROR", "CANCELLED", "EXPIRED": diff --git a/internal/ai/provider/cursor_agent_test.go b/internal/ai/provider/cursor_agent_test.go index 10d7ed5..4ed4a21 100644 --- a/internal/ai/provider/cursor_agent_test.go +++ b/internal/ai/provider/cursor_agent_test.go @@ -99,6 +99,85 @@ func TestCursorAgentProviderChat_PollsUntilFinished(t *testing.T) { } } +func TestCursorAgentProviderChatWithState_UsesFollowUpRunsAndPreservesAgent(t *testing.T) { + var ( + createAgentCalls int32 + createRunCalls int32 + receivedPrompt string + ) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodPost && r.URL.Path == "/v1/agents": + atomic.AddInt32(&createAgentCalls, 1) + t.Fatalf("expected follow-up request to avoid creating a new agent") + case r.Method == http.MethodPost && r.URL.Path == "/v1/agents/bc-existing/runs": + atomic.AddInt32(&createRunCalls, 1) + var body struct { + Prompt struct { + Text string `json:"text"` + } `json:"prompt"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("decode follow-up run body: %v", err) + } + receivedPrompt = body.Prompt.Text + _ = json.NewEncoder(w).Encode(map[string]any{ + "run": map[string]any{"id": "run-next"}, + }) + case r.Method == http.MethodGet && r.URL.Path == "/v1/agents/bc-existing/runs/run-next": + _ = json.NewEncoder(w).Encode(map[string]any{ + "id": "run-next", + "agentId": "bc-existing", + "status": "FINISHED", + "result": "done from follow-up", + "durationMs": 456, + }) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + providerInstance, err := NewCursorAgentProvider(ai.ProviderConfig{ + Name: "Cursor", + BaseURL: server.URL + "/v1", + APIKey: "cursor-key", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + resp, nextState, err := providerInstance.(SessionChatProvider).ChatWithState( + context.Background(), + json.RawMessage(`{"agentId":"bc-existing","lastRunId":"run-old"}`), + ai.ChatRequest{ + Messages: []ai.Message{ + {Role: "user", Content: "follow this up"}, + }, + }, + ) + if err != nil { + t.Fatalf("follow-up chat failed: %v", err) + } + + if atomic.LoadInt32(&createAgentCalls) != 0 { + t.Fatalf("expected no create-agent calls, got %d", createAgentCalls) + } + if atomic.LoadInt32(&createRunCalls) != 1 { + t.Fatalf("expected exactly one follow-up run call, got %d", createRunCalls) + } + if !strings.Contains(receivedPrompt, "follow this up") { + t.Fatalf("expected follow-up prompt text, got %q", receivedPrompt) + } + if resp == nil || resp.Content != "done from follow-up" { + t.Fatalf("unexpected response: %#v", resp) + } + if string(nextState) != `{"agentId":"bc-existing","lastRunId":"run-next"}` { + t.Fatalf("unexpected next session state: %s", string(nextState)) + } +} + func TestCursorAgentProviderChatStream_MapsAssistantAndThinkingEvents(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { @@ -164,3 +243,109 @@ func TestCursorAgentProviderChatStream_MapsAssistantAndThinkingEvents(t *testing t.Fatalf("expected final done chunk, got %#v", chunks[len(chunks)-1]) } } + +func TestCursorAgentProviderChatStreamWithState_UsesFollowUpRunsAndPreservesAgent(t *testing.T) { + var ( + createAgentCalls int32 + createRunCalls int32 + receivedPrompt string + ) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodPost && r.URL.Path == "/v1/agents": + atomic.AddInt32(&createAgentCalls, 1) + t.Fatalf("expected follow-up request to avoid creating a new agent") + case r.Method == http.MethodPost && r.URL.Path == "/v1/agents/bc-existing/runs": + atomic.AddInt32(&createRunCalls, 1) + var body struct { + Prompt struct { + Text string `json:"text"` + } `json:"prompt"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("decode follow-up run body: %v", err) + } + receivedPrompt = body.Prompt.Text + _ = json.NewEncoder(w).Encode(map[string]any{ + "run": map[string]any{"id": "run-next"}, + }) + case r.Method == http.MethodGet && r.URL.Path == "/v1/agents/bc-existing/runs/run-next/stream": + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte("event: assistant\n")) + _, _ = w.Write([]byte("data: {\"text\":\"done\"}\n\n")) + _, _ = w.Write([]byte("event: done\n")) + _, _ = w.Write([]byte("data: {}\n\n")) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + providerInstance, err := NewCursorAgentProvider(ai.ProviderConfig{ + Name: "Cursor", + BaseURL: server.URL + "/v1", + APIKey: "cursor-key", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + sessionState := json.RawMessage(`{"agentId":"bc-existing","lastRunId":"run-old"}`) + var chunks []ai.StreamChunk + nextState, err := providerInstance.(SessionStreamProvider).ChatStreamWithState( + context.Background(), + sessionState, + ai.ChatRequest{ + Messages: []ai.Message{ + {Role: "user", Content: "follow this up"}, + }, + }, + func(chunk ai.StreamChunk) { + chunks = append(chunks, chunk) + }, + ) + if err != nil { + t.Fatalf("follow-up stream failed: %v", err) + } + + if atomic.LoadInt32(&createAgentCalls) != 0 { + t.Fatalf("expected no create-agent calls, got %d", createAgentCalls) + } + if atomic.LoadInt32(&createRunCalls) != 1 { + t.Fatalf("expected exactly one follow-up run call, got %d", createRunCalls) + } + if !strings.Contains(receivedPrompt, "follow this up") { + t.Fatalf("expected follow-up prompt text, got %q", receivedPrompt) + } + if string(nextState) != `{"agentId":"bc-existing","lastRunId":"run-next"}` { + t.Fatalf("unexpected next session state: %s", string(nextState)) + } + if len(chunks) == 0 || chunks[0].Content != "done" { + t.Fatalf("expected streamed assistant content, got %#v", chunks) + } +} + +func TestCursorAgentProviderCreateAgentRequest_IncludesImageInputs(t *testing.T) { + requestBody, err := buildCursorCreateAgentRequest(ai.ChatRequest{ + Messages: []ai.Message{ + { + Role: "user", + Content: "look at this", + Images: []string{"data:image/png;base64,aGVsbG8="}, + }, + }, + }, "composer-latest") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(requestBody.Prompt.Images) != 1 { + t.Fatalf("expected one image payload, got %#v", requestBody.Prompt.Images) + } + if requestBody.Prompt.Images[0].Data != "aGVsbG8=" || requestBody.Prompt.Images[0].MimeType != "image/png" { + t.Fatalf("unexpected image payload: %#v", requestBody.Prompt.Images[0]) + } + if requestBody.Model == nil || requestBody.Model.ID != "composer-latest" { + t.Fatalf("expected model selection to be preserved, got %#v", requestBody.Model) + } +} diff --git a/internal/ai/provider/custom.go b/internal/ai/provider/custom.go index 18f7df4..f8f3376 100644 --- a/internal/ai/provider/custom.go +++ b/internal/ai/provider/custom.go @@ -2,6 +2,7 @@ package provider import ( "context" + "encoding/json" "fmt" "strings" @@ -75,3 +76,20 @@ func (p *CustomProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.Chat func (p *CustomProvider) ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error { return p.inner.ChatStream(ctx, req, callback) } + +func (p *CustomProvider) ChatWithState(ctx context.Context, state json.RawMessage, req ai.ChatRequest) (*ai.ChatResponse, json.RawMessage, error) { + sessionProvider, ok := p.inner.(SessionChatProvider) + if !ok { + resp, err := p.inner.Chat(ctx, req) + return resp, nil, err + } + return sessionProvider.ChatWithState(ctx, state, req) +} + +func (p *CustomProvider) ChatStreamWithState(ctx context.Context, state json.RawMessage, req ai.ChatRequest, callback func(ai.StreamChunk)) (json.RawMessage, error) { + sessionProvider, ok := p.inner.(SessionStreamProvider) + if !ok { + return nil, p.inner.ChatStream(ctx, req, callback) + } + return sessionProvider.ChatStreamWithState(ctx, state, req, callback) +} diff --git a/internal/ai/provider/provider.go b/internal/ai/provider/provider.go index e9f1d8e..1916c13 100644 --- a/internal/ai/provider/provider.go +++ b/internal/ai/provider/provider.go @@ -2,6 +2,7 @@ package provider import ( "context" + "encoding/json" "GoNavi-Wails/internal/ai" ) @@ -17,3 +18,24 @@ type Provider interface { // Validate 校验配置是否有效 Validate() error } + +// SessionStreamProvider 表示支持按会话复用上游状态的流式 Provider。 +// state 为 Provider 自己维护的持久化状态;返回值为更新后的状态快照。 +type SessionStreamProvider interface { + ChatStreamWithState( + ctx context.Context, + state json.RawMessage, + req ai.ChatRequest, + callback func(ai.StreamChunk), + ) (json.RawMessage, error) +} + +// SessionChatProvider 表示支持按会话复用上游状态的非流式 Provider。 +// state 为 Provider 自己维护的持久化状态;返回值为响应体和更新后的状态快照。 +type SessionChatProvider interface { + ChatWithState( + ctx context.Context, + state json.RawMessage, + req ai.ChatRequest, + ) (*ai.ChatResponse, json.RawMessage, error) +} diff --git a/internal/ai/service/service.go b/internal/ai/service/service.go index d735434..0405a82 100644 --- a/internal/ai/service/service.go +++ b/internal/ai/service/service.go @@ -9,6 +9,7 @@ import ( "net/url" "os" "path/filepath" + "reflect" "strings" "sync" "time" @@ -42,11 +43,18 @@ type Service struct { secretStore secretstore.SecretStore localizer *i18n.Localizer cancelFuncs map[string]context.CancelFunc // 记录每个 session 的 context 取消函数 + sessionProviders map[string]aiSessionProviderRuntime mcpHTTPMu sync.Mutex mcpHTTP *mcpHTTPServerRuntime mcpHTTPLast ai.MCPHTTPServerStatus } +type aiSessionProviderRuntime struct { + ProviderKey string + State json.RawMessage + Messages []ai.Message +} + var miniMaxAnthropicModels = []string{ "MiniMax-M3", "MiniMax-M2.7", @@ -131,15 +139,16 @@ func NewServiceWithSecretStore(store secretstore.SecretStore) *Service { store = secretstore.NewUnavailableStore("secret store unavailable") } return &Service{ - providers: make([]ai.ProviderConfig, 0), - safetyLevel: ai.PermissionReadOnly, - contextLevel: ai.ContextSchemaOnly, - mcpServers: make([]ai.MCPServerConfig, 0), - skills: make([]ai.SkillConfig, 0), - guard: safety.NewGuard(ai.PermissionReadOnly), - secretStore: store, - localizer: newServiceLocalizer(), - cancelFuncs: make(map[string]context.CancelFunc), + providers: make([]ai.ProviderConfig, 0), + safetyLevel: ai.PermissionReadOnly, + contextLevel: ai.ContextSchemaOnly, + mcpServers: make([]ai.MCPServerConfig, 0), + skills: make([]ai.SkillConfig, 0), + guard: safety.NewGuard(ai.PermissionReadOnly), + secretStore: store, + localizer: newServiceLocalizer(), + cancelFuncs: make(map[string]context.CancelFunc), + sessionProviders: make(map[string]aiSessionProviderRuntime), } } @@ -1150,7 +1159,16 @@ func (s *Service) AISetContextLevel(level string) { // AIChatSend 非流式发送 AI 对话 func (s *Service) AIChatSend(messages []ai.Message, tools []ai.Tool) map[string]interface{} { - p, err := s.getActiveProvider() + return s.aiChatSend("", messages, tools, false) +} + +// AIChatSendInSession 非流式发送 AI 对话,并在支持的 Provider 上复用会话态。 +func (s *Service) AIChatSendInSession(sessionID string, messages []ai.Message, tools []ai.Tool) map[string]interface{} { + return s.aiChatSend(sessionID, messages, tools, true) +} + +func (s *Service) aiChatSend(sessionID string, messages []ai.Message, tools []ai.Tool, allowSessionReuse bool) map[string]interface{} { + p, config, err := s.getActiveProviderRuntime() if err != nil { logger.Error(err, "AIChatSend 获取 Provider 失败:messages=%d tools=%d", len(messages), len(tools)) return map[string]interface{}{"success": false, "error": err.Error()} @@ -1158,14 +1176,62 @@ func (s *Service) AIChatSend(messages []ai.Message, tools []ai.Tool) map[string] started := time.Now() providerName := p.Name() - logger.Infof("AIChatSend 开始:provider=%s messages=%d tools=%d", providerName, len(messages), len(tools)) - resp, err := p.Chat(context.Background(), ai.ChatRequest{Messages: messages, Tools: tools}) + logger.Infof("AIChatSend 开始:sessionID=%s provider=%s messages=%d tools=%d sessionReuse=%t", sessionID, providerName, len(messages), len(tools), allowSessionReuse) + requestMessages := cloneAIMessages(messages) + var updatedProviderState json.RawMessage + if allowSessionReuse && strings.TrimSpace(sessionID) != "" { + if sessionAwareProvider, ok := p.(provider.SessionChatProvider); ok { + providerKey := providerSessionKey(config) + providerState, deltaMessages := s.resolveSessionProviderRequest(sessionID, providerKey, messages) + requestMessages = deltaMessages + resp, updatedState, err := sessionAwareProvider.ChatWithState(context.Background(), providerState, ai.ChatRequest{Messages: requestMessages, Tools: tools}) + if err != nil { + logger.Warnf("AIChatSend 失败:sessionID=%s provider=%s messages=%d tools=%d duration=%s err=%s", sessionID, providerName, len(messages), len(tools), time.Since(started).Round(time.Millisecond), provider.RedactAIUpstreamLogText(err.Error())) + return map[string]interface{}{"success": false, "error": err.Error()} + } + updatedProviderState = updatedState + historyAfterSend := cloneAIMessages(messages) + if assistantMessage, hasAssistantMessage := buildAssistantMessageFromChatResponse(resp); hasAssistantMessage { + historyAfterSend = append(historyAfterSend, assistantMessage) + } + if persistErr := s.storeSessionProviderRuntime(sessionID, providerKey, updatedProviderState, historyAfterSend); persistErr != nil { + logger.Warnf("AIChatSend 保存会话 Provider 状态失败:sessionID=%s provider=%s err=%s", sessionID, providerName, provider.RedactAIUpstreamLogText(persistErr.Error())) + } + logger.Infof( + "AIChatSend 完成:sessionID=%s provider=%s messages=%d tools=%d toolCalls=%d promptTokens=%d completionTokens=%d totalTokens=%d duration=%s sessionReuse=%t", + sessionID, + providerName, + len(messages), + len(tools), + len(resp.ToolCalls), + resp.TokensUsed.PromptTokens, + resp.TokensUsed.CompletionTokens, + resp.TokensUsed.TotalTokens, + time.Since(started).Round(time.Millisecond), + true, + ) + return map[string]interface{}{ + "success": true, + "content": resp.Content, + "reasoning_content": resp.ReasoningContent, + "tool_calls": resp.ToolCalls, + "tokensUsed": map[string]int{ + "promptTokens": resp.TokensUsed.PromptTokens, + "completionTokens": resp.TokensUsed.CompletionTokens, + "totalTokens": resp.TokensUsed.TotalTokens, + }, + } + } + } + + resp, err := p.Chat(context.Background(), ai.ChatRequest{Messages: requestMessages, Tools: tools}) if err != nil { - logger.Warnf("AIChatSend 失败:provider=%s messages=%d tools=%d duration=%s err=%s", providerName, len(messages), len(tools), time.Since(started).Round(time.Millisecond), provider.RedactAIUpstreamLogText(err.Error())) + logger.Warnf("AIChatSend 失败:sessionID=%s provider=%s messages=%d tools=%d duration=%s err=%s", sessionID, providerName, len(messages), len(tools), time.Since(started).Round(time.Millisecond), provider.RedactAIUpstreamLogText(err.Error())) return map[string]interface{}{"success": false, "error": err.Error()} } logger.Infof( - "AIChatSend 完成:provider=%s messages=%d tools=%d toolCalls=%d promptTokens=%d completionTokens=%d totalTokens=%d duration=%s", + "AIChatSend 完成:sessionID=%s provider=%s messages=%d tools=%d toolCalls=%d promptTokens=%d completionTokens=%d totalTokens=%d duration=%s sessionReuse=%t", + sessionID, providerName, len(messages), len(tools), @@ -1174,6 +1240,7 @@ func (s *Service) AIChatSend(messages []ai.Message, tools []ai.Tool) map[string] resp.TokensUsed.CompletionTokens, resp.TokensUsed.TotalTokens, time.Since(started).Round(time.Millisecond), + false, ) return map[string]interface{}{ @@ -1204,7 +1271,7 @@ func (s *Service) AIChatStream(sessionID string, messages []ai.Message, tools [] cancel() // 确保释放 }() - p, err := s.getActiveProvider() + p, config, err := s.getActiveProviderRuntime() if err != nil { logger.Error(err, "AIChatStream 获取 Provider 失败:sessionID=%s messages=%d tools=%d", sessionID, len(messages), len(tools)) wailsRuntime.EventsEmit(s.ctx, "ai:stream:"+sessionID, map[string]interface{}{ @@ -1220,29 +1287,67 @@ func (s *Service) AIChatStream(sessionID string, messages []ai.Message, tools [] thinkingChunks := 0 toolCallChunks := 0 errorChunks := 0 + var assistantContent strings.Builder + var assistantReasoning strings.Builder + var assistantToolCalls []ai.ToolCall + var updatedProviderState json.RawMessage + requestMessages := cloneAIMessages(messages) logger.Infof("AIChatStream 开始:sessionID=%s provider=%s messages=%d tools=%d", sessionID, providerName, len(messages), len(tools)) - err = p.ChatStream(streamCtx, ai.ChatRequest{Messages: messages, Tools: tools}, func(chunk ai.StreamChunk) { - if chunk.Content != "" { - contentChunks++ - } - if chunk.Thinking != "" || chunk.ReasoningContent != "" { - thinkingChunks++ - } - if len(chunk.ToolCalls) > 0 { - toolCallChunks++ - } - if chunk.Error != "" { - errorChunks++ - } - wailsRuntime.EventsEmit(s.ctx, "ai:stream:"+sessionID, map[string]interface{}{ - "content": chunk.Content, - "thinking": chunk.Thinking, - "reasoning_content": chunk.ReasoningContent, - "tool_calls": chunk.ToolCalls, - "done": chunk.Done, - "error": chunk.Error, + if sessionAwareProvider, ok := p.(provider.SessionStreamProvider); ok { + providerKey := providerSessionKey(config) + providerState, deltaMessages := s.resolveSessionProviderRequest(sessionID, providerKey, messages) + requestMessages = deltaMessages + updatedProviderState, err = sessionAwareProvider.ChatStreamWithState(streamCtx, providerState, ai.ChatRequest{Messages: requestMessages, Tools: tools}, func(chunk ai.StreamChunk) { + if chunk.Content != "" { + contentChunks++ + assistantContent.WriteString(chunk.Content) + } + if chunk.Thinking != "" || chunk.ReasoningContent != "" { + thinkingChunks++ + if chunk.ReasoningContent != "" { + assistantReasoning.WriteString(chunk.ReasoningContent) + } + } + if len(chunk.ToolCalls) > 0 { + toolCallChunks++ + assistantToolCalls = append([]ai.ToolCall(nil), chunk.ToolCalls...) + } + if chunk.Error != "" { + errorChunks++ + } + wailsRuntime.EventsEmit(s.ctx, "ai:stream:"+sessionID, map[string]interface{}{ + "content": chunk.Content, + "thinking": chunk.Thinking, + "reasoning_content": chunk.ReasoningContent, + "tool_calls": chunk.ToolCalls, + "done": chunk.Done, + "error": chunk.Error, + }) }) - }) + } else { + err = p.ChatStream(streamCtx, ai.ChatRequest{Messages: messages, Tools: tools}, func(chunk ai.StreamChunk) { + if chunk.Content != "" { + contentChunks++ + } + if chunk.Thinking != "" || chunk.ReasoningContent != "" { + thinkingChunks++ + } + if len(chunk.ToolCalls) > 0 { + toolCallChunks++ + } + if chunk.Error != "" { + errorChunks++ + } + wailsRuntime.EventsEmit(s.ctx, "ai:stream:"+sessionID, map[string]interface{}{ + "content": chunk.Content, + "thinking": chunk.Thinking, + "reasoning_content": chunk.ReasoningContent, + "tool_calls": chunk.ToolCalls, + "done": chunk.Done, + "error": chunk.Error, + }) + }) + } // 当 context 被主动 cancel 的时候,不把这个视为向外抛的 error if err != nil && err != context.Canceled { @@ -1257,6 +1362,16 @@ func (s *Service) AIChatStream(sessionID string, messages []ai.Message, tools [] logger.Infof("AIChatStream 已取消:sessionID=%s provider=%s duration=%s", sessionID, providerName, time.Since(started).Round(time.Millisecond)) return } + if _, ok := p.(provider.SessionStreamProvider); ok && errorChunks == 0 { + providerKey := providerSessionKey(config) + historyAfterStream := cloneAIMessages(messages) + if assistantMessage, hasAssistantMessage := buildAssistantMessageFromStreamResult(assistantContent.String(), assistantReasoning.String(), assistantToolCalls); hasAssistantMessage { + historyAfterStream = append(historyAfterStream, assistantMessage) + } + if persistErr := s.storeSessionProviderRuntime(sessionID, providerKey, updatedProviderState, historyAfterStream); persistErr != nil { + logger.Warnf("AIChatStream 保存会话 Provider 状态失败:sessionID=%s provider=%s err=%s", sessionID, providerName, provider.RedactAIUpstreamLogText(persistErr.Error())) + } + } logger.Infof( "AIChatStream 完成:sessionID=%s provider=%s messages=%d tools=%d contentChunks=%d thinkingChunks=%d toolCallChunks=%d errorChunks=%d duration=%s", sessionID, @@ -1292,6 +1407,11 @@ func (s *Service) AICheckSQL(sql string) ai.SafetyResult { // --- 内部方法 --- func (s *Service) getActiveProvider() (provider.Provider, error) { + p, _, err := s.getActiveProviderRuntime() + return p, err +} + +func (s *Service) getActiveProviderRuntime() (provider.Provider, ai.ProviderConfig, error) { s.mu.RLock() defer s.mu.RUnlock() @@ -1301,11 +1421,174 @@ func (s *Service) getActiveProvider() (provider.Provider, error) { for _, cfg := range s.providers { if cfg.ID == s.activeProvider { - return provider.NewProvider(normalizeProviderConfig(cfg)) + normalized := normalizeProviderConfig(cfg) + p, err := provider.NewProvider(normalized) + return p, normalized, err } } - return nil, fmt.Errorf("未配置 AI Provider,请先在设置中配置") + return nil, ai.ProviderConfig{}, fmt.Errorf("未配置 AI Provider,请先在设置中配置") +} + +func providerSessionKey(config ai.ProviderConfig) string { + return strings.Join([]string{ + strings.TrimSpace(config.ID), + strings.ToLower(strings.TrimSpace(config.Type)), + strings.ToLower(strings.TrimSpace(config.APIFormat)), + strings.TrimSpace(config.BaseURL), + strings.TrimSpace(config.Model), + }, "|") +} + +func cloneAIMessages(messages []ai.Message) []ai.Message { + if len(messages) == 0 { + return nil + } + cloned := make([]ai.Message, len(messages)) + for index, message := range messages { + cloned[index] = message + if len(message.Images) > 0 { + cloned[index].Images = append([]string(nil), message.Images...) + } + if len(message.ToolCalls) > 0 { + cloned[index].ToolCalls = append([]ai.ToolCall(nil), message.ToolCalls...) + } + } + return cloned +} + +func buildAssistantMessageFromStreamResult(content string, reasoning string, toolCalls []ai.ToolCall) (ai.Message, bool) { + message := ai.Message{ + Role: "assistant", + Content: content, + ReasoningContent: reasoning, + } + if len(toolCalls) > 0 { + message.ToolCalls = append([]ai.ToolCall(nil), toolCalls...) + } + hasPayload := strings.TrimSpace(message.Content) != "" || strings.TrimSpace(message.ReasoningContent) != "" || len(message.ToolCalls) > 0 + return message, hasPayload +} + +func buildAssistantMessageFromChatResponse(resp *ai.ChatResponse) (ai.Message, bool) { + if resp == nil { + return ai.Message{}, false + } + return buildAssistantMessageFromStreamResult(resp.Content, resp.ReasoningContent, resp.ToolCalls) +} + +func messagesHavePrefix(messages []ai.Message, prefix []ai.Message) bool { + if len(prefix) == 0 { + return true + } + if len(messages) < len(prefix) { + return false + } + for index := range prefix { + if !reflect.DeepEqual(messages[index], prefix[index]) { + return false + } + } + return true +} + +func (s *Service) resolveSessionProviderRequest(sessionID string, providerKey string, messages []ai.Message) (json.RawMessage, []ai.Message) { + runtimeState, ok := s.loadSessionProviderRuntime(sessionID, providerKey) + if !ok || len(runtimeState.State) == 0 || len(runtimeState.Messages) == 0 { + return nil, cloneAIMessages(messages) + } + if !messagesHavePrefix(messages, runtimeState.Messages) { + return nil, cloneAIMessages(messages) + } + deltaMessages := cloneAIMessages(messages[len(runtimeState.Messages):]) + if len(deltaMessages) == 0 { + return nil, cloneAIMessages(messages) + } + return runtimeState.State, deltaMessages +} + +func (s *Service) loadSessionProviderRuntime(sessionID string, providerKey string) (aiSessionProviderRuntime, bool) { + s.mu.RLock() + runtimeState, ok := s.sessionProviders[sessionID] + s.mu.RUnlock() + if ok && runtimeState.ProviderKey == providerKey { + return aiSessionProviderRuntime{ + ProviderKey: runtimeState.ProviderKey, + State: append(json.RawMessage(nil), runtimeState.State...), + Messages: cloneAIMessages(runtimeState.Messages), + }, true + } + + sessionData, err := s.loadSessionFile(sessionID) + if err != nil { + return aiSessionProviderRuntime{}, false + } + if strings.TrimSpace(sessionData.ProviderKey) == "" || sessionData.ProviderKey != providerKey || len(sessionData.ProviderState) == 0 { + return aiSessionProviderRuntime{}, false + } + var providerMessages []ai.Message + if len(sessionData.ProviderMessages) > 0 { + if err := json.Unmarshal(sessionData.ProviderMessages, &providerMessages); err != nil { + return aiSessionProviderRuntime{}, false + } + } + + runtimeState = aiSessionProviderRuntime{ + ProviderKey: sessionData.ProviderKey, + State: append(json.RawMessage(nil), sessionData.ProviderState...), + Messages: providerMessages, + } + s.mu.Lock() + s.sessionProviders[sessionID] = runtimeState + s.mu.Unlock() + return aiSessionProviderRuntime{ + ProviderKey: runtimeState.ProviderKey, + State: append(json.RawMessage(nil), runtimeState.State...), + Messages: cloneAIMessages(runtimeState.Messages), + }, true +} + +func (s *Service) storeSessionProviderRuntime(sessionID string, providerKey string, state json.RawMessage, messages []ai.Message) error { + if strings.TrimSpace(providerKey) == "" { + return nil + } + + runtimeState := aiSessionProviderRuntime{ + ProviderKey: providerKey, + State: append(json.RawMessage(nil), state...), + Messages: cloneAIMessages(messages), + } + s.mu.Lock() + if len(state) == 0 { + delete(s.sessionProviders, sessionID) + } else { + s.sessionProviders[sessionID] = runtimeState + } + s.mu.Unlock() + + sessionData, err := s.loadOrCreateSessionFile(sessionID) + if err != nil { + return err + } + if len(state) == 0 { + sessionData.ProviderKey = "" + sessionData.ProviderState = nil + sessionData.ProviderMessages = nil + return s.saveSessionFile(sessionID, sessionData) + } + + sessionData.ProviderKey = providerKey + sessionData.ProviderState = append(json.RawMessage(nil), state...) + if len(messages) == 0 { + sessionData.ProviderMessages = nil + } else { + messageBytes, err := json.Marshal(messages) + if err != nil { + return fmt.Errorf("序列化会话 Provider 消息失败: %w", err) + } + sessionData.ProviderMessages = json.RawMessage(messageBytes) + } + return s.saveSessionFile(sessionID, sessionData) } // --- 配置持久化 --- @@ -1363,16 +1646,69 @@ func normalizeUserPromptText(value string) string { // sessionFileData 会话文件的 JSON 结构 type sessionFileData struct { - ID string `json:"id"` - Title string `json:"title"` - UpdatedAt int64 `json:"updatedAt"` - Messages json.RawMessage `json:"messages"` // 透传前端格式,后端不解析消息体 + ID string `json:"id"` + Title string `json:"title"` + UpdatedAt int64 `json:"updatedAt"` + Messages json.RawMessage `json:"messages"` // 透传前端格式,后端不解析消息体 + ProviderKey string `json:"providerKey,omitempty"` + ProviderState json.RawMessage `json:"providerState,omitempty"` + ProviderMessages json.RawMessage `json:"providerMessages,omitempty"` } func (s *Service) sessionsDir() string { return filepath.Join(s.configDir, "sessions") } +func (s *Service) sessionFilePath(sessionID string) string { + return filepath.Join(s.sessionsDir(), sessionID+".json") +} + +func (s *Service) loadSessionFile(sessionID string) (sessionFileData, error) { + data, err := os.ReadFile(s.sessionFilePath(sessionID)) + if err != nil { + return sessionFileData{}, err + } + var sessionData sessionFileData + if err := json.Unmarshal(data, &sessionData); err != nil { + return sessionFileData{}, err + } + return sessionData, nil +} + +func (s *Service) loadOrCreateSessionFile(sessionID string) (sessionFileData, error) { + sessionData, err := s.loadSessionFile(sessionID) + if err == nil { + return sessionData, nil + } + if !os.IsNotExist(err) { + return sessionFileData{}, err + } + return sessionFileData{ + ID: sessionID, + Title: "新的对话", + UpdatedAt: time.Now().UnixMilli(), + Messages: json.RawMessage("[]"), + }, nil +} + +func (s *Service) saveSessionFile(sessionID string, sessionData sessionFileData) error { + dir := s.sessionsDir() + if err := os.MkdirAll(dir, 0o755); err != nil { + return fmt.Errorf("创建 sessions 目录失败: %w", err) + } + if strings.TrimSpace(sessionData.ID) == "" { + sessionData.ID = sessionID + } + if len(sessionData.Messages) == 0 { + sessionData.Messages = json.RawMessage("[]") + } + data, err := json.Marshal(sessionData) + if err != nil { + return fmt.Errorf("序列化会话数据失败: %w", err) + } + return os.WriteFile(s.sessionFilePath(sessionID), data, 0o644) +} + // AIGetSessions 获取所有会话的元数据列表(不含消息体) func (s *Service) AIGetSessions() []map[string]interface{} { dir := s.sessionsDir() @@ -1417,53 +1753,40 @@ func (s *Service) AIGetSessions() []map[string]interface{} { // AILoadSession 加载指定会话的完整数据(含消息) func (s *Service) AILoadSession(sessionID string) map[string]interface{} { - path := filepath.Join(s.sessionsDir(), sessionID+".json") - data, err := os.ReadFile(path) + sessionData, err := s.loadSessionFile(sessionID) if err != nil { return map[string]interface{}{"success": false, "error": "会话不存在"} } - var sfd sessionFileData - if err := json.Unmarshal(data, &sfd); err != nil { - return map[string]interface{}{"success": false, "error": "会话数据损坏"} - } return map[string]interface{}{ "success": true, - "id": sfd.ID, - "title": sfd.Title, - "updatedAt": sfd.UpdatedAt, - "messages": sfd.Messages, + "id": sessionData.ID, + "title": sessionData.Title, + "updatedAt": sessionData.UpdatedAt, + "messages": sessionData.Messages, } } // AISaveSession 保存会话数据到文件 func (s *Service) AISaveSession(sessionID string, title string, updatedAt float64, messagesJSON string) error { - dir := s.sessionsDir() - if err := os.MkdirAll(dir, 0o755); err != nil { - return fmt.Errorf("创建 sessions 目录失败: %w", err) - } - - sfd := sessionFileData{ - ID: sessionID, - Title: title, - UpdatedAt: int64(updatedAt), - Messages: json.RawMessage(messagesJSON), - } - - data, err := json.Marshal(sfd) + sessionData, err := s.loadOrCreateSessionFile(sessionID) if err != nil { - return fmt.Errorf("序列化会话数据失败: %w", err) + return err } - - path := filepath.Join(dir, sessionID+".json") - return os.WriteFile(path, data, 0o644) + sessionData.ID = sessionID + sessionData.Title = title + sessionData.UpdatedAt = int64(updatedAt) + sessionData.Messages = json.RawMessage(messagesJSON) + return s.saveSessionFile(sessionID, sessionData) } // AIDeleteSession 删除会话文件 func (s *Service) AIDeleteSession(sessionID string) error { - path := filepath.Join(s.sessionsDir(), sessionID+".json") - if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + if err := os.Remove(s.sessionFilePath(sessionID)); err != nil && !os.IsNotExist(err) { return fmt.Errorf("删除会话失败: %w", err) } + s.mu.Lock() + delete(s.sessionProviders, sessionID) + s.mu.Unlock() return nil } diff --git a/internal/ai/service/service_cursor_test.go b/internal/ai/service/service_cursor_test.go index cc2a5fe..374b860 100644 --- a/internal/ai/service/service_cursor_test.go +++ b/internal/ai/service/service_cursor_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "reflect" "testing" "GoNavi-Wails/internal/ai" @@ -99,3 +100,192 @@ func TestAIListModels_FetchesCursorModelItems(t *testing.T) { t.Fatalf("expected api source, got %#v", result["source"]) } } + +func TestResolveSessionProviderRequest_ReusesStoredStateOnlyForHistoryExtension(t *testing.T) { + service := NewService() + service.sessionProviders["session-1"] = aiSessionProviderRuntime{ + ProviderKey: "cursor-provider", + State: json.RawMessage(`{"agentId":"bc-1"}`), + Messages: []ai.Message{ + {Role: "user", Content: "hello"}, + {Role: "assistant", Content: "world"}, + }, + } + + state, delta := service.resolveSessionProviderRequest("session-1", "cursor-provider", []ai.Message{ + {Role: "user", Content: "hello"}, + {Role: "assistant", Content: "world"}, + {Role: "user", Content: "next"}, + }) + if string(state) != `{"agentId":"bc-1"}` { + t.Fatalf("expected stored provider state, got %s", string(state)) + } + expectedDelta := []ai.Message{{Role: "user", Content: "next"}} + if !reflect.DeepEqual(delta, expectedDelta) { + t.Fatalf("unexpected delta messages: %#v", delta) + } + + state, delta = service.resolveSessionProviderRequest("session-1", "cursor-provider", []ai.Message{ + {Role: "user", Content: "hello changed"}, + }) + if len(state) != 0 { + t.Fatalf("expected mismatched history to reset provider state, got %s", string(state)) + } + if len(delta) != 1 || delta[0].Content != "hello changed" { + t.Fatalf("expected full messages after mismatch, got %#v", delta) + } +} + +func TestAISaveSession_PreservesProviderRuntimeMetadata(t *testing.T) { + service := NewService() + service.configDir = t.TempDir() + + err := service.storeSessionProviderRuntime( + "session-1", + "cursor-provider", + json.RawMessage(`{"agentId":"bc-1","lastRunId":"run-1"}`), + []ai.Message{{Role: "user", Content: "hello"}}, + ) + if err != nil { + t.Fatalf("store provider runtime: %v", err) + } + + err = service.AISaveSession("session-1", "标题", 123, `[{"id":"m1","role":"user","content":"hello","timestamp":1}]`) + if err != nil { + t.Fatalf("save session: %v", err) + } + + sessionData, err := service.loadSessionFile("session-1") + if err != nil { + t.Fatalf("load session file: %v", err) + } + if sessionData.ProviderKey != "cursor-provider" { + t.Fatalf("expected provider key to be preserved, got %q", sessionData.ProviderKey) + } + if string(sessionData.ProviderState) != `{"agentId":"bc-1","lastRunId":"run-1"}` { + t.Fatalf("expected provider state to be preserved, got %s", string(sessionData.ProviderState)) + } + var providerMessages []ai.Message + if err := json.Unmarshal(sessionData.ProviderMessages, &providerMessages); err != nil { + t.Fatalf("unmarshal provider messages: %v", err) + } + if len(providerMessages) != 1 || providerMessages[0].Content != "hello" { + t.Fatalf("unexpected provider messages: %#v", providerMessages) + } +} + +func TestAIChatSendInSession_ReusesCursorProviderStateAndPersistsFollowUpRuns(t *testing.T) { + var ( + createAgentCalls int + createRunCalls int + createRunPrompt string + ) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodPost && r.URL.Path == "/v1/agents": + createAgentCalls++ + var body struct { + Prompt struct { + Text string `json:"text"` + } `json:"prompt"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("decode create agent body: %v", err) + } + if body.Prompt.Text == "" { + t.Fatalf("expected first prompt text") + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "agent": map[string]any{"id": "bc-1"}, + "run": map[string]any{"id": "run-1", "agentId": "bc-1"}, + }) + case r.Method == http.MethodGet && r.URL.Path == "/v1/agents/bc-1/runs/run-1": + _ = json.NewEncoder(w).Encode(map[string]any{ + "id": "run-1", + "agentId": "bc-1", + "status": "FINISHED", + "result": "first answer", + "durationMs": 100, + }) + case r.Method == http.MethodPost && r.URL.Path == "/v1/agents/bc-1/runs": + createRunCalls++ + var body struct { + Prompt struct { + Text string `json:"text"` + } `json:"prompt"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("decode follow-up run body: %v", err) + } + createRunPrompt = body.Prompt.Text + _ = json.NewEncoder(w).Encode(map[string]any{ + "run": map[string]any{"id": "run-2"}, + }) + case r.Method == http.MethodGet && r.URL.Path == "/v1/agents/bc-1/runs/run-2": + _ = json.NewEncoder(w).Encode(map[string]any{ + "id": "run-2", + "agentId": "bc-1", + "status": "FINISHED", + "result": "second answer", + "durationMs": 120, + }) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + service := NewService() + service.configDir = t.TempDir() + service.providers = []ai.ProviderConfig{ + { + ID: "provider-cursor", + Type: "custom", + APIFormat: "cursor-agent", + BaseURL: server.URL + "/v1", + APIKey: "cursor-key", + }, + } + service.activeProvider = "provider-cursor" + + firstResult := service.AIChatSendInSession("session-1", []ai.Message{ + {Role: "user", Content: "hello"}, + }, nil) + if firstResult["success"] != true { + t.Fatalf("expected first send to succeed, got %#v", firstResult) + } + secondResult := service.AIChatSendInSession("session-1", []ai.Message{ + {Role: "user", Content: "hello"}, + {Role: "assistant", Content: "first answer"}, + {Role: "user", Content: "next"}, + }, nil) + if secondResult["success"] != true { + t.Fatalf("expected second send to succeed, got %#v", secondResult) + } + + if createAgentCalls != 1 { + t.Fatalf("expected exactly one create-agent call, got %d", createAgentCalls) + } + if createRunCalls != 1 { + t.Fatalf("expected exactly one follow-up run call, got %d", createRunCalls) + } + if createRunPrompt != "next" { + t.Fatalf("expected follow-up run to send only delta message, got %q", createRunPrompt) + } + + sessionData, err := service.loadSessionFile("session-1") + if err != nil { + t.Fatalf("load session file: %v", err) + } + if string(sessionData.ProviderState) != `{"agentId":"bc-1","lastRunId":"run-2"}` { + t.Fatalf("unexpected provider state: %s", string(sessionData.ProviderState)) + } + var providerMessages []ai.Message + if err := json.Unmarshal(sessionData.ProviderMessages, &providerMessages); err != nil { + t.Fatalf("unmarshal provider messages: %v", err) + } + if len(providerMessages) != 4 || providerMessages[3].Content != "second answer" { + t.Fatalf("unexpected provider messages: %#v", providerMessages) + } +}