mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-06-23 06:53:52 +08:00
✨ feat(ai): 接入 Cursor Cloud Agents API
- 新增 cursor-agent provider,支持创建 agent、轮询 run 状态和 SSE 流式响应 - 接入 AITestProvider 与 AIListModels,支持 Cursor 官方 /v1/models 连通性测试和模型发现 - 在 AI 设置中新增 Cursor 供应商预设,固定 cursor-agent 协议并补齐默认端点配置 - 调整 provider readiness 与 insights 规则,允许 Cursor 未显式选模型时走官方默认模型 - 补充后端 provider/service 测试和前端 preset、表单、readiness 相关用例 Close #576
This commit is contained in:
@@ -345,7 +345,7 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
|
||||
// 构建 payload,处理 model/models 逻辑
|
||||
const preset = findPreset(values.presetKey);
|
||||
const isCustomLike = values.presetKey === 'custom' || values.presetKey === 'ollama' || values.presetKey === 'codebuddy';
|
||||
const isCustomLike = values.presetKey === 'custom' || values.presetKey === 'ollama' || values.presetKey === 'codebuddy' || values.presetKey === 'cursor';
|
||||
const { model: finalModel, models: resolvedModels } = resolvePresetModelSelection({
|
||||
presetKey: values.presetKey,
|
||||
presetDefaultModel: preset.defaultModel,
|
||||
|
||||
@@ -9,6 +9,7 @@ import AISettingsProvidersSection from './AISettingsProvidersSection';
|
||||
|
||||
const providerPresets = [
|
||||
{ key: 'openai', label: 'OpenAI', icon: <span>O</span>, desc: 'GPT', defaultBaseUrl: 'https://api.openai.com/v1' },
|
||||
{ key: 'cursor', label: 'Cursor', icon: <span>R</span>, desc: 'Cursor API', defaultBaseUrl: 'https://api.cursor.com/v1' },
|
||||
{ key: 'custom', label: '自定义', icon: <span>C</span>, desc: '自定义接口', defaultBaseUrl: 'https://example.com' },
|
||||
];
|
||||
|
||||
@@ -150,4 +151,44 @@ describe('AISettingsProvidersSection', () => {
|
||||
expect(markup).toContain('本机 CodeBuddy CLI 已登录账号');
|
||||
expect(markup).toContain('留空则使用 CodeBuddy CLI 默认网关');
|
||||
});
|
||||
|
||||
it('renders automatic-model copy for the Cursor preset', () => {
|
||||
const Wrap = () => {
|
||||
const [form] = Form.useForm();
|
||||
return (
|
||||
<AISettingsProvidersSection
|
||||
providers={[provider]}
|
||||
activeProviderId="provider-1"
|
||||
editingProvider={{ ...provider, apiFormat: 'cursor-agent', baseUrl: 'https://api.cursor.com/v1' }}
|
||||
isEditing
|
||||
form={form}
|
||||
providerPresets={providerPresets}
|
||||
watchedPresetKey="cursor"
|
||||
watchedApiFormat="cursor-agent"
|
||||
loading={false}
|
||||
testStatus="idle"
|
||||
primaryPasswordVisible={false}
|
||||
darkMode={false}
|
||||
overlayTheme={overlayTheme}
|
||||
cardBg="#fff"
|
||||
cardBorder="rgba(0,0,0,0.08)"
|
||||
inputBg="#fff"
|
||||
onPrimaryPasswordVisibleChange={() => {}}
|
||||
resolveProviderPreset={() => ({ label: 'Cursor', icon: <span>R</span> })}
|
||||
resolvePresetByKey={(key) => providerPresets.find((item) => item.key === key) || providerPresets[0]}
|
||||
onAddProvider={() => {}}
|
||||
onEditProvider={() => {}}
|
||||
onDeleteProvider={() => {}}
|
||||
onSetActiveProvider={() => {}}
|
||||
onCancelEdit={() => {}}
|
||||
onPresetChange={() => {}}
|
||||
onTestProvider={() => {}}
|
||||
onSaveProvider={() => {}}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
const markup = renderToStaticMarkup(<Wrap />);
|
||||
expect(markup).toContain('可选:预填常用 Cursor 模型 ID;留空则由 Cursor 默认模型自动选择');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -106,8 +106,9 @@ const AISettingsProvidersSection: React.FC<AISettingsProvidersSectionProps> = ({
|
||||
onSaveProvider,
|
||||
}) => {
|
||||
const presetKeyFromForm = watchedPresetKey || (editingProvider as (AIProviderConfig & { presetKey?: string }) | null)?.presetKey || 'openai';
|
||||
const supportsAdvancedEndpoint = presetKeyFromForm === 'custom' || presetKeyFromForm === 'ollama' || presetKeyFromForm === 'codebuddy';
|
||||
const supportsAdvancedEndpoint = presetKeyFromForm === 'custom' || presetKeyFromForm === 'ollama' || presetKeyFromForm === 'codebuddy' || presetKeyFromForm === 'cursor';
|
||||
const codeBuddyUsesOptionalSecret = presetKeyFromForm === 'codebuddy';
|
||||
const cursorUsesOptionalModel = presetKeyFromForm === 'cursor';
|
||||
const sectionLabelColor = darkMode ? 'rgba(255,255,255,0.5)' : 'rgba(0,0,0,0.4)';
|
||||
const currentFieldGroupStyle = fieldGroupStyle(cardBorder, cardBg);
|
||||
const currentFieldLabelStyle = fieldLabelStyle(sectionLabelColor);
|
||||
@@ -134,7 +135,7 @@ const AISettingsProvidersSection: React.FC<AISettingsProvidersSectionProps> = ({
|
||||
{providers.map((provider) => {
|
||||
const matchedPreset = resolveProviderPreset(provider);
|
||||
const isActive = provider.id === activeProviderId;
|
||||
const modelLabel = provider.model || (provider.apiFormat === 'codebuddy-cli' ? '自动选择' : '未选择模型');
|
||||
const modelLabel = provider.model || (provider.apiFormat === 'codebuddy-cli' || provider.apiFormat === 'cursor-agent' ? '自动选择' : '未选择模型');
|
||||
return (
|
||||
<div
|
||||
key={provider.id}
|
||||
@@ -295,7 +296,7 @@ const AISettingsProvidersSection: React.FC<AISettingsProvidersSectionProps> = ({
|
||||
borderRadius: 8,
|
||||
gap: 4,
|
||||
}}>
|
||||
{[{ value: 'openai', label: 'OpenAI' }, { value: 'anthropic', label: 'Anthropic' }, { value: 'gemini', label: 'Gemini' }, { value: 'claude-cli', label: 'Claude CLI' }].map((format) => (
|
||||
{[{ value: 'openai', label: 'OpenAI' }, { value: 'anthropic', label: 'Anthropic' }, { value: 'gemini', label: 'Gemini' }, { value: 'cursor-agent', label: 'Cursor Agent' }, { value: 'claude-cli', label: 'Claude CLI' }].map((format) => (
|
||||
<div
|
||||
key={format.value}
|
||||
onClick={() => form.setFieldsValue({ apiFormat: format.value })}
|
||||
@@ -319,7 +320,16 @@ const AISettingsProvidersSection: React.FC<AISettingsProvidersSectionProps> = ({
|
||||
)}
|
||||
|
||||
<Form.Item label={<span style={{ fontWeight: 500, color: overlayTheme.titleText }}>可用模型列表(可选配置)</span>} name="models" style={{ marginBottom: 0 }}>
|
||||
<Select mode="tags" size="middle" placeholder={codeBuddyUsesOptionalSecret ? '可选:预填常用模型;留空则由 CodeBuddy CLI 或服务端自动选择' : '配置指定的模型ID,留空则默认去服务端拉取'} style={{ width: '100%' }} />
|
||||
<Select
|
||||
mode="tags"
|
||||
size="middle"
|
||||
placeholder={codeBuddyUsesOptionalSecret
|
||||
? '可选:预填常用模型;留空则由 CodeBuddy CLI 或服务端自动选择'
|
||||
: cursorUsesOptionalModel
|
||||
? '可选:预填常用 Cursor 模型 ID;留空则由 Cursor 默认模型自动选择'
|
||||
: '配置指定的模型ID,留空则默认去服务端拉取'}
|
||||
style={{ width: '100%' }}
|
||||
/>
|
||||
</Form.Item>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -116,4 +116,28 @@ describe('buildAIChatReadinessSnapshot', () => {
|
||||
expect(snapshot.title).toContain('CodeBuddy');
|
||||
expect(snapshot.title).toContain('自动选择');
|
||||
});
|
||||
|
||||
it('treats Cursor Agent as ready without an explicit model', () => {
|
||||
const snapshot = buildAIChatReadinessSnapshot({
|
||||
providers: [{
|
||||
id: 'provider-1',
|
||||
type: 'custom',
|
||||
name: 'Cursor',
|
||||
apiKey: '',
|
||||
hasSecret: true,
|
||||
baseUrl: 'https://api.cursor.com/v1',
|
||||
model: '',
|
||||
apiFormat: 'cursor-agent',
|
||||
models: [],
|
||||
maxTokens: 4096,
|
||||
temperature: 0.2,
|
||||
}],
|
||||
activeProviderId: 'provider-1',
|
||||
});
|
||||
|
||||
expect(snapshot.status).toBe('ready');
|
||||
expect(snapshot.ready).toBe(true);
|
||||
expect(snapshot.title).toContain('Cursor');
|
||||
expect(snapshot.title).toContain('自动选择');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -66,7 +66,7 @@ const isBaseURLOptionalProvider = (provider: AIProviderConfig): boolean =>
|
||||
provider.type === 'custom' && trimText(provider.apiFormat) === 'codebuddy-cli';
|
||||
|
||||
const isModelOptionalProvider = (provider: AIProviderConfig): boolean =>
|
||||
provider.type === 'custom' && trimText(provider.apiFormat) === 'codebuddy-cli';
|
||||
provider.type === 'custom' && ['codebuddy-cli', 'cursor-agent'].includes(trimText(provider.apiFormat));
|
||||
|
||||
const getSelectedProvider = (params: {
|
||||
providers?: AIProviderConfig[];
|
||||
|
||||
@@ -70,4 +70,26 @@ describe('aiProviderInsights', () => {
|
||||
expect(JSON.stringify(snapshot)).not.toContain('apiKey');
|
||||
expect(JSON.stringify(snapshot)).not.toContain('secret-token');
|
||||
});
|
||||
|
||||
it('does not flag Cursor Agent for a missing selected model', () => {
|
||||
const snapshot = buildAIProviderSnapshot({
|
||||
providers: [{
|
||||
id: 'provider-cursor',
|
||||
type: 'custom',
|
||||
name: 'Cursor',
|
||||
apiKey: '',
|
||||
hasSecret: true,
|
||||
baseUrl: 'https://api.cursor.com/v1',
|
||||
model: '',
|
||||
models: [],
|
||||
apiFormat: 'cursor-agent',
|
||||
maxTokens: 4096,
|
||||
temperature: 0.2,
|
||||
}],
|
||||
activeProviderId: 'provider-cursor',
|
||||
});
|
||||
|
||||
expect(snapshot.missingSelectedModelCount).toBe(0);
|
||||
expect(snapshot.providers[0].issues).toEqual(['missing_declared_models']);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -17,6 +17,12 @@ const trimText = (value: unknown): string => String(value || '').trim();
|
||||
const hasProviderSecret = (provider: AIProviderConfig): boolean =>
|
||||
provider.hasSecret ?? Boolean(provider.secretRef || provider.apiKey);
|
||||
|
||||
const isBaseURLOptionalProvider = (provider: AIProviderConfig): boolean =>
|
||||
provider.type === 'custom' && trimText(provider.apiFormat) === 'codebuddy-cli';
|
||||
|
||||
const isModelOptionalProvider = (provider: AIProviderConfig): boolean =>
|
||||
provider.type === 'custom' && ['codebuddy-cli', 'cursor-agent'].includes(trimText(provider.apiFormat));
|
||||
|
||||
const getProviderHost = (baseUrl: string): string => {
|
||||
const normalized = trimText(baseUrl);
|
||||
if (!normalized) {
|
||||
@@ -41,10 +47,10 @@ const buildProviderIssues = (provider: AIProviderConfig): string[] => {
|
||||
if (!hasSecret) {
|
||||
issues.push('missing_secret');
|
||||
}
|
||||
if (!baseUrl) {
|
||||
if (!isBaseURLOptionalProvider(provider) && !baseUrl) {
|
||||
issues.push('missing_base_url');
|
||||
}
|
||||
if (!model) {
|
||||
if (!isModelOptionalProvider(provider) && !model) {
|
||||
issues.push('missing_selected_model');
|
||||
}
|
||||
if (declaredModels.length === 0) {
|
||||
|
||||
@@ -35,6 +35,16 @@ describe('aiSettingsModalConfig', () => {
|
||||
expect(preset.key).toBe('codebuddy');
|
||||
});
|
||||
|
||||
it('matches a Cursor Agent provider back to the dedicated preset', () => {
|
||||
const preset = matchProviderPreset({
|
||||
type: 'custom',
|
||||
baseUrl: 'https://api.cursor.com/v1',
|
||||
apiFormat: 'cursor-agent',
|
||||
});
|
||||
|
||||
expect(preset.key).toBe('cursor');
|
||||
});
|
||||
|
||||
it('creates MCP server drafts and skill drafts with stable defaults', () => {
|
||||
const server = EMPTY_MCP_SERVER({ name: 'Browser', args: ['stdio'] });
|
||||
const skill = EMPTY_SKILL();
|
||||
@@ -49,6 +59,7 @@ describe('aiSettingsModalConfig', () => {
|
||||
it('keeps the provider preset list available for the settings modal', () => {
|
||||
expect(PROVIDER_PRESETS.some((item) => item.key === 'codex')).toBe(false);
|
||||
expect(PROVIDER_PRESETS.some((item) => item.key === 'codebuddy')).toBe(true);
|
||||
expect(PROVIDER_PRESETS.some((item) => item.key === 'cursor')).toBe(true);
|
||||
expect(PROVIDER_PRESETS.some((item) => item.key === 'openai')).toBe(true);
|
||||
expect(PROVIDER_PRESETS.some((item) => item.key === 'custom')).toBe(true);
|
||||
});
|
||||
|
||||
@@ -47,6 +47,7 @@ export const PROVIDER_PRESETS: ProviderPreset[] = [
|
||||
{ key: 'volcengine-coding', label: '火山 Coding Plan', icon: <CloudOutlined />, desc: 'Ark Code / Coding Plan', color: '#0284c7', backendType: 'openai', defaultBaseUrl: 'https://ark.cn-beijing.volces.com/api/coding/v3', defaultModel: '', models: [] },
|
||||
{ key: 'minimax', label: 'MiniMax', icon: <ExperimentOutlined />, desc: 'M3 / M2.7 系列 (Anthropic 兼容)', color: '#e11d48', backendType: 'anthropic', defaultBaseUrl: 'https://api.minimaxi.com/anthropic', defaultModel: 'MiniMax-M3', models: ['MiniMax-M3', 'MiniMax-M2.7', 'MiniMax-M2.7-highspeed'] },
|
||||
{ key: 'codebuddy', label: 'CodeBuddy', icon: <ApiOutlined />, desc: '本地 CodeBuddy CLI / 官方登录态', color: '#2563eb', backendType: 'custom', fixedApiFormat: 'codebuddy-cli', defaultBaseUrl: '', defaultModel: '', models: [] },
|
||||
{ key: 'cursor', label: 'Cursor', icon: <ApiOutlined />, desc: 'Cloud Agents API / 官方 API Key', color: '#7c3aed', backendType: 'custom', fixedApiFormat: 'cursor-agent', defaultBaseUrl: 'https://api.cursor.com/v1', defaultModel: '', models: [] },
|
||||
{ key: 'ollama', label: 'Ollama', icon: <AppstoreOutlined />, desc: '本地部署开源模型', color: '#78716c', backendType: 'openai', defaultBaseUrl: 'http://localhost:11434/v1', defaultModel: 'llama3', models: [] },
|
||||
{ key: 'custom', label: '自定义', icon: <AppstoreOutlined />, desc: '自定义 API 端点', color: '#64748b', backendType: 'custom', defaultBaseUrl: '', defaultModel: '', models: [] },
|
||||
];
|
||||
|
||||
@@ -600,7 +600,7 @@ export interface AIProviderConfig {
|
||||
baseUrl: string;
|
||||
model: string;
|
||||
models?: string[];
|
||||
apiFormat?: string; // custom 专用: openai | anthropic | gemini | claude-cli | codebuddy-cli
|
||||
apiFormat?: string; // custom 专用: openai | anthropic | gemini | cursor-agent | claude-cli | codebuddy-cli
|
||||
headers?: Record<string, string>;
|
||||
maxTokens: number;
|
||||
temperature: number;
|
||||
|
||||
@@ -30,6 +30,7 @@ const PRESETS: PresetMatcher[] = [
|
||||
fixedApiFormat: 'claude-cli',
|
||||
},
|
||||
{ key: 'codebuddy', backendType: 'custom', defaultBaseUrl: '', fixedApiFormat: 'codebuddy-cli' },
|
||||
{ key: 'cursor', backendType: 'custom', defaultBaseUrl: 'https://api.cursor.com/v1', fixedApiFormat: 'cursor-agent' },
|
||||
{ key: 'custom', backendType: 'custom', defaultBaseUrl: '' },
|
||||
];
|
||||
|
||||
@@ -103,6 +104,19 @@ describe('ai provider preset helpers', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('keeps Cursor model empty when only a suggested model list is configured', () => {
|
||||
expect(resolvePresetModelSelection({
|
||||
presetKey: 'cursor',
|
||||
presetDefaultModel: '',
|
||||
presetModels: [],
|
||||
valuesModel: '',
|
||||
customModels: ['composer-2', 'composer-latest'],
|
||||
})).toEqual({
|
||||
model: '',
|
||||
models: ['composer-2', 'composer-latest'],
|
||||
});
|
||||
});
|
||||
|
||||
it('forces built-in presets back to their standard base URL when saving or testing', () => {
|
||||
expect(resolvePresetBaseURL({
|
||||
presetKey: 'qwen-bailian',
|
||||
@@ -119,6 +133,14 @@ describe('ai provider preset helpers', () => {
|
||||
})).toBe('https://example-proxy.internal/v1');
|
||||
});
|
||||
|
||||
it('keeps the user-entered base URL for the Cursor preset', () => {
|
||||
expect(resolvePresetBaseURL({
|
||||
presetKey: 'cursor',
|
||||
presetDefaultBaseUrl: 'https://api.cursor.com/v1',
|
||||
valuesBaseUrl: 'https://cursor-proxy.internal/v1',
|
||||
})).toBe('https://cursor-proxy.internal/v1');
|
||||
});
|
||||
|
||||
it('forces qwen coding plan to save as custom plus claude-cli', () => {
|
||||
expect(resolvePresetTransport({
|
||||
presetBackendType: 'custom',
|
||||
@@ -197,4 +219,18 @@ describe('resolveProviderPresetKey', () => {
|
||||
|
||||
expect(key).toBe('codebuddy');
|
||||
});
|
||||
|
||||
it('能识别 Cursor Agent 预设', () => {
|
||||
const key = resolveProviderPresetKey(
|
||||
{
|
||||
type: 'custom',
|
||||
apiFormat: 'cursor-agent',
|
||||
baseUrl: 'https://api.cursor.com/v1',
|
||||
},
|
||||
PRESETS,
|
||||
'custom',
|
||||
);
|
||||
|
||||
expect(key).toBe('cursor');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -17,7 +17,7 @@ export const QWEN_CODING_PLAN_MODELS = [
|
||||
'glm-4.7',
|
||||
];
|
||||
|
||||
const CUSTOM_LIKE_PRESET_KEYS = new Set(['custom', 'ollama', 'codebuddy']);
|
||||
const CUSTOM_LIKE_PRESET_KEYS = new Set(['custom', 'ollama', 'codebuddy', 'cursor']);
|
||||
|
||||
export interface ResolvePresetModelSelectionInput {
|
||||
presetKey: string;
|
||||
@@ -183,6 +183,12 @@ export const resolvePresetModelSelection = ({
|
||||
}: ResolvePresetModelSelectionInput): ResolvePresetModelSelectionResult => {
|
||||
const isCustomLike = CUSTOM_LIKE_PRESET_KEYS.has(presetKey);
|
||||
const resolvedModels = isCustomLike ? (customModels || []) : presetModels;
|
||||
if (presetKey === 'cursor') {
|
||||
return {
|
||||
models: resolvedModels,
|
||||
model: valuesModel || '',
|
||||
};
|
||||
}
|
||||
const fallbackModel = resolvedModels.length > 0 ? resolvedModels[0] : '';
|
||||
return {
|
||||
models: resolvedModels,
|
||||
|
||||
568
internal/ai/provider/cursor_agent.go
Normal file
568
internal/ai/provider/cursor_agent.go
Normal file
@@ -0,0 +1,568 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/ai"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultCursorAPIBaseURL = "https://api.cursor.com/v1"
|
||||
cursorHTTPTimeout = 120 * time.Second
|
||||
cursorRunPollInterval = time.Second
|
||||
)
|
||||
|
||||
// CursorAgentProvider 通过 Cursor Cloud Agents API 发起对话。
|
||||
// 当前实现为无状态适配:每次请求都创建一个新的 agent,再消费本次 run 的结果。
|
||||
type CursorAgentProvider struct {
|
||||
config ai.ProviderConfig
|
||||
baseURL string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewCursorAgentProvider 创建 Cursor Agent Provider。
|
||||
func NewCursorAgentProvider(config ai.ProviderConfig) (Provider, error) {
|
||||
normalized := config
|
||||
normalized.BaseURL = NormalizeCursorAPIBaseURL(config.BaseURL)
|
||||
normalized.Model = strings.TrimSpace(config.Model)
|
||||
|
||||
return &CursorAgentProvider{
|
||||
config: normalized,
|
||||
baseURL: normalized.BaseURL,
|
||||
client: &http.Client{
|
||||
Timeout: cursorHTTPTimeout,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *CursorAgentProvider) Name() string {
|
||||
if strings.TrimSpace(p.config.Name) != "" {
|
||||
return p.config.Name
|
||||
}
|
||||
return "Cursor"
|
||||
}
|
||||
|
||||
func (p *CursorAgentProvider) Validate() error {
|
||||
if strings.TrimSpace(p.config.APIKey) == "" {
|
||||
return fmt.Errorf("API Key 不能为空")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NormalizeCursorAPIBaseURL 归一化 Cursor API 的 base URL。
|
||||
func NormalizeCursorAPIBaseURL(raw string) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return defaultCursorAPIBaseURL
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(trimmed)
|
||||
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
|
||||
return normalizeCursorAPIBaseURLString(trimmed)
|
||||
}
|
||||
|
||||
parsed.RawQuery = ""
|
||||
parsed.Fragment = ""
|
||||
parsed.Path = normalizeCursorAPIPath(parsed.Path)
|
||||
return strings.TrimRight(parsed.String(), "/")
|
||||
}
|
||||
|
||||
// ResolveCursorAPIEndpoint 基于归一化后的 base URL 生成具体接口地址。
|
||||
func ResolveCursorAPIEndpoint(baseURL string, endpoint string) string {
|
||||
normalizedBaseURL := NormalizeCursorAPIBaseURL(baseURL)
|
||||
normalizedEndpoint := strings.TrimLeft(strings.TrimSpace(endpoint), "/")
|
||||
if normalizedEndpoint == "" {
|
||||
return normalizedBaseURL
|
||||
}
|
||||
return normalizedBaseURL + "/" + normalizedEndpoint
|
||||
}
|
||||
|
||||
func normalizeCursorAPIBaseURLString(raw string) string {
|
||||
normalized := strings.TrimRight(strings.TrimSpace(raw), "/")
|
||||
if normalized == "" {
|
||||
return defaultCursorAPIBaseURL
|
||||
}
|
||||
|
||||
lower := strings.ToLower(normalized)
|
||||
switch {
|
||||
case strings.HasSuffix(lower, "/v1/agents"):
|
||||
normalized = normalized[:len(normalized)-len("/v1/agents")]
|
||||
case strings.HasSuffix(lower, "/agents"):
|
||||
normalized = normalized[:len(normalized)-len("/agents")]
|
||||
case strings.HasSuffix(lower, "/v1/models"):
|
||||
normalized = normalized[:len(normalized)-len("/v1/models")]
|
||||
case strings.HasSuffix(lower, "/models"):
|
||||
normalized = normalized[:len(normalized)-len("/models")]
|
||||
}
|
||||
normalized = strings.TrimRight(normalized, "/")
|
||||
if strings.HasSuffix(strings.ToLower(normalized), "/v1") {
|
||||
return normalized
|
||||
}
|
||||
return normalized + "/v1"
|
||||
}
|
||||
|
||||
func normalizeCursorAPIPath(path string) string {
|
||||
normalized := strings.TrimRight(strings.TrimSpace(path), "/")
|
||||
lower := strings.ToLower(normalized)
|
||||
switch {
|
||||
case strings.HasSuffix(lower, "/v1/agents"):
|
||||
normalized = normalized[:len(normalized)-len("/v1/agents")]
|
||||
case strings.HasSuffix(lower, "/agents"):
|
||||
normalized = normalized[:len(normalized)-len("/agents")]
|
||||
case strings.HasSuffix(lower, "/v1/models"):
|
||||
normalized = normalized[:len(normalized)-len("/v1/models")]
|
||||
case strings.HasSuffix(lower, "/models"):
|
||||
normalized = normalized[:len(normalized)-len("/models")]
|
||||
}
|
||||
normalized = strings.TrimRight(normalized, "/")
|
||||
if strings.HasSuffix(strings.ToLower(normalized), "/v1") {
|
||||
return normalized
|
||||
}
|
||||
if normalized == "" {
|
||||
return "/v1"
|
||||
}
|
||||
return normalized + "/v1"
|
||||
}
|
||||
|
||||
type cursorPrompt struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type cursorModelSelection struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
type cursorCreateAgentRequest struct {
|
||||
Prompt cursorPrompt `json:"prompt"`
|
||||
Model *cursorModelSelection `json:"model,omitempty"`
|
||||
}
|
||||
|
||||
type cursorCreateAgentResponse struct {
|
||||
Agent struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"agent"`
|
||||
Run struct {
|
||||
ID string `json:"id"`
|
||||
AgentID string `json:"agentId"`
|
||||
} `json:"run"`
|
||||
}
|
||||
|
||||
type cursorRunResponse struct {
|
||||
ID string `json:"id"`
|
||||
AgentID string `json:"agentId"`
|
||||
Status string `json:"status"`
|
||||
Result string `json:"result"`
|
||||
DurationMS int `json:"durationMs"`
|
||||
}
|
||||
|
||||
type cursorAssistantEvent struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type cursorErrorEvent struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type cursorResultEvent struct {
|
||||
RunID string `json:"runId"`
|
||||
Status string `json:"status"`
|
||||
Text string `json:"text"`
|
||||
DurationMS int `json:"durationMs"`
|
||||
}
|
||||
|
||||
func (p *CursorAgentProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.ChatResponse, error) {
|
||||
if err := p.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
agentID, runID, err := p.createAgent(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
run, err := p.waitForRun(ctx, agentID, runID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ai.ChatResponse{
|
||||
Content: strings.TrimSpace(run.Result),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *CursorAgentProvider) ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error {
|
||||
if err := p.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
agentID, runID, err := p.createAgent(ctx, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stream, err := p.openRunStream(ctx, agentID, runID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
scanner := bufio.NewScanner(stream)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
|
||||
var (
|
||||
currentEventType string
|
||||
currentDataLines []string
|
||||
receivedAssistantText bool
|
||||
receivedResultText bool
|
||||
completedExplicitly bool
|
||||
)
|
||||
|
||||
dispatchEvent := func(eventType string, dataLines []string) (bool, error) {
|
||||
if strings.TrimSpace(eventType) == "" {
|
||||
eventType = "message"
|
||||
}
|
||||
payload := strings.TrimSpace(strings.Join(dataLines, "\n"))
|
||||
switch eventType {
|
||||
case "assistant":
|
||||
if payload == "" {
|
||||
return false, nil
|
||||
}
|
||||
var event cursorAssistantEvent
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
return false, nil
|
||||
}
|
||||
if strings.TrimSpace(event.Text) != "" {
|
||||
receivedAssistantText = true
|
||||
callback(ai.StreamChunk{Content: event.Text})
|
||||
}
|
||||
case "thinking":
|
||||
if payload == "" {
|
||||
return false, nil
|
||||
}
|
||||
var event cursorAssistantEvent
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
return false, nil
|
||||
}
|
||||
if strings.TrimSpace(event.Text) != "" {
|
||||
callback(ai.StreamChunk{
|
||||
Thinking: event.Text,
|
||||
ReasoningContent: event.Text,
|
||||
})
|
||||
}
|
||||
case "result":
|
||||
if payload == "" {
|
||||
return false, nil
|
||||
}
|
||||
var event cursorResultEvent
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
return false, nil
|
||||
}
|
||||
if !receivedAssistantText && strings.TrimSpace(event.Text) != "" {
|
||||
receivedResultText = true
|
||||
callback(ai.StreamChunk{Content: event.Text})
|
||||
}
|
||||
if isCursorRunFailureStatus(event.Status) {
|
||||
callback(ai.StreamChunk{
|
||||
Error: cursorRunStatusMessage(event.Status, event.Text),
|
||||
Done: true,
|
||||
})
|
||||
completedExplicitly = true
|
||||
return true, nil
|
||||
}
|
||||
case "error":
|
||||
if payload == "" {
|
||||
callback(ai.StreamChunk{Error: "Cursor 流式请求失败", Done: true})
|
||||
completedExplicitly = true
|
||||
return true, nil
|
||||
}
|
||||
var event cursorErrorEvent
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
callback(ai.StreamChunk{Error: "Cursor 流式请求失败", Done: true})
|
||||
completedExplicitly = true
|
||||
return true, nil
|
||||
}
|
||||
errMessage := strings.TrimSpace(event.Message)
|
||||
if errMessage == "" {
|
||||
errMessage = "Cursor 流式请求失败"
|
||||
}
|
||||
callback(ai.StreamChunk{Error: errMessage, Done: true})
|
||||
completedExplicitly = true
|
||||
return true, nil
|
||||
case "done":
|
||||
callback(ai.StreamChunk{Done: true})
|
||||
completedExplicitly = true
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
switch {
|
||||
case strings.TrimSpace(line) == "":
|
||||
done, dispatchErr := dispatchEvent(currentEventType, currentDataLines)
|
||||
currentEventType = ""
|
||||
currentDataLines = nil
|
||||
if dispatchErr != nil {
|
||||
return dispatchErr
|
||||
}
|
||||
if done {
|
||||
return nil
|
||||
}
|
||||
case strings.HasPrefix(line, "event:"):
|
||||
currentEventType = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
|
||||
case strings.HasPrefix(line, "data:"):
|
||||
currentDataLines = append(currentDataLines, strings.TrimSpace(strings.TrimPrefix(line, "data:")))
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return fmt.Errorf("读取 Cursor 流式响应失败: %w", err)
|
||||
}
|
||||
|
||||
if len(currentDataLines) > 0 || strings.TrimSpace(currentEventType) != "" {
|
||||
done, dispatchErr := dispatchEvent(currentEventType, currentDataLines)
|
||||
if dispatchErr != nil {
|
||||
return dispatchErr
|
||||
}
|
||||
if done {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if !completedExplicitly {
|
||||
if !receivedAssistantText && !receivedResultText {
|
||||
callback(ai.StreamChunk{Error: "未收到任何有效响应内容,请检查 Cursor 配置或模型权限", Done: true})
|
||||
return nil
|
||||
}
|
||||
callback(ai.StreamChunk{Done: true})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *CursorAgentProvider) createAgent(ctx context.Context, req ai.ChatRequest) (string, string, error) {
|
||||
requestBody, err := buildCursorCreateAgentRequest(req, p.config.Model)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
responseBody := cursorCreateAgentResponse{}
|
||||
if err := p.doJSONRequest(ctx, http.MethodPost, ResolveCursorAPIEndpoint(p.baseURL, "agents"), requestBody, &responseBody, "application/json"); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
agentID := strings.TrimSpace(responseBody.Agent.ID)
|
||||
runID := strings.TrimSpace(responseBody.Run.ID)
|
||||
if agentID == "" || runID == "" {
|
||||
return "", "", fmt.Errorf("Cursor 创建 agent 成功,但未返回有效的 agentId/runId")
|
||||
}
|
||||
return agentID, runID, nil
|
||||
}
|
||||
|
||||
func buildCursorCreateAgentRequest(req ai.ChatRequest, model string) (cursorCreateAgentRequest, error) {
|
||||
prompt, err := buildCursorPrompt(req.Messages)
|
||||
if err != nil {
|
||||
return cursorCreateAgentRequest{}, err
|
||||
}
|
||||
|
||||
requestBody := cursorCreateAgentRequest{
|
||||
Prompt: cursorPrompt{
|
||||
Text: prompt,
|
||||
},
|
||||
}
|
||||
|
||||
if trimmedModel := strings.TrimSpace(model); trimmedModel != "" {
|
||||
requestBody.Model = &cursorModelSelection{ID: trimmedModel}
|
||||
}
|
||||
|
||||
return requestBody, nil
|
||||
}
|
||||
|
||||
func buildCursorPrompt(messages []ai.Message) (string, error) {
|
||||
requestMessages := messages
|
||||
if requestMessagesContainImages(messages) {
|
||||
requestMessages = stripImagesFromRequestMessages(messages)
|
||||
}
|
||||
|
||||
prompt := strings.TrimSpace(buildPrompt(requestMessages))
|
||||
if prompt == "" {
|
||||
return "", fmt.Errorf("请求内容不能为空")
|
||||
}
|
||||
return prompt, nil
|
||||
}
|
||||
|
||||
func (p *CursorAgentProvider) waitForRun(ctx context.Context, agentID string, runID string) (*cursorRunResponse, error) {
|
||||
ticker := time.NewTicker(cursorRunPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
run, err := p.getRun(ctx, agentID, runID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if isCursorRunTerminalStatus(run.Status) {
|
||||
if isCursorRunFailureStatus(run.Status) {
|
||||
return nil, fmt.Errorf("%s", cursorRunStatusMessage(run.Status, run.Result))
|
||||
}
|
||||
return run, nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *CursorAgentProvider) getRun(ctx context.Context, agentID string, runID string) (*cursorRunResponse, error) {
|
||||
endpoint := ResolveCursorAPIEndpoint(p.baseURL, fmt.Sprintf("agents/%s/runs/%s", agentID, runID))
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建 Cursor run 查询失败: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Accept", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+p.config.APIKey)
|
||||
for k, v := range p.config.Headers {
|
||||
httpReq.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp, err := p.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询 Cursor run 状态失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
bodyBytes, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
return nil, fmt.Errorf("Cursor run 查询失败 (HTTP %d): %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
|
||||
}
|
||||
|
||||
responseBody := cursorRunResponse{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&responseBody); err != nil {
|
||||
return nil, fmt.Errorf("解析 Cursor run 响应失败: %w", err)
|
||||
}
|
||||
return &responseBody, nil
|
||||
}
|
||||
|
||||
func (p *CursorAgentProvider) openRunStream(ctx context.Context, agentID string, runID string) (io.ReadCloser, error) {
|
||||
endpoint := ResolveCursorAPIEndpoint(p.baseURL, fmt.Sprintf("agents/%s/runs/%s/stream", agentID, runID))
|
||||
requestLog := logAIUpstreamRequestStart(p.Name(), http.MethodGet, endpoint, nil)
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
logAIUpstreamRequestFinish(requestLog, 0, err)
|
||||
return nil, fmt.Errorf("创建 Cursor 流式请求失败: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Authorization", "Bearer "+p.config.APIKey)
|
||||
httpReq.Header.Set("Accept", "text/event-stream")
|
||||
httpReq.Header.Set("Cache-Control", "no-cache")
|
||||
for k, v := range p.config.Headers {
|
||||
httpReq.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp, err := p.client.Do(httpReq)
|
||||
if err != nil {
|
||||
logAIUpstreamRequestFinish(requestLog, 0, err)
|
||||
return nil, fmt.Errorf("发送 Cursor 流式请求失败: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
bodyBytes, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
statusErr := fmt.Errorf("Cursor API 返回错误 (HTTP %d): %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
|
||||
logAIUpstreamRequestFinish(requestLog, resp.StatusCode, statusErr)
|
||||
return nil, statusErr
|
||||
}
|
||||
|
||||
logAIUpstreamRequestFinish(requestLog, resp.StatusCode, nil)
|
||||
return resp.Body, nil
|
||||
}
|
||||
|
||||
func (p *CursorAgentProvider) doJSONRequest(ctx context.Context, method string, endpoint string, body any, target any, accept string) error {
|
||||
var requestBody io.Reader
|
||||
if body != nil {
|
||||
bodyBytes, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化 Cursor 请求失败: %w", err)
|
||||
}
|
||||
requestBody = bytes.NewReader(bodyBytes)
|
||||
}
|
||||
|
||||
requestLog := logAIUpstreamRequestStart(p.Name(), method, endpoint, body)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, method, endpoint, requestBody)
|
||||
if err != nil {
|
||||
logAIUpstreamRequestFinish(requestLog, 0, err)
|
||||
return fmt.Errorf("创建 Cursor 请求失败: %w", err)
|
||||
}
|
||||
|
||||
if body != nil {
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
if strings.TrimSpace(accept) != "" {
|
||||
httpReq.Header.Set("Accept", accept)
|
||||
}
|
||||
httpReq.Header.Set("Authorization", "Bearer "+p.config.APIKey)
|
||||
for k, v := range p.config.Headers {
|
||||
httpReq.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp, err := p.client.Do(httpReq)
|
||||
if err != nil {
|
||||
logAIUpstreamRequestFinish(requestLog, 0, err)
|
||||
return fmt.Errorf("发送 Cursor 请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
bodyBytes, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
statusErr := fmt.Errorf("Cursor API 返回错误 (HTTP %d): %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
|
||||
logAIUpstreamRequestFinish(requestLog, resp.StatusCode, statusErr)
|
||||
return statusErr
|
||||
}
|
||||
|
||||
if target != nil {
|
||||
if err := json.NewDecoder(resp.Body).Decode(target); err != nil {
|
||||
logAIUpstreamRequestFinish(requestLog, resp.StatusCode, err)
|
||||
return fmt.Errorf("解析 Cursor 响应失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
logAIUpstreamRequestFinish(requestLog, resp.StatusCode, nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
func isCursorRunTerminalStatus(status string) bool {
|
||||
switch strings.ToUpper(strings.TrimSpace(status)) {
|
||||
case "FINISHED", "ERROR", "CANCELLED", "EXPIRED":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isCursorRunFailureStatus(status string) bool {
|
||||
switch strings.ToUpper(strings.TrimSpace(status)) {
|
||||
case "ERROR", "CANCELLED", "EXPIRED":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func cursorRunStatusMessage(status string, result string) string {
|
||||
normalizedStatus := strings.ToUpper(strings.TrimSpace(status))
|
||||
if text := strings.TrimSpace(result); text != "" {
|
||||
return fmt.Sprintf("Cursor 运行结束(%s):%s", normalizedStatus, text)
|
||||
}
|
||||
return fmt.Sprintf("Cursor 运行结束(%s)", normalizedStatus)
|
||||
}
|
||||
166
internal/ai/provider/cursor_agent_test.go
Normal file
166
internal/ai/provider/cursor_agent_test.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/ai"
|
||||
)
|
||||
|
||||
func TestCursorAgentProviderChat_PollsUntilFinished(t *testing.T) {
|
||||
var (
|
||||
receivedAuthorization string
|
||||
receivedPromptText string
|
||||
pollCount int32
|
||||
)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/v1/agents":
|
||||
receivedAuthorization = r.Header.Get("Authorization")
|
||||
var body struct {
|
||||
Prompt struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"prompt"`
|
||||
Model *struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"model"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
t.Fatalf("decode create agent body: %v", err)
|
||||
}
|
||||
receivedPromptText = body.Prompt.Text
|
||||
if body.Model == nil || body.Model.ID != "composer-latest" {
|
||||
t.Fatalf("expected model to be forwarded, got %#v", body.Model)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"agent": map[string]any{"id": "bc-1"},
|
||||
"run": map[string]any{"id": "run-1", "agentId": "bc-1"},
|
||||
})
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/v1/agents/bc-1/runs/run-1":
|
||||
next := atomic.AddInt32(&pollCount, 1)
|
||||
if next == 1 {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "run-1",
|
||||
"agentId": "bc-1",
|
||||
"status": "RUNNING",
|
||||
})
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "run-1",
|
||||
"agentId": "bc-1",
|
||||
"status": "FINISHED",
|
||||
"result": "done from cursor",
|
||||
"durationMs": 1234,
|
||||
})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider, err := NewCursorAgentProvider(ai.ProviderConfig{
|
||||
Name: "Cursor",
|
||||
BaseURL: server.URL + "/v1",
|
||||
APIKey: "cursor-key",
|
||||
Model: "composer-latest",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
resp, err := provider.Chat(context.Background(), ai.ChatRequest{
|
||||
Messages: []ai.Message{
|
||||
{Role: "system", Content: "You are helpful"},
|
||||
{Role: "user", Content: "hello cursor"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("chat failed: %v", err)
|
||||
}
|
||||
|
||||
if receivedAuthorization != "Bearer cursor-key" {
|
||||
t.Fatalf("expected bearer auth header, got %q", receivedAuthorization)
|
||||
}
|
||||
if !strings.Contains(receivedPromptText, "You are helpful") || !strings.Contains(receivedPromptText, "hello cursor") {
|
||||
t.Fatalf("expected prompt text to include flattened history, got %q", receivedPromptText)
|
||||
}
|
||||
if resp.Content != "done from cursor" {
|
||||
t.Fatalf("expected final result content, got %q", resp.Content)
|
||||
}
|
||||
if atomic.LoadInt32(&pollCount) < 2 {
|
||||
t.Fatalf("expected provider to poll until terminal status, got %d polls", pollCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCursorAgentProviderChatStream_MapsAssistantAndThinkingEvents(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/v1/agents":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"agent": map[string]any{"id": "bc-2"},
|
||||
"run": map[string]any{"id": "run-2", "agentId": "bc-2"},
|
||||
})
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/v1/agents/bc-2/runs/run-2/stream":
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("event: status\n"))
|
||||
_, _ = w.Write([]byte("data: {\"runId\":\"run-2\",\"status\":\"RUNNING\"}\n\n"))
|
||||
_, _ = w.Write([]byte("event: thinking\n"))
|
||||
_, _ = w.Write([]byte("data: {\"text\":\"plan first\"}\n\n"))
|
||||
_, _ = w.Write([]byte("event: tool_call\n"))
|
||||
_, _ = w.Write([]byte("data: {\"callId\":\"tool-1\",\"name\":\"shell\",\"status\":\"running\"}\n\n"))
|
||||
_, _ = w.Write([]byte("event: assistant\n"))
|
||||
_, _ = w.Write([]byte("data: {\"text\":\"partial answer\"}\n\n"))
|
||||
_, _ = w.Write([]byte("event: result\n"))
|
||||
_, _ = w.Write([]byte("data: {\"runId\":\"run-2\",\"status\":\"FINISHED\",\"text\":\"final answer\"}\n\n"))
|
||||
_, _ = w.Write([]byte("event: done\n"))
|
||||
_, _ = w.Write([]byte("data: {}\n\n"))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider, err := NewCursorAgentProvider(ai.ProviderConfig{
|
||||
Name: "Cursor",
|
||||
BaseURL: server.URL + "/v1",
|
||||
APIKey: "cursor-key",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
var chunks []ai.StreamChunk
|
||||
err = provider.ChatStream(context.Background(), ai.ChatRequest{
|
||||
Messages: []ai.Message{
|
||||
{Role: "user", Content: "stream this"},
|
||||
},
|
||||
}, func(chunk ai.StreamChunk) {
|
||||
chunks = append(chunks, chunk)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("chat stream failed: %v", err)
|
||||
}
|
||||
|
||||
if len(chunks) < 3 {
|
||||
t.Fatalf("expected multiple stream chunks, got %d", len(chunks))
|
||||
}
|
||||
if chunks[0].Thinking != "plan first" {
|
||||
t.Fatalf("expected thinking chunk, got %#v", chunks[0])
|
||||
}
|
||||
if chunks[1].Content != "partial answer" {
|
||||
t.Fatalf("expected assistant content chunk, got %#v", chunks[1])
|
||||
}
|
||||
if len(chunks[1].ToolCalls) != 0 {
|
||||
t.Fatalf("expected cursor tool_call events to stay unmapped, got %#v", chunks[1].ToolCalls)
|
||||
}
|
||||
if !chunks[len(chunks)-1].Done {
|
||||
t.Fatalf("expected final done chunk, got %#v", chunks[len(chunks)-1])
|
||||
}
|
||||
}
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
// CustomProvider 自定义 Provider,根据 apiFormat 选择底层协议
|
||||
// 支持 openai / anthropic / gemini 三种 API 格式
|
||||
// 支持 openai / anthropic / gemini / cursor-agent 等 API 格式
|
||||
type CustomProvider struct {
|
||||
inner Provider
|
||||
name string
|
||||
@@ -33,6 +33,8 @@ func NewCustomProvider(config ai.ProviderConfig) (Provider, error) {
|
||||
innerProvider, err = NewAnthropicProvider(config)
|
||||
case "gemini":
|
||||
innerProvider, err = NewGeminiProvider(config)
|
||||
case "cursor-agent":
|
||||
innerProvider, err = NewCursorAgentProvider(config)
|
||||
case "claude-cli":
|
||||
innerProvider, err = NewClaudeCLIProvider(config)
|
||||
case "codebuddy-cli":
|
||||
|
||||
@@ -510,7 +510,7 @@ func (s *Service) AITestProvider(config ai.ProviderConfig) map[string]interface{
|
||||
var err error
|
||||
|
||||
switch providerType {
|
||||
case "openai", "anthropic", "gemini":
|
||||
case "openai", "anthropic", "gemini", "cursor-agent":
|
||||
req, reqErr := newProviderHealthCheckRequest(config)
|
||||
if reqErr != nil {
|
||||
err = s.localizeProviderHealthCheckRequestError(reqErr)
|
||||
@@ -750,6 +750,8 @@ func resolveModelsURL(config ai.ProviderConfig) string {
|
||||
baseURL = "https://generativelanguage.googleapis.com"
|
||||
}
|
||||
return baseURL + "/v1beta/models?key=" + config.APIKey
|
||||
case "cursor-agent":
|
||||
return provider.ResolveCursorAPIEndpoint(baseURL, "models")
|
||||
case "codebuddy-cli":
|
||||
return ""
|
||||
case "openai":
|
||||
@@ -779,6 +781,8 @@ func newModelsRequest(config ai.ProviderConfig) (*http.Request, error) {
|
||||
}
|
||||
case "gemini":
|
||||
// Gemini 使用 query string 传递 key,无需额外鉴权头
|
||||
case "cursor-agent":
|
||||
req.Header.Set("Authorization", "Bearer "+config.APIKey)
|
||||
default:
|
||||
req.Header.Set("Authorization", "Bearer "+config.APIKey)
|
||||
}
|
||||
@@ -935,6 +939,8 @@ func fetchModels(config ai.ProviderConfig) ([]string, error) {
|
||||
return fetchAnthropicModels(config)
|
||||
case "gemini":
|
||||
return fetchGeminiModels(config)
|
||||
case "cursor-agent":
|
||||
return fetchCursorModels(config)
|
||||
case "codebuddy-cli":
|
||||
return append([]string(nil), config.Models...), nil
|
||||
default:
|
||||
@@ -1057,6 +1063,42 @@ func fetchGeminiModels(config ai.ProviderConfig) ([]string, error) {
|
||||
return models, nil
|
||||
}
|
||||
|
||||
func fetchCursorModels(config ai.ProviderConfig) ([]string, error) {
|
||||
req, err := newModelsRequest(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求模型列表失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return nil, fmt.Errorf("获取模型列表失败 (HTTP %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Items []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"items"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("解析模型列表失败: %w", err)
|
||||
}
|
||||
|
||||
models := make([]string, 0, len(result.Items))
|
||||
for _, item := range result.Items {
|
||||
if strings.TrimSpace(item.ID) != "" {
|
||||
models = append(models, item.ID)
|
||||
}
|
||||
}
|
||||
return models, nil
|
||||
}
|
||||
|
||||
// --- 安全控制 ---
|
||||
|
||||
// AIGetSafetyLevel 获取当前安全级别
|
||||
|
||||
101
internal/ai/service/service_cursor_test.go
Normal file
101
internal/ai/service/service_cursor_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package aiservice
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/ai"
|
||||
)
|
||||
|
||||
func TestResolveModelsURL_UsesCursorModelsEndpoint(t *testing.T) {
|
||||
url := resolveModelsURL(ai.ProviderConfig{
|
||||
Type: "custom",
|
||||
APIFormat: "cursor-agent",
|
||||
BaseURL: "https://api.cursor.com/v1",
|
||||
})
|
||||
if url != "https://api.cursor.com/v1/models" {
|
||||
t.Fatalf("expected cursor models endpoint, got %q", url)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAITestProvider_UsesCursorModelsEndpointAndBearerAuth(t *testing.T) {
|
||||
var (
|
||||
receivedPath string
|
||||
receivedAuthorization string
|
||||
)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedPath = r.URL.Path
|
||||
receivedAuthorization = r.Header.Get("Authorization")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"items": []map[string]any{
|
||||
{"id": "composer-2"},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
service := NewService()
|
||||
result := service.AITestProvider(ai.ProviderConfig{
|
||||
Type: "custom",
|
||||
APIFormat: "cursor-agent",
|
||||
BaseURL: server.URL + "/v1",
|
||||
APIKey: "cursor-key",
|
||||
})
|
||||
|
||||
if result["success"] != true {
|
||||
t.Fatalf("expected AITestProvider to succeed, got %#v", result)
|
||||
}
|
||||
if receivedPath != "/v1/models" {
|
||||
t.Fatalf("expected cursor health check to hit /v1/models, got %q", receivedPath)
|
||||
}
|
||||
if receivedAuthorization != "Bearer cursor-key" {
|
||||
t.Fatalf("expected bearer auth header, got %q", receivedAuthorization)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAIListModels_FetchesCursorModelItems(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/models" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"items": []map[string]any{
|
||||
{"id": "composer-2"},
|
||||
{"id": "composer-latest"},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
service := NewService()
|
||||
service.providers = []ai.ProviderConfig{
|
||||
{
|
||||
ID: "provider-cursor",
|
||||
Type: "custom",
|
||||
APIFormat: "cursor-agent",
|
||||
BaseURL: server.URL + "/v1",
|
||||
APIKey: "cursor-key",
|
||||
},
|
||||
}
|
||||
service.activeProvider = "provider-cursor"
|
||||
|
||||
result := service.AIListModels()
|
||||
if result["success"] != true {
|
||||
t.Fatalf("expected AIListModels to succeed, got %#v", result)
|
||||
}
|
||||
|
||||
models, ok := result["models"].([]string)
|
||||
if !ok {
|
||||
t.Fatalf("expected []string models, got %#v", result["models"])
|
||||
}
|
||||
if len(models) != 2 || models[0] != "composer-2" || models[1] != "composer-latest" {
|
||||
t.Fatalf("unexpected models: %#v", models)
|
||||
}
|
||||
if source, _ := result["source"].(string); source != "api" {
|
||||
t.Fatalf("expected api source, got %#v", result["source"])
|
||||
}
|
||||
}
|
||||
@@ -80,7 +80,7 @@ type ProviderConfig struct {
|
||||
BaseURL string `json:"baseUrl"`
|
||||
Model string `json:"model"`
|
||||
Models []string `json:"models,omitempty"`
|
||||
APIFormat string `json:"apiFormat,omitempty"` // custom 专用: openai | anthropic | gemini | claude-cli | codebuddy-cli
|
||||
APIFormat string `json:"apiFormat,omitempty"` // custom 专用: openai | anthropic | gemini | cursor-agent | claude-cli | codebuddy-cli
|
||||
Headers map[string]string `json:"headers,omitempty"`
|
||||
MaxTokens int `json:"maxTokens"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
|
||||
Reference in New Issue
Block a user