mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-05-07 04:22:48 +08:00
🐛 fix(ai/provider): 修复 Kimi 与 MiniMax 供应商兼容路由
- 调整 Kimi 预设为 Anthropic 兼容入口并修正 Moonshot 域名回显 - 修复 Anthropic 请求地址归一化,确保聊天请求正确落到 /v1/messages - 修正 Kimi 模型列表与测试连接路由,固定使用 Moonshot /v1/models - 修正 MiniMax 默认模型与兼容模型集合,避免请求不存在的 /anthropic/v1/models - 为 MiniMax 健康检查改用最小化 messages 请求,并兼容旧模型名配置 - 补充 Kimi 与 MiniMax 供应商回归测试,更新需求追踪文档
This commit is contained in:
@@ -1 +1 @@
|
||||
dcb87159cf0f1f6f750d1c4870911d3f
|
||||
6ba85e4f456d2c0d230cab198c7dc02b
|
||||
@@ -483,3 +483,15 @@
|
||||
@keyframes ai-spin-anim {
|
||||
to { transform: rotate(360deg); }
|
||||
}
|
||||
|
||||
/* 面板/弹窗内部 toast 定位覆盖:从 fixed(视口顶部)改为 absolute(容器内部顶部) */
|
||||
.ai-chat-panel .ant-message,
|
||||
.ai-settings-body .ant-message {
|
||||
position: absolute !important;
|
||||
top: 16px !important;
|
||||
left: 50% !important;
|
||||
transform: translateX(-50%) !important;
|
||||
right: auto !important;
|
||||
width: max-content;
|
||||
z-index: 100;
|
||||
}
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import React, { useState, useRef, useEffect, useCallback, useMemo } from 'react';
|
||||
import { createPortal } from 'react-dom';
|
||||
import { useStore } from '../store';
|
||||
import { useStore, loadAISessionsFromBackend, loadAISessionFromBackend } from '../store';
|
||||
import { EventsOn, EventsOff } from '../../wailsjs/runtime';
|
||||
import { DBGetDatabases, DBGetTables } from '../../wailsjs/go/app/App';
|
||||
import type { OverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme';
|
||||
import { AIChatMessage, AIToolCall } from '../types';
|
||||
import { DownOutlined } from '@ant-design/icons';
|
||||
import { message } from 'antd';
|
||||
import { message as antdMessage } from 'antd';
|
||||
import './AIChatPanel.css';
|
||||
|
||||
import { AIChatHeader } from './ai/AIChatHeader';
|
||||
@@ -224,6 +224,9 @@ export const AIChatPanel: React.FC<AIChatPanelProps> = ({
|
||||
const panelRef = useRef<HTMLDivElement>(null); // 面板 DOM ref,用于拖拽时直接操作宽度
|
||||
const dragWidthRef = useRef(0); // 拖拽过程中的实时宽度(不触发 React 重渲染)
|
||||
|
||||
// 面板内部 toast 通知(不在屏幕顶部,而在面板容器内显示)
|
||||
const [messageApi, messageContextHolder] = antdMessage.useMessage({ getContainer: () => panelRef.current || document.body });
|
||||
|
||||
const aiChatHistory = useStore(state => state.aiChatHistory);
|
||||
const aiActiveSessionId = useStore(state => state.aiActiveSessionId);
|
||||
const createNewAISession = useStore(state => state.createNewAISession);
|
||||
@@ -280,6 +283,21 @@ export const AIChatPanel: React.FC<AIChatPanelProps> = ({
|
||||
}, [aiActiveSessionId, createNewAISession]);
|
||||
|
||||
const sid = aiActiveSessionId || 'session-fallback';
|
||||
|
||||
// 面板首次可见时从后端加载会话列表
|
||||
const sessionsLoadedRef = useRef(false);
|
||||
useEffect(() => {
|
||||
if (!aiPanelVisible || sessionsLoadedRef.current) return;
|
||||
sessionsLoadedRef.current = true;
|
||||
loadAISessionsFromBackend();
|
||||
}, [aiPanelVisible]);
|
||||
|
||||
// 切换会话时按需从后端加载消息
|
||||
useEffect(() => {
|
||||
if (sid && sid !== 'session-fallback') {
|
||||
loadAISessionFromBackend(sid);
|
||||
}
|
||||
}, [sid]);
|
||||
const messages = aiChatHistory[sid] || [];
|
||||
|
||||
const getConnectionName = useCallback(() => {
|
||||
@@ -314,6 +332,17 @@ export const AIChatPanel: React.FC<AIChatPanelProps> = ({
|
||||
|
||||
useEffect(() => { loadActiveProvider(); }, [loadActiveProvider]);
|
||||
|
||||
// 监听供应商配置变更(来自设置面板的删除/新增/切换操作),重新加载 active provider 并清空已缓存的模型
|
||||
useEffect(() => {
|
||||
const handler = () => {
|
||||
setDynamicModels([]);
|
||||
activeProviderIdRef.current = null;
|
||||
loadActiveProvider();
|
||||
};
|
||||
window.addEventListener('gonavi:ai:provider-changed', handler);
|
||||
return () => window.removeEventListener('gonavi:ai:provider-changed', handler);
|
||||
}, [loadActiveProvider]);
|
||||
|
||||
const handleModelChange = async (val: string) => {
|
||||
if (!activeProvider) return;
|
||||
try {
|
||||
@@ -331,7 +360,12 @@ export const AIChatPanel: React.FC<AIChatPanelProps> = ({
|
||||
setDynamicModels([]);
|
||||
activeProviderIdRef.current = activeProvider.id;
|
||||
}
|
||||
}, [activeProvider?.id]);
|
||||
// 供应商被删除后 activeProvider 变为 null,此时也必须清空残留模型
|
||||
if (!activeProvider) {
|
||||
setDynamicModels([]);
|
||||
activeProviderIdRef.current = null;
|
||||
}
|
||||
}, [activeProvider?.id, activeProvider]);
|
||||
|
||||
|
||||
// dynamicModels 仅在内存中使用,不再写回供应商配置,避免污染静态 models 列表
|
||||
@@ -346,11 +380,11 @@ export const AIChatPanel: React.FC<AIChatPanelProps> = ({
|
||||
const sortedModels = [...result.models].sort((a, b) => a.localeCompare(b));
|
||||
setDynamicModels(sortedModels);
|
||||
} else if (result && !result.success) {
|
||||
message.warning(result.error || '获取模型列表失败,可手动输入模型名称');
|
||||
messageApi.warning(result.error || '获取模型列表失败,可手动输入模型名称');
|
||||
}
|
||||
} catch (e: any) {
|
||||
console.warn('Failed to fetch models', e);
|
||||
message.warning('获取模型列表失败: ' + (e?.message || '未知错误'));
|
||||
messageApi.warning('获取模型列表失败: ' + (e?.message || '未知错误'));
|
||||
} finally {
|
||||
setLoadingModels(false);
|
||||
}
|
||||
@@ -993,6 +1027,17 @@ SELECT * FROM users WHERE status = 1;
|
||||
const handleSend = useCallback(async () => {
|
||||
const text = input.trim();
|
||||
if ((!text && draftImages.length === 0) || sending) return;
|
||||
|
||||
// 前置校验:必须配置供应商且选择模型后才能发送
|
||||
if (!activeProvider) {
|
||||
messageApi.warning('请先在 AI 设置中配置供应商');
|
||||
return;
|
||||
}
|
||||
if (!activeProvider.model || !activeProvider.model.trim()) {
|
||||
messageApi.warning('请先选择模型 ID(点击工具栏的模型下拉框选择)');
|
||||
return;
|
||||
}
|
||||
|
||||
toolCallRoundRef.current = 0; // 重置工具调用轮次计数
|
||||
nudgeCountRef.current = 0; // 重置催促计数
|
||||
|
||||
@@ -1083,7 +1128,7 @@ SELECT * FROM users WHERE status = 1;
|
||||
addAIChatMessage(sid, { id: genId(), role: 'assistant', content: `❌ 发送失败: ${cleanE2}`, rawError: cleanE2 !== rawE2 ? rawE2 : undefined, timestamp: Date.now() });
|
||||
setSending(false);
|
||||
}
|
||||
}, [input, draftImages, sending, messages, addAIChatMessage, sid]);
|
||||
}, [input, draftImages, sending, messages, addAIChatMessage, sid, activeProvider]);
|
||||
|
||||
const handleKeyDown = useCallback((e: React.KeyboardEvent) => {
|
||||
if (e.key === 'Enter' && !e.shiftKey) {
|
||||
@@ -1213,6 +1258,7 @@ SELECT * FROM users WHERE status = 1;
|
||||
|
||||
return (
|
||||
<div ref={panelRef} className="ai-chat-panel" style={{ width: panelWidth, background: bgColor || 'transparent', color: textColor, borderLeft: overlayTheme.shellBorder, position: 'relative' }}>
|
||||
{messageContextHolder}
|
||||
<div className={`ai-resize-handle${isResizing ? ' active' : ''}`} onMouseDown={handleResizeStart} />
|
||||
|
||||
{isResizing && panelRect.current && createPortal(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import React, { useState, useEffect, useCallback } from 'react';
|
||||
import { Modal, Button, Input, Select, Form, message, Tooltip, Tabs, Space, Popconfirm, Slider } from 'antd';
|
||||
import React, { useState, useEffect, useCallback, useRef } from 'react';
|
||||
import { Modal, Button, Input, Select, Form, message as antdMessage, Tooltip, Tabs, Space, Popconfirm, Slider } from 'antd';
|
||||
import { PlusOutlined, DeleteOutlined, EditOutlined, CheckOutlined, ApiOutlined, SafetyCertificateOutlined, RobotOutlined, ThunderboltOutlined, CloudOutlined, ExperimentOutlined, KeyOutlined, LinkOutlined, AppstoreOutlined, ToolOutlined } from '@ant-design/icons';
|
||||
import type { AIProviderConfig, AIProviderType, AISafetyLevel, AIContextLevel } from '../types';
|
||||
|
||||
@@ -26,21 +26,40 @@ interface ProviderPreset {
|
||||
}
|
||||
|
||||
const PROVIDER_PRESETS: ProviderPreset[] = [
|
||||
{ key: 'openai', label: 'OpenAI', icon: <ThunderboltOutlined />, desc: 'GPT-5.4 / 5.3 系列', color: '#10b981', backendType: 'openai', defaultBaseUrl: 'https://api.openai.com/v1', defaultModel: 'gpt-5.4', models: ['gpt-5.4', 'gpt-5.4-mini', 'gpt-5.4-nano', 'gpt-5.3'] },
|
||||
{ key: 'deepseek', label: 'DeepSeek', icon: <ThunderboltOutlined />, desc: 'DeepSeek-V4 / R1', color: '#3b82f6', backendType: 'openai', defaultBaseUrl: 'https://api.deepseek.com/v1', defaultModel: 'deepseek-chat', models: ['deepseek-chat', 'deepseek-reasoner'] },
|
||||
{ key: 'qwen', label: '通义千问', icon: <CloudOutlined />, desc: 'Qwen3.5 / Qwen3 系列', color: '#6366f1', backendType: 'openai', defaultBaseUrl: 'https://dashscope.aliyuncs.com/compatible-mode/v1', defaultModel: 'qwen3.5-max', models: ['qwen3.5-max', 'qwen3-plus', 'qwen3-turbo'] },
|
||||
{ key: 'zhipu', label: '智谱 GLM', icon: <ExperimentOutlined />, desc: 'GLM-5 / GLM-5-Turbo', color: '#0ea5e9', backendType: 'openai', defaultBaseUrl: 'https://open.bigmodel.cn/api/paas/v4', defaultModel: 'glm-5', models: ['glm-5', 'glm-5-turbo', 'glm-4.7-flash'] },
|
||||
{ key: 'moonshot', label: 'Kimi', icon: <ExperimentOutlined />, desc: 'Kimi K2.5 系列', color: '#0d9488', backendType: 'openai', defaultBaseUrl: 'https://api.moonshot.cn/v1', defaultModel: 'kimi-k2.5', models: ['kimi-k2.5', 'kimi-k2-turbo-preview', 'kimi-k2-thinking'] },
|
||||
{ key: 'anthropic', label: 'Claude', icon: <ExperimentOutlined />, desc: 'Claude Opus/Sonnet 4.6', color: '#d97706', backendType: 'anthropic', defaultBaseUrl: 'https://api.anthropic.com', defaultModel: 'claude-sonnet-4-6', models: ['claude-opus-4-6', 'claude-sonnet-4-6'] },
|
||||
{ key: 'gemini', label: 'Gemini', icon: <CloudOutlined />, desc: 'Gemini 3.1 / 2.5 系列', color: '#059669', backendType: 'gemini', defaultBaseUrl: 'https://generativelanguage.googleapis.com', defaultModel: 'gemini-2.5-flash', models: ['gemini-3.1-pro', 'gemini-2.5-flash', 'gemini-2.5-pro'] },
|
||||
{ key: 'openai', label: 'OpenAI', icon: <ApiOutlined />, desc: 'GPT-5.4 / 5.3 系列', color: '#10b981', backendType: 'openai', defaultBaseUrl: 'https://api.openai.com/v1', defaultModel: 'gpt-4o', models: [] },
|
||||
{ key: 'deepseek', label: 'DeepSeek', icon: <ThunderboltOutlined />, desc: 'DeepSeek-V4 / R1', color: '#3b82f6', backendType: 'openai', defaultBaseUrl: 'https://api.deepseek.com/v1', defaultModel: 'deepseek-chat', models: [] },
|
||||
{ key: 'qwen', label: '通义千问', icon: <CloudOutlined />, desc: 'Qwen3.5 / Qwen3 系列', color: '#6366f1', backendType: 'openai', defaultBaseUrl: 'https://dashscope.aliyuncs.com/compatible-mode/v1', defaultModel: 'qwen-max', models: [] },
|
||||
{ key: 'zhipu', label: '智谱 GLM', icon: <ExperimentOutlined />, desc: 'GLM-5 / GLM-5-Turbo', color: '#0ea5e9', backendType: 'openai', defaultBaseUrl: 'https://open.bigmodel.cn/api/paas/v4', defaultModel: 'glm-4', models: [] },
|
||||
{ key: 'moonshot', label: 'Kimi', icon: <ExperimentOutlined />, desc: 'Kimi K2.5 (Anthropic 兼容)', color: '#0d9488', backendType: 'anthropic', defaultBaseUrl: 'https://api.moonshot.cn/anthropic', defaultModel: 'moonshot-v1-8k', models: [] },
|
||||
{ key: 'anthropic', label: 'Claude', icon: <ExperimentOutlined />, desc: 'Claude Opus/Sonnet', color: '#d97706', backendType: 'anthropic', defaultBaseUrl: 'https://api.anthropic.com', defaultModel: 'claude-3-5-sonnet-20241022', models: [] },
|
||||
{ key: 'gemini', label: 'Gemini', icon: <CloudOutlined />, desc: 'Gemini 3.1 / 2.5 系列', color: '#059669', backendType: 'gemini', defaultBaseUrl: 'https://generativelanguage.googleapis.com', defaultModel: 'gemini-2.5-flash', models: [] },
|
||||
{ key: 'volcengine', label: '火山引擎', icon: <CloudOutlined />, desc: '火山方舟 / 豆包大模型', color: '#0ea5e9', backendType: 'openai', defaultBaseUrl: 'https://ark.cn-beijing.volces.com/api/v3', defaultModel: 'ep-xxxxxx', models: [] },
|
||||
{ key: 'minimax', label: 'MiniMax', icon: <ExperimentOutlined />, desc: 'abab6.5 / abab7 系列', color: '#e11d48', backendType: 'anthropic', defaultBaseUrl: 'https://api.minimaxi.com/anthropic', defaultModel: 'MiniMax-Text-01', models: ['MiniMax-Text-01', 'MiniMax-Text-01-vision', 'MiniMax-Text-01-search', 'MiniMax-Text-01-code', 'MiniMax-Text-01-web', 'MiniMax-Text-01-sql', 'MiniMax-Text-01-python', 'MiniMax-Text-01-math', 'MiniMax-Text-01-doc'] },
|
||||
{ key: 'minimax', label: 'MiniMax', icon: <ExperimentOutlined />, desc: 'M2.7 / M2.5 系列 (Anthropic 兼容)', color: '#e11d48', backendType: 'anthropic', defaultBaseUrl: 'https://api.minimaxi.com/anthropic', defaultModel: 'MiniMax-M2.7', models: ['MiniMax-M2.7', 'MiniMax-M2.7-highspeed', 'MiniMax-M2.5', 'MiniMax-M2.5-highspeed', 'MiniMax-M2.1', 'MiniMax-M2.1-highspeed', 'MiniMax-M2'] },
|
||||
{ key: 'ollama', label: 'Ollama', icon: <AppstoreOutlined />, desc: '本地部署开源模型', color: '#78716c', backendType: 'openai', defaultBaseUrl: 'http://localhost:11434/v1', defaultModel: 'llama3', models: [] },
|
||||
{ key: 'custom', label: '自定义', icon: <AppstoreOutlined />, desc: '自定义 API 端点', color: '#64748b', backendType: 'custom', defaultBaseUrl: '', defaultModel: '', models: [] },
|
||||
];
|
||||
|
||||
const findPreset = (key: string): ProviderPreset => PROVIDER_PRESETS.find(p => p.key === key) || PROVIDER_PRESETS[PROVIDER_PRESETS.length - 1];
|
||||
|
||||
const getProviderHostname = (raw?: string): string => {
|
||||
if (!raw) return '';
|
||||
try {
|
||||
return new URL(raw).hostname.toLowerCase();
|
||||
} catch {
|
||||
return '';
|
||||
}
|
||||
};
|
||||
|
||||
const matchProviderPreset = (provider: Pick<AIProviderConfig, 'type' | 'baseUrl'>): ProviderPreset => {
|
||||
const host = getProviderHostname(provider.baseUrl);
|
||||
if (host.endsWith('moonshot.cn')) {
|
||||
return findPreset('moonshot');
|
||||
}
|
||||
return PROVIDER_PRESETS.find(pr => pr.backendType === provider.type && host !== '' && host === getProviderHostname(pr.defaultBaseUrl))
|
||||
|| PROVIDER_PRESETS.find(pr => pr.backendType === provider.type)
|
||||
|| findPreset('custom');
|
||||
};
|
||||
|
||||
const SAFETY_OPTIONS: { label: string; value: AISafetyLevel; desc: string; color: string; icon: string }[] = [
|
||||
{ label: '只读模式', value: 'readonly', desc: 'AI 仅可执行 SELECT 等查询操作,最安全', color: '#22c55e', icon: '🔒' },
|
||||
{ label: '读写模式', value: 'readwrite', desc: 'AI 可执行 INSERT/UPDATE/DELETE,危险操作需二次确认', color: '#f59e0b', icon: '⚠️' },
|
||||
@@ -65,6 +84,10 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
const [builtinPrompts, setBuiltinPrompts] = useState<Record<string, string>>({});
|
||||
const [activeSection, setActiveSection] = useState<'providers' | 'safety' | 'context' | 'prompts' | 'tools'>('providers');
|
||||
const [form] = Form.useForm();
|
||||
const modalBodyRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
// Modal 内部 toast 通知
|
||||
const [messageApi, messageContextHolder] = antdMessage.useMessage({ getContainer: () => modalBodyRef.current || document.body });
|
||||
|
||||
// 主题色
|
||||
const cardBg = darkMode ? 'rgba(255,255,255,0.04)' : 'rgba(0,0,0,0.02)';
|
||||
@@ -108,31 +131,45 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
const newProvider: AIProviderConfig = {
|
||||
id: '', type: preset.backendType, name: '', apiKey: '',
|
||||
baseUrl: preset.defaultBaseUrl, model: preset.defaultModel,
|
||||
maxTokens: 4096, temperature: 0.7,
|
||||
models: [], maxTokens: 4096, temperature: 0.7,
|
||||
};
|
||||
setEditingProvider({ ...newProvider, presetKey: 'openai' } as any);
|
||||
setIsEditing(true);
|
||||
setTestStatus('idle');
|
||||
form.resetFields();
|
||||
form.setFieldsValue({ ...newProvider, presetKey: 'openai', apiFormat: 'openai' });
|
||||
};
|
||||
|
||||
const handleEditProvider = (p: AIProviderConfig) => {
|
||||
// 尝试根据 baseUrl 和 type 推断 preset
|
||||
const matchedPreset = PROVIDER_PRESETS.find(pr => pr.backendType === p.type && p.baseUrl?.includes(new URL(pr.defaultBaseUrl || 'http://x').hostname))
|
||||
|| PROVIDER_PRESETS.find(pr => pr.backendType === p.type)
|
||||
|| findPreset('custom');
|
||||
const matchedPreset = matchProviderPreset(p);
|
||||
setEditingProvider(p);
|
||||
setIsEditing(true);
|
||||
setTestStatus('idle');
|
||||
form.setFieldsValue({ ...p, presetKey: matchedPreset.key, apiFormat: p.apiFormat || 'openai' });
|
||||
form.resetFields();
|
||||
form.setFieldsValue({ ...p, type: matchedPreset.backendType, models: p.models || [], presetKey: matchedPreset.key, apiFormat: p.apiFormat || 'openai' });
|
||||
};
|
||||
|
||||
const handleDeleteProvider = async (id: string) => {
|
||||
try {
|
||||
const Service = (window as any).go?.aiservice?.Service;
|
||||
const wasActive = id === activeProviderId;
|
||||
await Service?.AIDeleteProvider?.(id);
|
||||
void message.success('已删除'); void loadConfig();
|
||||
} catch (e: any) { void message.error(e?.message || '删除失败'); }
|
||||
await loadConfig();
|
||||
// 合并提示:删除的是当前激活的供应商时,附带自动切换信息
|
||||
if (wasActive) {
|
||||
const newProviders: any[] = await Service?.AIGetProviders?.() || [];
|
||||
if (newProviders.length > 0) {
|
||||
const newActiveName = newProviders[0]?.name || '下一个供应商';
|
||||
void messageApi.success(`已删除,自动切换到「${newActiveName}」`);
|
||||
} else {
|
||||
void messageApi.success('已删除');
|
||||
}
|
||||
} else {
|
||||
void messageApi.success('已删除');
|
||||
}
|
||||
window.dispatchEvent(new CustomEvent('gonavi:ai:provider-changed'));
|
||||
} catch (e: any) { void messageApi.error(e?.message || '删除失败'); }
|
||||
};
|
||||
|
||||
const handleSaveProvider = async () => {
|
||||
@@ -150,20 +187,24 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
// 内置供应商自动使用 preset label 作为名称
|
||||
const finalName = isCustomLike ? (values.name || preset.label) : preset.label;
|
||||
|
||||
const finalBaseUrl = values.baseUrl || preset.defaultBaseUrl;
|
||||
|
||||
const payload = {
|
||||
...editingProvider,
|
||||
...values,
|
||||
name: finalName,
|
||||
model: finalModel,
|
||||
models: resolvedModels,
|
||||
baseUrl: finalBaseUrl,
|
||||
apiFormat: values.apiFormat || 'openai',
|
||||
};
|
||||
// 后端 AISaveProvider 统一处理新增和更新,返回 void,失败抛异常
|
||||
await Service?.AISaveProvider?.(payload);
|
||||
void message.success('已保存'); setIsEditing(false); setEditingProvider(null); void loadConfig();
|
||||
void messageApi.success('已保存'); setIsEditing(false); setEditingProvider(null); void loadConfig();
|
||||
window.dispatchEvent(new CustomEvent('gonavi:ai:provider-changed'));
|
||||
} catch (e: any) {
|
||||
if (e?.errorFields) { /* antd form validation error, ignore */ }
|
||||
else void message.error(e?.message || '保存失败');
|
||||
else void messageApi.error(e?.message || '保存失败');
|
||||
} finally { setLoading(false); }
|
||||
};
|
||||
|
||||
@@ -171,8 +212,9 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
try {
|
||||
const Service = (window as any).go?.aiservice?.Service;
|
||||
await Service?.AISetActiveProvider?.(id);
|
||||
setActiveProviderId(id); void message.success('已切换');
|
||||
} catch (e: any) { void message.error(e?.message || '切换失败'); }
|
||||
setActiveProviderId(id); void messageApi.success('已切换');
|
||||
window.dispatchEvent(new CustomEvent('gonavi:ai:provider-changed'));
|
||||
} catch (e: any) { void messageApi.error(e?.message || '切换失败'); }
|
||||
};
|
||||
|
||||
const handleSafetyChange = async (level: AISafetyLevel) => {
|
||||
@@ -197,10 +239,12 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
setLoading(true);
|
||||
setTestStatus('idle');
|
||||
const Service = (window as any).go?.aiservice?.Service;
|
||||
const res = await Service?.AITestProvider?.({ ...values, maxTokens: Number(values.maxTokens) || 4096, temperature: Number(values.temperature) ?? 0.7 });
|
||||
if (res?.success) { setTestStatus('success'); void message.success('连接成功'); }
|
||||
else { setTestStatus('error'); void message.error(`测试失败: ${res?.message || '未知错误'}`); }
|
||||
} catch (e: any) { setTestStatus('error'); void message.error(e?.message || '测试失败'); }
|
||||
const preset = findPreset(values.presetKey || 'openai');
|
||||
const finalBaseUrl = values.baseUrl || preset.defaultBaseUrl;
|
||||
const res = await Service?.AITestProvider?.({ ...values, baseUrl: finalBaseUrl, maxTokens: Number(values.maxTokens) || 4096, temperature: Number(values.temperature) ?? 0.7 });
|
||||
if (res?.success) { setTestStatus('success'); void messageApi.success('连接成功'); }
|
||||
else { setTestStatus('error'); void messageApi.error(`测试失败: ${res?.message || '未知错误'}`); }
|
||||
} catch (e: any) { setTestStatus('error'); void messageApi.error(e?.message || '测试失败'); }
|
||||
finally { setLoading(false); }
|
||||
};
|
||||
|
||||
@@ -238,9 +282,7 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
</div>
|
||||
)}
|
||||
{providers.map(p => {
|
||||
const matchedPreset = PROVIDER_PRESETS.find(pr => pr.backendType === p.type && p.baseUrl?.includes(new URL(pr.defaultBaseUrl || 'http://x').hostname))
|
||||
|| PROVIDER_PRESETS.find(pr => pr.backendType === p.type)
|
||||
|| findPreset('custom');
|
||||
const matchedPreset = matchProviderPreset(p);
|
||||
const isActive = p.id === activeProviderId;
|
||||
return (
|
||||
<div key={p.id} onClick={() => handleSetActive(p.id)} style={{
|
||||
@@ -605,7 +647,8 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
body: { paddingTop: 8, height: 620, overflow: 'hidden' },
|
||||
}}
|
||||
>
|
||||
<div style={{ display: 'grid', gridTemplateColumns: '180px minmax(0, 1fr)', gap: 16, padding: '12px 0', height: '100%', minHeight: 0, overflow: 'hidden', alignItems: 'stretch' }}>
|
||||
<div ref={modalBodyRef} className="ai-settings-body" style={{ display: 'grid', gridTemplateColumns: '180px minmax(0, 1fr)', gap: 16, padding: '12px 0', height: '100%', minHeight: 0, overflow: 'hidden', alignItems: 'stretch', position: 'relative' }}>
|
||||
{messageContextHolder}
|
||||
<div style={{ padding: '0 12px', height: 'fit-content' }}>
|
||||
<div style={{ marginBottom: 12, fontWeight: 600, color: overlayTheme.titleText }}>设置导航</div>
|
||||
<div style={{ display: 'grid', gap: 10 }}>
|
||||
|
||||
@@ -544,7 +544,28 @@ export const AIMessageBubble: React.FC<AIMessageBubbleProps> = React.memo(({ msg
|
||||
const [isCopied, setIsCopied] = useState(false);
|
||||
const isUser = msg.role === 'user';
|
||||
|
||||
const displayContent = msg.content;
|
||||
// 从 content 中提取 <think>...</think> 标签内容(部分模型如 MiniMax、DeepSeek 会以文本形式返回思考过程)
|
||||
const { displayContent, parsedThinking } = React.useMemo(() => {
|
||||
const content = msg.content || '';
|
||||
// 优先使用后端已结构化的 thinking 字段(如 Claude API 原生 thinking)
|
||||
if (msg.thinking) {
|
||||
return { displayContent: content, parsedThinking: msg.thinking };
|
||||
}
|
||||
// 尝试从 content 中提取 <think>...</think> 标签
|
||||
const thinkRegex = /<think>([\s\S]*?)(?:<\/think>|$)/g;
|
||||
let thinkParts: string[] = [];
|
||||
let cleanContent = content;
|
||||
let match;
|
||||
while ((match = thinkRegex.exec(content)) !== null) {
|
||||
thinkParts.push(match[1].trim());
|
||||
}
|
||||
if (thinkParts.length > 0) {
|
||||
// 移除所有 <think>...</think> 标签(含未闭合的)
|
||||
cleanContent = content.replace(/<think>[\s\S]*?(?:<\/think>|$)/g, '').trim();
|
||||
return { displayContent: cleanContent, parsedThinking: thinkParts.join('\n\n') };
|
||||
}
|
||||
return { displayContent: content, parsedThinking: '' };
|
||||
}, [msg.content, msg.thinking]);
|
||||
const isTypingThinking = !!(msg.loading && msg.phase === 'thinking');
|
||||
|
||||
if (msg.role === 'tool') return null;
|
||||
@@ -568,11 +589,11 @@ export const AIMessageBubble: React.FC<AIMessageBubbleProps> = React.memo(({ msg
|
||||
</div>
|
||||
|
||||
{/* 即使在波纹过渡态,如果有 thinking / tool_calls 也要显示出来,只是把它们压在波纹下面 */}
|
||||
<div style={{ marginTop: msg.thinking || (msg.tool_calls && msg.tool_calls.length > 0) ? 12 : 0 }}>
|
||||
{!isUser && msg.thinking && (
|
||||
<div style={{ marginTop: parsedThinking || (msg.tool_calls && msg.tool_calls.length > 0) ? 12 : 0 }}>
|
||||
{!isUser && parsedThinking && (
|
||||
<ThinkingBlock
|
||||
displayThinking={msg.thinking}
|
||||
totalLen={msg.thinking.length}
|
||||
displayThinking={parsedThinking}
|
||||
totalLen={parsedThinking.length}
|
||||
isTyping={isTypingThinking}
|
||||
isGlobalLoading={!!msg.loading}
|
||||
darkMode={darkMode}
|
||||
@@ -649,10 +670,10 @@ export const AIMessageBubble: React.FC<AIMessageBubbleProps> = React.memo(({ msg
|
||||
</div>
|
||||
)}
|
||||
{/* 可折叠思考过程 */}
|
||||
{!isUser && msg.thinking && (
|
||||
{!isUser && parsedThinking && (
|
||||
<ThinkingBlock
|
||||
displayThinking={msg.thinking}
|
||||
totalLen={msg.thinking.length}
|
||||
displayThinking={parsedThinking}
|
||||
totalLen={parsedThinking.length}
|
||||
isTyping={isTypingThinking}
|
||||
isGlobalLoading={!!msg.loading}
|
||||
darkMode={darkMode}
|
||||
|
||||
@@ -667,6 +667,74 @@ const unwrapPersistedAppState = (persistedState: unknown): Record<string, unknow
|
||||
return raw;
|
||||
};
|
||||
|
||||
// --- AI 会话文件持久化辅助函数 ---
|
||||
|
||||
/** 每个 session 独立防抖定时器(2秒) */
|
||||
const _persistTimers: Record<string, ReturnType<typeof setTimeout>> = {};
|
||||
|
||||
function _debouncedPersistSession(sessionId: string) {
|
||||
if (_persistTimers[sessionId]) clearTimeout(_persistTimers[sessionId]);
|
||||
_persistTimers[sessionId] = setTimeout(() => {
|
||||
delete _persistTimers[sessionId];
|
||||
const state = useStore.getState();
|
||||
const messages = state.aiChatHistory[sessionId];
|
||||
const sessionMeta = state.aiChatSessions.find(s => s.id === sessionId);
|
||||
if (!messages && !sessionMeta) return; // session 已被删除,跳过
|
||||
const title = sessionMeta?.title || '新的对话';
|
||||
const updatedAt = sessionMeta?.updatedAt || Date.now();
|
||||
const messagesJSON = JSON.stringify(messages || []);
|
||||
const Service = (window as any).go?.aiservice?.Service;
|
||||
Service?.AISaveSession?.(sessionId, title, updatedAt, messagesJSON).catch((e: any) => {
|
||||
console.error('[AI Session Persist] 持久化失败:', sessionId, e);
|
||||
});
|
||||
}, 2000);
|
||||
}
|
||||
|
||||
/** 从后端加载会话列表(仅元数据,不含消息体) */
|
||||
export async function loadAISessionsFromBackend(): Promise<{ id: string; title: string; updatedAt: number }[]> {
|
||||
const Service = (window as any).go?.aiservice?.Service;
|
||||
if (!Service?.AIGetSessions) return [];
|
||||
try {
|
||||
const sessions = await Service.AIGetSessions();
|
||||
if (Array.isArray(sessions)) {
|
||||
useStore.setState({ aiChatSessions: sessions });
|
||||
return sessions;
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('[AI Session] 加载会话列表失败:', e);
|
||||
}
|
||||
return [];
|
||||
}
|
||||
|
||||
/** 从后端加载指定会话的消息数据到内存 */
|
||||
export async function loadAISessionFromBackend(sessionId: string): Promise<boolean> {
|
||||
const state = useStore.getState();
|
||||
// 如果内存中已有消息,跳过重复加载
|
||||
if (state.aiChatHistory[sessionId]?.length > 0) return true;
|
||||
|
||||
const Service = (window as any).go?.aiservice?.Service;
|
||||
if (!Service?.AILoadSession) return false;
|
||||
try {
|
||||
const result = await Service.AILoadSession(sessionId);
|
||||
if (result?.success) {
|
||||
let messages = result.messages;
|
||||
// messages 可能是 JSON string 或已解析的数组
|
||||
if (typeof messages === 'string') {
|
||||
try { messages = JSON.parse(messages); } catch { messages = []; }
|
||||
}
|
||||
if (Array.isArray(messages)) {
|
||||
useStore.setState((prev) => ({
|
||||
aiChatHistory: { ...prev.aiChatHistory, [sessionId]: messages },
|
||||
}));
|
||||
return true;
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('[AI Session] 加载会话消息失败:', sessionId, e);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
export const useStore = create<AppState>()(
|
||||
persist(
|
||||
(set) => ({
|
||||
@@ -986,99 +1054,123 @@ export const useStore = create<AppState>()(
|
||||
// AI actions
|
||||
toggleAIPanel: () => set((state) => ({ aiPanelVisible: !state.aiPanelVisible })),
|
||||
setAIPanelVisible: (visible) => set({ aiPanelVisible: visible }),
|
||||
addAIChatMessage: (sessionId, message) => set((state) => {
|
||||
const history = { ...state.aiChatHistory };
|
||||
const messages = history[sessionId] || [];
|
||||
history[sessionId] = [...messages, message];
|
||||
|
||||
let newSessions = [...state.aiChatSessions];
|
||||
const existingSession = newSessions.find(s => s.id === sessionId);
|
||||
|
||||
if (!existingSession) {
|
||||
// 生成标题(首个 user message 内容前 20 字符)
|
||||
let title = message.role === 'user' ? message.content : '新的对话';
|
||||
if (title.length > 20) {
|
||||
title = title.substring(0, 20) + '...';
|
||||
}
|
||||
newSessions.unshift({ id: sessionId, title, updatedAt: Date.now() });
|
||||
} else {
|
||||
// 提至最新
|
||||
newSessions = newSessions.filter(s => s.id !== sessionId);
|
||||
newSessions.unshift({ ...existingSession, updatedAt: Date.now() });
|
||||
}
|
||||
|
||||
return { aiChatHistory: history, aiChatSessions: newSessions };
|
||||
}),
|
||||
updateAIChatMessage: (sessionId, messageId, updates) => set((state) => {
|
||||
const messages = state.aiChatHistory[sessionId];
|
||||
if (!messages) return state;
|
||||
// 🔧 性能优化:用 findIndex + 定点替换代替全量 map,长对话场景下从 O(n) 降至 O(1)
|
||||
const idx = messages.findIndex(m => m.id === messageId);
|
||||
if (idx < 0) return state;
|
||||
const newMessages = [...messages];
|
||||
newMessages[idx] = { ...newMessages[idx], ...updates };
|
||||
const history = { ...state.aiChatHistory, [sessionId]: newMessages };
|
||||
// 仅当非纯 content 追加时才重排 session 顺序(性能优化:流式打字时跳过)
|
||||
const isContentOnlyUpdate = Object.keys(updates).length === 1 && 'content' in updates;
|
||||
if (!isContentOnlyUpdate) {
|
||||
let newSessions = [...state.aiChatSessions];
|
||||
const existingSession = newSessions.find(s => s.id === sessionId);
|
||||
if (existingSession) {
|
||||
newSessions = newSessions.filter(s => s.id !== sessionId);
|
||||
newSessions.unshift({ ...existingSession, updatedAt: Date.now() });
|
||||
}
|
||||
return { aiChatHistory: history, aiChatSessions: newSessions };
|
||||
}
|
||||
return { aiChatHistory: history };
|
||||
}),
|
||||
deleteAIChatMessage: (sessionId, messageId) => set((state) => {
|
||||
const history = { ...state.aiChatHistory };
|
||||
if (history[sessionId]) {
|
||||
history[sessionId] = history[sessionId].filter(m => m.id !== messageId);
|
||||
}
|
||||
return { aiChatHistory: history };
|
||||
}),
|
||||
truncateAIChatMessages: (sessionId, upToMessageId) => set((state) => {
|
||||
const history = { ...state.aiChatHistory };
|
||||
const messages = history[sessionId];
|
||||
if (messages) {
|
||||
const idx = messages.findIndex(m => m.id === upToMessageId);
|
||||
if (idx >= 0) {
|
||||
history[sessionId] = messages.slice(0, idx + 1);
|
||||
}
|
||||
}
|
||||
return { aiChatHistory: history };
|
||||
}),
|
||||
clearAIChatHistory: (sessionId) => set((state) => {
|
||||
const history = { ...state.aiChatHistory };
|
||||
delete history[sessionId];
|
||||
return { aiChatHistory: history };
|
||||
}),
|
||||
replaceAIChatHistory: (sessionId, messages) => set((state) => {
|
||||
const history = { ...state.aiChatHistory };
|
||||
history[sessionId] = messages;
|
||||
return { aiChatHistory: history };
|
||||
}),
|
||||
deleteAISession: (sessionId) => set((state) => {
|
||||
const history = { ...state.aiChatHistory };
|
||||
delete history[sessionId];
|
||||
const newSessions = state.aiChatSessions.filter(s => s.id !== sessionId);
|
||||
const newActive = state.aiActiveSessionId === sessionId ? null : state.aiActiveSessionId;
|
||||
return { aiChatHistory: history, aiChatSessions: newSessions, aiActiveSessionId: newActive };
|
||||
}),
|
||||
addAIChatMessage: (sessionId, message) => {
|
||||
set((state) => {
|
||||
const history = { ...state.aiChatHistory };
|
||||
const messages = history[sessionId] || [];
|
||||
history[sessionId] = [...messages, message];
|
||||
|
||||
let newSessions = [...state.aiChatSessions];
|
||||
const existingSession = newSessions.find(s => s.id === sessionId);
|
||||
|
||||
if (!existingSession) {
|
||||
let title = message.role === 'user' ? message.content : '新的对话';
|
||||
if (title.length > 20) {
|
||||
title = title.substring(0, 20) + '...';
|
||||
}
|
||||
newSessions.unshift({ id: sessionId, title, updatedAt: Date.now() });
|
||||
} else {
|
||||
newSessions = newSessions.filter(s => s.id !== sessionId);
|
||||
newSessions.unshift({ ...existingSession, updatedAt: Date.now() });
|
||||
}
|
||||
|
||||
return { aiChatHistory: history, aiChatSessions: newSessions };
|
||||
});
|
||||
// 异步持久化到文件(fire-and-forget,防抖由外层控制)
|
||||
_debouncedPersistSession(sessionId);
|
||||
},
|
||||
updateAIChatMessage: (sessionId, messageId, updates) => {
|
||||
set((state) => {
|
||||
const messages = state.aiChatHistory[sessionId];
|
||||
if (!messages) return state;
|
||||
const idx = messages.findIndex(m => m.id === messageId);
|
||||
if (idx < 0) return state;
|
||||
const newMessages = [...messages];
|
||||
newMessages[idx] = { ...newMessages[idx], ...updates };
|
||||
const history = { ...state.aiChatHistory, [sessionId]: newMessages };
|
||||
const isContentOnlyUpdate = Object.keys(updates).length === 1 && 'content' in updates;
|
||||
if (!isContentOnlyUpdate) {
|
||||
let newSessions = [...state.aiChatSessions];
|
||||
const existingSession = newSessions.find(s => s.id === sessionId);
|
||||
if (existingSession) {
|
||||
newSessions = newSessions.filter(s => s.id !== sessionId);
|
||||
newSessions.unshift({ ...existingSession, updatedAt: Date.now() });
|
||||
}
|
||||
return { aiChatHistory: history, aiChatSessions: newSessions };
|
||||
}
|
||||
return { aiChatHistory: history };
|
||||
});
|
||||
// 流式打字高频调用,防抖 2 秒后才写磁盘
|
||||
_debouncedPersistSession(sessionId);
|
||||
},
|
||||
deleteAIChatMessage: (sessionId, messageId) => {
|
||||
set((state) => {
|
||||
const history = { ...state.aiChatHistory };
|
||||
if (history[sessionId]) {
|
||||
history[sessionId] = history[sessionId].filter(m => m.id !== messageId);
|
||||
}
|
||||
return { aiChatHistory: history };
|
||||
});
|
||||
_debouncedPersistSession(sessionId);
|
||||
},
|
||||
truncateAIChatMessages: (sessionId, upToMessageId) => {
|
||||
set((state) => {
|
||||
const history = { ...state.aiChatHistory };
|
||||
const messages = history[sessionId];
|
||||
if (messages) {
|
||||
const idx = messages.findIndex(m => m.id === upToMessageId);
|
||||
if (idx >= 0) {
|
||||
history[sessionId] = messages.slice(0, idx + 1);
|
||||
}
|
||||
}
|
||||
return { aiChatHistory: history };
|
||||
});
|
||||
_debouncedPersistSession(sessionId);
|
||||
},
|
||||
clearAIChatHistory: (sessionId) => {
|
||||
set((state) => {
|
||||
const history = { ...state.aiChatHistory };
|
||||
delete history[sessionId];
|
||||
return { aiChatHistory: history };
|
||||
});
|
||||
_debouncedPersistSession(sessionId);
|
||||
},
|
||||
replaceAIChatHistory: (sessionId, messages) => {
|
||||
set((state) => {
|
||||
const history = { ...state.aiChatHistory };
|
||||
history[sessionId] = messages;
|
||||
return { aiChatHistory: history };
|
||||
});
|
||||
_debouncedPersistSession(sessionId);
|
||||
},
|
||||
deleteAISession: (sessionId) => {
|
||||
set((state) => {
|
||||
const history = { ...state.aiChatHistory };
|
||||
delete history[sessionId];
|
||||
const newSessions = state.aiChatSessions.filter(s => s.id !== sessionId);
|
||||
const newActive = state.aiActiveSessionId === sessionId ? null : state.aiActiveSessionId;
|
||||
return { aiChatHistory: history, aiChatSessions: newSessions, aiActiveSessionId: newActive };
|
||||
});
|
||||
// 删除文件
|
||||
const Service = (window as any).go?.aiservice?.Service;
|
||||
Service?.AIDeleteSession?.(sessionId).catch(() => {});
|
||||
},
|
||||
createNewAISession: () => set(() => {
|
||||
const newId = `session-${Date.now()}`;
|
||||
return { aiActiveSessionId: newId };
|
||||
}),
|
||||
setAIActiveSessionId: (sessionId) => set({ aiActiveSessionId: sessionId }),
|
||||
updateAISessionTitle: (sessionId, title) => set((state) => {
|
||||
const newSessions = [...state.aiChatSessions];
|
||||
const session = newSessions.find(s => s.id === sessionId);
|
||||
if (session) {
|
||||
session.title = title;
|
||||
}
|
||||
return { aiChatSessions: newSessions };
|
||||
}),
|
||||
updateAISessionTitle: (sessionId, title) => {
|
||||
set((state) => {
|
||||
const newSessions = [...state.aiChatSessions];
|
||||
const session = newSessions.find(s => s.id === sessionId);
|
||||
if (session) {
|
||||
session.title = title;
|
||||
}
|
||||
return { aiChatSessions: newSessions };
|
||||
});
|
||||
_debouncedPersistSession(sessionId);
|
||||
},
|
||||
addAIContext: (connectionKey, context) => set((state) => {
|
||||
const contexts = state.aiContexts[connectionKey] || [];
|
||||
if (contexts.find(c => c.dbName === context.dbName && c.tableName === context.tableName)) {
|
||||
@@ -1173,8 +1265,9 @@ export const useStore = create<AppState>()(
|
||||
shortcutOptions: sanitizeShortcutOptions(state.shortcutOptions),
|
||||
tableAccessCount: sanitizeTableAccessCount(state.tableAccessCount),
|
||||
|
||||
aiChatHistory: (state.aiChatHistory && typeof state.aiChatHistory === 'object') ? state.aiChatHistory : {},
|
||||
aiChatSessions: Array.isArray(state.aiChatSessions) ? state.aiChatSessions : [],
|
||||
// AI 会话数据不再从 localStorage 恢复,改为从后端文件加载
|
||||
aiChatHistory: {},
|
||||
aiChatSessions: [],
|
||||
};
|
||||
},
|
||||
partialize: (state) => ({
|
||||
@@ -1200,17 +1293,7 @@ export const useStore = create<AppState>()(
|
||||
windowState: state.windowState,
|
||||
sidebarWidth: state.sidebarWidth,
|
||||
|
||||
// 只持久化最近 20 个会话的聊天记录,防止 localStorage 膨胀
|
||||
aiChatHistory: (() => {
|
||||
const MAX_PERSIST_SESSIONS = 20;
|
||||
const recentIds = new Set(state.aiChatSessions.slice(0, MAX_PERSIST_SESSIONS).map(s => s.id));
|
||||
const trimmed: Record<string, any> = {};
|
||||
for (const id of recentIds) {
|
||||
if (state.aiChatHistory[id]) trimmed[id] = state.aiChatHistory[id];
|
||||
}
|
||||
return trimmed;
|
||||
})(),
|
||||
aiChatSessions: state.aiChatSessions.slice(0, 50),
|
||||
// AI 会话数据已迁移到后端文件持久化(~/.gonavi/sessions/),不再写入 localStorage
|
||||
}), // Don't persist logs
|
||||
}
|
||||
)
|
||||
|
||||
8
frontend/wailsjs/go/aiservice/Service.d.ts
vendored
8
frontend/wailsjs/go/aiservice/Service.d.ts
vendored
@@ -13,6 +13,8 @@ export function AICheckSQL(arg1:string):Promise<ai.SafetyResult>;
|
||||
|
||||
export function AIDeleteProvider(arg1:string):Promise<void>;
|
||||
|
||||
export function AIDeleteSession(arg1:string):Promise<void>;
|
||||
|
||||
export function AIGetActiveProvider():Promise<string>;
|
||||
|
||||
export function AIGetBuiltinPrompts():Promise<Record<string, string>>;
|
||||
@@ -23,10 +25,16 @@ export function AIGetProviders():Promise<Array<ai.ProviderConfig>>;
|
||||
|
||||
export function AIGetSafetyLevel():Promise<string>;
|
||||
|
||||
export function AIGetSessions():Promise<Array<Record<string, any>>>;
|
||||
|
||||
export function AIListModels():Promise<Record<string, any>>;
|
||||
|
||||
export function AILoadSession(arg1:string):Promise<Record<string, any>>;
|
||||
|
||||
export function AISaveProvider(arg1:ai.ProviderConfig):Promise<void>;
|
||||
|
||||
export function AISaveSession(arg1:string,arg2:string,arg3:number,arg4:string):Promise<void>;
|
||||
|
||||
export function AISetActiveProvider(arg1:string):Promise<void>;
|
||||
|
||||
export function AISetContextLevel(arg1:string):Promise<void>;
|
||||
|
||||
@@ -22,6 +22,10 @@ export function AIDeleteProvider(arg1) {
|
||||
return window['go']['aiservice']['Service']['AIDeleteProvider'](arg1);
|
||||
}
|
||||
|
||||
export function AIDeleteSession(arg1) {
|
||||
return window['go']['aiservice']['Service']['AIDeleteSession'](arg1);
|
||||
}
|
||||
|
||||
export function AIGetActiveProvider() {
|
||||
return window['go']['aiservice']['Service']['AIGetActiveProvider']();
|
||||
}
|
||||
@@ -42,14 +46,26 @@ export function AIGetSafetyLevel() {
|
||||
return window['go']['aiservice']['Service']['AIGetSafetyLevel']();
|
||||
}
|
||||
|
||||
export function AIGetSessions() {
|
||||
return window['go']['aiservice']['Service']['AIGetSessions']();
|
||||
}
|
||||
|
||||
export function AIListModels() {
|
||||
return window['go']['aiservice']['Service']['AIListModels']();
|
||||
}
|
||||
|
||||
export function AILoadSession(arg1) {
|
||||
return window['go']['aiservice']['Service']['AILoadSession'](arg1);
|
||||
}
|
||||
|
||||
export function AISaveProvider(arg1) {
|
||||
return window['go']['aiservice']['Service']['AISaveProvider'](arg1);
|
||||
}
|
||||
|
||||
export function AISaveSession(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['aiservice']['Service']['AISaveSession'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
export function AISetActiveProvider(arg1) {
|
||||
return window['go']['aiservice']['Service']['AISetActiveProvider'](arg1);
|
||||
}
|
||||
|
||||
@@ -15,10 +15,23 @@ import (
|
||||
|
||||
const (
|
||||
defaultAnthropicBaseURL = "https://api.anthropic.com"
|
||||
defaultAnthropicModel = "claude-3-5-sonnet-20241022"
|
||||
anthropicAPIVersion = "2023-06-01"
|
||||
)
|
||||
|
||||
func normalizeAnthropicMessagesURL(baseURL string) string {
|
||||
url := strings.TrimRight(strings.TrimSpace(baseURL), "/")
|
||||
if url == "" {
|
||||
url = defaultAnthropicBaseURL
|
||||
}
|
||||
if strings.HasSuffix(url, "/messages") {
|
||||
return url
|
||||
}
|
||||
if strings.HasSuffix(url, "/v1") {
|
||||
return url + "/messages"
|
||||
}
|
||||
return url + "/v1/messages"
|
||||
}
|
||||
|
||||
// AnthropicProvider 实现 Anthropic Claude API 的 Provider
|
||||
type AnthropicProvider struct {
|
||||
config ai.ProviderConfig
|
||||
@@ -34,7 +47,7 @@ func NewAnthropicProvider(config ai.ProviderConfig) (Provider, error) {
|
||||
}
|
||||
model := strings.TrimSpace(config.Model)
|
||||
if model == "" {
|
||||
model = defaultAnthropicModel
|
||||
return nil, fmt.Errorf("模型 ID 不能为空,请在设置中选择或输入模型")
|
||||
}
|
||||
maxTokens := config.MaxTokens
|
||||
if maxTokens <= 0 {
|
||||
@@ -425,10 +438,7 @@ func (p *AnthropicProvider) doRequest(ctx context.Context, body interface{}) (io
|
||||
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
url := p.baseURL + "/v1/messages"
|
||||
if strings.HasSuffix(p.baseURL, "/v1") {
|
||||
url = p.baseURL + "/messages"
|
||||
}
|
||||
url := normalizeAnthropicMessagesURL(p.baseURL)
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
|
||||
24
internal/ai/provider/anthropic_test.go
Normal file
24
internal/ai/provider/anthropic_test.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package provider
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeAnthropicMessagesURL_AppendsMessagesSuffix(t *testing.T) {
|
||||
url := normalizeAnthropicMessagesURL("https://api.anthropic.com")
|
||||
if url != "https://api.anthropic.com/v1/messages" {
|
||||
t.Fatalf("expected normalized anthropic messages url, got %q", url)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAnthropicMessagesURL_UsesMoonshotAnthropicMessagesEndpoint(t *testing.T) {
|
||||
url := normalizeAnthropicMessagesURL("https://api.moonshot.cn/anthropic")
|
||||
if url != "https://api.moonshot.cn/anthropic/v1/messages" {
|
||||
t.Fatalf("expected moonshot anthropic messages url, got %q", url)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAnthropicMessagesURL_PreservesExplicitMessagesPath(t *testing.T) {
|
||||
url := normalizeAnthropicMessagesURL("https://api.moonshot.cn/anthropic/v1/messages")
|
||||
if url != "https://api.moonshot.cn/anthropic/v1/messages" {
|
||||
t.Fatalf("expected explicit messages path to be preserved, got %q", url)
|
||||
}
|
||||
}
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
|
||||
const (
|
||||
defaultGeminiBaseURL = "https://generativelanguage.googleapis.com"
|
||||
defaultGeminiModel = "gemini-2.0-flash"
|
||||
)
|
||||
|
||||
// GeminiProvider 实现 Google Gemini API 的 Provider
|
||||
@@ -33,7 +32,7 @@ func NewGeminiProvider(config ai.ProviderConfig) (Provider, error) {
|
||||
}
|
||||
model := strings.TrimSpace(config.Model)
|
||||
if model == "" {
|
||||
model = defaultGeminiModel
|
||||
return nil, fmt.Errorf("模型 ID 不能为空,请在设置中选择或输入模型")
|
||||
}
|
||||
maxTokens := config.MaxTokens
|
||||
if maxTokens <= 0 {
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
|
||||
const (
|
||||
defaultOpenAIBaseURL = "https://api.openai.com/v1"
|
||||
defaultOpenAIModel = "gpt-4o"
|
||||
defaultOpenAIMaxTokens = 4096
|
||||
defaultOpenAITemperature = 0.7
|
||||
openAIHTTPTimeout = 120 * time.Second
|
||||
@@ -41,7 +40,7 @@ func NewOpenAIProvider(config ai.ProviderConfig) (Provider, error) {
|
||||
}
|
||||
model := strings.TrimSpace(config.Model)
|
||||
if model == "" {
|
||||
model = defaultOpenAIModel
|
||||
return nil, fmt.Errorf("模型 ID 不能为空,请在设置中选择或输入模型")
|
||||
}
|
||||
maxTokens := config.MaxTokens
|
||||
if maxTokens <= 0 {
|
||||
|
||||
@@ -28,18 +28,24 @@ func TestOpenAIProvider_Validate_Valid(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestOpenAIProvider_Name_Custom(t *testing.T) {
|
||||
p, _ := NewOpenAIProvider(ai.ProviderConfig{
|
||||
Type: "openai", Name: "My OpenAI", APIKey: "sk-test",
|
||||
p, err := NewOpenAIProvider(ai.ProviderConfig{
|
||||
Type: "openai", Name: "My OpenAI", APIKey: "sk-test", Model: "gpt-4o",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected constructor error: %v", err)
|
||||
}
|
||||
if p.Name() != "My OpenAI" {
|
||||
t.Fatalf("expected name 'My OpenAI', got '%s'", p.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIProvider_Name_Default(t *testing.T) {
|
||||
p, _ := NewOpenAIProvider(ai.ProviderConfig{
|
||||
Type: "openai", APIKey: "sk-test",
|
||||
p, err := NewOpenAIProvider(ai.ProviderConfig{
|
||||
Type: "openai", APIKey: "sk-test", Model: "gpt-4o",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected constructor error: %v", err)
|
||||
}
|
||||
if p.Name() != "OpenAI" {
|
||||
t.Fatalf("expected default name 'OpenAI', got '%s'", p.Name())
|
||||
}
|
||||
@@ -56,29 +62,34 @@ func TestOpenAIProvider_DefaultBaseURL(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestOpenAIProvider_CustomBaseURL(t *testing.T) {
|
||||
p, _ := NewOpenAIProvider(ai.ProviderConfig{
|
||||
Type: "openai", APIKey: "sk-test", BaseURL: "https://my-proxy.com/v1",
|
||||
p, err := NewOpenAIProvider(ai.ProviderConfig{
|
||||
Type: "openai", APIKey: "sk-test", BaseURL: "https://my-proxy.com/v1", Model: "gpt-4o",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected constructor error: %v", err)
|
||||
}
|
||||
op := p.(*OpenAIProvider)
|
||||
if op.baseURL != "https://my-proxy.com/v1" {
|
||||
t.Fatalf("expected custom base URL, got '%s'", op.baseURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIProvider_DefaultModel(t *testing.T) {
|
||||
p, _ := NewOpenAIProvider(ai.ProviderConfig{
|
||||
func TestOpenAIProvider_RejectsMissingModel(t *testing.T) {
|
||||
_, err := NewOpenAIProvider(ai.ProviderConfig{
|
||||
Type: "openai", APIKey: "sk-test",
|
||||
})
|
||||
op := p.(*OpenAIProvider)
|
||||
if op.config.Model != "gpt-4o" {
|
||||
t.Fatalf("expected default model 'gpt-4o', got '%s'", op.config.Model)
|
||||
if err == nil {
|
||||
t.Fatal("expected constructor error for missing model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIProvider_DefaultMaxTokens(t *testing.T) {
|
||||
p, _ := NewOpenAIProvider(ai.ProviderConfig{
|
||||
Type: "openai", APIKey: "sk-test",
|
||||
p, err := NewOpenAIProvider(ai.ProviderConfig{
|
||||
Type: "openai", APIKey: "sk-test", Model: "gpt-4o",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected constructor error: %v", err)
|
||||
}
|
||||
op := p.(*OpenAIProvider)
|
||||
if op.config.MaxTokens != 4096 {
|
||||
t.Fatalf("expected default max tokens 4096, got %d", op.config.MaxTokens)
|
||||
|
||||
@@ -35,6 +35,16 @@ type Service struct {
|
||||
cancelFuncs map[string]context.CancelFunc // 记录每个 session 的 context 取消函数
|
||||
}
|
||||
|
||||
var miniMaxAnthropicModels = []string{
|
||||
"MiniMax-M2.7",
|
||||
"MiniMax-M2.7-highspeed",
|
||||
"MiniMax-M2.5",
|
||||
"MiniMax-M2.5-highspeed",
|
||||
"MiniMax-M2.1",
|
||||
"MiniMax-M2.1-highspeed",
|
||||
"MiniMax-M2",
|
||||
}
|
||||
|
||||
// NewService 创建 AI Service 实例
|
||||
func NewService() *Service {
|
||||
return &Service{
|
||||
@@ -63,6 +73,9 @@ func (s *Service) AIGetProviders() []ai.ProviderConfig {
|
||||
|
||||
result := make([]ai.ProviderConfig, len(s.providers))
|
||||
copy(result, s.providers)
|
||||
for i := range result {
|
||||
result[i] = normalizeProviderConfig(result[i])
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -72,6 +85,8 @@ func (s *Service) AISaveProvider(config ai.ProviderConfig) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
config = normalizeProviderConfig(config)
|
||||
|
||||
if strings.TrimSpace(config.ID) == "" {
|
||||
config.ID = "provider-" + uuid.New().String()[:8]
|
||||
}
|
||||
@@ -128,65 +143,36 @@ func (s *Service) AITestProvider(config ai.ProviderConfig) map[string]interface{
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
config = normalizeProviderConfig(config)
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(config.BaseURL), "/")
|
||||
providerType := config.Type
|
||||
if providerType == "custom" && config.APIFormat != "" {
|
||||
providerType = config.APIFormat
|
||||
}
|
||||
providerType := normalizedProviderType(config)
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
var err error
|
||||
|
||||
switch providerType {
|
||||
case "openai":
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com/v1"
|
||||
}
|
||||
if !strings.HasSuffix(baseURL, "/v1") && !strings.Contains(baseURL, "/v1/") {
|
||||
baseURL = baseURL + "/v1"
|
||||
}
|
||||
// 使用 /models 端点验证连通性和鉴权
|
||||
req, _ := http.NewRequest("GET", baseURL+"/models", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+config.APIKey)
|
||||
for k, v := range config.Headers {
|
||||
req.Header.Set(k, v)
|
||||
case "openai", "anthropic", "gemini":
|
||||
req, reqErr := newProviderHealthCheckRequest(config)
|
||||
if reqErr != nil {
|
||||
err = reqErr
|
||||
break
|
||||
}
|
||||
resp, reqErr := client.Do(req)
|
||||
if reqErr != nil {
|
||||
err = reqErr
|
||||
} else {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusUnauthorized {
|
||||
err = fmt.Errorf("API Key 验证失败 (HTTP %d)", resp.StatusCode)
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||
err = fmt.Errorf("API Key 无效或请求错误 (HTTP %d)", resp.StatusCode)
|
||||
} else if providerType == "gemini" && resp.StatusCode == http.StatusBadRequest {
|
||||
err = fmt.Errorf("API Key 无效或请求错误 (HTTP %d)", resp.StatusCode)
|
||||
} else if resp.StatusCode >= 400 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 512))
|
||||
err = fmt.Errorf("接口返回异常 (HTTP %d): %s", resp.StatusCode, string(body))
|
||||
} else if resp.StatusCode >= 500 {
|
||||
err = fmt.Errorf("上游服务器内部错误 (HTTP %d)", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
case "anthropic":
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.anthropic.com"
|
||||
}
|
||||
req, _ := http.NewRequest("GET", baseURL, nil)
|
||||
resp, reqErr := client.Do(req)
|
||||
if reqErr != nil {
|
||||
err = reqErr
|
||||
} else {
|
||||
resp.Body.Close()
|
||||
}
|
||||
case "gemini":
|
||||
if baseURL == "" {
|
||||
baseURL = "https://generativelanguage.googleapis.com"
|
||||
}
|
||||
req, _ := http.NewRequest("GET", baseURL+"/v1beta/models?key="+config.APIKey, nil)
|
||||
resp, reqErr := client.Do(req)
|
||||
if reqErr != nil {
|
||||
err = reqErr
|
||||
} else {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusBadRequest {
|
||||
err = fmt.Errorf("API Key 无效或请求错误 (HTTP %d)", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
default:
|
||||
if baseURL != "" {
|
||||
req, _ := http.NewRequest("GET", baseURL, nil)
|
||||
@@ -209,6 +195,153 @@ func (s *Service) AITestProvider(config ai.ProviderConfig) map[string]interface{
|
||||
}
|
||||
}
|
||||
|
||||
func normalizedProviderType(config ai.ProviderConfig) string {
|
||||
providerType := strings.ToLower(strings.TrimSpace(config.Type))
|
||||
if providerType == "custom" && strings.TrimSpace(config.APIFormat) != "" {
|
||||
return strings.ToLower(strings.TrimSpace(config.APIFormat))
|
||||
}
|
||||
return providerType
|
||||
}
|
||||
|
||||
func isMiniMaxAnthropicProvider(config ai.ProviderConfig) bool {
|
||||
if normalizedProviderType(config) != "anthropic" {
|
||||
return false
|
||||
}
|
||||
baseURL := strings.ToLower(strings.TrimRight(strings.TrimSpace(config.BaseURL), "/"))
|
||||
return strings.Contains(baseURL, "api.minimax.io") || strings.Contains(baseURL, "api.minimaxi.com")
|
||||
}
|
||||
|
||||
func isMoonshotAnthropicProvider(config ai.ProviderConfig) bool {
|
||||
if normalizedProviderType(config) != "anthropic" {
|
||||
return false
|
||||
}
|
||||
baseURL := strings.ToLower(strings.TrimRight(strings.TrimSpace(config.BaseURL), "/"))
|
||||
return strings.Contains(baseURL, "api.moonshot.cn")
|
||||
}
|
||||
|
||||
func defaultStaticModelsForProvider(config ai.ProviderConfig) []string {
|
||||
if isMiniMaxAnthropicProvider(config) {
|
||||
return append([]string(nil), miniMaxAnthropicModels...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeProviderConfig(config ai.ProviderConfig) ai.ProviderConfig {
|
||||
staticModels := defaultStaticModelsForProvider(config)
|
||||
if len(staticModels) > 0 && len(config.Models) == 0 {
|
||||
config.Models = staticModels
|
||||
}
|
||||
model := strings.TrimSpace(config.Model)
|
||||
if isMiniMaxAnthropicProvider(config) && (model == "" || strings.HasPrefix(strings.ToLower(model), "minimax-text-")) {
|
||||
config.Model = miniMaxAnthropicModels[0]
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
func resolveModelsURL(config ai.ProviderConfig) string {
|
||||
config = normalizeProviderConfig(config)
|
||||
providerType := normalizedProviderType(config)
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(config.BaseURL), "/")
|
||||
|
||||
switch providerType {
|
||||
case "anthropic":
|
||||
if isMoonshotAnthropicProvider(config) {
|
||||
return "https://api.moonshot.cn/v1/models"
|
||||
}
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.anthropic.com"
|
||||
}
|
||||
if !strings.HasSuffix(baseURL, "/v1") && !strings.Contains(baseURL, "/v1/") {
|
||||
baseURL = baseURL + "/v1"
|
||||
}
|
||||
return baseURL + "/models"
|
||||
case "gemini":
|
||||
if baseURL == "" {
|
||||
baseURL = "https://generativelanguage.googleapis.com"
|
||||
}
|
||||
return baseURL + "/v1beta/models?key=" + config.APIKey
|
||||
case "openai":
|
||||
fallthrough
|
||||
default:
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com/v1"
|
||||
}
|
||||
if !strings.HasSuffix(baseURL, "/v1") && !strings.Contains(baseURL, "/v1/") {
|
||||
baseURL = baseURL + "/v1"
|
||||
}
|
||||
return baseURL + "/models"
|
||||
}
|
||||
}
|
||||
|
||||
func newModelsRequest(config ai.ProviderConfig) (*http.Request, error) {
|
||||
config = normalizeProviderConfig(config)
|
||||
url := resolveModelsURL(config)
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
switch normalizedProviderType(config) {
|
||||
case "anthropic":
|
||||
req.Header.Set("x-api-key", config.APIKey)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
req.Header.Set("Authorization", "Bearer "+config.APIKey)
|
||||
case "gemini":
|
||||
// Gemini 使用 query string 传递 key,无需额外鉴权头
|
||||
default:
|
||||
req.Header.Set("Authorization", "Bearer "+config.APIKey)
|
||||
}
|
||||
|
||||
for k, v := range config.Headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func resolveAnthropicMessagesURL(baseURL string) string {
|
||||
url := strings.TrimRight(strings.TrimSpace(baseURL), "/")
|
||||
if url == "" {
|
||||
url = "https://api.anthropic.com"
|
||||
}
|
||||
if strings.HasSuffix(url, "/messages") {
|
||||
return url
|
||||
}
|
||||
if strings.HasSuffix(url, "/v1") {
|
||||
return url + "/messages"
|
||||
}
|
||||
return url + "/v1/messages"
|
||||
}
|
||||
|
||||
func newProviderHealthCheckRequest(config ai.ProviderConfig) (*http.Request, error) {
|
||||
config = normalizeProviderConfig(config)
|
||||
if isMiniMaxAnthropicProvider(config) {
|
||||
body := map[string]interface{}{
|
||||
"model": config.Model,
|
||||
"max_tokens": 1,
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": "ping"},
|
||||
},
|
||||
}
|
||||
bodyBytes, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
req, err := http.NewRequest("POST", resolveAnthropicMessagesURL(config.BaseURL), strings.NewReader(string(bodyBytes)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("x-api-key", config.APIKey)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
for k, v := range config.Headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
return newModelsRequest(config)
|
||||
}
|
||||
|
||||
// AISetActiveProvider 设置活动 Provider
|
||||
func (s *Service) AISetActiveProvider(id string) {
|
||||
s.mu.Lock()
|
||||
@@ -261,17 +394,16 @@ func (s *Service) AIListModels() map[string]interface{} {
|
||||
|
||||
// fetchModels 从供应商 API 获取可用模型列表
|
||||
func fetchModels(config ai.ProviderConfig) ([]string, error) {
|
||||
providerType := config.Type
|
||||
if providerType == "custom" && config.APIFormat != "" {
|
||||
providerType = config.APIFormat
|
||||
providerType := normalizedProviderType(config)
|
||||
if staticModels := defaultStaticModelsForProvider(config); len(staticModels) > 0 {
|
||||
return staticModels, nil
|
||||
}
|
||||
|
||||
switch providerType {
|
||||
case "openai":
|
||||
return fetchOpenAIModels(config)
|
||||
case "anthropic":
|
||||
// Anthropic 没有公开的 /models 端点,返回硬编码列表
|
||||
return []string{"claude-opus-4-6", "claude-sonnet-4-6"}, nil
|
||||
return fetchAnthropicModels(config)
|
||||
case "gemini":
|
||||
return fetchGeminiModels(config)
|
||||
default:
|
||||
@@ -281,20 +413,45 @@ func fetchModels(config ai.ProviderConfig) ([]string, error) {
|
||||
|
||||
// fetchOpenAIModels 获取 OpenAI 兼容 API 的模型列表
|
||||
func fetchOpenAIModels(config ai.ProviderConfig) ([]string, error) {
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(config.BaseURL), "/")
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com/v1"
|
||||
}
|
||||
// 确保 baseURL 以 /v1 结尾
|
||||
if !strings.HasSuffix(baseURL, "/v1") {
|
||||
baseURL = baseURL + "/v1"
|
||||
req, err := newModelsRequest(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", baseURL+"/models", nil)
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
return nil, fmt.Errorf("请求模型列表失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return nil, fmt.Errorf("获取模型列表失败 (HTTP %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("解析模型列表失败: %w", err)
|
||||
}
|
||||
|
||||
models := make([]string, 0, len(result.Data))
|
||||
for _, m := range result.Data {
|
||||
models = append(models, m.ID)
|
||||
}
|
||||
return models, nil
|
||||
}
|
||||
|
||||
// fetchAnthropicModels 获取 Anthropic API 的模型列表
|
||||
func fetchAnthropicModels(config ai.ProviderConfig) ([]string, error) {
|
||||
req, err := newModelsRequest(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+config.APIKey)
|
||||
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
@@ -515,7 +672,7 @@ func (s *Service) getActiveProvider() (provider.Provider, error) {
|
||||
|
||||
for _, cfg := range s.providers {
|
||||
if cfg.ID == s.activeProvider {
|
||||
return provider.NewProvider(cfg)
|
||||
return provider.NewProvider(normalizeProviderConfig(cfg))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -548,6 +705,9 @@ func (s *Service) loadConfig() {
|
||||
if s.providers == nil {
|
||||
s.providers = make([]ai.ProviderConfig, 0)
|
||||
}
|
||||
for i := range s.providers {
|
||||
s.providers[i] = normalizeProviderConfig(s.providers[i])
|
||||
}
|
||||
s.activeProvider = cfg.ActiveProvider
|
||||
|
||||
switch ai.SQLPermissionLevel(cfg.SafetyLevel) {
|
||||
@@ -591,14 +751,122 @@ func (s *Service) saveConfig() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- 会话文件持久化 ---
|
||||
|
||||
// sessionFileData 会话文件的 JSON 结构
|
||||
type sessionFileData struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
UpdatedAt int64 `json:"updatedAt"`
|
||||
Messages json.RawMessage `json:"messages"` // 透传前端格式,后端不解析消息体
|
||||
}
|
||||
|
||||
func (s *Service) sessionsDir() string {
|
||||
return filepath.Join(s.configDir, "sessions")
|
||||
}
|
||||
|
||||
// AIGetSessions 获取所有会话的元数据列表(不含消息体)
|
||||
func (s *Service) AIGetSessions() []map[string]interface{} {
|
||||
dir := s.sessionsDir()
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return []map[string]interface{}{}
|
||||
}
|
||||
|
||||
var sessions []map[string]interface{}
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") {
|
||||
continue
|
||||
}
|
||||
data, err := os.ReadFile(filepath.Join(dir, entry.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var sfd sessionFileData
|
||||
if err := json.Unmarshal(data, &sfd); err != nil {
|
||||
continue
|
||||
}
|
||||
sessions = append(sessions, map[string]interface{}{
|
||||
"id": sfd.ID,
|
||||
"title": sfd.Title,
|
||||
"updatedAt": sfd.UpdatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
// 按 updatedAt 降序排列
|
||||
for i := 0; i < len(sessions); i++ {
|
||||
for j := i + 1; j < len(sessions); j++ {
|
||||
ti, _ := sessions[i]["updatedAt"].(int64)
|
||||
tj, _ := sessions[j]["updatedAt"].(int64)
|
||||
if tj > ti {
|
||||
sessions[i], sessions[j] = sessions[j], sessions[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sessions
|
||||
}
|
||||
|
||||
// AILoadSession 加载指定会话的完整数据(含消息)
|
||||
func (s *Service) AILoadSession(sessionID string) map[string]interface{} {
|
||||
path := filepath.Join(s.sessionsDir(), sessionID+".json")
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return map[string]interface{}{"success": false, "error": "会话不存在"}
|
||||
}
|
||||
var sfd sessionFileData
|
||||
if err := json.Unmarshal(data, &sfd); err != nil {
|
||||
return map[string]interface{}{"success": false, "error": "会话数据损坏"}
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"success": true,
|
||||
"id": sfd.ID,
|
||||
"title": sfd.Title,
|
||||
"updatedAt": sfd.UpdatedAt,
|
||||
"messages": sfd.Messages,
|
||||
}
|
||||
}
|
||||
|
||||
// AISaveSession 保存会话数据到文件
|
||||
func (s *Service) AISaveSession(sessionID string, title string, updatedAt float64, messagesJSON string) error {
|
||||
dir := s.sessionsDir()
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return fmt.Errorf("创建 sessions 目录失败: %w", err)
|
||||
}
|
||||
|
||||
sfd := sessionFileData{
|
||||
ID: sessionID,
|
||||
Title: title,
|
||||
UpdatedAt: int64(updatedAt),
|
||||
Messages: json.RawMessage(messagesJSON),
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(sfd, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化会话数据失败: %w", err)
|
||||
}
|
||||
|
||||
path := filepath.Join(dir, sessionID+".json")
|
||||
return os.WriteFile(path, data, 0o644)
|
||||
}
|
||||
|
||||
// AIDeleteSession 删除会话文件
|
||||
func (s *Service) AIDeleteSession(sessionID string) error {
|
||||
path := filepath.Join(s.sessionsDir(), sessionID+".json")
|
||||
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("删除会话失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- 工具函数 ---
|
||||
|
||||
func resolveConfigDir() string {
|
||||
configDir, err := os.UserConfigDir()
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
configDir = "."
|
||||
homeDir = "."
|
||||
}
|
||||
return filepath.Join(configDir, "GoNavi")
|
||||
return filepath.Join(homeDir, ".gonavi")
|
||||
}
|
||||
|
||||
func maskAPIKey(apiKey string) string {
|
||||
|
||||
78
internal/ai/service/service_test.go
Normal file
78
internal/ai/service/service_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package aiservice
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/ai"
|
||||
)
|
||||
|
||||
func TestResolveModelsURL_UsesMoonshotOpenAIModelsEndpointForKimiAnthropicBaseURL(t *testing.T) {
|
||||
url := resolveModelsURL(ai.ProviderConfig{
|
||||
Type: "anthropic",
|
||||
BaseURL: "https://api.moonshot.cn/anthropic",
|
||||
})
|
||||
if url != "https://api.moonshot.cn/v1/models" {
|
||||
t.Fatalf("expected moonshot models endpoint, got %q", url)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveModelsURL_UsesAnthropicModelsEndpointForOfficialAnthropic(t *testing.T) {
|
||||
url := resolveModelsURL(ai.ProviderConfig{
|
||||
Type: "anthropic",
|
||||
BaseURL: "https://api.anthropic.com",
|
||||
})
|
||||
if url != "https://api.anthropic.com/v1/models" {
|
||||
t.Fatalf("expected anthropic models endpoint, got %q", url)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveModelsURL_UsesOpenAIModelsEndpointForOpenAICompatibleProvider(t *testing.T) {
|
||||
url := resolveModelsURL(ai.ProviderConfig{
|
||||
Type: "openai",
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
})
|
||||
if url != "https://api.openai.com/v1/models" {
|
||||
t.Fatalf("expected openai models endpoint, got %q", url)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultStaticModelsForProvider_ReturnsMiniMaxAnthropicModels(t *testing.T) {
|
||||
models := defaultStaticModelsForProvider(ai.ProviderConfig{
|
||||
Type: "anthropic",
|
||||
BaseURL: "https://api.minimaxi.com/anthropic",
|
||||
})
|
||||
expected := []string{
|
||||
"MiniMax-M2.7",
|
||||
"MiniMax-M2.7-highspeed",
|
||||
"MiniMax-M2.5",
|
||||
"MiniMax-M2.5-highspeed",
|
||||
"MiniMax-M2.1",
|
||||
"MiniMax-M2.1-highspeed",
|
||||
"MiniMax-M2",
|
||||
}
|
||||
if !reflect.DeepEqual(models, expected) {
|
||||
t.Fatalf("expected MiniMax static models %v, got %v", expected, models)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewProviderHealthCheckRequest_UsesMessagesEndpointForMiniMaxAnthropic(t *testing.T) {
|
||||
req, err := newProviderHealthCheckRequest(ai.ProviderConfig{
|
||||
Type: "anthropic",
|
||||
BaseURL: "https://api.minimaxi.com/anthropic",
|
||||
Model: "MiniMax-M2.7",
|
||||
APIKey: "sk-test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if req.Method != "POST" {
|
||||
t.Fatalf("expected POST request, got %s", req.Method)
|
||||
}
|
||||
if req.URL.String() != "https://api.minimaxi.com/anthropic/v1/messages" {
|
||||
t.Fatalf("expected MiniMax messages endpoint, got %q", req.URL.String())
|
||||
}
|
||||
if got := req.Header.Get("x-api-key"); got != "sk-test" {
|
||||
t.Fatalf("expected x-api-key header to be set, got %q", got)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user