mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-05-19 02:29:30 +08:00
Compare commits
6 Commits
feature/ai
...
release/0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
09d013f27d | ||
|
|
09aa526570 | ||
|
|
5844cd7c01 | ||
|
|
4f74c44147 | ||
|
|
a5fdfefa2d | ||
|
|
37ac13b94e |
@@ -6,7 +6,6 @@ import { DBGetDatabases, DBGetTables } from '../../wailsjs/go/app/App';
|
||||
import type { OverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme';
|
||||
import { AIChatMessage, AIToolCall } from '../types';
|
||||
import { DownOutlined } from '@ant-design/icons';
|
||||
import { message as antdMessage } from 'antd';
|
||||
import './AIChatPanel.css';
|
||||
|
||||
import { AIChatHeader } from './ai/AIChatHeader';
|
||||
@@ -14,6 +13,12 @@ import { AIChatWelcome } from './ai/AIChatWelcome';
|
||||
import { AIMessageBubble } from './ai/AIMessageBubble';
|
||||
import { AIChatInput } from './ai/AIChatInput';
|
||||
import { AIHistoryDrawer } from './ai/AIHistoryDrawer';
|
||||
import type { AIComposerNotice } from '../utils/aiComposerNotice';
|
||||
import {
|
||||
buildMissingModelNotice,
|
||||
buildMissingProviderNotice,
|
||||
buildModelFetchFailedNotice,
|
||||
} from '../utils/aiComposerNotice';
|
||||
|
||||
interface AIChatPanelProps {
|
||||
width?: number;
|
||||
@@ -211,6 +216,7 @@ export const AIChatPanel: React.FC<AIChatPanelProps> = ({
|
||||
const [dynamicModels, setDynamicModels] = useState<string[]>([]);
|
||||
const [showScrollBottom, setShowScrollBottom] = useState(false);
|
||||
const [loadingModels, setLoadingModels] = useState(false);
|
||||
const [composerNotice, setComposerNotice] = useState<AIComposerNotice | null>(null);
|
||||
const [panelWidth, setPanelWidth] = useState(width);
|
||||
const [isResizing, setIsResizing] = useState(false);
|
||||
const [historyOpen, setHistoryOpen] = useState(false);
|
||||
@@ -224,9 +230,6 @@ export const AIChatPanel: React.FC<AIChatPanelProps> = ({
|
||||
const panelRef = useRef<HTMLDivElement>(null); // 面板 DOM ref,用于拖拽时直接操作宽度
|
||||
const dragWidthRef = useRef(0); // 拖拽过程中的实时宽度(不触发 React 重渲染)
|
||||
|
||||
// 面板内部 toast 通知(不在屏幕顶部,而在面板容器内显示)
|
||||
const [messageApi, messageContextHolder] = antdMessage.useMessage({ getContainer: () => panelRef.current || document.body });
|
||||
|
||||
const aiChatHistory = useStore(state => state.aiChatHistory);
|
||||
const aiActiveSessionId = useStore(state => state.aiActiveSessionId);
|
||||
const createNewAISession = useStore(state => state.createNewAISession);
|
||||
@@ -336,6 +339,7 @@ export const AIChatPanel: React.FC<AIChatPanelProps> = ({
|
||||
useEffect(() => {
|
||||
const handler = () => {
|
||||
setDynamicModels([]);
|
||||
setComposerNotice(null);
|
||||
activeProviderIdRef.current = null;
|
||||
loadActiveProvider();
|
||||
};
|
||||
@@ -350,6 +354,7 @@ export const AIChatPanel: React.FC<AIChatPanelProps> = ({
|
||||
const payload = { ...activeProvider, model: val };
|
||||
await Service?.AISaveProvider?.(payload);
|
||||
setActiveProvider(payload);
|
||||
setComposerNotice(null);
|
||||
} catch (e) { console.warn('Failed to update provider model', e); }
|
||||
};
|
||||
|
||||
@@ -358,33 +363,45 @@ export const AIChatPanel: React.FC<AIChatPanelProps> = ({
|
||||
useEffect(() => {
|
||||
if (activeProvider?.id && activeProvider.id !== activeProviderIdRef.current) {
|
||||
setDynamicModels([]);
|
||||
setComposerNotice(null);
|
||||
activeProviderIdRef.current = activeProvider.id;
|
||||
}
|
||||
// 供应商被删除后 activeProvider 变为 null,此时也必须清空残留模型
|
||||
if (!activeProvider) {
|
||||
setDynamicModels([]);
|
||||
setComposerNotice(null);
|
||||
activeProviderIdRef.current = null;
|
||||
}
|
||||
}, [activeProvider?.id, activeProvider]);
|
||||
|
||||
useEffect(() => {
|
||||
if (activeProvider?.model && String(activeProvider.model).trim()) {
|
||||
setComposerNotice(null);
|
||||
}
|
||||
}, [activeProvider?.model]);
|
||||
|
||||
|
||||
// dynamicModels 仅在内存中使用,不再写回供应商配置,避免污染静态 models 列表
|
||||
|
||||
const fetchDynamicModels = useCallback(async () => {
|
||||
try {
|
||||
setLoadingModels(true);
|
||||
setComposerNotice(null);
|
||||
const Service = (window as any).go?.aiservice?.Service;
|
||||
if (!Service) return;
|
||||
const result = await Service.AIListModels?.();
|
||||
if (result?.success && Array.isArray(result.models) && result.models.length > 0) {
|
||||
const sortedModels = [...result.models].sort((a, b) => a.localeCompare(b));
|
||||
setDynamicModels(sortedModels);
|
||||
setComposerNotice(null);
|
||||
} else if (result && !result.success) {
|
||||
messageApi.warning(result.error || '获取模型列表失败,可手动输入模型名称');
|
||||
setDynamicModels([]);
|
||||
setComposerNotice(buildModelFetchFailedNotice(result.error));
|
||||
}
|
||||
} catch (e: any) {
|
||||
console.warn('Failed to fetch models', e);
|
||||
messageApi.warning('获取模型列表失败: ' + (e?.message || '未知错误'));
|
||||
setDynamicModels([]);
|
||||
setComposerNotice(buildModelFetchFailedNotice('获取模型列表失败:' + (e?.message || '未知错误')));
|
||||
} finally {
|
||||
setLoadingModels(false);
|
||||
}
|
||||
@@ -1030,13 +1047,14 @@ SELECT * FROM users WHERE status = 1;
|
||||
|
||||
// 前置校验:必须配置供应商且选择模型后才能发送
|
||||
if (!activeProvider) {
|
||||
messageApi.warning('请先在 AI 设置中配置供应商');
|
||||
setComposerNotice(buildMissingProviderNotice());
|
||||
return;
|
||||
}
|
||||
if (!activeProvider.model || !activeProvider.model.trim()) {
|
||||
messageApi.warning('请先选择模型 ID(点击工具栏的模型下拉框选择)');
|
||||
setComposerNotice(buildMissingModelNotice());
|
||||
return;
|
||||
}
|
||||
setComposerNotice(null);
|
||||
|
||||
toolCallRoundRef.current = 0; // 重置工具调用轮次计数
|
||||
nudgeCountRef.current = 0; // 重置催促计数
|
||||
@@ -1258,7 +1276,6 @@ SELECT * FROM users WHERE status = 1;
|
||||
|
||||
return (
|
||||
<div ref={panelRef} className="ai-chat-panel" style={{ width: panelWidth, background: bgColor || 'transparent', color: textColor, borderLeft: overlayTheme.shellBorder, position: 'relative' }}>
|
||||
{messageContextHolder}
|
||||
<div className={`ai-resize-handle${isResizing ? ' active' : ''}`} onMouseDown={handleResizeStart} />
|
||||
|
||||
{isResizing && panelRect.current && createPortal(
|
||||
@@ -1366,6 +1383,7 @@ SELECT * FROM users WHERE status = 1;
|
||||
activeProvider={activeProvider}
|
||||
dynamicModels={dynamicModels}
|
||||
loadingModels={loadingModels}
|
||||
composerNotice={composerNotice}
|
||||
onModelChange={handleModelChange}
|
||||
onFetchModels={fetchDynamicModels}
|
||||
textareaRef={textareaRef}
|
||||
|
||||
@@ -2,6 +2,24 @@ import React, { useState, useEffect, useCallback, useRef } from 'react';
|
||||
import { Modal, Button, Input, Select, Form, message as antdMessage, Tooltip, Tabs, Space, Popconfirm, Slider } from 'antd';
|
||||
import { PlusOutlined, DeleteOutlined, EditOutlined, CheckOutlined, ApiOutlined, SafetyCertificateOutlined, RobotOutlined, ThunderboltOutlined, CloudOutlined, ExperimentOutlined, KeyOutlined, LinkOutlined, AppstoreOutlined, ToolOutlined } from '@ant-design/icons';
|
||||
import type { AIProviderConfig, AIProviderType, AISafetyLevel, AIContextLevel } from '../types';
|
||||
import {
|
||||
getProviderFingerprint,
|
||||
getProviderHostname,
|
||||
matchQwenPresetKey,
|
||||
QWEN_BAILIAN_ANTHROPIC_BASE_URL,
|
||||
QWEN_CODING_PLAN_ANTHROPIC_BASE_URL,
|
||||
QWEN_CODING_PLAN_MODELS,
|
||||
resolvePresetBaseURL,
|
||||
resolvePresetModelSelection,
|
||||
resolvePresetTransport,
|
||||
} from '../utils/aiProviderPresets';
|
||||
import {
|
||||
PROVIDER_PRESET_CARD_BASE_STYLE,
|
||||
PROVIDER_PRESET_CARD_CONTENT_STYLE,
|
||||
PROVIDER_PRESET_CARD_DESCRIPTION_STYLE,
|
||||
PROVIDER_PRESET_GRID_STYLE,
|
||||
PROVIDER_PRESET_CARD_TITLE_STYLE,
|
||||
} from '../utils/aiSettingsPresetLayout';
|
||||
|
||||
import type { OverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme';
|
||||
|
||||
@@ -20,6 +38,7 @@ interface ProviderPreset {
|
||||
desc: string;
|
||||
color: string;
|
||||
backendType: AIProviderType;
|
||||
fixedApiFormat?: string;
|
||||
defaultBaseUrl: string;
|
||||
defaultModel: string;
|
||||
models: string[];
|
||||
@@ -28,12 +47,14 @@ interface ProviderPreset {
|
||||
const PROVIDER_PRESETS: ProviderPreset[] = [
|
||||
{ key: 'openai', label: 'OpenAI', icon: <ApiOutlined />, desc: 'GPT-5.4 / 5.3 系列', color: '#10b981', backendType: 'openai', defaultBaseUrl: 'https://api.openai.com/v1', defaultModel: 'gpt-4o', models: [] },
|
||||
{ key: 'deepseek', label: 'DeepSeek', icon: <ThunderboltOutlined />, desc: 'DeepSeek-V4 / R1', color: '#3b82f6', backendType: 'openai', defaultBaseUrl: 'https://api.deepseek.com/v1', defaultModel: 'deepseek-chat', models: [] },
|
||||
{ key: 'qwen', label: '通义千问', icon: <CloudOutlined />, desc: 'Qwen3.5 / Qwen3 系列', color: '#6366f1', backendType: 'openai', defaultBaseUrl: 'https://dashscope.aliyuncs.com/compatible-mode/v1', defaultModel: 'qwen-max', models: [] },
|
||||
{ key: 'qwen-bailian', label: '通义千问(百炼通用)', icon: <CloudOutlined />, desc: '百炼 Anthropic 兼容 / 模型从远端拉取', color: '#6366f1', backendType: 'anthropic', defaultBaseUrl: QWEN_BAILIAN_ANTHROPIC_BASE_URL, defaultModel: '', models: [] },
|
||||
{ key: 'qwen-coding-plan', label: '通义千问(Coding Plan)', icon: <CloudOutlined />, desc: 'Claude Code CLI 代理链路 / 使用官方支持模型清单', color: '#4f46e5', backendType: 'custom', fixedApiFormat: 'claude-cli', defaultBaseUrl: QWEN_CODING_PLAN_ANTHROPIC_BASE_URL, defaultModel: '', models: QWEN_CODING_PLAN_MODELS },
|
||||
{ key: 'zhipu', label: '智谱 GLM', icon: <ExperimentOutlined />, desc: 'GLM-5 / GLM-5-Turbo', color: '#0ea5e9', backendType: 'openai', defaultBaseUrl: 'https://open.bigmodel.cn/api/paas/v4', defaultModel: 'glm-4', models: [] },
|
||||
{ key: 'moonshot', label: 'Kimi', icon: <ExperimentOutlined />, desc: 'Kimi K2.5 (Anthropic 兼容)', color: '#0d9488', backendType: 'anthropic', defaultBaseUrl: 'https://api.moonshot.cn/anthropic', defaultModel: 'moonshot-v1-8k', models: [] },
|
||||
{ key: 'anthropic', label: 'Claude', icon: <ExperimentOutlined />, desc: 'Claude Opus/Sonnet', color: '#d97706', backendType: 'anthropic', defaultBaseUrl: 'https://api.anthropic.com', defaultModel: 'claude-3-5-sonnet-20241022', models: [] },
|
||||
{ key: 'gemini', label: 'Gemini', icon: <CloudOutlined />, desc: 'Gemini 3.1 / 2.5 系列', color: '#059669', backendType: 'gemini', defaultBaseUrl: 'https://generativelanguage.googleapis.com', defaultModel: 'gemini-2.5-flash', models: [] },
|
||||
{ key: 'volcengine', label: '火山引擎', icon: <CloudOutlined />, desc: '火山方舟 / 豆包大模型', color: '#0ea5e9', backendType: 'openai', defaultBaseUrl: 'https://ark.cn-beijing.volces.com/api/v3', defaultModel: 'ep-xxxxxx', models: [] },
|
||||
{ key: 'volcengine-ark', label: '火山方舟', icon: <CloudOutlined />, desc: 'Ark 通用推理 / 豆包模型', color: '#0ea5e9', backendType: 'openai', defaultBaseUrl: 'https://ark.cn-beijing.volces.com/api/v3', defaultModel: '', models: [] },
|
||||
{ key: 'volcengine-coding', label: '火山 Coding Plan', icon: <CloudOutlined />, desc: 'Ark Code / Coding Plan', color: '#0284c7', backendType: 'openai', defaultBaseUrl: 'https://ark.cn-beijing.volces.com/api/coding/v3', defaultModel: '', models: [] },
|
||||
{ key: 'minimax', label: 'MiniMax', icon: <ExperimentOutlined />, desc: 'M2.7 / M2.5 系列 (Anthropic 兼容)', color: '#e11d48', backendType: 'anthropic', defaultBaseUrl: 'https://api.minimaxi.com/anthropic', defaultModel: 'MiniMax-M2.7', models: ['MiniMax-M2.7', 'MiniMax-M2.7-highspeed', 'MiniMax-M2.5', 'MiniMax-M2.5-highspeed', 'MiniMax-M2.1', 'MiniMax-M2.1-highspeed', 'MiniMax-M2'] },
|
||||
{ key: 'ollama', label: 'Ollama', icon: <AppstoreOutlined />, desc: '本地部署开源模型', color: '#78716c', backendType: 'openai', defaultBaseUrl: 'http://localhost:11434/v1', defaultModel: 'llama3', models: [] },
|
||||
{ key: 'custom', label: '自定义', icon: <AppstoreOutlined />, desc: '自定义 API 端点', color: '#64748b', backendType: 'custom', defaultBaseUrl: '', defaultModel: '', models: [] },
|
||||
@@ -41,16 +62,21 @@ const PROVIDER_PRESETS: ProviderPreset[] = [
|
||||
|
||||
const findPreset = (key: string): ProviderPreset => PROVIDER_PRESETS.find(p => p.key === key) || PROVIDER_PRESETS[PROVIDER_PRESETS.length - 1];
|
||||
|
||||
const getProviderHostname = (raw?: string): string => {
|
||||
if (!raw) return '';
|
||||
try {
|
||||
return new URL(raw).hostname.toLowerCase();
|
||||
} catch {
|
||||
return '';
|
||||
}
|
||||
};
|
||||
|
||||
const matchProviderPreset = (provider: Pick<AIProviderConfig, 'type' | 'baseUrl'>): ProviderPreset => {
|
||||
const qwenPresetKey = matchQwenPresetKey(provider);
|
||||
if (qwenPresetKey) {
|
||||
return findPreset(qwenPresetKey);
|
||||
}
|
||||
const fingerprint = getProviderFingerprint(provider.baseUrl);
|
||||
const exactPreset = PROVIDER_PRESETS.find(pr =>
|
||||
pr.backendType === provider.type
|
||||
&& fingerprint !== ''
|
||||
&& fingerprint === getProviderFingerprint(pr.defaultBaseUrl)
|
||||
);
|
||||
if (exactPreset) {
|
||||
return exactPreset;
|
||||
}
|
||||
|
||||
const host = getProviderHostname(provider.baseUrl);
|
||||
if (host.endsWith('moonshot.cn')) {
|
||||
return findPreset('moonshot');
|
||||
@@ -143,11 +169,22 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
const handleEditProvider = (p: AIProviderConfig) => {
|
||||
// 尝试根据 baseUrl 和 type 推断 preset
|
||||
const matchedPreset = matchProviderPreset(p);
|
||||
const resolvedTransport = resolvePresetTransport({
|
||||
presetBackendType: matchedPreset.backendType,
|
||||
presetFixedApiFormat: matchedPreset.fixedApiFormat,
|
||||
valuesApiFormat: p.apiFormat,
|
||||
});
|
||||
setEditingProvider(p);
|
||||
setIsEditing(true);
|
||||
setTestStatus('idle');
|
||||
form.resetFields();
|
||||
form.setFieldsValue({ ...p, type: matchedPreset.backendType, models: p.models || [], presetKey: matchedPreset.key, apiFormat: p.apiFormat || 'openai' });
|
||||
form.setFieldsValue({
|
||||
...p,
|
||||
type: resolvedTransport.type,
|
||||
models: p.models || [],
|
||||
presetKey: matchedPreset.key,
|
||||
apiFormat: resolvedTransport.apiFormat || p.apiFormat || 'openai',
|
||||
});
|
||||
};
|
||||
|
||||
const handleDeleteProvider = async (id: string) => {
|
||||
@@ -179,24 +216,38 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
const Service = (window as any).go?.aiservice?.Service;
|
||||
|
||||
// 构建 payload,处理 model/models 逻辑
|
||||
const isCustomLike = values.presetKey === 'custom' || values.presetKey === 'ollama';
|
||||
const preset = findPreset(values.presetKey);
|
||||
const resolvedModels = isCustomLike ? (values.models || []) : preset.models;
|
||||
const fallbackModel = resolvedModels.length > 0 ? resolvedModels[0] : '';
|
||||
const finalModel = isCustomLike ? fallbackModel : (values.model || fallbackModel);
|
||||
const isCustomLike = values.presetKey === 'custom' || values.presetKey === 'ollama';
|
||||
const { model: finalModel, models: resolvedModels } = resolvePresetModelSelection({
|
||||
presetKey: values.presetKey,
|
||||
presetDefaultModel: preset.defaultModel,
|
||||
presetModels: preset.models,
|
||||
valuesModel: values.model,
|
||||
customModels: values.models,
|
||||
});
|
||||
// 内置供应商自动使用 preset label 作为名称
|
||||
const finalName = isCustomLike ? (values.name || preset.label) : preset.label;
|
||||
|
||||
const finalBaseUrl = values.baseUrl || preset.defaultBaseUrl;
|
||||
const finalBaseUrl = resolvePresetBaseURL({
|
||||
presetKey: values.presetKey,
|
||||
presetDefaultBaseUrl: preset.defaultBaseUrl,
|
||||
valuesBaseUrl: values.baseUrl,
|
||||
});
|
||||
const resolvedTransport = resolvePresetTransport({
|
||||
presetBackendType: preset.backendType,
|
||||
presetFixedApiFormat: preset.fixedApiFormat,
|
||||
valuesApiFormat: values.apiFormat,
|
||||
});
|
||||
|
||||
const payload = {
|
||||
...editingProvider,
|
||||
...values,
|
||||
...resolvedTransport,
|
||||
name: finalName,
|
||||
model: finalModel,
|
||||
models: resolvedModels,
|
||||
baseUrl: finalBaseUrl,
|
||||
apiFormat: values.apiFormat || 'openai',
|
||||
apiFormat: resolvedTransport.apiFormat,
|
||||
};
|
||||
// 后端 AISaveProvider 统一处理新增和更新,返回 void,失败抛异常
|
||||
await Service?.AISaveProvider?.(payload);
|
||||
@@ -240,8 +291,34 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
setTestStatus('idle');
|
||||
const Service = (window as any).go?.aiservice?.Service;
|
||||
const preset = findPreset(values.presetKey || 'openai');
|
||||
const finalBaseUrl = values.baseUrl || preset.defaultBaseUrl;
|
||||
const res = await Service?.AITestProvider?.({ ...values, baseUrl: finalBaseUrl, maxTokens: Number(values.maxTokens) || 4096, temperature: Number(values.temperature) ?? 0.7 });
|
||||
const finalBaseUrl = resolvePresetBaseURL({
|
||||
presetKey: values.presetKey || 'openai',
|
||||
presetDefaultBaseUrl: preset.defaultBaseUrl,
|
||||
valuesBaseUrl: values.baseUrl,
|
||||
});
|
||||
const { model: finalModel, models: resolvedModels } = resolvePresetModelSelection({
|
||||
presetKey: values.presetKey || 'openai',
|
||||
presetDefaultModel: preset.defaultModel,
|
||||
presetModels: preset.models,
|
||||
valuesModel: values.model,
|
||||
customModels: values.models,
|
||||
});
|
||||
const resolvedTransport = resolvePresetTransport({
|
||||
presetBackendType: preset.backendType,
|
||||
presetFixedApiFormat: preset.fixedApiFormat,
|
||||
valuesApiFormat: values.apiFormat,
|
||||
});
|
||||
const res = await Service?.AITestProvider?.({
|
||||
...editingProvider,
|
||||
...values,
|
||||
...resolvedTransport,
|
||||
baseUrl: finalBaseUrl,
|
||||
model: finalModel,
|
||||
models: resolvedModels,
|
||||
maxTokens: Number(values.maxTokens) || 4096,
|
||||
temperature: Number(values.temperature) ?? 0.7,
|
||||
apiFormat: resolvedTransport.apiFormat,
|
||||
});
|
||||
if (res?.success) { setTestStatus('success'); void messageApi.success('连接成功'); }
|
||||
else { setTestStatus('error'); void messageApi.error(`测试失败: ${res?.message || '未知错误'}`); }
|
||||
} catch (e: any) { setTestStatus('error'); void messageApi.error(e?.message || '测试失败'); }
|
||||
@@ -250,9 +327,15 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
|
||||
const handlePresetChange = (presetKey: string) => {
|
||||
const preset = findPreset(presetKey);
|
||||
const resolvedTransport = resolvePresetTransport({
|
||||
presetBackendType: preset.backendType,
|
||||
presetFixedApiFormat: preset.fixedApiFormat,
|
||||
valuesApiFormat: form.getFieldValue('apiFormat'),
|
||||
});
|
||||
form.setFieldsValue({
|
||||
presetKey,
|
||||
type: preset.backendType,
|
||||
type: resolvedTransport.type,
|
||||
apiFormat: resolvedTransport.apiFormat || 'openai',
|
||||
baseUrl: preset.defaultBaseUrl,
|
||||
model: preset.defaultModel,
|
||||
});
|
||||
@@ -307,7 +390,7 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
<div style={{ fontSize: 12, color: overlayTheme.mutedText, marginTop: 4, display: 'flex', alignItems: 'center', gap: 6 }}>
|
||||
<span>{matchedPreset.label}</span>
|
||||
<span style={{ opacity: 0.4 }}>·</span>
|
||||
<span style={{ fontFamily: 'monospace', fontSize: 12 }}>{p.model}</span>
|
||||
<span style={{ fontFamily: 'monospace', fontSize: 12 }}>{p.model || '未选择模型'}</span>
|
||||
</div>
|
||||
</div>
|
||||
<Space size={2}>
|
||||
@@ -353,25 +436,24 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
<AppstoreOutlined style={{ fontSize: 14 }} /> 服务类型
|
||||
</div>
|
||||
<Form.Item name="presetKey" noStyle>
|
||||
<div style={{ display: 'grid', gridTemplateColumns: '1fr 1fr 1fr', gap: 6 }}>
|
||||
<div style={PROVIDER_PRESET_GRID_STYLE}>
|
||||
{PROVIDER_PRESETS.map(pt => (
|
||||
<div key={pt.key} onClick={() => { form.setFieldValue('presetKey', pt.key); handlePresetChange(pt.key); }}
|
||||
style={{
|
||||
padding: '12px 14px', borderRadius: 12, cursor: 'pointer', transition: 'all 0.2s ease',
|
||||
...PROVIDER_PRESET_CARD_BASE_STYLE,
|
||||
border: `1.5px solid ${presetKeyFromForm === pt.key ? overlayTheme.selectedText : 'transparent'}`,
|
||||
background: presetKeyFromForm === pt.key ? overlayTheme.selectedBg : (darkMode ? 'rgba(255,255,255,0.02)' : 'rgba(255,255,255,0.72)'),
|
||||
boxShadow: presetKeyFromForm === pt.key ? 'none' : (darkMode ? 'inset 0 0 0 1px rgba(255,255,255,0.028)' : 'inset 0 0 0 1px rgba(16,24,40,0.03)'),
|
||||
display: 'flex', alignItems: 'flex-start', gap: 10,
|
||||
}}>
|
||||
<div style={{
|
||||
color: presetKeyFromForm === pt.key ? overlayTheme.iconColor : overlayTheme.mutedText,
|
||||
fontSize: 18, marginTop: 2, transition: 'all 0.2s ease',
|
||||
fontSize: 18, marginTop: 2, transition: 'all 0.2s ease', flexShrink: 0,
|
||||
}}>
|
||||
{pt.icon}
|
||||
</div>
|
||||
<div>
|
||||
<div style={{ fontSize: 13, fontWeight: 700, color: overlayTheme.titleText, lineHeight: 1.3 }}>{pt.label}</div>
|
||||
<div style={{ fontSize: 12, color: overlayTheme.mutedText, marginTop: 4, lineHeight: 1.4 }}>{pt.desc}</div>
|
||||
<div style={PROVIDER_PRESET_CARD_CONTENT_STYLE}>
|
||||
<div style={{ ...PROVIDER_PRESET_CARD_TITLE_STYLE, fontSize: 13, fontWeight: 700, color: overlayTheme.titleText, lineHeight: 1.3 }}>{pt.label}</div>
|
||||
<div style={{ ...PROVIDER_PRESET_CARD_DESCRIPTION_STYLE, fontSize: 12, color: overlayTheme.mutedText, lineHeight: 1.4 }}>{pt.desc}</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
|
||||
61
frontend/src/components/ai/AIChatInput.notice.test.tsx
Normal file
61
frontend/src/components/ai/AIChatInput.notice.test.tsx
Normal file
@@ -0,0 +1,61 @@
|
||||
import React from 'react';
|
||||
import { renderToStaticMarkup } from 'react-dom/server';
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { AIChatInput } from './AIChatInput';
|
||||
import { buildOverlayWorkbenchTheme } from '../../utils/overlayWorkbenchTheme';
|
||||
|
||||
vi.mock('../../store', () => ({
|
||||
useStore: (selector: (state: any) => any) => selector({
|
||||
aiContexts: {},
|
||||
addAIContext: vi.fn(),
|
||||
removeAIContext: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock('../../../wailsjs/go/app/App', () => ({
|
||||
DBGetTables: vi.fn(),
|
||||
DBShowCreateTable: vi.fn(),
|
||||
DBGetDatabases: vi.fn(),
|
||||
}));
|
||||
|
||||
describe('AIChatInput notice layout', () => {
|
||||
it('renders the composer notice above the input editor', () => {
|
||||
const markup = renderToStaticMarkup(
|
||||
<AIChatInput
|
||||
input=""
|
||||
setInput={() => {}}
|
||||
draftImages={[]}
|
||||
setDraftImages={() => {}}
|
||||
sending={false}
|
||||
onSend={() => {}}
|
||||
onStop={() => {}}
|
||||
handleKeyDown={() => {}}
|
||||
activeConnName=""
|
||||
activeContext={null}
|
||||
activeProvider={{ model: '', models: [] }}
|
||||
dynamicModels={[]}
|
||||
loadingModels={false}
|
||||
composerNotice={{
|
||||
tone: 'error',
|
||||
title: '模型列表加载失败',
|
||||
description: '请检查供应商入口和 API Key。',
|
||||
}}
|
||||
onModelChange={() => {}}
|
||||
onFetchModels={() => {}}
|
||||
textareaRef={React.createRef<HTMLTextAreaElement>()}
|
||||
darkMode={false}
|
||||
textColor="#162033"
|
||||
mutedColor="rgba(16,24,40,0.55)"
|
||||
overlayTheme={buildOverlayWorkbenchTheme(false)}
|
||||
/>
|
||||
);
|
||||
|
||||
const noticeIndex = markup.indexOf('data-ai-chat-composer-notice="true"');
|
||||
const inputIndex = markup.indexOf('data-ai-chat-composer-input="true"');
|
||||
|
||||
expect(noticeIndex).toBeGreaterThanOrEqual(0);
|
||||
expect(inputIndex).toBeGreaterThanOrEqual(0);
|
||||
expect(noticeIndex).toBeLessThan(inputIndex);
|
||||
});
|
||||
});
|
||||
@@ -1,9 +1,10 @@
|
||||
import React from 'react';
|
||||
import { Input, Select, AutoComplete, Tooltip, Modal, Checkbox, Spin, message, Button, Tag } from 'antd';
|
||||
import { DatabaseOutlined, SendOutlined, TableOutlined, SearchOutlined, PictureOutlined } from '@ant-design/icons';
|
||||
import { DatabaseOutlined, SendOutlined, TableOutlined, SearchOutlined, PictureOutlined, ExclamationCircleFilled } from '@ant-design/icons';
|
||||
import { useStore } from '../../store';
|
||||
import { DBGetTables, DBShowCreateTable, DBGetDatabases } from '../../../wailsjs/go/app/App';
|
||||
import type { OverlayWorkbenchTheme } from '../../utils/overlayWorkbenchTheme';
|
||||
import type { AIComposerNotice } from '../../utils/aiComposerNotice';
|
||||
|
||||
interface AIChatInputProps {
|
||||
input: string;
|
||||
@@ -19,6 +20,7 @@ interface AIChatInputProps {
|
||||
activeProvider: any;
|
||||
dynamicModels: string[];
|
||||
loadingModels: boolean;
|
||||
composerNotice?: AIComposerNotice | null;
|
||||
onModelChange: (val: string) => void;
|
||||
onFetchModels: () => void;
|
||||
textareaRef: React.RefObject<HTMLTextAreaElement>;
|
||||
@@ -33,6 +35,7 @@ interface AIChatInputProps {
|
||||
export const AIChatInput: React.FC<AIChatInputProps> = ({
|
||||
input, setInput, draftImages, setDraftImages, sending, onSend, onStop, handleKeyDown,
|
||||
activeConnName, activeContext, activeProvider, dynamicModels, loadingModels,
|
||||
composerNotice,
|
||||
onModelChange, onFetchModels, textareaRef, darkMode, textColor, mutedColor, overlayTheme,
|
||||
contextUsageChars, maxContextChars
|
||||
}) => {
|
||||
@@ -67,6 +70,33 @@ export const AIChatInput: React.FC<AIChatInputProps> = ({
|
||||
|
||||
const filteredTables = contextTables.filter(t => t.name.toLowerCase().includes(searchText.toLowerCase()));
|
||||
const [contextExpanded, setContextExpanded] = React.useState(false);
|
||||
const composerNoticePalette = React.useMemo(() => {
|
||||
if (composerNotice?.tone === 'error') {
|
||||
return darkMode
|
||||
? {
|
||||
background: 'rgba(255,120,117,0.12)',
|
||||
borderColor: 'rgba(255,120,117,0.24)',
|
||||
iconColor: '#ff7875',
|
||||
}
|
||||
: {
|
||||
background: 'rgba(255,77,79,0.08)',
|
||||
borderColor: 'rgba(255,77,79,0.16)',
|
||||
iconColor: '#ff4d4f',
|
||||
};
|
||||
}
|
||||
|
||||
return darkMode
|
||||
? {
|
||||
background: 'rgba(250,173,20,0.12)',
|
||||
borderColor: 'rgba(250,173,20,0.22)',
|
||||
iconColor: '#ffd666',
|
||||
}
|
||||
: {
|
||||
background: 'rgba(250,173,20,0.08)',
|
||||
borderColor: 'rgba(250,173,20,0.18)',
|
||||
iconColor: '#d48806',
|
||||
};
|
||||
}, [composerNotice, darkMode]);
|
||||
|
||||
// Slash commands
|
||||
const [showSlashMenu, setShowSlashMenu] = React.useState(false);
|
||||
@@ -258,7 +288,31 @@ export const AIChatInput: React.FC<AIChatInputProps> = ({
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
<div style={{ position: 'relative' }}>
|
||||
{composerNotice && (
|
||||
<div
|
||||
data-ai-chat-composer-notice="true"
|
||||
style={{
|
||||
display: 'flex',
|
||||
alignItems: 'flex-start',
|
||||
gap: 8,
|
||||
padding: '8px 10px',
|
||||
borderRadius: 12,
|
||||
background: composerNoticePalette.background,
|
||||
border: `1px solid ${composerNoticePalette.borderColor}`,
|
||||
}}
|
||||
>
|
||||
<ExclamationCircleFilled style={{ color: composerNoticePalette.iconColor, fontSize: 14, marginTop: 1, flexShrink: 0 }} />
|
||||
<div style={{ minWidth: 0 }}>
|
||||
<div style={{ fontSize: 12, fontWeight: 600, color: textColor, lineHeight: 1.4 }}>
|
||||
{composerNotice.title}
|
||||
</div>
|
||||
<div style={{ fontSize: 11, color: mutedColor, lineHeight: 1.5, marginTop: 2, wordBreak: 'break-word' }}>
|
||||
{composerNotice.description}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
<div data-ai-chat-composer-input="true" style={{ position: 'relative' }}>
|
||||
{showSlashMenu && filteredSlashCmds.length > 0 && (
|
||||
<div style={{
|
||||
position: 'absolute', bottom: '100%', left: 0, right: 0, marginBottom: 4,
|
||||
@@ -354,9 +408,13 @@ export const AIChatInput: React.FC<AIChatInputProps> = ({
|
||||
<Select
|
||||
size="small"
|
||||
variant="filled"
|
||||
value={activeProvider.model || (dynamicModels.length > 0 ? dynamicModels[0] : activeProvider.models?.[0])}
|
||||
value={activeProvider.model || undefined}
|
||||
onChange={onModelChange}
|
||||
onDropdownVisibleChange={(open) => { if (open && dynamicModels.length === 0) onFetchModels(); }}
|
||||
onDropdownVisibleChange={(open) => {
|
||||
if (open && dynamicModels.length === 0 && (activeProvider.models || []).length === 0) {
|
||||
onFetchModels();
|
||||
}
|
||||
}}
|
||||
loading={loadingModels}
|
||||
options={(dynamicModels.length > 0 ? dynamicModels : (activeProvider.models || [])).map((m: string) => ({ label: m, value: m }))}
|
||||
style={{ width: 130, fontSize: 11, background: 'transparent' }}
|
||||
|
||||
@@ -204,7 +204,7 @@ export interface AIProviderConfig {
|
||||
baseUrl: string;
|
||||
model: string;
|
||||
models?: string[];
|
||||
apiFormat?: string; // custom 专用: openai | anthropic | gemini
|
||||
apiFormat?: string; // custom 专用: openai | anthropic | gemini | claude-cli
|
||||
headers?: Record<string, string>;
|
||||
maxTokens: number;
|
||||
temperature: number;
|
||||
@@ -243,4 +243,3 @@ export interface AISafetyResult {
|
||||
requiresConfirm: boolean;
|
||||
warningMessage?: string;
|
||||
}
|
||||
|
||||
|
||||
33
frontend/src/utils/aiComposerNotice.test.ts
Normal file
33
frontend/src/utils/aiComposerNotice.test.ts
Normal file
@@ -0,0 +1,33 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import {
|
||||
buildModelFetchFailedNotice,
|
||||
buildMissingModelNotice,
|
||||
buildMissingProviderNotice,
|
||||
} from './aiComposerNotice';
|
||||
|
||||
describe('ai composer notice helpers', () => {
|
||||
it('builds a compact notice for missing provider', () => {
|
||||
expect(buildMissingProviderNotice()).toEqual({
|
||||
tone: 'warning',
|
||||
title: '还没有可用供应商',
|
||||
description: '先在 AI 设置里添加并启用一个模型供应商。',
|
||||
});
|
||||
});
|
||||
|
||||
it('builds a compact notice for missing model selection', () => {
|
||||
expect(buildMissingModelNotice()).toEqual({
|
||||
tone: 'warning',
|
||||
title: '先选择一个模型',
|
||||
description: '打开下方模型下拉并选择模型;如果列表为空,请检查供应商入口和 API Key。',
|
||||
});
|
||||
});
|
||||
|
||||
it('builds a readable inline notice for model fetch failures', () => {
|
||||
expect(buildModelFetchFailedNotice('当前接口未返回可用模型')).toEqual({
|
||||
tone: 'error',
|
||||
title: '模型列表加载失败',
|
||||
description: '当前接口未返回可用模型',
|
||||
});
|
||||
});
|
||||
});
|
||||
27
frontend/src/utils/aiComposerNotice.ts
Normal file
27
frontend/src/utils/aiComposerNotice.ts
Normal file
@@ -0,0 +1,27 @@
|
||||
export type AIComposerNoticeTone = 'warning' | 'error';
|
||||
|
||||
export interface AIComposerNotice {
|
||||
tone: AIComposerNoticeTone;
|
||||
title: string;
|
||||
description: string;
|
||||
}
|
||||
|
||||
const defaultModelFetchFailedDescription = '请检查供应商入口、API Key 或账号权限,然后重新打开模型下拉。';
|
||||
|
||||
export const buildMissingProviderNotice = (): AIComposerNotice => ({
|
||||
tone: 'warning',
|
||||
title: '还没有可用供应商',
|
||||
description: '先在 AI 设置里添加并启用一个模型供应商。',
|
||||
});
|
||||
|
||||
export const buildMissingModelNotice = (): AIComposerNotice => ({
|
||||
tone: 'warning',
|
||||
title: '先选择一个模型',
|
||||
description: '打开下方模型下拉并选择模型;如果列表为空,请检查供应商入口和 API Key。',
|
||||
});
|
||||
|
||||
export const buildModelFetchFailedNotice = (error?: string): AIComposerNotice => ({
|
||||
tone: 'error',
|
||||
title: '模型列表加载失败',
|
||||
description: String(error || '').trim() || defaultModelFetchFailedDescription,
|
||||
});
|
||||
111
frontend/src/utils/aiProviderPresets.test.ts
Normal file
111
frontend/src/utils/aiProviderPresets.test.ts
Normal file
@@ -0,0 +1,111 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import {
|
||||
matchQwenPresetKey,
|
||||
QWEN_BAILIAN_MODELS_BASE_URL,
|
||||
QWEN_CODING_PLAN_ANTHROPIC_BASE_URL,
|
||||
QWEN_CODING_PLAN_MODELS,
|
||||
resolvePresetBaseURL,
|
||||
resolvePresetModelSelection,
|
||||
resolvePresetTransport,
|
||||
} from './aiProviderPresets';
|
||||
|
||||
describe('ai provider preset helpers', () => {
|
||||
it('maps legacy Bailian compatible-mode URL back to the Bailian preset', () => {
|
||||
expect(matchQwenPresetKey({
|
||||
type: 'openai',
|
||||
baseUrl: QWEN_BAILIAN_MODELS_BASE_URL,
|
||||
})).toBe('qwen-bailian');
|
||||
});
|
||||
|
||||
it('maps Coding Plan anthropic URL to the dedicated Coding Plan preset', () => {
|
||||
expect(matchQwenPresetKey({
|
||||
type: 'anthropic',
|
||||
baseUrl: QWEN_CODING_PLAN_ANTHROPIC_BASE_URL,
|
||||
})).toBe('qwen-coding-plan');
|
||||
});
|
||||
|
||||
it('maps Coding Plan Claude CLI config back to the dedicated Coding Plan preset', () => {
|
||||
expect(matchQwenPresetKey({
|
||||
type: 'custom',
|
||||
apiFormat: 'claude-cli',
|
||||
baseUrl: QWEN_CODING_PLAN_ANTHROPIC_BASE_URL,
|
||||
})).toBe('qwen-coding-plan');
|
||||
});
|
||||
|
||||
it('does not keep a baked-in model list for the Coding Plan preset', () => {
|
||||
expect(QWEN_CODING_PLAN_MODELS).toEqual([
|
||||
'qwen3.5-plus',
|
||||
'kimi-k2.5',
|
||||
'glm-5',
|
||||
'MiniMax-M2.5',
|
||||
'qwen3-max-2026-01-23',
|
||||
'qwen3-coder-next',
|
||||
'qwen3-coder-plus',
|
||||
'glm-4.7',
|
||||
]);
|
||||
});
|
||||
|
||||
it('keeps built-in preset model empty when the preset intentionally requires an explicit selection', () => {
|
||||
expect(resolvePresetModelSelection({
|
||||
presetKey: 'qwen-coding-plan',
|
||||
presetDefaultModel: '',
|
||||
presetModels: QWEN_CODING_PLAN_MODELS,
|
||||
valuesModel: '',
|
||||
customModels: [],
|
||||
})).toEqual({
|
||||
model: '',
|
||||
models: QWEN_CODING_PLAN_MODELS,
|
||||
});
|
||||
});
|
||||
|
||||
it('still falls back to the first configured model for custom-like presets', () => {
|
||||
expect(resolvePresetModelSelection({
|
||||
presetKey: 'custom',
|
||||
presetDefaultModel: '',
|
||||
presetModels: [],
|
||||
valuesModel: '',
|
||||
customModels: ['foo-model', 'bar-model'],
|
||||
})).toEqual({
|
||||
model: 'foo-model',
|
||||
models: ['foo-model', 'bar-model'],
|
||||
});
|
||||
});
|
||||
|
||||
it('forces built-in presets back to their standard base URL when saving or testing', () => {
|
||||
expect(resolvePresetBaseURL({
|
||||
presetKey: 'qwen-bailian',
|
||||
presetDefaultBaseUrl: 'https://dashscope.aliyuncs.com/apps/anthropic',
|
||||
valuesBaseUrl: 'https://dashscope.aliyuncs.com/compatible-mode/v1',
|
||||
})).toBe('https://dashscope.aliyuncs.com/apps/anthropic');
|
||||
});
|
||||
|
||||
it('keeps the user-entered base URL for custom-like presets', () => {
|
||||
expect(resolvePresetBaseURL({
|
||||
presetKey: 'custom',
|
||||
presetDefaultBaseUrl: '',
|
||||
valuesBaseUrl: 'https://example-proxy.internal/v1',
|
||||
})).toBe('https://example-proxy.internal/v1');
|
||||
});
|
||||
|
||||
it('forces qwen coding plan to save as custom plus claude-cli', () => {
|
||||
expect(resolvePresetTransport({
|
||||
presetBackendType: 'custom',
|
||||
presetFixedApiFormat: 'claude-cli',
|
||||
valuesApiFormat: 'anthropic',
|
||||
})).toEqual({
|
||||
type: 'custom',
|
||||
apiFormat: 'claude-cli',
|
||||
});
|
||||
});
|
||||
|
||||
it('keeps custom preset transport editable', () => {
|
||||
expect(resolvePresetTransport({
|
||||
presetBackendType: 'custom',
|
||||
valuesApiFormat: 'gemini',
|
||||
})).toEqual({
|
||||
type: 'custom',
|
||||
apiFormat: 'gemini',
|
||||
});
|
||||
});
|
||||
});
|
||||
143
frontend/src/utils/aiProviderPresets.ts
Normal file
143
frontend/src/utils/aiProviderPresets.ts
Normal file
@@ -0,0 +1,143 @@
|
||||
import type { AIProviderConfig, AIProviderType } from '../types';
|
||||
|
||||
export const LEGACY_QWEN_BAILIAN_OPENAI_BASE_URL = 'https://dashscope.aliyuncs.com/compatible-mode/v1';
|
||||
export const LEGACY_QWEN_CODING_PLAN_OPENAI_BASE_URL = 'https://coding.dashscope.aliyuncs.com/v1';
|
||||
export const QWEN_BAILIAN_ANTHROPIC_BASE_URL = 'https://dashscope.aliyuncs.com/apps/anthropic';
|
||||
export const QWEN_CODING_PLAN_ANTHROPIC_BASE_URL = 'https://coding.dashscope.aliyuncs.com/apps/anthropic';
|
||||
export const QWEN_BAILIAN_MODELS_BASE_URL = LEGACY_QWEN_BAILIAN_OPENAI_BASE_URL;
|
||||
|
||||
export const QWEN_CODING_PLAN_MODELS = [
|
||||
'qwen3.5-plus',
|
||||
'kimi-k2.5',
|
||||
'glm-5',
|
||||
'MiniMax-M2.5',
|
||||
'qwen3-max-2026-01-23',
|
||||
'qwen3-coder-next',
|
||||
'qwen3-coder-plus',
|
||||
'glm-4.7',
|
||||
];
|
||||
|
||||
const CUSTOM_LIKE_PRESET_KEYS = new Set(['custom', 'ollama']);
|
||||
|
||||
export interface ResolvePresetModelSelectionInput {
|
||||
presetKey: string;
|
||||
presetDefaultModel: string;
|
||||
presetModels: string[];
|
||||
valuesModel?: string;
|
||||
customModels?: string[];
|
||||
}
|
||||
|
||||
export interface ResolvePresetModelSelectionResult {
|
||||
model: string;
|
||||
models: string[];
|
||||
}
|
||||
|
||||
export interface ResolvePresetBaseURLInput {
|
||||
presetKey: string;
|
||||
presetDefaultBaseUrl: string;
|
||||
valuesBaseUrl?: string;
|
||||
}
|
||||
|
||||
export interface ResolvePresetTransportInput {
|
||||
presetBackendType: AIProviderType;
|
||||
presetFixedApiFormat?: string;
|
||||
valuesApiFormat?: string;
|
||||
}
|
||||
|
||||
export interface ResolvePresetTransportResult {
|
||||
type: AIProviderType;
|
||||
apiFormat?: string;
|
||||
}
|
||||
|
||||
export const getProviderHostname = (raw?: string): string => {
|
||||
if (!raw) return '';
|
||||
try {
|
||||
return new URL(raw).hostname.toLowerCase();
|
||||
} catch {
|
||||
return '';
|
||||
}
|
||||
};
|
||||
|
||||
export const getProviderFingerprint = (raw?: string): string => {
|
||||
if (!raw) return '';
|
||||
try {
|
||||
const url = new URL(raw);
|
||||
const normalizedPath = url.pathname.replace(/\/+$/, '').toLowerCase();
|
||||
return `${url.hostname.toLowerCase()}${normalizedPath}`;
|
||||
} catch {
|
||||
return '';
|
||||
}
|
||||
};
|
||||
|
||||
export const matchQwenPresetKey = (provider: Pick<AIProviderConfig, 'type' | 'baseUrl' | 'apiFormat'>): string | null => {
|
||||
const fingerprint = getProviderFingerprint(provider.baseUrl);
|
||||
const bailianFingerprints = new Set([
|
||||
getProviderFingerprint(LEGACY_QWEN_BAILIAN_OPENAI_BASE_URL),
|
||||
getProviderFingerprint(QWEN_BAILIAN_ANTHROPIC_BASE_URL),
|
||||
]);
|
||||
if (fingerprint !== '' && bailianFingerprints.has(fingerprint)) {
|
||||
return 'qwen-bailian';
|
||||
}
|
||||
|
||||
const codingPlanFingerprints = new Set([
|
||||
getProviderFingerprint(LEGACY_QWEN_CODING_PLAN_OPENAI_BASE_URL),
|
||||
getProviderFingerprint(QWEN_CODING_PLAN_ANTHROPIC_BASE_URL),
|
||||
]);
|
||||
if (fingerprint !== '' && codingPlanFingerprints.has(fingerprint)) {
|
||||
return 'qwen-coding-plan';
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
export const resolvePresetModelSelection = ({
|
||||
presetKey,
|
||||
presetDefaultModel,
|
||||
presetModels,
|
||||
valuesModel,
|
||||
customModels,
|
||||
}: ResolvePresetModelSelectionInput): ResolvePresetModelSelectionResult => {
|
||||
const isCustomLike = CUSTOM_LIKE_PRESET_KEYS.has(presetKey);
|
||||
const resolvedModels = isCustomLike ? (customModels || []) : presetModels;
|
||||
const fallbackModel = resolvedModels.length > 0 ? resolvedModels[0] : '';
|
||||
return {
|
||||
models: resolvedModels,
|
||||
model: isCustomLike ? (valuesModel || fallbackModel) : (valuesModel || presetDefaultModel),
|
||||
};
|
||||
};
|
||||
|
||||
export const resolvePresetBaseURL = ({
|
||||
presetKey,
|
||||
presetDefaultBaseUrl,
|
||||
valuesBaseUrl,
|
||||
}: ResolvePresetBaseURLInput): string => {
|
||||
if (CUSTOM_LIKE_PRESET_KEYS.has(presetKey)) {
|
||||
return valuesBaseUrl || presetDefaultBaseUrl;
|
||||
}
|
||||
return presetDefaultBaseUrl;
|
||||
};
|
||||
|
||||
export const resolvePresetTransport = ({
|
||||
presetBackendType,
|
||||
presetFixedApiFormat,
|
||||
valuesApiFormat,
|
||||
}: ResolvePresetTransportInput): ResolvePresetTransportResult => {
|
||||
if (presetFixedApiFormat) {
|
||||
return {
|
||||
type: presetBackendType,
|
||||
apiFormat: presetFixedApiFormat,
|
||||
};
|
||||
}
|
||||
|
||||
if (presetBackendType === 'custom') {
|
||||
return {
|
||||
type: presetBackendType,
|
||||
apiFormat: valuesApiFormat || 'openai',
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
type: presetBackendType,
|
||||
apiFormat: undefined,
|
||||
};
|
||||
};
|
||||
56
frontend/src/utils/aiSettingsPresetLayout.test.ts
Normal file
56
frontend/src/utils/aiSettingsPresetLayout.test.ts
Normal file
@@ -0,0 +1,56 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import {
|
||||
PROVIDER_PRESET_CARD_BASE_STYLE,
|
||||
PROVIDER_PRESET_CARD_CONTENT_STYLE,
|
||||
PROVIDER_PRESET_CARD_DESCRIPTION_STYLE,
|
||||
PROVIDER_PRESET_GRID_STYLE,
|
||||
PROVIDER_PRESET_CARD_TITLE_STYLE,
|
||||
} from './aiSettingsPresetLayout';
|
||||
|
||||
describe('ai settings preset layout', () => {
|
||||
it('uses a fixed grid auto row height so provider bubbles stay visually consistent across rows', () => {
|
||||
expect(PROVIDER_PRESET_GRID_STYLE).toMatchObject({
|
||||
display: 'grid',
|
||||
gridTemplateColumns: 'repeat(3, minmax(0, 1fr))',
|
||||
gap: 6,
|
||||
gridAutoRows: '96px',
|
||||
alignItems: 'stretch',
|
||||
});
|
||||
});
|
||||
|
||||
it('stretches each provider card to fill the row height', () => {
|
||||
expect(PROVIDER_PRESET_CARD_BASE_STYLE).toMatchObject({
|
||||
display: 'flex',
|
||||
alignItems: 'flex-start',
|
||||
gap: 10,
|
||||
height: '100%',
|
||||
minHeight: '96px',
|
||||
overflow: 'hidden',
|
||||
});
|
||||
});
|
||||
|
||||
it('keeps the text column compact instead of pinning the description to the bottom', () => {
|
||||
expect(PROVIDER_PRESET_CARD_CONTENT_STYLE).toMatchObject({
|
||||
minWidth: 0,
|
||||
flex: 1,
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
});
|
||||
|
||||
expect(PROVIDER_PRESET_CARD_DESCRIPTION_STYLE).toMatchObject({
|
||||
marginTop: 4,
|
||||
display: '-webkit-box',
|
||||
WebkitLineClamp: 2,
|
||||
WebkitBoxOrient: 'vertical',
|
||||
overflow: 'hidden',
|
||||
});
|
||||
|
||||
expect(PROVIDER_PRESET_CARD_TITLE_STYLE).toMatchObject({
|
||||
display: '-webkit-box',
|
||||
WebkitLineClamp: 2,
|
||||
WebkitBoxOrient: 'vertical',
|
||||
overflow: 'hidden',
|
||||
});
|
||||
});
|
||||
});
|
||||
47
frontend/src/utils/aiSettingsPresetLayout.ts
Normal file
47
frontend/src/utils/aiSettingsPresetLayout.ts
Normal file
@@ -0,0 +1,47 @@
|
||||
import type { CSSProperties } from 'react';
|
||||
|
||||
export const PROVIDER_PRESET_CARD_HEIGHT = 96;
|
||||
|
||||
export const PROVIDER_PRESET_GRID_STYLE: CSSProperties = {
|
||||
display: 'grid',
|
||||
gridTemplateColumns: 'repeat(3, minmax(0, 1fr))',
|
||||
gap: 6,
|
||||
gridAutoRows: `${PROVIDER_PRESET_CARD_HEIGHT}px`,
|
||||
alignItems: 'stretch',
|
||||
};
|
||||
|
||||
export const PROVIDER_PRESET_CARD_BASE_STYLE: CSSProperties = {
|
||||
padding: '12px 14px',
|
||||
borderRadius: 12,
|
||||
cursor: 'pointer',
|
||||
transition: 'all 0.2s ease',
|
||||
display: 'flex',
|
||||
alignItems: 'flex-start',
|
||||
gap: 10,
|
||||
height: '100%',
|
||||
minHeight: `${PROVIDER_PRESET_CARD_HEIGHT}px`,
|
||||
boxSizing: 'border-box',
|
||||
overflow: 'hidden',
|
||||
};
|
||||
|
||||
export const PROVIDER_PRESET_CARD_CONTENT_STYLE: CSSProperties = {
|
||||
minWidth: 0,
|
||||
flex: 1,
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
};
|
||||
|
||||
export const PROVIDER_PRESET_CARD_DESCRIPTION_STYLE: CSSProperties = {
|
||||
marginTop: 4,
|
||||
display: '-webkit-box',
|
||||
WebkitLineClamp: 2,
|
||||
WebkitBoxOrient: 'vertical',
|
||||
overflow: 'hidden',
|
||||
};
|
||||
|
||||
export const PROVIDER_PRESET_CARD_TITLE_STYLE: CSSProperties = {
|
||||
display: '-webkit-box',
|
||||
WebkitLineClamp: 2,
|
||||
WebkitBoxOrient: 'vertical',
|
||||
overflow: 'hidden',
|
||||
};
|
||||
3
frontend/wailsjs/go/aiservice/Service.d.ts
vendored
3
frontend/wailsjs/go/aiservice/Service.d.ts
vendored
@@ -1,7 +1,6 @@
|
||||
// Cynhyrchwyd y ffeil hon yn awtomatig. PEIDIWCH Â MODIWL
|
||||
// This file is automatically generated. DO NOT EDIT
|
||||
import {ai} from '../models';
|
||||
import {context} from '../models';
|
||||
|
||||
export function AIChatCancel(arg1:string):Promise<void>;
|
||||
|
||||
@@ -42,5 +41,3 @@ export function AISetContextLevel(arg1:string):Promise<void>;
|
||||
export function AISetSafetyLevel(arg1:string):Promise<void>;
|
||||
|
||||
export function AITestProvider(arg1:ai.ProviderConfig):Promise<Record<string, any>>;
|
||||
|
||||
export function Startup(arg1:context.Context):Promise<void>;
|
||||
|
||||
@@ -81,7 +81,3 @@ export function AISetSafetyLevel(arg1) {
|
||||
export function AITestProvider(arg1) {
|
||||
return window['go']['aiservice']['Service']['AITestProvider'](arg1);
|
||||
}
|
||||
|
||||
export function Startup(arg1) {
|
||||
return window['go']['aiservice']['Service']['Startup'](arg1);
|
||||
}
|
||||
|
||||
6
frontend/wailsjs/go/app/App.d.ts
vendored
6
frontend/wailsjs/go/app/App.d.ts
vendored
@@ -1,10 +1,8 @@
|
||||
// Cynhyrchwyd y ffeil hon yn awtomatig. PEIDIWCH Â MODIWL
|
||||
// This file is automatically generated. DO NOT EDIT
|
||||
import {connection} from '../models';
|
||||
import {time} from '../models';
|
||||
import {sync} from '../models';
|
||||
import {redis} from '../models';
|
||||
import {context} from '../models';
|
||||
|
||||
export function ApplyChanges(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:connection.ChangeSet):Promise<connection.QueryResult>;
|
||||
|
||||
@@ -16,8 +14,6 @@ export function CheckDriverNetworkStatus():Promise<connection.QueryResult>;
|
||||
|
||||
export function CheckForUpdates():Promise<connection.QueryResult>;
|
||||
|
||||
export function CleanupStaleQueries(arg1:time.Duration):Promise<void>;
|
||||
|
||||
export function ConfigureDriverRuntimeDirectory(arg1:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function ConfigureGlobalProxy(arg1:boolean,arg2:connection.ProxyConfig):Promise<connection.QueryResult>;
|
||||
@@ -198,8 +194,6 @@ export function SetMacNativeWindowControls(arg1:boolean):Promise<void>;
|
||||
|
||||
export function SetWindowTranslucency(arg1:number,arg2:number):Promise<void>;
|
||||
|
||||
export function Startup(arg1:context.Context):Promise<void>;
|
||||
|
||||
export function TestConnection(arg1:connection.ConnectionConfig):Promise<connection.QueryResult>;
|
||||
|
||||
export function TruncateTables(arg1:connection.ConnectionConfig,arg2:string,arg3:Array<string>):Promise<connection.QueryResult>;
|
||||
|
||||
@@ -22,10 +22,6 @@ export function CheckForUpdates() {
|
||||
return window['go']['app']['App']['CheckForUpdates']();
|
||||
}
|
||||
|
||||
export function CleanupStaleQueries(arg1) {
|
||||
return window['go']['app']['App']['CleanupStaleQueries'](arg1);
|
||||
}
|
||||
|
||||
export function ConfigureDriverRuntimeDirectory(arg1) {
|
||||
return window['go']['app']['App']['ConfigureDriverRuntimeDirectory'](arg1);
|
||||
}
|
||||
@@ -386,10 +382,6 @@ export function SetWindowTranslucency(arg1, arg2) {
|
||||
return window['go']['app']['App']['SetWindowTranslucency'](arg1, arg2);
|
||||
}
|
||||
|
||||
export function Startup(arg1) {
|
||||
return window['go']['app']['App']['Startup'](arg1);
|
||||
}
|
||||
|
||||
export function TestConnection(arg1) {
|
||||
return window['go']['app']['App']['TestConnection'](arg1);
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"GoNavi-Wails/internal/ai"
|
||||
@@ -32,6 +33,25 @@ func normalizeAnthropicMessagesURL(baseURL string) string {
|
||||
return url + "/v1/messages"
|
||||
}
|
||||
|
||||
func IsDashScopeAnthropicCompatibleBaseURL(baseURL string) bool {
|
||||
parsed, err := url.Parse(strings.TrimSpace(baseURL))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
host := strings.ToLower(parsed.Hostname())
|
||||
return host == "dashscope.aliyuncs.com" || host == "coding.dashscope.aliyuncs.com"
|
||||
}
|
||||
|
||||
func ApplyAnthropicAuthHeaders(headers http.Header, baseURL string, apiKey string) {
|
||||
headers.Set("x-api-key", apiKey)
|
||||
if IsDashScopeAnthropicCompatibleBaseURL(baseURL) {
|
||||
headers.Set("Authorization", "Bearer "+apiKey)
|
||||
headers.Del("anthropic-version")
|
||||
return
|
||||
}
|
||||
headers.Set("anthropic-version", anthropicAPIVersion)
|
||||
}
|
||||
|
||||
// AnthropicProvider 实现 Anthropic Claude API 的 Provider
|
||||
type AnthropicProvider struct {
|
||||
config ai.ProviderConfig
|
||||
@@ -446,8 +466,7 @@ func (p *AnthropicProvider) doRequest(ctx context.Context, body interface{}) (io
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("x-api-key", p.config.APIKey)
|
||||
httpReq.Header.Set("anthropic-version", anthropicAPIVersion)
|
||||
ApplyAnthropicAuthHeaders(httpReq.Header, p.baseURL, p.config.APIKey)
|
||||
|
||||
if strings.Contains(string(jsonBody), `"stream":true`) || strings.Contains(string(jsonBody), `"stream": true`) {
|
||||
httpReq.Header.Set("Accept", "text/event-stream")
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package provider
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNormalizeAnthropicMessagesURL_AppendsMessagesSuffix(t *testing.T) {
|
||||
url := normalizeAnthropicMessagesURL("https://api.anthropic.com")
|
||||
@@ -22,3 +25,33 @@ func TestNormalizeAnthropicMessagesURL_PreservesExplicitMessagesPath(t *testing.
|
||||
t.Fatalf("expected explicit messages path to be preserved, got %q", url)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyAnthropicAuthHeaders_UsesOfficialAnthropicHeadersForAnthropicAPI(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
ApplyAnthropicAuthHeaders(headers, "https://api.anthropic.com", "sk-test")
|
||||
|
||||
if got := headers.Get("x-api-key"); got != "sk-test" {
|
||||
t.Fatalf("expected x-api-key header, got %q", got)
|
||||
}
|
||||
if got := headers.Get("anthropic-version"); got != anthropicAPIVersion {
|
||||
t.Fatalf("expected anthropic-version header, got %q", got)
|
||||
}
|
||||
if got := headers.Get("Authorization"); got != "" {
|
||||
t.Fatalf("expected no authorization header for official anthropic, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyAnthropicAuthHeaders_UsesBearerForDashScopeCompatibleAnthropic(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
ApplyAnthropicAuthHeaders(headers, "https://coding.dashscope.aliyuncs.com/apps/anthropic", "sk-sp-test")
|
||||
|
||||
if got := headers.Get("Authorization"); got != "Bearer sk-sp-test" {
|
||||
t.Fatalf("expected bearer authorization header, got %q", got)
|
||||
}
|
||||
if got := headers.Get("x-api-key"); got != "sk-sp-test" {
|
||||
t.Fatalf("expected x-api-key header, got %q", got)
|
||||
}
|
||||
if got := headers.Get("anthropic-version"); got != "" {
|
||||
t.Fatalf("expected no anthropic-version header for DashScope, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,16 +5,20 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
ai "GoNavi-Wails/internal/ai"
|
||||
)
|
||||
|
||||
var claudeLookPath = exec.LookPath
|
||||
var claudeCommandContext = exec.CommandContext
|
||||
var claudeCLIRequestTimeout = 90 * time.Second
|
||||
|
||||
// ClaudeCLIProvider 通过 Claude Code CLI 发送聊天请求
|
||||
// 适用于 anyrouter/newapi 等只支持 Claude Code 协议的代理服务
|
||||
@@ -48,19 +52,25 @@ func (p *ClaudeCLIProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.C
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := ensureClaudeCLITimeout(ctx, claudeCLIRequestTimeout)
|
||||
defer cancel()
|
||||
|
||||
prompt := buildPrompt(req.Messages)
|
||||
args := []string{"-p", prompt, "--output-format", "json", "--no-session-persistence"}
|
||||
if p.config.Model != "" {
|
||||
args = append(args, "--model", p.config.Model)
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "claude", args...)
|
||||
cmd := claudeCommandContext(ctx, "claude", args...)
|
||||
if err := p.setEnv(cmd); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
if isClaudeCLITimeout(ctx, err) {
|
||||
return nil, fmt.Errorf("claude CLI 执行超时(%s),当前 Base URL 或 API Key 可能没有返回有效响应", claudeCLIRequestTimeout)
|
||||
}
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
return nil, fmt.Errorf("claude CLI 执行失败: %s", string(exitErr.Stderr))
|
||||
}
|
||||
@@ -68,13 +78,14 @@ func (p *ClaudeCLIProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.C
|
||||
}
|
||||
|
||||
// 解析 JSON 输出
|
||||
var result struct {
|
||||
Result string `json:"result"`
|
||||
}
|
||||
var result cliStreamEvent
|
||||
if err := json.Unmarshal(output, &result); err != nil {
|
||||
// 如果 JSON 解析失败,直接返回原始文本
|
||||
return &ai.ChatResponse{Content: strings.TrimSpace(string(output))}, nil
|
||||
}
|
||||
if errMsg, hasError := extractClaudeCLIEventError(result); hasError {
|
||||
return nil, fmt.Errorf("claude CLI 返回错误: %s", errMsg)
|
||||
}
|
||||
|
||||
return &ai.ChatResponse{Content: result.Result}, nil
|
||||
}
|
||||
@@ -85,6 +96,9 @@ func (p *ClaudeCLIProvider) ChatStream(ctx context.Context, req ai.ChatRequest,
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := ensureClaudeCLITimeout(ctx, claudeCLIRequestTimeout)
|
||||
defer cancel()
|
||||
|
||||
prompt := buildPrompt(req.Messages)
|
||||
args := []string{"-p", prompt, "--output-format", "stream-json", "--verbose", "--include-partial-messages", "--no-session-persistence"}
|
||||
if p.config.Model != "" {
|
||||
@@ -93,7 +107,7 @@ func (p *ClaudeCLIProvider) ChatStream(ctx context.Context, req ai.ChatRequest,
|
||||
|
||||
fmt.Printf("[ClaudeCLI DEBUG] Running: claude %v\n", args)
|
||||
|
||||
cmd := exec.CommandContext(ctx, "claude", args...)
|
||||
cmd := claudeCommandContext(ctx, "claude", args...)
|
||||
if err := p.setEnv(cmd); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -137,7 +151,23 @@ func (p *ClaudeCLIProvider) ChatStream(ctx context.Context, req ai.ChatRequest,
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case "system":
|
||||
if isClaudeCLISystemRetryEvent(event) {
|
||||
if errMsg, hasError := extractClaudeCLISystemRetryError(event); hasError {
|
||||
callback(ai.StreamChunk{Error: errMsg, Done: true})
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
case "assistant":
|
||||
if errMsg, hasError := extractClaudeCLIEventError(event); hasError {
|
||||
callback(ai.StreamChunk{Error: errMsg, Done: true})
|
||||
_ = cmd.Wait()
|
||||
return nil
|
||||
}
|
||||
// 助手消息开始或文本内容
|
||||
if event.Message.Content != nil {
|
||||
for _, block := range event.Message.Content {
|
||||
@@ -156,12 +186,18 @@ func (p *ClaudeCLIProvider) ChatStream(ctx context.Context, req ai.ChatRequest,
|
||||
callback(ai.StreamChunk{Content: event.Delta.Text})
|
||||
}
|
||||
case "result":
|
||||
if errMsg, hasError := extractClaudeCLIEventError(event); hasError {
|
||||
callback(ai.StreamChunk{Error: errMsg, Done: true})
|
||||
_ = cmd.Wait()
|
||||
return nil
|
||||
}
|
||||
// 最终结果事件 — 不发送 content(assistant 事件已包含),只标记完成
|
||||
callback(ai.StreamChunk{Done: true})
|
||||
_ = cmd.Wait()
|
||||
return nil
|
||||
case "error":
|
||||
callback(ai.StreamChunk{Error: event.Error.Message, Done: true})
|
||||
errMsg, _ := extractClaudeCLIEventError(event)
|
||||
callback(ai.StreamChunk{Error: errMsg, Done: true})
|
||||
_ = cmd.Wait()
|
||||
return nil
|
||||
}
|
||||
@@ -171,6 +207,14 @@ func (p *ClaudeCLIProvider) ChatStream(ctx context.Context, req ai.ChatRequest,
|
||||
stderrStr := strings.TrimSpace(stderrBuf.String())
|
||||
fmt.Printf("[ClaudeCLI DEBUG] Process exited. stderr: %s\n", stderrStr)
|
||||
|
||||
if isClaudeCLITimeout(ctx, waitErr) {
|
||||
callback(ai.StreamChunk{
|
||||
Error: fmt.Sprintf("claude CLI 执行超时(%s),当前 Base URL 或 API Key 可能没有返回有效响应", claudeCLIRequestTimeout),
|
||||
Done: true,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
if waitErr != nil {
|
||||
errMsg := fmt.Sprintf("claude CLI 异常退出: %v", waitErr)
|
||||
if stderrStr != "" {
|
||||
@@ -184,6 +228,20 @@ func (p *ClaudeCLIProvider) ChatStream(ctx context.Context, req ai.ChatRequest,
|
||||
return nil
|
||||
}
|
||||
|
||||
func ensureClaudeCLITimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
|
||||
if _, hasDeadline := ctx.Deadline(); hasDeadline || timeout <= 0 {
|
||||
return ctx, func() {}
|
||||
}
|
||||
return context.WithTimeout(ctx, timeout)
|
||||
}
|
||||
|
||||
func isClaudeCLITimeout(ctx context.Context, err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return errors.Is(ctx.Err(), context.DeadlineExceeded) || errors.Is(err, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
// setEnv 设置 Claude CLI 的环境变量
|
||||
func (p *ClaudeCLIProvider) setEnv(cmd *exec.Cmd) error {
|
||||
env, err := buildClaudeCLIEnv(p.config, cmd.Environ(), runtime.GOOS, claudeLookPath, fileExists)
|
||||
@@ -200,6 +258,7 @@ func buildClaudeCLIEnv(config ai.ProviderConfig, baseEnv []string, goos string,
|
||||
env = upsertEnv(env, "ANTHROPIC_BASE_URL", strings.TrimRight(config.BaseURL, "/"))
|
||||
}
|
||||
if config.APIKey != "" {
|
||||
env = upsertEnv(env, "ANTHROPIC_AUTH_TOKEN", config.APIKey)
|
||||
env = upsertEnv(env, "ANTHROPIC_API_KEY", config.APIKey)
|
||||
}
|
||||
|
||||
@@ -354,8 +413,15 @@ func buildPrompt(messages []ai.Message) string {
|
||||
|
||||
// cliStreamEvent Claude CLI stream-json 输出的事件结构
|
||||
type cliStreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
Message struct {
|
||||
Type string `json:"type"`
|
||||
Subtype string `json:"subtype,omitempty"`
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
Attempt int `json:"attempt,omitempty"`
|
||||
MaxRetries int `json:"max_retries,omitempty"`
|
||||
RetryDelayMS float64 `json:"retry_delay_ms,omitempty"`
|
||||
ErrorStatus int `json:"error_status,omitempty"`
|
||||
SessionID string `json:"session_id,omitempty"`
|
||||
Message struct {
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
@@ -367,8 +433,79 @@ type cliStreamEvent struct {
|
||||
Text string `json:"text"`
|
||||
Thinking string `json:"thinking"`
|
||||
} `json:"delta,omitempty"`
|
||||
Result string `json:"result,omitempty"`
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error,omitempty"`
|
||||
Result string `json:"result,omitempty"`
|
||||
Error cliStreamEventError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type cliStreamEventError struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *cliStreamEventError) UnmarshalJSON(data []byte) error {
|
||||
trimmed := strings.TrimSpace(string(data))
|
||||
if trimmed == "" || trimmed == "null" {
|
||||
e.Message = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
var text string
|
||||
if err := json.Unmarshal(data, &text); err == nil {
|
||||
e.Message = strings.TrimSpace(text)
|
||||
return nil
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &payload); err != nil {
|
||||
return err
|
||||
}
|
||||
e.Message = strings.TrimSpace(payload.Message)
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractClaudeCLIEventError(event cliStreamEvent) (string, bool) {
|
||||
if event.Type != "error" && !event.IsError {
|
||||
return "", false
|
||||
}
|
||||
|
||||
if msg := strings.TrimSpace(event.Result); msg != "" {
|
||||
return msg, true
|
||||
}
|
||||
|
||||
for _, block := range event.Message.Content {
|
||||
if block.Type == "text" && strings.TrimSpace(block.Text) != "" {
|
||||
return strings.TrimSpace(block.Text), true
|
||||
}
|
||||
}
|
||||
|
||||
if msg := strings.TrimSpace(event.Error.Message); msg != "" {
|
||||
return msg, true
|
||||
}
|
||||
|
||||
return "claude CLI 返回未知错误", true
|
||||
}
|
||||
|
||||
func isClaudeCLISystemRetryEvent(event cliStreamEvent) bool {
|
||||
return event.Type == "system" && event.Subtype == "api_retry"
|
||||
}
|
||||
|
||||
func extractClaudeCLISystemRetryError(event cliStreamEvent) (string, bool) {
|
||||
if !isClaudeCLISystemRetryEvent(event) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
errText := strings.TrimSpace(event.Error.Message)
|
||||
if event.ErrorStatus != 401 && event.ErrorStatus != 403 && !strings.EqualFold(errText, "authentication_failed") {
|
||||
return "", false
|
||||
}
|
||||
|
||||
if errText == "" {
|
||||
errText = "authentication_failed"
|
||||
}
|
||||
|
||||
if event.ErrorStatus > 0 {
|
||||
return fmt.Sprintf("claude CLI 鉴权失败 (HTTP %d): %s", event.ErrorStatus, errText), true
|
||||
}
|
||||
return fmt.Sprintf("claude CLI 鉴权失败: %s", errText), true
|
||||
}
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/ai"
|
||||
)
|
||||
@@ -26,6 +30,9 @@ func TestBuildClaudeCLIEnv_IncludesAnthropicProxyEnv(t *testing.T) {
|
||||
if got := envValue(env, "ANTHROPIC_API_KEY"); got != "sk-test" {
|
||||
t.Fatalf("expected api key in env, got %q", got)
|
||||
}
|
||||
if got := envValue(env, "ANTHROPIC_AUTH_TOKEN"); got != "sk-test" {
|
||||
t.Fatalf("expected auth token in env, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildClaudeCLIEnv_UsesDetectedGitBashOnWindows(t *testing.T) {
|
||||
@@ -67,3 +74,281 @@ func TestBuildClaudeCLIEnv_ReturnsActionableErrorWhenGitBashMissingOnWindows(t *
|
||||
t.Fatalf("expected env var hint, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeCLIProvider_ChatTimesOutWhenCommandDoesNotFinish(t *testing.T) {
|
||||
fakeClaude := writeFakeClaudeScript(t, "#!/bin/sh\nsleep 5\n")
|
||||
restore := overrideClaudeCLIForTest(t, fakeClaude)
|
||||
defer restore()
|
||||
|
||||
originalRequestTimeout := claudeCLIRequestTimeout
|
||||
claudeCLIRequestTimeout = 200 * time.Millisecond
|
||||
defer func() {
|
||||
claudeCLIRequestTimeout = originalRequestTimeout
|
||||
}()
|
||||
|
||||
provider, err := NewClaudeCLIProvider(ai.ProviderConfig{
|
||||
BaseURL: "https://coding.dashscope.aliyuncs.com/apps/anthropic",
|
||||
APIKey: "sk-test",
|
||||
Model: "qwen3.5-plus",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected provider error: %v", err)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
_, err = provider.Chat(context.Background(), ai.ChatRequest{
|
||||
Messages: []ai.Message{{Role: "user", Content: "ping"}},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected chat timeout error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "执行超时") {
|
||||
t.Fatalf("expected timeout error, got %v", err)
|
||||
}
|
||||
if time.Since(start) < 200*time.Millisecond {
|
||||
t.Fatalf("expected timeout path to wait for configured deadline, took %s", time.Since(start))
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeCLIProvider_ChatStreamUsesRequestTimeoutWhenNoMeaningfulResponseArrives(t *testing.T) {
|
||||
fakeClaude := writeFakeClaudeScript(t, "#!/bin/sh\necho '{\"type\":\"system\",\"subtype\":\"init\"}'\nexec sleep 5\n")
|
||||
restore := overrideClaudeCLIForTest(t, fakeClaude)
|
||||
defer restore()
|
||||
|
||||
originalRequestTimeout := claudeCLIRequestTimeout
|
||||
claudeCLIRequestTimeout = 200 * time.Millisecond
|
||||
defer func() {
|
||||
claudeCLIRequestTimeout = originalRequestTimeout
|
||||
}()
|
||||
|
||||
provider, err := NewClaudeCLIProvider(ai.ProviderConfig{
|
||||
BaseURL: "https://coding.dashscope.aliyuncs.com/apps/anthropic",
|
||||
APIKey: "sk-test",
|
||||
Model: "qwen3.5-plus",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected provider error: %v", err)
|
||||
}
|
||||
|
||||
var chunks []ai.StreamChunk
|
||||
err = provider.ChatStream(context.Background(), ai.ChatRequest{
|
||||
Messages: []ai.Message{{Role: "user", Content: "ping"}},
|
||||
}, func(chunk ai.StreamChunk) {
|
||||
chunks = append(chunks, chunk)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expected stream provider to report timeout via callback, got %v", err)
|
||||
}
|
||||
if len(chunks) == 0 {
|
||||
t.Fatal("expected timeout chunk")
|
||||
}
|
||||
lastChunk := chunks[len(chunks)-1]
|
||||
if !lastChunk.Done {
|
||||
t.Fatalf("expected timeout chunk to terminate stream, got %#v", lastChunk)
|
||||
}
|
||||
if !strings.Contains(lastChunk.Error, "执行超时") {
|
||||
t.Fatalf("expected request timeout message, got %#v", lastChunk)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeCLIProvider_ChatStreamAllowsDelayedMeaningfulResponse(t *testing.T) {
|
||||
fakeClaude := writeFakeClaudeScript(t, "#!/bin/sh\necho '{\"type\":\"system\",\"subtype\":\"init\"}'\nsleep 0.2\necho '{\"type\":\"assistant\",\"message\":{\"content\":[{\"type\":\"text\",\"text\":\"OK\"}]}}'\necho '{\"type\":\"result\",\"subtype\":\"success\",\"is_error\":false,\"result\":\"OK\"}'\n")
|
||||
restore := overrideClaudeCLIForTest(t, fakeClaude)
|
||||
defer restore()
|
||||
|
||||
originalRequestTimeout := claudeCLIRequestTimeout
|
||||
claudeCLIRequestTimeout = 1 * time.Second
|
||||
defer func() {
|
||||
claudeCLIRequestTimeout = originalRequestTimeout
|
||||
}()
|
||||
|
||||
provider, err := NewClaudeCLIProvider(ai.ProviderConfig{
|
||||
BaseURL: "https://coding.dashscope.aliyuncs.com/apps/anthropic",
|
||||
APIKey: "sk-test",
|
||||
Model: "qwen3.5-plus",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected provider error: %v", err)
|
||||
}
|
||||
|
||||
var chunks []ai.StreamChunk
|
||||
err = provider.ChatStream(context.Background(), ai.ChatRequest{
|
||||
Messages: []ai.Message{{Role: "user", Content: "ping"}},
|
||||
}, func(chunk ai.StreamChunk) {
|
||||
chunks = append(chunks, chunk)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expected delayed response to complete via callback, got %v", err)
|
||||
}
|
||||
if len(chunks) == 0 {
|
||||
t.Fatal("expected delayed response chunks")
|
||||
}
|
||||
if chunks[0].Content != "OK" {
|
||||
t.Fatalf("expected delayed content chunk, got %#v", chunks)
|
||||
}
|
||||
if !chunks[len(chunks)-1].Done {
|
||||
t.Fatalf("expected terminal done chunk, got %#v", chunks[len(chunks)-1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeCLIProvider_ChatReturnsErrorWhenJSONResponseIsError(t *testing.T) {
|
||||
fakeClaude := writeFakeClaudeScript(t, "#!/bin/sh\necho '{\"type\":\"result\",\"subtype\":\"success\",\"is_error\":true,\"result\":\"API Error: Unable to connect to API (ECONNRESET)\",\"error\":\"unknown\"}'\n")
|
||||
restore := overrideClaudeCLIForTest(t, fakeClaude)
|
||||
defer restore()
|
||||
|
||||
provider, err := NewClaudeCLIProvider(ai.ProviderConfig{
|
||||
BaseURL: "https://coding.dashscope.aliyuncs.com/apps/anthropic",
|
||||
APIKey: "sk-test",
|
||||
Model: "qwen3.5-plus",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected provider error: %v", err)
|
||||
}
|
||||
|
||||
_, err = provider.Chat(context.Background(), ai.ChatRequest{
|
||||
Messages: []ai.Message{{Role: "user", Content: "ping"}},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected chat error when CLI JSON marks request as failed")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "Unable to connect to API") {
|
||||
t.Fatalf("expected upstream API error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeCLIProvider_ChatStreamReportsAssistantErrorEvent(t *testing.T) {
|
||||
fakeClaude := writeFakeClaudeScript(t, "#!/bin/sh\necho '{\"type\":\"assistant\",\"is_error\":true,\"message\":{\"content\":[{\"type\":\"text\",\"text\":\"API Error: Unable to connect to API (ECONNRESET)\"}]},\"error\":\"unknown\"}'\n")
|
||||
restore := overrideClaudeCLIForTest(t, fakeClaude)
|
||||
defer restore()
|
||||
|
||||
provider, err := NewClaudeCLIProvider(ai.ProviderConfig{
|
||||
BaseURL: "https://coding.dashscope.aliyuncs.com/apps/anthropic",
|
||||
APIKey: "sk-test",
|
||||
Model: "qwen3.5-plus",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected provider error: %v", err)
|
||||
}
|
||||
|
||||
var chunks []ai.StreamChunk
|
||||
err = provider.ChatStream(context.Background(), ai.ChatRequest{
|
||||
Messages: []ai.Message{{Role: "user", Content: "ping"}},
|
||||
}, func(chunk ai.StreamChunk) {
|
||||
chunks = append(chunks, chunk)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expected stream provider to report error via callback, got %v", err)
|
||||
}
|
||||
if len(chunks) != 1 {
|
||||
t.Fatalf("expected a single terminal error chunk, got %#v", chunks)
|
||||
}
|
||||
if chunks[0].Content != "" {
|
||||
t.Fatalf("expected assistant error event to avoid content output, got %#v", chunks[0])
|
||||
}
|
||||
if !chunks[0].Done || !strings.Contains(chunks[0].Error, "Unable to connect to API") {
|
||||
t.Fatalf("expected upstream API error chunk, got %#v", chunks[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeCLIProvider_ChatStreamReportsResultErrorEvent(t *testing.T) {
|
||||
fakeClaude := writeFakeClaudeScript(t, "#!/bin/sh\necho '{\"type\":\"result\",\"subtype\":\"success\",\"is_error\":true,\"result\":\"API Error: Unable to connect to API (ECONNRESET)\",\"error\":\"unknown\"}'\n")
|
||||
restore := overrideClaudeCLIForTest(t, fakeClaude)
|
||||
defer restore()
|
||||
|
||||
provider, err := NewClaudeCLIProvider(ai.ProviderConfig{
|
||||
BaseURL: "https://coding.dashscope.aliyuncs.com/apps/anthropic",
|
||||
APIKey: "sk-test",
|
||||
Model: "qwen3.5-plus",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected provider error: %v", err)
|
||||
}
|
||||
|
||||
var chunks []ai.StreamChunk
|
||||
err = provider.ChatStream(context.Background(), ai.ChatRequest{
|
||||
Messages: []ai.Message{{Role: "user", Content: "ping"}},
|
||||
}, func(chunk ai.StreamChunk) {
|
||||
chunks = append(chunks, chunk)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expected stream provider to report error via callback, got %v", err)
|
||||
}
|
||||
if len(chunks) != 1 {
|
||||
t.Fatalf("expected a single terminal error chunk, got %#v", chunks)
|
||||
}
|
||||
if chunks[0].Content != "" {
|
||||
t.Fatalf("expected result error event to avoid content output, got %#v", chunks[0])
|
||||
}
|
||||
if !chunks[0].Done || !strings.Contains(chunks[0].Error, "Unable to connect to API") {
|
||||
t.Fatalf("expected upstream API error chunk, got %#v", chunks[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeCLIProvider_ChatStreamReportsApiRetryAuthenticationFailure(t *testing.T) {
|
||||
fakeClaude := writeFakeClaudeScript(t, "#!/bin/sh\necho '{\"type\":\"system\",\"subtype\":\"api_retry\",\"attempt\":1,\"max_retries\":10,\"retry_delay_ms\":536.11,\"error_status\":401,\"error\":\"authentication_failed\",\"session_id\":\"retry-1\"}'\nexec sleep 5\n")
|
||||
restore := overrideClaudeCLIForTest(t, fakeClaude)
|
||||
defer restore()
|
||||
|
||||
provider, err := NewClaudeCLIProvider(ai.ProviderConfig{
|
||||
BaseURL: "https://coding.dashscope.aliyuncs.com/apps/anthropic",
|
||||
APIKey: "sk-test",
|
||||
Model: "qwen3.5-plus",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected provider error: %v", err)
|
||||
}
|
||||
|
||||
var chunks []ai.StreamChunk
|
||||
err = provider.ChatStream(context.Background(), ai.ChatRequest{
|
||||
Messages: []ai.Message{{Role: "user", Content: "ping"}},
|
||||
}, func(chunk ai.StreamChunk) {
|
||||
chunks = append(chunks, chunk)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expected stream provider to report authentication error via callback, got %v", err)
|
||||
}
|
||||
if len(chunks) != 1 {
|
||||
t.Fatalf("expected a single terminal error chunk, got %#v", chunks)
|
||||
}
|
||||
if !chunks[0].Done {
|
||||
t.Fatalf("expected terminal error chunk, got %#v", chunks[0])
|
||||
}
|
||||
if strings.Contains(chunks[0].Error, "未收到模型响应") {
|
||||
t.Fatalf("expected auth failure instead of startup timeout, got %#v", chunks[0])
|
||||
}
|
||||
if !strings.Contains(chunks[0].Error, "401") || !strings.Contains(chunks[0].Error, "authentication_failed") {
|
||||
t.Fatalf("expected auth retry error details, got %#v", chunks[0])
|
||||
}
|
||||
}
|
||||
|
||||
func writeFakeClaudeScript(t *testing.T, content string) string {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "claude")
|
||||
if err := os.WriteFile(path, []byte(content), 0o755); err != nil {
|
||||
t.Fatalf("failed to write fake claude script: %v", err)
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func overrideClaudeCLIForTest(t *testing.T, fakeClaudePath string) func() {
|
||||
t.Helper()
|
||||
|
||||
originalLookPath := claudeLookPath
|
||||
claudeLookPath = func(name string) (string, error) {
|
||||
if name == "claude" {
|
||||
return fakeClaudePath, nil
|
||||
}
|
||||
return originalLookPath(name)
|
||||
}
|
||||
|
||||
originalPath := os.Getenv("PATH")
|
||||
if err := os.Setenv("PATH", filepath.Dir(fakeClaudePath)+string(os.PathListSeparator)+originalPath); err != nil {
|
||||
t.Fatalf("failed to override PATH: %v", err)
|
||||
}
|
||||
|
||||
return func() {
|
||||
claudeLookPath = originalLookPath
|
||||
_ = os.Setenv("PATH", originalPath)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,9 +2,13 @@ package provider
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var openAICompatibleVersionSuffixPattern = regexp.MustCompile(`(?i)(^|/)v\d+$`)
|
||||
|
||||
// ParseDataURI 解析前端传递的 Data URI,返回 mimeType 和去掉前缀的 rawBase64
|
||||
func ParseDataURI(dataURI string) (mimeType, rawBase64 string, err error) {
|
||||
if !strings.HasPrefix(dataURI, "data:") {
|
||||
@@ -24,3 +28,70 @@ func ParseDataURI(dataURI string) (mimeType, rawBase64 string, err error) {
|
||||
rawBase64 = parts[1]
|
||||
return mimeType, rawBase64, nil
|
||||
}
|
||||
|
||||
// NormalizeOpenAICompatibleBaseURL 统一归一化 OpenAI 兼容服务的 base URL。
|
||||
func NormalizeOpenAICompatibleBaseURL(raw string) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return defaultOpenAIBaseURL
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(trimmed)
|
||||
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
|
||||
return normalizeOpenAICompatibleBaseURLString(trimmed)
|
||||
}
|
||||
|
||||
parsed.RawQuery = ""
|
||||
parsed.Fragment = ""
|
||||
parsed.Path = normalizeOpenAICompatiblePath(parsed.Path)
|
||||
return strings.TrimRight(parsed.String(), "/")
|
||||
}
|
||||
|
||||
// ResolveOpenAICompatibleEndpoint 基于归一化 base URL 拼接 OpenAI 兼容接口路径。
|
||||
func ResolveOpenAICompatibleEndpoint(baseURL string, endpoint string) string {
|
||||
normalizedBaseURL := NormalizeOpenAICompatibleBaseURL(baseURL)
|
||||
normalizedEndpoint := strings.TrimLeft(strings.TrimSpace(endpoint), "/")
|
||||
if normalizedEndpoint == "" {
|
||||
return normalizedBaseURL
|
||||
}
|
||||
return normalizedBaseURL + "/" + normalizedEndpoint
|
||||
}
|
||||
|
||||
func normalizeOpenAICompatibleBaseURLString(raw string) string {
|
||||
normalized := strings.TrimRight(strings.TrimSpace(raw), "/")
|
||||
if normalized == "" {
|
||||
return defaultOpenAIBaseURL
|
||||
}
|
||||
|
||||
lower := strings.ToLower(normalized)
|
||||
switch {
|
||||
case strings.HasSuffix(lower, "/chat/completions"):
|
||||
normalized = normalized[:len(normalized)-len("/chat/completions")]
|
||||
case strings.HasSuffix(lower, "/models"):
|
||||
normalized = normalized[:len(normalized)-len("/models")]
|
||||
}
|
||||
normalized = strings.TrimRight(normalized, "/")
|
||||
if openAICompatibleVersionSuffixPattern.MatchString(normalized) {
|
||||
return normalized
|
||||
}
|
||||
return normalized + "/v1"
|
||||
}
|
||||
|
||||
func normalizeOpenAICompatiblePath(path string) string {
|
||||
normalized := strings.TrimRight(strings.TrimSpace(path), "/")
|
||||
lower := strings.ToLower(normalized)
|
||||
switch {
|
||||
case strings.HasSuffix(lower, "/chat/completions"):
|
||||
normalized = normalized[:len(normalized)-len("/chat/completions")]
|
||||
case strings.HasSuffix(lower, "/models"):
|
||||
normalized = normalized[:len(normalized)-len("/models")]
|
||||
}
|
||||
normalized = strings.TrimRight(normalized, "/")
|
||||
if openAICompatibleVersionSuffixPattern.MatchString(normalized) {
|
||||
return normalized
|
||||
}
|
||||
if normalized == "" {
|
||||
return "/v1"
|
||||
}
|
||||
return normalized + "/v1"
|
||||
}
|
||||
|
||||
@@ -30,14 +30,7 @@ type OpenAIProvider struct {
|
||||
|
||||
// NewOpenAIProvider 创建 OpenAI Provider 实例
|
||||
func NewOpenAIProvider(config ai.ProviderConfig) (Provider, error) {
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(config.BaseURL), "/")
|
||||
if baseURL == "" {
|
||||
baseURL = defaultOpenAIBaseURL
|
||||
}
|
||||
// 确保 baseURL 包含 /v1 路径(兼容用户只填域名的情况,如 https://anyrouter.top)
|
||||
if !strings.HasSuffix(baseURL, "/v1") && !strings.Contains(baseURL, "/v1/") {
|
||||
baseURL = baseURL + "/v1"
|
||||
}
|
||||
baseURL := NormalizeOpenAICompatibleBaseURL(config.BaseURL)
|
||||
model := strings.TrimSpace(config.Model)
|
||||
if model == "" {
|
||||
return nil, fmt.Errorf("模型 ID 不能为空,请在设置中选择或输入模型")
|
||||
@@ -315,7 +308,7 @@ func (p *OpenAIProvider) ChatStream(ctx context.Context, req ai.ChatRequest, cal
|
||||
}
|
||||
if len(chunk.Choices) > 0 {
|
||||
choice := chunk.Choices[0]
|
||||
|
||||
|
||||
// Handle ToolCalls delta
|
||||
if len(choice.Delta.ToolCalls) > 0 {
|
||||
receivedContent = true
|
||||
@@ -383,10 +376,7 @@ func (p *OpenAIProvider) doRequest(ctx context.Context, body interface{}) (io.Re
|
||||
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
url := p.baseURL + "/chat/completions"
|
||||
|
||||
|
||||
|
||||
url := ResolveOpenAICompatibleEndpoint(p.baseURL, "chat/completions")
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建 HTTP 请求失败: %w", err)
|
||||
|
||||
@@ -5,6 +5,76 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNormalizeOpenAICompatibleBaseURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
raw string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "empty uses default openai base url",
|
||||
raw: "",
|
||||
want: "https://api.openai.com/v1",
|
||||
},
|
||||
{
|
||||
name: "domain only appends v1",
|
||||
raw: "https://api.openai.com",
|
||||
want: "https://api.openai.com/v1",
|
||||
},
|
||||
{
|
||||
name: "keeps existing v1 suffix",
|
||||
raw: "https://api.deepseek.com/v1",
|
||||
want: "https://api.deepseek.com/v1",
|
||||
},
|
||||
{
|
||||
name: "keeps dashscope compatible mode path",
|
||||
raw: "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
want: "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
},
|
||||
{
|
||||
name: "keeps zhipu v4 path",
|
||||
raw: "https://open.bigmodel.cn/api/paas/v4",
|
||||
want: "https://open.bigmodel.cn/api/paas/v4",
|
||||
},
|
||||
{
|
||||
name: "keeps volcengine ark v3 path",
|
||||
raw: "https://ark.cn-beijing.volces.com/api/v3",
|
||||
want: "https://ark.cn-beijing.volces.com/api/v3",
|
||||
},
|
||||
{
|
||||
name: "keeps volcengine coding plan v3 path",
|
||||
raw: "https://ark.cn-beijing.volces.com/api/coding/v3",
|
||||
want: "https://ark.cn-beijing.volces.com/api/coding/v3",
|
||||
},
|
||||
{
|
||||
name: "strips chat completions suffix before normalizing",
|
||||
raw: "https://api.openai.com/v1/chat/completions",
|
||||
want: "https://api.openai.com/v1",
|
||||
},
|
||||
{
|
||||
name: "strips models suffix before normalizing",
|
||||
raw: "https://ark.cn-beijing.volces.com/api/coding/v3/models",
|
||||
want: "https://ark.cn-beijing.volces.com/api/coding/v3",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := NormalizeOpenAICompatibleBaseURL(tt.raw); got != tt.want {
|
||||
t.Fatalf("expected normalized base url %q, got %q", tt.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveOpenAICompatibleEndpoint(t *testing.T) {
|
||||
got := ResolveOpenAICompatibleEndpoint("https://ark.cn-beijing.volces.com/api/coding/v3/models", "chat/completions")
|
||||
want := "https://ark.cn-beijing.volces.com/api/coding/v3/chat/completions"
|
||||
if got != want {
|
||||
t.Fatalf("expected endpoint %q, got %q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIProvider_Validate_MissingAPIKey(t *testing.T) {
|
||||
p, err := NewOpenAIProvider(ai.ProviderConfig{Type: "openai", Model: "gpt-4o"})
|
||||
if err != nil {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -31,7 +32,7 @@ type Service struct {
|
||||
safetyLevel ai.SQLPermissionLevel
|
||||
contextLevel ai.ContextLevel
|
||||
guard *safety.Guard
|
||||
configDir string // 配置存储目录
|
||||
configDir string // 配置存储目录
|
||||
cancelFuncs map[string]context.CancelFunc // 记录每个 session 的 context 取消函数
|
||||
}
|
||||
|
||||
@@ -45,6 +46,55 @@ var miniMaxAnthropicModels = []string{
|
||||
"MiniMax-M2",
|
||||
}
|
||||
|
||||
var dashScopeCodingPlanModels = []string{
|
||||
"qwen3.5-plus",
|
||||
"kimi-k2.5",
|
||||
"glm-5",
|
||||
"MiniMax-M2.5",
|
||||
"qwen3-max-2026-01-23",
|
||||
"qwen3-coder-next",
|
||||
"qwen3-coder-plus",
|
||||
"glm-4.7",
|
||||
}
|
||||
|
||||
const dashScopeCodingPlanAnthropicBaseURL = "https://coding.dashscope.aliyuncs.com/apps/anthropic"
|
||||
|
||||
var volcengineCodingPlanAllowedExactModels = []string{
|
||||
"auto",
|
||||
}
|
||||
|
||||
var volcengineCodingPlanAllowedModelFamilies = []string{
|
||||
"doubao-seed-2.0-code",
|
||||
"doubao-seed-2.0-pro",
|
||||
"doubao-seed-2.0-lite",
|
||||
"doubao-seed-code",
|
||||
"minimax-m2.5",
|
||||
"glm-4.7",
|
||||
"deepseek-v3.2",
|
||||
"kimi-k2",
|
||||
}
|
||||
|
||||
const volcengineCodingPlanEmptyModelsError = `当前接口未返回可用的火山 Coding Plan 模型,请检查账号权限或切换到"火山方舟"供应商`
|
||||
|
||||
var claudeCLIHealthCheckFunc = func(config ai.ProviderConfig) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cliProvider, err := provider.NewProvider(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = cliProvider.Chat(ctx, ai.ChatRequest{
|
||||
Messages: []ai.Message{
|
||||
{Role: "user", Content: "ping"},
|
||||
},
|
||||
MaxTokens: 1,
|
||||
Temperature: 0,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// NewService 创建 AI Service 实例
|
||||
func NewService() *Service {
|
||||
return &Service{
|
||||
@@ -56,8 +106,13 @@ func NewService() *Service {
|
||||
}
|
||||
}
|
||||
|
||||
// Startup Wails 生命周期回调
|
||||
func (s *Service) Startup(ctx context.Context) {
|
||||
// InitializeLifecycle attaches runtime context without exposing lifecycle internals to Wails bindings.
|
||||
func InitializeLifecycle(s *Service, ctx context.Context) {
|
||||
s.startup(ctx)
|
||||
}
|
||||
|
||||
// startup Wails 生命周期回调
|
||||
func (s *Service) startup(ctx context.Context) {
|
||||
s.ctx = ctx
|
||||
s.configDir = resolveConfigDir()
|
||||
s.loadConfig()
|
||||
@@ -173,6 +228,12 @@ func (s *Service) AITestProvider(config ai.ProviderConfig) map[string]interface{
|
||||
err = fmt.Errorf("上游服务器内部错误 (HTTP %d)", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
case "claude-cli":
|
||||
testConfig := config
|
||||
if strings.TrimSpace(testConfig.Model) == "" && isDashScopeCodingPlanProvider(testConfig) && len(dashScopeCodingPlanModels) > 0 {
|
||||
testConfig.Model = dashScopeCodingPlanModels[0]
|
||||
}
|
||||
err = claudeCLIHealthCheckFunc(testConfig)
|
||||
default:
|
||||
if baseURL != "" {
|
||||
req, _ := http.NewRequest("GET", baseURL, nil)
|
||||
@@ -219,18 +280,104 @@ func isMoonshotAnthropicProvider(config ai.ProviderConfig) bool {
|
||||
return strings.Contains(baseURL, "api.moonshot.cn")
|
||||
}
|
||||
|
||||
func parseProviderBaseURL(raw string) (string, string) {
|
||||
parsed, err := url.Parse(strings.TrimSpace(raw))
|
||||
if err != nil {
|
||||
return "", ""
|
||||
}
|
||||
return strings.ToLower(parsed.Hostname()), strings.TrimRight(strings.ToLower(parsed.Path), "/")
|
||||
}
|
||||
|
||||
func isDashScopeBailianAnthropicProvider(config ai.ProviderConfig) bool {
|
||||
if normalizedProviderType(config) != "anthropic" {
|
||||
return false
|
||||
}
|
||||
host, path := parseProviderBaseURL(config.BaseURL)
|
||||
return host == "dashscope.aliyuncs.com" && strings.HasPrefix(path, "/apps/anthropic")
|
||||
}
|
||||
|
||||
func isDashScopeCodingPlanAnthropicProvider(config ai.ProviderConfig) bool {
|
||||
if normalizedProviderType(config) != "anthropic" {
|
||||
return false
|
||||
}
|
||||
return isDashScopeCodingPlanProvider(config)
|
||||
}
|
||||
|
||||
func isDashScopeCodingPlanProvider(config ai.ProviderConfig) bool {
|
||||
host, path := parseProviderBaseURL(config.BaseURL)
|
||||
return host == "coding.dashscope.aliyuncs.com" && (strings.HasPrefix(path, "/apps/anthropic") || strings.HasPrefix(path, "/v1"))
|
||||
}
|
||||
|
||||
func isVolcengineCodingPlanProvider(config ai.ProviderConfig) bool {
|
||||
if normalizedProviderType(config) != "openai" {
|
||||
return false
|
||||
}
|
||||
host, path := parseProviderBaseURL(provider.NormalizeOpenAICompatibleBaseURL(config.BaseURL))
|
||||
return host == "ark.cn-beijing.volces.com" && path == "/api/coding/v3"
|
||||
}
|
||||
|
||||
func filterVolcengineCodingPlanModels(models []string) []string {
|
||||
filtered := make([]string, 0, len(models))
|
||||
for _, model := range models {
|
||||
lowerModel := strings.ToLower(strings.TrimSpace(model))
|
||||
matched := false
|
||||
for _, exactModel := range volcengineCodingPlanAllowedExactModels {
|
||||
if lowerModel == exactModel {
|
||||
filtered = append(filtered, model)
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if matched {
|
||||
continue
|
||||
}
|
||||
for _, family := range volcengineCodingPlanAllowedModelFamilies {
|
||||
if strings.Contains(lowerModel, family) {
|
||||
filtered = append(filtered, model)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func filterFetchedModelsForProvider(config ai.ProviderConfig, models []string) ([]string, error) {
|
||||
if !isVolcengineCodingPlanProvider(config) {
|
||||
return models, nil
|
||||
}
|
||||
filtered := filterVolcengineCodingPlanModels(models)
|
||||
if len(filtered) == 0 {
|
||||
return nil, fmt.Errorf(volcengineCodingPlanEmptyModelsError)
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
func defaultStaticModelsForProvider(config ai.ProviderConfig) []string {
|
||||
if isMiniMaxAnthropicProvider(config) {
|
||||
return append([]string(nil), miniMaxAnthropicModels...)
|
||||
}
|
||||
if isDashScopeCodingPlanProvider(config) {
|
||||
return append([]string(nil), dashScopeCodingPlanModels...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeProviderConfig(config ai.ProviderConfig) ai.ProviderConfig {
|
||||
staticModels := defaultStaticModelsForProvider(config)
|
||||
if len(staticModels) > 0 && len(config.Models) == 0 {
|
||||
config.Models = staticModels
|
||||
switch {
|
||||
case isDashScopeBailianAnthropicProvider(config):
|
||||
config.Models = nil
|
||||
case isDashScopeCodingPlanProvider(config):
|
||||
config.Type = "custom"
|
||||
config.APIFormat = "claude-cli"
|
||||
config.BaseURL = dashScopeCodingPlanAnthropicBaseURL
|
||||
config.Models = append([]string(nil), dashScopeCodingPlanModels...)
|
||||
default:
|
||||
staticModels := defaultStaticModelsForProvider(config)
|
||||
if len(staticModels) > 0 && len(config.Models) == 0 {
|
||||
config.Models = staticModels
|
||||
}
|
||||
}
|
||||
|
||||
model := strings.TrimSpace(config.Model)
|
||||
if isMiniMaxAnthropicProvider(config) && (model == "" || strings.HasPrefix(strings.ToLower(model), "minimax-text-")) {
|
||||
config.Model = miniMaxAnthropicModels[0]
|
||||
@@ -248,6 +395,9 @@ func resolveModelsURL(config ai.ProviderConfig) string {
|
||||
if isMoonshotAnthropicProvider(config) {
|
||||
return "https://api.moonshot.cn/v1/models"
|
||||
}
|
||||
if isDashScopeBailianAnthropicProvider(config) {
|
||||
return "https://dashscope.aliyuncs.com/compatible-mode/v1/models"
|
||||
}
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.anthropic.com"
|
||||
}
|
||||
@@ -263,13 +413,7 @@ func resolveModelsURL(config ai.ProviderConfig) string {
|
||||
case "openai":
|
||||
fallthrough
|
||||
default:
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com/v1"
|
||||
}
|
||||
if !strings.HasSuffix(baseURL, "/v1") && !strings.Contains(baseURL, "/v1/") {
|
||||
baseURL = baseURL + "/v1"
|
||||
}
|
||||
return baseURL + "/models"
|
||||
return provider.ResolveOpenAICompatibleEndpoint(baseURL, "models")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -283,9 +427,11 @@ func newModelsRequest(config ai.ProviderConfig) (*http.Request, error) {
|
||||
|
||||
switch normalizedProviderType(config) {
|
||||
case "anthropic":
|
||||
req.Header.Set("x-api-key", config.APIKey)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
req.Header.Set("Authorization", "Bearer "+config.APIKey)
|
||||
if isDashScopeBailianAnthropicProvider(config) {
|
||||
req.Header.Set("Authorization", "Bearer "+config.APIKey)
|
||||
} else {
|
||||
provider.ApplyAnthropicAuthHeaders(req.Header, config.BaseURL, config.APIKey)
|
||||
}
|
||||
case "gemini":
|
||||
// Gemini 使用 query string 传递 key,无需额外鉴权头
|
||||
default:
|
||||
@@ -315,33 +461,36 @@ func resolveAnthropicMessagesURL(baseURL string) string {
|
||||
|
||||
func newProviderHealthCheckRequest(config ai.ProviderConfig) (*http.Request, error) {
|
||||
config = normalizeProviderConfig(config)
|
||||
if isMiniMaxAnthropicProvider(config) {
|
||||
body := map[string]interface{}{
|
||||
"model": config.Model,
|
||||
"max_tokens": 1,
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": "ping"},
|
||||
},
|
||||
}
|
||||
bodyBytes, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
req, err := http.NewRequest("POST", resolveAnthropicMessagesURL(config.BaseURL), strings.NewReader(string(bodyBytes)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("x-api-key", config.APIKey)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
for k, v := range config.Headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
return req, nil
|
||||
if isMiniMaxAnthropicProvider(config) || isDashScopeBailianAnthropicProvider(config) || isDashScopeCodingPlanAnthropicProvider(config) {
|
||||
return newAnthropicMessagesHealthCheckRequest(config)
|
||||
}
|
||||
return newModelsRequest(config)
|
||||
}
|
||||
|
||||
func newAnthropicMessagesHealthCheckRequest(config ai.ProviderConfig) (*http.Request, error) {
|
||||
body := map[string]interface{}{
|
||||
"model": config.Model,
|
||||
"max_tokens": 1,
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": "ping"},
|
||||
},
|
||||
}
|
||||
bodyBytes, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
req, err := http.NewRequest("POST", resolveAnthropicMessagesURL(config.BaseURL), strings.NewReader(string(bodyBytes)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
provider.ApplyAnthropicAuthHeaders(req.Header, config.BaseURL, config.APIKey)
|
||||
for k, v := range config.Headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// AISetActiveProvider 设置活动 Provider
|
||||
func (s *Service) AISetActiveProvider(id string) {
|
||||
s.mu.Lock()
|
||||
@@ -380,7 +529,12 @@ func (s *Service) AIListModels() map[string]interface{} {
|
||||
return map[string]interface{}{"success": false, "models": []string{}, "error": "未找到活跃 Provider"}
|
||||
}
|
||||
|
||||
models, err := fetchModels(config)
|
||||
config = normalizeProviderConfig(config)
|
||||
if staticModels := defaultStaticModelsForProvider(config); len(staticModels) > 0 {
|
||||
return map[string]interface{}{"success": true, "models": staticModels, "source": "static"}
|
||||
}
|
||||
|
||||
models, err := fetchModelsFunc(config)
|
||||
if err != nil {
|
||||
// 回退到配置中的静态模型列表
|
||||
if len(config.Models) > 0 {
|
||||
@@ -389,10 +543,17 @@ func (s *Service) AIListModels() map[string]interface{} {
|
||||
return map[string]interface{}{"success": false, "models": []string{}, "error": err.Error()}
|
||||
}
|
||||
|
||||
models, err = filterFetchedModelsForProvider(config, models)
|
||||
if err != nil {
|
||||
return map[string]interface{}{"success": false, "models": []string{}, "error": err.Error()}
|
||||
}
|
||||
|
||||
return map[string]interface{}{"success": true, "models": models, "source": "api"}
|
||||
}
|
||||
|
||||
// fetchModels 从供应商 API 获取可用模型列表
|
||||
var fetchModelsFunc = fetchModels
|
||||
|
||||
func fetchModels(config ai.ProviderConfig) ([]string, error) {
|
||||
providerType := normalizedProviderType(config)
|
||||
if staticModels := defaultStaticModelsForProvider(config); len(staticModels) > 0 {
|
||||
@@ -588,8 +749,8 @@ func (s *Service) AIChatSend(messages []ai.Message, tools []ai.Tool) map[string]
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"success": true,
|
||||
"content": resp.Content,
|
||||
"success": true,
|
||||
"content": resp.Content,
|
||||
"tool_calls": resp.ToolCalls,
|
||||
"tokensUsed": map[string]int{
|
||||
"promptTokens": resp.TokensUsed.PromptTokens,
|
||||
|
||||
157
internal/ai/service/service_qwen_test.go
Normal file
157
internal/ai/service/service_qwen_test.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package aiservice
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/ai"
|
||||
)
|
||||
|
||||
func TestDefaultStaticModelsForProvider_DoesNotReturnBailianStaticModels(t *testing.T) {
|
||||
models := defaultStaticModelsForProvider(ai.ProviderConfig{
|
||||
Type: "anthropic",
|
||||
BaseURL: "https://dashscope.aliyuncs.com/apps/anthropic",
|
||||
})
|
||||
if len(models) != 0 {
|
||||
t.Fatalf("expected Bailian provider to rely on remote model list, got %v", models)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultStaticModelsForProvider_ReturnsDashScopeCodingPlanSupportedModels(t *testing.T) {
|
||||
expected := []string{
|
||||
"qwen3.5-plus",
|
||||
"kimi-k2.5",
|
||||
"glm-5",
|
||||
"MiniMax-M2.5",
|
||||
"qwen3-max-2026-01-23",
|
||||
"qwen3-coder-next",
|
||||
"qwen3-coder-plus",
|
||||
"glm-4.7",
|
||||
}
|
||||
testCases := []ai.ProviderConfig{
|
||||
{
|
||||
Type: "anthropic",
|
||||
BaseURL: "https://coding.dashscope.aliyuncs.com/apps/anthropic",
|
||||
},
|
||||
{
|
||||
Type: "custom",
|
||||
APIFormat: "claude-cli",
|
||||
BaseURL: "https://coding.dashscope.aliyuncs.com/apps/anthropic",
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
models := defaultStaticModelsForProvider(testCase)
|
||||
if !reflect.DeepEqual(models, expected) {
|
||||
t.Fatalf("expected Coding Plan supported models %v, got %v for config %#v", expected, models, testCase)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeProviderConfig_DoesNotForceModelForDashScopeProviders(t *testing.T) {
|
||||
bailian := normalizeProviderConfig(ai.ProviderConfig{
|
||||
Type: "anthropic",
|
||||
BaseURL: "https://dashscope.aliyuncs.com/apps/anthropic",
|
||||
})
|
||||
if bailian.Model != "" {
|
||||
t.Fatalf("expected Bailian model to remain empty until explicit selection, got %q", bailian.Model)
|
||||
}
|
||||
|
||||
codingPlan := normalizeProviderConfig(ai.ProviderConfig{
|
||||
Type: "anthropic",
|
||||
BaseURL: "https://coding.dashscope.aliyuncs.com/apps/anthropic",
|
||||
})
|
||||
if codingPlan.Type != "custom" {
|
||||
t.Fatalf("expected Coding Plan provider type to normalize to custom, got %q", codingPlan.Type)
|
||||
}
|
||||
if codingPlan.APIFormat != "claude-cli" {
|
||||
t.Fatalf("expected Coding Plan provider api format to normalize to claude-cli, got %q", codingPlan.APIFormat)
|
||||
}
|
||||
if codingPlan.Model != "" {
|
||||
t.Fatalf("expected Coding Plan model to remain empty until explicit selection, got %q", codingPlan.Model)
|
||||
}
|
||||
if len(codingPlan.Models) == 0 {
|
||||
t.Fatal("expected Coding Plan provider to expose official supported models")
|
||||
}
|
||||
if codingPlan.Models[0] != "qwen3.5-plus" {
|
||||
t.Fatalf("expected Coding Plan provider to expose latest supported models, got %v", codingPlan.Models)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveModelsURL_UsesDashScopeCompatibleModelsEndpointForBailianAnthropic(t *testing.T) {
|
||||
url := resolveModelsURL(ai.ProviderConfig{
|
||||
Type: "anthropic",
|
||||
BaseURL: "https://dashscope.aliyuncs.com/apps/anthropic",
|
||||
})
|
||||
if url != "https://dashscope.aliyuncs.com/compatible-mode/v1/models" {
|
||||
t.Fatalf("expected Bailian models endpoint, got %q", url)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAIListModels_ReturnsStaticModelsForDashScopeCodingPlanWithoutRemoteFetch(t *testing.T) {
|
||||
originalFetchModelsFunc := fetchModelsFunc
|
||||
fetchModelsFunc = func(config ai.ProviderConfig) ([]string, error) {
|
||||
t.Fatalf("expected Coding Plan model list to stay static and skip remote fetch, got config %#v", config)
|
||||
return nil, nil
|
||||
}
|
||||
defer func() {
|
||||
fetchModelsFunc = originalFetchModelsFunc
|
||||
}()
|
||||
|
||||
service := NewService()
|
||||
service.providers = []ai.ProviderConfig{
|
||||
{
|
||||
ID: "provider-coding-plan",
|
||||
Type: "anthropic",
|
||||
BaseURL: "https://coding.dashscope.aliyuncs.com/apps/anthropic",
|
||||
},
|
||||
}
|
||||
service.activeProvider = "provider-coding-plan"
|
||||
|
||||
result := service.AIListModels()
|
||||
if result["success"] != true {
|
||||
t.Fatalf("expected AIListModels to succeed, got %#v", result)
|
||||
}
|
||||
models, ok := result["models"].([]string)
|
||||
if !ok {
|
||||
t.Fatalf("expected []string models, got %#v", result["models"])
|
||||
}
|
||||
if len(models) == 0 || models[0] != "qwen3.5-plus" {
|
||||
t.Fatalf("expected official static Coding Plan models, got %#v", models)
|
||||
}
|
||||
if source, _ := result["source"].(string); source != "static" {
|
||||
t.Fatalf("expected static source, got %#v", result["source"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAITestProvider_UsesClaudeCLIHealthCheckForDashScopeCodingPlan(t *testing.T) {
|
||||
originalClaudeCLIHealthCheckFunc := claudeCLIHealthCheckFunc
|
||||
defer func() {
|
||||
claudeCLIHealthCheckFunc = originalClaudeCLIHealthCheckFunc
|
||||
}()
|
||||
|
||||
var received ai.ProviderConfig
|
||||
claudeCLIHealthCheckFunc = func(config ai.ProviderConfig) error {
|
||||
received = config
|
||||
return nil
|
||||
}
|
||||
|
||||
service := NewService()
|
||||
result := service.AITestProvider(ai.ProviderConfig{
|
||||
Type: "anthropic",
|
||||
BaseURL: "https://coding.dashscope.aliyuncs.com/apps/anthropic",
|
||||
APIKey: "sk-test",
|
||||
})
|
||||
if result["success"] != true {
|
||||
t.Fatalf("expected AITestProvider to succeed, got %#v", result)
|
||||
}
|
||||
if received.Type != "custom" {
|
||||
t.Fatalf("expected Coding Plan test to use custom provider type, got %q", received.Type)
|
||||
}
|
||||
if received.APIFormat != "claude-cli" {
|
||||
t.Fatalf("expected Coding Plan test to use claude-cli api format, got %q", received.APIFormat)
|
||||
}
|
||||
if received.Model != "qwen3.5-plus" {
|
||||
t.Fatalf("expected Coding Plan test to default probe model to qwen3.5-plus, got %q", received.Model)
|
||||
}
|
||||
}
|
||||
@@ -37,6 +37,43 @@ func TestResolveModelsURL_UsesOpenAIModelsEndpointForOpenAICompatibleProvider(t
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveModelsURL_UsesVersionedVolcengineCodingPlanPath(t *testing.T) {
|
||||
url := resolveModelsURL(ai.ProviderConfig{
|
||||
Type: "openai",
|
||||
BaseURL: "https://ark.cn-beijing.volces.com/api/coding/v3",
|
||||
})
|
||||
if url != "https://ark.cn-beijing.volces.com/api/coding/v3/models" {
|
||||
t.Fatalf("expected volcengine coding plan models endpoint, got %q", url)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveModelsURL_UsesVersionedZhipuPath(t *testing.T) {
|
||||
url := resolveModelsURL(ai.ProviderConfig{
|
||||
Type: "openai",
|
||||
BaseURL: "https://open.bigmodel.cn/api/paas/v4",
|
||||
})
|
||||
if url != "https://open.bigmodel.cn/api/paas/v4/models" {
|
||||
t.Fatalf("expected zhipu models endpoint, got %q", url)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewModelsRequest_StripsChatCompletionsSuffixForOpenAICompatibleProvider(t *testing.T) {
|
||||
req, err := newModelsRequest(ai.ProviderConfig{
|
||||
Type: "openai",
|
||||
BaseURL: "https://ark.cn-beijing.volces.com/api/v3/chat/completions",
|
||||
APIKey: "sk-test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if req.URL.String() != "https://ark.cn-beijing.volces.com/api/v3/models" {
|
||||
t.Fatalf("expected normalized models endpoint, got %q", req.URL.String())
|
||||
}
|
||||
if got := req.Header.Get("Authorization"); got != "Bearer sk-test" {
|
||||
t.Fatalf("expected bearer auth header, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultStaticModelsForProvider_ReturnsMiniMaxAnthropicModels(t *testing.T) {
|
||||
models := defaultStaticModelsForProvider(ai.ProviderConfig{
|
||||
Type: "anthropic",
|
||||
@@ -56,6 +93,16 @@ func TestDefaultStaticModelsForProvider_ReturnsMiniMaxAnthropicModels(t *testing
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultStaticModelsForProvider_DoesNotReturnDashScopeBailianStaticModels(t *testing.T) {
|
||||
models := defaultStaticModelsForProvider(ai.ProviderConfig{
|
||||
Type: "anthropic",
|
||||
BaseURL: "https://dashscope.aliyuncs.com/apps/anthropic",
|
||||
})
|
||||
if len(models) != 0 {
|
||||
t.Fatalf("expected Bailian provider to fetch models remotely, got %v", models)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewProviderHealthCheckRequest_UsesMessagesEndpointForMiniMaxAnthropic(t *testing.T) {
|
||||
req, err := newProviderHealthCheckRequest(ai.ProviderConfig{
|
||||
Type: "anthropic",
|
||||
@@ -76,3 +123,30 @@ func TestNewProviderHealthCheckRequest_UsesMessagesEndpointForMiniMaxAnthropic(t
|
||||
t.Fatalf("expected x-api-key header to be set, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewProviderHealthCheckRequest_UsesMessagesEndpointForDashScopeAnthropic(t *testing.T) {
|
||||
req, err := newProviderHealthCheckRequest(ai.ProviderConfig{
|
||||
Type: "anthropic",
|
||||
BaseURL: "https://dashscope.aliyuncs.com/apps/anthropic",
|
||||
Model: "qwen3.5-plus",
|
||||
APIKey: "sk-test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if req.Method != "POST" {
|
||||
t.Fatalf("expected POST request, got %s", req.Method)
|
||||
}
|
||||
if req.URL.String() != "https://dashscope.aliyuncs.com/apps/anthropic/v1/messages" {
|
||||
t.Fatalf("expected DashScope messages endpoint, got %q", req.URL.String())
|
||||
}
|
||||
if got := req.Header.Get("x-api-key"); got != "sk-test" {
|
||||
t.Fatalf("expected x-api-key header to be set, got %q", got)
|
||||
}
|
||||
if got := req.Header.Get("Authorization"); got != "Bearer sk-test" {
|
||||
t.Fatalf("expected bearer authorization header, got %q", got)
|
||||
}
|
||||
if got := req.Header.Get("anthropic-version"); got != "" {
|
||||
t.Fatalf("expected no anthropic-version header for DashScope, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
111
internal/ai/service/service_volcengine_test.go
Normal file
111
internal/ai/service/service_volcengine_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package aiservice
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/ai"
|
||||
)
|
||||
|
||||
func TestIsVolcengineCodingPlanProvider_MatchesCodingPlanBaseURL(t *testing.T) {
|
||||
if !isVolcengineCodingPlanProvider(ai.ProviderConfig{
|
||||
Type: "openai",
|
||||
BaseURL: "https://ark.cn-beijing.volces.com/api/coding/v3",
|
||||
}) {
|
||||
t.Fatal("expected volcengine coding plan provider to be detected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterVolcengineCodingPlanModels_KeepsOnlySupportedFamilies(t *testing.T) {
|
||||
filtered := filterVolcengineCodingPlanModels([]string{
|
||||
"Auto",
|
||||
"qwen3-14b-20250429",
|
||||
"wan2-1-14b-t2v-250225",
|
||||
"Doubao-Seed-2.0-Code",
|
||||
"Doubao-Seed-2.0-pro",
|
||||
"Doubao-Seed-2.0-lite",
|
||||
"doubao-seed-code-32k-250615",
|
||||
"MiniMax-M2.5",
|
||||
"GLM-4.7",
|
||||
"DeepSeek-V3.2",
|
||||
"kimi-k2-turbo-preview",
|
||||
})
|
||||
|
||||
expected := []string{
|
||||
"Auto",
|
||||
"Doubao-Seed-2.0-Code",
|
||||
"Doubao-Seed-2.0-pro",
|
||||
"Doubao-Seed-2.0-lite",
|
||||
"doubao-seed-code-32k-250615",
|
||||
"MiniMax-M2.5",
|
||||
"GLM-4.7",
|
||||
"DeepSeek-V3.2",
|
||||
"kimi-k2-turbo-preview",
|
||||
}
|
||||
if !reflect.DeepEqual(filtered, expected) {
|
||||
t.Fatalf("expected filtered models %v, got %v", expected, filtered)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterVolcengineCodingPlanModels_DoesNotBroadlyMatchAutoKeyword(t *testing.T) {
|
||||
filtered := filterVolcengineCodingPlanModels([]string{
|
||||
"Auto",
|
||||
"automatic-router-preview",
|
||||
})
|
||||
|
||||
expected := []string{"Auto"}
|
||||
if !reflect.DeepEqual(filtered, expected) {
|
||||
t.Fatalf("expected only exact Auto model to remain, got %v", filtered)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterFetchedModelsForProvider_DoesNotFilterVolcengineArk(t *testing.T) {
|
||||
rawModels := []string{
|
||||
"qwen3-14b-20250429",
|
||||
"wan2-1-14b-t2v-250225",
|
||||
}
|
||||
|
||||
filtered, err := filterFetchedModelsForProvider(ai.ProviderConfig{
|
||||
Type: "openai",
|
||||
BaseURL: "https://ark.cn-beijing.volces.com/api/v3",
|
||||
}, rawModels)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(filtered, rawModels) {
|
||||
t.Fatalf("expected ark models to stay untouched, got %v", filtered)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAIListModels_ReturnsFailureWhenVolcengineCodingPlanModelsAreFilteredEmpty(t *testing.T) {
|
||||
originalFetchModelsFunc := fetchModelsFunc
|
||||
fetchModelsFunc = func(config ai.ProviderConfig) ([]string, error) {
|
||||
return []string{
|
||||
"qwen3-14b-20250429",
|
||||
"wan2-1-14b-t2v-250225",
|
||||
}, nil
|
||||
}
|
||||
defer func() {
|
||||
fetchModelsFunc = originalFetchModelsFunc
|
||||
}()
|
||||
|
||||
service := NewService()
|
||||
service.providers = []ai.ProviderConfig{
|
||||
{
|
||||
ID: "provider-coding",
|
||||
Type: "openai",
|
||||
BaseURL: "https://ark.cn-beijing.volces.com/api/coding/v3",
|
||||
},
|
||||
}
|
||||
service.activeProvider = "provider-coding"
|
||||
|
||||
result := service.AIListModels()
|
||||
if result["success"] != false {
|
||||
t.Fatalf("expected AIListModels to fail, got %#v", result)
|
||||
}
|
||||
errorMessage, _ := result["error"].(string)
|
||||
if !strings.Contains(errorMessage, "当前接口未返回可用的火山 Coding Plan 模型") {
|
||||
t.Fatalf("expected specific coding plan error, got %q", errorMessage)
|
||||
}
|
||||
}
|
||||
@@ -25,7 +25,7 @@ type Tool struct {
|
||||
|
||||
// Message 表示一条对话消息
|
||||
type Message struct {
|
||||
Role string `json:"role"` // "system" | "user" | "assistant" | "tool"
|
||||
Role string `json:"role"` // "system" | "user" | "assistant" | "tool"
|
||||
Content string `json:"content"`
|
||||
Images []string `json:"images,omitempty"` // base64 encoded images with data:image/png;base64,... prefix
|
||||
ToolCallID string `json:"tool_call_id,omitempty"` // 当 role 为 "tool" 时必须传递
|
||||
@@ -66,13 +66,13 @@ type StreamChunk struct {
|
||||
// ProviderConfig AI Provider 配置
|
||||
type ProviderConfig struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"` // openai | anthropic | gemini | custom
|
||||
Type string `json:"type"` // openai | anthropic | gemini | custom
|
||||
Name string `json:"name"`
|
||||
APIKey string `json:"apiKey"`
|
||||
BaseURL string `json:"baseUrl"`
|
||||
Model string `json:"model"`
|
||||
Models []string `json:"models,omitempty"`
|
||||
APIFormat string `json:"apiFormat,omitempty"` // custom 专用: openai | anthropic | gemini
|
||||
APIFormat string `json:"apiFormat,omitempty"` // custom 专用: openai | anthropic | gemini | claude-cli
|
||||
Headers map[string]string `json:"headers,omitempty"`
|
||||
MaxTokens int `json:"maxTokens"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
|
||||
@@ -64,9 +64,14 @@ func NewApp() *App {
|
||||
}
|
||||
}
|
||||
|
||||
// Startup is called when the app starts. The context is saved
|
||||
// so we can call the runtime methods
|
||||
func (a *App) Startup(ctx context.Context) {
|
||||
// InitializeLifecycle attaches runtime context without exposing lifecycle internals to Wails bindings.
|
||||
func InitializeLifecycle(a *App, ctx context.Context) {
|
||||
a.startup(ctx)
|
||||
}
|
||||
|
||||
// startup is called when the app starts. The context is saved
|
||||
// so we can call the runtime methods.
|
||||
func (a *App) startup(ctx context.Context) {
|
||||
a.ctx = ctx
|
||||
a.startedAt = time.Now()
|
||||
logger.Init()
|
||||
@@ -603,7 +608,7 @@ func (a *App) connectDatabaseWithStartupRetry(rawConfig connection.ConnectionCon
|
||||
|
||||
if err := dbInst.Connect(connectConfig); err == nil {
|
||||
if attempt > 1 {
|
||||
logger.Warnf("数据库连接在启动保护重试后成功:%s 缓存Key=%s 尝试=%d/%d", formatConnSummary(effectiveConfig), cacheKey, attempt, startupConnectRetryAttempts)
|
||||
logger.Warnf("数据库连接在重试后成功:%s 缓存Key=%s 尝试=%d/%d", formatConnSummary(effectiveConfig), cacheKey, attempt, startupConnectRetryAttempts)
|
||||
}
|
||||
return dbInst, effectiveConfig, nil
|
||||
} else {
|
||||
@@ -611,10 +616,10 @@ func (a *App) connectDatabaseWithStartupRetry(rawConfig connection.ConnectionCon
|
||||
wrapped := wrapConnectError(effectiveConfig, err)
|
||||
lastErr = wrapped
|
||||
logger.Error(wrapped, "建立数据库连接失败:%s 缓存Key=%s", formatConnSummary(effectiveConfig), cacheKey)
|
||||
if !a.shouldRetryStartupConnect(err, attempt) {
|
||||
if !a.shouldRetryConnect(err, attempt) {
|
||||
return nil, effectiveConfig, wrapped
|
||||
}
|
||||
logger.Warnf("检测到启动期瞬时网络失败,准备重试连接:%s 缓存Key=%s 尝试=%d/%d 延迟=%s 原因=%s",
|
||||
logger.Warnf("检测到瞬时网络失败,准备重试连接:%s 缓存Key=%s 尝试=%d/%d 延迟=%s 原因=%s",
|
||||
formatConnSummary(effectiveConfig), cacheKey, attempt, startupConnectRetryAttempts, startupConnectRetryDelay, normalizeErrorMessage(err))
|
||||
time.Sleep(startupConnectRetryDelay)
|
||||
}
|
||||
@@ -645,18 +650,21 @@ func (a *App) startupPhaseLabel() string {
|
||||
return fmt.Sprintf("稳定期(age=%s)", age)
|
||||
}
|
||||
|
||||
func (a *App) shouldRetryStartupConnect(err error, attempt int) bool {
|
||||
func (a *App) shouldRetryConnect(err error, attempt int) bool {
|
||||
if attempt >= startupConnectRetryAttempts {
|
||||
return false
|
||||
}
|
||||
if a == nil || a.startedAt.IsZero() {
|
||||
if !isTransientStartupConnectError(err) {
|
||||
return false
|
||||
}
|
||||
age := time.Since(a.startedAt)
|
||||
if age < 0 || age > startupConnectRetryWindow {
|
||||
return false
|
||||
if a != nil && !a.startedAt.IsZero() {
|
||||
age := time.Since(a.startedAt)
|
||||
if age >= 0 && age <= startupConnectRetryWindow {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return isTransientStartupConnectError(err)
|
||||
// Outside startup window, still grant one retry for transient network glitches.
|
||||
return attempt == 1
|
||||
}
|
||||
|
||||
func isTransientStartupConnectError(err error) bool {
|
||||
@@ -700,8 +708,8 @@ func (a *App) CancelQuery(queryID string) connection.QueryResult {
|
||||
return connection.QueryResult{Success: false, Message: "查询不存在或已完成"}
|
||||
}
|
||||
|
||||
// CleanupStaleQueries removes queries older than maxAge
|
||||
func (a *App) CleanupStaleQueries(maxAge time.Duration) {
|
||||
// cleanupStaleQueries removes queries older than maxAge.
|
||||
func (a *App) cleanupStaleQueries(maxAge time.Duration) {
|
||||
a.queryMu.Lock()
|
||||
defer a.queryMu.Unlock()
|
||||
|
||||
|
||||
@@ -2,11 +2,14 @@ package app
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/db"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
)
|
||||
|
||||
type fakeStartupRetryDB struct {
|
||||
@@ -106,7 +109,7 @@ func TestConnectDatabaseWithStartupRetry_RetriesTransientFailureAndReappliesGlob
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectDatabaseWithStartupRetry_DoesNotRetryOutsideStartupWindow(t *testing.T) {
|
||||
func TestConnectDatabaseWithStartupRetry_RetriesOnceOutsideStartupWindow(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
originalResolveDialConfigWithProxyFunc := resolveDialConfigWithProxyFunc
|
||||
defer func() {
|
||||
@@ -130,12 +133,165 @@ func TestConnectDatabaseWithStartupRetry_DoesNotRetryOutsideStartupWindow(t *tes
|
||||
a := &App{startedAt: time.Now().Add(-startupConnectRetryWindow - time.Second)}
|
||||
rawConfig := connection.ConnectionConfig{Type: "postgres", Host: "10.1.131.86", Port: 5432, User: "postgres"}
|
||||
|
||||
_, _, err := a.connectDatabaseWithStartupRetry(rawConfig)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if connectCalls != 2 {
|
||||
t.Fatalf("expected 2 connect attempts outside startup window, got %d", connectCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectDatabaseWithStartupRetry_DoesNotRetryOutsideStartupWindowForNonTransientError(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
originalResolveDialConfigWithProxyFunc := resolveDialConfigWithProxyFunc
|
||||
defer func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
resolveDialConfigWithProxyFunc = originalResolveDialConfigWithProxyFunc
|
||||
}()
|
||||
|
||||
connectCalls := 0
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return &fakeStartupRetryDB{
|
||||
connect: func(config connection.ConnectionConfig) error {
|
||||
connectCalls++
|
||||
return errors.New("pq: password authentication failed")
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
resolveDialConfigWithProxyFunc = func(raw connection.ConnectionConfig) (connection.ConnectionConfig, error) {
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
a := &App{startedAt: time.Now().Add(-startupConnectRetryWindow - time.Second)}
|
||||
rawConfig := connection.ConnectionConfig{Type: "postgres", Host: "10.1.131.86", Port: 5432, User: "postgres"}
|
||||
|
||||
_, _, err := a.connectDatabaseWithStartupRetry(rawConfig)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if connectCalls != 1 {
|
||||
t.Fatalf("expected 1 connect attempt outside startup window, got %d", connectCalls)
|
||||
t.Fatalf("expected 1 connect attempt outside startup window for non-transient error, got %d", connectCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectDatabaseWithStartupRetry_LogsRetryHintOutsideStartupWindow(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
originalResolveDialConfigWithProxyFunc := resolveDialConfigWithProxyFunc
|
||||
defer func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
resolveDialConfigWithProxyFunc = originalResolveDialConfigWithProxyFunc
|
||||
}()
|
||||
|
||||
logPath := logger.Path()
|
||||
beforeSize := int64(0)
|
||||
if fi, err := os.Stat(logPath); err == nil {
|
||||
beforeSize = fi.Size()
|
||||
}
|
||||
|
||||
connectCalls := 0
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return &fakeStartupRetryDB{
|
||||
connect: func(config connection.ConnectionConfig) error {
|
||||
connectCalls++
|
||||
if connectCalls == 1 {
|
||||
return errors.New("dial tcp 10.1.131.86:5432: connect: no route to host")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
resolveDialConfigWithProxyFunc = func(raw connection.ConnectionConfig) (connection.ConnectionConfig, error) {
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
a := &App{startedAt: time.Now().Add(-startupConnectRetryWindow - time.Second)}
|
||||
rawConfig := connection.ConnectionConfig{Type: "postgres", Host: "10.1.131.86", Port: 5432, User: "postgres"}
|
||||
|
||||
_, _, err := a.connectDatabaseWithStartupRetry(rawConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("expected success after retry, got error: %v", err)
|
||||
}
|
||||
if connectCalls != 2 {
|
||||
t.Fatalf("expected 2 connect attempts, got %d", connectCalls)
|
||||
}
|
||||
|
||||
logContent, readErr := os.ReadFile(logPath)
|
||||
if readErr != nil {
|
||||
t.Fatalf("read log failed: %v", readErr)
|
||||
}
|
||||
if int64(len(logContent)) < beforeSize {
|
||||
t.Fatalf("expected log file to grow, before=%d after=%d", beforeSize, len(logContent))
|
||||
}
|
||||
appended := string(logContent[beforeSize:])
|
||||
if !strings.Contains(appended, "检测到瞬时网络失败,准备重试连接") {
|
||||
t.Fatalf("expected retry hint log in appended segment, got: %s", appended)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectDatabaseWithStartupRetry_OutsideStartupWindowTransientFailureStopsAfterOneRetry(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
originalResolveDialConfigWithProxyFunc := resolveDialConfigWithProxyFunc
|
||||
defer func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
resolveDialConfigWithProxyFunc = originalResolveDialConfigWithProxyFunc
|
||||
}()
|
||||
|
||||
connectCalls := 0
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return &fakeStartupRetryDB{
|
||||
connect: func(config connection.ConnectionConfig) error {
|
||||
connectCalls++
|
||||
return errors.New("dial tcp 10.1.131.86:5432: connect: no route to host")
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
resolveDialConfigWithProxyFunc = func(raw connection.ConnectionConfig) (connection.ConnectionConfig, error) {
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
a := &App{startedAt: time.Now().Add(-startupConnectRetryWindow - time.Second)}
|
||||
rawConfig := connection.ConnectionConfig{Type: "postgres", Host: "10.1.131.86", Port: 5432, User: "postgres"}
|
||||
|
||||
_, _, err := a.connectDatabaseWithStartupRetry(rawConfig)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if connectCalls != 2 {
|
||||
t.Fatalf("expected 2 connect attempts outside startup window for transient error, got %d", connectCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectDatabaseWithStartupRetry_StartupWindowTransientFailureUsesFullRetryBudget(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
originalResolveDialConfigWithProxyFunc := resolveDialConfigWithProxyFunc
|
||||
defer func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
resolveDialConfigWithProxyFunc = originalResolveDialConfigWithProxyFunc
|
||||
}()
|
||||
|
||||
connectCalls := 0
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return &fakeStartupRetryDB{
|
||||
connect: func(config connection.ConnectionConfig) error {
|
||||
connectCalls++
|
||||
return errors.New("dial tcp 10.1.131.86:5432: connect: no route to host")
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
resolveDialConfigWithProxyFunc = func(raw connection.ConnectionConfig) (connection.ConnectionConfig, error) {
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
a := &App{startedAt: time.Now()}
|
||||
rawConfig := connection.ConnectionConfig{Type: "postgres", Host: "10.1.131.86", Port: 5432, User: "postgres"}
|
||||
|
||||
_, _, err := a.connectDatabaseWithStartupRetry(rawConfig)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if connectCalls != startupConnectRetryAttempts {
|
||||
t.Fatalf("expected %d connect attempts in startup window, got %d", startupConnectRetryAttempts, connectCalls)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -88,7 +88,7 @@ func TestCleanupStaleQueries(t *testing.T) {
|
||||
app.queryMu.Unlock()
|
||||
|
||||
// Cleanup queries older than 1 hour
|
||||
app.CleanupStaleQueries(1 * time.Hour)
|
||||
app.cleanupStaleQueries(1 * time.Hour)
|
||||
|
||||
// Verify stale query was removed
|
||||
app.queryMu.Lock()
|
||||
@@ -110,7 +110,7 @@ func TestCleanupStaleQueries(t *testing.T) {
|
||||
defer cancel2()
|
||||
|
||||
// Cleanup queries older than 1 hour
|
||||
app.CleanupStaleQueries(1 * time.Hour)
|
||||
app.cleanupStaleQueries(1 * time.Hour)
|
||||
|
||||
// Verify fresh query still exists
|
||||
app.queryMu.Lock()
|
||||
|
||||
4
main.go
4
main.go
@@ -34,8 +34,8 @@ func main() {
|
||||
},
|
||||
BackgroundColour: &options.RGBA{R: 0, G: 0, B: 0, A: 0},
|
||||
OnStartup: func(ctx context.Context) {
|
||||
application.Startup(ctx)
|
||||
aiService.Startup(ctx)
|
||||
app.InitializeLifecycle(application, ctx)
|
||||
aiservice.InitializeLifecycle(aiService, ctx)
|
||||
},
|
||||
OnShutdown: application.Shutdown,
|
||||
Bind: []interface{}{
|
||||
|
||||
Reference in New Issue
Block a user