feat(ai): 补齐 Cursor 与 CodeBuddy 会话态聊天链路

- 新增 SessionChatProvider 接口,补齐非流式对话的会话态复用能力
- 为 Cursor Agent 和 CodeBuddy CLI 同步实现流式与非流式会话续接及状态持久化
- CustomProvider 补充会话态透传,统一 custom provider 的会话复用行为
- Service 新增 AIChatSendInSession,聊天主链路非流式回退改走带 session 的发送接口
- 保留原 AIChatSend 无状态语义,避免标题生成和记忆压缩污染主会话上下文
- 补充前后端定向测试,覆盖会话恢复、续接发送和前端回退分流
This commit is contained in:
Syngnat
2026-06-18 13:35:08 +08:00
parent b588235b62
commit 06dd9507ee
12 changed files with 1392 additions and 137 deletions

View File

@@ -24,6 +24,10 @@ type CodeBuddyCLIProvider struct {
config ai.ProviderConfig
}
type codebuddySessionState struct {
SessionID string `json:"sessionId,omitempty"`
}
// NewCodeBuddyCLIProvider 创建 CodeBuddyCLIProvider 实例。
func NewCodeBuddyCLIProvider(config ai.ProviderConfig) (Provider, error) {
return &CodeBuddyCLIProvider{config: config}, nil
@@ -42,8 +46,13 @@ func (p *CodeBuddyCLIProvider) Validate() error {
}
func (p *CodeBuddyCLIProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.ChatResponse, error) {
resp, _, err := p.ChatWithState(ctx, nil, req)
return resp, err
}
func (p *CodeBuddyCLIProvider) ChatWithState(ctx context.Context, state json.RawMessage, req ai.ChatRequest) (*ai.ChatResponse, json.RawMessage, error) {
if err := p.Validate(); err != nil {
return nil, err
return nil, nil, err
}
ctx, cancel := ensureClaudeCLITimeout(ctx, codebuddyCLIRequestTimeout)
@@ -51,18 +60,25 @@ func (p *CodeBuddyCLIProvider) Chat(ctx context.Context, req ai.ChatRequest) (*a
commandName, err := resolveCodeBuddyCLICommand(codebuddyLookPath)
if err != nil {
return nil, err
return nil, nil, err
}
sessionState, err := parseCodeBuddySessionState(state)
if err != nil {
return nil, nil, err
}
prompt := buildPrompt(req.Messages)
args := []string{"-p", prompt, "--output-format", "json", "--no-session-persistence"}
args := []string{"-p", prompt, "--output-format", "json", "--enable-session-tracking"}
if strings.TrimSpace(p.config.Model) != "" {
args = append(args, "--model", strings.TrimSpace(p.config.Model))
}
if strings.TrimSpace(sessionState.SessionID) != "" {
args = append(args, "--resume", strings.TrimSpace(sessionState.SessionID))
}
cmd := codebuddyCommandContext(ctx, commandName, args...)
if err := p.setEnv(cmd); err != nil {
return nil, err
return nil, nil, err
}
requestLog := logAIUpstreamRequestStart(
@@ -80,27 +96,55 @@ func (p *CodeBuddyCLIProvider) Chat(ctx context.Context, req ai.ChatRequest) (*a
if err != nil {
if isClaudeCLITimeout(ctx, err) {
requestErr = fmt.Errorf("CodeBuddy CLI 执行超时(%s当前登录态、Base URL 或 API Key 可能没有返回有效响应", codebuddyCLIRequestTimeout)
return nil, requestErr
return nil, nil, requestErr
}
if exitErr, ok := err.(*exec.ExitError); ok {
requestErr = fmt.Errorf("CodeBuddy CLI 执行失败: %s", string(exitErr.Stderr))
return nil, requestErr
return nil, nil, requestErr
}
requestErr = fmt.Errorf("CodeBuddy CLI 执行失败: %w", err)
return nil, requestErr
return nil, nil, requestErr
}
resp, parseErr := parseCodeBuddyCLIChatOutput(output)
resp, nextSessionID, parseErr := parseCodeBuddyCLIChatOutput(output)
if parseErr != nil {
requestErr = parseErr
return nil, requestErr
return nil, nil, requestErr
}
return resp, nil
if strings.TrimSpace(nextSessionID) == "" {
nextSessionID = strings.TrimSpace(sessionState.SessionID)
}
nextState, err := marshalCodeBuddySessionState(nextSessionID)
if err != nil {
return nil, nil, err
}
return resp, nextState, nil
}
func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error {
_, err := p.ChatStreamWithState(ctx, nil, req, callback)
return err
}
func (p *CodeBuddyCLIProvider) ChatStreamWithState(ctx context.Context, state json.RawMessage, req ai.ChatRequest, callback func(ai.StreamChunk)) (json.RawMessage, error) {
sessionState, err := parseCodeBuddySessionState(state)
if err != nil {
return nil, err
}
sessionID, err := p.chatStreamWithSession(ctx, strings.TrimSpace(sessionState.SessionID), req, callback)
if err != nil {
return nil, err
}
if strings.TrimSpace(sessionID) == "" {
sessionID = strings.TrimSpace(sessionState.SessionID)
}
return marshalCodeBuddySessionState(sessionID)
}
func (p *CodeBuddyCLIProvider) chatStreamWithSession(ctx context.Context, resumeSessionID string, req ai.ChatRequest, callback func(ai.StreamChunk)) (string, error) {
if err := p.Validate(); err != nil {
return err
return "", err
}
ctx, cancel := ensureClaudeCLITimeout(ctx, codebuddyCLIRequestTimeout)
@@ -108,18 +152,21 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques
commandName, err := resolveCodeBuddyCLICommand(codebuddyLookPath)
if err != nil {
return err
return "", err
}
prompt := buildPrompt(req.Messages)
args := []string{"-p", prompt, "--output-format", "stream-json", "--verbose", "--include-partial-messages", "--no-session-persistence"}
args := []string{"-p", prompt, "--output-format", "stream-json", "--verbose", "--include-partial-messages", "--enable-session-tracking"}
if strings.TrimSpace(p.config.Model) != "" {
args = append(args, "--model", strings.TrimSpace(p.config.Model))
}
if strings.TrimSpace(resumeSessionID) != "" {
args = append(args, "--resume", strings.TrimSpace(resumeSessionID))
}
cmd := codebuddyCommandContext(ctx, commandName, args...)
if err := p.setEnv(cmd); err != nil {
return err
return "", err
}
requestLog := logAIUpstreamRequestStart(
@@ -138,7 +185,7 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques
stdout, err := cmd.StdoutPipe()
if err != nil {
requestErr = fmt.Errorf("创建 stdout 管道失败: %w", err)
return requestErr
return "", requestErr
}
var stderrBuf bytes.Buffer
@@ -146,7 +193,7 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques
if err := cmd.Start(); err != nil {
requestErr = fmt.Errorf("启动 CodeBuddy CLI 失败: %w", err)
return requestErr
return "", requestErr
}
if cmd.Process != nil {
@@ -155,6 +202,7 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques
scanner := bufio.NewScanner(stdout)
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
currentSessionID := strings.TrimSpace(resumeSessionID)
for scanner.Scan() {
line := scanner.Text()
@@ -167,6 +215,9 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques
logger.Warnf("CodeBuddyCLI 忽略非 JSON 输出requestId=%s line=%s", requestLog.id, RedactAIUpstreamLogText(line))
continue
}
if strings.TrimSpace(event.SessionID) != "" {
currentSessionID = strings.TrimSpace(event.SessionID)
}
switch event.Type {
case "system":
@@ -178,7 +229,7 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques
_ = cmd.Process.Kill()
}
_ = cmd.Wait()
return nil
return "", nil
}
}
case "assistant":
@@ -186,7 +237,7 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques
callback(ai.StreamChunk{Error: errMsg, Done: true})
requestErr = fmt.Errorf("CodeBuddy CLI 返回错误: %s", errMsg)
_ = cmd.Wait()
return nil
return "", nil
}
if event.Message.Content != nil {
for _, block := range event.Message.Content {
@@ -208,17 +259,17 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques
callback(ai.StreamChunk{Error: errMsg, Done: true})
requestErr = fmt.Errorf("CodeBuddy CLI 返回错误: %s", errMsg)
_ = cmd.Wait()
return nil
return "", nil
}
callback(ai.StreamChunk{Done: true})
_ = cmd.Wait()
return nil
return currentSessionID, nil
case "error":
errMsg, _ := extractCodeBuddyCLIEventError(event)
callback(ai.StreamChunk{Error: errMsg, Done: true})
requestErr = fmt.Errorf("CodeBuddy CLI 返回错误: %s", errMsg)
_ = cmd.Wait()
return nil
return "", nil
}
}
@@ -231,7 +282,7 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques
Error: requestErr.Error(),
Done: true,
})
return nil
return "", nil
}
if waitErr != nil {
@@ -241,11 +292,38 @@ func (p *CodeBuddyCLIProvider) ChatStream(ctx context.Context, req ai.ChatReques
}
requestErr = fmt.Errorf("%s", errMsg)
callback(ai.StreamChunk{Error: errMsg, Done: true})
return nil
return "", nil
}
callback(ai.StreamChunk{Done: true})
return nil
return currentSessionID, nil
}
func parseCodeBuddySessionState(state json.RawMessage) (codebuddySessionState, error) {
trimmed := bytes.TrimSpace(state)
if len(trimmed) == 0 {
return codebuddySessionState{}, nil
}
var sessionState codebuddySessionState
if err := json.Unmarshal(trimmed, &sessionState); err != nil {
return codebuddySessionState{}, fmt.Errorf("解析 CodeBuddy 会话状态失败: %w", err)
}
sessionState.SessionID = strings.TrimSpace(sessionState.SessionID)
return sessionState, nil
}
func marshalCodeBuddySessionState(sessionID string) (json.RawMessage, error) {
sessionID = strings.TrimSpace(sessionID)
if sessionID == "" {
return nil, nil
}
payload, err := json.Marshal(codebuddySessionState{SessionID: sessionID})
if err != nil {
return nil, fmt.Errorf("序列化 CodeBuddy 会话状态失败: %w", err)
}
return json.RawMessage(payload), nil
}
func resolveCodeBuddyCLICommand(lookPath func(string) (string, error)) (string, error) {
@@ -280,10 +358,10 @@ func buildCodeBuddyCLIRequestLogBody(outputFormat string, commandName string, ar
}
}
func parseCodeBuddyCLIChatOutput(output []byte) (*ai.ChatResponse, error) {
func parseCodeBuddyCLIChatOutput(output []byte) (*ai.ChatResponse, string, error) {
trimmed := bytes.TrimSpace(output)
if len(trimmed) == 0 {
return &ai.ChatResponse{}, nil
return &ai.ChatResponse{}, "", nil
}
var events []cliStreamEvent
@@ -296,20 +374,24 @@ func parseCodeBuddyCLIChatOutput(output []byte) (*ai.ChatResponse, error) {
return buildCodeBuddyCLIResponseFromEvents([]cliStreamEvent{event})
}
return &ai.ChatResponse{Content: strings.TrimSpace(string(output))}, nil
return &ai.ChatResponse{Content: strings.TrimSpace(string(output))}, "", nil
}
func buildCodeBuddyCLIResponseFromEvents(events []cliStreamEvent) (*ai.ChatResponse, error) {
func buildCodeBuddyCLIResponseFromEvents(events []cliStreamEvent) (*ai.ChatResponse, string, error) {
parts := make([]string, 0, len(events))
resultText := ""
sessionID := ""
for _, event := range events {
if errMsg, hasError := extractCodeBuddyCLIEventError(event); hasError {
return nil, fmt.Errorf("CodeBuddy CLI 返回错误: %s", errMsg)
return nil, "", fmt.Errorf("CodeBuddy CLI 返回错误: %s", errMsg)
}
if strings.TrimSpace(event.Result) != "" {
resultText = strings.TrimSpace(event.Result)
}
if strings.TrimSpace(event.SessionID) != "" {
sessionID = strings.TrimSpace(event.SessionID)
}
for _, block := range event.Message.Content {
if block.Type == "text" && strings.TrimSpace(block.Text) != "" {
parts = append(parts, block.Text)
@@ -318,12 +400,12 @@ func buildCodeBuddyCLIResponseFromEvents(events []cliStreamEvent) (*ai.ChatRespo
}
if resultText != "" {
return &ai.ChatResponse{Content: resultText}, nil
return &ai.ChatResponse{Content: resultText}, sessionID, nil
}
if len(parts) > 0 {
return &ai.ChatResponse{Content: strings.Join(parts, "")}, nil
return &ai.ChatResponse{Content: strings.Join(parts, "")}, sessionID, nil
}
return &ai.ChatResponse{}, nil
return &ai.ChatResponse{}, sessionID, nil
}
func (p *CodeBuddyCLIProvider) setEnv(cmd *exec.Cmd) error {

View File

@@ -2,6 +2,7 @@ package provider
import (
"context"
"encoding/json"
"errors"
"os"
"os/exec"
@@ -79,6 +80,191 @@ func TestCodeBuddyCLIProvider_ChatParsesJSONEventArray(t *testing.T) {
}
}
func TestCodeBuddyCLIProviderChatWithState_StartsTrackedSession(t *testing.T) {
fakeCodeBuddy := writeFakeCodeBuddyScript(t, "#!/bin/sh\necho '[{\"type\":\"assistant\",\"session_id\":\"session-new\",\"message\":{\"content\":[{\"type\":\"text\",\"text\":\"hello \"}]}},{\"type\":\"result\",\"subtype\":\"success\",\"is_error\":false,\"result\":\"hello world\",\"session_id\":\"session-new\"}]'\n")
var capturedArgs []string
restore := overrideCodeBuddyCLIForTestWithCapture(t, fakeCodeBuddy, func(args []string) {
capturedArgs = append([]string(nil), args...)
})
defer restore()
providerInstance, err := NewCodeBuddyCLIProvider(ai.ProviderConfig{
APIKey: "cb-test",
Model: "deepseek-v3",
})
if err != nil {
t.Fatalf("unexpected provider error: %v", err)
}
resp, nextState, err := providerInstance.(SessionChatProvider).ChatWithState(
context.Background(),
nil,
ai.ChatRequest{
Messages: []ai.Message{{Role: "user", Content: "ping"}},
},
)
if err != nil {
t.Fatalf("expected chat with state to succeed, got %v", err)
}
if resp == nil || resp.Content != "hello world" {
t.Fatalf("unexpected response: %#v", resp)
}
if string(nextState) != `{"sessionId":"session-new"}` {
t.Fatalf("expected new session state, got %s", string(nextState))
}
if !hasArg(capturedArgs, "--enable-session-tracking") {
t.Fatalf("expected session tracking flag, got args %#v", capturedArgs)
}
if hasArg(capturedArgs, "--no-session-persistence") {
t.Fatalf("did not expect no-session-persistence flag, got args %#v", capturedArgs)
}
if hasArg(capturedArgs, "--resume") {
t.Fatalf("did not expect resume flag for first session, got args %#v", capturedArgs)
}
if !hasArgSequence(capturedArgs, "--model", "deepseek-v3") {
t.Fatalf("expected model flag to be preserved, got args %#v", capturedArgs)
}
}
func TestCodeBuddyCLIProviderChatWithState_ResumesExistingSession(t *testing.T) {
fakeCodeBuddy := writeFakeCodeBuddyScript(t, "#!/bin/sh\necho '[{\"type\":\"assistant\",\"message\":{\"content\":[{\"type\":\"text\",\"text\":\"continued\"}]}}]'\n")
var capturedArgs []string
restore := overrideCodeBuddyCLIForTestWithCapture(t, fakeCodeBuddy, func(args []string) {
capturedArgs = append([]string(nil), args...)
})
defer restore()
providerInstance, err := NewCodeBuddyCLIProvider(ai.ProviderConfig{
APIKey: "cb-test",
})
if err != nil {
t.Fatalf("unexpected provider error: %v", err)
}
resp, nextState, err := providerInstance.(SessionChatProvider).ChatWithState(
context.Background(),
json.RawMessage(`{"sessionId":"session-existing"}`),
ai.ChatRequest{
Messages: []ai.Message{{Role: "user", Content: "ping again"}},
},
)
if err != nil {
t.Fatalf("expected resumed chat with state to succeed, got %v", err)
}
if resp == nil || resp.Content != "continued" {
t.Fatalf("unexpected response: %#v", resp)
}
if string(nextState) != `{"sessionId":"session-existing"}` {
t.Fatalf("expected existing session state to be preserved, got %s", string(nextState))
}
if !hasArgSequence(capturedArgs, "--resume", "session-existing") {
t.Fatalf("expected resume args, got %#v", capturedArgs)
}
if !hasArg(capturedArgs, "--enable-session-tracking") {
t.Fatalf("expected session tracking flag, got args %#v", capturedArgs)
}
}
func TestCodeBuddyCLIProviderChatStreamWithState_StartsTrackedSession(t *testing.T) {
fakeCodeBuddy := writeFakeCodeBuddyScript(t, "#!/bin/sh\nprintf '%s\\n' '{\"type\":\"system\",\"session_id\":\"session-new\"}' '{\"type\":\"assistant\",\"message\":{\"content\":[{\"type\":\"text\",\"text\":\"hello from codebuddy\"}]}}' '{\"type\":\"result\",\"subtype\":\"success\",\"is_error\":false,\"result\":\"hello from codebuddy\",\"session_id\":\"session-new\"}'\n")
var capturedArgs []string
restore := overrideCodeBuddyCLIForTestWithCapture(t, fakeCodeBuddy, func(args []string) {
capturedArgs = append([]string(nil), args...)
})
defer restore()
providerInstance, err := NewCodeBuddyCLIProvider(ai.ProviderConfig{
APIKey: "cb-test",
Model: "deepseek-v3",
})
if err != nil {
t.Fatalf("unexpected provider error: %v", err)
}
var chunks []ai.StreamChunk
nextState, err := providerInstance.(SessionStreamProvider).ChatStreamWithState(
context.Background(),
nil,
ai.ChatRequest{
Messages: []ai.Message{{Role: "user", Content: "ping"}},
},
func(chunk ai.StreamChunk) {
chunks = append(chunks, chunk)
},
)
if err != nil {
t.Fatalf("expected chat stream with state to succeed, got %v", err)
}
if string(nextState) != `{"sessionId":"session-new"}` {
t.Fatalf("expected new session state, got %s", string(nextState))
}
if len(chunks) < 2 || chunks[0].Content != "hello from codebuddy" || !chunks[len(chunks)-1].Done {
t.Fatalf("unexpected stream chunks: %#v", chunks)
}
if !hasArg(capturedArgs, "--enable-session-tracking") {
t.Fatalf("expected session tracking flag, got args %#v", capturedArgs)
}
if hasArg(capturedArgs, "--no-session-persistence") {
t.Fatalf("did not expect no-session-persistence flag, got args %#v", capturedArgs)
}
if hasArg(capturedArgs, "--resume") {
t.Fatalf("did not expect resume flag for first session, got args %#v", capturedArgs)
}
if !hasArgSequence(capturedArgs, "--model", "deepseek-v3") {
t.Fatalf("expected model flag to be preserved, got args %#v", capturedArgs)
}
}
func TestCodeBuddyCLIProviderChatStreamWithState_ResumesExistingSessionWithoutDroppingState(t *testing.T) {
fakeCodeBuddy := writeFakeCodeBuddyScript(t, "#!/bin/sh\nprintf '%s\\n' '{\"type\":\"assistant\",\"message\":{\"content\":[{\"type\":\"text\",\"text\":\"continued\"}]}}' '{\"type\":\"result\",\"subtype\":\"success\",\"is_error\":false,\"result\":\"continued\"}'\n")
var capturedArgs []string
restore := overrideCodeBuddyCLIForTestWithCapture(t, fakeCodeBuddy, func(args []string) {
capturedArgs = append([]string(nil), args...)
})
defer restore()
providerInstance, err := NewCodeBuddyCLIProvider(ai.ProviderConfig{
APIKey: "cb-test",
})
if err != nil {
t.Fatalf("unexpected provider error: %v", err)
}
var chunks []ai.StreamChunk
nextState, err := providerInstance.(SessionStreamProvider).ChatStreamWithState(
context.Background(),
json.RawMessage(`{"sessionId":"session-existing"}`),
ai.ChatRequest{
Messages: []ai.Message{{Role: "user", Content: "ping again"}},
},
func(chunk ai.StreamChunk) {
chunks = append(chunks, chunk)
},
)
if err != nil {
t.Fatalf("expected resumed chat stream to succeed, got %v", err)
}
if string(nextState) != `{"sessionId":"session-existing"}` {
t.Fatalf("expected existing session state to be preserved, got %s", string(nextState))
}
if len(chunks) < 2 || chunks[0].Content != "continued" || !chunks[len(chunks)-1].Done {
t.Fatalf("unexpected stream chunks: %#v", chunks)
}
if !hasArgSequence(capturedArgs, "--resume", "session-existing") {
t.Fatalf("expected resume args, got %#v", capturedArgs)
}
if !hasArg(capturedArgs, "--enable-session-tracking") {
t.Fatalf("expected session tracking flag, got args %#v", capturedArgs)
}
if hasArg(capturedArgs, "--no-session-persistence") {
t.Fatalf("did not expect no-session-persistence flag, got args %#v", capturedArgs)
}
}
func writeFakeCodeBuddyScript(t *testing.T, content string) string {
t.Helper()
dir := t.TempDir()
@@ -138,3 +324,54 @@ func overrideCodeBuddyCLIForTest(t *testing.T, fakeCodeBuddyPath string) func()
_ = os.Setenv("PATH", originalPath)
}
}
func overrideCodeBuddyCLIForTestWithCapture(t *testing.T, fakeCodeBuddyPath string, capture func(args []string)) func() {
t.Helper()
originalLookPath := codebuddyLookPath
originalCommandContext := codebuddyCommandContext
codebuddyLookPath = func(name string) (string, error) {
if name == "codebuddy" || name == "cbc" {
return fakeCodeBuddyPath, nil
}
return originalLookPath(name)
}
codebuddyCommandContext = func(ctx context.Context, name string, args ...string) *exec.Cmd {
if name == "codebuddy" || name == "cbc" {
if capture != nil {
capture(args)
}
return exec.CommandContext(ctx, fakeCodeBuddyPath, args...)
}
return originalCommandContext(ctx, name, args...)
}
originalPath := os.Getenv("PATH")
if err := os.Setenv("PATH", filepath.Dir(fakeCodeBuddyPath)+string(os.PathListSeparator)+originalPath); err != nil {
t.Fatalf("failed to override PATH: %v", err)
}
return func() {
codebuddyLookPath = originalLookPath
codebuddyCommandContext = originalCommandContext
_ = os.Setenv("PATH", originalPath)
}
}
func hasArg(args []string, target string) bool {
for _, arg := range args {
if arg == target {
return true
}
}
return false
}
func hasArgSequence(args []string, key string, value string) bool {
for index := 0; index < len(args)-1; index++ {
if args[index] == key && args[index+1] == value {
return true
}
}
return false
}

View File

@@ -22,13 +22,24 @@ const (
)
// CursorAgentProvider 通过 Cursor Cloud Agents API 发起对话。
// 当前实现为无状态适配:每次请求都创建一个新的 agent再消费本次 run 的结果
// 支持基于 session state 复用已有 agent并对 follow-up runs 继续追加上下文
type CursorAgentProvider struct {
config ai.ProviderConfig
baseURL string
client *http.Client
}
type cursorSessionState struct {
AgentID string `json:"agentId,omitempty"`
LastRunID string `json:"lastRunId,omitempty"`
}
type cursorImageInput struct {
Data string `json:"data,omitempty"`
URL string `json:"url,omitempty"`
MimeType string `json:"mimeType,omitempty"`
}
// NewCursorAgentProvider 创建 Cursor Agent Provider。
func NewCursorAgentProvider(config ai.ProviderConfig) (Provider, error) {
normalized := config
@@ -134,7 +145,8 @@ func normalizeCursorAPIPath(path string) string {
}
type cursorPrompt struct {
Text string `json:"text"`
Text string `json:"text"`
Images []cursorImageInput `json:"images,omitempty"`
}
type cursorModelSelection struct {
@@ -146,6 +158,10 @@ type cursorCreateAgentRequest struct {
Model *cursorModelSelection `json:"model,omitempty"`
}
type cursorCreateRunRequest struct {
Prompt cursorPrompt `json:"prompt"`
}
type cursorCreateAgentResponse struct {
Agent struct {
ID string `json:"id"`
@@ -181,38 +197,85 @@ type cursorResultEvent struct {
}
func (p *CursorAgentProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.ChatResponse, error) {
resp, _, err := p.ChatWithState(ctx, nil, req)
return resp, err
}
func (p *CursorAgentProvider) ChatWithState(ctx context.Context, state json.RawMessage, req ai.ChatRequest) (*ai.ChatResponse, json.RawMessage, error) {
if err := p.Validate(); err != nil {
return nil, err
return nil, nil, err
}
agentID, runID, err := p.createAgent(ctx, req)
sessionState, err := parseCursorSessionState(state)
if err != nil {
return nil, err
return nil, nil, err
}
agentID := strings.TrimSpace(sessionState.AgentID)
runID := ""
if agentID == "" {
agentID, runID, err = p.createAgent(ctx, req)
if err != nil {
return nil, nil, err
}
} else {
runID, err = p.createRun(ctx, agentID, req)
if err != nil {
return nil, nil, err
}
}
run, err := p.waitForRun(ctx, agentID, runID)
if err != nil {
return nil, err
return nil, nil, err
}
sessionState.AgentID = agentID
sessionState.LastRunID = runID
nextState, err := json.Marshal(sessionState)
if err != nil {
return nil, nil, fmt.Errorf("序列化 Cursor 会话状态失败: %w", err)
}
return &ai.ChatResponse{
Content: strings.TrimSpace(run.Result),
}, nil
}, json.RawMessage(nextState), nil
}
func (p *CursorAgentProvider) ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error {
_, err := p.ChatStreamWithState(ctx, nil, req, callback)
return err
}
func (p *CursorAgentProvider) ChatStreamWithState(ctx context.Context, state json.RawMessage, req ai.ChatRequest, callback func(ai.StreamChunk)) (json.RawMessage, error) {
if err := p.Validate(); err != nil {
return err
return nil, err
}
agentID, runID, err := p.createAgent(ctx, req)
sessionState, err := parseCursorSessionState(state)
if err != nil {
return err
return nil, err
}
agentID := strings.TrimSpace(sessionState.AgentID)
runID := ""
if agentID == "" {
agentID, runID, err = p.createAgent(ctx, req)
if err != nil {
return nil, err
}
} else {
runID, err = p.createRun(ctx, agentID, req)
if err != nil {
return nil, err
}
}
sessionState.AgentID = agentID
sessionState.LastRunID = runID
stream, err := p.openRunStream(ctx, agentID, runID)
if err != nil {
return err
return nil, err
}
defer stream.Close()
@@ -314,10 +377,10 @@ func (p *CursorAgentProvider) ChatStream(ctx context.Context, req ai.ChatRequest
currentEventType = ""
currentDataLines = nil
if dispatchErr != nil {
return dispatchErr
return nil, dispatchErr
}
if done {
return nil
return marshalCursorSessionState(sessionState)
}
case strings.HasPrefix(line, "event:"):
currentEventType = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
@@ -327,27 +390,27 @@ func (p *CursorAgentProvider) ChatStream(ctx context.Context, req ai.ChatRequest
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("读取 Cursor 流式响应失败: %w", err)
return nil, fmt.Errorf("读取 Cursor 流式响应失败: %w", err)
}
if len(currentDataLines) > 0 || strings.TrimSpace(currentEventType) != "" {
done, dispatchErr := dispatchEvent(currentEventType, currentDataLines)
if dispatchErr != nil {
return dispatchErr
return nil, dispatchErr
}
if done {
return nil
return marshalCursorSessionState(sessionState)
}
}
if !completedExplicitly {
if !receivedAssistantText && !receivedResultText {
callback(ai.StreamChunk{Error: "未收到任何有效响应内容,请检查 Cursor 配置或模型权限", Done: true})
return nil
return marshalCursorSessionState(sessionState)
}
callback(ai.StreamChunk{Done: true})
}
return nil
return marshalCursorSessionState(sessionState)
}
func (p *CursorAgentProvider) createAgent(ctx context.Context, req ai.ChatRequest) (string, string, error) {
@@ -370,15 +433,13 @@ func (p *CursorAgentProvider) createAgent(ctx context.Context, req ai.ChatReques
}
func buildCursorCreateAgentRequest(req ai.ChatRequest, model string) (cursorCreateAgentRequest, error) {
prompt, err := buildCursorPrompt(req.Messages)
prompt, err := buildCursorPromptInput(req.Messages)
if err != nil {
return cursorCreateAgentRequest{}, err
}
requestBody := cursorCreateAgentRequest{
Prompt: cursorPrompt{
Text: prompt,
},
Prompt: prompt,
}
if trimmedModel := strings.TrimSpace(model); trimmedModel != "" {
@@ -389,18 +450,59 @@ func buildCursorCreateAgentRequest(req ai.ChatRequest, model string) (cursorCrea
}
func buildCursorPrompt(messages []ai.Message) (string, error) {
requestMessages := messages
if requestMessagesContainImages(messages) {
requestMessages = stripImagesFromRequestMessages(messages)
prompt := strings.TrimSpace(buildPrompt(messages))
if prompt == "" && requestMessagesContainImages(messages) {
return "请结合这些图片继续分析并回答。", nil
}
prompt := strings.TrimSpace(buildPrompt(requestMessages))
if prompt == "" {
return "", fmt.Errorf("请求内容不能为空")
}
return prompt, nil
}
func buildCursorPromptInput(messages []ai.Message) (cursorPrompt, error) {
text, err := buildCursorPrompt(messages)
if err != nil {
return cursorPrompt{}, err
}
images, err := buildCursorImageInputs(messages)
if err != nil {
return cursorPrompt{}, err
}
return cursorPrompt{
Text: text,
Images: images,
}, nil
}
func buildCursorImageInputs(messages []ai.Message) ([]cursorImageInput, error) {
images := make([]cursorImageInput, 0)
for _, message := range messages {
for _, img := range message.Images {
trimmed := strings.TrimSpace(img)
if trimmed == "" {
continue
}
if strings.HasPrefix(trimmed, "http://") || strings.HasPrefix(trimmed, "https://") {
images = append(images, cursorImageInput{URL: trimmed})
continue
}
mimeType, rawBase64, err := ParseDataURI(trimmed)
if err != nil {
return nil, fmt.Errorf("解析图片数据失败: %w", err)
}
images = append(images, cursorImageInput{
Data: rawBase64,
MimeType: mimeType,
})
}
}
if len(images) > 5 {
return nil, fmt.Errorf("Cursor 最多支持 5 张图片,当前请求包含 %d 张", len(images))
}
return images, nil
}
func (p *CursorAgentProvider) waitForRun(ctx context.Context, agentID string, runID string) (*cursorRunResponse, error) {
ticker := time.NewTicker(cursorRunPollInterval)
defer ticker.Stop()
@@ -455,6 +557,31 @@ func (p *CursorAgentProvider) getRun(ctx context.Context, agentID string, runID
return &responseBody, nil
}
func (p *CursorAgentProvider) createRun(ctx context.Context, agentID string, req ai.ChatRequest) (string, error) {
prompt, err := buildCursorPromptInput(req.Messages)
if err != nil {
return "", err
}
requestBody := cursorCreateRunRequest{
Prompt: prompt,
}
var responseBody struct {
Run struct {
ID string `json:"id"`
} `json:"run"`
}
if err := p.doJSONRequest(ctx, http.MethodPost, ResolveCursorAPIEndpoint(p.baseURL, fmt.Sprintf("agents/%s/runs", agentID)), requestBody, &responseBody, "application/json"); err != nil {
return "", err
}
runID := strings.TrimSpace(responseBody.Run.ID)
if runID == "" {
return "", fmt.Errorf("Cursor 创建 follow-up run 成功,但未返回有效 runId")
}
return runID, nil
}
func (p *CursorAgentProvider) openRunStream(ctx context.Context, agentID string, runID string) (io.ReadCloser, error) {
endpoint := ResolveCursorAPIEndpoint(p.baseURL, fmt.Sprintf("agents/%s/runs/%s/stream", agentID, runID))
requestLog := logAIUpstreamRequestStart(p.Name(), http.MethodGet, endpoint, nil)
@@ -541,6 +668,28 @@ func (p *CursorAgentProvider) doJSONRequest(ctx context.Context, method string,
return nil
}
func parseCursorSessionState(state json.RawMessage) (cursorSessionState, error) {
if len(state) == 0 {
return cursorSessionState{}, nil
}
var result cursorSessionState
if err := json.Unmarshal(state, &result); err != nil {
return cursorSessionState{}, fmt.Errorf("解析 Cursor 会话状态失败: %w", err)
}
return result, nil
}
func marshalCursorSessionState(state cursorSessionState) (json.RawMessage, error) {
if strings.TrimSpace(state.AgentID) == "" {
return nil, nil
}
bytes, err := json.Marshal(state)
if err != nil {
return nil, fmt.Errorf("序列化 Cursor 会话状态失败: %w", err)
}
return json.RawMessage(bytes), nil
}
func isCursorRunTerminalStatus(status string) bool {
switch strings.ToUpper(strings.TrimSpace(status)) {
case "FINISHED", "ERROR", "CANCELLED", "EXPIRED":

View File

@@ -99,6 +99,85 @@ func TestCursorAgentProviderChat_PollsUntilFinished(t *testing.T) {
}
}
func TestCursorAgentProviderChatWithState_UsesFollowUpRunsAndPreservesAgent(t *testing.T) {
var (
createAgentCalls int32
createRunCalls int32
receivedPrompt string
)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodPost && r.URL.Path == "/v1/agents":
atomic.AddInt32(&createAgentCalls, 1)
t.Fatalf("expected follow-up request to avoid creating a new agent")
case r.Method == http.MethodPost && r.URL.Path == "/v1/agents/bc-existing/runs":
atomic.AddInt32(&createRunCalls, 1)
var body struct {
Prompt struct {
Text string `json:"text"`
} `json:"prompt"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Fatalf("decode follow-up run body: %v", err)
}
receivedPrompt = body.Prompt.Text
_ = json.NewEncoder(w).Encode(map[string]any{
"run": map[string]any{"id": "run-next"},
})
case r.Method == http.MethodGet && r.URL.Path == "/v1/agents/bc-existing/runs/run-next":
_ = json.NewEncoder(w).Encode(map[string]any{
"id": "run-next",
"agentId": "bc-existing",
"status": "FINISHED",
"result": "done from follow-up",
"durationMs": 456,
})
default:
http.NotFound(w, r)
}
}))
defer server.Close()
providerInstance, err := NewCursorAgentProvider(ai.ProviderConfig{
Name: "Cursor",
BaseURL: server.URL + "/v1",
APIKey: "cursor-key",
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
resp, nextState, err := providerInstance.(SessionChatProvider).ChatWithState(
context.Background(),
json.RawMessage(`{"agentId":"bc-existing","lastRunId":"run-old"}`),
ai.ChatRequest{
Messages: []ai.Message{
{Role: "user", Content: "follow this up"},
},
},
)
if err != nil {
t.Fatalf("follow-up chat failed: %v", err)
}
if atomic.LoadInt32(&createAgentCalls) != 0 {
t.Fatalf("expected no create-agent calls, got %d", createAgentCalls)
}
if atomic.LoadInt32(&createRunCalls) != 1 {
t.Fatalf("expected exactly one follow-up run call, got %d", createRunCalls)
}
if !strings.Contains(receivedPrompt, "follow this up") {
t.Fatalf("expected follow-up prompt text, got %q", receivedPrompt)
}
if resp == nil || resp.Content != "done from follow-up" {
t.Fatalf("unexpected response: %#v", resp)
}
if string(nextState) != `{"agentId":"bc-existing","lastRunId":"run-next"}` {
t.Fatalf("unexpected next session state: %s", string(nextState))
}
}
func TestCursorAgentProviderChatStream_MapsAssistantAndThinkingEvents(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
@@ -164,3 +243,109 @@ func TestCursorAgentProviderChatStream_MapsAssistantAndThinkingEvents(t *testing
t.Fatalf("expected final done chunk, got %#v", chunks[len(chunks)-1])
}
}
func TestCursorAgentProviderChatStreamWithState_UsesFollowUpRunsAndPreservesAgent(t *testing.T) {
var (
createAgentCalls int32
createRunCalls int32
receivedPrompt string
)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodPost && r.URL.Path == "/v1/agents":
atomic.AddInt32(&createAgentCalls, 1)
t.Fatalf("expected follow-up request to avoid creating a new agent")
case r.Method == http.MethodPost && r.URL.Path == "/v1/agents/bc-existing/runs":
atomic.AddInt32(&createRunCalls, 1)
var body struct {
Prompt struct {
Text string `json:"text"`
} `json:"prompt"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Fatalf("decode follow-up run body: %v", err)
}
receivedPrompt = body.Prompt.Text
_ = json.NewEncoder(w).Encode(map[string]any{
"run": map[string]any{"id": "run-next"},
})
case r.Method == http.MethodGet && r.URL.Path == "/v1/agents/bc-existing/runs/run-next/stream":
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("event: assistant\n"))
_, _ = w.Write([]byte("data: {\"text\":\"done\"}\n\n"))
_, _ = w.Write([]byte("event: done\n"))
_, _ = w.Write([]byte("data: {}\n\n"))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
providerInstance, err := NewCursorAgentProvider(ai.ProviderConfig{
Name: "Cursor",
BaseURL: server.URL + "/v1",
APIKey: "cursor-key",
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
sessionState := json.RawMessage(`{"agentId":"bc-existing","lastRunId":"run-old"}`)
var chunks []ai.StreamChunk
nextState, err := providerInstance.(SessionStreamProvider).ChatStreamWithState(
context.Background(),
sessionState,
ai.ChatRequest{
Messages: []ai.Message{
{Role: "user", Content: "follow this up"},
},
},
func(chunk ai.StreamChunk) {
chunks = append(chunks, chunk)
},
)
if err != nil {
t.Fatalf("follow-up stream failed: %v", err)
}
if atomic.LoadInt32(&createAgentCalls) != 0 {
t.Fatalf("expected no create-agent calls, got %d", createAgentCalls)
}
if atomic.LoadInt32(&createRunCalls) != 1 {
t.Fatalf("expected exactly one follow-up run call, got %d", createRunCalls)
}
if !strings.Contains(receivedPrompt, "follow this up") {
t.Fatalf("expected follow-up prompt text, got %q", receivedPrompt)
}
if string(nextState) != `{"agentId":"bc-existing","lastRunId":"run-next"}` {
t.Fatalf("unexpected next session state: %s", string(nextState))
}
if len(chunks) == 0 || chunks[0].Content != "done" {
t.Fatalf("expected streamed assistant content, got %#v", chunks)
}
}
func TestCursorAgentProviderCreateAgentRequest_IncludesImageInputs(t *testing.T) {
requestBody, err := buildCursorCreateAgentRequest(ai.ChatRequest{
Messages: []ai.Message{
{
Role: "user",
Content: "look at this",
Images: []string{"data:image/png;base64,aGVsbG8="},
},
},
}, "composer-latest")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(requestBody.Prompt.Images) != 1 {
t.Fatalf("expected one image payload, got %#v", requestBody.Prompt.Images)
}
if requestBody.Prompt.Images[0].Data != "aGVsbG8=" || requestBody.Prompt.Images[0].MimeType != "image/png" {
t.Fatalf("unexpected image payload: %#v", requestBody.Prompt.Images[0])
}
if requestBody.Model == nil || requestBody.Model.ID != "composer-latest" {
t.Fatalf("expected model selection to be preserved, got %#v", requestBody.Model)
}
}

View File

@@ -2,6 +2,7 @@ package provider
import (
"context"
"encoding/json"
"fmt"
"strings"
@@ -75,3 +76,20 @@ func (p *CustomProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.Chat
func (p *CustomProvider) ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error {
return p.inner.ChatStream(ctx, req, callback)
}
func (p *CustomProvider) ChatWithState(ctx context.Context, state json.RawMessage, req ai.ChatRequest) (*ai.ChatResponse, json.RawMessage, error) {
sessionProvider, ok := p.inner.(SessionChatProvider)
if !ok {
resp, err := p.inner.Chat(ctx, req)
return resp, nil, err
}
return sessionProvider.ChatWithState(ctx, state, req)
}
func (p *CustomProvider) ChatStreamWithState(ctx context.Context, state json.RawMessage, req ai.ChatRequest, callback func(ai.StreamChunk)) (json.RawMessage, error) {
sessionProvider, ok := p.inner.(SessionStreamProvider)
if !ok {
return nil, p.inner.ChatStream(ctx, req, callback)
}
return sessionProvider.ChatStreamWithState(ctx, state, req, callback)
}

View File

@@ -2,6 +2,7 @@ package provider
import (
"context"
"encoding/json"
"GoNavi-Wails/internal/ai"
)
@@ -17,3 +18,24 @@ type Provider interface {
// Validate 校验配置是否有效
Validate() error
}
// SessionStreamProvider 表示支持按会话复用上游状态的流式 Provider。
// state 为 Provider 自己维护的持久化状态;返回值为更新后的状态快照。
type SessionStreamProvider interface {
ChatStreamWithState(
ctx context.Context,
state json.RawMessage,
req ai.ChatRequest,
callback func(ai.StreamChunk),
) (json.RawMessage, error)
}
// SessionChatProvider 表示支持按会话复用上游状态的非流式 Provider。
// state 为 Provider 自己维护的持久化状态;返回值为响应体和更新后的状态快照。
type SessionChatProvider interface {
ChatWithState(
ctx context.Context,
state json.RawMessage,
req ai.ChatRequest,
) (*ai.ChatResponse, json.RawMessage, error)
}

View File

@@ -9,6 +9,7 @@ import (
"net/url"
"os"
"path/filepath"
"reflect"
"strings"
"sync"
"time"
@@ -42,11 +43,18 @@ type Service struct {
secretStore secretstore.SecretStore
localizer *i18n.Localizer
cancelFuncs map[string]context.CancelFunc // 记录每个 session 的 context 取消函数
sessionProviders map[string]aiSessionProviderRuntime
mcpHTTPMu sync.Mutex
mcpHTTP *mcpHTTPServerRuntime
mcpHTTPLast ai.MCPHTTPServerStatus
}
type aiSessionProviderRuntime struct {
ProviderKey string
State json.RawMessage
Messages []ai.Message
}
var miniMaxAnthropicModels = []string{
"MiniMax-M3",
"MiniMax-M2.7",
@@ -131,15 +139,16 @@ func NewServiceWithSecretStore(store secretstore.SecretStore) *Service {
store = secretstore.NewUnavailableStore("secret store unavailable")
}
return &Service{
providers: make([]ai.ProviderConfig, 0),
safetyLevel: ai.PermissionReadOnly,
contextLevel: ai.ContextSchemaOnly,
mcpServers: make([]ai.MCPServerConfig, 0),
skills: make([]ai.SkillConfig, 0),
guard: safety.NewGuard(ai.PermissionReadOnly),
secretStore: store,
localizer: newServiceLocalizer(),
cancelFuncs: make(map[string]context.CancelFunc),
providers: make([]ai.ProviderConfig, 0),
safetyLevel: ai.PermissionReadOnly,
contextLevel: ai.ContextSchemaOnly,
mcpServers: make([]ai.MCPServerConfig, 0),
skills: make([]ai.SkillConfig, 0),
guard: safety.NewGuard(ai.PermissionReadOnly),
secretStore: store,
localizer: newServiceLocalizer(),
cancelFuncs: make(map[string]context.CancelFunc),
sessionProviders: make(map[string]aiSessionProviderRuntime),
}
}
@@ -1150,7 +1159,16 @@ func (s *Service) AISetContextLevel(level string) {
// AIChatSend 非流式发送 AI 对话
func (s *Service) AIChatSend(messages []ai.Message, tools []ai.Tool) map[string]interface{} {
p, err := s.getActiveProvider()
return s.aiChatSend("", messages, tools, false)
}
// AIChatSendInSession 非流式发送 AI 对话,并在支持的 Provider 上复用会话态。
func (s *Service) AIChatSendInSession(sessionID string, messages []ai.Message, tools []ai.Tool) map[string]interface{} {
return s.aiChatSend(sessionID, messages, tools, true)
}
func (s *Service) aiChatSend(sessionID string, messages []ai.Message, tools []ai.Tool, allowSessionReuse bool) map[string]interface{} {
p, config, err := s.getActiveProviderRuntime()
if err != nil {
logger.Error(err, "AIChatSend 获取 Provider 失败messages=%d tools=%d", len(messages), len(tools))
return map[string]interface{}{"success": false, "error": err.Error()}
@@ -1158,14 +1176,62 @@ func (s *Service) AIChatSend(messages []ai.Message, tools []ai.Tool) map[string]
started := time.Now()
providerName := p.Name()
logger.Infof("AIChatSend 开始provider=%s messages=%d tools=%d", providerName, len(messages), len(tools))
resp, err := p.Chat(context.Background(), ai.ChatRequest{Messages: messages, Tools: tools})
logger.Infof("AIChatSend 开始:sessionID=%s provider=%s messages=%d tools=%d sessionReuse=%t", sessionID, providerName, len(messages), len(tools), allowSessionReuse)
requestMessages := cloneAIMessages(messages)
var updatedProviderState json.RawMessage
if allowSessionReuse && strings.TrimSpace(sessionID) != "" {
if sessionAwareProvider, ok := p.(provider.SessionChatProvider); ok {
providerKey := providerSessionKey(config)
providerState, deltaMessages := s.resolveSessionProviderRequest(sessionID, providerKey, messages)
requestMessages = deltaMessages
resp, updatedState, err := sessionAwareProvider.ChatWithState(context.Background(), providerState, ai.ChatRequest{Messages: requestMessages, Tools: tools})
if err != nil {
logger.Warnf("AIChatSend 失败sessionID=%s provider=%s messages=%d tools=%d duration=%s err=%s", sessionID, providerName, len(messages), len(tools), time.Since(started).Round(time.Millisecond), provider.RedactAIUpstreamLogText(err.Error()))
return map[string]interface{}{"success": false, "error": err.Error()}
}
updatedProviderState = updatedState
historyAfterSend := cloneAIMessages(messages)
if assistantMessage, hasAssistantMessage := buildAssistantMessageFromChatResponse(resp); hasAssistantMessage {
historyAfterSend = append(historyAfterSend, assistantMessage)
}
if persistErr := s.storeSessionProviderRuntime(sessionID, providerKey, updatedProviderState, historyAfterSend); persistErr != nil {
logger.Warnf("AIChatSend 保存会话 Provider 状态失败sessionID=%s provider=%s err=%s", sessionID, providerName, provider.RedactAIUpstreamLogText(persistErr.Error()))
}
logger.Infof(
"AIChatSend 完成sessionID=%s provider=%s messages=%d tools=%d toolCalls=%d promptTokens=%d completionTokens=%d totalTokens=%d duration=%s sessionReuse=%t",
sessionID,
providerName,
len(messages),
len(tools),
len(resp.ToolCalls),
resp.TokensUsed.PromptTokens,
resp.TokensUsed.CompletionTokens,
resp.TokensUsed.TotalTokens,
time.Since(started).Round(time.Millisecond),
true,
)
return map[string]interface{}{
"success": true,
"content": resp.Content,
"reasoning_content": resp.ReasoningContent,
"tool_calls": resp.ToolCalls,
"tokensUsed": map[string]int{
"promptTokens": resp.TokensUsed.PromptTokens,
"completionTokens": resp.TokensUsed.CompletionTokens,
"totalTokens": resp.TokensUsed.TotalTokens,
},
}
}
}
resp, err := p.Chat(context.Background(), ai.ChatRequest{Messages: requestMessages, Tools: tools})
if err != nil {
logger.Warnf("AIChatSend 失败provider=%s messages=%d tools=%d duration=%s err=%s", providerName, len(messages), len(tools), time.Since(started).Round(time.Millisecond), provider.RedactAIUpstreamLogText(err.Error()))
logger.Warnf("AIChatSend 失败:sessionID=%s provider=%s messages=%d tools=%d duration=%s err=%s", sessionID, providerName, len(messages), len(tools), time.Since(started).Round(time.Millisecond), provider.RedactAIUpstreamLogText(err.Error()))
return map[string]interface{}{"success": false, "error": err.Error()}
}
logger.Infof(
"AIChatSend 完成provider=%s messages=%d tools=%d toolCalls=%d promptTokens=%d completionTokens=%d totalTokens=%d duration=%s",
"AIChatSend 完成:sessionID=%s provider=%s messages=%d tools=%d toolCalls=%d promptTokens=%d completionTokens=%d totalTokens=%d duration=%s sessionReuse=%t",
sessionID,
providerName,
len(messages),
len(tools),
@@ -1174,6 +1240,7 @@ func (s *Service) AIChatSend(messages []ai.Message, tools []ai.Tool) map[string]
resp.TokensUsed.CompletionTokens,
resp.TokensUsed.TotalTokens,
time.Since(started).Round(time.Millisecond),
false,
)
return map[string]interface{}{
@@ -1204,7 +1271,7 @@ func (s *Service) AIChatStream(sessionID string, messages []ai.Message, tools []
cancel() // 确保释放
}()
p, err := s.getActiveProvider()
p, config, err := s.getActiveProviderRuntime()
if err != nil {
logger.Error(err, "AIChatStream 获取 Provider 失败sessionID=%s messages=%d tools=%d", sessionID, len(messages), len(tools))
wailsRuntime.EventsEmit(s.ctx, "ai:stream:"+sessionID, map[string]interface{}{
@@ -1220,29 +1287,67 @@ func (s *Service) AIChatStream(sessionID string, messages []ai.Message, tools []
thinkingChunks := 0
toolCallChunks := 0
errorChunks := 0
var assistantContent strings.Builder
var assistantReasoning strings.Builder
var assistantToolCalls []ai.ToolCall
var updatedProviderState json.RawMessage
requestMessages := cloneAIMessages(messages)
logger.Infof("AIChatStream 开始sessionID=%s provider=%s messages=%d tools=%d", sessionID, providerName, len(messages), len(tools))
err = p.ChatStream(streamCtx, ai.ChatRequest{Messages: messages, Tools: tools}, func(chunk ai.StreamChunk) {
if chunk.Content != "" {
contentChunks++
}
if chunk.Thinking != "" || chunk.ReasoningContent != "" {
thinkingChunks++
}
if len(chunk.ToolCalls) > 0 {
toolCallChunks++
}
if chunk.Error != "" {
errorChunks++
}
wailsRuntime.EventsEmit(s.ctx, "ai:stream:"+sessionID, map[string]interface{}{
"content": chunk.Content,
"thinking": chunk.Thinking,
"reasoning_content": chunk.ReasoningContent,
"tool_calls": chunk.ToolCalls,
"done": chunk.Done,
"error": chunk.Error,
if sessionAwareProvider, ok := p.(provider.SessionStreamProvider); ok {
providerKey := providerSessionKey(config)
providerState, deltaMessages := s.resolveSessionProviderRequest(sessionID, providerKey, messages)
requestMessages = deltaMessages
updatedProviderState, err = sessionAwareProvider.ChatStreamWithState(streamCtx, providerState, ai.ChatRequest{Messages: requestMessages, Tools: tools}, func(chunk ai.StreamChunk) {
if chunk.Content != "" {
contentChunks++
assistantContent.WriteString(chunk.Content)
}
if chunk.Thinking != "" || chunk.ReasoningContent != "" {
thinkingChunks++
if chunk.ReasoningContent != "" {
assistantReasoning.WriteString(chunk.ReasoningContent)
}
}
if len(chunk.ToolCalls) > 0 {
toolCallChunks++
assistantToolCalls = append([]ai.ToolCall(nil), chunk.ToolCalls...)
}
if chunk.Error != "" {
errorChunks++
}
wailsRuntime.EventsEmit(s.ctx, "ai:stream:"+sessionID, map[string]interface{}{
"content": chunk.Content,
"thinking": chunk.Thinking,
"reasoning_content": chunk.ReasoningContent,
"tool_calls": chunk.ToolCalls,
"done": chunk.Done,
"error": chunk.Error,
})
})
})
} else {
err = p.ChatStream(streamCtx, ai.ChatRequest{Messages: messages, Tools: tools}, func(chunk ai.StreamChunk) {
if chunk.Content != "" {
contentChunks++
}
if chunk.Thinking != "" || chunk.ReasoningContent != "" {
thinkingChunks++
}
if len(chunk.ToolCalls) > 0 {
toolCallChunks++
}
if chunk.Error != "" {
errorChunks++
}
wailsRuntime.EventsEmit(s.ctx, "ai:stream:"+sessionID, map[string]interface{}{
"content": chunk.Content,
"thinking": chunk.Thinking,
"reasoning_content": chunk.ReasoningContent,
"tool_calls": chunk.ToolCalls,
"done": chunk.Done,
"error": chunk.Error,
})
})
}
// 当 context 被主动 cancel 的时候,不把这个视为向外抛的 error
if err != nil && err != context.Canceled {
@@ -1257,6 +1362,16 @@ func (s *Service) AIChatStream(sessionID string, messages []ai.Message, tools []
logger.Infof("AIChatStream 已取消sessionID=%s provider=%s duration=%s", sessionID, providerName, time.Since(started).Round(time.Millisecond))
return
}
if _, ok := p.(provider.SessionStreamProvider); ok && errorChunks == 0 {
providerKey := providerSessionKey(config)
historyAfterStream := cloneAIMessages(messages)
if assistantMessage, hasAssistantMessage := buildAssistantMessageFromStreamResult(assistantContent.String(), assistantReasoning.String(), assistantToolCalls); hasAssistantMessage {
historyAfterStream = append(historyAfterStream, assistantMessage)
}
if persistErr := s.storeSessionProviderRuntime(sessionID, providerKey, updatedProviderState, historyAfterStream); persistErr != nil {
logger.Warnf("AIChatStream 保存会话 Provider 状态失败sessionID=%s provider=%s err=%s", sessionID, providerName, provider.RedactAIUpstreamLogText(persistErr.Error()))
}
}
logger.Infof(
"AIChatStream 完成sessionID=%s provider=%s messages=%d tools=%d contentChunks=%d thinkingChunks=%d toolCallChunks=%d errorChunks=%d duration=%s",
sessionID,
@@ -1292,6 +1407,11 @@ func (s *Service) AICheckSQL(sql string) ai.SafetyResult {
// --- 内部方法 ---
func (s *Service) getActiveProvider() (provider.Provider, error) {
p, _, err := s.getActiveProviderRuntime()
return p, err
}
func (s *Service) getActiveProviderRuntime() (provider.Provider, ai.ProviderConfig, error) {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -1301,11 +1421,174 @@ func (s *Service) getActiveProvider() (provider.Provider, error) {
for _, cfg := range s.providers {
if cfg.ID == s.activeProvider {
return provider.NewProvider(normalizeProviderConfig(cfg))
normalized := normalizeProviderConfig(cfg)
p, err := provider.NewProvider(normalized)
return p, normalized, err
}
}
return nil, fmt.Errorf("未配置 AI Provider请先在设置中配置")
return nil, ai.ProviderConfig{}, fmt.Errorf("未配置 AI Provider请先在设置中配置")
}
func providerSessionKey(config ai.ProviderConfig) string {
return strings.Join([]string{
strings.TrimSpace(config.ID),
strings.ToLower(strings.TrimSpace(config.Type)),
strings.ToLower(strings.TrimSpace(config.APIFormat)),
strings.TrimSpace(config.BaseURL),
strings.TrimSpace(config.Model),
}, "|")
}
func cloneAIMessages(messages []ai.Message) []ai.Message {
if len(messages) == 0 {
return nil
}
cloned := make([]ai.Message, len(messages))
for index, message := range messages {
cloned[index] = message
if len(message.Images) > 0 {
cloned[index].Images = append([]string(nil), message.Images...)
}
if len(message.ToolCalls) > 0 {
cloned[index].ToolCalls = append([]ai.ToolCall(nil), message.ToolCalls...)
}
}
return cloned
}
func buildAssistantMessageFromStreamResult(content string, reasoning string, toolCalls []ai.ToolCall) (ai.Message, bool) {
message := ai.Message{
Role: "assistant",
Content: content,
ReasoningContent: reasoning,
}
if len(toolCalls) > 0 {
message.ToolCalls = append([]ai.ToolCall(nil), toolCalls...)
}
hasPayload := strings.TrimSpace(message.Content) != "" || strings.TrimSpace(message.ReasoningContent) != "" || len(message.ToolCalls) > 0
return message, hasPayload
}
func buildAssistantMessageFromChatResponse(resp *ai.ChatResponse) (ai.Message, bool) {
if resp == nil {
return ai.Message{}, false
}
return buildAssistantMessageFromStreamResult(resp.Content, resp.ReasoningContent, resp.ToolCalls)
}
func messagesHavePrefix(messages []ai.Message, prefix []ai.Message) bool {
if len(prefix) == 0 {
return true
}
if len(messages) < len(prefix) {
return false
}
for index := range prefix {
if !reflect.DeepEqual(messages[index], prefix[index]) {
return false
}
}
return true
}
func (s *Service) resolveSessionProviderRequest(sessionID string, providerKey string, messages []ai.Message) (json.RawMessage, []ai.Message) {
runtimeState, ok := s.loadSessionProviderRuntime(sessionID, providerKey)
if !ok || len(runtimeState.State) == 0 || len(runtimeState.Messages) == 0 {
return nil, cloneAIMessages(messages)
}
if !messagesHavePrefix(messages, runtimeState.Messages) {
return nil, cloneAIMessages(messages)
}
deltaMessages := cloneAIMessages(messages[len(runtimeState.Messages):])
if len(deltaMessages) == 0 {
return nil, cloneAIMessages(messages)
}
return runtimeState.State, deltaMessages
}
func (s *Service) loadSessionProviderRuntime(sessionID string, providerKey string) (aiSessionProviderRuntime, bool) {
s.mu.RLock()
runtimeState, ok := s.sessionProviders[sessionID]
s.mu.RUnlock()
if ok && runtimeState.ProviderKey == providerKey {
return aiSessionProviderRuntime{
ProviderKey: runtimeState.ProviderKey,
State: append(json.RawMessage(nil), runtimeState.State...),
Messages: cloneAIMessages(runtimeState.Messages),
}, true
}
sessionData, err := s.loadSessionFile(sessionID)
if err != nil {
return aiSessionProviderRuntime{}, false
}
if strings.TrimSpace(sessionData.ProviderKey) == "" || sessionData.ProviderKey != providerKey || len(sessionData.ProviderState) == 0 {
return aiSessionProviderRuntime{}, false
}
var providerMessages []ai.Message
if len(sessionData.ProviderMessages) > 0 {
if err := json.Unmarshal(sessionData.ProviderMessages, &providerMessages); err != nil {
return aiSessionProviderRuntime{}, false
}
}
runtimeState = aiSessionProviderRuntime{
ProviderKey: sessionData.ProviderKey,
State: append(json.RawMessage(nil), sessionData.ProviderState...),
Messages: providerMessages,
}
s.mu.Lock()
s.sessionProviders[sessionID] = runtimeState
s.mu.Unlock()
return aiSessionProviderRuntime{
ProviderKey: runtimeState.ProviderKey,
State: append(json.RawMessage(nil), runtimeState.State...),
Messages: cloneAIMessages(runtimeState.Messages),
}, true
}
func (s *Service) storeSessionProviderRuntime(sessionID string, providerKey string, state json.RawMessage, messages []ai.Message) error {
if strings.TrimSpace(providerKey) == "" {
return nil
}
runtimeState := aiSessionProviderRuntime{
ProviderKey: providerKey,
State: append(json.RawMessage(nil), state...),
Messages: cloneAIMessages(messages),
}
s.mu.Lock()
if len(state) == 0 {
delete(s.sessionProviders, sessionID)
} else {
s.sessionProviders[sessionID] = runtimeState
}
s.mu.Unlock()
sessionData, err := s.loadOrCreateSessionFile(sessionID)
if err != nil {
return err
}
if len(state) == 0 {
sessionData.ProviderKey = ""
sessionData.ProviderState = nil
sessionData.ProviderMessages = nil
return s.saveSessionFile(sessionID, sessionData)
}
sessionData.ProviderKey = providerKey
sessionData.ProviderState = append(json.RawMessage(nil), state...)
if len(messages) == 0 {
sessionData.ProviderMessages = nil
} else {
messageBytes, err := json.Marshal(messages)
if err != nil {
return fmt.Errorf("序列化会话 Provider 消息失败: %w", err)
}
sessionData.ProviderMessages = json.RawMessage(messageBytes)
}
return s.saveSessionFile(sessionID, sessionData)
}
// --- 配置持久化 ---
@@ -1363,16 +1646,69 @@ func normalizeUserPromptText(value string) string {
// sessionFileData 会话文件的 JSON 结构
type sessionFileData struct {
ID string `json:"id"`
Title string `json:"title"`
UpdatedAt int64 `json:"updatedAt"`
Messages json.RawMessage `json:"messages"` // 透传前端格式,后端不解析消息体
ID string `json:"id"`
Title string `json:"title"`
UpdatedAt int64 `json:"updatedAt"`
Messages json.RawMessage `json:"messages"` // 透传前端格式,后端不解析消息体
ProviderKey string `json:"providerKey,omitempty"`
ProviderState json.RawMessage `json:"providerState,omitempty"`
ProviderMessages json.RawMessage `json:"providerMessages,omitempty"`
}
func (s *Service) sessionsDir() string {
return filepath.Join(s.configDir, "sessions")
}
func (s *Service) sessionFilePath(sessionID string) string {
return filepath.Join(s.sessionsDir(), sessionID+".json")
}
func (s *Service) loadSessionFile(sessionID string) (sessionFileData, error) {
data, err := os.ReadFile(s.sessionFilePath(sessionID))
if err != nil {
return sessionFileData{}, err
}
var sessionData sessionFileData
if err := json.Unmarshal(data, &sessionData); err != nil {
return sessionFileData{}, err
}
return sessionData, nil
}
func (s *Service) loadOrCreateSessionFile(sessionID string) (sessionFileData, error) {
sessionData, err := s.loadSessionFile(sessionID)
if err == nil {
return sessionData, nil
}
if !os.IsNotExist(err) {
return sessionFileData{}, err
}
return sessionFileData{
ID: sessionID,
Title: "新的对话",
UpdatedAt: time.Now().UnixMilli(),
Messages: json.RawMessage("[]"),
}, nil
}
func (s *Service) saveSessionFile(sessionID string, sessionData sessionFileData) error {
dir := s.sessionsDir()
if err := os.MkdirAll(dir, 0o755); err != nil {
return fmt.Errorf("创建 sessions 目录失败: %w", err)
}
if strings.TrimSpace(sessionData.ID) == "" {
sessionData.ID = sessionID
}
if len(sessionData.Messages) == 0 {
sessionData.Messages = json.RawMessage("[]")
}
data, err := json.Marshal(sessionData)
if err != nil {
return fmt.Errorf("序列化会话数据失败: %w", err)
}
return os.WriteFile(s.sessionFilePath(sessionID), data, 0o644)
}
// AIGetSessions 获取所有会话的元数据列表(不含消息体)
func (s *Service) AIGetSessions() []map[string]interface{} {
dir := s.sessionsDir()
@@ -1417,53 +1753,40 @@ func (s *Service) AIGetSessions() []map[string]interface{} {
// AILoadSession 加载指定会话的完整数据(含消息)
func (s *Service) AILoadSession(sessionID string) map[string]interface{} {
path := filepath.Join(s.sessionsDir(), sessionID+".json")
data, err := os.ReadFile(path)
sessionData, err := s.loadSessionFile(sessionID)
if err != nil {
return map[string]interface{}{"success": false, "error": "会话不存在"}
}
var sfd sessionFileData
if err := json.Unmarshal(data, &sfd); err != nil {
return map[string]interface{}{"success": false, "error": "会话数据损坏"}
}
return map[string]interface{}{
"success": true,
"id": sfd.ID,
"title": sfd.Title,
"updatedAt": sfd.UpdatedAt,
"messages": sfd.Messages,
"id": sessionData.ID,
"title": sessionData.Title,
"updatedAt": sessionData.UpdatedAt,
"messages": sessionData.Messages,
}
}
// AISaveSession 保存会话数据到文件
func (s *Service) AISaveSession(sessionID string, title string, updatedAt float64, messagesJSON string) error {
dir := s.sessionsDir()
if err := os.MkdirAll(dir, 0o755); err != nil {
return fmt.Errorf("创建 sessions 目录失败: %w", err)
}
sfd := sessionFileData{
ID: sessionID,
Title: title,
UpdatedAt: int64(updatedAt),
Messages: json.RawMessage(messagesJSON),
}
data, err := json.Marshal(sfd)
sessionData, err := s.loadOrCreateSessionFile(sessionID)
if err != nil {
return fmt.Errorf("序列化会话数据失败: %w", err)
return err
}
path := filepath.Join(dir, sessionID+".json")
return os.WriteFile(path, data, 0o644)
sessionData.ID = sessionID
sessionData.Title = title
sessionData.UpdatedAt = int64(updatedAt)
sessionData.Messages = json.RawMessage(messagesJSON)
return s.saveSessionFile(sessionID, sessionData)
}
// AIDeleteSession 删除会话文件
func (s *Service) AIDeleteSession(sessionID string) error {
path := filepath.Join(s.sessionsDir(), sessionID+".json")
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
if err := os.Remove(s.sessionFilePath(sessionID)); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("删除会话失败: %w", err)
}
s.mu.Lock()
delete(s.sessionProviders, sessionID)
s.mu.Unlock()
return nil
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"GoNavi-Wails/internal/ai"
@@ -99,3 +100,192 @@ func TestAIListModels_FetchesCursorModelItems(t *testing.T) {
t.Fatalf("expected api source, got %#v", result["source"])
}
}
func TestResolveSessionProviderRequest_ReusesStoredStateOnlyForHistoryExtension(t *testing.T) {
service := NewService()
service.sessionProviders["session-1"] = aiSessionProviderRuntime{
ProviderKey: "cursor-provider",
State: json.RawMessage(`{"agentId":"bc-1"}`),
Messages: []ai.Message{
{Role: "user", Content: "hello"},
{Role: "assistant", Content: "world"},
},
}
state, delta := service.resolveSessionProviderRequest("session-1", "cursor-provider", []ai.Message{
{Role: "user", Content: "hello"},
{Role: "assistant", Content: "world"},
{Role: "user", Content: "next"},
})
if string(state) != `{"agentId":"bc-1"}` {
t.Fatalf("expected stored provider state, got %s", string(state))
}
expectedDelta := []ai.Message{{Role: "user", Content: "next"}}
if !reflect.DeepEqual(delta, expectedDelta) {
t.Fatalf("unexpected delta messages: %#v", delta)
}
state, delta = service.resolveSessionProviderRequest("session-1", "cursor-provider", []ai.Message{
{Role: "user", Content: "hello changed"},
})
if len(state) != 0 {
t.Fatalf("expected mismatched history to reset provider state, got %s", string(state))
}
if len(delta) != 1 || delta[0].Content != "hello changed" {
t.Fatalf("expected full messages after mismatch, got %#v", delta)
}
}
func TestAISaveSession_PreservesProviderRuntimeMetadata(t *testing.T) {
service := NewService()
service.configDir = t.TempDir()
err := service.storeSessionProviderRuntime(
"session-1",
"cursor-provider",
json.RawMessage(`{"agentId":"bc-1","lastRunId":"run-1"}`),
[]ai.Message{{Role: "user", Content: "hello"}},
)
if err != nil {
t.Fatalf("store provider runtime: %v", err)
}
err = service.AISaveSession("session-1", "标题", 123, `[{"id":"m1","role":"user","content":"hello","timestamp":1}]`)
if err != nil {
t.Fatalf("save session: %v", err)
}
sessionData, err := service.loadSessionFile("session-1")
if err != nil {
t.Fatalf("load session file: %v", err)
}
if sessionData.ProviderKey != "cursor-provider" {
t.Fatalf("expected provider key to be preserved, got %q", sessionData.ProviderKey)
}
if string(sessionData.ProviderState) != `{"agentId":"bc-1","lastRunId":"run-1"}` {
t.Fatalf("expected provider state to be preserved, got %s", string(sessionData.ProviderState))
}
var providerMessages []ai.Message
if err := json.Unmarshal(sessionData.ProviderMessages, &providerMessages); err != nil {
t.Fatalf("unmarshal provider messages: %v", err)
}
if len(providerMessages) != 1 || providerMessages[0].Content != "hello" {
t.Fatalf("unexpected provider messages: %#v", providerMessages)
}
}
func TestAIChatSendInSession_ReusesCursorProviderStateAndPersistsFollowUpRuns(t *testing.T) {
var (
createAgentCalls int
createRunCalls int
createRunPrompt string
)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodPost && r.URL.Path == "/v1/agents":
createAgentCalls++
var body struct {
Prompt struct {
Text string `json:"text"`
} `json:"prompt"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Fatalf("decode create agent body: %v", err)
}
if body.Prompt.Text == "" {
t.Fatalf("expected first prompt text")
}
_ = json.NewEncoder(w).Encode(map[string]any{
"agent": map[string]any{"id": "bc-1"},
"run": map[string]any{"id": "run-1", "agentId": "bc-1"},
})
case r.Method == http.MethodGet && r.URL.Path == "/v1/agents/bc-1/runs/run-1":
_ = json.NewEncoder(w).Encode(map[string]any{
"id": "run-1",
"agentId": "bc-1",
"status": "FINISHED",
"result": "first answer",
"durationMs": 100,
})
case r.Method == http.MethodPost && r.URL.Path == "/v1/agents/bc-1/runs":
createRunCalls++
var body struct {
Prompt struct {
Text string `json:"text"`
} `json:"prompt"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Fatalf("decode follow-up run body: %v", err)
}
createRunPrompt = body.Prompt.Text
_ = json.NewEncoder(w).Encode(map[string]any{
"run": map[string]any{"id": "run-2"},
})
case r.Method == http.MethodGet && r.URL.Path == "/v1/agents/bc-1/runs/run-2":
_ = json.NewEncoder(w).Encode(map[string]any{
"id": "run-2",
"agentId": "bc-1",
"status": "FINISHED",
"result": "second answer",
"durationMs": 120,
})
default:
http.NotFound(w, r)
}
}))
defer server.Close()
service := NewService()
service.configDir = t.TempDir()
service.providers = []ai.ProviderConfig{
{
ID: "provider-cursor",
Type: "custom",
APIFormat: "cursor-agent",
BaseURL: server.URL + "/v1",
APIKey: "cursor-key",
},
}
service.activeProvider = "provider-cursor"
firstResult := service.AIChatSendInSession("session-1", []ai.Message{
{Role: "user", Content: "hello"},
}, nil)
if firstResult["success"] != true {
t.Fatalf("expected first send to succeed, got %#v", firstResult)
}
secondResult := service.AIChatSendInSession("session-1", []ai.Message{
{Role: "user", Content: "hello"},
{Role: "assistant", Content: "first answer"},
{Role: "user", Content: "next"},
}, nil)
if secondResult["success"] != true {
t.Fatalf("expected second send to succeed, got %#v", secondResult)
}
if createAgentCalls != 1 {
t.Fatalf("expected exactly one create-agent call, got %d", createAgentCalls)
}
if createRunCalls != 1 {
t.Fatalf("expected exactly one follow-up run call, got %d", createRunCalls)
}
if createRunPrompt != "next" {
t.Fatalf("expected follow-up run to send only delta message, got %q", createRunPrompt)
}
sessionData, err := service.loadSessionFile("session-1")
if err != nil {
t.Fatalf("load session file: %v", err)
}
if string(sessionData.ProviderState) != `{"agentId":"bc-1","lastRunId":"run-2"}` {
t.Fatalf("unexpected provider state: %s", string(sessionData.ProviderState))
}
var providerMessages []ai.Message
if err := json.Unmarshal(sessionData.ProviderMessages, &providerMessages); err != nil {
t.Fatalf("unmarshal provider messages: %v", err)
}
if len(providerMessages) != 4 || providerMessages[3].Content != "second answer" {
t.Fatalf("unexpected provider messages: %#v", providerMessages)
}
}