mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-06-28 17:31:32 +08:00
✨ feat(ai): 补齐 Cursor 与 CodeBuddy 会话态聊天链路
- 新增 SessionChatProvider 接口,补齐非流式对话的会话态复用能力 - 为 Cursor Agent 和 CodeBuddy CLI 同步实现流式与非流式会话续接及状态持久化 - CustomProvider 补充会话态透传,统一 custom provider 的会话复用行为 - Service 新增 AIChatSendInSession,聊天主链路非流式回退改走带 session 的发送接口 - 保留原 AIChatSend 无状态语义,避免标题生成和记忆压缩污染主会话上下文 - 补充前后端定向测试,覆盖会话恢复、续接发送和前端回退分流
This commit is contained in:
@@ -24,6 +24,10 @@ type CodeBuddyCLIProvider struct {
|
||||
config ai.ProviderConfig
|
||||
}
|
||||
|
||||
type codebuddySessionState struct {
|
||||
SessionID string `json:"sessionId,omitempty"`
|
||||
}
|
||||
|
||||
// NewCodeBuddyCLIProvider 创建 CodeBuddyCLIProvider 实例。
|
||||
func NewCodeBuddyCLIProvider(config ai.ProviderConfig) (Provider, error) {
|
||||
return &CodeBuddyCLIProvider{config: config}, nil
|
||||
@@ -42,8 +46,13 @@ func (p *CodeBuddyCLIProvider) Validate() error {
|
||||
}
|
||||
|
||||
func (p *CodeBuddyCLIProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.ChatResponse, error) {
|
||||
resp, _, err := p.ChatWithState(ctx, nil, req)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (p *CodeBuddyCLIProvider) ChatWithState(ctx context.Context, state json.RawMessage, req ai.ChatRequest) (*ai.ChatResponse, json.RawMessage, error) {
|
||||
if err := p.Validate(); err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := ensureClaudeCLITimeout(ctx, codebuddyCLIRequestTimeout)
|
||||
@@ -51,18 +60,25 @@ func (p *CodeBuddyCLIProvider) Chat(ctx context.Context, req ai.ChatRequest) (*a
|
||||
|
||||
commandName, err := resolveCodeBuddyCLICommand(codebuddyLookPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sessionState, err := parseCodeBuddySessionState(state)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
prompt := buildPrompt(req.Messages)
|
||||
args := []string{"-p", prompt, "--output-format", "json", "--no-session-persistence"}
|
||||
args := []string{"-p", prompt, "--output-format", "json", "--enable-session-tracking"}
|
||||
if strings.TrimSpace(p.config.Model) != "" {
|
||||
args = append(args, "--model", strings.TrimSpace(p.config.Model))
|
||||
}
|
||||
if strings.TrimSpace(sessionState.SessionID) != "" {
|
||||
args = append(args, "--resume", strings.TrimSpace(sessionState.SessionID))
|
||||
}
|
||||
|
||||
cmd := codebuddyCommandContext(ctx, commandName, args...)
|
||||
if err := p.setEnv(cmd); err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
requestLog := logAIUpstreamRequestStart(
|
||||
@@ -80,27 +96,55 @@ func (p *CodeBuddyCLIProvider) Chat(ctx context.Context, req ai.ChatRequest) (*a
|
||||
if err != nil {
|
||||
if isClaudeCLITimeout(ctx, err) {
|
||||
requestErr = fmt.Errorf("CodeBuddy CLI 执行超时(%s),当前登录态、Base URL 或 API Key 可能没有返回有效响应", codebuddyCLIRequestTimeout)
|
||||
return nil, requestErr
|
||||
return nil, nil, requestErr
|
||||
}
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
requestErr = fmt.Errorf("CodeBuddy CLI 执行失败: %s", string(exitErr.Stderr))
|
||||
return nil, requestErr
|
||||
return nil, nil, requestErr
|
||||
}
|
||||
requestErr = fmt.Errorf("CodeBuddy CLI 执行失败: %w", err)
|
||||
return nil, requestErr
|
||||
return nil, nil, requestErr
|
||||
}
|
||||
|
||||
resp, parseErr := parseCodeBuddyCLIChatOutput(output)
|
||||
resp, nextSessionID, parseErr := parseCodeBuddyCLIChatOutput(output)
|
||||
if parseErr != nil {
|
||||
requestErr = parseErr
|
||||
return nil, requestErr
|
||||
return nil, nil, requestErr
|
||||
}
|
||||
return resp, nil
|
||||
if strings.TrimSpace(nextSessionID) == "" {
|
||||
nextSessionID = strings.TrimSpace(sessionState.SessionID)
|
||||
}
|
||||
nextState, err := marshalCodeBuddySessionState(nextSessionID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return resp, nextState, nil
|
||||
}
|
||||
|
||||
func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error {
|
||||
_, err := p.ChatStreamWithState(ctx, nil, req, callback)
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *CodeBuddyCLIProvider) ChatStreamWithState(ctx context.Context, state json.RawMessage, req ai.ChatRequest, callback func(ai.StreamChunk)) (json.RawMessage, error) {
|
||||
sessionState, err := parseCodeBuddySessionState(state)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sessionID, err := p.chatStreamWithSession(ctx, strings.TrimSpace(sessionState.SessionID), req, callback)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(sessionID) == "" {
|
||||
sessionID = strings.TrimSpace(sessionState.SessionID)
|
||||
}
|
||||
return marshalCodeBuddySessionState(sessionID)
|
||||
}
|
||||
|
||||
func (p *CodeBuddyCLIProvider) chatStreamWithSession(ctx context.Context, resumeSessionID string, req ai.ChatRequest, callback func(ai.StreamChunk)) (string, error) {
|
||||
if err := p.Validate(); err != nil {
|
||||
return err
|
||||
return "", err
|
||||
}
|
||||
|
||||
ctx, cancel := ensureClaudeCLITimeout(ctx, codebuddyCLIRequestTimeout)
|
||||
@@ -108,18 +152,21 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques
|
||||
|
||||
commandName, err := resolveCodeBuddyCLICommand(codebuddyLookPath)
|
||||
if err != nil {
|
||||
return err
|
||||
return "", err
|
||||
}
|
||||
|
||||
prompt := buildPrompt(req.Messages)
|
||||
args := []string{"-p", prompt, "--output-format", "stream-json", "--verbose", "--include-partial-messages", "--no-session-persistence"}
|
||||
args := []string{"-p", prompt, "--output-format", "stream-json", "--verbose", "--include-partial-messages", "--enable-session-tracking"}
|
||||
if strings.TrimSpace(p.config.Model) != "" {
|
||||
args = append(args, "--model", strings.TrimSpace(p.config.Model))
|
||||
}
|
||||
if strings.TrimSpace(resumeSessionID) != "" {
|
||||
args = append(args, "--resume", strings.TrimSpace(resumeSessionID))
|
||||
}
|
||||
|
||||
cmd := codebuddyCommandContext(ctx, commandName, args...)
|
||||
if err := p.setEnv(cmd); err != nil {
|
||||
return err
|
||||
return "", err
|
||||
}
|
||||
|
||||
requestLog := logAIUpstreamRequestStart(
|
||||
@@ -138,7 +185,7 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
requestErr = fmt.Errorf("创建 stdout 管道失败: %w", err)
|
||||
return requestErr
|
||||
return "", requestErr
|
||||
}
|
||||
|
||||
var stderrBuf bytes.Buffer
|
||||
@@ -146,7 +193,7 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
requestErr = fmt.Errorf("启动 CodeBuddy CLI 失败: %w", err)
|
||||
return requestErr
|
||||
return "", requestErr
|
||||
}
|
||||
|
||||
if cmd.Process != nil {
|
||||
@@ -155,6 +202,7 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques
|
||||
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
|
||||
currentSessionID := strings.TrimSpace(resumeSessionID)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
@@ -167,6 +215,9 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques
|
||||
logger.Warnf("CodeBuddyCLI 忽略非 JSON 输出:requestId=%s line=%s", requestLog.id, RedactAIUpstreamLogText(line))
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(event.SessionID) != "" {
|
||||
currentSessionID = strings.TrimSpace(event.SessionID)
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case "system":
|
||||
@@ -178,7 +229,7 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
return nil
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
case "assistant":
|
||||
@@ -186,7 +237,7 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques
|
||||
callback(ai.StreamChunk{Error: errMsg, Done: true})
|
||||
requestErr = fmt.Errorf("CodeBuddy CLI 返回错误: %s", errMsg)
|
||||
_ = cmd.Wait()
|
||||
return nil
|
||||
return "", nil
|
||||
}
|
||||
if event.Message.Content != nil {
|
||||
for _, block := range event.Message.Content {
|
||||
@@ -208,17 +259,17 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques
|
||||
callback(ai.StreamChunk{Error: errMsg, Done: true})
|
||||
requestErr = fmt.Errorf("CodeBuddy CLI 返回错误: %s", errMsg)
|
||||
_ = cmd.Wait()
|
||||
return nil
|
||||
return "", nil
|
||||
}
|
||||
callback(ai.StreamChunk{Done: true})
|
||||
_ = cmd.Wait()
|
||||
return nil
|
||||
return currentSessionID, nil
|
||||
case "error":
|
||||
errMsg, _ := extractCodeBuddyCLIEventError(event)
|
||||
callback(ai.StreamChunk{Error: errMsg, Done: true})
|
||||
requestErr = fmt.Errorf("CodeBuddy CLI 返回错误: %s", errMsg)
|
||||
_ = cmd.Wait()
|
||||
return nil
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -231,7 +282,7 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques
|
||||
Error: requestErr.Error(),
|
||||
Done: true,
|
||||
})
|
||||
return nil
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if waitErr != nil {
|
||||
@@ -241,11 +292,38 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques
|
||||
}
|
||||
requestErr = fmt.Errorf("%s", errMsg)
|
||||
callback(ai.StreamChunk{Error: errMsg, Done: true})
|
||||
return nil
|
||||
return "", nil
|
||||
}
|
||||
|
||||
callback(ai.StreamChunk{Done: true})
|
||||
return nil
|
||||
return currentSessionID, nil
|
||||
}
|
||||
|
||||
func parseCodeBuddySessionState(state json.RawMessage) (codebuddySessionState, error) {
|
||||
trimmed := bytes.TrimSpace(state)
|
||||
if len(trimmed) == 0 {
|
||||
return codebuddySessionState{}, nil
|
||||
}
|
||||
|
||||
var sessionState codebuddySessionState
|
||||
if err := json.Unmarshal(trimmed, &sessionState); err != nil {
|
||||
return codebuddySessionState{}, fmt.Errorf("解析 CodeBuddy 会话状态失败: %w", err)
|
||||
}
|
||||
sessionState.SessionID = strings.TrimSpace(sessionState.SessionID)
|
||||
return sessionState, nil
|
||||
}
|
||||
|
||||
func marshalCodeBuddySessionState(sessionID string) (json.RawMessage, error) {
|
||||
sessionID = strings.TrimSpace(sessionID)
|
||||
if sessionID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(codebuddySessionState{SessionID: sessionID})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化 CodeBuddy 会话状态失败: %w", err)
|
||||
}
|
||||
return json.RawMessage(payload), nil
|
||||
}
|
||||
|
||||
func resolveCodeBuddyCLICommand(lookPath func(string) (string, error)) (string, error) {
|
||||
@@ -280,10 +358,10 @@ func buildCodeBuddyCLIRequestLogBody(outputFormat string, commandName string, ar
|
||||
}
|
||||
}
|
||||
|
||||
func parseCodeBuddyCLIChatOutput(output []byte) (*ai.ChatResponse, error) {
|
||||
func parseCodeBuddyCLIChatOutput(output []byte) (*ai.ChatResponse, string, error) {
|
||||
trimmed := bytes.TrimSpace(output)
|
||||
if len(trimmed) == 0 {
|
||||
return &ai.ChatResponse{}, nil
|
||||
return &ai.ChatResponse{}, "", nil
|
||||
}
|
||||
|
||||
var events []cliStreamEvent
|
||||
@@ -296,20 +374,24 @@ func parseCodeBuddyCLIChatOutput(output []byte) (*ai.ChatResponse, error) {
|
||||
return buildCodeBuddyCLIResponseFromEvents([]cliStreamEvent{event})
|
||||
}
|
||||
|
||||
return &ai.ChatResponse{Content: strings.TrimSpace(string(output))}, nil
|
||||
return &ai.ChatResponse{Content: strings.TrimSpace(string(output))}, "", nil
|
||||
}
|
||||
|
||||
func buildCodeBuddyCLIResponseFromEvents(events []cliStreamEvent) (*ai.ChatResponse, error) {
|
||||
func buildCodeBuddyCLIResponseFromEvents(events []cliStreamEvent) (*ai.ChatResponse, string, error) {
|
||||
parts := make([]string, 0, len(events))
|
||||
resultText := ""
|
||||
sessionID := ""
|
||||
|
||||
for _, event := range events {
|
||||
if errMsg, hasError := extractCodeBuddyCLIEventError(event); hasError {
|
||||
return nil, fmt.Errorf("CodeBuddy CLI 返回错误: %s", errMsg)
|
||||
return nil, "", fmt.Errorf("CodeBuddy CLI 返回错误: %s", errMsg)
|
||||
}
|
||||
if strings.TrimSpace(event.Result) != "" {
|
||||
resultText = strings.TrimSpace(event.Result)
|
||||
}
|
||||
if strings.TrimSpace(event.SessionID) != "" {
|
||||
sessionID = strings.TrimSpace(event.SessionID)
|
||||
}
|
||||
for _, block := range event.Message.Content {
|
||||
if block.Type == "text" && strings.TrimSpace(block.Text) != "" {
|
||||
parts = append(parts, block.Text)
|
||||
@@ -318,12 +400,12 @@ func buildCodeBuddyCLIResponseFromEvents(events []cliStreamEvent) (*ai.ChatRespo
|
||||
}
|
||||
|
||||
if resultText != "" {
|
||||
return &ai.ChatResponse{Content: resultText}, nil
|
||||
return &ai.ChatResponse{Content: resultText}, sessionID, nil
|
||||
}
|
||||
if len(parts) > 0 {
|
||||
return &ai.ChatResponse{Content: strings.Join(parts, "")}, nil
|
||||
return &ai.ChatResponse{Content: strings.Join(parts, "")}, sessionID, nil
|
||||
}
|
||||
return &ai.ChatResponse{}, nil
|
||||
return &ai.ChatResponse{}, sessionID, nil
|
||||
}
|
||||
|
||||
func (p *CodeBuddyCLIProvider) setEnv(cmd *exec.Cmd) error {
|
||||
|
||||
@@ -2,6 +2,7 @@ package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -79,6 +80,191 @@ func TestCodeBuddyCLIProvider_ChatParsesJSONEventArray(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodeBuddyCLIProviderChatWithState_StartsTrackedSession(t *testing.T) {
|
||||
fakeCodeBuddy := writeFakeCodeBuddyScript(t, "#!/bin/sh\necho '[{\"type\":\"assistant\",\"session_id\":\"session-new\",\"message\":{\"content\":[{\"type\":\"text\",\"text\":\"hello \"}]}},{\"type\":\"result\",\"subtype\":\"success\",\"is_error\":false,\"result\":\"hello world\",\"session_id\":\"session-new\"}]'\n")
|
||||
var capturedArgs []string
|
||||
restore := overrideCodeBuddyCLIForTestWithCapture(t, fakeCodeBuddy, func(args []string) {
|
||||
capturedArgs = append([]string(nil), args...)
|
||||
})
|
||||
defer restore()
|
||||
|
||||
providerInstance, err := NewCodeBuddyCLIProvider(ai.ProviderConfig{
|
||||
APIKey: "cb-test",
|
||||
Model: "deepseek-v3",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected provider error: %v", err)
|
||||
}
|
||||
|
||||
resp, nextState, err := providerInstance.(SessionChatProvider).ChatWithState(
|
||||
context.Background(),
|
||||
nil,
|
||||
ai.ChatRequest{
|
||||
Messages: []ai.Message{{Role: "user", Content: "ping"}},
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("expected chat with state to succeed, got %v", err)
|
||||
}
|
||||
|
||||
if resp == nil || resp.Content != "hello world" {
|
||||
t.Fatalf("unexpected response: %#v", resp)
|
||||
}
|
||||
if string(nextState) != `{"sessionId":"session-new"}` {
|
||||
t.Fatalf("expected new session state, got %s", string(nextState))
|
||||
}
|
||||
if !hasArg(capturedArgs, "--enable-session-tracking") {
|
||||
t.Fatalf("expected session tracking flag, got args %#v", capturedArgs)
|
||||
}
|
||||
if hasArg(capturedArgs, "--no-session-persistence") {
|
||||
t.Fatalf("did not expect no-session-persistence flag, got args %#v", capturedArgs)
|
||||
}
|
||||
if hasArg(capturedArgs, "--resume") {
|
||||
t.Fatalf("did not expect resume flag for first session, got args %#v", capturedArgs)
|
||||
}
|
||||
if !hasArgSequence(capturedArgs, "--model", "deepseek-v3") {
|
||||
t.Fatalf("expected model flag to be preserved, got args %#v", capturedArgs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodeBuddyCLIProviderChatWithState_ResumesExistingSession(t *testing.T) {
|
||||
fakeCodeBuddy := writeFakeCodeBuddyScript(t, "#!/bin/sh\necho '[{\"type\":\"assistant\",\"message\":{\"content\":[{\"type\":\"text\",\"text\":\"continued\"}]}}]'\n")
|
||||
var capturedArgs []string
|
||||
restore := overrideCodeBuddyCLIForTestWithCapture(t, fakeCodeBuddy, func(args []string) {
|
||||
capturedArgs = append([]string(nil), args...)
|
||||
})
|
||||
defer restore()
|
||||
|
||||
providerInstance, err := NewCodeBuddyCLIProvider(ai.ProviderConfig{
|
||||
APIKey: "cb-test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected provider error: %v", err)
|
||||
}
|
||||
|
||||
resp, nextState, err := providerInstance.(SessionChatProvider).ChatWithState(
|
||||
context.Background(),
|
||||
json.RawMessage(`{"sessionId":"session-existing"}`),
|
||||
ai.ChatRequest{
|
||||
Messages: []ai.Message{{Role: "user", Content: "ping again"}},
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("expected resumed chat with state to succeed, got %v", err)
|
||||
}
|
||||
|
||||
if resp == nil || resp.Content != "continued" {
|
||||
t.Fatalf("unexpected response: %#v", resp)
|
||||
}
|
||||
if string(nextState) != `{"sessionId":"session-existing"}` {
|
||||
t.Fatalf("expected existing session state to be preserved, got %s", string(nextState))
|
||||
}
|
||||
if !hasArgSequence(capturedArgs, "--resume", "session-existing") {
|
||||
t.Fatalf("expected resume args, got %#v", capturedArgs)
|
||||
}
|
||||
if !hasArg(capturedArgs, "--enable-session-tracking") {
|
||||
t.Fatalf("expected session tracking flag, got args %#v", capturedArgs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodeBuddyCLIProviderChatStreamWithState_StartsTrackedSession(t *testing.T) {
|
||||
fakeCodeBuddy := writeFakeCodeBuddyScript(t, "#!/bin/sh\nprintf '%s\\n' '{\"type\":\"system\",\"session_id\":\"session-new\"}' '{\"type\":\"assistant\",\"message\":{\"content\":[{\"type\":\"text\",\"text\":\"hello from codebuddy\"}]}}' '{\"type\":\"result\",\"subtype\":\"success\",\"is_error\":false,\"result\":\"hello from codebuddy\",\"session_id\":\"session-new\"}'\n")
|
||||
var capturedArgs []string
|
||||
restore := overrideCodeBuddyCLIForTestWithCapture(t, fakeCodeBuddy, func(args []string) {
|
||||
capturedArgs = append([]string(nil), args...)
|
||||
})
|
||||
defer restore()
|
||||
|
||||
providerInstance, err := NewCodeBuddyCLIProvider(ai.ProviderConfig{
|
||||
APIKey: "cb-test",
|
||||
Model: "deepseek-v3",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected provider error: %v", err)
|
||||
}
|
||||
|
||||
var chunks []ai.StreamChunk
|
||||
nextState, err := providerInstance.(SessionStreamProvider).ChatStreamWithState(
|
||||
context.Background(),
|
||||
nil,
|
||||
ai.ChatRequest{
|
||||
Messages: []ai.Message{{Role: "user", Content: "ping"}},
|
||||
},
|
||||
func(chunk ai.StreamChunk) {
|
||||
chunks = append(chunks, chunk)
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("expected chat stream with state to succeed, got %v", err)
|
||||
}
|
||||
|
||||
if string(nextState) != `{"sessionId":"session-new"}` {
|
||||
t.Fatalf("expected new session state, got %s", string(nextState))
|
||||
}
|
||||
if len(chunks) < 2 || chunks[0].Content != "hello from codebuddy" || !chunks[len(chunks)-1].Done {
|
||||
t.Fatalf("unexpected stream chunks: %#v", chunks)
|
||||
}
|
||||
if !hasArg(capturedArgs, "--enable-session-tracking") {
|
||||
t.Fatalf("expected session tracking flag, got args %#v", capturedArgs)
|
||||
}
|
||||
if hasArg(capturedArgs, "--no-session-persistence") {
|
||||
t.Fatalf("did not expect no-session-persistence flag, got args %#v", capturedArgs)
|
||||
}
|
||||
if hasArg(capturedArgs, "--resume") {
|
||||
t.Fatalf("did not expect resume flag for first session, got args %#v", capturedArgs)
|
||||
}
|
||||
if !hasArgSequence(capturedArgs, "--model", "deepseek-v3") {
|
||||
t.Fatalf("expected model flag to be preserved, got args %#v", capturedArgs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodeBuddyCLIProviderChatStreamWithState_ResumesExistingSessionWithoutDroppingState(t *testing.T) {
|
||||
fakeCodeBuddy := writeFakeCodeBuddyScript(t, "#!/bin/sh\nprintf '%s\\n' '{\"type\":\"assistant\",\"message\":{\"content\":[{\"type\":\"text\",\"text\":\"continued\"}]}}' '{\"type\":\"result\",\"subtype\":\"success\",\"is_error\":false,\"result\":\"continued\"}'\n")
|
||||
var capturedArgs []string
|
||||
restore := overrideCodeBuddyCLIForTestWithCapture(t, fakeCodeBuddy, func(args []string) {
|
||||
capturedArgs = append([]string(nil), args...)
|
||||
})
|
||||
defer restore()
|
||||
|
||||
providerInstance, err := NewCodeBuddyCLIProvider(ai.ProviderConfig{
|
||||
APIKey: "cb-test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected provider error: %v", err)
|
||||
}
|
||||
|
||||
var chunks []ai.StreamChunk
|
||||
nextState, err := providerInstance.(SessionStreamProvider).ChatStreamWithState(
|
||||
context.Background(),
|
||||
json.RawMessage(`{"sessionId":"session-existing"}`),
|
||||
ai.ChatRequest{
|
||||
Messages: []ai.Message{{Role: "user", Content: "ping again"}},
|
||||
},
|
||||
func(chunk ai.StreamChunk) {
|
||||
chunks = append(chunks, chunk)
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("expected resumed chat stream to succeed, got %v", err)
|
||||
}
|
||||
|
||||
if string(nextState) != `{"sessionId":"session-existing"}` {
|
||||
t.Fatalf("expected existing session state to be preserved, got %s", string(nextState))
|
||||
}
|
||||
if len(chunks) < 2 || chunks[0].Content != "continued" || !chunks[len(chunks)-1].Done {
|
||||
t.Fatalf("unexpected stream chunks: %#v", chunks)
|
||||
}
|
||||
if !hasArgSequence(capturedArgs, "--resume", "session-existing") {
|
||||
t.Fatalf("expected resume args, got %#v", capturedArgs)
|
||||
}
|
||||
if !hasArg(capturedArgs, "--enable-session-tracking") {
|
||||
t.Fatalf("expected session tracking flag, got args %#v", capturedArgs)
|
||||
}
|
||||
if hasArg(capturedArgs, "--no-session-persistence") {
|
||||
t.Fatalf("did not expect no-session-persistence flag, got args %#v", capturedArgs)
|
||||
}
|
||||
}
|
||||
|
||||
func writeFakeCodeBuddyScript(t *testing.T, content string) string {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
@@ -138,3 +324,54 @@ func overrideCodeBuddyCLIForTest(t *testing.T, fakeCodeBuddyPath string) func()
|
||||
_ = os.Setenv("PATH", originalPath)
|
||||
}
|
||||
}
|
||||
|
||||
func overrideCodeBuddyCLIForTestWithCapture(t *testing.T, fakeCodeBuddyPath string, capture func(args []string)) func() {
|
||||
t.Helper()
|
||||
|
||||
originalLookPath := codebuddyLookPath
|
||||
originalCommandContext := codebuddyCommandContext
|
||||
codebuddyLookPath = func(name string) (string, error) {
|
||||
if name == "codebuddy" || name == "cbc" {
|
||||
return fakeCodeBuddyPath, nil
|
||||
}
|
||||
return originalLookPath(name)
|
||||
}
|
||||
codebuddyCommandContext = func(ctx context.Context, name string, args ...string) *exec.Cmd {
|
||||
if name == "codebuddy" || name == "cbc" {
|
||||
if capture != nil {
|
||||
capture(args)
|
||||
}
|
||||
return exec.CommandContext(ctx, fakeCodeBuddyPath, args...)
|
||||
}
|
||||
return originalCommandContext(ctx, name, args...)
|
||||
}
|
||||
|
||||
originalPath := os.Getenv("PATH")
|
||||
if err := os.Setenv("PATH", filepath.Dir(fakeCodeBuddyPath)+string(os.PathListSeparator)+originalPath); err != nil {
|
||||
t.Fatalf("failed to override PATH: %v", err)
|
||||
}
|
||||
|
||||
return func() {
|
||||
codebuddyLookPath = originalLookPath
|
||||
codebuddyCommandContext = originalCommandContext
|
||||
_ = os.Setenv("PATH", originalPath)
|
||||
}
|
||||
}
|
||||
|
||||
func hasArg(args []string, target string) bool {
|
||||
for _, arg := range args {
|
||||
if arg == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func hasArgSequence(args []string, key string, value string) bool {
|
||||
for index := 0; index < len(args)-1; index++ {
|
||||
if args[index] == key && args[index+1] == value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -22,13 +22,24 @@ const (
|
||||
)
|
||||
|
||||
// CursorAgentProvider 通过 Cursor Cloud Agents API 发起对话。
|
||||
// 当前实现为无状态适配:每次请求都创建一个新的 agent,再消费本次 run 的结果。
|
||||
// 支持基于 session state 复用已有 agent,并对 follow-up runs 继续追加上下文。
|
||||
type CursorAgentProvider struct {
|
||||
config ai.ProviderConfig
|
||||
baseURL string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
type cursorSessionState struct {
|
||||
AgentID string `json:"agentId,omitempty"`
|
||||
LastRunID string `json:"lastRunId,omitempty"`
|
||||
}
|
||||
|
||||
type cursorImageInput struct {
|
||||
Data string `json:"data,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
MimeType string `json:"mimeType,omitempty"`
|
||||
}
|
||||
|
||||
// NewCursorAgentProvider 创建 Cursor Agent Provider。
|
||||
func NewCursorAgentProvider(config ai.ProviderConfig) (Provider, error) {
|
||||
normalized := config
|
||||
@@ -134,7 +145,8 @@ func normalizeCursorAPIPath(path string) string {
|
||||
}
|
||||
|
||||
type cursorPrompt struct {
|
||||
Text string `json:"text"`
|
||||
Text string `json:"text"`
|
||||
Images []cursorImageInput `json:"images,omitempty"`
|
||||
}
|
||||
|
||||
type cursorModelSelection struct {
|
||||
@@ -146,6 +158,10 @@ type cursorCreateAgentRequest struct {
|
||||
Model *cursorModelSelection `json:"model,omitempty"`
|
||||
}
|
||||
|
||||
type cursorCreateRunRequest struct {
|
||||
Prompt cursorPrompt `json:"prompt"`
|
||||
}
|
||||
|
||||
type cursorCreateAgentResponse struct {
|
||||
Agent struct {
|
||||
ID string `json:"id"`
|
||||
@@ -181,38 +197,85 @@ type cursorResultEvent struct {
|
||||
}
|
||||
|
||||
func (p *CursorAgentProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.ChatResponse, error) {
|
||||
resp, _, err := p.ChatWithState(ctx, nil, req)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (p *CursorAgentProvider) ChatWithState(ctx context.Context, state json.RawMessage, req ai.ChatRequest) (*ai.ChatResponse, json.RawMessage, error) {
|
||||
if err := p.Validate(); err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
agentID, runID, err := p.createAgent(ctx, req)
|
||||
sessionState, err := parseCursorSessionState(state)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
agentID := strings.TrimSpace(sessionState.AgentID)
|
||||
runID := ""
|
||||
if agentID == "" {
|
||||
agentID, runID, err = p.createAgent(ctx, req)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
} else {
|
||||
runID, err = p.createRun(ctx, agentID, req)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
run, err := p.waitForRun(ctx, agentID, runID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sessionState.AgentID = agentID
|
||||
sessionState.LastRunID = runID
|
||||
nextState, err := json.Marshal(sessionState)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("序列化 Cursor 会话状态失败: %w", err)
|
||||
}
|
||||
|
||||
return &ai.ChatResponse{
|
||||
Content: strings.TrimSpace(run.Result),
|
||||
}, nil
|
||||
}, json.RawMessage(nextState), nil
|
||||
}
|
||||
|
||||
func (p *CursorAgentProvider) ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error {
|
||||
_, err := p.ChatStreamWithState(ctx, nil, req, callback)
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *CursorAgentProvider) ChatStreamWithState(ctx context.Context, state json.RawMessage, req ai.ChatRequest, callback func(ai.StreamChunk)) (json.RawMessage, error) {
|
||||
if err := p.Validate(); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
agentID, runID, err := p.createAgent(ctx, req)
|
||||
sessionState, err := parseCursorSessionState(state)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
agentID := strings.TrimSpace(sessionState.AgentID)
|
||||
runID := ""
|
||||
if agentID == "" {
|
||||
agentID, runID, err = p.createAgent(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
runID, err = p.createRun(ctx, agentID, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
sessionState.AgentID = agentID
|
||||
sessionState.LastRunID = runID
|
||||
|
||||
stream, err := p.openRunStream(ctx, agentID, runID)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
@@ -314,10 +377,10 @@ func (p *CursorAgentProvider) ChatStream(ctx context.Context, req ai.ChatRequest
|
||||
currentEventType = ""
|
||||
currentDataLines = nil
|
||||
if dispatchErr != nil {
|
||||
return dispatchErr
|
||||
return nil, dispatchErr
|
||||
}
|
||||
if done {
|
||||
return nil
|
||||
return marshalCursorSessionState(sessionState)
|
||||
}
|
||||
case strings.HasPrefix(line, "event:"):
|
||||
currentEventType = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
|
||||
@@ -327,27 +390,27 @@ func (p *CursorAgentProvider) ChatStream(ctx context.Context, req ai.ChatRequest
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return fmt.Errorf("读取 Cursor 流式响应失败: %w", err)
|
||||
return nil, fmt.Errorf("读取 Cursor 流式响应失败: %w", err)
|
||||
}
|
||||
|
||||
if len(currentDataLines) > 0 || strings.TrimSpace(currentEventType) != "" {
|
||||
done, dispatchErr := dispatchEvent(currentEventType, currentDataLines)
|
||||
if dispatchErr != nil {
|
||||
return dispatchErr
|
||||
return nil, dispatchErr
|
||||
}
|
||||
if done {
|
||||
return nil
|
||||
return marshalCursorSessionState(sessionState)
|
||||
}
|
||||
}
|
||||
|
||||
if !completedExplicitly {
|
||||
if !receivedAssistantText && !receivedResultText {
|
||||
callback(ai.StreamChunk{Error: "未收到任何有效响应内容,请检查 Cursor 配置或模型权限", Done: true})
|
||||
return nil
|
||||
return marshalCursorSessionState(sessionState)
|
||||
}
|
||||
callback(ai.StreamChunk{Done: true})
|
||||
}
|
||||
return nil
|
||||
return marshalCursorSessionState(sessionState)
|
||||
}
|
||||
|
||||
func (p *CursorAgentProvider) createAgent(ctx context.Context, req ai.ChatRequest) (string, string, error) {
|
||||
@@ -370,15 +433,13 @@ func (p *CursorAgentProvider) createAgent(ctx context.Context, req ai.ChatReques
|
||||
}
|
||||
|
||||
func buildCursorCreateAgentRequest(req ai.ChatRequest, model string) (cursorCreateAgentRequest, error) {
|
||||
prompt, err := buildCursorPrompt(req.Messages)
|
||||
prompt, err := buildCursorPromptInput(req.Messages)
|
||||
if err != nil {
|
||||
return cursorCreateAgentRequest{}, err
|
||||
}
|
||||
|
||||
requestBody := cursorCreateAgentRequest{
|
||||
Prompt: cursorPrompt{
|
||||
Text: prompt,
|
||||
},
|
||||
Prompt: prompt,
|
||||
}
|
||||
|
||||
if trimmedModel := strings.TrimSpace(model); trimmedModel != "" {
|
||||
@@ -389,18 +450,59 @@ func buildCursorCreateAgentRequest(req ai.ChatRequest, model string) (cursorCrea
|
||||
}
|
||||
|
||||
func buildCursorPrompt(messages []ai.Message) (string, error) {
|
||||
requestMessages := messages
|
||||
if requestMessagesContainImages(messages) {
|
||||
requestMessages = stripImagesFromRequestMessages(messages)
|
||||
prompt := strings.TrimSpace(buildPrompt(messages))
|
||||
if prompt == "" && requestMessagesContainImages(messages) {
|
||||
return "请结合这些图片继续分析并回答。", nil
|
||||
}
|
||||
|
||||
prompt := strings.TrimSpace(buildPrompt(requestMessages))
|
||||
if prompt == "" {
|
||||
return "", fmt.Errorf("请求内容不能为空")
|
||||
}
|
||||
return prompt, nil
|
||||
}
|
||||
|
||||
func buildCursorPromptInput(messages []ai.Message) (cursorPrompt, error) {
|
||||
text, err := buildCursorPrompt(messages)
|
||||
if err != nil {
|
||||
return cursorPrompt{}, err
|
||||
}
|
||||
images, err := buildCursorImageInputs(messages)
|
||||
if err != nil {
|
||||
return cursorPrompt{}, err
|
||||
}
|
||||
return cursorPrompt{
|
||||
Text: text,
|
||||
Images: images,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildCursorImageInputs(messages []ai.Message) ([]cursorImageInput, error) {
|
||||
images := make([]cursorImageInput, 0)
|
||||
for _, message := range messages {
|
||||
for _, img := range message.Images {
|
||||
trimmed := strings.TrimSpace(img)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(trimmed, "http://") || strings.HasPrefix(trimmed, "https://") {
|
||||
images = append(images, cursorImageInput{URL: trimmed})
|
||||
continue
|
||||
}
|
||||
mimeType, rawBase64, err := ParseDataURI(trimmed)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("解析图片数据失败: %w", err)
|
||||
}
|
||||
images = append(images, cursorImageInput{
|
||||
Data: rawBase64,
|
||||
MimeType: mimeType,
|
||||
})
|
||||
}
|
||||
}
|
||||
if len(images) > 5 {
|
||||
return nil, fmt.Errorf("Cursor 最多支持 5 张图片,当前请求包含 %d 张", len(images))
|
||||
}
|
||||
return images, nil
|
||||
}
|
||||
|
||||
func (p *CursorAgentProvider) waitForRun(ctx context.Context, agentID string, runID string) (*cursorRunResponse, error) {
|
||||
ticker := time.NewTicker(cursorRunPollInterval)
|
||||
defer ticker.Stop()
|
||||
@@ -455,6 +557,31 @@ func (p *CursorAgentProvider) getRun(ctx context.Context, agentID string, runID
|
||||
return &responseBody, nil
|
||||
}
|
||||
|
||||
func (p *CursorAgentProvider) createRun(ctx context.Context, agentID string, req ai.ChatRequest) (string, error) {
|
||||
prompt, err := buildCursorPromptInput(req.Messages)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
requestBody := cursorCreateRunRequest{
|
||||
Prompt: prompt,
|
||||
}
|
||||
|
||||
var responseBody struct {
|
||||
Run struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"run"`
|
||||
}
|
||||
if err := p.doJSONRequest(ctx, http.MethodPost, ResolveCursorAPIEndpoint(p.baseURL, fmt.Sprintf("agents/%s/runs", agentID)), requestBody, &responseBody, "application/json"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
runID := strings.TrimSpace(responseBody.Run.ID)
|
||||
if runID == "" {
|
||||
return "", fmt.Errorf("Cursor 创建 follow-up run 成功,但未返回有效 runId")
|
||||
}
|
||||
return runID, nil
|
||||
}
|
||||
|
||||
func (p *CursorAgentProvider) openRunStream(ctx context.Context, agentID string, runID string) (io.ReadCloser, error) {
|
||||
endpoint := ResolveCursorAPIEndpoint(p.baseURL, fmt.Sprintf("agents/%s/runs/%s/stream", agentID, runID))
|
||||
requestLog := logAIUpstreamRequestStart(p.Name(), http.MethodGet, endpoint, nil)
|
||||
@@ -541,6 +668,28 @@ func (p *CursorAgentProvider) doJSONRequest(ctx context.Context, method string,
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseCursorSessionState(state json.RawMessage) (cursorSessionState, error) {
|
||||
if len(state) == 0 {
|
||||
return cursorSessionState{}, nil
|
||||
}
|
||||
var result cursorSessionState
|
||||
if err := json.Unmarshal(state, &result); err != nil {
|
||||
return cursorSessionState{}, fmt.Errorf("解析 Cursor 会话状态失败: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func marshalCursorSessionState(state cursorSessionState) (json.RawMessage, error) {
|
||||
if strings.TrimSpace(state.AgentID) == "" {
|
||||
return nil, nil
|
||||
}
|
||||
bytes, err := json.Marshal(state)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化 Cursor 会话状态失败: %w", err)
|
||||
}
|
||||
return json.RawMessage(bytes), nil
|
||||
}
|
||||
|
||||
func isCursorRunTerminalStatus(status string) bool {
|
||||
switch strings.ToUpper(strings.TrimSpace(status)) {
|
||||
case "FINISHED", "ERROR", "CANCELLED", "EXPIRED":
|
||||
|
||||
@@ -99,6 +99,85 @@ func TestCursorAgentProviderChat_PollsUntilFinished(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCursorAgentProviderChatWithState_UsesFollowUpRunsAndPreservesAgent(t *testing.T) {
|
||||
var (
|
||||
createAgentCalls int32
|
||||
createRunCalls int32
|
||||
receivedPrompt string
|
||||
)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/v1/agents":
|
||||
atomic.AddInt32(&createAgentCalls, 1)
|
||||
t.Fatalf("expected follow-up request to avoid creating a new agent")
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/v1/agents/bc-existing/runs":
|
||||
atomic.AddInt32(&createRunCalls, 1)
|
||||
var body struct {
|
||||
Prompt struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"prompt"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
t.Fatalf("decode follow-up run body: %v", err)
|
||||
}
|
||||
receivedPrompt = body.Prompt.Text
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"run": map[string]any{"id": "run-next"},
|
||||
})
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/v1/agents/bc-existing/runs/run-next":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "run-next",
|
||||
"agentId": "bc-existing",
|
||||
"status": "FINISHED",
|
||||
"result": "done from follow-up",
|
||||
"durationMs": 456,
|
||||
})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
providerInstance, err := NewCursorAgentProvider(ai.ProviderConfig{
|
||||
Name: "Cursor",
|
||||
BaseURL: server.URL + "/v1",
|
||||
APIKey: "cursor-key",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
resp, nextState, err := providerInstance.(SessionChatProvider).ChatWithState(
|
||||
context.Background(),
|
||||
json.RawMessage(`{"agentId":"bc-existing","lastRunId":"run-old"}`),
|
||||
ai.ChatRequest{
|
||||
Messages: []ai.Message{
|
||||
{Role: "user", Content: "follow this up"},
|
||||
},
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("follow-up chat failed: %v", err)
|
||||
}
|
||||
|
||||
if atomic.LoadInt32(&createAgentCalls) != 0 {
|
||||
t.Fatalf("expected no create-agent calls, got %d", createAgentCalls)
|
||||
}
|
||||
if atomic.LoadInt32(&createRunCalls) != 1 {
|
||||
t.Fatalf("expected exactly one follow-up run call, got %d", createRunCalls)
|
||||
}
|
||||
if !strings.Contains(receivedPrompt, "follow this up") {
|
||||
t.Fatalf("expected follow-up prompt text, got %q", receivedPrompt)
|
||||
}
|
||||
if resp == nil || resp.Content != "done from follow-up" {
|
||||
t.Fatalf("unexpected response: %#v", resp)
|
||||
}
|
||||
if string(nextState) != `{"agentId":"bc-existing","lastRunId":"run-next"}` {
|
||||
t.Fatalf("unexpected next session state: %s", string(nextState))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCursorAgentProviderChatStream_MapsAssistantAndThinkingEvents(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
@@ -164,3 +243,109 @@ func TestCursorAgentProviderChatStream_MapsAssistantAndThinkingEvents(t *testing
|
||||
t.Fatalf("expected final done chunk, got %#v", chunks[len(chunks)-1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCursorAgentProviderChatStreamWithState_UsesFollowUpRunsAndPreservesAgent(t *testing.T) {
|
||||
var (
|
||||
createAgentCalls int32
|
||||
createRunCalls int32
|
||||
receivedPrompt string
|
||||
)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/v1/agents":
|
||||
atomic.AddInt32(&createAgentCalls, 1)
|
||||
t.Fatalf("expected follow-up request to avoid creating a new agent")
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/v1/agents/bc-existing/runs":
|
||||
atomic.AddInt32(&createRunCalls, 1)
|
||||
var body struct {
|
||||
Prompt struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"prompt"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
t.Fatalf("decode follow-up run body: %v", err)
|
||||
}
|
||||
receivedPrompt = body.Prompt.Text
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"run": map[string]any{"id": "run-next"},
|
||||
})
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/v1/agents/bc-existing/runs/run-next/stream":
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("event: assistant\n"))
|
||||
_, _ = w.Write([]byte("data: {\"text\":\"done\"}\n\n"))
|
||||
_, _ = w.Write([]byte("event: done\n"))
|
||||
_, _ = w.Write([]byte("data: {}\n\n"))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
providerInstance, err := NewCursorAgentProvider(ai.ProviderConfig{
|
||||
Name: "Cursor",
|
||||
BaseURL: server.URL + "/v1",
|
||||
APIKey: "cursor-key",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
sessionState := json.RawMessage(`{"agentId":"bc-existing","lastRunId":"run-old"}`)
|
||||
var chunks []ai.StreamChunk
|
||||
nextState, err := providerInstance.(SessionStreamProvider).ChatStreamWithState(
|
||||
context.Background(),
|
||||
sessionState,
|
||||
ai.ChatRequest{
|
||||
Messages: []ai.Message{
|
||||
{Role: "user", Content: "follow this up"},
|
||||
},
|
||||
},
|
||||
func(chunk ai.StreamChunk) {
|
||||
chunks = append(chunks, chunk)
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("follow-up stream failed: %v", err)
|
||||
}
|
||||
|
||||
if atomic.LoadInt32(&createAgentCalls) != 0 {
|
||||
t.Fatalf("expected no create-agent calls, got %d", createAgentCalls)
|
||||
}
|
||||
if atomic.LoadInt32(&createRunCalls) != 1 {
|
||||
t.Fatalf("expected exactly one follow-up run call, got %d", createRunCalls)
|
||||
}
|
||||
if !strings.Contains(receivedPrompt, "follow this up") {
|
||||
t.Fatalf("expected follow-up prompt text, got %q", receivedPrompt)
|
||||
}
|
||||
if string(nextState) != `{"agentId":"bc-existing","lastRunId":"run-next"}` {
|
||||
t.Fatalf("unexpected next session state: %s", string(nextState))
|
||||
}
|
||||
if len(chunks) == 0 || chunks[0].Content != "done" {
|
||||
t.Fatalf("expected streamed assistant content, got %#v", chunks)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCursorAgentProviderCreateAgentRequest_IncludesImageInputs(t *testing.T) {
|
||||
requestBody, err := buildCursorCreateAgentRequest(ai.ChatRequest{
|
||||
Messages: []ai.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "look at this",
|
||||
Images: []string{"data:image/png;base64,aGVsbG8="},
|
||||
},
|
||||
},
|
||||
}, "composer-latest")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(requestBody.Prompt.Images) != 1 {
|
||||
t.Fatalf("expected one image payload, got %#v", requestBody.Prompt.Images)
|
||||
}
|
||||
if requestBody.Prompt.Images[0].Data != "aGVsbG8=" || requestBody.Prompt.Images[0].MimeType != "image/png" {
|
||||
t.Fatalf("unexpected image payload: %#v", requestBody.Prompt.Images[0])
|
||||
}
|
||||
if requestBody.Model == nil || requestBody.Model.ID != "composer-latest" {
|
||||
t.Fatalf("expected model selection to be preserved, got %#v", requestBody.Model)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -75,3 +76,20 @@ func (p *CustomProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.Chat
|
||||
func (p *CustomProvider) ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error {
|
||||
return p.inner.ChatStream(ctx, req, callback)
|
||||
}
|
||||
|
||||
func (p *CustomProvider) ChatWithState(ctx context.Context, state json.RawMessage, req ai.ChatRequest) (*ai.ChatResponse, json.RawMessage, error) {
|
||||
sessionProvider, ok := p.inner.(SessionChatProvider)
|
||||
if !ok {
|
||||
resp, err := p.inner.Chat(ctx, req)
|
||||
return resp, nil, err
|
||||
}
|
||||
return sessionProvider.ChatWithState(ctx, state, req)
|
||||
}
|
||||
|
||||
func (p *CustomProvider) ChatStreamWithState(ctx context.Context, state json.RawMessage, req ai.ChatRequest, callback func(ai.StreamChunk)) (json.RawMessage, error) {
|
||||
sessionProvider, ok := p.inner.(SessionStreamProvider)
|
||||
if !ok {
|
||||
return nil, p.inner.ChatStream(ctx, req, callback)
|
||||
}
|
||||
return sessionProvider.ChatStreamWithState(ctx, state, req, callback)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"GoNavi-Wails/internal/ai"
|
||||
)
|
||||
@@ -17,3 +18,24 @@ type Provider interface {
|
||||
// Validate 校验配置是否有效
|
||||
Validate() error
|
||||
}
|
||||
|
||||
// SessionStreamProvider 表示支持按会话复用上游状态的流式 Provider。
|
||||
// state 为 Provider 自己维护的持久化状态;返回值为更新后的状态快照。
|
||||
type SessionStreamProvider interface {
|
||||
ChatStreamWithState(
|
||||
ctx context.Context,
|
||||
state json.RawMessage,
|
||||
req ai.ChatRequest,
|
||||
callback func(ai.StreamChunk),
|
||||
) (json.RawMessage, error)
|
||||
}
|
||||
|
||||
// SessionChatProvider 表示支持按会话复用上游状态的非流式 Provider。
|
||||
// state 为 Provider 自己维护的持久化状态;返回值为响应体和更新后的状态快照。
|
||||
type SessionChatProvider interface {
|
||||
ChatWithState(
|
||||
ctx context.Context,
|
||||
state json.RawMessage,
|
||||
req ai.ChatRequest,
|
||||
) (*ai.ChatResponse, json.RawMessage, error)
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -42,11 +43,18 @@ type Service struct {
|
||||
secretStore secretstore.SecretStore
|
||||
localizer *i18n.Localizer
|
||||
cancelFuncs map[string]context.CancelFunc // 记录每个 session 的 context 取消函数
|
||||
sessionProviders map[string]aiSessionProviderRuntime
|
||||
mcpHTTPMu sync.Mutex
|
||||
mcpHTTP *mcpHTTPServerRuntime
|
||||
mcpHTTPLast ai.MCPHTTPServerStatus
|
||||
}
|
||||
|
||||
type aiSessionProviderRuntime struct {
|
||||
ProviderKey string
|
||||
State json.RawMessage
|
||||
Messages []ai.Message
|
||||
}
|
||||
|
||||
var miniMaxAnthropicModels = []string{
|
||||
"MiniMax-M3",
|
||||
"MiniMax-M2.7",
|
||||
@@ -131,15 +139,16 @@ func NewServiceWithSecretStore(store secretstore.SecretStore) *Service {
|
||||
store = secretstore.NewUnavailableStore("secret store unavailable")
|
||||
}
|
||||
return &Service{
|
||||
providers: make([]ai.ProviderConfig, 0),
|
||||
safetyLevel: ai.PermissionReadOnly,
|
||||
contextLevel: ai.ContextSchemaOnly,
|
||||
mcpServers: make([]ai.MCPServerConfig, 0),
|
||||
skills: make([]ai.SkillConfig, 0),
|
||||
guard: safety.NewGuard(ai.PermissionReadOnly),
|
||||
secretStore: store,
|
||||
localizer: newServiceLocalizer(),
|
||||
cancelFuncs: make(map[string]context.CancelFunc),
|
||||
providers: make([]ai.ProviderConfig, 0),
|
||||
safetyLevel: ai.PermissionReadOnly,
|
||||
contextLevel: ai.ContextSchemaOnly,
|
||||
mcpServers: make([]ai.MCPServerConfig, 0),
|
||||
skills: make([]ai.SkillConfig, 0),
|
||||
guard: safety.NewGuard(ai.PermissionReadOnly),
|
||||
secretStore: store,
|
||||
localizer: newServiceLocalizer(),
|
||||
cancelFuncs: make(map[string]context.CancelFunc),
|
||||
sessionProviders: make(map[string]aiSessionProviderRuntime),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1150,7 +1159,16 @@ func (s *Service) AISetContextLevel(level string) {
|
||||
|
||||
// AIChatSend 非流式发送 AI 对话
|
||||
func (s *Service) AIChatSend(messages []ai.Message, tools []ai.Tool) map[string]interface{} {
|
||||
p, err := s.getActiveProvider()
|
||||
return s.aiChatSend("", messages, tools, false)
|
||||
}
|
||||
|
||||
// AIChatSendInSession 非流式发送 AI 对话,并在支持的 Provider 上复用会话态。
|
||||
func (s *Service) AIChatSendInSession(sessionID string, messages []ai.Message, tools []ai.Tool) map[string]interface{} {
|
||||
return s.aiChatSend(sessionID, messages, tools, true)
|
||||
}
|
||||
|
||||
func (s *Service) aiChatSend(sessionID string, messages []ai.Message, tools []ai.Tool, allowSessionReuse bool) map[string]interface{} {
|
||||
p, config, err := s.getActiveProviderRuntime()
|
||||
if err != nil {
|
||||
logger.Error(err, "AIChatSend 获取 Provider 失败:messages=%d tools=%d", len(messages), len(tools))
|
||||
return map[string]interface{}{"success": false, "error": err.Error()}
|
||||
@@ -1158,14 +1176,62 @@ func (s *Service) AIChatSend(messages []ai.Message, tools []ai.Tool) map[string]
|
||||
|
||||
started := time.Now()
|
||||
providerName := p.Name()
|
||||
logger.Infof("AIChatSend 开始:provider=%s messages=%d tools=%d", providerName, len(messages), len(tools))
|
||||
resp, err := p.Chat(context.Background(), ai.ChatRequest{Messages: messages, Tools: tools})
|
||||
logger.Infof("AIChatSend 开始:sessionID=%s provider=%s messages=%d tools=%d sessionReuse=%t", sessionID, providerName, len(messages), len(tools), allowSessionReuse)
|
||||
requestMessages := cloneAIMessages(messages)
|
||||
var updatedProviderState json.RawMessage
|
||||
if allowSessionReuse && strings.TrimSpace(sessionID) != "" {
|
||||
if sessionAwareProvider, ok := p.(provider.SessionChatProvider); ok {
|
||||
providerKey := providerSessionKey(config)
|
||||
providerState, deltaMessages := s.resolveSessionProviderRequest(sessionID, providerKey, messages)
|
||||
requestMessages = deltaMessages
|
||||
resp, updatedState, err := sessionAwareProvider.ChatWithState(context.Background(), providerState, ai.ChatRequest{Messages: requestMessages, Tools: tools})
|
||||
if err != nil {
|
||||
logger.Warnf("AIChatSend 失败:sessionID=%s provider=%s messages=%d tools=%d duration=%s err=%s", sessionID, providerName, len(messages), len(tools), time.Since(started).Round(time.Millisecond), provider.RedactAIUpstreamLogText(err.Error()))
|
||||
return map[string]interface{}{"success": false, "error": err.Error()}
|
||||
}
|
||||
updatedProviderState = updatedState
|
||||
historyAfterSend := cloneAIMessages(messages)
|
||||
if assistantMessage, hasAssistantMessage := buildAssistantMessageFromChatResponse(resp); hasAssistantMessage {
|
||||
historyAfterSend = append(historyAfterSend, assistantMessage)
|
||||
}
|
||||
if persistErr := s.storeSessionProviderRuntime(sessionID, providerKey, updatedProviderState, historyAfterSend); persistErr != nil {
|
||||
logger.Warnf("AIChatSend 保存会话 Provider 状态失败:sessionID=%s provider=%s err=%s", sessionID, providerName, provider.RedactAIUpstreamLogText(persistErr.Error()))
|
||||
}
|
||||
logger.Infof(
|
||||
"AIChatSend 完成:sessionID=%s provider=%s messages=%d tools=%d toolCalls=%d promptTokens=%d completionTokens=%d totalTokens=%d duration=%s sessionReuse=%t",
|
||||
sessionID,
|
||||
providerName,
|
||||
len(messages),
|
||||
len(tools),
|
||||
len(resp.ToolCalls),
|
||||
resp.TokensUsed.PromptTokens,
|
||||
resp.TokensUsed.CompletionTokens,
|
||||
resp.TokensUsed.TotalTokens,
|
||||
time.Since(started).Round(time.Millisecond),
|
||||
true,
|
||||
)
|
||||
return map[string]interface{}{
|
||||
"success": true,
|
||||
"content": resp.Content,
|
||||
"reasoning_content": resp.ReasoningContent,
|
||||
"tool_calls": resp.ToolCalls,
|
||||
"tokensUsed": map[string]int{
|
||||
"promptTokens": resp.TokensUsed.PromptTokens,
|
||||
"completionTokens": resp.TokensUsed.CompletionTokens,
|
||||
"totalTokens": resp.TokensUsed.TotalTokens,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := p.Chat(context.Background(), ai.ChatRequest{Messages: requestMessages, Tools: tools})
|
||||
if err != nil {
|
||||
logger.Warnf("AIChatSend 失败:provider=%s messages=%d tools=%d duration=%s err=%s", providerName, len(messages), len(tools), time.Since(started).Round(time.Millisecond), provider.RedactAIUpstreamLogText(err.Error()))
|
||||
logger.Warnf("AIChatSend 失败:sessionID=%s provider=%s messages=%d tools=%d duration=%s err=%s", sessionID, providerName, len(messages), len(tools), time.Since(started).Round(time.Millisecond), provider.RedactAIUpstreamLogText(err.Error()))
|
||||
return map[string]interface{}{"success": false, "error": err.Error()}
|
||||
}
|
||||
logger.Infof(
|
||||
"AIChatSend 完成:provider=%s messages=%d tools=%d toolCalls=%d promptTokens=%d completionTokens=%d totalTokens=%d duration=%s",
|
||||
"AIChatSend 完成:sessionID=%s provider=%s messages=%d tools=%d toolCalls=%d promptTokens=%d completionTokens=%d totalTokens=%d duration=%s sessionReuse=%t",
|
||||
sessionID,
|
||||
providerName,
|
||||
len(messages),
|
||||
len(tools),
|
||||
@@ -1174,6 +1240,7 @@ func (s *Service) AIChatSend(messages []ai.Message, tools []ai.Tool) map[string]
|
||||
resp.TokensUsed.CompletionTokens,
|
||||
resp.TokensUsed.TotalTokens,
|
||||
time.Since(started).Round(time.Millisecond),
|
||||
false,
|
||||
)
|
||||
|
||||
return map[string]interface{}{
|
||||
@@ -1204,7 +1271,7 @@ func (s *Service) AIChatStream(sessionID string, messages []ai.Message, tools []
|
||||
cancel() // 确保释放
|
||||
}()
|
||||
|
||||
p, err := s.getActiveProvider()
|
||||
p, config, err := s.getActiveProviderRuntime()
|
||||
if err != nil {
|
||||
logger.Error(err, "AIChatStream 获取 Provider 失败:sessionID=%s messages=%d tools=%d", sessionID, len(messages), len(tools))
|
||||
wailsRuntime.EventsEmit(s.ctx, "ai:stream:"+sessionID, map[string]interface{}{
|
||||
@@ -1220,29 +1287,67 @@ func (s *Service) AIChatStream(sessionID string, messages []ai.Message, tools []
|
||||
thinkingChunks := 0
|
||||
toolCallChunks := 0
|
||||
errorChunks := 0
|
||||
var assistantContent strings.Builder
|
||||
var assistantReasoning strings.Builder
|
||||
var assistantToolCalls []ai.ToolCall
|
||||
var updatedProviderState json.RawMessage
|
||||
requestMessages := cloneAIMessages(messages)
|
||||
logger.Infof("AIChatStream 开始:sessionID=%s provider=%s messages=%d tools=%d", sessionID, providerName, len(messages), len(tools))
|
||||
err = p.ChatStream(streamCtx, ai.ChatRequest{Messages: messages, Tools: tools}, func(chunk ai.StreamChunk) {
|
||||
if chunk.Content != "" {
|
||||
contentChunks++
|
||||
}
|
||||
if chunk.Thinking != "" || chunk.ReasoningContent != "" {
|
||||
thinkingChunks++
|
||||
}
|
||||
if len(chunk.ToolCalls) > 0 {
|
||||
toolCallChunks++
|
||||
}
|
||||
if chunk.Error != "" {
|
||||
errorChunks++
|
||||
}
|
||||
wailsRuntime.EventsEmit(s.ctx, "ai:stream:"+sessionID, map[string]interface{}{
|
||||
"content": chunk.Content,
|
||||
"thinking": chunk.Thinking,
|
||||
"reasoning_content": chunk.ReasoningContent,
|
||||
"tool_calls": chunk.ToolCalls,
|
||||
"done": chunk.Done,
|
||||
"error": chunk.Error,
|
||||
if sessionAwareProvider, ok := p.(provider.SessionStreamProvider); ok {
|
||||
providerKey := providerSessionKey(config)
|
||||
providerState, deltaMessages := s.resolveSessionProviderRequest(sessionID, providerKey, messages)
|
||||
requestMessages = deltaMessages
|
||||
updatedProviderState, err = sessionAwareProvider.ChatStreamWithState(streamCtx, providerState, ai.ChatRequest{Messages: requestMessages, Tools: tools}, func(chunk ai.StreamChunk) {
|
||||
if chunk.Content != "" {
|
||||
contentChunks++
|
||||
assistantContent.WriteString(chunk.Content)
|
||||
}
|
||||
if chunk.Thinking != "" || chunk.ReasoningContent != "" {
|
||||
thinkingChunks++
|
||||
if chunk.ReasoningContent != "" {
|
||||
assistantReasoning.WriteString(chunk.ReasoningContent)
|
||||
}
|
||||
}
|
||||
if len(chunk.ToolCalls) > 0 {
|
||||
toolCallChunks++
|
||||
assistantToolCalls = append([]ai.ToolCall(nil), chunk.ToolCalls...)
|
||||
}
|
||||
if chunk.Error != "" {
|
||||
errorChunks++
|
||||
}
|
||||
wailsRuntime.EventsEmit(s.ctx, "ai:stream:"+sessionID, map[string]interface{}{
|
||||
"content": chunk.Content,
|
||||
"thinking": chunk.Thinking,
|
||||
"reasoning_content": chunk.ReasoningContent,
|
||||
"tool_calls": chunk.ToolCalls,
|
||||
"done": chunk.Done,
|
||||
"error": chunk.Error,
|
||||
})
|
||||
})
|
||||
})
|
||||
} else {
|
||||
err = p.ChatStream(streamCtx, ai.ChatRequest{Messages: messages, Tools: tools}, func(chunk ai.StreamChunk) {
|
||||
if chunk.Content != "" {
|
||||
contentChunks++
|
||||
}
|
||||
if chunk.Thinking != "" || chunk.ReasoningContent != "" {
|
||||
thinkingChunks++
|
||||
}
|
||||
if len(chunk.ToolCalls) > 0 {
|
||||
toolCallChunks++
|
||||
}
|
||||
if chunk.Error != "" {
|
||||
errorChunks++
|
||||
}
|
||||
wailsRuntime.EventsEmit(s.ctx, "ai:stream:"+sessionID, map[string]interface{}{
|
||||
"content": chunk.Content,
|
||||
"thinking": chunk.Thinking,
|
||||
"reasoning_content": chunk.ReasoningContent,
|
||||
"tool_calls": chunk.ToolCalls,
|
||||
"done": chunk.Done,
|
||||
"error": chunk.Error,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// 当 context 被主动 cancel 的时候,不把这个视为向外抛的 error
|
||||
if err != nil && err != context.Canceled {
|
||||
@@ -1257,6 +1362,16 @@ func (s *Service) AIChatStream(sessionID string, messages []ai.Message, tools []
|
||||
logger.Infof("AIChatStream 已取消:sessionID=%s provider=%s duration=%s", sessionID, providerName, time.Since(started).Round(time.Millisecond))
|
||||
return
|
||||
}
|
||||
if _, ok := p.(provider.SessionStreamProvider); ok && errorChunks == 0 {
|
||||
providerKey := providerSessionKey(config)
|
||||
historyAfterStream := cloneAIMessages(messages)
|
||||
if assistantMessage, hasAssistantMessage := buildAssistantMessageFromStreamResult(assistantContent.String(), assistantReasoning.String(), assistantToolCalls); hasAssistantMessage {
|
||||
historyAfterStream = append(historyAfterStream, assistantMessage)
|
||||
}
|
||||
if persistErr := s.storeSessionProviderRuntime(sessionID, providerKey, updatedProviderState, historyAfterStream); persistErr != nil {
|
||||
logger.Warnf("AIChatStream 保存会话 Provider 状态失败:sessionID=%s provider=%s err=%s", sessionID, providerName, provider.RedactAIUpstreamLogText(persistErr.Error()))
|
||||
}
|
||||
}
|
||||
logger.Infof(
|
||||
"AIChatStream 完成:sessionID=%s provider=%s messages=%d tools=%d contentChunks=%d thinkingChunks=%d toolCallChunks=%d errorChunks=%d duration=%s",
|
||||
sessionID,
|
||||
@@ -1292,6 +1407,11 @@ func (s *Service) AICheckSQL(sql string) ai.SafetyResult {
|
||||
// --- 内部方法 ---
|
||||
|
||||
func (s *Service) getActiveProvider() (provider.Provider, error) {
|
||||
p, _, err := s.getActiveProviderRuntime()
|
||||
return p, err
|
||||
}
|
||||
|
||||
func (s *Service) getActiveProviderRuntime() (provider.Provider, ai.ProviderConfig, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
@@ -1301,11 +1421,174 @@ func (s *Service) getActiveProvider() (provider.Provider, error) {
|
||||
|
||||
for _, cfg := range s.providers {
|
||||
if cfg.ID == s.activeProvider {
|
||||
return provider.NewProvider(normalizeProviderConfig(cfg))
|
||||
normalized := normalizeProviderConfig(cfg)
|
||||
p, err := provider.NewProvider(normalized)
|
||||
return p, normalized, err
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("未配置 AI Provider,请先在设置中配置")
|
||||
return nil, ai.ProviderConfig{}, fmt.Errorf("未配置 AI Provider,请先在设置中配置")
|
||||
}
|
||||
|
||||
func providerSessionKey(config ai.ProviderConfig) string {
|
||||
return strings.Join([]string{
|
||||
strings.TrimSpace(config.ID),
|
||||
strings.ToLower(strings.TrimSpace(config.Type)),
|
||||
strings.ToLower(strings.TrimSpace(config.APIFormat)),
|
||||
strings.TrimSpace(config.BaseURL),
|
||||
strings.TrimSpace(config.Model),
|
||||
}, "|")
|
||||
}
|
||||
|
||||
func cloneAIMessages(messages []ai.Message) []ai.Message {
|
||||
if len(messages) == 0 {
|
||||
return nil
|
||||
}
|
||||
cloned := make([]ai.Message, len(messages))
|
||||
for index, message := range messages {
|
||||
cloned[index] = message
|
||||
if len(message.Images) > 0 {
|
||||
cloned[index].Images = append([]string(nil), message.Images...)
|
||||
}
|
||||
if len(message.ToolCalls) > 0 {
|
||||
cloned[index].ToolCalls = append([]ai.ToolCall(nil), message.ToolCalls...)
|
||||
}
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func buildAssistantMessageFromStreamResult(content string, reasoning string, toolCalls []ai.ToolCall) (ai.Message, bool) {
|
||||
message := ai.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
ReasoningContent: reasoning,
|
||||
}
|
||||
if len(toolCalls) > 0 {
|
||||
message.ToolCalls = append([]ai.ToolCall(nil), toolCalls...)
|
||||
}
|
||||
hasPayload := strings.TrimSpace(message.Content) != "" || strings.TrimSpace(message.ReasoningContent) != "" || len(message.ToolCalls) > 0
|
||||
return message, hasPayload
|
||||
}
|
||||
|
||||
func buildAssistantMessageFromChatResponse(resp *ai.ChatResponse) (ai.Message, bool) {
|
||||
if resp == nil {
|
||||
return ai.Message{}, false
|
||||
}
|
||||
return buildAssistantMessageFromStreamResult(resp.Content, resp.ReasoningContent, resp.ToolCalls)
|
||||
}
|
||||
|
||||
func messagesHavePrefix(messages []ai.Message, prefix []ai.Message) bool {
|
||||
if len(prefix) == 0 {
|
||||
return true
|
||||
}
|
||||
if len(messages) < len(prefix) {
|
||||
return false
|
||||
}
|
||||
for index := range prefix {
|
||||
if !reflect.DeepEqual(messages[index], prefix[index]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Service) resolveSessionProviderRequest(sessionID string, providerKey string, messages []ai.Message) (json.RawMessage, []ai.Message) {
|
||||
runtimeState, ok := s.loadSessionProviderRuntime(sessionID, providerKey)
|
||||
if !ok || len(runtimeState.State) == 0 || len(runtimeState.Messages) == 0 {
|
||||
return nil, cloneAIMessages(messages)
|
||||
}
|
||||
if !messagesHavePrefix(messages, runtimeState.Messages) {
|
||||
return nil, cloneAIMessages(messages)
|
||||
}
|
||||
deltaMessages := cloneAIMessages(messages[len(runtimeState.Messages):])
|
||||
if len(deltaMessages) == 0 {
|
||||
return nil, cloneAIMessages(messages)
|
||||
}
|
||||
return runtimeState.State, deltaMessages
|
||||
}
|
||||
|
||||
func (s *Service) loadSessionProviderRuntime(sessionID string, providerKey string) (aiSessionProviderRuntime, bool) {
|
||||
s.mu.RLock()
|
||||
runtimeState, ok := s.sessionProviders[sessionID]
|
||||
s.mu.RUnlock()
|
||||
if ok && runtimeState.ProviderKey == providerKey {
|
||||
return aiSessionProviderRuntime{
|
||||
ProviderKey: runtimeState.ProviderKey,
|
||||
State: append(json.RawMessage(nil), runtimeState.State...),
|
||||
Messages: cloneAIMessages(runtimeState.Messages),
|
||||
}, true
|
||||
}
|
||||
|
||||
sessionData, err := s.loadSessionFile(sessionID)
|
||||
if err != nil {
|
||||
return aiSessionProviderRuntime{}, false
|
||||
}
|
||||
if strings.TrimSpace(sessionData.ProviderKey) == "" || sessionData.ProviderKey != providerKey || len(sessionData.ProviderState) == 0 {
|
||||
return aiSessionProviderRuntime{}, false
|
||||
}
|
||||
var providerMessages []ai.Message
|
||||
if len(sessionData.ProviderMessages) > 0 {
|
||||
if err := json.Unmarshal(sessionData.ProviderMessages, &providerMessages); err != nil {
|
||||
return aiSessionProviderRuntime{}, false
|
||||
}
|
||||
}
|
||||
|
||||
runtimeState = aiSessionProviderRuntime{
|
||||
ProviderKey: sessionData.ProviderKey,
|
||||
State: append(json.RawMessage(nil), sessionData.ProviderState...),
|
||||
Messages: providerMessages,
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.sessionProviders[sessionID] = runtimeState
|
||||
s.mu.Unlock()
|
||||
return aiSessionProviderRuntime{
|
||||
ProviderKey: runtimeState.ProviderKey,
|
||||
State: append(json.RawMessage(nil), runtimeState.State...),
|
||||
Messages: cloneAIMessages(runtimeState.Messages),
|
||||
}, true
|
||||
}
|
||||
|
||||
func (s *Service) storeSessionProviderRuntime(sessionID string, providerKey string, state json.RawMessage, messages []ai.Message) error {
|
||||
if strings.TrimSpace(providerKey) == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
runtimeState := aiSessionProviderRuntime{
|
||||
ProviderKey: providerKey,
|
||||
State: append(json.RawMessage(nil), state...),
|
||||
Messages: cloneAIMessages(messages),
|
||||
}
|
||||
s.mu.Lock()
|
||||
if len(state) == 0 {
|
||||
delete(s.sessionProviders, sessionID)
|
||||
} else {
|
||||
s.sessionProviders[sessionID] = runtimeState
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
sessionData, err := s.loadOrCreateSessionFile(sessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(state) == 0 {
|
||||
sessionData.ProviderKey = ""
|
||||
sessionData.ProviderState = nil
|
||||
sessionData.ProviderMessages = nil
|
||||
return s.saveSessionFile(sessionID, sessionData)
|
||||
}
|
||||
|
||||
sessionData.ProviderKey = providerKey
|
||||
sessionData.ProviderState = append(json.RawMessage(nil), state...)
|
||||
if len(messages) == 0 {
|
||||
sessionData.ProviderMessages = nil
|
||||
} else {
|
||||
messageBytes, err := json.Marshal(messages)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化会话 Provider 消息失败: %w", err)
|
||||
}
|
||||
sessionData.ProviderMessages = json.RawMessage(messageBytes)
|
||||
}
|
||||
return s.saveSessionFile(sessionID, sessionData)
|
||||
}
|
||||
|
||||
// --- 配置持久化 ---
|
||||
@@ -1363,16 +1646,69 @@ func normalizeUserPromptText(value string) string {
|
||||
|
||||
// sessionFileData 会话文件的 JSON 结构
|
||||
type sessionFileData struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
UpdatedAt int64 `json:"updatedAt"`
|
||||
Messages json.RawMessage `json:"messages"` // 透传前端格式,后端不解析消息体
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
UpdatedAt int64 `json:"updatedAt"`
|
||||
Messages json.RawMessage `json:"messages"` // 透传前端格式,后端不解析消息体
|
||||
ProviderKey string `json:"providerKey,omitempty"`
|
||||
ProviderState json.RawMessage `json:"providerState,omitempty"`
|
||||
ProviderMessages json.RawMessage `json:"providerMessages,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Service) sessionsDir() string {
|
||||
return filepath.Join(s.configDir, "sessions")
|
||||
}
|
||||
|
||||
func (s *Service) sessionFilePath(sessionID string) string {
|
||||
return filepath.Join(s.sessionsDir(), sessionID+".json")
|
||||
}
|
||||
|
||||
func (s *Service) loadSessionFile(sessionID string) (sessionFileData, error) {
|
||||
data, err := os.ReadFile(s.sessionFilePath(sessionID))
|
||||
if err != nil {
|
||||
return sessionFileData{}, err
|
||||
}
|
||||
var sessionData sessionFileData
|
||||
if err := json.Unmarshal(data, &sessionData); err != nil {
|
||||
return sessionFileData{}, err
|
||||
}
|
||||
return sessionData, nil
|
||||
}
|
||||
|
||||
func (s *Service) loadOrCreateSessionFile(sessionID string) (sessionFileData, error) {
|
||||
sessionData, err := s.loadSessionFile(sessionID)
|
||||
if err == nil {
|
||||
return sessionData, nil
|
||||
}
|
||||
if !os.IsNotExist(err) {
|
||||
return sessionFileData{}, err
|
||||
}
|
||||
return sessionFileData{
|
||||
ID: sessionID,
|
||||
Title: "新的对话",
|
||||
UpdatedAt: time.Now().UnixMilli(),
|
||||
Messages: json.RawMessage("[]"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) saveSessionFile(sessionID string, sessionData sessionFileData) error {
|
||||
dir := s.sessionsDir()
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return fmt.Errorf("创建 sessions 目录失败: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(sessionData.ID) == "" {
|
||||
sessionData.ID = sessionID
|
||||
}
|
||||
if len(sessionData.Messages) == 0 {
|
||||
sessionData.Messages = json.RawMessage("[]")
|
||||
}
|
||||
data, err := json.Marshal(sessionData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化会话数据失败: %w", err)
|
||||
}
|
||||
return os.WriteFile(s.sessionFilePath(sessionID), data, 0o644)
|
||||
}
|
||||
|
||||
// AIGetSessions 获取所有会话的元数据列表(不含消息体)
|
||||
func (s *Service) AIGetSessions() []map[string]interface{} {
|
||||
dir := s.sessionsDir()
|
||||
@@ -1417,53 +1753,40 @@ func (s *Service) AIGetSessions() []map[string]interface{} {
|
||||
|
||||
// AILoadSession 加载指定会话的完整数据(含消息)
|
||||
func (s *Service) AILoadSession(sessionID string) map[string]interface{} {
|
||||
path := filepath.Join(s.sessionsDir(), sessionID+".json")
|
||||
data, err := os.ReadFile(path)
|
||||
sessionData, err := s.loadSessionFile(sessionID)
|
||||
if err != nil {
|
||||
return map[string]interface{}{"success": false, "error": "会话不存在"}
|
||||
}
|
||||
var sfd sessionFileData
|
||||
if err := json.Unmarshal(data, &sfd); err != nil {
|
||||
return map[string]interface{}{"success": false, "error": "会话数据损坏"}
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"success": true,
|
||||
"id": sfd.ID,
|
||||
"title": sfd.Title,
|
||||
"updatedAt": sfd.UpdatedAt,
|
||||
"messages": sfd.Messages,
|
||||
"id": sessionData.ID,
|
||||
"title": sessionData.Title,
|
||||
"updatedAt": sessionData.UpdatedAt,
|
||||
"messages": sessionData.Messages,
|
||||
}
|
||||
}
|
||||
|
||||
// AISaveSession 保存会话数据到文件
|
||||
func (s *Service) AISaveSession(sessionID string, title string, updatedAt float64, messagesJSON string) error {
|
||||
dir := s.sessionsDir()
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return fmt.Errorf("创建 sessions 目录失败: %w", err)
|
||||
}
|
||||
|
||||
sfd := sessionFileData{
|
||||
ID: sessionID,
|
||||
Title: title,
|
||||
UpdatedAt: int64(updatedAt),
|
||||
Messages: json.RawMessage(messagesJSON),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(sfd)
|
||||
sessionData, err := s.loadOrCreateSessionFile(sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化会话数据失败: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
path := filepath.Join(dir, sessionID+".json")
|
||||
return os.WriteFile(path, data, 0o644)
|
||||
sessionData.ID = sessionID
|
||||
sessionData.Title = title
|
||||
sessionData.UpdatedAt = int64(updatedAt)
|
||||
sessionData.Messages = json.RawMessage(messagesJSON)
|
||||
return s.saveSessionFile(sessionID, sessionData)
|
||||
}
|
||||
|
||||
// AIDeleteSession 删除会话文件
|
||||
func (s *Service) AIDeleteSession(sessionID string) error {
|
||||
path := filepath.Join(s.sessionsDir(), sessionID+".json")
|
||||
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
|
||||
if err := os.Remove(s.sessionFilePath(sessionID)); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("删除会话失败: %w", err)
|
||||
}
|
||||
s.mu.Lock()
|
||||
delete(s.sessionProviders, sessionID)
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/ai"
|
||||
@@ -99,3 +100,192 @@ func TestAIListModels_FetchesCursorModelItems(t *testing.T) {
|
||||
t.Fatalf("expected api source, got %#v", result["source"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSessionProviderRequest_ReusesStoredStateOnlyForHistoryExtension(t *testing.T) {
|
||||
service := NewService()
|
||||
service.sessionProviders["session-1"] = aiSessionProviderRuntime{
|
||||
ProviderKey: "cursor-provider",
|
||||
State: json.RawMessage(`{"agentId":"bc-1"}`),
|
||||
Messages: []ai.Message{
|
||||
{Role: "user", Content: "hello"},
|
||||
{Role: "assistant", Content: "world"},
|
||||
},
|
||||
}
|
||||
|
||||
state, delta := service.resolveSessionProviderRequest("session-1", "cursor-provider", []ai.Message{
|
||||
{Role: "user", Content: "hello"},
|
||||
{Role: "assistant", Content: "world"},
|
||||
{Role: "user", Content: "next"},
|
||||
})
|
||||
if string(state) != `{"agentId":"bc-1"}` {
|
||||
t.Fatalf("expected stored provider state, got %s", string(state))
|
||||
}
|
||||
expectedDelta := []ai.Message{{Role: "user", Content: "next"}}
|
||||
if !reflect.DeepEqual(delta, expectedDelta) {
|
||||
t.Fatalf("unexpected delta messages: %#v", delta)
|
||||
}
|
||||
|
||||
state, delta = service.resolveSessionProviderRequest("session-1", "cursor-provider", []ai.Message{
|
||||
{Role: "user", Content: "hello changed"},
|
||||
})
|
||||
if len(state) != 0 {
|
||||
t.Fatalf("expected mismatched history to reset provider state, got %s", string(state))
|
||||
}
|
||||
if len(delta) != 1 || delta[0].Content != "hello changed" {
|
||||
t.Fatalf("expected full messages after mismatch, got %#v", delta)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAISaveSession_PreservesProviderRuntimeMetadata(t *testing.T) {
|
||||
service := NewService()
|
||||
service.configDir = t.TempDir()
|
||||
|
||||
err := service.storeSessionProviderRuntime(
|
||||
"session-1",
|
||||
"cursor-provider",
|
||||
json.RawMessage(`{"agentId":"bc-1","lastRunId":"run-1"}`),
|
||||
[]ai.Message{{Role: "user", Content: "hello"}},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("store provider runtime: %v", err)
|
||||
}
|
||||
|
||||
err = service.AISaveSession("session-1", "标题", 123, `[{"id":"m1","role":"user","content":"hello","timestamp":1}]`)
|
||||
if err != nil {
|
||||
t.Fatalf("save session: %v", err)
|
||||
}
|
||||
|
||||
sessionData, err := service.loadSessionFile("session-1")
|
||||
if err != nil {
|
||||
t.Fatalf("load session file: %v", err)
|
||||
}
|
||||
if sessionData.ProviderKey != "cursor-provider" {
|
||||
t.Fatalf("expected provider key to be preserved, got %q", sessionData.ProviderKey)
|
||||
}
|
||||
if string(sessionData.ProviderState) != `{"agentId":"bc-1","lastRunId":"run-1"}` {
|
||||
t.Fatalf("expected provider state to be preserved, got %s", string(sessionData.ProviderState))
|
||||
}
|
||||
var providerMessages []ai.Message
|
||||
if err := json.Unmarshal(sessionData.ProviderMessages, &providerMessages); err != nil {
|
||||
t.Fatalf("unmarshal provider messages: %v", err)
|
||||
}
|
||||
if len(providerMessages) != 1 || providerMessages[0].Content != "hello" {
|
||||
t.Fatalf("unexpected provider messages: %#v", providerMessages)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAIChatSendInSession_ReusesCursorProviderStateAndPersistsFollowUpRuns(t *testing.T) {
|
||||
var (
|
||||
createAgentCalls int
|
||||
createRunCalls int
|
||||
createRunPrompt string
|
||||
)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/v1/agents":
|
||||
createAgentCalls++
|
||||
var body struct {
|
||||
Prompt struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"prompt"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
t.Fatalf("decode create agent body: %v", err)
|
||||
}
|
||||
if body.Prompt.Text == "" {
|
||||
t.Fatalf("expected first prompt text")
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"agent": map[string]any{"id": "bc-1"},
|
||||
"run": map[string]any{"id": "run-1", "agentId": "bc-1"},
|
||||
})
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/v1/agents/bc-1/runs/run-1":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "run-1",
|
||||
"agentId": "bc-1",
|
||||
"status": "FINISHED",
|
||||
"result": "first answer",
|
||||
"durationMs": 100,
|
||||
})
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/v1/agents/bc-1/runs":
|
||||
createRunCalls++
|
||||
var body struct {
|
||||
Prompt struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"prompt"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
t.Fatalf("decode follow-up run body: %v", err)
|
||||
}
|
||||
createRunPrompt = body.Prompt.Text
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"run": map[string]any{"id": "run-2"},
|
||||
})
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/v1/agents/bc-1/runs/run-2":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "run-2",
|
||||
"agentId": "bc-1",
|
||||
"status": "FINISHED",
|
||||
"result": "second answer",
|
||||
"durationMs": 120,
|
||||
})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
service := NewService()
|
||||
service.configDir = t.TempDir()
|
||||
service.providers = []ai.ProviderConfig{
|
||||
{
|
||||
ID: "provider-cursor",
|
||||
Type: "custom",
|
||||
APIFormat: "cursor-agent",
|
||||
BaseURL: server.URL + "/v1",
|
||||
APIKey: "cursor-key",
|
||||
},
|
||||
}
|
||||
service.activeProvider = "provider-cursor"
|
||||
|
||||
firstResult := service.AIChatSendInSession("session-1", []ai.Message{
|
||||
{Role: "user", Content: "hello"},
|
||||
}, nil)
|
||||
if firstResult["success"] != true {
|
||||
t.Fatalf("expected first send to succeed, got %#v", firstResult)
|
||||
}
|
||||
secondResult := service.AIChatSendInSession("session-1", []ai.Message{
|
||||
{Role: "user", Content: "hello"},
|
||||
{Role: "assistant", Content: "first answer"},
|
||||
{Role: "user", Content: "next"},
|
||||
}, nil)
|
||||
if secondResult["success"] != true {
|
||||
t.Fatalf("expected second send to succeed, got %#v", secondResult)
|
||||
}
|
||||
|
||||
if createAgentCalls != 1 {
|
||||
t.Fatalf("expected exactly one create-agent call, got %d", createAgentCalls)
|
||||
}
|
||||
if createRunCalls != 1 {
|
||||
t.Fatalf("expected exactly one follow-up run call, got %d", createRunCalls)
|
||||
}
|
||||
if createRunPrompt != "next" {
|
||||
t.Fatalf("expected follow-up run to send only delta message, got %q", createRunPrompt)
|
||||
}
|
||||
|
||||
sessionData, err := service.loadSessionFile("session-1")
|
||||
if err != nil {
|
||||
t.Fatalf("load session file: %v", err)
|
||||
}
|
||||
if string(sessionData.ProviderState) != `{"agentId":"bc-1","lastRunId":"run-2"}` {
|
||||
t.Fatalf("unexpected provider state: %s", string(sessionData.ProviderState))
|
||||
}
|
||||
var providerMessages []ai.Message
|
||||
if err := json.Unmarshal(sessionData.ProviderMessages, &providerMessages); err != nil {
|
||||
t.Fatalf("unmarshal provider messages: %v", err)
|
||||
}
|
||||
if len(providerMessages) != 4 || providerMessages[3].Content != "second answer" {
|
||||
t.Fatalf("unexpected provider messages: %#v", providerMessages)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user