diff --git a/frontend/package.json.md5 b/frontend/package.json.md5 index 3018db7..efbd2b6 100755 --- a/frontend/package.json.md5 +++ b/frontend/package.json.md5 @@ -1 +1 @@ -dcb87159cf0f1f6f750d1c4870911d3f \ No newline at end of file +6ba85e4f456d2c0d230cab198c7dc02b \ No newline at end of file diff --git a/frontend/src/components/AIChatPanel.css b/frontend/src/components/AIChatPanel.css index 88978ec..5b935eb 100644 --- a/frontend/src/components/AIChatPanel.css +++ b/frontend/src/components/AIChatPanel.css @@ -483,3 +483,15 @@ @keyframes ai-spin-anim { to { transform: rotate(360deg); } } + +/* 面板/弹窗内部 toast 定位覆盖:从 fixed(视口顶部)改为 absolute(容器内部顶部) */ +.ai-chat-panel .ant-message, +.ai-settings-body .ant-message { + position: absolute !important; + top: 16px !important; + left: 50% !important; + transform: translateX(-50%) !important; + right: auto !important; + width: max-content; + z-index: 100; +} diff --git a/frontend/src/components/AIChatPanel.tsx b/frontend/src/components/AIChatPanel.tsx index 49c6660..6f2702e 100644 --- a/frontend/src/components/AIChatPanel.tsx +++ b/frontend/src/components/AIChatPanel.tsx @@ -1,12 +1,12 @@ import React, { useState, useRef, useEffect, useCallback, useMemo } from 'react'; import { createPortal } from 'react-dom'; -import { useStore } from '../store'; +import { useStore, loadAISessionsFromBackend, loadAISessionFromBackend } from '../store'; import { EventsOn, EventsOff } from '../../wailsjs/runtime'; import { DBGetDatabases, DBGetTables } from '../../wailsjs/go/app/App'; import type { OverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme'; import { AIChatMessage, AIToolCall } from '../types'; import { DownOutlined } from '@ant-design/icons'; -import { message } from 'antd'; +import { message as antdMessage } from 'antd'; import './AIChatPanel.css'; import { AIChatHeader } from './ai/AIChatHeader'; @@ -224,6 +224,9 @@ export const AIChatPanel: React.FC = ({ const panelRef = useRef(null); // 面板 DOM ref,用于拖拽时直接操作宽度 const dragWidthRef = useRef(0); // 拖拽过程中的实时宽度(不触发 React 重渲染) + // 面板内部 toast 通知(不在屏幕顶部,而在面板容器内显示) + const [messageApi, messageContextHolder] = antdMessage.useMessage({ getContainer: () => panelRef.current || document.body }); + const aiChatHistory = useStore(state => state.aiChatHistory); const aiActiveSessionId = useStore(state => state.aiActiveSessionId); const createNewAISession = useStore(state => state.createNewAISession); @@ -280,6 +283,21 @@ export const AIChatPanel: React.FC = ({ }, [aiActiveSessionId, createNewAISession]); const sid = aiActiveSessionId || 'session-fallback'; + + // 面板首次可见时从后端加载会话列表 + const sessionsLoadedRef = useRef(false); + useEffect(() => { + if (!aiPanelVisible || sessionsLoadedRef.current) return; + sessionsLoadedRef.current = true; + loadAISessionsFromBackend(); + }, [aiPanelVisible]); + + // 切换会话时按需从后端加载消息 + useEffect(() => { + if (sid && sid !== 'session-fallback') { + loadAISessionFromBackend(sid); + } + }, [sid]); const messages = aiChatHistory[sid] || []; const getConnectionName = useCallback(() => { @@ -314,6 +332,17 @@ export const AIChatPanel: React.FC = ({ useEffect(() => { loadActiveProvider(); }, [loadActiveProvider]); + // 监听供应商配置变更(来自设置面板的删除/新增/切换操作),重新加载 active provider 并清空已缓存的模型 + useEffect(() => { + const handler = () => { + setDynamicModels([]); + activeProviderIdRef.current = null; + loadActiveProvider(); + }; + window.addEventListener('gonavi:ai:provider-changed', handler); + return () => window.removeEventListener('gonavi:ai:provider-changed', handler); + }, [loadActiveProvider]); + const handleModelChange = async (val: string) => { if (!activeProvider) return; try { @@ -331,7 +360,12 @@ export const AIChatPanel: React.FC = ({ setDynamicModels([]); activeProviderIdRef.current = activeProvider.id; } - }, [activeProvider?.id]); + // 供应商被删除后 activeProvider 变为 null,此时也必须清空残留模型 + if (!activeProvider) { + setDynamicModels([]); + activeProviderIdRef.current = null; + } + }, [activeProvider?.id, activeProvider]); // dynamicModels 仅在内存中使用,不再写回供应商配置,避免污染静态 models 列表 @@ -346,11 +380,11 @@ export const AIChatPanel: React.FC = ({ const sortedModels = [...result.models].sort((a, b) => a.localeCompare(b)); setDynamicModels(sortedModels); } else if (result && !result.success) { - message.warning(result.error || '获取模型列表失败,可手动输入模型名称'); + messageApi.warning(result.error || '获取模型列表失败,可手动输入模型名称'); } } catch (e: any) { console.warn('Failed to fetch models', e); - message.warning('获取模型列表失败: ' + (e?.message || '未知错误')); + messageApi.warning('获取模型列表失败: ' + (e?.message || '未知错误')); } finally { setLoadingModels(false); } @@ -993,6 +1027,17 @@ SELECT * FROM users WHERE status = 1; const handleSend = useCallback(async () => { const text = input.trim(); if ((!text && draftImages.length === 0) || sending) return; + + // 前置校验:必须配置供应商且选择模型后才能发送 + if (!activeProvider) { + messageApi.warning('请先在 AI 设置中配置供应商'); + return; + } + if (!activeProvider.model || !activeProvider.model.trim()) { + messageApi.warning('请先选择模型 ID(点击工具栏的模型下拉框选择)'); + return; + } + toolCallRoundRef.current = 0; // 重置工具调用轮次计数 nudgeCountRef.current = 0; // 重置催促计数 @@ -1083,7 +1128,7 @@ SELECT * FROM users WHERE status = 1; addAIChatMessage(sid, { id: genId(), role: 'assistant', content: `❌ 发送失败: ${cleanE2}`, rawError: cleanE2 !== rawE2 ? rawE2 : undefined, timestamp: Date.now() }); setSending(false); } - }, [input, draftImages, sending, messages, addAIChatMessage, sid]); + }, [input, draftImages, sending, messages, addAIChatMessage, sid, activeProvider]); const handleKeyDown = useCallback((e: React.KeyboardEvent) => { if (e.key === 'Enter' && !e.shiftKey) { @@ -1213,6 +1258,7 @@ SELECT * FROM users WHERE status = 1; return (
+ {messageContextHolder}
{isResizing && panelRect.current && createPortal( diff --git a/frontend/src/components/AISettingsModal.tsx b/frontend/src/components/AISettingsModal.tsx index c9c00b3..83378d9 100644 --- a/frontend/src/components/AISettingsModal.tsx +++ b/frontend/src/components/AISettingsModal.tsx @@ -1,5 +1,5 @@ -import React, { useState, useEffect, useCallback } from 'react'; -import { Modal, Button, Input, Select, Form, message, Tooltip, Tabs, Space, Popconfirm, Slider } from 'antd'; +import React, { useState, useEffect, useCallback, useRef } from 'react'; +import { Modal, Button, Input, Select, Form, message as antdMessage, Tooltip, Tabs, Space, Popconfirm, Slider } from 'antd'; import { PlusOutlined, DeleteOutlined, EditOutlined, CheckOutlined, ApiOutlined, SafetyCertificateOutlined, RobotOutlined, ThunderboltOutlined, CloudOutlined, ExperimentOutlined, KeyOutlined, LinkOutlined, AppstoreOutlined, ToolOutlined } from '@ant-design/icons'; import type { AIProviderConfig, AIProviderType, AISafetyLevel, AIContextLevel } from '../types'; @@ -26,21 +26,40 @@ interface ProviderPreset { } const PROVIDER_PRESETS: ProviderPreset[] = [ - { key: 'openai', label: 'OpenAI', icon: , desc: 'GPT-5.4 / 5.3 系列', color: '#10b981', backendType: 'openai', defaultBaseUrl: 'https://api.openai.com/v1', defaultModel: 'gpt-5.4', models: ['gpt-5.4', 'gpt-5.4-mini', 'gpt-5.4-nano', 'gpt-5.3'] }, - { key: 'deepseek', label: 'DeepSeek', icon: , desc: 'DeepSeek-V4 / R1', color: '#3b82f6', backendType: 'openai', defaultBaseUrl: 'https://api.deepseek.com/v1', defaultModel: 'deepseek-chat', models: ['deepseek-chat', 'deepseek-reasoner'] }, - { key: 'qwen', label: '通义千问', icon: , desc: 'Qwen3.5 / Qwen3 系列', color: '#6366f1', backendType: 'openai', defaultBaseUrl: 'https://dashscope.aliyuncs.com/compatible-mode/v1', defaultModel: 'qwen3.5-max', models: ['qwen3.5-max', 'qwen3-plus', 'qwen3-turbo'] }, - { key: 'zhipu', label: '智谱 GLM', icon: , desc: 'GLM-5 / GLM-5-Turbo', color: '#0ea5e9', backendType: 'openai', defaultBaseUrl: 'https://open.bigmodel.cn/api/paas/v4', defaultModel: 'glm-5', models: ['glm-5', 'glm-5-turbo', 'glm-4.7-flash'] }, - { key: 'moonshot', label: 'Kimi', icon: , desc: 'Kimi K2.5 系列', color: '#0d9488', backendType: 'openai', defaultBaseUrl: 'https://api.moonshot.cn/v1', defaultModel: 'kimi-k2.5', models: ['kimi-k2.5', 'kimi-k2-turbo-preview', 'kimi-k2-thinking'] }, - { key: 'anthropic', label: 'Claude', icon: , desc: 'Claude Opus/Sonnet 4.6', color: '#d97706', backendType: 'anthropic', defaultBaseUrl: 'https://api.anthropic.com', defaultModel: 'claude-sonnet-4-6', models: ['claude-opus-4-6', 'claude-sonnet-4-6'] }, - { key: 'gemini', label: 'Gemini', icon: , desc: 'Gemini 3.1 / 2.5 系列', color: '#059669', backendType: 'gemini', defaultBaseUrl: 'https://generativelanguage.googleapis.com', defaultModel: 'gemini-2.5-flash', models: ['gemini-3.1-pro', 'gemini-2.5-flash', 'gemini-2.5-pro'] }, + { key: 'openai', label: 'OpenAI', icon: , desc: 'GPT-5.4 / 5.3 系列', color: '#10b981', backendType: 'openai', defaultBaseUrl: 'https://api.openai.com/v1', defaultModel: 'gpt-4o', models: [] }, + { key: 'deepseek', label: 'DeepSeek', icon: , desc: 'DeepSeek-V4 / R1', color: '#3b82f6', backendType: 'openai', defaultBaseUrl: 'https://api.deepseek.com/v1', defaultModel: 'deepseek-chat', models: [] }, + { key: 'qwen', label: '通义千问', icon: , desc: 'Qwen3.5 / Qwen3 系列', color: '#6366f1', backendType: 'openai', defaultBaseUrl: 'https://dashscope.aliyuncs.com/compatible-mode/v1', defaultModel: 'qwen-max', models: [] }, + { key: 'zhipu', label: '智谱 GLM', icon: , desc: 'GLM-5 / GLM-5-Turbo', color: '#0ea5e9', backendType: 'openai', defaultBaseUrl: 'https://open.bigmodel.cn/api/paas/v4', defaultModel: 'glm-4', models: [] }, + { key: 'moonshot', label: 'Kimi', icon: , desc: 'Kimi K2.5 (Anthropic 兼容)', color: '#0d9488', backendType: 'anthropic', defaultBaseUrl: 'https://api.moonshot.cn/anthropic', defaultModel: 'moonshot-v1-8k', models: [] }, + { key: 'anthropic', label: 'Claude', icon: , desc: 'Claude Opus/Sonnet', color: '#d97706', backendType: 'anthropic', defaultBaseUrl: 'https://api.anthropic.com', defaultModel: 'claude-3-5-sonnet-20241022', models: [] }, + { key: 'gemini', label: 'Gemini', icon: , desc: 'Gemini 3.1 / 2.5 系列', color: '#059669', backendType: 'gemini', defaultBaseUrl: 'https://generativelanguage.googleapis.com', defaultModel: 'gemini-2.5-flash', models: [] }, { key: 'volcengine', label: '火山引擎', icon: , desc: '火山方舟 / 豆包大模型', color: '#0ea5e9', backendType: 'openai', defaultBaseUrl: 'https://ark.cn-beijing.volces.com/api/v3', defaultModel: 'ep-xxxxxx', models: [] }, - { key: 'minimax', label: 'MiniMax', icon: , desc: 'abab6.5 / abab7 系列', color: '#e11d48', backendType: 'anthropic', defaultBaseUrl: 'https://api.minimaxi.com/anthropic', defaultModel: 'MiniMax-Text-01', models: ['MiniMax-Text-01', 'MiniMax-Text-01-vision', 'MiniMax-Text-01-search', 'MiniMax-Text-01-code', 'MiniMax-Text-01-web', 'MiniMax-Text-01-sql', 'MiniMax-Text-01-python', 'MiniMax-Text-01-math', 'MiniMax-Text-01-doc'] }, + { key: 'minimax', label: 'MiniMax', icon: , desc: 'M2.7 / M2.5 系列 (Anthropic 兼容)', color: '#e11d48', backendType: 'anthropic', defaultBaseUrl: 'https://api.minimaxi.com/anthropic', defaultModel: 'MiniMax-M2.7', models: ['MiniMax-M2.7', 'MiniMax-M2.7-highspeed', 'MiniMax-M2.5', 'MiniMax-M2.5-highspeed', 'MiniMax-M2.1', 'MiniMax-M2.1-highspeed', 'MiniMax-M2'] }, { key: 'ollama', label: 'Ollama', icon: , desc: '本地部署开源模型', color: '#78716c', backendType: 'openai', defaultBaseUrl: 'http://localhost:11434/v1', defaultModel: 'llama3', models: [] }, { key: 'custom', label: '自定义', icon: , desc: '自定义 API 端点', color: '#64748b', backendType: 'custom', defaultBaseUrl: '', defaultModel: '', models: [] }, ]; const findPreset = (key: string): ProviderPreset => PROVIDER_PRESETS.find(p => p.key === key) || PROVIDER_PRESETS[PROVIDER_PRESETS.length - 1]; +const getProviderHostname = (raw?: string): string => { + if (!raw) return ''; + try { + return new URL(raw).hostname.toLowerCase(); + } catch { + return ''; + } +}; + +const matchProviderPreset = (provider: Pick): ProviderPreset => { + const host = getProviderHostname(provider.baseUrl); + if (host.endsWith('moonshot.cn')) { + return findPreset('moonshot'); + } + return PROVIDER_PRESETS.find(pr => pr.backendType === provider.type && host !== '' && host === getProviderHostname(pr.defaultBaseUrl)) + || PROVIDER_PRESETS.find(pr => pr.backendType === provider.type) + || findPreset('custom'); +}; + const SAFETY_OPTIONS: { label: string; value: AISafetyLevel; desc: string; color: string; icon: string }[] = [ { label: '只读模式', value: 'readonly', desc: 'AI 仅可执行 SELECT 等查询操作,最安全', color: '#22c55e', icon: '🔒' }, { label: '读写模式', value: 'readwrite', desc: 'AI 可执行 INSERT/UPDATE/DELETE,危险操作需二次确认', color: '#f59e0b', icon: '⚠️' }, @@ -65,6 +84,10 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo const [builtinPrompts, setBuiltinPrompts] = useState>({}); const [activeSection, setActiveSection] = useState<'providers' | 'safety' | 'context' | 'prompts' | 'tools'>('providers'); const [form] = Form.useForm(); + const modalBodyRef = useRef(null); + + // Modal 内部 toast 通知 + const [messageApi, messageContextHolder] = antdMessage.useMessage({ getContainer: () => modalBodyRef.current || document.body }); // 主题色 const cardBg = darkMode ? 'rgba(255,255,255,0.04)' : 'rgba(0,0,0,0.02)'; @@ -108,31 +131,45 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo const newProvider: AIProviderConfig = { id: '', type: preset.backendType, name: '', apiKey: '', baseUrl: preset.defaultBaseUrl, model: preset.defaultModel, - maxTokens: 4096, temperature: 0.7, + models: [], maxTokens: 4096, temperature: 0.7, }; setEditingProvider({ ...newProvider, presetKey: 'openai' } as any); setIsEditing(true); setTestStatus('idle'); + form.resetFields(); form.setFieldsValue({ ...newProvider, presetKey: 'openai', apiFormat: 'openai' }); }; const handleEditProvider = (p: AIProviderConfig) => { // 尝试根据 baseUrl 和 type 推断 preset - const matchedPreset = PROVIDER_PRESETS.find(pr => pr.backendType === p.type && p.baseUrl?.includes(new URL(pr.defaultBaseUrl || 'http://x').hostname)) - || PROVIDER_PRESETS.find(pr => pr.backendType === p.type) - || findPreset('custom'); + const matchedPreset = matchProviderPreset(p); setEditingProvider(p); setIsEditing(true); setTestStatus('idle'); - form.setFieldsValue({ ...p, presetKey: matchedPreset.key, apiFormat: p.apiFormat || 'openai' }); + form.resetFields(); + form.setFieldsValue({ ...p, type: matchedPreset.backendType, models: p.models || [], presetKey: matchedPreset.key, apiFormat: p.apiFormat || 'openai' }); }; const handleDeleteProvider = async (id: string) => { try { const Service = (window as any).go?.aiservice?.Service; + const wasActive = id === activeProviderId; await Service?.AIDeleteProvider?.(id); - void message.success('已删除'); void loadConfig(); - } catch (e: any) { void message.error(e?.message || '删除失败'); } + await loadConfig(); + // 合并提示:删除的是当前激活的供应商时,附带自动切换信息 + if (wasActive) { + const newProviders: any[] = await Service?.AIGetProviders?.() || []; + if (newProviders.length > 0) { + const newActiveName = newProviders[0]?.name || '下一个供应商'; + void messageApi.success(`已删除,自动切换到「${newActiveName}」`); + } else { + void messageApi.success('已删除'); + } + } else { + void messageApi.success('已删除'); + } + window.dispatchEvent(new CustomEvent('gonavi:ai:provider-changed')); + } catch (e: any) { void messageApi.error(e?.message || '删除失败'); } }; const handleSaveProvider = async () => { @@ -150,20 +187,24 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo // 内置供应商自动使用 preset label 作为名称 const finalName = isCustomLike ? (values.name || preset.label) : preset.label; + const finalBaseUrl = values.baseUrl || preset.defaultBaseUrl; + const payload = { ...editingProvider, ...values, name: finalName, model: finalModel, models: resolvedModels, + baseUrl: finalBaseUrl, apiFormat: values.apiFormat || 'openai', }; // 后端 AISaveProvider 统一处理新增和更新,返回 void,失败抛异常 await Service?.AISaveProvider?.(payload); - void message.success('已保存'); setIsEditing(false); setEditingProvider(null); void loadConfig(); + void messageApi.success('已保存'); setIsEditing(false); setEditingProvider(null); void loadConfig(); + window.dispatchEvent(new CustomEvent('gonavi:ai:provider-changed')); } catch (e: any) { if (e?.errorFields) { /* antd form validation error, ignore */ } - else void message.error(e?.message || '保存失败'); + else void messageApi.error(e?.message || '保存失败'); } finally { setLoading(false); } }; @@ -171,8 +212,9 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo try { const Service = (window as any).go?.aiservice?.Service; await Service?.AISetActiveProvider?.(id); - setActiveProviderId(id); void message.success('已切换'); - } catch (e: any) { void message.error(e?.message || '切换失败'); } + setActiveProviderId(id); void messageApi.success('已切换'); + window.dispatchEvent(new CustomEvent('gonavi:ai:provider-changed')); + } catch (e: any) { void messageApi.error(e?.message || '切换失败'); } }; const handleSafetyChange = async (level: AISafetyLevel) => { @@ -197,10 +239,12 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo setLoading(true); setTestStatus('idle'); const Service = (window as any).go?.aiservice?.Service; - const res = await Service?.AITestProvider?.({ ...values, maxTokens: Number(values.maxTokens) || 4096, temperature: Number(values.temperature) ?? 0.7 }); - if (res?.success) { setTestStatus('success'); void message.success('连接成功'); } - else { setTestStatus('error'); void message.error(`测试失败: ${res?.message || '未知错误'}`); } - } catch (e: any) { setTestStatus('error'); void message.error(e?.message || '测试失败'); } + const preset = findPreset(values.presetKey || 'openai'); + const finalBaseUrl = values.baseUrl || preset.defaultBaseUrl; + const res = await Service?.AITestProvider?.({ ...values, baseUrl: finalBaseUrl, maxTokens: Number(values.maxTokens) || 4096, temperature: Number(values.temperature) ?? 0.7 }); + if (res?.success) { setTestStatus('success'); void messageApi.success('连接成功'); } + else { setTestStatus('error'); void messageApi.error(`测试失败: ${res?.message || '未知错误'}`); } + } catch (e: any) { setTestStatus('error'); void messageApi.error(e?.message || '测试失败'); } finally { setLoading(false); } }; @@ -238,9 +282,7 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo
)} {providers.map(p => { - const matchedPreset = PROVIDER_PRESETS.find(pr => pr.backendType === p.type && p.baseUrl?.includes(new URL(pr.defaultBaseUrl || 'http://x').hostname)) - || PROVIDER_PRESETS.find(pr => pr.backendType === p.type) - || findPreset('custom'); + const matchedPreset = matchProviderPreset(p); const isActive = p.id === activeProviderId; return (
handleSetActive(p.id)} style={{ @@ -605,7 +647,8 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo body: { paddingTop: 8, height: 620, overflow: 'hidden' }, }} > -
+
+ {messageContextHolder}
设置导航
diff --git a/frontend/src/components/ai/AIMessageBubble.tsx b/frontend/src/components/ai/AIMessageBubble.tsx index 453d1b6..93348b5 100644 --- a/frontend/src/components/ai/AIMessageBubble.tsx +++ b/frontend/src/components/ai/AIMessageBubble.tsx @@ -544,7 +544,28 @@ export const AIMessageBubble: React.FC = React.memo(({ msg const [isCopied, setIsCopied] = useState(false); const isUser = msg.role === 'user'; - const displayContent = msg.content; + // 从 content 中提取 ... 标签内容(部分模型如 MiniMax、DeepSeek 会以文本形式返回思考过程) + const { displayContent, parsedThinking } = React.useMemo(() => { + const content = msg.content || ''; + // 优先使用后端已结构化的 thinking 字段(如 Claude API 原生 thinking) + if (msg.thinking) { + return { displayContent: content, parsedThinking: msg.thinking }; + } + // 尝试从 content 中提取 ... 标签 + const thinkRegex = /([\s\S]*?)(?:<\/think>|$)/g; + let thinkParts: string[] = []; + let cleanContent = content; + let match; + while ((match = thinkRegex.exec(content)) !== null) { + thinkParts.push(match[1].trim()); + } + if (thinkParts.length > 0) { + // 移除所有 ... 标签(含未闭合的) + cleanContent = content.replace(/[\s\S]*?(?:<\/think>|$)/g, '').trim(); + return { displayContent: cleanContent, parsedThinking: thinkParts.join('\n\n') }; + } + return { displayContent: content, parsedThinking: '' }; + }, [msg.content, msg.thinking]); const isTypingThinking = !!(msg.loading && msg.phase === 'thinking'); if (msg.role === 'tool') return null; @@ -568,11 +589,11 @@ export const AIMessageBubble: React.FC = React.memo(({ msg
{/* 即使在波纹过渡态,如果有 thinking / tool_calls 也要显示出来,只是把它们压在波纹下面 */} -
0) ? 12 : 0 }}> - {!isUser && msg.thinking && ( +
0) ? 12 : 0 }}> + {!isUser && parsedThinking && ( = React.memo(({ msg
)} {/* 可折叠思考过程 */} - {!isUser && msg.thinking && ( + {!isUser && parsedThinking && ( > = {}; + +function _debouncedPersistSession(sessionId: string) { + if (_persistTimers[sessionId]) clearTimeout(_persistTimers[sessionId]); + _persistTimers[sessionId] = setTimeout(() => { + delete _persistTimers[sessionId]; + const state = useStore.getState(); + const messages = state.aiChatHistory[sessionId]; + const sessionMeta = state.aiChatSessions.find(s => s.id === sessionId); + if (!messages && !sessionMeta) return; // session 已被删除,跳过 + const title = sessionMeta?.title || '新的对话'; + const updatedAt = sessionMeta?.updatedAt || Date.now(); + const messagesJSON = JSON.stringify(messages || []); + const Service = (window as any).go?.aiservice?.Service; + Service?.AISaveSession?.(sessionId, title, updatedAt, messagesJSON).catch((e: any) => { + console.error('[AI Session Persist] 持久化失败:', sessionId, e); + }); + }, 2000); +} + +/** 从后端加载会话列表(仅元数据,不含消息体) */ +export async function loadAISessionsFromBackend(): Promise<{ id: string; title: string; updatedAt: number }[]> { + const Service = (window as any).go?.aiservice?.Service; + if (!Service?.AIGetSessions) return []; + try { + const sessions = await Service.AIGetSessions(); + if (Array.isArray(sessions)) { + useStore.setState({ aiChatSessions: sessions }); + return sessions; + } + } catch (e) { + console.error('[AI Session] 加载会话列表失败:', e); + } + return []; +} + +/** 从后端加载指定会话的消息数据到内存 */ +export async function loadAISessionFromBackend(sessionId: string): Promise { + const state = useStore.getState(); + // 如果内存中已有消息,跳过重复加载 + if (state.aiChatHistory[sessionId]?.length > 0) return true; + + const Service = (window as any).go?.aiservice?.Service; + if (!Service?.AILoadSession) return false; + try { + const result = await Service.AILoadSession(sessionId); + if (result?.success) { + let messages = result.messages; + // messages 可能是 JSON string 或已解析的数组 + if (typeof messages === 'string') { + try { messages = JSON.parse(messages); } catch { messages = []; } + } + if (Array.isArray(messages)) { + useStore.setState((prev) => ({ + aiChatHistory: { ...prev.aiChatHistory, [sessionId]: messages }, + })); + return true; + } + } + } catch (e) { + console.error('[AI Session] 加载会话消息失败:', sessionId, e); + } + return false; +} + export const useStore = create()( persist( (set) => ({ @@ -986,99 +1054,123 @@ export const useStore = create()( // AI actions toggleAIPanel: () => set((state) => ({ aiPanelVisible: !state.aiPanelVisible })), setAIPanelVisible: (visible) => set({ aiPanelVisible: visible }), - addAIChatMessage: (sessionId, message) => set((state) => { - const history = { ...state.aiChatHistory }; - const messages = history[sessionId] || []; - history[sessionId] = [...messages, message]; - - let newSessions = [...state.aiChatSessions]; - const existingSession = newSessions.find(s => s.id === sessionId); - - if (!existingSession) { - // 生成标题(首个 user message 内容前 20 字符) - let title = message.role === 'user' ? message.content : '新的对话'; - if (title.length > 20) { - title = title.substring(0, 20) + '...'; - } - newSessions.unshift({ id: sessionId, title, updatedAt: Date.now() }); - } else { - // 提至最新 - newSessions = newSessions.filter(s => s.id !== sessionId); - newSessions.unshift({ ...existingSession, updatedAt: Date.now() }); - } - - return { aiChatHistory: history, aiChatSessions: newSessions }; - }), - updateAIChatMessage: (sessionId, messageId, updates) => set((state) => { - const messages = state.aiChatHistory[sessionId]; - if (!messages) return state; - // 🔧 性能优化:用 findIndex + 定点替换代替全量 map,长对话场景下从 O(n) 降至 O(1) - const idx = messages.findIndex(m => m.id === messageId); - if (idx < 0) return state; - const newMessages = [...messages]; - newMessages[idx] = { ...newMessages[idx], ...updates }; - const history = { ...state.aiChatHistory, [sessionId]: newMessages }; - // 仅当非纯 content 追加时才重排 session 顺序(性能优化:流式打字时跳过) - const isContentOnlyUpdate = Object.keys(updates).length === 1 && 'content' in updates; - if (!isContentOnlyUpdate) { - let newSessions = [...state.aiChatSessions]; - const existingSession = newSessions.find(s => s.id === sessionId); - if (existingSession) { - newSessions = newSessions.filter(s => s.id !== sessionId); - newSessions.unshift({ ...existingSession, updatedAt: Date.now() }); - } - return { aiChatHistory: history, aiChatSessions: newSessions }; - } - return { aiChatHistory: history }; - }), - deleteAIChatMessage: (sessionId, messageId) => set((state) => { - const history = { ...state.aiChatHistory }; - if (history[sessionId]) { - history[sessionId] = history[sessionId].filter(m => m.id !== messageId); - } - return { aiChatHistory: history }; - }), - truncateAIChatMessages: (sessionId, upToMessageId) => set((state) => { - const history = { ...state.aiChatHistory }; - const messages = history[sessionId]; - if (messages) { - const idx = messages.findIndex(m => m.id === upToMessageId); - if (idx >= 0) { - history[sessionId] = messages.slice(0, idx + 1); - } - } - return { aiChatHistory: history }; - }), - clearAIChatHistory: (sessionId) => set((state) => { - const history = { ...state.aiChatHistory }; - delete history[sessionId]; - return { aiChatHistory: history }; - }), - replaceAIChatHistory: (sessionId, messages) => set((state) => { - const history = { ...state.aiChatHistory }; - history[sessionId] = messages; - return { aiChatHistory: history }; - }), - deleteAISession: (sessionId) => set((state) => { - const history = { ...state.aiChatHistory }; - delete history[sessionId]; - const newSessions = state.aiChatSessions.filter(s => s.id !== sessionId); - const newActive = state.aiActiveSessionId === sessionId ? null : state.aiActiveSessionId; - return { aiChatHistory: history, aiChatSessions: newSessions, aiActiveSessionId: newActive }; - }), + addAIChatMessage: (sessionId, message) => { + set((state) => { + const history = { ...state.aiChatHistory }; + const messages = history[sessionId] || []; + history[sessionId] = [...messages, message]; + + let newSessions = [...state.aiChatSessions]; + const existingSession = newSessions.find(s => s.id === sessionId); + + if (!existingSession) { + let title = message.role === 'user' ? message.content : '新的对话'; + if (title.length > 20) { + title = title.substring(0, 20) + '...'; + } + newSessions.unshift({ id: sessionId, title, updatedAt: Date.now() }); + } else { + newSessions = newSessions.filter(s => s.id !== sessionId); + newSessions.unshift({ ...existingSession, updatedAt: Date.now() }); + } + + return { aiChatHistory: history, aiChatSessions: newSessions }; + }); + // 异步持久化到文件(fire-and-forget,防抖由外层控制) + _debouncedPersistSession(sessionId); + }, + updateAIChatMessage: (sessionId, messageId, updates) => { + set((state) => { + const messages = state.aiChatHistory[sessionId]; + if (!messages) return state; + const idx = messages.findIndex(m => m.id === messageId); + if (idx < 0) return state; + const newMessages = [...messages]; + newMessages[idx] = { ...newMessages[idx], ...updates }; + const history = { ...state.aiChatHistory, [sessionId]: newMessages }; + const isContentOnlyUpdate = Object.keys(updates).length === 1 && 'content' in updates; + if (!isContentOnlyUpdate) { + let newSessions = [...state.aiChatSessions]; + const existingSession = newSessions.find(s => s.id === sessionId); + if (existingSession) { + newSessions = newSessions.filter(s => s.id !== sessionId); + newSessions.unshift({ ...existingSession, updatedAt: Date.now() }); + } + return { aiChatHistory: history, aiChatSessions: newSessions }; + } + return { aiChatHistory: history }; + }); + // 流式打字高频调用,防抖 2 秒后才写磁盘 + _debouncedPersistSession(sessionId); + }, + deleteAIChatMessage: (sessionId, messageId) => { + set((state) => { + const history = { ...state.aiChatHistory }; + if (history[sessionId]) { + history[sessionId] = history[sessionId].filter(m => m.id !== messageId); + } + return { aiChatHistory: history }; + }); + _debouncedPersistSession(sessionId); + }, + truncateAIChatMessages: (sessionId, upToMessageId) => { + set((state) => { + const history = { ...state.aiChatHistory }; + const messages = history[sessionId]; + if (messages) { + const idx = messages.findIndex(m => m.id === upToMessageId); + if (idx >= 0) { + history[sessionId] = messages.slice(0, idx + 1); + } + } + return { aiChatHistory: history }; + }); + _debouncedPersistSession(sessionId); + }, + clearAIChatHistory: (sessionId) => { + set((state) => { + const history = { ...state.aiChatHistory }; + delete history[sessionId]; + return { aiChatHistory: history }; + }); + _debouncedPersistSession(sessionId); + }, + replaceAIChatHistory: (sessionId, messages) => { + set((state) => { + const history = { ...state.aiChatHistory }; + history[sessionId] = messages; + return { aiChatHistory: history }; + }); + _debouncedPersistSession(sessionId); + }, + deleteAISession: (sessionId) => { + set((state) => { + const history = { ...state.aiChatHistory }; + delete history[sessionId]; + const newSessions = state.aiChatSessions.filter(s => s.id !== sessionId); + const newActive = state.aiActiveSessionId === sessionId ? null : state.aiActiveSessionId; + return { aiChatHistory: history, aiChatSessions: newSessions, aiActiveSessionId: newActive }; + }); + // 删除文件 + const Service = (window as any).go?.aiservice?.Service; + Service?.AIDeleteSession?.(sessionId).catch(() => {}); + }, createNewAISession: () => set(() => { const newId = `session-${Date.now()}`; return { aiActiveSessionId: newId }; }), setAIActiveSessionId: (sessionId) => set({ aiActiveSessionId: sessionId }), - updateAISessionTitle: (sessionId, title) => set((state) => { - const newSessions = [...state.aiChatSessions]; - const session = newSessions.find(s => s.id === sessionId); - if (session) { - session.title = title; - } - return { aiChatSessions: newSessions }; - }), + updateAISessionTitle: (sessionId, title) => { + set((state) => { + const newSessions = [...state.aiChatSessions]; + const session = newSessions.find(s => s.id === sessionId); + if (session) { + session.title = title; + } + return { aiChatSessions: newSessions }; + }); + _debouncedPersistSession(sessionId); + }, addAIContext: (connectionKey, context) => set((state) => { const contexts = state.aiContexts[connectionKey] || []; if (contexts.find(c => c.dbName === context.dbName && c.tableName === context.tableName)) { @@ -1173,8 +1265,9 @@ export const useStore = create()( shortcutOptions: sanitizeShortcutOptions(state.shortcutOptions), tableAccessCount: sanitizeTableAccessCount(state.tableAccessCount), - aiChatHistory: (state.aiChatHistory && typeof state.aiChatHistory === 'object') ? state.aiChatHistory : {}, - aiChatSessions: Array.isArray(state.aiChatSessions) ? state.aiChatSessions : [], + // AI 会话数据不再从 localStorage 恢复,改为从后端文件加载 + aiChatHistory: {}, + aiChatSessions: [], }; }, partialize: (state) => ({ @@ -1200,17 +1293,7 @@ export const useStore = create()( windowState: state.windowState, sidebarWidth: state.sidebarWidth, - // 只持久化最近 20 个会话的聊天记录,防止 localStorage 膨胀 - aiChatHistory: (() => { - const MAX_PERSIST_SESSIONS = 20; - const recentIds = new Set(state.aiChatSessions.slice(0, MAX_PERSIST_SESSIONS).map(s => s.id)); - const trimmed: Record = {}; - for (const id of recentIds) { - if (state.aiChatHistory[id]) trimmed[id] = state.aiChatHistory[id]; - } - return trimmed; - })(), - aiChatSessions: state.aiChatSessions.slice(0, 50), + // AI 会话数据已迁移到后端文件持久化(~/.gonavi/sessions/),不再写入 localStorage }), // Don't persist logs } ) diff --git a/frontend/wailsjs/go/aiservice/Service.d.ts b/frontend/wailsjs/go/aiservice/Service.d.ts index 6ffc07a..52dc1a6 100755 --- a/frontend/wailsjs/go/aiservice/Service.d.ts +++ b/frontend/wailsjs/go/aiservice/Service.d.ts @@ -13,6 +13,8 @@ export function AICheckSQL(arg1:string):Promise; export function AIDeleteProvider(arg1:string):Promise; +export function AIDeleteSession(arg1:string):Promise; + export function AIGetActiveProvider():Promise; export function AIGetBuiltinPrompts():Promise>; @@ -23,10 +25,16 @@ export function AIGetProviders():Promise>; export function AIGetSafetyLevel():Promise; +export function AIGetSessions():Promise>>; + export function AIListModels():Promise>; +export function AILoadSession(arg1:string):Promise>; + export function AISaveProvider(arg1:ai.ProviderConfig):Promise; +export function AISaveSession(arg1:string,arg2:string,arg3:number,arg4:string):Promise; + export function AISetActiveProvider(arg1:string):Promise; export function AISetContextLevel(arg1:string):Promise; diff --git a/frontend/wailsjs/go/aiservice/Service.js b/frontend/wailsjs/go/aiservice/Service.js index 7f5de4a..acebb37 100755 --- a/frontend/wailsjs/go/aiservice/Service.js +++ b/frontend/wailsjs/go/aiservice/Service.js @@ -22,6 +22,10 @@ export function AIDeleteProvider(arg1) { return window['go']['aiservice']['Service']['AIDeleteProvider'](arg1); } +export function AIDeleteSession(arg1) { + return window['go']['aiservice']['Service']['AIDeleteSession'](arg1); +} + export function AIGetActiveProvider() { return window['go']['aiservice']['Service']['AIGetActiveProvider'](); } @@ -42,14 +46,26 @@ export function AIGetSafetyLevel() { return window['go']['aiservice']['Service']['AIGetSafetyLevel'](); } +export function AIGetSessions() { + return window['go']['aiservice']['Service']['AIGetSessions'](); +} + export function AIListModels() { return window['go']['aiservice']['Service']['AIListModels'](); } +export function AILoadSession(arg1) { + return window['go']['aiservice']['Service']['AILoadSession'](arg1); +} + export function AISaveProvider(arg1) { return window['go']['aiservice']['Service']['AISaveProvider'](arg1); } +export function AISaveSession(arg1, arg2, arg3, arg4) { + return window['go']['aiservice']['Service']['AISaveSession'](arg1, arg2, arg3, arg4); +} + export function AISetActiveProvider(arg1) { return window['go']['aiservice']['Service']['AISetActiveProvider'](arg1); } diff --git a/internal/ai/provider/anthropic.go b/internal/ai/provider/anthropic.go index 1104680..1733222 100644 --- a/internal/ai/provider/anthropic.go +++ b/internal/ai/provider/anthropic.go @@ -15,10 +15,23 @@ import ( const ( defaultAnthropicBaseURL = "https://api.anthropic.com" - defaultAnthropicModel = "claude-3-5-sonnet-20241022" anthropicAPIVersion = "2023-06-01" ) +func normalizeAnthropicMessagesURL(baseURL string) string { + url := strings.TrimRight(strings.TrimSpace(baseURL), "/") + if url == "" { + url = defaultAnthropicBaseURL + } + if strings.HasSuffix(url, "/messages") { + return url + } + if strings.HasSuffix(url, "/v1") { + return url + "/messages" + } + return url + "/v1/messages" +} + // AnthropicProvider 实现 Anthropic Claude API 的 Provider type AnthropicProvider struct { config ai.ProviderConfig @@ -34,7 +47,7 @@ func NewAnthropicProvider(config ai.ProviderConfig) (Provider, error) { } model := strings.TrimSpace(config.Model) if model == "" { - model = defaultAnthropicModel + return nil, fmt.Errorf("模型 ID 不能为空,请在设置中选择或输入模型") } maxTokens := config.MaxTokens if maxTokens <= 0 { @@ -425,10 +438,7 @@ func (p *AnthropicProvider) doRequest(ctx context.Context, body interface{}) (io return nil, fmt.Errorf("序列化请求失败: %w", err) } - url := p.baseURL + "/v1/messages" - if strings.HasSuffix(p.baseURL, "/v1") { - url = p.baseURL + "/messages" - } + url := normalizeAnthropicMessagesURL(p.baseURL) httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody)) if err != nil { diff --git a/internal/ai/provider/anthropic_test.go b/internal/ai/provider/anthropic_test.go new file mode 100644 index 0000000..3c119f2 --- /dev/null +++ b/internal/ai/provider/anthropic_test.go @@ -0,0 +1,24 @@ +package provider + +import "testing" + +func TestNormalizeAnthropicMessagesURL_AppendsMessagesSuffix(t *testing.T) { + url := normalizeAnthropicMessagesURL("https://api.anthropic.com") + if url != "https://api.anthropic.com/v1/messages" { + t.Fatalf("expected normalized anthropic messages url, got %q", url) + } +} + +func TestNormalizeAnthropicMessagesURL_UsesMoonshotAnthropicMessagesEndpoint(t *testing.T) { + url := normalizeAnthropicMessagesURL("https://api.moonshot.cn/anthropic") + if url != "https://api.moonshot.cn/anthropic/v1/messages" { + t.Fatalf("expected moonshot anthropic messages url, got %q", url) + } +} + +func TestNormalizeAnthropicMessagesURL_PreservesExplicitMessagesPath(t *testing.T) { + url := normalizeAnthropicMessagesURL("https://api.moonshot.cn/anthropic/v1/messages") + if url != "https://api.moonshot.cn/anthropic/v1/messages" { + t.Fatalf("expected explicit messages path to be preserved, got %q", url) + } +} diff --git a/internal/ai/provider/gemini.go b/internal/ai/provider/gemini.go index b4cf910..ebadde9 100644 --- a/internal/ai/provider/gemini.go +++ b/internal/ai/provider/gemini.go @@ -15,7 +15,6 @@ import ( const ( defaultGeminiBaseURL = "https://generativelanguage.googleapis.com" - defaultGeminiModel = "gemini-2.0-flash" ) // GeminiProvider 实现 Google Gemini API 的 Provider @@ -33,7 +32,7 @@ func NewGeminiProvider(config ai.ProviderConfig) (Provider, error) { } model := strings.TrimSpace(config.Model) if model == "" { - model = defaultGeminiModel + return nil, fmt.Errorf("模型 ID 不能为空,请在设置中选择或输入模型") } maxTokens := config.MaxTokens if maxTokens <= 0 { diff --git a/internal/ai/provider/openai.go b/internal/ai/provider/openai.go index 5ff9f10..1e86caf 100644 --- a/internal/ai/provider/openai.go +++ b/internal/ai/provider/openai.go @@ -16,7 +16,6 @@ import ( const ( defaultOpenAIBaseURL = "https://api.openai.com/v1" - defaultOpenAIModel = "gpt-4o" defaultOpenAIMaxTokens = 4096 defaultOpenAITemperature = 0.7 openAIHTTPTimeout = 120 * time.Second @@ -41,7 +40,7 @@ func NewOpenAIProvider(config ai.ProviderConfig) (Provider, error) { } model := strings.TrimSpace(config.Model) if model == "" { - model = defaultOpenAIModel + return nil, fmt.Errorf("模型 ID 不能为空,请在设置中选择或输入模型") } maxTokens := config.MaxTokens if maxTokens <= 0 { diff --git a/internal/ai/provider/openai_test.go b/internal/ai/provider/openai_test.go index 94671a4..c200178 100644 --- a/internal/ai/provider/openai_test.go +++ b/internal/ai/provider/openai_test.go @@ -28,18 +28,24 @@ func TestOpenAIProvider_Validate_Valid(t *testing.T) { } func TestOpenAIProvider_Name_Custom(t *testing.T) { - p, _ := NewOpenAIProvider(ai.ProviderConfig{ - Type: "openai", Name: "My OpenAI", APIKey: "sk-test", + p, err := NewOpenAIProvider(ai.ProviderConfig{ + Type: "openai", Name: "My OpenAI", APIKey: "sk-test", Model: "gpt-4o", }) + if err != nil { + t.Fatalf("unexpected constructor error: %v", err) + } if p.Name() != "My OpenAI" { t.Fatalf("expected name 'My OpenAI', got '%s'", p.Name()) } } func TestOpenAIProvider_Name_Default(t *testing.T) { - p, _ := NewOpenAIProvider(ai.ProviderConfig{ - Type: "openai", APIKey: "sk-test", + p, err := NewOpenAIProvider(ai.ProviderConfig{ + Type: "openai", APIKey: "sk-test", Model: "gpt-4o", }) + if err != nil { + t.Fatalf("unexpected constructor error: %v", err) + } if p.Name() != "OpenAI" { t.Fatalf("expected default name 'OpenAI', got '%s'", p.Name()) } @@ -56,29 +62,34 @@ func TestOpenAIProvider_DefaultBaseURL(t *testing.T) { } func TestOpenAIProvider_CustomBaseURL(t *testing.T) { - p, _ := NewOpenAIProvider(ai.ProviderConfig{ - Type: "openai", APIKey: "sk-test", BaseURL: "https://my-proxy.com/v1", + p, err := NewOpenAIProvider(ai.ProviderConfig{ + Type: "openai", APIKey: "sk-test", BaseURL: "https://my-proxy.com/v1", Model: "gpt-4o", }) + if err != nil { + t.Fatalf("unexpected constructor error: %v", err) + } op := p.(*OpenAIProvider) if op.baseURL != "https://my-proxy.com/v1" { t.Fatalf("expected custom base URL, got '%s'", op.baseURL) } } -func TestOpenAIProvider_DefaultModel(t *testing.T) { - p, _ := NewOpenAIProvider(ai.ProviderConfig{ +func TestOpenAIProvider_RejectsMissingModel(t *testing.T) { + _, err := NewOpenAIProvider(ai.ProviderConfig{ Type: "openai", APIKey: "sk-test", }) - op := p.(*OpenAIProvider) - if op.config.Model != "gpt-4o" { - t.Fatalf("expected default model 'gpt-4o', got '%s'", op.config.Model) + if err == nil { + t.Fatal("expected constructor error for missing model") } } func TestOpenAIProvider_DefaultMaxTokens(t *testing.T) { - p, _ := NewOpenAIProvider(ai.ProviderConfig{ - Type: "openai", APIKey: "sk-test", + p, err := NewOpenAIProvider(ai.ProviderConfig{ + Type: "openai", APIKey: "sk-test", Model: "gpt-4o", }) + if err != nil { + t.Fatalf("unexpected constructor error: %v", err) + } op := p.(*OpenAIProvider) if op.config.MaxTokens != 4096 { t.Fatalf("expected default max tokens 4096, got %d", op.config.MaxTokens) diff --git a/internal/ai/service/service.go b/internal/ai/service/service.go index addab29..bf21638 100644 --- a/internal/ai/service/service.go +++ b/internal/ai/service/service.go @@ -35,6 +35,16 @@ type Service struct { cancelFuncs map[string]context.CancelFunc // 记录每个 session 的 context 取消函数 } +var miniMaxAnthropicModels = []string{ + "MiniMax-M2.7", + "MiniMax-M2.7-highspeed", + "MiniMax-M2.5", + "MiniMax-M2.5-highspeed", + "MiniMax-M2.1", + "MiniMax-M2.1-highspeed", + "MiniMax-M2", +} + // NewService 创建 AI Service 实例 func NewService() *Service { return &Service{ @@ -63,6 +73,9 @@ func (s *Service) AIGetProviders() []ai.ProviderConfig { result := make([]ai.ProviderConfig, len(s.providers)) copy(result, s.providers) + for i := range result { + result[i] = normalizeProviderConfig(result[i]) + } return result } @@ -72,6 +85,8 @@ func (s *Service) AISaveProvider(config ai.ProviderConfig) error { s.mu.Lock() defer s.mu.Unlock() + config = normalizeProviderConfig(config) + if strings.TrimSpace(config.ID) == "" { config.ID = "provider-" + uuid.New().String()[:8] } @@ -128,65 +143,36 @@ func (s *Service) AITestProvider(config ai.ProviderConfig) map[string]interface{ } s.mu.RUnlock() + config = normalizeProviderConfig(config) baseURL := strings.TrimRight(strings.TrimSpace(config.BaseURL), "/") - providerType := config.Type - if providerType == "custom" && config.APIFormat != "" { - providerType = config.APIFormat - } + providerType := normalizedProviderType(config) client := &http.Client{Timeout: 10 * time.Second} var err error switch providerType { - case "openai": - if baseURL == "" { - baseURL = "https://api.openai.com/v1" - } - if !strings.HasSuffix(baseURL, "/v1") && !strings.Contains(baseURL, "/v1/") { - baseURL = baseURL + "/v1" - } - // 使用 /models 端点验证连通性和鉴权 - req, _ := http.NewRequest("GET", baseURL+"/models", nil) - req.Header.Set("Authorization", "Bearer "+config.APIKey) - for k, v := range config.Headers { - req.Header.Set(k, v) + case "openai", "anthropic", "gemini": + req, reqErr := newProviderHealthCheckRequest(config) + if reqErr != nil { + err = reqErr + break } resp, reqErr := client.Do(req) if reqErr != nil { err = reqErr } else { defer resp.Body.Close() - if resp.StatusCode == http.StatusUnauthorized { - err = fmt.Errorf("API Key 验证失败 (HTTP %d)", resp.StatusCode) + if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { + err = fmt.Errorf("API Key 无效或请求错误 (HTTP %d)", resp.StatusCode) + } else if providerType == "gemini" && resp.StatusCode == http.StatusBadRequest { + err = fmt.Errorf("API Key 无效或请求错误 (HTTP %d)", resp.StatusCode) + } else if resp.StatusCode >= 400 { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 512)) + err = fmt.Errorf("接口返回异常 (HTTP %d): %s", resp.StatusCode, string(body)) } else if resp.StatusCode >= 500 { err = fmt.Errorf("上游服务器内部错误 (HTTP %d)", resp.StatusCode) } } - case "anthropic": - if baseURL == "" { - baseURL = "https://api.anthropic.com" - } - req, _ := http.NewRequest("GET", baseURL, nil) - resp, reqErr := client.Do(req) - if reqErr != nil { - err = reqErr - } else { - resp.Body.Close() - } - case "gemini": - if baseURL == "" { - baseURL = "https://generativelanguage.googleapis.com" - } - req, _ := http.NewRequest("GET", baseURL+"/v1beta/models?key="+config.APIKey, nil) - resp, reqErr := client.Do(req) - if reqErr != nil { - err = reqErr - } else { - defer resp.Body.Close() - if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusBadRequest { - err = fmt.Errorf("API Key 无效或请求错误 (HTTP %d)", resp.StatusCode) - } - } default: if baseURL != "" { req, _ := http.NewRequest("GET", baseURL, nil) @@ -209,6 +195,153 @@ func (s *Service) AITestProvider(config ai.ProviderConfig) map[string]interface{ } } +func normalizedProviderType(config ai.ProviderConfig) string { + providerType := strings.ToLower(strings.TrimSpace(config.Type)) + if providerType == "custom" && strings.TrimSpace(config.APIFormat) != "" { + return strings.ToLower(strings.TrimSpace(config.APIFormat)) + } + return providerType +} + +func isMiniMaxAnthropicProvider(config ai.ProviderConfig) bool { + if normalizedProviderType(config) != "anthropic" { + return false + } + baseURL := strings.ToLower(strings.TrimRight(strings.TrimSpace(config.BaseURL), "/")) + return strings.Contains(baseURL, "api.minimax.io") || strings.Contains(baseURL, "api.minimaxi.com") +} + +func isMoonshotAnthropicProvider(config ai.ProviderConfig) bool { + if normalizedProviderType(config) != "anthropic" { + return false + } + baseURL := strings.ToLower(strings.TrimRight(strings.TrimSpace(config.BaseURL), "/")) + return strings.Contains(baseURL, "api.moonshot.cn") +} + +func defaultStaticModelsForProvider(config ai.ProviderConfig) []string { + if isMiniMaxAnthropicProvider(config) { + return append([]string(nil), miniMaxAnthropicModels...) + } + return nil +} + +func normalizeProviderConfig(config ai.ProviderConfig) ai.ProviderConfig { + staticModels := defaultStaticModelsForProvider(config) + if len(staticModels) > 0 && len(config.Models) == 0 { + config.Models = staticModels + } + model := strings.TrimSpace(config.Model) + if isMiniMaxAnthropicProvider(config) && (model == "" || strings.HasPrefix(strings.ToLower(model), "minimax-text-")) { + config.Model = miniMaxAnthropicModels[0] + } + return config +} + +func resolveModelsURL(config ai.ProviderConfig) string { + config = normalizeProviderConfig(config) + providerType := normalizedProviderType(config) + baseURL := strings.TrimRight(strings.TrimSpace(config.BaseURL), "/") + + switch providerType { + case "anthropic": + if isMoonshotAnthropicProvider(config) { + return "https://api.moonshot.cn/v1/models" + } + if baseURL == "" { + baseURL = "https://api.anthropic.com" + } + if !strings.HasSuffix(baseURL, "/v1") && !strings.Contains(baseURL, "/v1/") { + baseURL = baseURL + "/v1" + } + return baseURL + "/models" + case "gemini": + if baseURL == "" { + baseURL = "https://generativelanguage.googleapis.com" + } + return baseURL + "/v1beta/models?key=" + config.APIKey + case "openai": + fallthrough + default: + if baseURL == "" { + baseURL = "https://api.openai.com/v1" + } + if !strings.HasSuffix(baseURL, "/v1") && !strings.Contains(baseURL, "/v1/") { + baseURL = baseURL + "/v1" + } + return baseURL + "/models" + } +} + +func newModelsRequest(config ai.ProviderConfig) (*http.Request, error) { + config = normalizeProviderConfig(config) + url := resolveModelsURL(config) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, fmt.Errorf("创建请求失败: %w", err) + } + + switch normalizedProviderType(config) { + case "anthropic": + req.Header.Set("x-api-key", config.APIKey) + req.Header.Set("anthropic-version", "2023-06-01") + req.Header.Set("Authorization", "Bearer "+config.APIKey) + case "gemini": + // Gemini 使用 query string 传递 key,无需额外鉴权头 + default: + req.Header.Set("Authorization", "Bearer "+config.APIKey) + } + + for k, v := range config.Headers { + req.Header.Set(k, v) + } + + return req, nil +} + +func resolveAnthropicMessagesURL(baseURL string) string { + url := strings.TrimRight(strings.TrimSpace(baseURL), "/") + if url == "" { + url = "https://api.anthropic.com" + } + if strings.HasSuffix(url, "/messages") { + return url + } + if strings.HasSuffix(url, "/v1") { + return url + "/messages" + } + return url + "/v1/messages" +} + +func newProviderHealthCheckRequest(config ai.ProviderConfig) (*http.Request, error) { + config = normalizeProviderConfig(config) + if isMiniMaxAnthropicProvider(config) { + body := map[string]interface{}{ + "model": config.Model, + "max_tokens": 1, + "messages": []map[string]string{ + {"role": "user", "content": "ping"}, + }, + } + bodyBytes, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("序列化请求失败: %w", err) + } + req, err := http.NewRequest("POST", resolveAnthropicMessagesURL(config.BaseURL), strings.NewReader(string(bodyBytes))) + if err != nil { + return nil, fmt.Errorf("创建请求失败: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-api-key", config.APIKey) + req.Header.Set("anthropic-version", "2023-06-01") + for k, v := range config.Headers { + req.Header.Set(k, v) + } + return req, nil + } + return newModelsRequest(config) +} + // AISetActiveProvider 设置活动 Provider func (s *Service) AISetActiveProvider(id string) { s.mu.Lock() @@ -261,17 +394,16 @@ func (s *Service) AIListModels() map[string]interface{} { // fetchModels 从供应商 API 获取可用模型列表 func fetchModels(config ai.ProviderConfig) ([]string, error) { - providerType := config.Type - if providerType == "custom" && config.APIFormat != "" { - providerType = config.APIFormat + providerType := normalizedProviderType(config) + if staticModels := defaultStaticModelsForProvider(config); len(staticModels) > 0 { + return staticModels, nil } switch providerType { case "openai": return fetchOpenAIModels(config) case "anthropic": - // Anthropic 没有公开的 /models 端点,返回硬编码列表 - return []string{"claude-opus-4-6", "claude-sonnet-4-6"}, nil + return fetchAnthropicModels(config) case "gemini": return fetchGeminiModels(config) default: @@ -281,20 +413,45 @@ func fetchModels(config ai.ProviderConfig) ([]string, error) { // fetchOpenAIModels 获取 OpenAI 兼容 API 的模型列表 func fetchOpenAIModels(config ai.ProviderConfig) ([]string, error) { - baseURL := strings.TrimRight(strings.TrimSpace(config.BaseURL), "/") - if baseURL == "" { - baseURL = "https://api.openai.com/v1" - } - // 确保 baseURL 以 /v1 结尾 - if !strings.HasSuffix(baseURL, "/v1") { - baseURL = baseURL + "/v1" + req, err := newModelsRequest(config) + if err != nil { + return nil, err } - req, err := http.NewRequest("GET", baseURL+"/models", nil) + client := &http.Client{Timeout: 15 * time.Second} + resp, err := client.Do(req) if err != nil { - return nil, fmt.Errorf("创建请求失败: %w", err) + return nil, fmt.Errorf("请求模型列表失败: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) + return nil, fmt.Errorf("获取模型列表失败 (HTTP %d): %s", resp.StatusCode, string(body)) + } + + var result struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("解析模型列表失败: %w", err) + } + + models := make([]string, 0, len(result.Data)) + for _, m := range result.Data { + models = append(models, m.ID) + } + return models, nil +} + +// fetchAnthropicModels 获取 Anthropic API 的模型列表 +func fetchAnthropicModels(config ai.ProviderConfig) ([]string, error) { + req, err := newModelsRequest(config) + if err != nil { + return nil, err } - req.Header.Set("Authorization", "Bearer "+config.APIKey) client := &http.Client{Timeout: 15 * time.Second} resp, err := client.Do(req) @@ -515,7 +672,7 @@ func (s *Service) getActiveProvider() (provider.Provider, error) { for _, cfg := range s.providers { if cfg.ID == s.activeProvider { - return provider.NewProvider(cfg) + return provider.NewProvider(normalizeProviderConfig(cfg)) } } @@ -548,6 +705,9 @@ func (s *Service) loadConfig() { if s.providers == nil { s.providers = make([]ai.ProviderConfig, 0) } + for i := range s.providers { + s.providers[i] = normalizeProviderConfig(s.providers[i]) + } s.activeProvider = cfg.ActiveProvider switch ai.SQLPermissionLevel(cfg.SafetyLevel) { @@ -591,14 +751,122 @@ func (s *Service) saveConfig() error { return nil } +// --- 会话文件持久化 --- + +// sessionFileData 会话文件的 JSON 结构 +type sessionFileData struct { + ID string `json:"id"` + Title string `json:"title"` + UpdatedAt int64 `json:"updatedAt"` + Messages json.RawMessage `json:"messages"` // 透传前端格式,后端不解析消息体 +} + +func (s *Service) sessionsDir() string { + return filepath.Join(s.configDir, "sessions") +} + +// AIGetSessions 获取所有会话的元数据列表(不含消息体) +func (s *Service) AIGetSessions() []map[string]interface{} { + dir := s.sessionsDir() + entries, err := os.ReadDir(dir) + if err != nil { + return []map[string]interface{}{} + } + + var sessions []map[string]interface{} + for _, entry := range entries { + if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") { + continue + } + data, err := os.ReadFile(filepath.Join(dir, entry.Name())) + if err != nil { + continue + } + var sfd sessionFileData + if err := json.Unmarshal(data, &sfd); err != nil { + continue + } + sessions = append(sessions, map[string]interface{}{ + "id": sfd.ID, + "title": sfd.Title, + "updatedAt": sfd.UpdatedAt, + }) + } + + // 按 updatedAt 降序排列 + for i := 0; i < len(sessions); i++ { + for j := i + 1; j < len(sessions); j++ { + ti, _ := sessions[i]["updatedAt"].(int64) + tj, _ := sessions[j]["updatedAt"].(int64) + if tj > ti { + sessions[i], sessions[j] = sessions[j], sessions[i] + } + } + } + + return sessions +} + +// AILoadSession 加载指定会话的完整数据(含消息) +func (s *Service) AILoadSession(sessionID string) map[string]interface{} { + path := filepath.Join(s.sessionsDir(), sessionID+".json") + data, err := os.ReadFile(path) + if err != nil { + return map[string]interface{}{"success": false, "error": "会话不存在"} + } + var sfd sessionFileData + if err := json.Unmarshal(data, &sfd); err != nil { + return map[string]interface{}{"success": false, "error": "会话数据损坏"} + } + return map[string]interface{}{ + "success": true, + "id": sfd.ID, + "title": sfd.Title, + "updatedAt": sfd.UpdatedAt, + "messages": sfd.Messages, + } +} + +// AISaveSession 保存会话数据到文件 +func (s *Service) AISaveSession(sessionID string, title string, updatedAt float64, messagesJSON string) error { + dir := s.sessionsDir() + if err := os.MkdirAll(dir, 0o755); err != nil { + return fmt.Errorf("创建 sessions 目录失败: %w", err) + } + + sfd := sessionFileData{ + ID: sessionID, + Title: title, + UpdatedAt: int64(updatedAt), + Messages: json.RawMessage(messagesJSON), + } + + data, err := json.MarshalIndent(sfd, "", " ") + if err != nil { + return fmt.Errorf("序列化会话数据失败: %w", err) + } + + path := filepath.Join(dir, sessionID+".json") + return os.WriteFile(path, data, 0o644) +} + +// AIDeleteSession 删除会话文件 +func (s *Service) AIDeleteSession(sessionID string) error { + path := filepath.Join(s.sessionsDir(), sessionID+".json") + if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("删除会话失败: %w", err) + } + return nil +} + // --- 工具函数 --- func resolveConfigDir() string { - configDir, err := os.UserConfigDir() + homeDir, err := os.UserHomeDir() if err != nil { - configDir = "." + homeDir = "." } - return filepath.Join(configDir, "GoNavi") + return filepath.Join(homeDir, ".gonavi") } func maskAPIKey(apiKey string) string { diff --git a/internal/ai/service/service_test.go b/internal/ai/service/service_test.go new file mode 100644 index 0000000..6574b5a --- /dev/null +++ b/internal/ai/service/service_test.go @@ -0,0 +1,78 @@ +package aiservice + +import ( + "reflect" + "testing" + + "GoNavi-Wails/internal/ai" +) + +func TestResolveModelsURL_UsesMoonshotOpenAIModelsEndpointForKimiAnthropicBaseURL(t *testing.T) { + url := resolveModelsURL(ai.ProviderConfig{ + Type: "anthropic", + BaseURL: "https://api.moonshot.cn/anthropic", + }) + if url != "https://api.moonshot.cn/v1/models" { + t.Fatalf("expected moonshot models endpoint, got %q", url) + } +} + +func TestResolveModelsURL_UsesAnthropicModelsEndpointForOfficialAnthropic(t *testing.T) { + url := resolveModelsURL(ai.ProviderConfig{ + Type: "anthropic", + BaseURL: "https://api.anthropic.com", + }) + if url != "https://api.anthropic.com/v1/models" { + t.Fatalf("expected anthropic models endpoint, got %q", url) + } +} + +func TestResolveModelsURL_UsesOpenAIModelsEndpointForOpenAICompatibleProvider(t *testing.T) { + url := resolveModelsURL(ai.ProviderConfig{ + Type: "openai", + BaseURL: "https://api.openai.com/v1", + }) + if url != "https://api.openai.com/v1/models" { + t.Fatalf("expected openai models endpoint, got %q", url) + } +} + +func TestDefaultStaticModelsForProvider_ReturnsMiniMaxAnthropicModels(t *testing.T) { + models := defaultStaticModelsForProvider(ai.ProviderConfig{ + Type: "anthropic", + BaseURL: "https://api.minimaxi.com/anthropic", + }) + expected := []string{ + "MiniMax-M2.7", + "MiniMax-M2.7-highspeed", + "MiniMax-M2.5", + "MiniMax-M2.5-highspeed", + "MiniMax-M2.1", + "MiniMax-M2.1-highspeed", + "MiniMax-M2", + } + if !reflect.DeepEqual(models, expected) { + t.Fatalf("expected MiniMax static models %v, got %v", expected, models) + } +} + +func TestNewProviderHealthCheckRequest_UsesMessagesEndpointForMiniMaxAnthropic(t *testing.T) { + req, err := newProviderHealthCheckRequest(ai.ProviderConfig{ + Type: "anthropic", + BaseURL: "https://api.minimaxi.com/anthropic", + Model: "MiniMax-M2.7", + APIKey: "sk-test", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if req.Method != "POST" { + t.Fatalf("expected POST request, got %s", req.Method) + } + if req.URL.String() != "https://api.minimaxi.com/anthropic/v1/messages" { + t.Fatalf("expected MiniMax messages endpoint, got %q", req.URL.String()) + } + if got := req.Header.Get("x-api-key"); got != "sk-test" { + t.Fatalf("expected x-api-key header to be set, got %q", got) + } +}