diff --git a/cmd/gonavi-mcp-server/README.md b/cmd/gonavi-mcp-server/README.md new file mode 100644 index 0000000..253a1f7 --- /dev/null +++ b/cmd/gonavi-mcp-server/README.md @@ -0,0 +1,130 @@ +# GoNavi MCP Server + +`gonavi-mcp-server` 会把 GoNavi 已保存连接背后的数据库能力通过 MCP `stdio` 暴露给外部客户端。 + +## 当前提供的 tools + +- `get_connections` + - 返回 GoNavi 已保存连接的 `id/name/type/target/defaultDatabase` 等摘要信息 +- `get_databases` + - 入参:`connectionId` +- `get_tables` + - 入参:`connectionId`、可选 `dbName` +- `get_columns` + - 入参:`connectionId`、可选 `dbName`、`tableName` +- `get_table_ddl` + - 入参:`connectionId`、可选 `dbName`、`tableName` +- `execute_sql` + - 入参:`connectionId`、可选 `dbName`、`sql` + - 默认只允许只读 SQL + - 如果 SQL 包含 DDL/DML,必须显式传 `allowMutating=true` + - `maxRowsPerResult` 用来限制单个结果集返回的行数,默认 `200` + +## 运行方式 + +开发态直接运行: + +```powershell +go run ./cmd/gonavi-mcp-server +``` + +也可以先编译: + +```powershell +go build -o .\bin\gonavi-mcp-server.exe .\cmd\gonavi-mcp-server +``` + +## Claude Code / Codex + +正式安装包场景,推荐直接在 GoNavi 里使用“AI 设置 -> MCP 服务 -> 安装到 Claude Code / 安装到 Codex”。 + +它会自动把当前安装的 `GoNavi.exe` 写入 Claude Code 的用户级 `~/.claude.json`,命令形态类似: + +```json +{ + "mcpServers": { + "gonavi": { + "type": "stdio", + "command": "C:\\Program Files\\GoNavi\\GoNavi.exe", + "args": ["mcp-server"], + "env": {} + } + } +} +``` + +这样用户不需要自己找本机 `gonavi-mcp-server.exe` 路径,安装包本体就能直接作为 MCP 入口。 + +Codex 当前使用 `~/.codex/config.toml`,GoNavi 会写入类似下面这段: + +```toml +[mcp_servers.gonavi] +command = 'C:\Program Files\GoNavi\GoNavi.exe' +args = ['mcp-server'] +startup_timeout_sec = 60 +``` + +仓库开发态如果要在本机 `Claude Code CLI` 里稳定使用这个 MCP,仍然推荐走仓库内包装脚本: + +```powershell +.\tools\claude-gonavi-mcp.ps1 -p "必须调用 gonavi MCP 的 get_connections 工具" +``` + +或者: + +```cmd +tools\claude-gonavi-mcp.cmd -p "必须调用 gonavi MCP 的 get_connections 工具" +``` + +这个脚本会先构建 `bin\gonavi-mcp-server.exe`,再通过 `--mcp-config` 和 `--strict-mcp-config` 把 GoNavi MCP 单独注入当前 Claude 会话,避免默认混合 MCP 加载时序导致的首轮工具未挂载问题。 + +## MCP 客户端配置示例 + +开发态: + +```json +{ + "mcpServers": { + "gonavi": { + "command": "go", + "args": ["run", "./cmd/gonavi-mcp-server"] + } + } +} +``` + +Windows 独立 server 编译产物(开发态): + +```json +{ + "mcpServers": { + "gonavi": { + "command": "D:\\Work\\CodeRepos\\GoNavi\\bin\\gonavi-mcp-server.exe", + "args": [] + } + } +} +``` + +Windows 已安装 GoNavi(推荐给最终用户): + +```json +{ + "mcpServers": { + "gonavi": { + "type": "stdio", + "command": "C:\\Program Files\\GoNavi\\GoNavi.exe", + "args": ["mcp-server"], + "env": {} + } + } +} +``` + +## 使用说明 + +- 先调用 `get_connections`,拿到 `connectionId` +- 之后所有数据库工具都只传 `connectionId`,由 GoNavi 服务端内部解析保存连接和密钥 +- 如果 `dbName` 为空,会优先使用该保存连接里的默认数据库 +- Server 会读取 GoNavi 当前活动数据目录里的连接配置,并通过系统 keyring/凭据管理器解析密文 +- 如果本机凭据存储不可用,依赖密钥的连接会返回对应错误 diff --git a/cmd/gonavi-mcp-server/main.go b/cmd/gonavi-mcp-server/main.go new file mode 100644 index 0000000..b91fc74 --- /dev/null +++ b/cmd/gonavi-mcp-server/main.go @@ -0,0 +1,15 @@ +package main + +import ( + "context" + "log" + + "GoNavi-Wails/internal/mcpserver" +) + +func main() { + ctx := context.Background() + if err := mcpserver.RunAppStdioServer(ctx); err != nil { + log.Printf("GoNavi MCP Server 退出: %v", err) + } +} diff --git a/frontend/src/App.tool-center.test.ts b/frontend/src/App.tool-center.test.ts index ed14916..09e4bc6 100644 --- a/frontend/src/App.tool-center.test.ts +++ b/frontend/src/App.tool-center.test.ts @@ -191,6 +191,18 @@ describe('tool center menu entries', () => { expect(appSource).toContain('该异常不一定表现为 viewport ratio drift'); }); + it('captures window state on startup and lifecycle events instead of waiting only for the polling interval', () => { + expect(appSource).toContain('const scheduleWindowStateSave = (delayMs = 120) => {'); + expect(appSource).toContain('if (hydrated) {'); + expect(appSource).toContain('scheduleWindowStateSave(320);'); + expect(appSource).toContain('const unsubscribeHydration = useStore.persist.onFinishHydration(() => {'); + expect(appSource).toContain("window.addEventListener('resize', handleWindowRuntimeChange);"); + expect(appSource).toContain("window.addEventListener('focus', handleWindowRuntimeChange);"); + expect(appSource).toContain("window.addEventListener('pageshow', handleWindowRuntimeChange);"); + expect(appSource).toContain("window.addEventListener('pagehide', handleWindowLifecycleFlush, { capture: true });"); + expect(appSource).toContain("window.addEventListener('beforeunload', handleWindowLifecycleFlush, { capture: true });"); + }); + it('keeps titlebar double-click on maximise while shortcuts may enter macOS fullscreen', () => { expect(appSource).toContain('const handleTitleBarWindowToggle = async (options?: { allowMacNativeFullscreen?: boolean }) => {'); expect(appSource).toContain('const allowMacNativeFullscreen = options?.allowMacNativeFullscreen === true;'); @@ -204,6 +216,12 @@ describe('tool center menu entries', () => { expect(appSource).toContain("window.removeEventListener('keydown', handleGlobalShortcut, true);"); }); + it('skips the native mac titlebar bridge when the current runtime does not expose it', () => { + expect(appSource).toContain("const backendApp = (window as any).go?.app?.App;"); + expect(appSource).toContain("if (typeof backendApp?.SetMacNativeWindowControls !== 'function') {"); + expect(appSource).toContain('void safeWindowRuntimeCall(() => SetMacNativeWindowControls(useNativeMacWindowControls), undefined);'); + }); + it('listens for command search query-tab events and routes them through handleNewQuery', () => { expect(appSource).toContain("window.addEventListener('gonavi:create-query-tab', handleCreateQueryTabEvent as EventListener);"); expect(appSource).toContain("window.removeEventListener('gonavi:create-query-tab', handleCreateQueryTabEvent as EventListener);"); diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index c5adfe7..dc752bd 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -790,9 +790,15 @@ function App() { // 定时保存窗口状态、尺寸与位置 useEffect(() => { const SAVE_INTERVAL_MS = 2000; + let cancelled = false; + let hydrated = useStore.persist.hasHydrated(); + let eventSaveTimer: number | null = null; let lastSaved = ''; const saveWindowState = async () => { + if (cancelled || !hydrated) { + return; + } try { const [isFs, isMax] = await Promise.all([ safeWindowRuntimeCall(() => WindowIsFullscreen(), false), @@ -836,8 +842,67 @@ function App() { } }; - const timer = window.setInterval(saveWindowState, SAVE_INTERVAL_MS); - return () => window.clearInterval(timer); + const scheduleWindowStateSave = (delayMs = 120) => { + if (cancelled || !hydrated) { + return; + } + if (eventSaveTimer !== null) { + window.clearTimeout(eventSaveTimer); + } + eventSaveTimer = window.setTimeout(() => { + eventSaveTimer = null; + void saveWindowState(); + }, delayMs); + }; + + const handleWindowRuntimeChange = () => { + scheduleWindowStateSave(); + }; + + const handleVisibilityChange = () => { + if (document.visibilityState === 'visible') { + scheduleWindowStateSave(120); + } + }; + + const handleWindowLifecycleFlush = () => { + void saveWindowState(); + }; + + if (hydrated) { + scheduleWindowStateSave(320); + } + const unsubscribeHydration = useStore.persist.onFinishHydration(() => { + if (cancelled || hydrated) { + return; + } + hydrated = true; + scheduleWindowStateSave(320); + }); + + const timer = window.setInterval(() => { + void saveWindowState(); + }, SAVE_INTERVAL_MS); + window.addEventListener('resize', handleWindowRuntimeChange); + window.addEventListener('focus', handleWindowRuntimeChange); + window.addEventListener('pageshow', handleWindowRuntimeChange); + window.addEventListener('pagehide', handleWindowLifecycleFlush, { capture: true }); + window.addEventListener('beforeunload', handleWindowLifecycleFlush, { capture: true }); + document.addEventListener('visibilitychange', handleVisibilityChange); + return () => { + cancelled = true; + if (eventSaveTimer !== null) { + window.clearTimeout(eventSaveTimer); + } + window.clearInterval(timer); + window.removeEventListener('resize', handleWindowRuntimeChange); + window.removeEventListener('focus', handleWindowRuntimeChange); + window.removeEventListener('pageshow', handleWindowRuntimeChange); + window.removeEventListener('pagehide', handleWindowLifecycleFlush, { capture: true }); + window.removeEventListener('beforeunload', handleWindowLifecycleFlush, { capture: true }); + document.removeEventListener('visibilitychange', handleVisibilityChange); + unsubscribeHydration(); + }; }, []); useEffect(() => { @@ -1567,12 +1632,11 @@ function App() { if (!isStoreHydrated || !isMacRuntime) { return; } - - try { - void SetMacNativeWindowControls(useNativeMacWindowControls).catch(() => undefined); - } catch (e) { - console.warn('Wails API: SetMacNativeWindowControls unavailable', e); + const backendApp = (window as any).go?.app?.App; + if (typeof backendApp?.SetMacNativeWindowControls !== 'function') { + return; } + void safeWindowRuntimeCall(() => SetMacNativeWindowControls(useNativeMacWindowControls), undefined); }, [isMacRuntime, isStoreHydrated, useNativeMacWindowControls]); useEffect(() => { diff --git a/frontend/src/components/AISettingsModal.edit-password.test.tsx b/frontend/src/components/AISettingsModal.edit-password.test.tsx index 2db53a3..9ebec72 100644 --- a/frontend/src/components/AISettingsModal.edit-password.test.tsx +++ b/frontend/src/components/AISettingsModal.edit-password.test.tsx @@ -18,6 +18,7 @@ describe('AISettingsModal edit password behavior', () => { }); it('loads MCP servers and skills through the AI service', () => { + expect(source).toContain('Service.AIGetMCPClientInstallStatuses?.()'); expect(source).toContain('Service.AIGetMCPServers?.()'); expect(source).toContain('Service.AIListMCPTools?.()'); expect(source).toContain('Service.AIGetSkills?.()'); @@ -25,6 +26,26 @@ describe('AISettingsModal edit password behavior', () => { expect(source).toContain('新增 Skill'); }); + it('explains external MCP installation and renders selectable client install states', () => { + expect(source).toContain('把 GoNavi 注册成外部 AI 客户端可调用的 MCP Server'); + expect(source).toContain('安装到外部客户端'); + expect(source).toContain('未安装'); + expect(source).toContain('需更新'); + expect(source).toContain('已安装'); + expect(source).toContain('刷新状态'); + expect(source).toContain('复制配置路径'); + expect(source).toContain('复制启动命令'); + expect(source).toContain('handleInstallSelectedMCPClient'); + expect(source).toContain('无需重复安装'); + }); + + it('waits briefly for the AI service bridge before warning and removes noisy provider debug logs', () => { + expect(source).toContain('const resolveAIService = useCallback(async () => {'); + expect(source).toContain('const service = await waitForAIService();'); + expect(source).not.toContain("console.log('[AI] AIGetProviders result:'"); + expect(source).not.toContain("console.log('[AI] AIGetActiveProvider result:'"); + }); + it('keeps the prefilled api key masked by default', () => { expect(source).toContain('const [primaryPasswordVisible, setPrimaryPasswordVisible] = useState(false);'); expect(source).toContain('visible: primaryPasswordVisible,'); diff --git a/frontend/src/components/AISettingsModal.tsx b/frontend/src/components/AISettingsModal.tsx index be8b8db..8f1ecfa 100644 --- a/frontend/src/components/AISettingsModal.tsx +++ b/frontend/src/components/AISettingsModal.tsx @@ -1,7 +1,7 @@ import React, { useState, useEffect, useCallback, useMemo, useRef } from 'react'; import { Modal, Button, Input, Select, Form, message as antdMessage, Tooltip, Tabs, Space, Popconfirm, Slider } from 'antd'; -import { PlusOutlined, DeleteOutlined, EditOutlined, CheckOutlined, ApiOutlined, SafetyCertificateOutlined, RobotOutlined, ThunderboltOutlined, CloudOutlined, ExperimentOutlined, KeyOutlined, LinkOutlined, AppstoreOutlined, ToolOutlined } from '@ant-design/icons'; -import type { AIProviderConfig, AIProviderType, AISafetyLevel, AIContextLevel, AIUserPromptSettings, AIMCPServerConfig, AIMCPToolDescriptor, AISkillConfig, AISkillScope } from '../types'; +import { PlusOutlined, DeleteOutlined, EditOutlined, CheckOutlined, ApiOutlined, SafetyCertificateOutlined, RobotOutlined, ThunderboltOutlined, CloudOutlined, ExperimentOutlined, KeyOutlined, LinkOutlined, AppstoreOutlined, ToolOutlined, ReloadOutlined, CopyOutlined } from '@ant-design/icons'; +import type { AIProviderConfig, AIProviderType, AISafetyLevel, AIContextLevel, AIUserPromptSettings, AIMCPServerConfig, AIMCPToolDescriptor, AIMCPClientInstallStatus, AISkillConfig, AISkillScope } from '../types'; import { QWEN_BAILIAN_ANTHROPIC_BASE_URL, QWEN_CODING_PLAN_ANTHROPIC_BASE_URL, @@ -30,6 +30,17 @@ interface AISettingsModalProps { focusProviderId?: string; } +interface MCPClientInstallResult { + success?: boolean; + client?: string; + message?: string; + configPath?: string; + command?: string; + args?: string[]; +} + +type MCPClientKey = 'claude-code' | 'codex'; + // 预设配置:每个预设映射到后端 type(openai/anthropic/gemini/custom)并附带默认 URL 和 Model interface ProviderPreset { key: string; @@ -97,6 +108,100 @@ const EMPTY_MCP_SERVER = (): AIMCPServerConfig => ({ timeoutSeconds: 20, }); +const EMPTY_MCP_CLIENT_STATUSES: AIMCPClientInstallStatus[] = [ + { + client: 'claude-code', + displayName: 'Claude Code', + installed: false, + matchesCurrent: false, + message: '未安装到 Claude Code 用户级配置', + }, + { + client: 'codex', + displayName: 'Codex', + installed: false, + matchesCurrent: false, + message: '未安装到 Codex 用户级配置', + }, +]; + +const normalizeMCPClientStatuses = (items?: AIMCPClientInstallStatus[]): AIMCPClientInstallStatus[] => { + const baseMap = new Map( + EMPTY_MCP_CLIENT_STATUSES.map((item) => [item.client, { ...item }]), + ); + (Array.isArray(items) ? items : []).forEach((item) => { + if (!item || !item.client) { + return; + } + const base = baseMap.get(item.client) || { + client: item.client, + displayName: item.client, + installed: false, + matchesCurrent: false, + message: '', + }; + baseMap.set(item.client, { + ...base, + ...item, + displayName: item.displayName || base.displayName, + message: item.message || base.message, + args: Array.isArray(item.args) ? item.args : (base.args || []), + }); + }); + return (['claude-code', 'codex'] as MCPClientKey[]) + .map((client) => baseMap.get(client)) + .filter((item): item is AIMCPClientInstallStatus => Boolean(item)); +}; + +const pickPreferredMCPClient = (items: AIMCPClientInstallStatus[], current?: MCPClientKey): MCPClientKey => { + if (current && items.some((item) => item.client === current)) { + return current; + } + const pending = items.find((item) => !item.matchesCurrent); + if (pending?.client === 'claude-code' || pending?.client === 'codex') { + return pending.client; + } + return 'claude-code'; +}; + +const waitFor = (delayMs: number) => new Promise((resolve) => { + window.setTimeout(resolve, delayMs); +}); + +const readAIService = () => (window as any).go?.aiservice?.Service; + +const waitForAIService = async (attempts = 6, delayMs = 80) => { + for (let attempt = 0; attempt < attempts; attempt += 1) { + const service = readAIService(); + if (service) { + return service; + } + if (attempt < attempts - 1) { + await waitFor(delayMs); + } + } + return readAIService(); +}; + +const quoteMCPCommandPart = (value: string): string => { + const text = String(value || '').trim(); + if (!text) { + return ''; + } + return /[\s"]/u.test(text) ? `"${text.replace(/"/g, '\\"')}"` : text; +}; + +const formatMCPLaunchCommand = (input?: Pick | Pick | null): string => { + const command = String(input?.command || '').trim(); + if (!command) { + return ''; + } + const args = Array.isArray(input?.args) + ? input.args.map((item) => String(item || '').trim()).filter(Boolean) + : []; + return [command, ...args].map(quoteMCPCommandPart).filter(Boolean).join(' '); +}; + const EMPTY_SKILL = (): AISkillConfig => ({ id: `skill-draft-${Date.now()}-${Math.random().toString(36).slice(2, 8)}`, name: '', @@ -142,6 +247,9 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo const [contextLevel, setContextLevel] = useState('schema_only'); const [mcpServers, setMCPServers] = useState([]); const [mcpTools, setMCPTools] = useState([]); + const [mcpClientStatuses, setMCPClientStatuses] = useState(EMPTY_MCP_CLIENT_STATUSES); + const [selectedMCPClient, setSelectedMCPClient] = useState('claude-code'); + const [mcpClientStatusLoading, setMCPClientStatusLoading] = useState(false); const [skills, setSkills] = useState([]); const [editingProvider, setEditingProvider] = useState(null); const [isEditing, setIsEditing] = useState(false); @@ -153,6 +261,7 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo const [primaryPasswordVisible, setPrimaryPasswordVisible] = useState(false); const [form] = Form.useForm(); const modalBodyRef = useRef(null); + const missingAIServiceWarnedRef = useRef(false); // Modal 内部 toast 通知 const [messageApi, messageContextHolder] = antdMessage.useMessage({ getContainer: () => modalBodyRef.current || document.body }); @@ -163,6 +272,35 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo const cardHoverBg = darkMode ? 'rgba(255,255,255,0.06)' : 'rgba(0,0,0,0.03)'; const sectionLabelColor = darkMode ? 'rgba(255,255,255,0.5)' : 'rgba(0,0,0,0.4)'; const inputBg = darkMode ? 'rgba(255,255,255,0.04)' : 'rgba(0,0,0,0.02)'; + const getMCPClientStatusTone = useCallback((status?: AIMCPClientInstallStatus) => { + const messageText = String(status?.message || ''); + if (status?.matchesCurrent) { + return { + label: '已安装', + color: '#16a34a', + bg: darkMode ? 'rgba(34,197,94,0.18)' : 'rgba(34,197,94,0.12)', + }; + } + if (status?.installed) { + return { + label: '需更新', + color: '#d97706', + bg: darkMode ? 'rgba(245,158,11,0.18)' : 'rgba(245,158,11,0.12)', + }; + } + if (messageText.includes('失败') || messageText.includes('异常')) { + return { + label: '需检查', + color: '#dc2626', + bg: darkMode ? 'rgba(239,68,68,0.18)' : 'rgba(239,68,68,0.1)', + }; + } + return { + label: '未安装', + color: darkMode ? 'rgba(255,255,255,0.72)' : '#64748b', + bg: darkMode ? 'rgba(255,255,255,0.08)' : 'rgba(100,116,139,0.08)', + }; + }, [darkMode]); // Hook 必须在组件顶层调用,不能在条件分支内 const watchedType = Form.useWatch('type', form); @@ -178,11 +316,71 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo value: tool.alias, })), ]), [mcpTools]); + const selectedMCPClientStatus = useMemo( + () => mcpClientStatuses.find((item) => item.client === selectedMCPClient) || mcpClientStatuses[0], + [mcpClientStatuses, selectedMCPClient], + ); + const selectedMCPClientCommandText = useMemo( + () => formatMCPLaunchCommand(selectedMCPClientStatus), + [selectedMCPClientStatus], + ); + + const resolveAIService = useCallback(async () => { + const service = await waitForAIService(); + if (service) { + missingAIServiceWarnedRef.current = false; + return service; + } + if (!missingAIServiceWarnedRef.current) { + console.warn('[AI] Service not found on window.go'); + missingAIServiceWarnedRef.current = true; + } + return null; + }, []); + + const loadMCPClientStatuses = useCallback(async (options?: { silent?: boolean }) => { + const silent = options?.silent === true; + if (!silent) { + setMCPClientStatusLoading(true); + } + try { + const Service = await resolveAIService(); + if (typeof Service?.AIGetMCPClientInstallStatuses !== 'function') { + return; + } + const result = await Service.AIGetMCPClientInstallStatuses(); + if (Array.isArray(result)) { + const normalizedStatuses = normalizeMCPClientStatuses(result); + setMCPClientStatuses(normalizedStatuses); + setSelectedMCPClient((prev) => pickPreferredMCPClient(normalizedStatuses, prev)); + } + } catch (e: any) { + if (silent) { + console.warn('[AI] refresh mcp client statuses failed', e); + } else { + void messageApi.error(e?.message || '刷新客户端安装状态失败'); + } + } finally { + if (!silent) { + setMCPClientStatusLoading(false); + } + } + }, [messageApi, resolveAIService]); + + const copyTextToClipboard = useCallback(async (text: string, successMessage: string) => { + if (typeof navigator?.clipboard?.writeText !== 'function') { + throw new Error('当前环境不支持复制到剪贴板'); + } + await navigator.clipboard.writeText(text); + void messageApi.success(successMessage); + }, [messageApi]); const loadConfig = useCallback(async () => { try { - const Service = (window as any).go?.aiservice?.Service; - if (!Service) { console.warn('[AI] Service not found on window.go'); return; } + const Service = await resolveAIService(); + if (!Service) { + return; + } const callOrFallback = async (loader: (() => Promise) | undefined, fallback: T): Promise => { if (typeof loader !== 'function') { return fallback; @@ -194,7 +392,7 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo return fallback; } }; - const [provRes, safeRes, ctxRes, promptsRes, userPromptsRes, mcpServersRes, mcpToolsRes, skillsRes] = await Promise.all([ + const [provRes, safeRes, ctxRes, promptsRes, userPromptsRes, mcpServersRes, mcpToolsRes, skillsRes, mcpClientStatusesRes] = await Promise.all([ callOrFallback(() => Service.AIGetProviders?.(), []), callOrFallback(() => Service.AIGetSafetyLevel?.(), 'readonly'), callOrFallback(() => Service.AIGetContextLevel?.(), 'schema_only'), @@ -203,12 +401,11 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo callOrFallback(() => Service.AIGetMCPServers?.(), []), callOrFallback(() => Service.AIListMCPTools?.(), []), callOrFallback(() => Service.AIGetSkills?.(), []), + callOrFallback(() => Service.AIGetMCPClientInstallStatuses?.(), EMPTY_MCP_CLIENT_STATUSES), ]); - console.log('[AI] AIGetProviders result:', JSON.stringify(provRes), 'isArray:', Array.isArray(provRes)); if (Array.isArray(provRes)) { setProviders(provRes); const activeRes = await Service.AIGetActiveProvider?.(); - console.log('[AI] AIGetActiveProvider result:', activeRes); if (activeRes) setActiveProviderId(activeRes); } if (safeRes) setSafetyLevel(safeRes); @@ -223,8 +420,13 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo if (Array.isArray(mcpServersRes)) setMCPServers(mcpServersRes); if (Array.isArray(mcpToolsRes)) setMCPTools(mcpToolsRes); if (Array.isArray(skillsRes)) setSkills(skillsRes); + if (Array.isArray(mcpClientStatusesRes)) { + const normalizedStatuses = normalizeMCPClientStatuses(mcpClientStatusesRes); + setMCPClientStatuses(normalizedStatuses); + setSelectedMCPClient((prev) => pickPreferredMCPClient(normalizedStatuses, prev)); + } } catch (e) { console.warn('Failed to load AI config', e); } - }, []); + }, [resolveAIService]); useEffect(() => { if (open) void loadConfig(); }, [open, loadConfig]); @@ -491,6 +693,63 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo } }; + const handleInstallSelectedMCPClient = async () => { + const targetClient = selectedMCPClientStatus?.client === 'codex' ? 'codex' : 'claude-code'; + const targetLabel = selectedMCPClientStatus?.displayName || (targetClient === 'codex' ? 'Codex' : 'Claude Code'); + if (selectedMCPClientStatus?.matchesCurrent) { + void messageApi.success(`${targetLabel} 已安装当前 GoNavi MCP,无需重复安装`); + return; + } + try { + setLoading(true); + const Service = await resolveAIService(); + let result: MCPClientInstallResult; + if (targetClient === 'codex') { + if (typeof Service?.AIInstallCodexMCP !== 'function') { + throw new Error('当前版本暂不支持自动安装 Codex MCP'); + } + result = await Service.AIInstallCodexMCP() as MCPClientInstallResult; + } else { + if (typeof Service?.AIInstallClaudeCodeMCP !== 'function') { + throw new Error('当前版本暂不支持自动安装 Claude Code MCP'); + } + result = await Service.AIInstallClaudeCodeMCP() as MCPClientInstallResult; + } + await loadMCPClientStatuses({ silent: true }); + window.dispatchEvent(new CustomEvent('gonavi:ai:config-changed')); + void messageApi.success(result?.message || `已写入 ${targetLabel} 用户级 MCP 配置`); + } catch (e: any) { + void messageApi.error(e?.message || `安装 ${targetLabel} MCP 失败`); + } finally { + setLoading(false); + } + }; + + const handleCopySelectedMCPConfigPath = useCallback(async () => { + const configPath = String(selectedMCPClientStatus?.configPath || '').trim(); + if (!configPath) { + void messageApi.warning('当前没有可复制的配置文件路径'); + return; + } + try { + await copyTextToClipboard(configPath, '配置文件路径已复制'); + } catch (e: any) { + void messageApi.error(e?.message || '复制配置文件路径失败'); + } + }, [copyTextToClipboard, messageApi, selectedMCPClientStatus]); + + const handleCopySelectedMCPLaunchCommand = useCallback(async () => { + if (!selectedMCPClientCommandText) { + void messageApi.warning('当前没有可复制的启动命令'); + return; + } + try { + await copyTextToClipboard(selectedMCPClientCommandText, '启动命令已复制'); + } catch (e: any) { + void messageApi.error(e?.message || '复制启动命令失败'); + } + }, [copyTextToClipboard, messageApi, selectedMCPClientCommandText]); + const updateSkillDraft = (id: string, patch: Partial) => { setSkills((prev) => prev.map((item) => item.id === id ? { ...item, ...patch } : item)); }; @@ -983,8 +1242,165 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo const renderMCPSettings = () => (
-
- MCP 会作为外部工具源接入 AI。当前阶段先支持 `stdio` 型服务,不需要为 GoNavi 的 MCP client 单独新建仓库;只有你准备发布独立的 MCP Server 时,才值得拆独立仓库。 +
+ 这里的“安装到客户端”是把 GoNavi 注册成外部 AI 客户端可调用的 MCP Server,供 Claude Code 或 Codex 使用;不是 GoNavi 自己安装自己。 +
+
+
+
安装到外部客户端
+
+ 先选择目标客户端,再把当前 GoNavi 安装路径写入它的用户级 MCP 配置。GoNavi 会自动处理配置文件路径,不需要你自己找本机 exe。 +
+
+ +
+ {mcpClientStatuses.map((status) => { + const active = selectedMCPClient === status.client; + const tone = getMCPClientStatusTone(status); + return ( +
{ + if (status.client === 'claude-code' || status.client === 'codex') { + setSelectedMCPClient(status.client); + } + }} + style={{ + padding: '14px 14px 12px', + borderRadius: 12, + border: `1.5px solid ${active ? overlayTheme.selectedText : cardBorder}`, + background: active ? overlayTheme.selectedBg : (darkMode ? 'rgba(255,255,255,0.02)' : 'rgba(255,255,255,0.7)'), + cursor: 'pointer', + display: 'flex', + flexDirection: 'column', + gap: 10, + transition: 'all 0.2s ease', + }} + > +
+
+ {status.displayName} +
+
+ {tone.label} +
+
+
+ {status.matchesCurrent + ? '当前 GoNavi 安装路径已写入,打开客户端后可直接使用。' + : status.installed + ? '检测到已有安装记录,但建议更新为当前 GoNavi 路径。' + : '当前尚未写入 GoNavi MCP 配置。'} +
+
+ ); + })} +
+ +
+
+
+ {selectedMCPClientStatus?.displayName || '客户端'} 状态 +
+ {selectedMCPClientStatus && ( +
+ {getMCPClientStatusTone(selectedMCPClientStatus).label} +
+ )} +
+
+ {selectedMCPClientStatus?.message || '未检测到安装状态'} +
+ {selectedMCPClientStatus?.configPath && ( +
+ 配置文件:{selectedMCPClientStatus.configPath} +
+ )} + {selectedMCPClientCommandText && ( +
+ 启动命令:{selectedMCPClientCommandText} +
+ )} +
+ + + +
+
+ +
+
+ 安装后重启对应客户端即可生效;若已经是当前路径,会直接提示无需重复安装。 +
+ +
支持命令、参数、环境变量和超时,保存后会自动进入 AI 工具列表。
@@ -1217,7 +1633,7 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo >
{messageContextHolder} -
+
设置导航
{[ diff --git a/frontend/src/components/FindInDatabaseModal.tsx b/frontend/src/components/FindInDatabaseModal.tsx index a81ea41..338ab9a 100644 --- a/frontend/src/components/FindInDatabaseModal.tsx +++ b/frontend/src/components/FindInDatabaseModal.tsx @@ -341,7 +341,7 @@ const FindInDatabaseModal: React.FC = ({ open, onClose header: { background: 'transparent', borderBottom: 'none', paddingBottom: 8 }, body: { paddingTop: 8 }, }} - destroyOnClose + destroyOnHidden >
{/* 搜索栏 */} diff --git a/frontend/src/main.tsx b/frontend/src/main.tsx index f94f734..0ef2461 100644 --- a/frontend/src/main.tsx +++ b/frontend/src/main.tsx @@ -22,12 +22,14 @@ const resolveDevHarnessMode = (): string => { } }; -if (typeof window !== 'undefined' && !(window as any).go) { +if (typeof window !== 'undefined' && (!(window as any).go?.app?.App || !(window as any).go?.aiservice?.Service)) { const mockConnections: any[] = []; const mockConnectionSecrets = new Map(); const mockProviders: any[] = []; const mockProviderSecrets = new Map(); let mockActiveProviderId = ''; + let mockAISafetyLevel = 'readonly'; + let mockAIContextLevel = 'schema_only'; let mockAIUserPromptSettings: any = { global: '', database: '', @@ -35,6 +37,28 @@ if (typeof window !== 'undefined' && !(window as any).go) { jvmDiagnostic: '', }; let mockMCPServers: any[] = []; + let mockMCPClientStatuses: any[] = [ + { + client: 'claude-code', + displayName: 'Claude Code', + installed: false, + matchesCurrent: false, + message: '未安装到 Claude Code 用户级配置', + configPath: 'C:/Users/mock/.claude.json', + command: 'C:/Program Files/GoNavi/GoNavi.exe', + args: ['mcp-server'], + }, + { + client: 'codex', + displayName: 'Codex', + installed: true, + matchesCurrent: false, + message: '已检测到 Codex 安装记录,但与当前 GoNavi 安装包路径不一致,建议更新安装', + configPath: 'C:/Users/mock/.codex/config.toml', + command: 'C:/Old/GoNavi.exe', + args: ['mcp-server'], + }, + ]; let mockSkills: any[] = []; let mockGlobalProxy: any = { enabled: false, type: 'socks5', host: '', port: 1080, user: '', password: '', hasPassword: false }; let mockDataRootInfo: any = { @@ -154,7 +178,7 @@ if (typeof window !== 'undefined' && !(window as any).go) { return cloneBrowserMockValue(view); }; - (window as any).go = { + const mockGo = { app: { App: { CheckUpdate: async () => ({ success: false }), @@ -291,8 +315,8 @@ if (typeof window !== 'undefined' && !(window as any).go) { mockActiveProviderId = id; return null; }, - AIGetSafetyLevel: async () => 'readonly', - AIGetContextLevel: async () => 'schema_only', + AIGetSafetyLevel: async () => mockAISafetyLevel, + AIGetContextLevel: async () => mockAIContextLevel, AIGetBuiltinPrompts: async () => ({}), AIGetUserPromptSettings: async () => cloneBrowserMockValue(mockAIUserPromptSettings), AISaveUserPromptSettings: async (input: any) => { @@ -304,7 +328,48 @@ if (typeof window !== 'undefined' && !(window as any).go) { }; return null; }, + AIGetMCPClientInstallStatuses: async () => cloneBrowserMockValue(mockMCPClientStatuses), AIGetMCPServers: async () => cloneBrowserMockValue(mockMCPServers), + AIInstallClaudeCodeMCP: async () => { + mockMCPClientStatuses = mockMCPClientStatuses.map((item) => item.client === 'claude-code' + ? { + ...item, + installed: true, + matchesCurrent: true, + message: '已写入 Claude Code 用户级 MCP 配置,重启 Claude CLI 后可在 /mcp 的 User MCPs 中看到 GoNavi。', + command: 'C:/Program Files/GoNavi/GoNavi.exe', + args: ['mcp-server'], + } + : item); + return { + success: true, + client: 'claude-code', + message: '已写入 Claude Code 用户级 MCP 配置,重启 Claude CLI 后可在 /mcp 的 User MCPs 中看到 GoNavi。', + configPath: 'C:/Users/mock/.claude.json', + command: 'C:/Program Files/GoNavi/GoNavi.exe', + args: ['mcp-server'], + }; + }, + AIInstallCodexMCP: async () => { + mockMCPClientStatuses = mockMCPClientStatuses.map((item) => item.client === 'codex' + ? { + ...item, + installed: true, + matchesCurrent: true, + message: '已写入 Codex 用户级 MCP 配置,重启 Codex CLI 或桌面端后可看到 GoNavi。', + command: 'C:/Program Files/GoNavi/GoNavi.exe', + args: ['mcp-server'], + } + : item); + return { + success: true, + client: 'codex', + message: '已写入 Codex 用户级 MCP 配置,重启 Codex CLI 或桌面端后可看到 GoNavi。', + configPath: 'C:/Users/mock/.codex/config.toml', + command: 'C:/Program Files/GoNavi/GoNavi.exe', + args: ['mcp-server'], + }; + }, AISaveMCPServer: async (input: any) => { const next = { id: String(input?.id || `mcp-${Date.now()}`), @@ -363,11 +428,38 @@ if (typeof window !== 'undefined' && !(window as any).go) { success: String(input?.apiKey || '').trim() !== '', message: String(input?.apiKey || '').trim() !== '' ? '端点连通性测试成功!' : '连接测试失败: missing api key', }), - AISetSafetyLevel: async () => null, - AISetContextLevel: async () => null, + AISetSafetyLevel: async (level: string) => { + mockAISafetyLevel = String(level || 'readonly'); + return null; + }, + AISetContextLevel: async (level: string) => { + mockAIContextLevel = String(level || 'schema_only'); + return null; + }, }, } }; + const existingGo = (window as any).go || {}; + (window as any).go = { + ...mockGo, + ...existingGo, + app: { + ...mockGo.app, + ...(existingGo.app || {}), + App: { + ...mockGo.app.App, + ...(existingGo.app?.App || {}), + }, + }, + aiservice: { + ...mockGo.aiservice, + ...(existingGo.aiservice || {}), + Service: { + ...mockGo.aiservice.Service, + ...(existingGo.aiservice?.Service || {}), + }, + }, + }; } const rootNode = document.getElementById('root')!; const devHarnessMode = import.meta.env.DEV ? resolveDevHarnessMode() : ''; diff --git a/frontend/src/types.ts b/frontend/src/types.ts index dfd0968..011a8fa 100644 --- a/frontend/src/types.ts +++ b/frontend/src/types.ts @@ -592,6 +592,17 @@ export interface AIMCPToolCallResult { isError: boolean; } +export interface AIMCPClientInstallStatus { + client: string; + displayName: string; + installed: boolean; + matchesCurrent: boolean; + message: string; + configPath?: string; + command?: string; + args?: string[]; +} + export type AISkillScope = "global" | "database" | "jvm" | "jvmDiagnostic"; export interface AISkillConfig { diff --git a/frontend/wailsjs/go/aiservice/Service.d.ts b/frontend/wailsjs/go/aiservice/Service.d.ts index 93450a1..54efd24 100755 --- a/frontend/wailsjs/go/aiservice/Service.d.ts +++ b/frontend/wailsjs/go/aiservice/Service.d.ts @@ -28,6 +28,8 @@ export function AIGetContextLevel():Promise; export function AIGetEditableProvider(arg1:string):Promise; +export function AIGetMCPClientInstallStatuses():Promise>; + export function AIGetMCPServers():Promise>; export function AIGetProviders():Promise>; @@ -40,6 +42,10 @@ export function AIGetSkills():Promise>; export function AIGetUserPromptSettings():Promise; +export function AIInstallClaudeCodeMCP():Promise; + +export function AIInstallCodexMCP():Promise; + export function AIListMCPTools():Promise>; export function AIListModels():Promise>; diff --git a/frontend/wailsjs/go/aiservice/Service.js b/frontend/wailsjs/go/aiservice/Service.js index 0ee39dc..40900e5 100755 --- a/frontend/wailsjs/go/aiservice/Service.js +++ b/frontend/wailsjs/go/aiservice/Service.js @@ -54,6 +54,10 @@ export function AIGetEditableProvider(arg1) { return window['go']['aiservice']['Service']['AIGetEditableProvider'](arg1); } +export function AIGetMCPClientInstallStatuses() { + return window['go']['aiservice']['Service']['AIGetMCPClientInstallStatuses'](); +} + export function AIGetMCPServers() { return window['go']['aiservice']['Service']['AIGetMCPServers'](); } @@ -78,6 +82,14 @@ export function AIGetUserPromptSettings() { return window['go']['aiservice']['Service']['AIGetUserPromptSettings'](); } +export function AIInstallClaudeCodeMCP() { + return window['go']['aiservice']['Service']['AIInstallClaudeCodeMCP'](); +} + +export function AIInstallCodexMCP() { + return window['go']['aiservice']['Service']['AIInstallCodexMCP'](); +} + export function AIListMCPTools() { return window['go']['aiservice']['Service']['AIListMCPTools'](); } diff --git a/frontend/wailsjs/go/models.ts b/frontend/wailsjs/go/models.ts index 4d5126e..6c34e4a 100755 --- a/frontend/wailsjs/go/models.ts +++ b/frontend/wailsjs/go/models.ts @@ -1,5 +1,53 @@ export namespace ai { + export class MCPClientInstallResult { + success: boolean; + client?: string; + message: string; + configPath?: string; + command?: string; + args?: string[]; + + static createFrom(source: any = {}) { + return new MCPClientInstallResult(source); + } + + constructor(source: any = {}) { + if ('string' === typeof source) source = JSON.parse(source); + this.success = source["success"]; + this.client = source["client"]; + this.message = source["message"]; + this.configPath = source["configPath"]; + this.command = source["command"]; + this.args = source["args"]; + } + } + export class MCPClientInstallStatus { + client: string; + displayName: string; + installed: boolean; + matchesCurrent: boolean; + message: string; + configPath?: string; + command?: string; + args?: string[]; + + static createFrom(source: any = {}) { + return new MCPClientInstallStatus(source); + } + + constructor(source: any = {}) { + if ('string' === typeof source) source = JSON.parse(source); + this.client = source["client"]; + this.displayName = source["displayName"]; + this.installed = source["installed"]; + this.matchesCurrent = source["matchesCurrent"]; + this.message = source["message"]; + this.configPath = source["configPath"]; + this.command = source["command"]; + this.args = source["args"]; + } + } export class MCPServerConfig { id: string; name: string; @@ -1272,4 +1320,3 @@ export namespace sync { } } - diff --git a/internal/ai/service/claude_code_mcp.go b/internal/ai/service/claude_code_mcp.go new file mode 100644 index 0000000..da279df --- /dev/null +++ b/internal/ai/service/claude_code_mcp.go @@ -0,0 +1,680 @@ +package aiservice + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "reflect" + "strconv" + "strings" + + "GoNavi-Wails/internal/ai" +) + +const ( + gonaviMCPServerID = "gonavi" + defaultCodexMCPStartupTimeoutSecond = 60 +) + +var claudeCodeConfigPathFunc = func() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", err + } + homeDir = strings.TrimSpace(homeDir) + if homeDir == "" { + return "", fmt.Errorf("无法确定用户目录") + } + return filepath.Join(homeDir, ".claude.json"), nil +} + +var codexConfigPathFunc = func() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", err + } + homeDir = strings.TrimSpace(homeDir) + if homeDir == "" { + return "", fmt.Errorf("无法确定用户目录") + } + return filepath.Join(homeDir, ".codex", "config.toml"), nil +} + +var localMCPExecutablePathFunc = os.Executable + +type claudeCodeMCPServerConfig struct { + Type string `json:"type"` + Command string `json:"command"` + Args []string `json:"args,omitempty"` + Env map[string]string `json:"env,omitempty"` +} + +type codexMCPServerConfig struct { + Command string + Args []string + StartupTimeoutSec int +} + +// AIGetMCPClientInstallStatuses 返回 GoNavi MCP 在常见外部客户端中的安装状态。 +func (s *Service) AIGetMCPClientInstallStatuses() []ai.MCPClientInstallStatus { + command, args, resolveErr := resolveCurrentLocalMCPCommand() + return []ai.MCPClientInstallStatus{ + inspectClaudeCodeMCPInstallStatus(command, args, resolveErr), + inspectCodexMCPInstallStatus(command, args, resolveErr), + } +} + +// AIInstallClaudeCodeMCP 把 GoNavi 的 MCP server 写入 Claude Code 用户级 MCP 配置。 +func (s *Service) AIInstallClaudeCodeMCP() (ai.MCPClientInstallResult, error) { + configPath, err := claudeCodeConfigPathFunc() + if err != nil { + return ai.MCPClientInstallResult{}, fmt.Errorf("定位 Claude Code 配置失败: %w", err) + } + + executablePath, err := localMCPExecutablePathFunc() + if err != nil { + return ai.MCPClientInstallResult{}, fmt.Errorf("定位当前 GoNavi 可执行文件失败: %w", err) + } + + command, args, err := resolveLocalMCPCommand(executablePath) + if err != nil { + return ai.MCPClientInstallResult{}, err + } + + serverConfig := claudeCodeMCPServerConfig{ + Type: "stdio", + Command: command, + Args: append([]string(nil), args...), + Env: map[string]string{}, + } + if err := upsertClaudeCodeMCPServerConfig(configPath, gonaviMCPServerID, serverConfig); err != nil { + return ai.MCPClientInstallResult{}, err + } + + return ai.MCPClientInstallResult{ + Success: true, + Client: "claude-code", + Message: "已写入 Claude Code 用户级 MCP 配置,重启 Claude CLI 后可在 /mcp 的 User MCPs 中看到 GoNavi。", + ConfigPath: configPath, + Command: command, + Args: append([]string(nil), args...), + }, nil +} + +// AIInstallCodexMCP 把 GoNavi 的 MCP server 写入 Codex 用户级 MCP 配置。 +func (s *Service) AIInstallCodexMCP() (ai.MCPClientInstallResult, error) { + configPath, err := codexConfigPathFunc() + if err != nil { + return ai.MCPClientInstallResult{}, fmt.Errorf("定位 Codex 配置失败: %w", err) + } + + executablePath, err := localMCPExecutablePathFunc() + if err != nil { + return ai.MCPClientInstallResult{}, fmt.Errorf("定位当前 GoNavi 可执行文件失败: %w", err) + } + + command, args, err := resolveLocalMCPCommand(executablePath) + if err != nil { + return ai.MCPClientInstallResult{}, err + } + + serverConfig := codexMCPServerConfig{ + Command: command, + Args: append([]string(nil), args...), + StartupTimeoutSec: defaultCodexMCPStartupTimeoutSecond, + } + if err := upsertCodexMCPServerConfig(configPath, gonaviMCPServerID, serverConfig); err != nil { + return ai.MCPClientInstallResult{}, err + } + + return ai.MCPClientInstallResult{ + Success: true, + Client: "codex", + Message: "已写入 Codex 用户级 MCP 配置,重启 Codex CLI 或桌面端后可看到 GoNavi。", + ConfigPath: configPath, + Command: command, + Args: append([]string(nil), args...), + }, nil +} + +func resolveCurrentLocalMCPCommand() (string, []string, error) { + executablePath, err := localMCPExecutablePathFunc() + if err != nil { + return "", nil, fmt.Errorf("定位当前 GoNavi 可执行文件失败: %w", err) + } + command, args, err := resolveLocalMCPCommand(executablePath) + if err != nil { + return "", nil, err + } + return command, args, nil +} + +func resolveLocalMCPCommand(executablePath string) (string, []string, error) { + executablePath = strings.TrimSpace(executablePath) + if executablePath == "" { + return "", nil, fmt.Errorf("当前 GoNavi 可执行文件路径为空") + } + + cleaned := filepath.Clean(executablePath) + baseName := strings.ToLower(strings.TrimSpace(filepath.Base(cleaned))) + switch baseName { + case "gonavi-mcp-server", "gonavi-mcp-server.exe": + return cleaned, []string{}, nil + default: + return cleaned, []string{"mcp-server"}, nil + } +} + +func inspectClaudeCodeMCPInstallStatus(expectedCommand string, expectedArgs []string, expectedErr error) ai.MCPClientInstallStatus { + configPath, pathErr := claudeCodeConfigPathFunc() + status := ai.MCPClientInstallStatus{ + Client: "claude-code", + DisplayName: "Claude Code", + ConfigPath: strings.TrimSpace(configPath), + Message: "未安装到 Claude Code 用户级配置", + } + if pathErr != nil { + status.Message = fmt.Sprintf("定位 Claude Code 配置失败: %v", pathErr) + return status + } + + serverConfig, found, err := readClaudeCodeMCPServerConfig(configPath, gonaviMCPServerID) + if err != nil { + status.Installed = found + status.Message = err.Error() + if found { + status.Command = strings.TrimSpace(serverConfig.Command) + status.Args = append([]string(nil), serverConfig.Args...) + } + return status + } + if !found { + return status + } + + status.Installed = true + status.Command = strings.TrimSpace(serverConfig.Command) + status.Args = append([]string(nil), serverConfig.Args...) + if expectedErr != nil { + status.Message = fmt.Sprintf("已检测到 Claude Code 安装记录,但当前 GoNavi 安装路径校验失败:%v", expectedErr) + return status + } + + status.MatchesCurrent = strings.EqualFold(strings.TrimSpace(serverConfig.Type), "stdio") && + sameMCPCommand(serverConfig.Command, serverConfig.Args, expectedCommand, expectedArgs) + if status.MatchesCurrent { + status.Message = "已安装到 Claude Code 用户级配置" + return status + } + + status.Message = "已检测到 Claude Code 安装记录,但与当前 GoNavi 安装包路径不一致,建议更新安装" + return status +} + +func inspectCodexMCPInstallStatus(expectedCommand string, expectedArgs []string, expectedErr error) ai.MCPClientInstallStatus { + configPath, pathErr := codexConfigPathFunc() + status := ai.MCPClientInstallStatus{ + Client: "codex", + DisplayName: "Codex", + ConfigPath: strings.TrimSpace(configPath), + Message: "未安装到 Codex 用户级配置", + } + if pathErr != nil { + status.Message = fmt.Sprintf("定位 Codex 配置失败: %v", pathErr) + return status + } + + serverConfig, found, err := readCodexMCPServerConfig(configPath, gonaviMCPServerID) + if err != nil { + status.Installed = found + status.Message = err.Error() + if found { + status.Command = strings.TrimSpace(serverConfig.Command) + status.Args = append([]string(nil), serverConfig.Args...) + } + return status + } + if !found { + return status + } + + status.Installed = true + status.Command = strings.TrimSpace(serverConfig.Command) + status.Args = append([]string(nil), serverConfig.Args...) + if expectedErr != nil { + status.Message = fmt.Sprintf("已检测到 Codex 安装记录,但当前 GoNavi 安装路径校验失败:%v", expectedErr) + return status + } + + status.MatchesCurrent = sameMCPCommand(serverConfig.Command, serverConfig.Args, expectedCommand, expectedArgs) && + (serverConfig.StartupTimeoutSec == 0 || serverConfig.StartupTimeoutSec == defaultCodexMCPStartupTimeoutSecond) + if status.MatchesCurrent { + status.Message = "已安装到 Codex 用户级配置" + return status + } + + status.Message = "已检测到 Codex 安装记录,但与当前 GoNavi 安装包路径不一致,建议更新安装" + return status +} + +func readClaudeCodeMCPServerConfig(configPath string, serverID string) (claudeCodeMCPServerConfig, bool, error) { + root, err := readClaudeCodeConfig(configPath) + if err != nil { + return claudeCodeMCPServerConfig{}, false, err + } + + rawServers, exists := root["mcpServers"] + if !exists || rawServers == nil { + return claudeCodeMCPServerConfig{}, false, nil + } + mcpServers, ok := rawServers.(map[string]any) + if !ok { + return claudeCodeMCPServerConfig{}, false, fmt.Errorf("Claude Code 配置格式异常:mcpServers 不是对象") + } + + rawServer, exists := mcpServers[strings.TrimSpace(serverID)] + if !exists || rawServer == nil { + return claudeCodeMCPServerConfig{}, false, nil + } + serverMap, ok := rawServer.(map[string]any) + if !ok { + return claudeCodeMCPServerConfig{}, true, fmt.Errorf("Claude Code 配置格式异常:mcpServers.%s 不是对象", strings.TrimSpace(serverID)) + } + + args, err := decodeJSONLikeStringSlice(serverMap["args"]) + if err != nil { + return claudeCodeMCPServerConfig{}, true, fmt.Errorf("Claude Code 配置格式异常:mcpServers.%s.args 不是字符串数组", strings.TrimSpace(serverID)) + } + return claudeCodeMCPServerConfig{ + Type: strings.TrimSpace(anyString(serverMap["type"])), + Command: strings.TrimSpace(anyString(serverMap["command"])), + Args: args, + }, true, nil +} + +func upsertClaudeCodeMCPServerConfig(configPath string, serverID string, serverConfig claudeCodeMCPServerConfig) error { + root, err := readClaudeCodeConfig(configPath) + if err != nil { + return err + } + + mcpServers, err := ensureJSONMap(root, "mcpServers") + if err != nil { + return err + } + + mcpServers[strings.TrimSpace(serverID)] = map[string]any{ + "type": serverConfig.Type, + "command": serverConfig.Command, + "args": append([]string(nil), serverConfig.Args...), + "env": cloneStringMap(serverConfig.Env), + } + root["mcpServers"] = mcpServers + + data, err := json.MarshalIndent(root, "", " ") + if err != nil { + return fmt.Errorf("序列化 Claude Code 配置失败: %w", err) + } + + if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil { + return fmt.Errorf("创建 Claude Code 配置目录失败: %w", err) + } + if err := os.WriteFile(configPath, append(data, '\n'), 0o644); err != nil { + return fmt.Errorf("写入 Claude Code 配置失败: %w", err) + } + return nil +} + +func readClaudeCodeConfig(configPath string) (map[string]any, error) { + data, err := os.ReadFile(configPath) + if err != nil { + if os.IsNotExist(err) { + return map[string]any{}, nil + } + return nil, fmt.Errorf("读取 Claude Code 配置失败: %w", err) + } + + if strings.TrimSpace(string(data)) == "" { + return map[string]any{}, nil + } + + var root map[string]any + if err := json.Unmarshal(data, &root); err != nil { + return nil, fmt.Errorf("解析 Claude Code 配置失败: %w", err) + } + if root == nil { + return map[string]any{}, nil + } + return root, nil +} + +func ensureJSONMap(root map[string]any, key string) (map[string]any, error) { + if root == nil { + return nil, fmt.Errorf("JSON 根对象不能为空") + } + + value, exists := root[key] + if !exists || value == nil { + result := map[string]any{} + root[key] = result + return result, nil + } + + typed, ok := value.(map[string]any) + if !ok { + return nil, fmt.Errorf("Claude Code 配置格式异常:%s 不是对象", key) + } + return typed, nil +} + +func readCodexMCPServerConfig(configPath string, serverID string) (codexMCPServerConfig, bool, error) { + data, err := os.ReadFile(configPath) + if err != nil { + if os.IsNotExist(err) { + return codexMCPServerConfig{}, false, nil + } + return codexMCPServerConfig{}, false, fmt.Errorf("读取 Codex 配置失败: %w", err) + } + return parseCodexMCPServerConfig(string(data), serverID) +} + +func upsertCodexMCPServerConfig(configPath string, serverID string, serverConfig codexMCPServerConfig) error { + data, err := os.ReadFile(configPath) + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("读取 Codex 配置失败: %w", err) + } + + updated := replaceOrAppendCodexMCPServerBlock(string(data), strings.TrimSpace(serverID), renderCodexMCPServerBlock(serverID, serverConfig)) + if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil { + return fmt.Errorf("创建 Codex 配置目录失败: %w", err) + } + if err := os.WriteFile(configPath, []byte(updated), 0o644); err != nil { + return fmt.Errorf("写入 Codex 配置失败: %w", err) + } + return nil +} + +func renderCodexMCPServerBlock(serverID string, serverConfig codexMCPServerConfig) string { + trimmedID := strings.TrimSpace(serverID) + if trimmedID == "" { + trimmedID = gonaviMCPServerID + } + + lines := []string{ + fmt.Sprintf("[mcp_servers.%s]", trimmedID), + fmt.Sprintf("command = %s", tomlString(serverConfig.Command)), + fmt.Sprintf("args = [%s]", strings.Join(renderTomlStringArray(serverConfig.Args), ", ")), + } + if serverConfig.StartupTimeoutSec > 0 { + lines = append(lines, fmt.Sprintf("startup_timeout_sec = %d", serverConfig.StartupTimeoutSec)) + } + return strings.Join(lines, "\n") + "\n" +} + +func parseCodexMCPServerConfig(content string, serverID string) (codexMCPServerConfig, bool, error) { + lines := strings.Split(strings.ReplaceAll(content, "\r\n", "\n"), "\n") + mainHeader := fmt.Sprintf("[mcp_servers.%s]", strings.TrimSpace(serverID)) + result := codexMCPServerConfig{} + found := false + inside := false + + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if !inside { + if trimmed == mainHeader { + inside = true + found = true + } + continue + } + if isTOMLHeaderLine(trimmed) { + break + } + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + continue + } + + key, value, ok := splitTOMLAssignment(trimmed) + if !ok { + continue + } + switch key { + case "command": + parsed, err := parseTOMLString(value) + if err != nil { + return result, true, fmt.Errorf("Codex 配置格式异常:mcp_servers.%s.command 解析失败", strings.TrimSpace(serverID)) + } + result.Command = parsed + case "args": + parsed, err := parseTOMLStringArray(value) + if err != nil { + return result, true, fmt.Errorf("Codex 配置格式异常:mcp_servers.%s.args 解析失败", strings.TrimSpace(serverID)) + } + result.Args = parsed + case "startup_timeout_sec": + parsed, err := strconv.Atoi(strings.TrimSpace(value)) + if err != nil { + return result, true, fmt.Errorf("Codex 配置格式异常:mcp_servers.%s.startup_timeout_sec 解析失败", strings.TrimSpace(serverID)) + } + result.StartupTimeoutSec = parsed + } + } + + return result, found, nil +} + +func replaceOrAppendCodexMCPServerBlock(content string, serverID string, block string) string { + lines := strings.Split(strings.ReplaceAll(content, "\r\n", "\n"), "\n") + mainHeader := fmt.Sprintf("[mcp_servers.%s]", serverID) + nestedPrefix := fmt.Sprintf("[mcp_servers.%s.", serverID) + + start, end := -1, -1 + for index, line := range lines { + trimmed := strings.TrimSpace(line) + if start == -1 { + if trimmed == mainHeader || strings.HasPrefix(trimmed, nestedPrefix) { + start = index + } + continue + } + if isTOMLHeaderLine(trimmed) && trimmed != mainHeader && !strings.HasPrefix(trimmed, nestedPrefix) { + end = index + break + } + } + if start != -1 && end == -1 { + end = len(lines) + } + + rendered := strings.TrimRight(block, "\n") + if start == -1 { + base := strings.TrimSpace(strings.Join(lines, "\n")) + if base == "" { + return rendered + "\n" + } + return strings.TrimRight(strings.Join(lines, "\n"), "\n") + "\n\n" + rendered + "\n" + } + + before := strings.TrimRight(strings.Join(lines[:start], "\n"), "\n") + after := strings.TrimLeft(strings.Join(lines[end:], "\n"), "\n") + switch { + case before == "" && after == "": + return rendered + "\n" + case before == "": + return rendered + "\n\n" + after + case after == "": + return before + "\n\n" + rendered + "\n" + default: + return before + "\n\n" + rendered + "\n\n" + after + } +} + +func renderTomlStringArray(values []string) []string { + rendered := make([]string, 0, len(values)) + for _, value := range values { + rendered = append(rendered, tomlString(value)) + } + return rendered +} + +func tomlString(value string) string { + if !strings.Contains(value, "'") && !strings.Contains(value, "\n") && !strings.Contains(value, "\r") { + return "'" + value + "'" + } + return strconv.Quote(value) +} + +func splitTOMLAssignment(line string) (string, string, bool) { + index := strings.Index(line, "=") + if index <= 0 { + return "", "", false + } + key := strings.TrimSpace(line[:index]) + value := strings.TrimSpace(line[index+1:]) + if key == "" { + return "", "", false + } + return key, value, true +} + +func parseTOMLString(value string) (string, error) { + value = strings.TrimSpace(value) + if len(value) < 2 { + return "", fmt.Errorf("字符串格式非法") + } + switch value[0] { + case '\'': + if value[len(value)-1] != '\'' { + return "", fmt.Errorf("单引号字符串未闭合") + } + return value[1 : len(value)-1], nil + case '"': + parsed, err := strconv.Unquote(value) + if err != nil { + return "", err + } + return parsed, nil + default: + return "", fmt.Errorf("不是字符串") + } +} + +func parseTOMLStringArray(value string) ([]string, error) { + value = strings.TrimSpace(value) + if value == "" { + return []string{}, nil + } + if !strings.HasPrefix(value, "[") || !strings.HasSuffix(value, "]") { + return nil, fmt.Errorf("不是数组") + } + + inner := strings.TrimSpace(value[1 : len(value)-1]) + if inner == "" { + return []string{}, nil + } + + result := make([]string, 0, 4) + for inner != "" { + item, rest, err := consumeTOMLQuotedString(inner) + if err != nil { + return nil, err + } + result = append(result, item) + inner = strings.TrimSpace(rest) + if inner == "" { + break + } + if !strings.HasPrefix(inner, ",") { + return nil, fmt.Errorf("数组分隔符非法") + } + inner = strings.TrimSpace(inner[1:]) + } + return result, nil +} + +func consumeTOMLQuotedString(value string) (string, string, error) { + value = strings.TrimLeft(value, " \t") + if value == "" { + return "", "", fmt.Errorf("字符串为空") + } + switch value[0] { + case '\'': + end := strings.IndexByte(value[1:], '\'') + if end < 0 { + return "", "", fmt.Errorf("单引号字符串未闭合") + } + end++ + return value[1:end], value[end+1:], nil + case '"': + escaped := false + for index := 1; index < len(value); index++ { + ch := value[index] + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == '"' { + parsed, err := strconv.Unquote(value[:index+1]) + if err != nil { + return "", "", err + } + return parsed, value[index+1:], nil + } + } + return "", "", fmt.Errorf("双引号字符串未闭合") + default: + return "", "", fmt.Errorf("不是字符串") + } +} + +func decodeJSONLikeStringSlice(value any) ([]string, error) { + switch typed := value.(type) { + case nil: + return []string{}, nil + case []string: + return append([]string(nil), typed...), nil + case []any: + result := make([]string, 0, len(typed)) + for _, item := range typed { + str, ok := item.(string) + if !ok { + return nil, fmt.Errorf("数组元素不是字符串") + } + result = append(result, str) + } + return result, nil + default: + return nil, fmt.Errorf("不是字符串数组") + } +} + +func anyString(value any) string { + text, _ := value.(string) + return text +} + +func sameMCPCommand(actualCommand string, actualArgs []string, expectedCommand string, expectedArgs []string) bool { + return strings.TrimSpace(actualCommand) == strings.TrimSpace(expectedCommand) && + reflect.DeepEqual(normalizeStringSlice(actualArgs), normalizeStringSlice(expectedArgs)) +} + +func normalizeStringSlice(values []string) []string { + if len(values) == 0 { + return []string{} + } + result := make([]string, 0, len(values)) + for _, value := range values { + result = append(result, strings.TrimSpace(value)) + } + return result +} + +func isTOMLHeaderLine(line string) bool { + line = strings.TrimSpace(line) + return strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") +} diff --git a/internal/ai/service/claude_code_mcp_test.go b/internal/ai/service/claude_code_mcp_test.go new file mode 100644 index 0000000..d179ca3 --- /dev/null +++ b/internal/ai/service/claude_code_mcp_test.go @@ -0,0 +1,276 @@ +package aiservice + +import ( + "encoding/json" + "os" + "path/filepath" + "reflect" + "strings" + "testing" +) + +func TestResolveLocalMCPCommandUsesMainBinaryWithArgument(t *testing.T) { + command, args, err := resolveLocalMCPCommand(`C:\Program Files\GoNavi\GoNavi.exe`) + if err != nil { + t.Fatalf("resolveLocalMCPCommand returned error: %v", err) + } + if command != `C:\Program Files\GoNavi\GoNavi.exe` { + t.Fatalf("expected command to keep main binary path, got %q", command) + } + if !reflect.DeepEqual(args, []string{"mcp-server"}) { + t.Fatalf("expected main binary args %#v, got %#v", []string{"mcp-server"}, args) + } +} + +func TestResolveLocalMCPCommandKeepsDedicatedServerBinary(t *testing.T) { + command, args, err := resolveLocalMCPCommand(`D:\Work\CodeRepos\GoNavi\bin\gonavi-mcp-server.exe`) + if err != nil { + t.Fatalf("resolveLocalMCPCommand returned error: %v", err) + } + if command != `D:\Work\CodeRepos\GoNavi\bin\gonavi-mcp-server.exe` { + t.Fatalf("expected dedicated server path to be reused, got %q", command) + } + if len(args) != 0 { + t.Fatalf("expected dedicated server args to be empty, got %#v", args) + } +} + +func TestReadClaudeCodeMCPServerConfigReadsExistingInstall(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, ".claude.json") + initial := map[string]any{ + "mcpServers": map[string]any{ + gonaviMCPServerID: map[string]any{ + "type": "stdio", + "command": `C:\Program Files\GoNavi\GoNavi.exe`, + "args": []string{"mcp-server"}, + }, + }, + } + data, err := json.MarshalIndent(initial, "", " ") + if err != nil { + t.Fatalf("MarshalIndent returned error: %v", err) + } + if err := os.WriteFile(configPath, append(data, '\n'), 0o644); err != nil { + t.Fatalf("WriteFile returned error: %v", err) + } + + cfg, found, err := readClaudeCodeMCPServerConfig(configPath, gonaviMCPServerID) + if err != nil { + t.Fatalf("readClaudeCodeMCPServerConfig returned error: %v", err) + } + if !found { + t.Fatal("expected gonavi install to be detected") + } + if cfg.Command != `C:\Program Files\GoNavi\GoNavi.exe` { + t.Fatalf("unexpected command: %q", cfg.Command) + } + if !reflect.DeepEqual(cfg.Args, []string{"mcp-server"}) { + t.Fatalf("unexpected args: %#v", cfg.Args) + } +} + +func TestUpsertClaudeCodeMCPServerConfigCreatesAndMergesUserConfig(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, ".claude.json") + initial := map[string]any{ + "theme": "dark-daltonized", + "mcpServers": map[string]any{ + "memory": map[string]any{ + "type": "stdio", + "command": "cmd", + }, + }, + } + data, err := json.MarshalIndent(initial, "", " ") + if err != nil { + t.Fatalf("MarshalIndent returned error: %v", err) + } + if err := os.WriteFile(configPath, append(data, '\n'), 0o644); err != nil { + t.Fatalf("WriteFile returned error: %v", err) + } + + err = upsertClaudeCodeMCPServerConfig(configPath, gonaviMCPServerID, claudeCodeMCPServerConfig{ + Type: "stdio", + Command: `C:\Program Files\GoNavi\GoNavi.exe`, + Args: []string{"mcp-server"}, + Env: map[string]string{}, + }) + if err != nil { + t.Fatalf("upsertClaudeCodeMCPServerConfig returned error: %v", err) + } + + updated, err := os.ReadFile(configPath) + if err != nil { + t.Fatalf("ReadFile returned error: %v", err) + } + + var root map[string]any + if err := json.Unmarshal(updated, &root); err != nil { + t.Fatalf("Unmarshal returned error: %v", err) + } + if got := strings.TrimSpace(root["theme"].(string)); got != "dark-daltonized" { + t.Fatalf("expected theme to be preserved, got %q", got) + } + + mcpServers, ok := root["mcpServers"].(map[string]any) + if !ok { + t.Fatalf("expected mcpServers object, got %#v", root["mcpServers"]) + } + if _, ok := mcpServers["memory"]; !ok { + t.Fatalf("expected existing memory server to be preserved, got %#v", mcpServers) + } + + gonavi, ok := mcpServers[gonaviMCPServerID].(map[string]any) + if !ok { + t.Fatalf("expected gonavi server object, got %#v", mcpServers[gonaviMCPServerID]) + } + if got := strings.TrimSpace(gonavi["command"].(string)); got != `C:\Program Files\GoNavi\GoNavi.exe` { + t.Fatalf("expected gonavi command to be written, got %q", got) + } + args, ok := gonavi["args"].([]any) + if !ok || len(args) != 1 || strings.TrimSpace(args[0].(string)) != "mcp-server" { + t.Fatalf("expected gonavi args to contain mcp-server, got %#v", gonavi["args"]) + } +} + +func TestUpsertClaudeCodeMCPServerConfigRejectsInvalidMCPServersShape(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, ".claude.json") + if err := os.WriteFile(configPath, []byte("{\"mcpServers\":[]}"), 0o644); err != nil { + t.Fatalf("WriteFile returned error: %v", err) + } + + err := upsertClaudeCodeMCPServerConfig(configPath, gonaviMCPServerID, claudeCodeMCPServerConfig{ + Type: "stdio", + Command: "GoNavi.exe", + }) + if err == nil { + t.Fatal("expected invalid mcpServers shape to return error") + } + if !strings.Contains(err.Error(), "mcpServers 不是对象") { + t.Fatalf("expected invalid shape error, got %v", err) + } +} + +func TestParseCodexMCPServerConfigDetectsExistingInstall(t *testing.T) { + content := strings.Join([]string{ + `model = "gpt-5.4"`, + ``, + `[mcp_servers.gonavi]`, + `command = 'C:\Program Files\GoNavi\GoNavi.exe'`, + `args = ['mcp-server']`, + `startup_timeout_sec = 60`, + ``, + `[projects.'D:\Work\CodeRepos\GoNavi']`, + `trust_level = "trusted"`, + ``, + }, "\n") + + cfg, found, err := parseCodexMCPServerConfig(content, gonaviMCPServerID) + if err != nil { + t.Fatalf("parseCodexMCPServerConfig returned error: %v", err) + } + if !found { + t.Fatal("expected gonavi install to be detected") + } + if cfg.Command != `C:\Program Files\GoNavi\GoNavi.exe` { + t.Fatalf("unexpected command: %q", cfg.Command) + } + if !reflect.DeepEqual(cfg.Args, []string{"mcp-server"}) { + t.Fatalf("unexpected args: %#v", cfg.Args) + } + if cfg.StartupTimeoutSec != 60 { + t.Fatalf("unexpected startup timeout: %d", cfg.StartupTimeoutSec) + } +} + +func TestUpsertCodexMCPServerConfigCreatesAndMergesConfig(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.toml") + initial := strings.Join([]string{ + `model = "gpt-5.4"`, + ``, + `[mcp_servers.memory]`, + `command = "cmd"`, + `args = ["/c", "npx"]`, + ``, + }, "\n") + if err := os.WriteFile(configPath, []byte(initial), 0o644); err != nil { + t.Fatalf("WriteFile returned error: %v", err) + } + + err := upsertCodexMCPServerConfig(configPath, gonaviMCPServerID, codexMCPServerConfig{ + Command: `C:\Program Files\GoNavi\GoNavi.exe`, + Args: []string{"mcp-server"}, + StartupTimeoutSec: defaultCodexMCPStartupTimeoutSecond, + }) + if err != nil { + t.Fatalf("upsertCodexMCPServerConfig returned error: %v", err) + } + + updated, err := os.ReadFile(configPath) + if err != nil { + t.Fatalf("ReadFile returned error: %v", err) + } + text := string(updated) + if !strings.Contains(text, `[mcp_servers.memory]`) { + t.Fatalf("expected memory server to be preserved, got %s", text) + } + if !strings.Contains(text, `[mcp_servers.gonavi]`) { + t.Fatalf("expected gonavi section to be created, got %s", text) + } + if !strings.Contains(text, `command = 'C:\Program Files\GoNavi\GoNavi.exe'`) { + t.Fatalf("expected gonavi command to be written, got %s", text) + } + if !strings.Contains(text, `args = ['mcp-server']`) { + t.Fatalf("expected gonavi args to be written, got %s", text) + } + if !strings.Contains(text, `startup_timeout_sec = 60`) { + t.Fatalf("expected startup timeout to be written, got %s", text) + } +} + +func TestUpsertCodexMCPServerConfigReplacesExistingBlockAndNestedSections(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.toml") + initial := strings.Join([]string{ + `model = "gpt-5.4"`, + ``, + `[mcp_servers.gonavi]`, + `command = 'old.exe'`, + `args = ['old']`, + `startup_timeout_sec = 15`, + ``, + `[mcp_servers.gonavi.env]`, + `FOO = "bar"`, + ``, + `[projects.'D:\Work\CodeRepos\GoNavi']`, + `trust_level = "trusted"`, + ``, + }, "\n") + if err := os.WriteFile(configPath, []byte(initial), 0o644); err != nil { + t.Fatalf("WriteFile returned error: %v", err) + } + + err := upsertCodexMCPServerConfig(configPath, gonaviMCPServerID, codexMCPServerConfig{ + Command: `C:\Program Files\GoNavi\GoNavi.exe`, + Args: []string{"mcp-server"}, + StartupTimeoutSec: defaultCodexMCPStartupTimeoutSecond, + }) + if err != nil { + t.Fatalf("upsertCodexMCPServerConfig returned error: %v", err) + } + + updated, err := os.ReadFile(configPath) + if err != nil { + t.Fatalf("ReadFile returned error: %v", err) + } + text := string(updated) + if strings.Contains(text, `command = 'old.exe'`) || strings.Contains(text, `[mcp_servers.gonavi.env]`) { + t.Fatalf("expected old gonavi block to be replaced, got %s", text) + } + if !strings.Contains(text, `[projects.'D:\Work\CodeRepos\GoNavi']`) { + t.Fatalf("expected unrelated project config to be preserved, got %s", text) + } +} diff --git a/internal/ai/types.go b/internal/ai/types.go index 0783a77..64a60c0 100644 --- a/internal/ai/types.go +++ b/internal/ai/types.go @@ -136,6 +136,31 @@ type MCPToolCallResult struct { IsError bool `json:"isError"` } +// MCPClientInstallResult 表示安装 GoNavi 到外部 MCP 客户端配置文件的结果。 +type MCPClientInstallResult struct { + Success bool `json:"success"` + Client string `json:"client,omitempty"` + Message string `json:"message"` + ConfigPath string `json:"configPath,omitempty"` + Command string `json:"command,omitempty"` + Args []string `json:"args,omitempty"` +} + +// MCPClientInstallStatus 表示 GoNavi MCP 在外部客户端中的当前安装状态。 +type MCPClientInstallStatus struct { + Client string `json:"client"` + DisplayName string `json:"displayName"` + Installed bool `json:"installed"` + MatchesCurrent bool `json:"matchesCurrent"` + Message string `json:"message"` + ConfigPath string `json:"configPath,omitempty"` + Command string `json:"command,omitempty"` + Args []string `json:"args,omitempty"` +} + +// ClaudeCodeMCPInstallResult 兼容旧命名,便于平滑迁移到通用结果类型。 +type ClaudeCodeMCPInstallResult = MCPClientInstallResult + // SkillScope 表示 Skill 的适用场景 type SkillScope string diff --git a/internal/mcpserver/backend.go b/internal/mcpserver/backend.go new file mode 100644 index 0000000..6923f90 --- /dev/null +++ b/internal/mcpserver/backend.go @@ -0,0 +1,98 @@ +package mcpserver + +import ( + "context" + + "GoNavi-Wails/internal/ai" + aiservice "GoNavi-Wails/internal/ai/service" + appcore "GoNavi-Wails/internal/app" + "GoNavi-Wails/internal/appdata" + "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/logger" +) + +// Backend 抽象 GoNavi 后端能力,便于复用真实 App 和单元测试替身。 +type Backend interface { + Close(context.Context) error + GetSavedConnections() ([]connection.SavedConnectionView, error) + GetEditableSavedConnection(id string) (connection.SavedConnectionView, error) + DBGetDatabases(config connection.ConnectionConfig) connection.QueryResult + DBGetTables(config connection.ConnectionConfig, dbName string) connection.QueryResult + DBGetColumns(config connection.ConnectionConfig, dbName string, tableName string) connection.QueryResult + DBShowCreateTable(config connection.ConnectionConfig, dbName string, tableName string) connection.QueryResult + DBQueryMulti(config connection.ConnectionConfig, dbName string, query string, queryID string) connection.QueryResult + InspectSQL(dbType string, sql string) appcore.SQLInspection + GetSQLSafetyLevel() ai.SQLPermissionLevel +} + +// AppBackend 基于现有 internal/app.App 暴露 MCP 所需数据库能力。 +type AppBackend struct { + app *appcore.App +} + +func NewAppBackend(ctx context.Context) *AppBackend { + if ctx == nil { + ctx = context.Background() + } + a := appcore.NewApp() + appcore.InitializeLifecycle(a, ctx) + return &AppBackend{app: a} +} + +func (b *AppBackend) Close(ctx context.Context) error { + if b == nil || b.app == nil { + return nil + } + if ctx == nil { + ctx = context.Background() + } + b.app.Shutdown(ctx) + return nil +} + +func (b *AppBackend) GetSavedConnections() ([]connection.SavedConnectionView, error) { + return b.app.GetSavedConnections() +} + +func (b *AppBackend) GetEditableSavedConnection(id string) (connection.SavedConnectionView, error) { + return b.app.GetEditableSavedConnection(id) +} + +func (b *AppBackend) DBGetDatabases(config connection.ConnectionConfig) connection.QueryResult { + return b.app.DBGetDatabases(config) +} + +func (b *AppBackend) DBGetTables(config connection.ConnectionConfig, dbName string) connection.QueryResult { + return b.app.DBGetTables(config, dbName) +} + +func (b *AppBackend) DBGetColumns(config connection.ConnectionConfig, dbName string, tableName string) connection.QueryResult { + return b.app.DBGetColumns(config, dbName, tableName) +} + +func (b *AppBackend) DBShowCreateTable(config connection.ConnectionConfig, dbName string, tableName string) connection.QueryResult { + return b.app.DBShowCreateTable(config, dbName, tableName) +} + +func (b *AppBackend) DBQueryMulti(config connection.ConnectionConfig, dbName string, query string, queryID string) connection.QueryResult { + return b.app.DBQueryMulti(config, dbName, query, queryID) +} + +func (b *AppBackend) InspectSQL(dbType string, sql string) appcore.SQLInspection { + return appcore.InspectSQL(dbType, sql) +} + +func (b *AppBackend) GetSQLSafetyLevel() ai.SQLPermissionLevel { + inspection, err := aiservice.NewProviderConfigStore(appdata.MustResolveActiveRoot(), nil).Inspect() + if err != nil { + logger.Error(err, "加载 MCP SQL 安全控制失败,按只读模式回退") + return ai.PermissionReadOnly + } + + switch inspection.Snapshot.SafetyLevel { + case ai.PermissionReadOnly, ai.PermissionReadWrite, ai.PermissionFull: + return inspection.Snapshot.SafetyLevel + default: + return ai.PermissionReadOnly + } +} diff --git a/internal/mcpserver/run.go b/internal/mcpserver/run.go new file mode 100644 index 0000000..3374c97 --- /dev/null +++ b/internal/mcpserver/run.go @@ -0,0 +1,29 @@ +package mcpserver + +import ( + "context" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// RunAppStdioServer 启动基于真实 GoNavi App 的 stdio MCP server。 +func RunAppStdioServer(ctx context.Context) error { + if ctx == nil { + ctx = context.Background() + } + + backend := NewAppBackend(ctx) + defer backend.Close(ctx) + + return RunStdioServer(ctx, backend) +} + +// RunStdioServer 使用指定 backend 启动 stdio MCP server。 +func RunStdioServer(ctx context.Context, backend Backend) error { + if ctx == nil { + ctx = context.Background() + } + + server := NewServer(backend) + return server.Run(ctx, &mcp.StdioTransport{}) +} diff --git a/internal/mcpserver/server.go b/internal/mcpserver/server.go new file mode 100644 index 0000000..b7a79e5 --- /dev/null +++ b/internal/mcpserver/server.go @@ -0,0 +1,59 @@ +package mcpserver + +import ( + "runtime/debug" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +func NewServer(backend Backend) *mcp.Server { + server := mcp.NewServer(&mcp.Implementation{ + Name: "gonavi-ai", + Version: implementationVersion(), + }, nil) + + service := NewService(backend) + + mcp.AddTool(server, &mcp.Tool{ + Name: "get_connections", + Description: "列出当前 GoNavi 已保存的数据库连接,先调用它获取 connectionId。不会返回明文密码等敏感信息。", + }, service.GetConnections) + + mcp.AddTool(server, &mcp.Tool{ + Name: "get_databases", + Description: "根据 connectionId 获取数据库/Schema 列表。", + }, service.GetDatabases) + + mcp.AddTool(server, &mcp.Tool{ + Name: "get_tables", + Description: "根据 connectionId 和可选 dbName 获取表列表。dbName 为空时优先使用保存连接里的默认数据库。", + }, service.GetTables) + + mcp.AddTool(server, &mcp.Tool{ + Name: "get_columns", + Description: "根据 connectionId、可选 dbName、tableName 获取字段定义。", + }, service.GetColumns) + + mcp.AddTool(server, &mcp.Tool{ + Name: "get_table_ddl", + Description: "根据 connectionId、可选 dbName、tableName 获取建表或建视图语句。", + }, service.GetTableDDL) + + mcp.AddTool(server, &mcp.Tool{ + Name: "execute_sql", + Description: "执行 SQL,支持多语句结果集。执行范围受 GoNavi AI 设置中的安全控制约束;命中允许范围内的 DML/DDL 等非只读语句时,仍必须显式传 allowMutating=true。", + }, service.ExecuteSQL) + + return server +} + +func implementationVersion() string { + if info, ok := debug.ReadBuildInfo(); ok { + version := strings.TrimSpace(info.Main.Version) + if version != "" && version != "(devel)" { + return version + } + } + return "dev" +} diff --git a/internal/mcpserver/service.go b/internal/mcpserver/service.go new file mode 100644 index 0000000..50eaaf4 --- /dev/null +++ b/internal/mcpserver/service.go @@ -0,0 +1,682 @@ +package mcpserver + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "GoNavi-Wails/internal/ai" + appcore "GoNavi-Wails/internal/app" + "GoNavi-Wails/internal/connection" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +const ( + defaultMaxRowsPerResult = 200 + maxRowsPerResultLimit = 1000 +) + +type Service struct { + backend Backend +} + +func NewService(backend Backend) *Service { + return &Service{backend: backend} +} + +type emptyArgs struct{} + +type connectionIDArgs struct { + ConnectionID string `json:"connectionId" jsonschema:"get_connections 返回的连接 ID"` +} + +type databaseArgs struct { + ConnectionID string `json:"connectionId" jsonschema:"get_connections 返回的连接 ID"` + DBName string `json:"dbName,omitempty" jsonschema:"可选数据库/Schema 名称。为空时优先使用保存连接里的默认数据库"` +} + +type tableArgs struct { + ConnectionID string `json:"connectionId" jsonschema:"get_connections 返回的连接 ID"` + DBName string `json:"dbName,omitempty" jsonschema:"可选数据库/Schema 名称。为空时优先使用保存连接里的默认数据库"` + TableName string `json:"tableName" jsonschema:"目标表或视图名称"` +} + +type executeSQLArgs struct { + ConnectionID string `json:"connectionId" jsonschema:"get_connections 返回的连接 ID"` + DBName string `json:"dbName,omitempty" jsonschema:"可选数据库/Schema 名称。为空时优先使用保存连接里的默认数据库"` + SQL string `json:"sql" jsonschema:"待执行的 SQL 文本,可以包含多条语句"` + AllowMutating bool `json:"allowMutating,omitempty" jsonschema:"当 SQL 包含当前 AI 安全控制允许范围内的 DDL/DML 等非只读语句时,必须显式设为 true"` + MaxRowsPerResult int `json:"maxRowsPerResult,omitempty" jsonschema:"每个结果集最多返回多少行。默认 200,最大 1000"` +} + +type connectionDescriptor struct { + ID string `json:"id"` + Name string `json:"name"` + Type string `json:"type"` + Host string `json:"host,omitempty"` + Port int `json:"port,omitempty"` + Database string `json:"database,omitempty"` + Driver string `json:"driver,omitempty"` + Topology string `json:"topology,omitempty"` + Target string `json:"target,omitempty"` + UseSSH bool `json:"useSSH,omitempty"` + UseProxy bool `json:"useProxy,omitempty"` + UseHTTPTunnel bool `json:"useHttpTunnel,omitempty"` + DefaultDatabase string `json:"defaultDatabase,omitempty"` +} + +type getConnectionsResult struct { + Connections []connectionDescriptor `json:"connections"` +} + +type getDatabasesResult struct { + ConnectionID string `json:"connectionId"` + Databases []string `json:"databases"` +} + +type getTablesResult struct { + ConnectionID string `json:"connectionId"` + DBName string `json:"dbName,omitempty"` + Tables []string `json:"tables"` +} + +type getColumnsResult struct { + ConnectionID string `json:"connectionId"` + DBName string `json:"dbName,omitempty"` + TableName string `json:"tableName"` + Columns []connection.ColumnDefinition `json:"columns"` +} + +type getTableDDLResult struct { + ConnectionID string `json:"connectionId"` + DBName string `json:"dbName,omitempty"` + TableName string `json:"tableName"` + DDL string `json:"ddl"` +} + +type sqlStatementSummary struct { + Index int `json:"index"` + Keyword string `json:"keyword,omitempty"` + ReadOnly bool `json:"readOnly"` +} + +type sqlResultSet struct { + StatementIndex int `json:"statementIndex,omitempty"` + Columns []string `json:"columns"` + Rows []map[string]interface{} `json:"rows"` + Messages []string `json:"messages,omitempty"` + RowCount int `json:"rowCount"` + Truncated bool `json:"truncated,omitempty"` +} + +type executeSQLResult struct { + ConnectionID string `json:"connectionId"` + DBName string `json:"dbName,omitempty"` + StatementCount int `json:"statementCount"` + ReadOnly bool `json:"readOnly"` + QueryID string `json:"queryId,omitempty"` + Message string `json:"message,omitempty"` + Truncated bool `json:"truncated,omitempty"` + Statements []sqlStatementSummary `json:"statements"` + Results []sqlResultSet `json:"results"` +} + +func (s *Service) GetConnections(ctx context.Context, req *mcp.CallToolRequest, args emptyArgs) (*mcp.CallToolResult, getConnectionsResult, error) { + _ = ctx + _ = req + _ = args + + items, err := s.backend.GetSavedConnections() + if err != nil { + return toolError("获取已保存连接失败: %v", err), getConnectionsResult{}, nil + } + + result := getConnectionsResult{ + Connections: make([]connectionDescriptor, 0, len(items)), + } + for _, item := range items { + cfg := item.Config + result.Connections = append(result.Connections, connectionDescriptor{ + ID: item.ID, + Name: item.Name, + Type: strings.TrimSpace(cfg.Type), + Host: strings.TrimSpace(cfg.Host), + Port: cfg.Port, + Database: strings.TrimSpace(cfg.Database), + Driver: strings.TrimSpace(cfg.Driver), + Topology: strings.TrimSpace(cfg.Topology), + Target: describeConnectionTarget(cfg), + UseSSH: cfg.UseSSH, + UseProxy: cfg.UseProxy, + UseHTTPTunnel: cfg.UseHTTPTunnel, + DefaultDatabase: strings.TrimSpace(cfg.Database), + }) + } + return successResult(), result, nil +} + +func (s *Service) GetDatabases(ctx context.Context, req *mcp.CallToolRequest, args connectionIDArgs) (*mcp.CallToolResult, getDatabasesResult, error) { + _ = ctx + _ = req + + view, errResult := s.resolveConnection(args.ConnectionID) + if errResult != nil { + return errResult, getDatabasesResult{}, nil + } + + queryResult := s.backend.DBGetDatabases(view.Config) + if !queryResult.Success { + return toolError("获取数据库列表失败: %s", strings.TrimSpace(queryResult.Message)), getDatabasesResult{}, nil + } + + databases, err := decodeNamedStringSlice(queryResult.Data, "Database", "database", "name") + if err != nil { + return toolError("解析数据库列表失败: %v", err), getDatabasesResult{}, nil + } + + return successResult(), getDatabasesResult{ + ConnectionID: view.ID, + Databases: ensureNonNilStrings(databases), + }, nil +} + +func (s *Service) GetTables(ctx context.Context, req *mcp.CallToolRequest, args databaseArgs) (*mcp.CallToolResult, getTablesResult, error) { + _ = ctx + _ = req + + view, errResult := s.resolveConnection(args.ConnectionID) + if errResult != nil { + return errResult, getTablesResult{}, nil + } + + dbName := effectiveDBName(args.DBName, view.Config) + queryResult := s.backend.DBGetTables(view.Config, dbName) + if !queryResult.Success { + return toolError("获取表列表失败: %s", strings.TrimSpace(queryResult.Message)), getTablesResult{}, nil + } + + tables, err := decodeNamedStringSlice(queryResult.Data, "Table", "table", "name") + if err != nil { + return toolError("解析表列表失败: %v", err), getTablesResult{}, nil + } + + return successResult(), getTablesResult{ + ConnectionID: view.ID, + DBName: dbName, + Tables: ensureNonNilStrings(tables), + }, nil +} + +func (s *Service) GetColumns(ctx context.Context, req *mcp.CallToolRequest, args tableArgs) (*mcp.CallToolResult, getColumnsResult, error) { + _ = ctx + _ = req + + view, errResult := s.resolveConnection(args.ConnectionID) + if errResult != nil { + return errResult, getColumnsResult{}, nil + } + + tableName := strings.TrimSpace(args.TableName) + if tableName == "" { + return toolError("tableName 不能为空"), getColumnsResult{}, nil + } + + dbName := effectiveDBName(args.DBName, view.Config) + queryResult := s.backend.DBGetColumns(view.Config, dbName, tableName) + if !queryResult.Success { + return toolError("获取字段列表失败: %s", strings.TrimSpace(queryResult.Message)), getColumnsResult{}, nil + } + + columns, err := decodeColumns(queryResult.Data) + if err != nil { + return toolError("解析字段列表失败: %v", err), getColumnsResult{}, nil + } + + return successResult(), getColumnsResult{ + ConnectionID: view.ID, + DBName: dbName, + TableName: tableName, + Columns: ensureNonNilColumns(columns), + }, nil +} + +func (s *Service) GetTableDDL(ctx context.Context, req *mcp.CallToolRequest, args tableArgs) (*mcp.CallToolResult, getTableDDLResult, error) { + _ = ctx + _ = req + + view, errResult := s.resolveConnection(args.ConnectionID) + if errResult != nil { + return errResult, getTableDDLResult{}, nil + } + + tableName := strings.TrimSpace(args.TableName) + if tableName == "" { + return toolError("tableName 不能为空"), getTableDDLResult{}, nil + } + + dbName := effectiveDBName(args.DBName, view.Config) + queryResult := s.backend.DBShowCreateTable(view.Config, dbName, tableName) + if !queryResult.Success { + return toolError("获取建表语句失败: %s", strings.TrimSpace(queryResult.Message)), getTableDDLResult{}, nil + } + + ddl, err := decodeString(queryResult.Data) + if err != nil { + return toolError("解析建表语句失败: %v", err), getTableDDLResult{}, nil + } + + return successResult(), getTableDDLResult{ + ConnectionID: view.ID, + DBName: dbName, + TableName: tableName, + DDL: ddl, + }, nil +} + +func (s *Service) ExecuteSQL(ctx context.Context, req *mcp.CallToolRequest, args executeSQLArgs) (*mcp.CallToolResult, executeSQLResult, error) { + _ = ctx + _ = req + + view, errResult := s.resolveConnection(args.ConnectionID) + if errResult != nil { + return errResult, executeSQLResult{}, nil + } + + sqlText := strings.TrimSpace(args.SQL) + if sqlText == "" { + return toolError("sql 不能为空"), executeSQLResult{}, nil + } + + inspection := s.backend.InspectSQL(view.Config.Type, sqlText) + if inspection.StatementCount == 0 { + return toolError("未识别到可执行的 SQL 语句"), executeSQLResult{}, nil + } + + safetyLevel := normalizeSQLSafetyLevel(s.backend.GetSQLSafetyLevel()) + safetyDecision := evaluateSQLSafety(safetyLevel, inspection) + if len(safetyDecision.disallowed) > 0 { + return toolError("%s", buildSafetyDeniedMessage(safetyLevel, safetyDecision.disallowed)), executeSQLResult{}, nil + } + if safetyDecision.requiresConfirm && !args.AllowMutating { + return toolError("当前 SQL 已通过 GoNavi AI 安全控制(%s),但包含非只读语句 %s,请显式传入 allowMutating=true 后重试", safetyLevelDisplayName(safetyLevel), formatSafetyStatements(safetyDecision.confirmRequired)), executeSQLResult{}, nil + } + + dbName := effectiveDBName(args.DBName, view.Config) + queryResult := s.backend.DBQueryMulti(view.Config, dbName, sqlText, "") + if !queryResult.Success { + return toolError("SQL 执行失败: %s", strings.TrimSpace(queryResult.Message)), executeSQLResult{}, nil + } + + resultSets, err := decodeResultSets(queryResult.Data) + if err != nil { + return toolError("解析 SQL 执行结果失败: %v", err), executeSQLResult{}, nil + } + + normalizedResults, truncated := normalizeResultSets(resultSets, normalizeMaxRowsPerResult(args.MaxRowsPerResult)) + return successResult(), executeSQLResult{ + ConnectionID: view.ID, + DBName: dbName, + StatementCount: inspection.StatementCount, + ReadOnly: inspection.ReadOnly, + QueryID: strings.TrimSpace(queryResult.QueryID), + Message: strings.TrimSpace(queryResult.Message), + Truncated: truncated, + Statements: toStatementSummaries(inspection.Statements), + Results: normalizedResults, + }, nil +} + +func successResult() *mcp.CallToolResult { + return &mcp.CallToolResult{} +} + +func toolError(format string, args ...interface{}) *mcp.CallToolResult { + return &mcp.CallToolResult{ + IsError: true, + Content: []mcp.Content{ + &mcp.TextContent{Text: fmt.Sprintf(format, args...)}, + }, + } +} + +func (s *Service) resolveConnection(connectionID string) (connection.SavedConnectionView, *mcp.CallToolResult) { + id := strings.TrimSpace(connectionID) + if id == "" { + return connection.SavedConnectionView{}, toolError("connectionId 不能为空") + } + view, err := s.backend.GetEditableSavedConnection(id) + if err != nil { + return connection.SavedConnectionView{}, toolError("加载连接 %s 失败: %v", id, err) + } + return view, nil +} + +func effectiveDBName(input string, config connection.ConnectionConfig) string { + if trimmed := strings.TrimSpace(input); trimmed != "" { + return trimmed + } + return strings.TrimSpace(config.Database) +} + +func describeConnectionTarget(config connection.ConnectionConfig) string { + dbType := strings.ToLower(strings.TrimSpace(config.Type)) + switch dbType { + case "sqlite", "duckdb": + if path := strings.TrimSpace(config.Database); path != "" { + return path + } + } + if len(config.Hosts) > 0 { + return strings.Join(config.Hosts, ",") + } + if host := strings.TrimSpace(config.Host); host != "" { + if config.Port > 0 { + return fmt.Sprintf("%s:%d", host, config.Port) + } + return host + } + if uri := strings.TrimSpace(config.URI); uri != "" { + return uri + } + if dsn := strings.TrimSpace(config.DSN); dsn != "" { + return dsn + } + return strings.TrimSpace(config.Database) +} + +func decodeNamedStringSlice(data interface{}, keys ...string) ([]string, error) { + switch items := data.(type) { + case nil: + return []string{}, nil + case []string: + return ensureNonNilStrings(append([]string(nil), items...)), nil + case []map[string]string: + result := make([]string, 0, len(items)) + for _, item := range items { + result = append(result, pickNamedStringFromStringMap(item, keys...)) + } + return result, nil + case []map[string]interface{}: + result := make([]string, 0, len(items)) + for _, item := range items { + result = append(result, pickNamedStringFromAnyMap(item, keys...)) + } + return result, nil + default: + var decoded []map[string]interface{} + if err := remarshal(data, &decoded); err != nil { + return nil, err + } + return decodeNamedStringSlice(decoded, keys...) + } +} + +func pickNamedStringFromStringMap(item map[string]string, keys ...string) string { + for _, key := range keys { + if value := strings.TrimSpace(item[key]); value != "" { + return value + } + } + for _, value := range item { + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + } + return "" +} + +func pickNamedStringFromAnyMap(item map[string]interface{}, keys ...string) string { + for _, key := range keys { + if value, ok := item[key]; ok { + if text := strings.TrimSpace(fmt.Sprint(value)); text != "" { + return text + } + } + } + for _, value := range item { + if text := strings.TrimSpace(fmt.Sprint(value)); text != "" { + return text + } + } + return "" +} + +func decodeColumns(data interface{}) ([]connection.ColumnDefinition, error) { + switch cols := data.(type) { + case nil: + return []connection.ColumnDefinition{}, nil + case []connection.ColumnDefinition: + return ensureNonNilColumns(append([]connection.ColumnDefinition(nil), cols...)), nil + default: + var decoded []connection.ColumnDefinition + if err := remarshal(data, &decoded); err != nil { + return nil, err + } + return ensureNonNilColumns(decoded), nil + } +} + +func decodeString(data interface{}) (string, error) { + switch value := data.(type) { + case nil: + return "", nil + case string: + return value, nil + default: + return fmt.Sprint(value), nil + } +} + +func decodeResultSets(data interface{}) ([]connection.ResultSetData, error) { + switch items := data.(type) { + case nil: + return []connection.ResultSetData{}, nil + case []connection.ResultSetData: + return ensureNonNilResultSets(append([]connection.ResultSetData(nil), items...)), nil + default: + var decoded []connection.ResultSetData + if err := remarshal(data, &decoded); err != nil { + return nil, err + } + return ensureNonNilResultSets(decoded), nil + } +} + +func remarshal(from interface{}, to interface{}) error { + payload, err := json.Marshal(from) + if err != nil { + return err + } + return json.Unmarshal(payload, to) +} + +func normalizeMaxRowsPerResult(input int) int { + if input <= 0 { + return defaultMaxRowsPerResult + } + if input > maxRowsPerResultLimit { + return maxRowsPerResultLimit + } + return input +} + +func normalizeResultSets(resultSets []connection.ResultSetData, maxRows int) ([]sqlResultSet, bool) { + normalized := make([]sqlResultSet, 0, len(resultSets)) + truncatedAny := false + for _, resultSet := range resultSets { + rows := ensureNonNilRows(resultSet.Rows) + rowCount := len(rows) + truncated := false + if maxRows > 0 && len(rows) > maxRows { + rows = append([]map[string]interface{}(nil), rows[:maxRows]...) + truncated = true + truncatedAny = true + } + normalized = append(normalized, sqlResultSet{ + StatementIndex: resultSet.StatementIndex, + Columns: ensureNonNilStrings(append([]string(nil), resultSet.Columns...)), + Rows: rows, + Messages: ensureNonNilStrings(append([]string(nil), resultSet.Messages...)), + RowCount: rowCount, + Truncated: truncated, + }) + } + return normalized, truncatedAny +} + +func toStatementSummaries(items []appcore.SQLStatementInspection) []sqlStatementSummary { + result := make([]sqlStatementSummary, 0, len(items)) + for _, item := range items { + result = append(result, sqlStatementSummary{ + Index: item.Index, + Keyword: item.Keyword, + ReadOnly: item.ReadOnly, + }) + } + return result +} + +func ensureNonNilStrings(items []string) []string { + if items == nil { + return []string{} + } + return items +} + +func ensureNonNilColumns(items []connection.ColumnDefinition) []connection.ColumnDefinition { + if items == nil { + return []connection.ColumnDefinition{} + } + return items +} + +func ensureNonNilRows(items []map[string]interface{}) []map[string]interface{} { + if items == nil { + return []map[string]interface{}{} + } + return items +} + +func ensureNonNilResultSets(items []connection.ResultSetData) []connection.ResultSetData { + if items == nil { + return []connection.ResultSetData{} + } + return items +} + +type sqlSafetyStatement struct { + Index int + Keyword string + OperationType ai.SQLOperationType +} + +type sqlSafetyDecision struct { + requiresConfirm bool + disallowed []sqlSafetyStatement + confirmRequired []sqlSafetyStatement +} + +func evaluateSQLSafety(level ai.SQLPermissionLevel, inspection appcore.SQLInspection) sqlSafetyDecision { + decision := sqlSafetyDecision{ + disallowed: []sqlSafetyStatement{}, + confirmRequired: []sqlSafetyStatement{}, + } + + for _, stmt := range inspection.Statements { + statement := sqlSafetyStatement{ + Index: stmt.Index, + Keyword: strings.TrimSpace(stmt.Keyword), + OperationType: classifyStatementOperation(stmt), + } + if !isOperationAllowed(level, statement.OperationType) { + decision.disallowed = append(decision.disallowed, statement) + continue + } + if statement.OperationType != ai.SQLOpQuery { + decision.requiresConfirm = true + decision.confirmRequired = append(decision.confirmRequired, statement) + } + } + + return decision +} + +func classifyStatementOperation(stmt appcore.SQLStatementInspection) ai.SQLOperationType { + if stmt.ReadOnly { + return ai.SQLOpQuery + } + + switch strings.ToLower(strings.TrimSpace(stmt.Keyword)) { + case "insert", "update", "delete", "replace", "merge", "upsert": + return ai.SQLOpDML + case "create", "alter", "drop", "truncate", "rename": + return ai.SQLOpDDL + default: + return ai.SQLOpOther + } +} + +func isOperationAllowed(level ai.SQLPermissionLevel, opType ai.SQLOperationType) bool { + switch normalizeSQLSafetyLevel(level) { + case ai.PermissionReadOnly: + return opType == ai.SQLOpQuery + case ai.PermissionReadWrite: + return opType == ai.SQLOpQuery || opType == ai.SQLOpDML + case ai.PermissionFull: + return opType == ai.SQLOpQuery || opType == ai.SQLOpDML || opType == ai.SQLOpDDL + default: + return opType == ai.SQLOpQuery + } +} + +func normalizeSQLSafetyLevel(level ai.SQLPermissionLevel) ai.SQLPermissionLevel { + switch level { + case ai.PermissionReadOnly, ai.PermissionReadWrite, ai.PermissionFull: + return level + default: + return ai.PermissionReadOnly + } +} + +func buildSafetyDeniedMessage(level ai.SQLPermissionLevel, statements []sqlSafetyStatement) string { + return fmt.Sprintf("当前 GoNavi AI 安全控制为%s,已阻止以下语句:%s。%s", safetyLevelDisplayName(level), formatSafetyStatements(statements), safetyLevelRuleText(level)) +} + +func safetyLevelDisplayName(level ai.SQLPermissionLevel) string { + switch normalizeSQLSafetyLevel(level) { + case ai.PermissionReadOnly: + return "只读模式" + case ai.PermissionReadWrite: + return "读写模式" + case ai.PermissionFull: + return "完全模式" + default: + return "只读模式" + } +} + +func safetyLevelRuleText(level ai.SQLPermissionLevel) string { + switch normalizeSQLSafetyLevel(level) { + case ai.PermissionReadOnly: + return "只读模式仅允许查询语句。" + case ai.PermissionReadWrite: + return "读写模式仅允许查询和 DML 语句。" + case ai.PermissionFull: + return "完全模式仅允许查询、DML 和 DDL;未识别操作仍会被阻止。" + default: + return "只读模式仅允许查询语句。" + } +} + +func formatSafetyStatements(statements []sqlSafetyStatement) string { + parts := make([]string, 0, len(statements)) + for _, stmt := range statements { + keyword := strings.TrimSpace(stmt.Keyword) + if keyword == "" { + keyword = "unknown" + } + parts = append(parts, fmt.Sprintf("#%d %s(%s)", stmt.Index, strings.ToLower(keyword), strings.ToUpper(string(stmt.OperationType)))) + } + return strings.Join(parts, ",") +} diff --git a/internal/mcpserver/service_test.go b/internal/mcpserver/service_test.go new file mode 100644 index 0000000..d8d5fc7 --- /dev/null +++ b/internal/mcpserver/service_test.go @@ -0,0 +1,431 @@ +package mcpserver + +import ( + "context" + "strings" + "testing" + + "GoNavi-Wails/internal/ai" + appcore "GoNavi-Wails/internal/app" + "GoNavi-Wails/internal/connection" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +type fakeBackend struct { + savedConnections []connection.SavedConnectionView + savedConnectionsErr error + editableConnection connection.SavedConnectionView + editableErr error + databasesResult connection.QueryResult + tablesResult connection.QueryResult + columnsResult connection.QueryResult + ddlResult connection.QueryResult + queryResult connection.QueryResult + inspection appcore.SQLInspection + safetyLevel ai.SQLPermissionLevel + queryCalled bool +} + +func (f *fakeBackend) Close(context.Context) error { + return nil +} + +func (f *fakeBackend) GetSavedConnections() ([]connection.SavedConnectionView, error) { + return f.savedConnections, f.savedConnectionsErr +} + +func (f *fakeBackend) GetEditableSavedConnection(id string) (connection.SavedConnectionView, error) { + return f.editableConnection, f.editableErr +} + +func (f *fakeBackend) DBGetDatabases(config connection.ConnectionConfig) connection.QueryResult { + return f.databasesResult +} + +func (f *fakeBackend) DBGetTables(config connection.ConnectionConfig, dbName string) connection.QueryResult { + return f.tablesResult +} + +func (f *fakeBackend) DBGetColumns(config connection.ConnectionConfig, dbName string, tableName string) connection.QueryResult { + return f.columnsResult +} + +func (f *fakeBackend) DBShowCreateTable(config connection.ConnectionConfig, dbName string, tableName string) connection.QueryResult { + return f.ddlResult +} + +func (f *fakeBackend) DBQueryMulti(config connection.ConnectionConfig, dbName string, query string, queryID string) connection.QueryResult { + f.queryCalled = true + return f.queryResult +} + +func (f *fakeBackend) InspectSQL(dbType string, sql string) appcore.SQLInspection { + return f.inspection +} + +func (f *fakeBackend) GetSQLSafetyLevel() ai.SQLPermissionLevel { + if f.safetyLevel == "" { + return ai.PermissionReadOnly + } + return f.safetyLevel +} + +func TestGetConnectionsReturnsSavedConnectionSummaries(t *testing.T) { + backend := &fakeBackend{ + savedConnections: []connection.SavedConnectionView{ + { + ID: "mysql-main", + Name: "MySQL Main", + Config: connection.ConnectionConfig{ + Type: "mysql", + Host: "10.0.0.8", + Port: 3306, + Database: "app", + UseSSH: true, + }, + }, + { + ID: "duckdb-local", + Name: "DuckDB Local", + Config: connection.ConnectionConfig{ + Type: "duckdb", + Database: `C:\data\example.duckdb`, + }, + }, + }, + } + + service := NewService(backend) + result, out, err := service.GetConnections(context.Background(), nil, emptyArgs{}) + if err != nil { + t.Fatalf("GetConnections returned error: %v", err) + } + if result == nil || result.IsError { + t.Fatalf("expected success result, got %#v", result) + } + if len(out.Connections) != 2 { + t.Fatalf("expected 2 connections, got %d", len(out.Connections)) + } + if out.Connections[0].Target != "10.0.0.8:3306" { + t.Fatalf("unexpected mysql target: %q", out.Connections[0].Target) + } + if out.Connections[1].Target != `C:\data\example.duckdb` { + t.Fatalf("unexpected duckdb target: %q", out.Connections[1].Target) + } +} + +func TestExecuteSQLRejectsMutatingStatementsWithoutAllowMutating(t *testing.T) { + backend := &fakeBackend{ + editableConnection: connection.SavedConnectionView{ + ID: "mysql-main", + Config: connection.ConnectionConfig{ + Type: "mysql", + Database: "app", + }, + }, + inspection: appcore.SQLInspection{ + StatementCount: 1, + ReadOnly: false, + Statements: []appcore.SQLStatementInspection{ + {Index: 1, Keyword: "delete", ReadOnly: false}, + }, + }, + safetyLevel: ai.PermissionReadWrite, + } + + service := NewService(backend) + result, _, err := service.ExecuteSQL(context.Background(), nil, executeSQLArgs{ + ConnectionID: "mysql-main", + SQL: "delete from users where id = 1", + }) + if err != nil { + t.Fatalf("ExecuteSQL returned error: %v", err) + } + if result == nil || !result.IsError { + t.Fatalf("expected tool error, got %#v", result) + } + if !strings.Contains(firstTextContent(result), "allowMutating=true") { + t.Fatalf("unexpected error text: %q", firstTextContent(result)) + } + if backend.queryCalled { + t.Fatalf("expected SQL not to execute when allowMutating is false") + } +} + +func TestExecuteSQLRejectsMutatingStatementsWhenAISafetyIsReadOnly(t *testing.T) { + backend := &fakeBackend{ + editableConnection: connection.SavedConnectionView{ + ID: "mysql-main", + Config: connection.ConnectionConfig{ + Type: "mysql", + Database: "app", + }, + }, + inspection: appcore.SQLInspection{ + StatementCount: 1, + ReadOnly: false, + Statements: []appcore.SQLStatementInspection{ + {Index: 1, Keyword: "delete", ReadOnly: false}, + }, + }, + safetyLevel: ai.PermissionReadOnly, + } + + service := NewService(backend) + result, _, err := service.ExecuteSQL(context.Background(), nil, executeSQLArgs{ + ConnectionID: "mysql-main", + SQL: "delete from users where id = 1", + AllowMutating: true, + }) + if err != nil { + t.Fatalf("ExecuteSQL returned error: %v", err) + } + if result == nil || !result.IsError { + t.Fatalf("expected tool error, got %#v", result) + } + if !strings.Contains(firstTextContent(result), "只读模式") { + t.Fatalf("unexpected error text: %q", firstTextContent(result)) + } + if backend.queryCalled { + t.Fatalf("expected SQL not to execute when AI safety is readonly") + } +} + +func TestExecuteSQLRejectsDDLWhenAISafetyIsReadWrite(t *testing.T) { + backend := &fakeBackend{ + editableConnection: connection.SavedConnectionView{ + ID: "mysql-main", + Config: connection.ConnectionConfig{ + Type: "mysql", + Database: "app", + }, + }, + inspection: appcore.SQLInspection{ + StatementCount: 1, + ReadOnly: false, + Statements: []appcore.SQLStatementInspection{ + {Index: 1, Keyword: "drop", ReadOnly: false}, + }, + }, + safetyLevel: ai.PermissionReadWrite, + } + + service := NewService(backend) + result, _, err := service.ExecuteSQL(context.Background(), nil, executeSQLArgs{ + ConnectionID: "mysql-main", + SQL: "drop table users", + AllowMutating: true, + }) + if err != nil { + t.Fatalf("ExecuteSQL returned error: %v", err) + } + if result == nil || !result.IsError { + t.Fatalf("expected tool error, got %#v", result) + } + text := firstTextContent(result) + if !strings.Contains(text, "读写模式") || !strings.Contains(text, "DDL") { + t.Fatalf("unexpected error text: %q", text) + } + if backend.queryCalled { + t.Fatalf("expected SQL not to execute when AI safety blocks DDL") + } +} + +func TestExecuteSQLRejectsMixedStatementsWhenAISafetyBlocksLaterStatement(t *testing.T) { + backend := &fakeBackend{ + editableConnection: connection.SavedConnectionView{ + ID: "mysql-main", + Config: connection.ConnectionConfig{ + Type: "mysql", + Database: "app", + }, + }, + inspection: appcore.SQLInspection{ + StatementCount: 2, + ReadOnly: false, + Statements: []appcore.SQLStatementInspection{ + {Index: 1, Keyword: "select", ReadOnly: true}, + {Index: 2, Keyword: "delete", ReadOnly: false}, + }, + }, + safetyLevel: ai.PermissionReadOnly, + } + + service := NewService(backend) + result, _, err := service.ExecuteSQL(context.Background(), nil, executeSQLArgs{ + ConnectionID: "mysql-main", + SQL: "select * from users; delete from users where id = 1", + AllowMutating: true, + }) + if err != nil { + t.Fatalf("ExecuteSQL returned error: %v", err) + } + if result == nil || !result.IsError { + t.Fatalf("expected tool error, got %#v", result) + } + if !strings.Contains(firstTextContent(result), "#2 delete") { + t.Fatalf("unexpected error text: %q", firstTextContent(result)) + } + if backend.queryCalled { + t.Fatalf("expected SQL not to execute when a later statement is blocked") + } +} + +func TestExecuteSQLAllowsDMLWhenAISafetyIsReadWriteAndAllowMutating(t *testing.T) { + backend := &fakeBackend{ + editableConnection: connection.SavedConnectionView{ + ID: "mysql-main", + Config: connection.ConnectionConfig{ + Type: "mysql", + Database: "app", + }, + }, + inspection: appcore.SQLInspection{ + StatementCount: 1, + ReadOnly: false, + Statements: []appcore.SQLStatementInspection{ + {Index: 1, Keyword: "insert", ReadOnly: false}, + }, + }, + safetyLevel: ai.PermissionReadWrite, + queryResult: connection.QueryResult{ + Success: true, + Data: []connection.ResultSetData{}, + }, + } + + service := NewService(backend) + result, out, err := service.ExecuteSQL(context.Background(), nil, executeSQLArgs{ + ConnectionID: "mysql-main", + SQL: "insert into users(id) values (1)", + AllowMutating: true, + }) + if err != nil { + t.Fatalf("ExecuteSQL returned error: %v", err) + } + if result == nil || result.IsError { + t.Fatalf("expected success result, got %#v", result) + } + if !backend.queryCalled { + t.Fatalf("expected SQL to be executed") + } + if out.ReadOnly { + t.Fatalf("expected mutating SQL result, got %#v", out) + } +} + +func TestExecuteSQLAllowsDDLWhenAISafetyIsFullAndAllowMutating(t *testing.T) { + backend := &fakeBackend{ + editableConnection: connection.SavedConnectionView{ + ID: "mysql-main", + Config: connection.ConnectionConfig{ + Type: "mysql", + Database: "app", + }, + }, + inspection: appcore.SQLInspection{ + StatementCount: 1, + ReadOnly: false, + Statements: []appcore.SQLStatementInspection{ + {Index: 1, Keyword: "drop", ReadOnly: false}, + }, + }, + safetyLevel: ai.PermissionFull, + queryResult: connection.QueryResult{ + Success: true, + Data: []connection.ResultSetData{}, + }, + } + + service := NewService(backend) + result, _, err := service.ExecuteSQL(context.Background(), nil, executeSQLArgs{ + ConnectionID: "mysql-main", + SQL: "drop table users", + AllowMutating: true, + }) + if err != nil { + t.Fatalf("ExecuteSQL returned error: %v", err) + } + if result == nil || result.IsError { + t.Fatalf("expected success result, got %#v", result) + } + if !backend.queryCalled { + t.Fatalf("expected SQL to be executed") + } +} + +func TestExecuteSQLNormalizesAndTruncatesResultSets(t *testing.T) { + backend := &fakeBackend{ + editableConnection: connection.SavedConnectionView{ + ID: "mysql-main", + Config: connection.ConnectionConfig{ + Type: "mysql", + Database: "app", + }, + }, + inspection: appcore.SQLInspection{ + StatementCount: 1, + ReadOnly: true, + Statements: []appcore.SQLStatementInspection{ + {Index: 1, Keyword: "select", ReadOnly: true}, + }, + }, + queryResult: connection.QueryResult{ + Success: true, + QueryID: "query-1", + Data: []connection.ResultSetData{ + { + StatementIndex: 1, + Columns: []string{"id"}, + Rows: []map[string]interface{}{ + {"id": 1}, + {"id": 2}, + {"id": 3}, + }, + }, + }, + }, + } + + service := NewService(backend) + result, out, err := service.ExecuteSQL(context.Background(), nil, executeSQLArgs{ + ConnectionID: "mysql-main", + SQL: "select id from users", + MaxRowsPerResult: 2, + }) + if err != nil { + t.Fatalf("ExecuteSQL returned error: %v", err) + } + if result == nil || result.IsError { + t.Fatalf("expected success result, got %#v", result) + } + if !backend.queryCalled { + t.Fatalf("expected SQL to be executed") + } + if out.StatementCount != 1 || len(out.Results) != 1 { + t.Fatalf("unexpected output: %#v", out) + } + if out.QueryID != "query-1" { + t.Fatalf("unexpected query id: %q", out.QueryID) + } + if !out.Truncated || !out.Results[0].Truncated { + t.Fatalf("expected truncated result, got %#v", out.Results[0]) + } + if out.Results[0].RowCount != 3 { + t.Fatalf("expected rowCount 3, got %d", out.Results[0].RowCount) + } + if len(out.Results[0].Rows) != 2 { + t.Fatalf("expected 2 returned rows, got %d", len(out.Results[0].Rows)) + } +} + +func firstTextContent(result *mcp.CallToolResult) string { + if result == nil || len(result.Content) == 0 { + return "" + } + text, _ := result.Content[0].(*mcp.TextContent) + if text == nil { + return "" + } + return text.Text +} diff --git a/main.go b/main.go index 0168f65..1629c79 100644 --- a/main.go +++ b/main.go @@ -8,6 +8,7 @@ import ( aiservice "GoNavi-Wails/internal/ai/service" "GoNavi-Wails/internal/app" "GoNavi-Wails/internal/logger" + "GoNavi-Wails/internal/mcpserver" "github.com/wailsapp/wails/v2" "github.com/wailsapp/wails/v2/pkg/options" @@ -17,6 +18,10 @@ import ( ) func main() { + if runSpecialMode(os.Args[1:]) { + return + } + // Create an instance of the app structure application := app.NewApp() aiService := aiservice.NewService() @@ -68,6 +73,30 @@ func main() { } } +func runSpecialMode(args []string) bool { + if !shouldRunMCPServerMode(args) { + return false + } + + if err := mcpserver.RunAppStdioServer(context.Background()); err != nil { + logger.Error(err, "GoNavi MCP Server 退出") + } + return true +} + +func shouldRunMCPServerMode(args []string) bool { + if len(args) == 0 { + return false + } + + switch strings.ToLower(strings.TrimSpace(args[0])) { + case "mcp-server", "--mcp-server": + return true + default: + return false + } +} + func isLowMemoryMode() bool { switch strings.ToLower(strings.TrimSpace(os.Getenv("GONAVI_LOW_MEMORY_MODE"))) { case "1", "true", "yes", "on": diff --git a/main_test.go b/main_test.go index 592d723..e64224e 100644 --- a/main_test.go +++ b/main_test.go @@ -24,3 +24,24 @@ func TestIsLowMemoryMode(t *testing.T) { }) } } + +func TestShouldRunMCPServerMode(t *testing.T) { + cases := []struct { + name string + args []string + want bool + }{ + {name: "empty", args: nil, want: false}, + {name: "mcp-server", args: []string{"mcp-server"}, want: true}, + {name: "flag style", args: []string{"--mcp-server"}, want: true}, + {name: "unknown", args: []string{"serve"}, want: false}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := shouldRunMCPServerMode(tc.args); got != tc.want { + t.Fatalf("shouldRunMCPServerMode(%v) = %v, want %v", tc.args, got, tc.want) + } + }) + } +} diff --git a/tools/claude-gonavi-mcp.cmd b/tools/claude-gonavi-mcp.cmd new file mode 100644 index 0000000..0f67343 --- /dev/null +++ b/tools/claude-gonavi-mcp.cmd @@ -0,0 +1,2 @@ +@echo off +powershell -NoProfile -ExecutionPolicy Bypass -File "%~dp0claude-gonavi-mcp.ps1" %* diff --git a/tools/claude-gonavi-mcp.ps1 b/tools/claude-gonavi-mcp.ps1 new file mode 100644 index 0000000..9f54843 --- /dev/null +++ b/tools/claude-gonavi-mcp.ps1 @@ -0,0 +1,44 @@ +param( + [switch]$SkipBuild +) + +$ErrorActionPreference = 'Stop' +$ClaudeArgs = $args + +$repoRoot = (Resolve-Path (Join-Path $PSScriptRoot '..')).Path +$binDir = Join-Path $repoRoot 'bin' +$serverExe = Join-Path $binDir 'gonavi-mcp-server.exe' + +if (-not $SkipBuild) { + if (-not (Test-Path $binDir)) { + New-Item -ItemType Directory -Path $binDir | Out-Null + } + + & go build -o $serverExe .\cmd\gonavi-mcp-server + if ($LASTEXITCODE -ne 0) { + throw "构建 gonavi-mcp-server 失败" + } +} elseif (-not (Test-Path $serverExe)) { + throw "未找到已编译的 gonavi-mcp-server.exe,请去掉 -SkipBuild 或先手动构建" +} + +$mcpConfig = @{ + mcpServers = @{ + gonavi = @{ + type = 'stdio' + command = $serverExe + args = @() + env = @{} + } + } +} | ConvertTo-Json -Compress -Depth 6 + +$tempConfig = Join-Path ([System.IO.Path]::GetTempPath()) ("gonavi-claude-mcp-" + [System.Guid]::NewGuid().ToString("N") + ".json") + +try { + Set-Content -LiteralPath $tempConfig -Value $mcpConfig -Encoding UTF8 + & claude @ClaudeArgs --mcp-config $tempConfig --strict-mcp-config + exit $LASTEXITCODE +} finally { + Remove-Item -LiteralPath $tempConfig -ErrorAction SilentlyContinue +}