mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-06-14 10:29:52 +08:00
✨ feat(ai-mcp): 完善外部客户端安装链路并收紧 SQL 安全控制
- 新增 GoNavi MCP stdio server 与 Claude/Codex 用户级安装入口 - 增加安装状态检测、刷新复制能力和浏览器联调 mock - 外部 execute_sql 对齐 GoNavi safetyLevel 并补充前端/后端验证
This commit is contained in:
130
cmd/gonavi-mcp-server/README.md
Normal file
130
cmd/gonavi-mcp-server/README.md
Normal file
@@ -0,0 +1,130 @@
|
||||
# GoNavi MCP Server
|
||||
|
||||
`gonavi-mcp-server` 会把 GoNavi 已保存连接背后的数据库能力通过 MCP `stdio` 暴露给外部客户端。
|
||||
|
||||
## 当前提供的 tools
|
||||
|
||||
- `get_connections`
|
||||
- 返回 GoNavi 已保存连接的 `id/name/type/target/defaultDatabase` 等摘要信息
|
||||
- `get_databases`
|
||||
- 入参:`connectionId`
|
||||
- `get_tables`
|
||||
- 入参:`connectionId`、可选 `dbName`
|
||||
- `get_columns`
|
||||
- 入参:`connectionId`、可选 `dbName`、`tableName`
|
||||
- `get_table_ddl`
|
||||
- 入参:`connectionId`、可选 `dbName`、`tableName`
|
||||
- `execute_sql`
|
||||
- 入参:`connectionId`、可选 `dbName`、`sql`
|
||||
- 默认只允许只读 SQL
|
||||
- 如果 SQL 包含 DDL/DML,必须显式传 `allowMutating=true`
|
||||
- `maxRowsPerResult` 用来限制单个结果集返回的行数,默认 `200`
|
||||
|
||||
## 运行方式
|
||||
|
||||
开发态直接运行:
|
||||
|
||||
```powershell
|
||||
go run ./cmd/gonavi-mcp-server
|
||||
```
|
||||
|
||||
也可以先编译:
|
||||
|
||||
```powershell
|
||||
go build -o .\bin\gonavi-mcp-server.exe .\cmd\gonavi-mcp-server
|
||||
```
|
||||
|
||||
## Claude Code / Codex
|
||||
|
||||
正式安装包场景,推荐直接在 GoNavi 里使用“AI 设置 -> MCP 服务 -> 安装到 Claude Code / 安装到 Codex”。
|
||||
|
||||
它会自动把当前安装的 `GoNavi.exe` 写入 Claude Code 的用户级 `~/.claude.json`,命令形态类似:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"gonavi": {
|
||||
"type": "stdio",
|
||||
"command": "C:\\Program Files\\GoNavi\\GoNavi.exe",
|
||||
"args": ["mcp-server"],
|
||||
"env": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
这样用户不需要自己找本机 `gonavi-mcp-server.exe` 路径,安装包本体就能直接作为 MCP 入口。
|
||||
|
||||
Codex 当前使用 `~/.codex/config.toml`,GoNavi 会写入类似下面这段:
|
||||
|
||||
```toml
|
||||
[mcp_servers.gonavi]
|
||||
command = 'C:\Program Files\GoNavi\GoNavi.exe'
|
||||
args = ['mcp-server']
|
||||
startup_timeout_sec = 60
|
||||
```
|
||||
|
||||
仓库开发态如果要在本机 `Claude Code CLI` 里稳定使用这个 MCP,仍然推荐走仓库内包装脚本:
|
||||
|
||||
```powershell
|
||||
.\tools\claude-gonavi-mcp.ps1 -p "必须调用 gonavi MCP 的 get_connections 工具"
|
||||
```
|
||||
|
||||
或者:
|
||||
|
||||
```cmd
|
||||
tools\claude-gonavi-mcp.cmd -p "必须调用 gonavi MCP 的 get_connections 工具"
|
||||
```
|
||||
|
||||
这个脚本会先构建 `bin\gonavi-mcp-server.exe`,再通过 `--mcp-config` 和 `--strict-mcp-config` 把 GoNavi MCP 单独注入当前 Claude 会话,避免默认混合 MCP 加载时序导致的首轮工具未挂载问题。
|
||||
|
||||
## MCP 客户端配置示例
|
||||
|
||||
开发态:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"gonavi": {
|
||||
"command": "go",
|
||||
"args": ["run", "./cmd/gonavi-mcp-server"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Windows 独立 server 编译产物(开发态):
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"gonavi": {
|
||||
"command": "D:\\Work\\CodeRepos\\GoNavi\\bin\\gonavi-mcp-server.exe",
|
||||
"args": []
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Windows 已安装 GoNavi(推荐给最终用户):
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"gonavi": {
|
||||
"type": "stdio",
|
||||
"command": "C:\\Program Files\\GoNavi\\GoNavi.exe",
|
||||
"args": ["mcp-server"],
|
||||
"env": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 使用说明
|
||||
|
||||
- 先调用 `get_connections`,拿到 `connectionId`
|
||||
- 之后所有数据库工具都只传 `connectionId`,由 GoNavi 服务端内部解析保存连接和密钥
|
||||
- 如果 `dbName` 为空,会优先使用该保存连接里的默认数据库
|
||||
- Server 会读取 GoNavi 当前活动数据目录里的连接配置,并通过系统 keyring/凭据管理器解析密文
|
||||
- 如果本机凭据存储不可用,依赖密钥的连接会返回对应错误
|
||||
15
cmd/gonavi-mcp-server/main.go
Normal file
15
cmd/gonavi-mcp-server/main.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
|
||||
"GoNavi-Wails/internal/mcpserver"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
if err := mcpserver.RunAppStdioServer(ctx); err != nil {
|
||||
log.Printf("GoNavi MCP Server 退出: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -191,6 +191,18 @@ describe('tool center menu entries', () => {
|
||||
expect(appSource).toContain('该异常不一定表现为 viewport ratio drift');
|
||||
});
|
||||
|
||||
it('captures window state on startup and lifecycle events instead of waiting only for the polling interval', () => {
|
||||
expect(appSource).toContain('const scheduleWindowStateSave = (delayMs = 120) => {');
|
||||
expect(appSource).toContain('if (hydrated) {');
|
||||
expect(appSource).toContain('scheduleWindowStateSave(320);');
|
||||
expect(appSource).toContain('const unsubscribeHydration = useStore.persist.onFinishHydration(() => {');
|
||||
expect(appSource).toContain("window.addEventListener('resize', handleWindowRuntimeChange);");
|
||||
expect(appSource).toContain("window.addEventListener('focus', handleWindowRuntimeChange);");
|
||||
expect(appSource).toContain("window.addEventListener('pageshow', handleWindowRuntimeChange);");
|
||||
expect(appSource).toContain("window.addEventListener('pagehide', handleWindowLifecycleFlush, { capture: true });");
|
||||
expect(appSource).toContain("window.addEventListener('beforeunload', handleWindowLifecycleFlush, { capture: true });");
|
||||
});
|
||||
|
||||
it('keeps titlebar double-click on maximise while shortcuts may enter macOS fullscreen', () => {
|
||||
expect(appSource).toContain('const handleTitleBarWindowToggle = async (options?: { allowMacNativeFullscreen?: boolean }) => {');
|
||||
expect(appSource).toContain('const allowMacNativeFullscreen = options?.allowMacNativeFullscreen === true;');
|
||||
@@ -204,6 +216,12 @@ describe('tool center menu entries', () => {
|
||||
expect(appSource).toContain("window.removeEventListener('keydown', handleGlobalShortcut, true);");
|
||||
});
|
||||
|
||||
it('skips the native mac titlebar bridge when the current runtime does not expose it', () => {
|
||||
expect(appSource).toContain("const backendApp = (window as any).go?.app?.App;");
|
||||
expect(appSource).toContain("if (typeof backendApp?.SetMacNativeWindowControls !== 'function') {");
|
||||
expect(appSource).toContain('void safeWindowRuntimeCall(() => SetMacNativeWindowControls(useNativeMacWindowControls), undefined);');
|
||||
});
|
||||
|
||||
it('listens for command search query-tab events and routes them through handleNewQuery', () => {
|
||||
expect(appSource).toContain("window.addEventListener('gonavi:create-query-tab', handleCreateQueryTabEvent as EventListener);");
|
||||
expect(appSource).toContain("window.removeEventListener('gonavi:create-query-tab', handleCreateQueryTabEvent as EventListener);");
|
||||
|
||||
@@ -790,9 +790,15 @@ function App() {
|
||||
// 定时保存窗口状态、尺寸与位置
|
||||
useEffect(() => {
|
||||
const SAVE_INTERVAL_MS = 2000;
|
||||
let cancelled = false;
|
||||
let hydrated = useStore.persist.hasHydrated();
|
||||
let eventSaveTimer: number | null = null;
|
||||
let lastSaved = '';
|
||||
|
||||
const saveWindowState = async () => {
|
||||
if (cancelled || !hydrated) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const [isFs, isMax] = await Promise.all([
|
||||
safeWindowRuntimeCall(() => WindowIsFullscreen(), false),
|
||||
@@ -836,8 +842,67 @@ function App() {
|
||||
}
|
||||
};
|
||||
|
||||
const timer = window.setInterval(saveWindowState, SAVE_INTERVAL_MS);
|
||||
return () => window.clearInterval(timer);
|
||||
const scheduleWindowStateSave = (delayMs = 120) => {
|
||||
if (cancelled || !hydrated) {
|
||||
return;
|
||||
}
|
||||
if (eventSaveTimer !== null) {
|
||||
window.clearTimeout(eventSaveTimer);
|
||||
}
|
||||
eventSaveTimer = window.setTimeout(() => {
|
||||
eventSaveTimer = null;
|
||||
void saveWindowState();
|
||||
}, delayMs);
|
||||
};
|
||||
|
||||
const handleWindowRuntimeChange = () => {
|
||||
scheduleWindowStateSave();
|
||||
};
|
||||
|
||||
const handleVisibilityChange = () => {
|
||||
if (document.visibilityState === 'visible') {
|
||||
scheduleWindowStateSave(120);
|
||||
}
|
||||
};
|
||||
|
||||
const handleWindowLifecycleFlush = () => {
|
||||
void saveWindowState();
|
||||
};
|
||||
|
||||
if (hydrated) {
|
||||
scheduleWindowStateSave(320);
|
||||
}
|
||||
const unsubscribeHydration = useStore.persist.onFinishHydration(() => {
|
||||
if (cancelled || hydrated) {
|
||||
return;
|
||||
}
|
||||
hydrated = true;
|
||||
scheduleWindowStateSave(320);
|
||||
});
|
||||
|
||||
const timer = window.setInterval(() => {
|
||||
void saveWindowState();
|
||||
}, SAVE_INTERVAL_MS);
|
||||
window.addEventListener('resize', handleWindowRuntimeChange);
|
||||
window.addEventListener('focus', handleWindowRuntimeChange);
|
||||
window.addEventListener('pageshow', handleWindowRuntimeChange);
|
||||
window.addEventListener('pagehide', handleWindowLifecycleFlush, { capture: true });
|
||||
window.addEventListener('beforeunload', handleWindowLifecycleFlush, { capture: true });
|
||||
document.addEventListener('visibilitychange', handleVisibilityChange);
|
||||
return () => {
|
||||
cancelled = true;
|
||||
if (eventSaveTimer !== null) {
|
||||
window.clearTimeout(eventSaveTimer);
|
||||
}
|
||||
window.clearInterval(timer);
|
||||
window.removeEventListener('resize', handleWindowRuntimeChange);
|
||||
window.removeEventListener('focus', handleWindowRuntimeChange);
|
||||
window.removeEventListener('pageshow', handleWindowRuntimeChange);
|
||||
window.removeEventListener('pagehide', handleWindowLifecycleFlush, { capture: true });
|
||||
window.removeEventListener('beforeunload', handleWindowLifecycleFlush, { capture: true });
|
||||
document.removeEventListener('visibilitychange', handleVisibilityChange);
|
||||
unsubscribeHydration();
|
||||
};
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
@@ -1567,12 +1632,11 @@ function App() {
|
||||
if (!isStoreHydrated || !isMacRuntime) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
void SetMacNativeWindowControls(useNativeMacWindowControls).catch(() => undefined);
|
||||
} catch (e) {
|
||||
console.warn('Wails API: SetMacNativeWindowControls unavailable', e);
|
||||
const backendApp = (window as any).go?.app?.App;
|
||||
if (typeof backendApp?.SetMacNativeWindowControls !== 'function') {
|
||||
return;
|
||||
}
|
||||
void safeWindowRuntimeCall(() => SetMacNativeWindowControls(useNativeMacWindowControls), undefined);
|
||||
}, [isMacRuntime, isStoreHydrated, useNativeMacWindowControls]);
|
||||
|
||||
useEffect(() => {
|
||||
|
||||
@@ -18,6 +18,7 @@ describe('AISettingsModal edit password behavior', () => {
|
||||
});
|
||||
|
||||
it('loads MCP servers and skills through the AI service', () => {
|
||||
expect(source).toContain('Service.AIGetMCPClientInstallStatuses?.()');
|
||||
expect(source).toContain('Service.AIGetMCPServers?.()');
|
||||
expect(source).toContain('Service.AIListMCPTools?.()');
|
||||
expect(source).toContain('Service.AIGetSkills?.()');
|
||||
@@ -25,6 +26,26 @@ describe('AISettingsModal edit password behavior', () => {
|
||||
expect(source).toContain('新增 Skill');
|
||||
});
|
||||
|
||||
it('explains external MCP installation and renders selectable client install states', () => {
|
||||
expect(source).toContain('把 GoNavi 注册成外部 AI 客户端可调用的 MCP Server');
|
||||
expect(source).toContain('安装到外部客户端');
|
||||
expect(source).toContain('未安装');
|
||||
expect(source).toContain('需更新');
|
||||
expect(source).toContain('已安装');
|
||||
expect(source).toContain('刷新状态');
|
||||
expect(source).toContain('复制配置路径');
|
||||
expect(source).toContain('复制启动命令');
|
||||
expect(source).toContain('handleInstallSelectedMCPClient');
|
||||
expect(source).toContain('无需重复安装');
|
||||
});
|
||||
|
||||
it('waits briefly for the AI service bridge before warning and removes noisy provider debug logs', () => {
|
||||
expect(source).toContain('const resolveAIService = useCallback(async () => {');
|
||||
expect(source).toContain('const service = await waitForAIService();');
|
||||
expect(source).not.toContain("console.log('[AI] AIGetProviders result:'");
|
||||
expect(source).not.toContain("console.log('[AI] AIGetActiveProvider result:'");
|
||||
});
|
||||
|
||||
it('keeps the prefilled api key masked by default', () => {
|
||||
expect(source).toContain('const [primaryPasswordVisible, setPrimaryPasswordVisible] = useState(false);');
|
||||
expect(source).toContain('visible: primaryPasswordVisible,');
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import React, { useState, useEffect, useCallback, useMemo, useRef } from 'react';
|
||||
import { Modal, Button, Input, Select, Form, message as antdMessage, Tooltip, Tabs, Space, Popconfirm, Slider } from 'antd';
|
||||
import { PlusOutlined, DeleteOutlined, EditOutlined, CheckOutlined, ApiOutlined, SafetyCertificateOutlined, RobotOutlined, ThunderboltOutlined, CloudOutlined, ExperimentOutlined, KeyOutlined, LinkOutlined, AppstoreOutlined, ToolOutlined } from '@ant-design/icons';
|
||||
import type { AIProviderConfig, AIProviderType, AISafetyLevel, AIContextLevel, AIUserPromptSettings, AIMCPServerConfig, AIMCPToolDescriptor, AISkillConfig, AISkillScope } from '../types';
|
||||
import { PlusOutlined, DeleteOutlined, EditOutlined, CheckOutlined, ApiOutlined, SafetyCertificateOutlined, RobotOutlined, ThunderboltOutlined, CloudOutlined, ExperimentOutlined, KeyOutlined, LinkOutlined, AppstoreOutlined, ToolOutlined, ReloadOutlined, CopyOutlined } from '@ant-design/icons';
|
||||
import type { AIProviderConfig, AIProviderType, AISafetyLevel, AIContextLevel, AIUserPromptSettings, AIMCPServerConfig, AIMCPToolDescriptor, AIMCPClientInstallStatus, AISkillConfig, AISkillScope } from '../types';
|
||||
import {
|
||||
QWEN_BAILIAN_ANTHROPIC_BASE_URL,
|
||||
QWEN_CODING_PLAN_ANTHROPIC_BASE_URL,
|
||||
@@ -30,6 +30,17 @@ interface AISettingsModalProps {
|
||||
focusProviderId?: string;
|
||||
}
|
||||
|
||||
interface MCPClientInstallResult {
|
||||
success?: boolean;
|
||||
client?: string;
|
||||
message?: string;
|
||||
configPath?: string;
|
||||
command?: string;
|
||||
args?: string[];
|
||||
}
|
||||
|
||||
type MCPClientKey = 'claude-code' | 'codex';
|
||||
|
||||
// 预设配置:每个预设映射到后端 type(openai/anthropic/gemini/custom)并附带默认 URL 和 Model
|
||||
interface ProviderPreset {
|
||||
key: string;
|
||||
@@ -97,6 +108,100 @@ const EMPTY_MCP_SERVER = (): AIMCPServerConfig => ({
|
||||
timeoutSeconds: 20,
|
||||
});
|
||||
|
||||
const EMPTY_MCP_CLIENT_STATUSES: AIMCPClientInstallStatus[] = [
|
||||
{
|
||||
client: 'claude-code',
|
||||
displayName: 'Claude Code',
|
||||
installed: false,
|
||||
matchesCurrent: false,
|
||||
message: '未安装到 Claude Code 用户级配置',
|
||||
},
|
||||
{
|
||||
client: 'codex',
|
||||
displayName: 'Codex',
|
||||
installed: false,
|
||||
matchesCurrent: false,
|
||||
message: '未安装到 Codex 用户级配置',
|
||||
},
|
||||
];
|
||||
|
||||
const normalizeMCPClientStatuses = (items?: AIMCPClientInstallStatus[]): AIMCPClientInstallStatus[] => {
|
||||
const baseMap = new Map<string, AIMCPClientInstallStatus>(
|
||||
EMPTY_MCP_CLIENT_STATUSES.map((item) => [item.client, { ...item }]),
|
||||
);
|
||||
(Array.isArray(items) ? items : []).forEach((item) => {
|
||||
if (!item || !item.client) {
|
||||
return;
|
||||
}
|
||||
const base = baseMap.get(item.client) || {
|
||||
client: item.client,
|
||||
displayName: item.client,
|
||||
installed: false,
|
||||
matchesCurrent: false,
|
||||
message: '',
|
||||
};
|
||||
baseMap.set(item.client, {
|
||||
...base,
|
||||
...item,
|
||||
displayName: item.displayName || base.displayName,
|
||||
message: item.message || base.message,
|
||||
args: Array.isArray(item.args) ? item.args : (base.args || []),
|
||||
});
|
||||
});
|
||||
return (['claude-code', 'codex'] as MCPClientKey[])
|
||||
.map((client) => baseMap.get(client))
|
||||
.filter((item): item is AIMCPClientInstallStatus => Boolean(item));
|
||||
};
|
||||
|
||||
const pickPreferredMCPClient = (items: AIMCPClientInstallStatus[], current?: MCPClientKey): MCPClientKey => {
|
||||
if (current && items.some((item) => item.client === current)) {
|
||||
return current;
|
||||
}
|
||||
const pending = items.find((item) => !item.matchesCurrent);
|
||||
if (pending?.client === 'claude-code' || pending?.client === 'codex') {
|
||||
return pending.client;
|
||||
}
|
||||
return 'claude-code';
|
||||
};
|
||||
|
||||
const waitFor = (delayMs: number) => new Promise<void>((resolve) => {
|
||||
window.setTimeout(resolve, delayMs);
|
||||
});
|
||||
|
||||
const readAIService = () => (window as any).go?.aiservice?.Service;
|
||||
|
||||
const waitForAIService = async (attempts = 6, delayMs = 80) => {
|
||||
for (let attempt = 0; attempt < attempts; attempt += 1) {
|
||||
const service = readAIService();
|
||||
if (service) {
|
||||
return service;
|
||||
}
|
||||
if (attempt < attempts - 1) {
|
||||
await waitFor(delayMs);
|
||||
}
|
||||
}
|
||||
return readAIService();
|
||||
};
|
||||
|
||||
const quoteMCPCommandPart = (value: string): string => {
|
||||
const text = String(value || '').trim();
|
||||
if (!text) {
|
||||
return '';
|
||||
}
|
||||
return /[\s"]/u.test(text) ? `"${text.replace(/"/g, '\\"')}"` : text;
|
||||
};
|
||||
|
||||
const formatMCPLaunchCommand = (input?: Pick<AIMCPClientInstallStatus, 'command' | 'args'> | Pick<MCPClientInstallResult, 'command' | 'args'> | null): string => {
|
||||
const command = String(input?.command || '').trim();
|
||||
if (!command) {
|
||||
return '';
|
||||
}
|
||||
const args = Array.isArray(input?.args)
|
||||
? input.args.map((item) => String(item || '').trim()).filter(Boolean)
|
||||
: [];
|
||||
return [command, ...args].map(quoteMCPCommandPart).filter(Boolean).join(' ');
|
||||
};
|
||||
|
||||
const EMPTY_SKILL = (): AISkillConfig => ({
|
||||
id: `skill-draft-${Date.now()}-${Math.random().toString(36).slice(2, 8)}`,
|
||||
name: '',
|
||||
@@ -142,6 +247,9 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
const [contextLevel, setContextLevel] = useState<AIContextLevel>('schema_only');
|
||||
const [mcpServers, setMCPServers] = useState<AIMCPServerConfig[]>([]);
|
||||
const [mcpTools, setMCPTools] = useState<AIMCPToolDescriptor[]>([]);
|
||||
const [mcpClientStatuses, setMCPClientStatuses] = useState<AIMCPClientInstallStatus[]>(EMPTY_MCP_CLIENT_STATUSES);
|
||||
const [selectedMCPClient, setSelectedMCPClient] = useState<MCPClientKey>('claude-code');
|
||||
const [mcpClientStatusLoading, setMCPClientStatusLoading] = useState(false);
|
||||
const [skills, setSkills] = useState<AISkillConfig[]>([]);
|
||||
const [editingProvider, setEditingProvider] = useState<AIProviderConfig | null>(null);
|
||||
const [isEditing, setIsEditing] = useState(false);
|
||||
@@ -153,6 +261,7 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
const [primaryPasswordVisible, setPrimaryPasswordVisible] = useState(false);
|
||||
const [form] = Form.useForm();
|
||||
const modalBodyRef = useRef<HTMLDivElement>(null);
|
||||
const missingAIServiceWarnedRef = useRef(false);
|
||||
|
||||
// Modal 内部 toast 通知
|
||||
const [messageApi, messageContextHolder] = antdMessage.useMessage({ getContainer: () => modalBodyRef.current || document.body });
|
||||
@@ -163,6 +272,35 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
const cardHoverBg = darkMode ? 'rgba(255,255,255,0.06)' : 'rgba(0,0,0,0.03)';
|
||||
const sectionLabelColor = darkMode ? 'rgba(255,255,255,0.5)' : 'rgba(0,0,0,0.4)';
|
||||
const inputBg = darkMode ? 'rgba(255,255,255,0.04)' : 'rgba(0,0,0,0.02)';
|
||||
const getMCPClientStatusTone = useCallback((status?: AIMCPClientInstallStatus) => {
|
||||
const messageText = String(status?.message || '');
|
||||
if (status?.matchesCurrent) {
|
||||
return {
|
||||
label: '已安装',
|
||||
color: '#16a34a',
|
||||
bg: darkMode ? 'rgba(34,197,94,0.18)' : 'rgba(34,197,94,0.12)',
|
||||
};
|
||||
}
|
||||
if (status?.installed) {
|
||||
return {
|
||||
label: '需更新',
|
||||
color: '#d97706',
|
||||
bg: darkMode ? 'rgba(245,158,11,0.18)' : 'rgba(245,158,11,0.12)',
|
||||
};
|
||||
}
|
||||
if (messageText.includes('失败') || messageText.includes('异常')) {
|
||||
return {
|
||||
label: '需检查',
|
||||
color: '#dc2626',
|
||||
bg: darkMode ? 'rgba(239,68,68,0.18)' : 'rgba(239,68,68,0.1)',
|
||||
};
|
||||
}
|
||||
return {
|
||||
label: '未安装',
|
||||
color: darkMode ? 'rgba(255,255,255,0.72)' : '#64748b',
|
||||
bg: darkMode ? 'rgba(255,255,255,0.08)' : 'rgba(100,116,139,0.08)',
|
||||
};
|
||||
}, [darkMode]);
|
||||
|
||||
// Hook 必须在组件顶层调用,不能在条件分支内
|
||||
const watchedType = Form.useWatch('type', form);
|
||||
@@ -178,11 +316,71 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
value: tool.alias,
|
||||
})),
|
||||
]), [mcpTools]);
|
||||
const selectedMCPClientStatus = useMemo(
|
||||
() => mcpClientStatuses.find((item) => item.client === selectedMCPClient) || mcpClientStatuses[0],
|
||||
[mcpClientStatuses, selectedMCPClient],
|
||||
);
|
||||
const selectedMCPClientCommandText = useMemo(
|
||||
() => formatMCPLaunchCommand(selectedMCPClientStatus),
|
||||
[selectedMCPClientStatus],
|
||||
);
|
||||
|
||||
const resolveAIService = useCallback(async () => {
|
||||
const service = await waitForAIService();
|
||||
if (service) {
|
||||
missingAIServiceWarnedRef.current = false;
|
||||
return service;
|
||||
}
|
||||
if (!missingAIServiceWarnedRef.current) {
|
||||
console.warn('[AI] Service not found on window.go');
|
||||
missingAIServiceWarnedRef.current = true;
|
||||
}
|
||||
return null;
|
||||
}, []);
|
||||
|
||||
const loadMCPClientStatuses = useCallback(async (options?: { silent?: boolean }) => {
|
||||
const silent = options?.silent === true;
|
||||
if (!silent) {
|
||||
setMCPClientStatusLoading(true);
|
||||
}
|
||||
try {
|
||||
const Service = await resolveAIService();
|
||||
if (typeof Service?.AIGetMCPClientInstallStatuses !== 'function') {
|
||||
return;
|
||||
}
|
||||
const result = await Service.AIGetMCPClientInstallStatuses();
|
||||
if (Array.isArray(result)) {
|
||||
const normalizedStatuses = normalizeMCPClientStatuses(result);
|
||||
setMCPClientStatuses(normalizedStatuses);
|
||||
setSelectedMCPClient((prev) => pickPreferredMCPClient(normalizedStatuses, prev));
|
||||
}
|
||||
} catch (e: any) {
|
||||
if (silent) {
|
||||
console.warn('[AI] refresh mcp client statuses failed', e);
|
||||
} else {
|
||||
void messageApi.error(e?.message || '刷新客户端安装状态失败');
|
||||
}
|
||||
} finally {
|
||||
if (!silent) {
|
||||
setMCPClientStatusLoading(false);
|
||||
}
|
||||
}
|
||||
}, [messageApi, resolveAIService]);
|
||||
|
||||
const copyTextToClipboard = useCallback(async (text: string, successMessage: string) => {
|
||||
if (typeof navigator?.clipboard?.writeText !== 'function') {
|
||||
throw new Error('当前环境不支持复制到剪贴板');
|
||||
}
|
||||
await navigator.clipboard.writeText(text);
|
||||
void messageApi.success(successMessage);
|
||||
}, [messageApi]);
|
||||
|
||||
const loadConfig = useCallback(async () => {
|
||||
try {
|
||||
const Service = (window as any).go?.aiservice?.Service;
|
||||
if (!Service) { console.warn('[AI] Service not found on window.go'); return; }
|
||||
const Service = await resolveAIService();
|
||||
if (!Service) {
|
||||
return;
|
||||
}
|
||||
const callOrFallback = async <T,>(loader: (() => Promise<T>) | undefined, fallback: T): Promise<T> => {
|
||||
if (typeof loader !== 'function') {
|
||||
return fallback;
|
||||
@@ -194,7 +392,7 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
return fallback;
|
||||
}
|
||||
};
|
||||
const [provRes, safeRes, ctxRes, promptsRes, userPromptsRes, mcpServersRes, mcpToolsRes, skillsRes] = await Promise.all([
|
||||
const [provRes, safeRes, ctxRes, promptsRes, userPromptsRes, mcpServersRes, mcpToolsRes, skillsRes, mcpClientStatusesRes] = await Promise.all([
|
||||
callOrFallback(() => Service.AIGetProviders?.(), []),
|
||||
callOrFallback<AISafetyLevel>(() => Service.AIGetSafetyLevel?.(), 'readonly'),
|
||||
callOrFallback<AIContextLevel>(() => Service.AIGetContextLevel?.(), 'schema_only'),
|
||||
@@ -203,12 +401,11 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
callOrFallback(() => Service.AIGetMCPServers?.(), []),
|
||||
callOrFallback(() => Service.AIListMCPTools?.(), []),
|
||||
callOrFallback(() => Service.AIGetSkills?.(), []),
|
||||
callOrFallback<AIMCPClientInstallStatus[]>(() => Service.AIGetMCPClientInstallStatuses?.(), EMPTY_MCP_CLIENT_STATUSES),
|
||||
]);
|
||||
console.log('[AI] AIGetProviders result:', JSON.stringify(provRes), 'isArray:', Array.isArray(provRes));
|
||||
if (Array.isArray(provRes)) {
|
||||
setProviders(provRes);
|
||||
const activeRes = await Service.AIGetActiveProvider?.();
|
||||
console.log('[AI] AIGetActiveProvider result:', activeRes);
|
||||
if (activeRes) setActiveProviderId(activeRes);
|
||||
}
|
||||
if (safeRes) setSafetyLevel(safeRes);
|
||||
@@ -223,8 +420,13 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
if (Array.isArray(mcpServersRes)) setMCPServers(mcpServersRes);
|
||||
if (Array.isArray(mcpToolsRes)) setMCPTools(mcpToolsRes);
|
||||
if (Array.isArray(skillsRes)) setSkills(skillsRes);
|
||||
if (Array.isArray(mcpClientStatusesRes)) {
|
||||
const normalizedStatuses = normalizeMCPClientStatuses(mcpClientStatusesRes);
|
||||
setMCPClientStatuses(normalizedStatuses);
|
||||
setSelectedMCPClient((prev) => pickPreferredMCPClient(normalizedStatuses, prev));
|
||||
}
|
||||
} catch (e) { console.warn('Failed to load AI config', e); }
|
||||
}, []);
|
||||
}, [resolveAIService]);
|
||||
|
||||
useEffect(() => { if (open) void loadConfig(); }, [open, loadConfig]);
|
||||
|
||||
@@ -491,6 +693,63 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
}
|
||||
};
|
||||
|
||||
const handleInstallSelectedMCPClient = async () => {
|
||||
const targetClient = selectedMCPClientStatus?.client === 'codex' ? 'codex' : 'claude-code';
|
||||
const targetLabel = selectedMCPClientStatus?.displayName || (targetClient === 'codex' ? 'Codex' : 'Claude Code');
|
||||
if (selectedMCPClientStatus?.matchesCurrent) {
|
||||
void messageApi.success(`${targetLabel} 已安装当前 GoNavi MCP,无需重复安装`);
|
||||
return;
|
||||
}
|
||||
try {
|
||||
setLoading(true);
|
||||
const Service = await resolveAIService();
|
||||
let result: MCPClientInstallResult;
|
||||
if (targetClient === 'codex') {
|
||||
if (typeof Service?.AIInstallCodexMCP !== 'function') {
|
||||
throw new Error('当前版本暂不支持自动安装 Codex MCP');
|
||||
}
|
||||
result = await Service.AIInstallCodexMCP() as MCPClientInstallResult;
|
||||
} else {
|
||||
if (typeof Service?.AIInstallClaudeCodeMCP !== 'function') {
|
||||
throw new Error('当前版本暂不支持自动安装 Claude Code MCP');
|
||||
}
|
||||
result = await Service.AIInstallClaudeCodeMCP() as MCPClientInstallResult;
|
||||
}
|
||||
await loadMCPClientStatuses({ silent: true });
|
||||
window.dispatchEvent(new CustomEvent('gonavi:ai:config-changed'));
|
||||
void messageApi.success(result?.message || `已写入 ${targetLabel} 用户级 MCP 配置`);
|
||||
} catch (e: any) {
|
||||
void messageApi.error(e?.message || `安装 ${targetLabel} MCP 失败`);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleCopySelectedMCPConfigPath = useCallback(async () => {
|
||||
const configPath = String(selectedMCPClientStatus?.configPath || '').trim();
|
||||
if (!configPath) {
|
||||
void messageApi.warning('当前没有可复制的配置文件路径');
|
||||
return;
|
||||
}
|
||||
try {
|
||||
await copyTextToClipboard(configPath, '配置文件路径已复制');
|
||||
} catch (e: any) {
|
||||
void messageApi.error(e?.message || '复制配置文件路径失败');
|
||||
}
|
||||
}, [copyTextToClipboard, messageApi, selectedMCPClientStatus]);
|
||||
|
||||
const handleCopySelectedMCPLaunchCommand = useCallback(async () => {
|
||||
if (!selectedMCPClientCommandText) {
|
||||
void messageApi.warning('当前没有可复制的启动命令');
|
||||
return;
|
||||
}
|
||||
try {
|
||||
await copyTextToClipboard(selectedMCPClientCommandText, '启动命令已复制');
|
||||
} catch (e: any) {
|
||||
void messageApi.error(e?.message || '复制启动命令失败');
|
||||
}
|
||||
}, [copyTextToClipboard, messageApi, selectedMCPClientCommandText]);
|
||||
|
||||
const updateSkillDraft = (id: string, patch: Partial<AISkillConfig>) => {
|
||||
setSkills((prev) => prev.map((item) => item.id === id ? { ...item, ...patch } : item));
|
||||
};
|
||||
@@ -983,8 +1242,165 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
|
||||
const renderMCPSettings = () => (
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: 12 }}>
|
||||
<div style={{ fontSize: 13, color: overlayTheme.mutedText, marginBottom: 4 }}>
|
||||
MCP 会作为外部工具源接入 AI。当前阶段先支持 `stdio` 型服务,不需要为 GoNavi 的 MCP client 单独新建仓库;只有你准备发布独立的 MCP Server 时,才值得拆独立仓库。
|
||||
<div style={{ fontSize: 13, color: overlayTheme.mutedText, marginBottom: 4, lineHeight: 1.7 }}>
|
||||
这里的“安装到客户端”是把 GoNavi 注册成外部 AI 客户端可调用的 MCP Server,供 Claude Code 或 Codex 使用;不是 GoNavi 自己安装自己。
|
||||
</div>
|
||||
<div style={{
|
||||
padding: '16px',
|
||||
borderRadius: 14,
|
||||
border: `1px solid ${cardBorder}`,
|
||||
background: cardBg,
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
gap: 14,
|
||||
}}>
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: 4 }}>
|
||||
<div style={{ fontWeight: 700, fontSize: 14, color: overlayTheme.titleText }}>安装到外部客户端</div>
|
||||
<div style={{ fontSize: 12, color: overlayTheme.mutedText, lineHeight: 1.7 }}>
|
||||
先选择目标客户端,再把当前 GoNavi 安装路径写入它的用户级 MCP 配置。GoNavi 会自动处理配置文件路径,不需要你自己找本机 exe。
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div style={{ display: 'grid', gridTemplateColumns: 'repeat(auto-fit, minmax(220px, 1fr))', gap: 10 }}>
|
||||
{mcpClientStatuses.map((status) => {
|
||||
const active = selectedMCPClient === status.client;
|
||||
const tone = getMCPClientStatusTone(status);
|
||||
return (
|
||||
<div
|
||||
key={status.client}
|
||||
onClick={() => {
|
||||
if (status.client === 'claude-code' || status.client === 'codex') {
|
||||
setSelectedMCPClient(status.client);
|
||||
}
|
||||
}}
|
||||
style={{
|
||||
padding: '14px 14px 12px',
|
||||
borderRadius: 12,
|
||||
border: `1.5px solid ${active ? overlayTheme.selectedText : cardBorder}`,
|
||||
background: active ? overlayTheme.selectedBg : (darkMode ? 'rgba(255,255,255,0.02)' : 'rgba(255,255,255,0.7)'),
|
||||
cursor: 'pointer',
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
gap: 10,
|
||||
transition: 'all 0.2s ease',
|
||||
}}
|
||||
>
|
||||
<div style={{ display: 'flex', alignItems: 'center', justifyContent: 'space-between', gap: 12 }}>
|
||||
<div style={{ fontWeight: 700, fontSize: 14, color: overlayTheme.titleText }}>
|
||||
{status.displayName}
|
||||
</div>
|
||||
<div style={{
|
||||
padding: '4px 10px',
|
||||
borderRadius: 999,
|
||||
fontSize: 12,
|
||||
fontWeight: 700,
|
||||
color: tone.color,
|
||||
background: tone.bg,
|
||||
whiteSpace: 'nowrap',
|
||||
}}>
|
||||
{tone.label}
|
||||
</div>
|
||||
</div>
|
||||
<div style={{ fontSize: 12, color: overlayTheme.mutedText, lineHeight: 1.6 }}>
|
||||
{status.matchesCurrent
|
||||
? '当前 GoNavi 安装路径已写入,打开客户端后可直接使用。'
|
||||
: status.installed
|
||||
? '检测到已有安装记录,但建议更新为当前 GoNavi 路径。'
|
||||
: '当前尚未写入 GoNavi MCP 配置。'}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
|
||||
<div style={{
|
||||
padding: '12px 14px',
|
||||
borderRadius: 12,
|
||||
border: `1px solid ${cardBorder}`,
|
||||
background: darkMode ? 'rgba(255,255,255,0.03)' : 'rgba(255,255,255,0.78)',
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
gap: 6,
|
||||
}}>
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: 8, flexWrap: 'wrap' }}>
|
||||
<div style={{ fontWeight: 700, fontSize: 13, color: overlayTheme.titleText }}>
|
||||
{selectedMCPClientStatus?.displayName || '客户端'} 状态
|
||||
</div>
|
||||
{selectedMCPClientStatus && (
|
||||
<div style={{
|
||||
padding: '3px 9px',
|
||||
borderRadius: 999,
|
||||
fontSize: 11,
|
||||
fontWeight: 700,
|
||||
color: getMCPClientStatusTone(selectedMCPClientStatus).color,
|
||||
background: getMCPClientStatusTone(selectedMCPClientStatus).bg,
|
||||
}}>
|
||||
{getMCPClientStatusTone(selectedMCPClientStatus).label}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div style={{ fontSize: 12, color: overlayTheme.mutedText, lineHeight: 1.7 }}>
|
||||
{selectedMCPClientStatus?.message || '未检测到安装状态'}
|
||||
</div>
|
||||
{selectedMCPClientStatus?.configPath && (
|
||||
<div style={{ fontSize: 12, color: overlayTheme.mutedText, lineHeight: 1.6, fontFamily: 'var(--gn-font-mono)' }}>
|
||||
配置文件:{selectedMCPClientStatus.configPath}
|
||||
</div>
|
||||
)}
|
||||
{selectedMCPClientCommandText && (
|
||||
<div style={{ fontSize: 12, color: overlayTheme.mutedText, lineHeight: 1.6, fontFamily: 'var(--gn-font-mono)' }}>
|
||||
启动命令:{selectedMCPClientCommandText}
|
||||
</div>
|
||||
)}
|
||||
<div style={{ display: 'flex', gap: 8, flexWrap: 'wrap' }}>
|
||||
<Button
|
||||
size="small"
|
||||
icon={<ReloadOutlined />}
|
||||
loading={mcpClientStatusLoading}
|
||||
onClick={() => void loadMCPClientStatuses()}
|
||||
style={{ borderRadius: 8 }}
|
||||
>
|
||||
刷新状态
|
||||
</Button>
|
||||
<Button
|
||||
size="small"
|
||||
icon={<CopyOutlined />}
|
||||
disabled={!selectedMCPClientStatus?.configPath}
|
||||
onClick={() => void handleCopySelectedMCPConfigPath()}
|
||||
style={{ borderRadius: 8 }}
|
||||
>
|
||||
复制配置路径
|
||||
</Button>
|
||||
<Button
|
||||
size="small"
|
||||
icon={<CopyOutlined />}
|
||||
disabled={!selectedMCPClientCommandText}
|
||||
onClick={() => void handleCopySelectedMCPLaunchCommand()}
|
||||
style={{ borderRadius: 8 }}
|
||||
>
|
||||
复制启动命令
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', gap: 12, flexWrap: 'wrap' }}>
|
||||
<div style={{ fontSize: 12, color: overlayTheme.mutedText, lineHeight: 1.6 }}>
|
||||
安装后重启对应客户端即可生效;若已经是当前路径,会直接提示无需重复安装。
|
||||
</div>
|
||||
<Button
|
||||
type={selectedMCPClientStatus?.matchesCurrent ? 'default' : 'primary'}
|
||||
onClick={handleInstallSelectedMCPClient}
|
||||
loading={loading}
|
||||
disabled={Boolean(selectedMCPClientStatus?.matchesCurrent)}
|
||||
style={{ borderRadius: 10, fontWeight: 600, minWidth: 176, height: 40 }}
|
||||
>
|
||||
{selectedMCPClientStatus?.matchesCurrent
|
||||
? `${selectedMCPClientStatus.displayName} 已安装`
|
||||
: selectedMCPClientStatus?.installed
|
||||
? `更新到 ${selectedMCPClientStatus?.displayName || '客户端'}`
|
||||
: `安装到 ${selectedMCPClientStatus?.displayName || '客户端'}`}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', gap: 12 }}>
|
||||
<div style={{ fontSize: 12, color: overlayTheme.mutedText }}>支持命令、参数、环境变量和超时,保存后会自动进入 AI 工具列表。</div>
|
||||
@@ -1217,7 +1633,7 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
>
|
||||
<div ref={modalBodyRef} className="ai-settings-body" style={{ display: 'grid', gridTemplateColumns: '180px minmax(0, 1fr)', gap: 16, padding: '12px 0', height: '100%', minHeight: 0, overflow: 'hidden', alignItems: 'stretch', position: 'relative' }}>
|
||||
{messageContextHolder}
|
||||
<div style={{ padding: '0 12px', height: 'fit-content' }}>
|
||||
<div style={{ minHeight: 0, height: '100%', overflowY: 'auto', overflowX: 'hidden', padding: '0 6px 28px 12px' }}>
|
||||
<div style={{ marginBottom: 12, fontWeight: 600, color: overlayTheme.titleText }}>设置导航</div>
|
||||
<div style={{ display: 'grid', gap: 10 }}>
|
||||
{[
|
||||
|
||||
@@ -341,7 +341,7 @@ const FindInDatabaseModal: React.FC<FindInDatabaseModalProps> = ({ open, onClose
|
||||
header: { background: 'transparent', borderBottom: 'none', paddingBottom: 8 },
|
||||
body: { paddingTop: 8 },
|
||||
}}
|
||||
destroyOnClose
|
||||
destroyOnHidden
|
||||
>
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: 16 }}>
|
||||
{/* 搜索栏 */}
|
||||
|
||||
@@ -22,12 +22,14 @@ const resolveDevHarnessMode = (): string => {
|
||||
}
|
||||
};
|
||||
|
||||
if (typeof window !== 'undefined' && !(window as any).go) {
|
||||
if (typeof window !== 'undefined' && (!(window as any).go?.app?.App || !(window as any).go?.aiservice?.Service)) {
|
||||
const mockConnections: any[] = [];
|
||||
const mockConnectionSecrets = new Map<string, any>();
|
||||
const mockProviders: any[] = [];
|
||||
const mockProviderSecrets = new Map<string, string>();
|
||||
let mockActiveProviderId = '';
|
||||
let mockAISafetyLevel = 'readonly';
|
||||
let mockAIContextLevel = 'schema_only';
|
||||
let mockAIUserPromptSettings: any = {
|
||||
global: '',
|
||||
database: '',
|
||||
@@ -35,6 +37,28 @@ if (typeof window !== 'undefined' && !(window as any).go) {
|
||||
jvmDiagnostic: '',
|
||||
};
|
||||
let mockMCPServers: any[] = [];
|
||||
let mockMCPClientStatuses: any[] = [
|
||||
{
|
||||
client: 'claude-code',
|
||||
displayName: 'Claude Code',
|
||||
installed: false,
|
||||
matchesCurrent: false,
|
||||
message: '未安装到 Claude Code 用户级配置',
|
||||
configPath: 'C:/Users/mock/.claude.json',
|
||||
command: 'C:/Program Files/GoNavi/GoNavi.exe',
|
||||
args: ['mcp-server'],
|
||||
},
|
||||
{
|
||||
client: 'codex',
|
||||
displayName: 'Codex',
|
||||
installed: true,
|
||||
matchesCurrent: false,
|
||||
message: '已检测到 Codex 安装记录,但与当前 GoNavi 安装包路径不一致,建议更新安装',
|
||||
configPath: 'C:/Users/mock/.codex/config.toml',
|
||||
command: 'C:/Old/GoNavi.exe',
|
||||
args: ['mcp-server'],
|
||||
},
|
||||
];
|
||||
let mockSkills: any[] = [];
|
||||
let mockGlobalProxy: any = { enabled: false, type: 'socks5', host: '', port: 1080, user: '', password: '', hasPassword: false };
|
||||
let mockDataRootInfo: any = {
|
||||
@@ -154,7 +178,7 @@ if (typeof window !== 'undefined' && !(window as any).go) {
|
||||
return cloneBrowserMockValue(view);
|
||||
};
|
||||
|
||||
(window as any).go = {
|
||||
const mockGo = {
|
||||
app: {
|
||||
App: {
|
||||
CheckUpdate: async () => ({ success: false }),
|
||||
@@ -291,8 +315,8 @@ if (typeof window !== 'undefined' && !(window as any).go) {
|
||||
mockActiveProviderId = id;
|
||||
return null;
|
||||
},
|
||||
AIGetSafetyLevel: async () => 'readonly',
|
||||
AIGetContextLevel: async () => 'schema_only',
|
||||
AIGetSafetyLevel: async () => mockAISafetyLevel,
|
||||
AIGetContextLevel: async () => mockAIContextLevel,
|
||||
AIGetBuiltinPrompts: async () => ({}),
|
||||
AIGetUserPromptSettings: async () => cloneBrowserMockValue(mockAIUserPromptSettings),
|
||||
AISaveUserPromptSettings: async (input: any) => {
|
||||
@@ -304,7 +328,48 @@ if (typeof window !== 'undefined' && !(window as any).go) {
|
||||
};
|
||||
return null;
|
||||
},
|
||||
AIGetMCPClientInstallStatuses: async () => cloneBrowserMockValue(mockMCPClientStatuses),
|
||||
AIGetMCPServers: async () => cloneBrowserMockValue(mockMCPServers),
|
||||
AIInstallClaudeCodeMCP: async () => {
|
||||
mockMCPClientStatuses = mockMCPClientStatuses.map((item) => item.client === 'claude-code'
|
||||
? {
|
||||
...item,
|
||||
installed: true,
|
||||
matchesCurrent: true,
|
||||
message: '已写入 Claude Code 用户级 MCP 配置,重启 Claude CLI 后可在 /mcp 的 User MCPs 中看到 GoNavi。',
|
||||
command: 'C:/Program Files/GoNavi/GoNavi.exe',
|
||||
args: ['mcp-server'],
|
||||
}
|
||||
: item);
|
||||
return {
|
||||
success: true,
|
||||
client: 'claude-code',
|
||||
message: '已写入 Claude Code 用户级 MCP 配置,重启 Claude CLI 后可在 /mcp 的 User MCPs 中看到 GoNavi。',
|
||||
configPath: 'C:/Users/mock/.claude.json',
|
||||
command: 'C:/Program Files/GoNavi/GoNavi.exe',
|
||||
args: ['mcp-server'],
|
||||
};
|
||||
},
|
||||
AIInstallCodexMCP: async () => {
|
||||
mockMCPClientStatuses = mockMCPClientStatuses.map((item) => item.client === 'codex'
|
||||
? {
|
||||
...item,
|
||||
installed: true,
|
||||
matchesCurrent: true,
|
||||
message: '已写入 Codex 用户级 MCP 配置,重启 Codex CLI 或桌面端后可看到 GoNavi。',
|
||||
command: 'C:/Program Files/GoNavi/GoNavi.exe',
|
||||
args: ['mcp-server'],
|
||||
}
|
||||
: item);
|
||||
return {
|
||||
success: true,
|
||||
client: 'codex',
|
||||
message: '已写入 Codex 用户级 MCP 配置,重启 Codex CLI 或桌面端后可看到 GoNavi。',
|
||||
configPath: 'C:/Users/mock/.codex/config.toml',
|
||||
command: 'C:/Program Files/GoNavi/GoNavi.exe',
|
||||
args: ['mcp-server'],
|
||||
};
|
||||
},
|
||||
AISaveMCPServer: async (input: any) => {
|
||||
const next = {
|
||||
id: String(input?.id || `mcp-${Date.now()}`),
|
||||
@@ -363,11 +428,38 @@ if (typeof window !== 'undefined' && !(window as any).go) {
|
||||
success: String(input?.apiKey || '').trim() !== '',
|
||||
message: String(input?.apiKey || '').trim() !== '' ? '端点连通性测试成功!' : '连接测试失败: missing api key',
|
||||
}),
|
||||
AISetSafetyLevel: async () => null,
|
||||
AISetContextLevel: async () => null,
|
||||
AISetSafetyLevel: async (level: string) => {
|
||||
mockAISafetyLevel = String(level || 'readonly');
|
||||
return null;
|
||||
},
|
||||
AISetContextLevel: async (level: string) => {
|
||||
mockAIContextLevel = String(level || 'schema_only');
|
||||
return null;
|
||||
},
|
||||
},
|
||||
}
|
||||
};
|
||||
const existingGo = (window as any).go || {};
|
||||
(window as any).go = {
|
||||
...mockGo,
|
||||
...existingGo,
|
||||
app: {
|
||||
...mockGo.app,
|
||||
...(existingGo.app || {}),
|
||||
App: {
|
||||
...mockGo.app.App,
|
||||
...(existingGo.app?.App || {}),
|
||||
},
|
||||
},
|
||||
aiservice: {
|
||||
...mockGo.aiservice,
|
||||
...(existingGo.aiservice || {}),
|
||||
Service: {
|
||||
...mockGo.aiservice.Service,
|
||||
...(existingGo.aiservice?.Service || {}),
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
const rootNode = document.getElementById('root')!;
|
||||
const devHarnessMode = import.meta.env.DEV ? resolveDevHarnessMode() : '';
|
||||
|
||||
@@ -592,6 +592,17 @@ export interface AIMCPToolCallResult {
|
||||
isError: boolean;
|
||||
}
|
||||
|
||||
export interface AIMCPClientInstallStatus {
|
||||
client: string;
|
||||
displayName: string;
|
||||
installed: boolean;
|
||||
matchesCurrent: boolean;
|
||||
message: string;
|
||||
configPath?: string;
|
||||
command?: string;
|
||||
args?: string[];
|
||||
}
|
||||
|
||||
export type AISkillScope = "global" | "database" | "jvm" | "jvmDiagnostic";
|
||||
|
||||
export interface AISkillConfig {
|
||||
|
||||
6
frontend/wailsjs/go/aiservice/Service.d.ts
vendored
6
frontend/wailsjs/go/aiservice/Service.d.ts
vendored
@@ -28,6 +28,8 @@ export function AIGetContextLevel():Promise<string>;
|
||||
|
||||
export function AIGetEditableProvider(arg1:string):Promise<ai.ProviderConfig>;
|
||||
|
||||
export function AIGetMCPClientInstallStatuses():Promise<Array<ai.MCPClientInstallStatus>>;
|
||||
|
||||
export function AIGetMCPServers():Promise<Array<ai.MCPServerConfig>>;
|
||||
|
||||
export function AIGetProviders():Promise<Array<ai.ProviderConfig>>;
|
||||
@@ -40,6 +42,10 @@ export function AIGetSkills():Promise<Array<ai.SkillConfig>>;
|
||||
|
||||
export function AIGetUserPromptSettings():Promise<ai.UserPromptSettings>;
|
||||
|
||||
export function AIInstallClaudeCodeMCP():Promise<ai.MCPClientInstallResult>;
|
||||
|
||||
export function AIInstallCodexMCP():Promise<ai.MCPClientInstallResult>;
|
||||
|
||||
export function AIListMCPTools():Promise<Array<ai.MCPToolDescriptor>>;
|
||||
|
||||
export function AIListModels():Promise<Record<string, any>>;
|
||||
|
||||
@@ -54,6 +54,10 @@ export function AIGetEditableProvider(arg1) {
|
||||
return window['go']['aiservice']['Service']['AIGetEditableProvider'](arg1);
|
||||
}
|
||||
|
||||
export function AIGetMCPClientInstallStatuses() {
|
||||
return window['go']['aiservice']['Service']['AIGetMCPClientInstallStatuses']();
|
||||
}
|
||||
|
||||
export function AIGetMCPServers() {
|
||||
return window['go']['aiservice']['Service']['AIGetMCPServers']();
|
||||
}
|
||||
@@ -78,6 +82,14 @@ export function AIGetUserPromptSettings() {
|
||||
return window['go']['aiservice']['Service']['AIGetUserPromptSettings']();
|
||||
}
|
||||
|
||||
export function AIInstallClaudeCodeMCP() {
|
||||
return window['go']['aiservice']['Service']['AIInstallClaudeCodeMCP']();
|
||||
}
|
||||
|
||||
export function AIInstallCodexMCP() {
|
||||
return window['go']['aiservice']['Service']['AIInstallCodexMCP']();
|
||||
}
|
||||
|
||||
export function AIListMCPTools() {
|
||||
return window['go']['aiservice']['Service']['AIListMCPTools']();
|
||||
}
|
||||
|
||||
@@ -1,5 +1,53 @@
|
||||
export namespace ai {
|
||||
|
||||
export class MCPClientInstallResult {
|
||||
success: boolean;
|
||||
client?: string;
|
||||
message: string;
|
||||
configPath?: string;
|
||||
command?: string;
|
||||
args?: string[];
|
||||
|
||||
static createFrom(source: any = {}) {
|
||||
return new MCPClientInstallResult(source);
|
||||
}
|
||||
|
||||
constructor(source: any = {}) {
|
||||
if ('string' === typeof source) source = JSON.parse(source);
|
||||
this.success = source["success"];
|
||||
this.client = source["client"];
|
||||
this.message = source["message"];
|
||||
this.configPath = source["configPath"];
|
||||
this.command = source["command"];
|
||||
this.args = source["args"];
|
||||
}
|
||||
}
|
||||
export class MCPClientInstallStatus {
|
||||
client: string;
|
||||
displayName: string;
|
||||
installed: boolean;
|
||||
matchesCurrent: boolean;
|
||||
message: string;
|
||||
configPath?: string;
|
||||
command?: string;
|
||||
args?: string[];
|
||||
|
||||
static createFrom(source: any = {}) {
|
||||
return new MCPClientInstallStatus(source);
|
||||
}
|
||||
|
||||
constructor(source: any = {}) {
|
||||
if ('string' === typeof source) source = JSON.parse(source);
|
||||
this.client = source["client"];
|
||||
this.displayName = source["displayName"];
|
||||
this.installed = source["installed"];
|
||||
this.matchesCurrent = source["matchesCurrent"];
|
||||
this.message = source["message"];
|
||||
this.configPath = source["configPath"];
|
||||
this.command = source["command"];
|
||||
this.args = source["args"];
|
||||
}
|
||||
}
|
||||
export class MCPServerConfig {
|
||||
id: string;
|
||||
name: string;
|
||||
@@ -1272,4 +1320,3 @@ export namespace sync {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
680
internal/ai/service/claude_code_mcp.go
Normal file
680
internal/ai/service/claude_code_mcp.go
Normal file
@@ -0,0 +1,680 @@
|
||||
package aiservice
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"GoNavi-Wails/internal/ai"
|
||||
)
|
||||
|
||||
const (
|
||||
gonaviMCPServerID = "gonavi"
|
||||
defaultCodexMCPStartupTimeoutSecond = 60
|
||||
)
|
||||
|
||||
var claudeCodeConfigPathFunc = func() (string, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
homeDir = strings.TrimSpace(homeDir)
|
||||
if homeDir == "" {
|
||||
return "", fmt.Errorf("无法确定用户目录")
|
||||
}
|
||||
return filepath.Join(homeDir, ".claude.json"), nil
|
||||
}
|
||||
|
||||
var codexConfigPathFunc = func() (string, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
homeDir = strings.TrimSpace(homeDir)
|
||||
if homeDir == "" {
|
||||
return "", fmt.Errorf("无法确定用户目录")
|
||||
}
|
||||
return filepath.Join(homeDir, ".codex", "config.toml"), nil
|
||||
}
|
||||
|
||||
var localMCPExecutablePathFunc = os.Executable
|
||||
|
||||
type claudeCodeMCPServerConfig struct {
|
||||
Type string `json:"type"`
|
||||
Command string `json:"command"`
|
||||
Args []string `json:"args,omitempty"`
|
||||
Env map[string]string `json:"env,omitempty"`
|
||||
}
|
||||
|
||||
type codexMCPServerConfig struct {
|
||||
Command string
|
||||
Args []string
|
||||
StartupTimeoutSec int
|
||||
}
|
||||
|
||||
// AIGetMCPClientInstallStatuses 返回 GoNavi MCP 在常见外部客户端中的安装状态。
|
||||
func (s *Service) AIGetMCPClientInstallStatuses() []ai.MCPClientInstallStatus {
|
||||
command, args, resolveErr := resolveCurrentLocalMCPCommand()
|
||||
return []ai.MCPClientInstallStatus{
|
||||
inspectClaudeCodeMCPInstallStatus(command, args, resolveErr),
|
||||
inspectCodexMCPInstallStatus(command, args, resolveErr),
|
||||
}
|
||||
}
|
||||
|
||||
// AIInstallClaudeCodeMCP 把 GoNavi 的 MCP server 写入 Claude Code 用户级 MCP 配置。
|
||||
func (s *Service) AIInstallClaudeCodeMCP() (ai.MCPClientInstallResult, error) {
|
||||
configPath, err := claudeCodeConfigPathFunc()
|
||||
if err != nil {
|
||||
return ai.MCPClientInstallResult{}, fmt.Errorf("定位 Claude Code 配置失败: %w", err)
|
||||
}
|
||||
|
||||
executablePath, err := localMCPExecutablePathFunc()
|
||||
if err != nil {
|
||||
return ai.MCPClientInstallResult{}, fmt.Errorf("定位当前 GoNavi 可执行文件失败: %w", err)
|
||||
}
|
||||
|
||||
command, args, err := resolveLocalMCPCommand(executablePath)
|
||||
if err != nil {
|
||||
return ai.MCPClientInstallResult{}, err
|
||||
}
|
||||
|
||||
serverConfig := claudeCodeMCPServerConfig{
|
||||
Type: "stdio",
|
||||
Command: command,
|
||||
Args: append([]string(nil), args...),
|
||||
Env: map[string]string{},
|
||||
}
|
||||
if err := upsertClaudeCodeMCPServerConfig(configPath, gonaviMCPServerID, serverConfig); err != nil {
|
||||
return ai.MCPClientInstallResult{}, err
|
||||
}
|
||||
|
||||
return ai.MCPClientInstallResult{
|
||||
Success: true,
|
||||
Client: "claude-code",
|
||||
Message: "已写入 Claude Code 用户级 MCP 配置,重启 Claude CLI 后可在 /mcp 的 User MCPs 中看到 GoNavi。",
|
||||
ConfigPath: configPath,
|
||||
Command: command,
|
||||
Args: append([]string(nil), args...),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AIInstallCodexMCP 把 GoNavi 的 MCP server 写入 Codex 用户级 MCP 配置。
|
||||
func (s *Service) AIInstallCodexMCP() (ai.MCPClientInstallResult, error) {
|
||||
configPath, err := codexConfigPathFunc()
|
||||
if err != nil {
|
||||
return ai.MCPClientInstallResult{}, fmt.Errorf("定位 Codex 配置失败: %w", err)
|
||||
}
|
||||
|
||||
executablePath, err := localMCPExecutablePathFunc()
|
||||
if err != nil {
|
||||
return ai.MCPClientInstallResult{}, fmt.Errorf("定位当前 GoNavi 可执行文件失败: %w", err)
|
||||
}
|
||||
|
||||
command, args, err := resolveLocalMCPCommand(executablePath)
|
||||
if err != nil {
|
||||
return ai.MCPClientInstallResult{}, err
|
||||
}
|
||||
|
||||
serverConfig := codexMCPServerConfig{
|
||||
Command: command,
|
||||
Args: append([]string(nil), args...),
|
||||
StartupTimeoutSec: defaultCodexMCPStartupTimeoutSecond,
|
||||
}
|
||||
if err := upsertCodexMCPServerConfig(configPath, gonaviMCPServerID, serverConfig); err != nil {
|
||||
return ai.MCPClientInstallResult{}, err
|
||||
}
|
||||
|
||||
return ai.MCPClientInstallResult{
|
||||
Success: true,
|
||||
Client: "codex",
|
||||
Message: "已写入 Codex 用户级 MCP 配置,重启 Codex CLI 或桌面端后可看到 GoNavi。",
|
||||
ConfigPath: configPath,
|
||||
Command: command,
|
||||
Args: append([]string(nil), args...),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func resolveCurrentLocalMCPCommand() (string, []string, error) {
|
||||
executablePath, err := localMCPExecutablePathFunc()
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("定位当前 GoNavi 可执行文件失败: %w", err)
|
||||
}
|
||||
command, args, err := resolveLocalMCPCommand(executablePath)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return command, args, nil
|
||||
}
|
||||
|
||||
func resolveLocalMCPCommand(executablePath string) (string, []string, error) {
|
||||
executablePath = strings.TrimSpace(executablePath)
|
||||
if executablePath == "" {
|
||||
return "", nil, fmt.Errorf("当前 GoNavi 可执行文件路径为空")
|
||||
}
|
||||
|
||||
cleaned := filepath.Clean(executablePath)
|
||||
baseName := strings.ToLower(strings.TrimSpace(filepath.Base(cleaned)))
|
||||
switch baseName {
|
||||
case "gonavi-mcp-server", "gonavi-mcp-server.exe":
|
||||
return cleaned, []string{}, nil
|
||||
default:
|
||||
return cleaned, []string{"mcp-server"}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func inspectClaudeCodeMCPInstallStatus(expectedCommand string, expectedArgs []string, expectedErr error) ai.MCPClientInstallStatus {
|
||||
configPath, pathErr := claudeCodeConfigPathFunc()
|
||||
status := ai.MCPClientInstallStatus{
|
||||
Client: "claude-code",
|
||||
DisplayName: "Claude Code",
|
||||
ConfigPath: strings.TrimSpace(configPath),
|
||||
Message: "未安装到 Claude Code 用户级配置",
|
||||
}
|
||||
if pathErr != nil {
|
||||
status.Message = fmt.Sprintf("定位 Claude Code 配置失败: %v", pathErr)
|
||||
return status
|
||||
}
|
||||
|
||||
serverConfig, found, err := readClaudeCodeMCPServerConfig(configPath, gonaviMCPServerID)
|
||||
if err != nil {
|
||||
status.Installed = found
|
||||
status.Message = err.Error()
|
||||
if found {
|
||||
status.Command = strings.TrimSpace(serverConfig.Command)
|
||||
status.Args = append([]string(nil), serverConfig.Args...)
|
||||
}
|
||||
return status
|
||||
}
|
||||
if !found {
|
||||
return status
|
||||
}
|
||||
|
||||
status.Installed = true
|
||||
status.Command = strings.TrimSpace(serverConfig.Command)
|
||||
status.Args = append([]string(nil), serverConfig.Args...)
|
||||
if expectedErr != nil {
|
||||
status.Message = fmt.Sprintf("已检测到 Claude Code 安装记录,但当前 GoNavi 安装路径校验失败:%v", expectedErr)
|
||||
return status
|
||||
}
|
||||
|
||||
status.MatchesCurrent = strings.EqualFold(strings.TrimSpace(serverConfig.Type), "stdio") &&
|
||||
sameMCPCommand(serverConfig.Command, serverConfig.Args, expectedCommand, expectedArgs)
|
||||
if status.MatchesCurrent {
|
||||
status.Message = "已安装到 Claude Code 用户级配置"
|
||||
return status
|
||||
}
|
||||
|
||||
status.Message = "已检测到 Claude Code 安装记录,但与当前 GoNavi 安装包路径不一致,建议更新安装"
|
||||
return status
|
||||
}
|
||||
|
||||
func inspectCodexMCPInstallStatus(expectedCommand string, expectedArgs []string, expectedErr error) ai.MCPClientInstallStatus {
|
||||
configPath, pathErr := codexConfigPathFunc()
|
||||
status := ai.MCPClientInstallStatus{
|
||||
Client: "codex",
|
||||
DisplayName: "Codex",
|
||||
ConfigPath: strings.TrimSpace(configPath),
|
||||
Message: "未安装到 Codex 用户级配置",
|
||||
}
|
||||
if pathErr != nil {
|
||||
status.Message = fmt.Sprintf("定位 Codex 配置失败: %v", pathErr)
|
||||
return status
|
||||
}
|
||||
|
||||
serverConfig, found, err := readCodexMCPServerConfig(configPath, gonaviMCPServerID)
|
||||
if err != nil {
|
||||
status.Installed = found
|
||||
status.Message = err.Error()
|
||||
if found {
|
||||
status.Command = strings.TrimSpace(serverConfig.Command)
|
||||
status.Args = append([]string(nil), serverConfig.Args...)
|
||||
}
|
||||
return status
|
||||
}
|
||||
if !found {
|
||||
return status
|
||||
}
|
||||
|
||||
status.Installed = true
|
||||
status.Command = strings.TrimSpace(serverConfig.Command)
|
||||
status.Args = append([]string(nil), serverConfig.Args...)
|
||||
if expectedErr != nil {
|
||||
status.Message = fmt.Sprintf("已检测到 Codex 安装记录,但当前 GoNavi 安装路径校验失败:%v", expectedErr)
|
||||
return status
|
||||
}
|
||||
|
||||
status.MatchesCurrent = sameMCPCommand(serverConfig.Command, serverConfig.Args, expectedCommand, expectedArgs) &&
|
||||
(serverConfig.StartupTimeoutSec == 0 || serverConfig.StartupTimeoutSec == defaultCodexMCPStartupTimeoutSecond)
|
||||
if status.MatchesCurrent {
|
||||
status.Message = "已安装到 Codex 用户级配置"
|
||||
return status
|
||||
}
|
||||
|
||||
status.Message = "已检测到 Codex 安装记录,但与当前 GoNavi 安装包路径不一致,建议更新安装"
|
||||
return status
|
||||
}
|
||||
|
||||
func readClaudeCodeMCPServerConfig(configPath string, serverID string) (claudeCodeMCPServerConfig, bool, error) {
|
||||
root, err := readClaudeCodeConfig(configPath)
|
||||
if err != nil {
|
||||
return claudeCodeMCPServerConfig{}, false, err
|
||||
}
|
||||
|
||||
rawServers, exists := root["mcpServers"]
|
||||
if !exists || rawServers == nil {
|
||||
return claudeCodeMCPServerConfig{}, false, nil
|
||||
}
|
||||
mcpServers, ok := rawServers.(map[string]any)
|
||||
if !ok {
|
||||
return claudeCodeMCPServerConfig{}, false, fmt.Errorf("Claude Code 配置格式异常:mcpServers 不是对象")
|
||||
}
|
||||
|
||||
rawServer, exists := mcpServers[strings.TrimSpace(serverID)]
|
||||
if !exists || rawServer == nil {
|
||||
return claudeCodeMCPServerConfig{}, false, nil
|
||||
}
|
||||
serverMap, ok := rawServer.(map[string]any)
|
||||
if !ok {
|
||||
return claudeCodeMCPServerConfig{}, true, fmt.Errorf("Claude Code 配置格式异常:mcpServers.%s 不是对象", strings.TrimSpace(serverID))
|
||||
}
|
||||
|
||||
args, err := decodeJSONLikeStringSlice(serverMap["args"])
|
||||
if err != nil {
|
||||
return claudeCodeMCPServerConfig{}, true, fmt.Errorf("Claude Code 配置格式异常:mcpServers.%s.args 不是字符串数组", strings.TrimSpace(serverID))
|
||||
}
|
||||
return claudeCodeMCPServerConfig{
|
||||
Type: strings.TrimSpace(anyString(serverMap["type"])),
|
||||
Command: strings.TrimSpace(anyString(serverMap["command"])),
|
||||
Args: args,
|
||||
}, true, nil
|
||||
}
|
||||
|
||||
func upsertClaudeCodeMCPServerConfig(configPath string, serverID string, serverConfig claudeCodeMCPServerConfig) error {
|
||||
root, err := readClaudeCodeConfig(configPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mcpServers, err := ensureJSONMap(root, "mcpServers")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mcpServers[strings.TrimSpace(serverID)] = map[string]any{
|
||||
"type": serverConfig.Type,
|
||||
"command": serverConfig.Command,
|
||||
"args": append([]string(nil), serverConfig.Args...),
|
||||
"env": cloneStringMap(serverConfig.Env),
|
||||
}
|
||||
root["mcpServers"] = mcpServers
|
||||
|
||||
data, err := json.MarshalIndent(root, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化 Claude Code 配置失败: %w", err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return fmt.Errorf("创建 Claude Code 配置目录失败: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(configPath, append(data, '\n'), 0o644); err != nil {
|
||||
return fmt.Errorf("写入 Claude Code 配置失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readClaudeCodeConfig(configPath string) (map[string]any, error) {
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("读取 Claude Code 配置失败: %w", err)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(string(data)) == "" {
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
|
||||
var root map[string]any
|
||||
if err := json.Unmarshal(data, &root); err != nil {
|
||||
return nil, fmt.Errorf("解析 Claude Code 配置失败: %w", err)
|
||||
}
|
||||
if root == nil {
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
return root, nil
|
||||
}
|
||||
|
||||
func ensureJSONMap(root map[string]any, key string) (map[string]any, error) {
|
||||
if root == nil {
|
||||
return nil, fmt.Errorf("JSON 根对象不能为空")
|
||||
}
|
||||
|
||||
value, exists := root[key]
|
||||
if !exists || value == nil {
|
||||
result := map[string]any{}
|
||||
root[key] = result
|
||||
return result, nil
|
||||
}
|
||||
|
||||
typed, ok := value.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Claude Code 配置格式异常:%s 不是对象", key)
|
||||
}
|
||||
return typed, nil
|
||||
}
|
||||
|
||||
func readCodexMCPServerConfig(configPath string, serverID string) (codexMCPServerConfig, bool, error) {
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return codexMCPServerConfig{}, false, nil
|
||||
}
|
||||
return codexMCPServerConfig{}, false, fmt.Errorf("读取 Codex 配置失败: %w", err)
|
||||
}
|
||||
return parseCodexMCPServerConfig(string(data), serverID)
|
||||
}
|
||||
|
||||
func upsertCodexMCPServerConfig(configPath string, serverID string, serverConfig codexMCPServerConfig) error {
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("读取 Codex 配置失败: %w", err)
|
||||
}
|
||||
|
||||
updated := replaceOrAppendCodexMCPServerBlock(string(data), strings.TrimSpace(serverID), renderCodexMCPServerBlock(serverID, serverConfig))
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return fmt.Errorf("创建 Codex 配置目录失败: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(configPath, []byte(updated), 0o644); err != nil {
|
||||
return fmt.Errorf("写入 Codex 配置失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func renderCodexMCPServerBlock(serverID string, serverConfig codexMCPServerConfig) string {
|
||||
trimmedID := strings.TrimSpace(serverID)
|
||||
if trimmedID == "" {
|
||||
trimmedID = gonaviMCPServerID
|
||||
}
|
||||
|
||||
lines := []string{
|
||||
fmt.Sprintf("[mcp_servers.%s]", trimmedID),
|
||||
fmt.Sprintf("command = %s", tomlString(serverConfig.Command)),
|
||||
fmt.Sprintf("args = [%s]", strings.Join(renderTomlStringArray(serverConfig.Args), ", ")),
|
||||
}
|
||||
if serverConfig.StartupTimeoutSec > 0 {
|
||||
lines = append(lines, fmt.Sprintf("startup_timeout_sec = %d", serverConfig.StartupTimeoutSec))
|
||||
}
|
||||
return strings.Join(lines, "\n") + "\n"
|
||||
}
|
||||
|
||||
func parseCodexMCPServerConfig(content string, serverID string) (codexMCPServerConfig, bool, error) {
|
||||
lines := strings.Split(strings.ReplaceAll(content, "\r\n", "\n"), "\n")
|
||||
mainHeader := fmt.Sprintf("[mcp_servers.%s]", strings.TrimSpace(serverID))
|
||||
result := codexMCPServerConfig{}
|
||||
found := false
|
||||
inside := false
|
||||
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if !inside {
|
||||
if trimmed == mainHeader {
|
||||
inside = true
|
||||
found = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
if isTOMLHeaderLine(trimmed) {
|
||||
break
|
||||
}
|
||||
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
key, value, ok := splitTOMLAssignment(trimmed)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch key {
|
||||
case "command":
|
||||
parsed, err := parseTOMLString(value)
|
||||
if err != nil {
|
||||
return result, true, fmt.Errorf("Codex 配置格式异常:mcp_servers.%s.command 解析失败", strings.TrimSpace(serverID))
|
||||
}
|
||||
result.Command = parsed
|
||||
case "args":
|
||||
parsed, err := parseTOMLStringArray(value)
|
||||
if err != nil {
|
||||
return result, true, fmt.Errorf("Codex 配置格式异常:mcp_servers.%s.args 解析失败", strings.TrimSpace(serverID))
|
||||
}
|
||||
result.Args = parsed
|
||||
case "startup_timeout_sec":
|
||||
parsed, err := strconv.Atoi(strings.TrimSpace(value))
|
||||
if err != nil {
|
||||
return result, true, fmt.Errorf("Codex 配置格式异常:mcp_servers.%s.startup_timeout_sec 解析失败", strings.TrimSpace(serverID))
|
||||
}
|
||||
result.StartupTimeoutSec = parsed
|
||||
}
|
||||
}
|
||||
|
||||
return result, found, nil
|
||||
}
|
||||
|
||||
func replaceOrAppendCodexMCPServerBlock(content string, serverID string, block string) string {
|
||||
lines := strings.Split(strings.ReplaceAll(content, "\r\n", "\n"), "\n")
|
||||
mainHeader := fmt.Sprintf("[mcp_servers.%s]", serverID)
|
||||
nestedPrefix := fmt.Sprintf("[mcp_servers.%s.", serverID)
|
||||
|
||||
start, end := -1, -1
|
||||
for index, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if start == -1 {
|
||||
if trimmed == mainHeader || strings.HasPrefix(trimmed, nestedPrefix) {
|
||||
start = index
|
||||
}
|
||||
continue
|
||||
}
|
||||
if isTOMLHeaderLine(trimmed) && trimmed != mainHeader && !strings.HasPrefix(trimmed, nestedPrefix) {
|
||||
end = index
|
||||
break
|
||||
}
|
||||
}
|
||||
if start != -1 && end == -1 {
|
||||
end = len(lines)
|
||||
}
|
||||
|
||||
rendered := strings.TrimRight(block, "\n")
|
||||
if start == -1 {
|
||||
base := strings.TrimSpace(strings.Join(lines, "\n"))
|
||||
if base == "" {
|
||||
return rendered + "\n"
|
||||
}
|
||||
return strings.TrimRight(strings.Join(lines, "\n"), "\n") + "\n\n" + rendered + "\n"
|
||||
}
|
||||
|
||||
before := strings.TrimRight(strings.Join(lines[:start], "\n"), "\n")
|
||||
after := strings.TrimLeft(strings.Join(lines[end:], "\n"), "\n")
|
||||
switch {
|
||||
case before == "" && after == "":
|
||||
return rendered + "\n"
|
||||
case before == "":
|
||||
return rendered + "\n\n" + after
|
||||
case after == "":
|
||||
return before + "\n\n" + rendered + "\n"
|
||||
default:
|
||||
return before + "\n\n" + rendered + "\n\n" + after
|
||||
}
|
||||
}
|
||||
|
||||
func renderTomlStringArray(values []string) []string {
|
||||
rendered := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
rendered = append(rendered, tomlString(value))
|
||||
}
|
||||
return rendered
|
||||
}
|
||||
|
||||
func tomlString(value string) string {
|
||||
if !strings.Contains(value, "'") && !strings.Contains(value, "\n") && !strings.Contains(value, "\r") {
|
||||
return "'" + value + "'"
|
||||
}
|
||||
return strconv.Quote(value)
|
||||
}
|
||||
|
||||
func splitTOMLAssignment(line string) (string, string, bool) {
|
||||
index := strings.Index(line, "=")
|
||||
if index <= 0 {
|
||||
return "", "", false
|
||||
}
|
||||
key := strings.TrimSpace(line[:index])
|
||||
value := strings.TrimSpace(line[index+1:])
|
||||
if key == "" {
|
||||
return "", "", false
|
||||
}
|
||||
return key, value, true
|
||||
}
|
||||
|
||||
func parseTOMLString(value string) (string, error) {
|
||||
value = strings.TrimSpace(value)
|
||||
if len(value) < 2 {
|
||||
return "", fmt.Errorf("字符串格式非法")
|
||||
}
|
||||
switch value[0] {
|
||||
case '\'':
|
||||
if value[len(value)-1] != '\'' {
|
||||
return "", fmt.Errorf("单引号字符串未闭合")
|
||||
}
|
||||
return value[1 : len(value)-1], nil
|
||||
case '"':
|
||||
parsed, err := strconv.Unquote(value)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return parsed, nil
|
||||
default:
|
||||
return "", fmt.Errorf("不是字符串")
|
||||
}
|
||||
}
|
||||
|
||||
func parseTOMLStringArray(value string) ([]string, error) {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return []string{}, nil
|
||||
}
|
||||
if !strings.HasPrefix(value, "[") || !strings.HasSuffix(value, "]") {
|
||||
return nil, fmt.Errorf("不是数组")
|
||||
}
|
||||
|
||||
inner := strings.TrimSpace(value[1 : len(value)-1])
|
||||
if inner == "" {
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
result := make([]string, 0, 4)
|
||||
for inner != "" {
|
||||
item, rest, err := consumeTOMLQuotedString(inner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, item)
|
||||
inner = strings.TrimSpace(rest)
|
||||
if inner == "" {
|
||||
break
|
||||
}
|
||||
if !strings.HasPrefix(inner, ",") {
|
||||
return nil, fmt.Errorf("数组分隔符非法")
|
||||
}
|
||||
inner = strings.TrimSpace(inner[1:])
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func consumeTOMLQuotedString(value string) (string, string, error) {
|
||||
value = strings.TrimLeft(value, " \t")
|
||||
if value == "" {
|
||||
return "", "", fmt.Errorf("字符串为空")
|
||||
}
|
||||
switch value[0] {
|
||||
case '\'':
|
||||
end := strings.IndexByte(value[1:], '\'')
|
||||
if end < 0 {
|
||||
return "", "", fmt.Errorf("单引号字符串未闭合")
|
||||
}
|
||||
end++
|
||||
return value[1:end], value[end+1:], nil
|
||||
case '"':
|
||||
escaped := false
|
||||
for index := 1; index < len(value); index++ {
|
||||
ch := value[index]
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
parsed, err := strconv.Unquote(value[:index+1])
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return parsed, value[index+1:], nil
|
||||
}
|
||||
}
|
||||
return "", "", fmt.Errorf("双引号字符串未闭合")
|
||||
default:
|
||||
return "", "", fmt.Errorf("不是字符串")
|
||||
}
|
||||
}
|
||||
|
||||
func decodeJSONLikeStringSlice(value any) ([]string, error) {
|
||||
switch typed := value.(type) {
|
||||
case nil:
|
||||
return []string{}, nil
|
||||
case []string:
|
||||
return append([]string(nil), typed...), nil
|
||||
case []any:
|
||||
result := make([]string, 0, len(typed))
|
||||
for _, item := range typed {
|
||||
str, ok := item.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("数组元素不是字符串")
|
||||
}
|
||||
result = append(result, str)
|
||||
}
|
||||
return result, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("不是字符串数组")
|
||||
}
|
||||
}
|
||||
|
||||
func anyString(value any) string {
|
||||
text, _ := value.(string)
|
||||
return text
|
||||
}
|
||||
|
||||
func sameMCPCommand(actualCommand string, actualArgs []string, expectedCommand string, expectedArgs []string) bool {
|
||||
return strings.TrimSpace(actualCommand) == strings.TrimSpace(expectedCommand) &&
|
||||
reflect.DeepEqual(normalizeStringSlice(actualArgs), normalizeStringSlice(expectedArgs))
|
||||
}
|
||||
|
||||
func normalizeStringSlice(values []string) []string {
|
||||
if len(values) == 0 {
|
||||
return []string{}
|
||||
}
|
||||
result := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
result = append(result, strings.TrimSpace(value))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func isTOMLHeaderLine(line string) bool {
|
||||
line = strings.TrimSpace(line)
|
||||
return strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]")
|
||||
}
|
||||
276
internal/ai/service/claude_code_mcp_test.go
Normal file
276
internal/ai/service/claude_code_mcp_test.go
Normal file
@@ -0,0 +1,276 @@
|
||||
package aiservice
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestResolveLocalMCPCommandUsesMainBinaryWithArgument(t *testing.T) {
|
||||
command, args, err := resolveLocalMCPCommand(`C:\Program Files\GoNavi\GoNavi.exe`)
|
||||
if err != nil {
|
||||
t.Fatalf("resolveLocalMCPCommand returned error: %v", err)
|
||||
}
|
||||
if command != `C:\Program Files\GoNavi\GoNavi.exe` {
|
||||
t.Fatalf("expected command to keep main binary path, got %q", command)
|
||||
}
|
||||
if !reflect.DeepEqual(args, []string{"mcp-server"}) {
|
||||
t.Fatalf("expected main binary args %#v, got %#v", []string{"mcp-server"}, args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveLocalMCPCommandKeepsDedicatedServerBinary(t *testing.T) {
|
||||
command, args, err := resolveLocalMCPCommand(`D:\Work\CodeRepos\GoNavi\bin\gonavi-mcp-server.exe`)
|
||||
if err != nil {
|
||||
t.Fatalf("resolveLocalMCPCommand returned error: %v", err)
|
||||
}
|
||||
if command != `D:\Work\CodeRepos\GoNavi\bin\gonavi-mcp-server.exe` {
|
||||
t.Fatalf("expected dedicated server path to be reused, got %q", command)
|
||||
}
|
||||
if len(args) != 0 {
|
||||
t.Fatalf("expected dedicated server args to be empty, got %#v", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadClaudeCodeMCPServerConfigReadsExistingInstall(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
configPath := filepath.Join(tempDir, ".claude.json")
|
||||
initial := map[string]any{
|
||||
"mcpServers": map[string]any{
|
||||
gonaviMCPServerID: map[string]any{
|
||||
"type": "stdio",
|
||||
"command": `C:\Program Files\GoNavi\GoNavi.exe`,
|
||||
"args": []string{"mcp-server"},
|
||||
},
|
||||
},
|
||||
}
|
||||
data, err := json.MarshalIndent(initial, "", " ")
|
||||
if err != nil {
|
||||
t.Fatalf("MarshalIndent returned error: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(configPath, append(data, '\n'), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile returned error: %v", err)
|
||||
}
|
||||
|
||||
cfg, found, err := readClaudeCodeMCPServerConfig(configPath, gonaviMCPServerID)
|
||||
if err != nil {
|
||||
t.Fatalf("readClaudeCodeMCPServerConfig returned error: %v", err)
|
||||
}
|
||||
if !found {
|
||||
t.Fatal("expected gonavi install to be detected")
|
||||
}
|
||||
if cfg.Command != `C:\Program Files\GoNavi\GoNavi.exe` {
|
||||
t.Fatalf("unexpected command: %q", cfg.Command)
|
||||
}
|
||||
if !reflect.DeepEqual(cfg.Args, []string{"mcp-server"}) {
|
||||
t.Fatalf("unexpected args: %#v", cfg.Args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertClaudeCodeMCPServerConfigCreatesAndMergesUserConfig(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
configPath := filepath.Join(tempDir, ".claude.json")
|
||||
initial := map[string]any{
|
||||
"theme": "dark-daltonized",
|
||||
"mcpServers": map[string]any{
|
||||
"memory": map[string]any{
|
||||
"type": "stdio",
|
||||
"command": "cmd",
|
||||
},
|
||||
},
|
||||
}
|
||||
data, err := json.MarshalIndent(initial, "", " ")
|
||||
if err != nil {
|
||||
t.Fatalf("MarshalIndent returned error: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(configPath, append(data, '\n'), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile returned error: %v", err)
|
||||
}
|
||||
|
||||
err = upsertClaudeCodeMCPServerConfig(configPath, gonaviMCPServerID, claudeCodeMCPServerConfig{
|
||||
Type: "stdio",
|
||||
Command: `C:\Program Files\GoNavi\GoNavi.exe`,
|
||||
Args: []string{"mcp-server"},
|
||||
Env: map[string]string{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("upsertClaudeCodeMCPServerConfig returned error: %v", err)
|
||||
}
|
||||
|
||||
updated, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile returned error: %v", err)
|
||||
}
|
||||
|
||||
var root map[string]any
|
||||
if err := json.Unmarshal(updated, &root); err != nil {
|
||||
t.Fatalf("Unmarshal returned error: %v", err)
|
||||
}
|
||||
if got := strings.TrimSpace(root["theme"].(string)); got != "dark-daltonized" {
|
||||
t.Fatalf("expected theme to be preserved, got %q", got)
|
||||
}
|
||||
|
||||
mcpServers, ok := root["mcpServers"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected mcpServers object, got %#v", root["mcpServers"])
|
||||
}
|
||||
if _, ok := mcpServers["memory"]; !ok {
|
||||
t.Fatalf("expected existing memory server to be preserved, got %#v", mcpServers)
|
||||
}
|
||||
|
||||
gonavi, ok := mcpServers[gonaviMCPServerID].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected gonavi server object, got %#v", mcpServers[gonaviMCPServerID])
|
||||
}
|
||||
if got := strings.TrimSpace(gonavi["command"].(string)); got != `C:\Program Files\GoNavi\GoNavi.exe` {
|
||||
t.Fatalf("expected gonavi command to be written, got %q", got)
|
||||
}
|
||||
args, ok := gonavi["args"].([]any)
|
||||
if !ok || len(args) != 1 || strings.TrimSpace(args[0].(string)) != "mcp-server" {
|
||||
t.Fatalf("expected gonavi args to contain mcp-server, got %#v", gonavi["args"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertClaudeCodeMCPServerConfigRejectsInvalidMCPServersShape(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
configPath := filepath.Join(tempDir, ".claude.json")
|
||||
if err := os.WriteFile(configPath, []byte("{\"mcpServers\":[]}"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile returned error: %v", err)
|
||||
}
|
||||
|
||||
err := upsertClaudeCodeMCPServerConfig(configPath, gonaviMCPServerID, claudeCodeMCPServerConfig{
|
||||
Type: "stdio",
|
||||
Command: "GoNavi.exe",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected invalid mcpServers shape to return error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "mcpServers 不是对象") {
|
||||
t.Fatalf("expected invalid shape error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCodexMCPServerConfigDetectsExistingInstall(t *testing.T) {
|
||||
content := strings.Join([]string{
|
||||
`model = "gpt-5.4"`,
|
||||
``,
|
||||
`[mcp_servers.gonavi]`,
|
||||
`command = 'C:\Program Files\GoNavi\GoNavi.exe'`,
|
||||
`args = ['mcp-server']`,
|
||||
`startup_timeout_sec = 60`,
|
||||
``,
|
||||
`[projects.'D:\Work\CodeRepos\GoNavi']`,
|
||||
`trust_level = "trusted"`,
|
||||
``,
|
||||
}, "\n")
|
||||
|
||||
cfg, found, err := parseCodexMCPServerConfig(content, gonaviMCPServerID)
|
||||
if err != nil {
|
||||
t.Fatalf("parseCodexMCPServerConfig returned error: %v", err)
|
||||
}
|
||||
if !found {
|
||||
t.Fatal("expected gonavi install to be detected")
|
||||
}
|
||||
if cfg.Command != `C:\Program Files\GoNavi\GoNavi.exe` {
|
||||
t.Fatalf("unexpected command: %q", cfg.Command)
|
||||
}
|
||||
if !reflect.DeepEqual(cfg.Args, []string{"mcp-server"}) {
|
||||
t.Fatalf("unexpected args: %#v", cfg.Args)
|
||||
}
|
||||
if cfg.StartupTimeoutSec != 60 {
|
||||
t.Fatalf("unexpected startup timeout: %d", cfg.StartupTimeoutSec)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertCodexMCPServerConfigCreatesAndMergesConfig(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
configPath := filepath.Join(tempDir, "config.toml")
|
||||
initial := strings.Join([]string{
|
||||
`model = "gpt-5.4"`,
|
||||
``,
|
||||
`[mcp_servers.memory]`,
|
||||
`command = "cmd"`,
|
||||
`args = ["/c", "npx"]`,
|
||||
``,
|
||||
}, "\n")
|
||||
if err := os.WriteFile(configPath, []byte(initial), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile returned error: %v", err)
|
||||
}
|
||||
|
||||
err := upsertCodexMCPServerConfig(configPath, gonaviMCPServerID, codexMCPServerConfig{
|
||||
Command: `C:\Program Files\GoNavi\GoNavi.exe`,
|
||||
Args: []string{"mcp-server"},
|
||||
StartupTimeoutSec: defaultCodexMCPStartupTimeoutSecond,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("upsertCodexMCPServerConfig returned error: %v", err)
|
||||
}
|
||||
|
||||
updated, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile returned error: %v", err)
|
||||
}
|
||||
text := string(updated)
|
||||
if !strings.Contains(text, `[mcp_servers.memory]`) {
|
||||
t.Fatalf("expected memory server to be preserved, got %s", text)
|
||||
}
|
||||
if !strings.Contains(text, `[mcp_servers.gonavi]`) {
|
||||
t.Fatalf("expected gonavi section to be created, got %s", text)
|
||||
}
|
||||
if !strings.Contains(text, `command = 'C:\Program Files\GoNavi\GoNavi.exe'`) {
|
||||
t.Fatalf("expected gonavi command to be written, got %s", text)
|
||||
}
|
||||
if !strings.Contains(text, `args = ['mcp-server']`) {
|
||||
t.Fatalf("expected gonavi args to be written, got %s", text)
|
||||
}
|
||||
if !strings.Contains(text, `startup_timeout_sec = 60`) {
|
||||
t.Fatalf("expected startup timeout to be written, got %s", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertCodexMCPServerConfigReplacesExistingBlockAndNestedSections(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
configPath := filepath.Join(tempDir, "config.toml")
|
||||
initial := strings.Join([]string{
|
||||
`model = "gpt-5.4"`,
|
||||
``,
|
||||
`[mcp_servers.gonavi]`,
|
||||
`command = 'old.exe'`,
|
||||
`args = ['old']`,
|
||||
`startup_timeout_sec = 15`,
|
||||
``,
|
||||
`[mcp_servers.gonavi.env]`,
|
||||
`FOO = "bar"`,
|
||||
``,
|
||||
`[projects.'D:\Work\CodeRepos\GoNavi']`,
|
||||
`trust_level = "trusted"`,
|
||||
``,
|
||||
}, "\n")
|
||||
if err := os.WriteFile(configPath, []byte(initial), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile returned error: %v", err)
|
||||
}
|
||||
|
||||
err := upsertCodexMCPServerConfig(configPath, gonaviMCPServerID, codexMCPServerConfig{
|
||||
Command: `C:\Program Files\GoNavi\GoNavi.exe`,
|
||||
Args: []string{"mcp-server"},
|
||||
StartupTimeoutSec: defaultCodexMCPStartupTimeoutSecond,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("upsertCodexMCPServerConfig returned error: %v", err)
|
||||
}
|
||||
|
||||
updated, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile returned error: %v", err)
|
||||
}
|
||||
text := string(updated)
|
||||
if strings.Contains(text, `command = 'old.exe'`) || strings.Contains(text, `[mcp_servers.gonavi.env]`) {
|
||||
t.Fatalf("expected old gonavi block to be replaced, got %s", text)
|
||||
}
|
||||
if !strings.Contains(text, `[projects.'D:\Work\CodeRepos\GoNavi']`) {
|
||||
t.Fatalf("expected unrelated project config to be preserved, got %s", text)
|
||||
}
|
||||
}
|
||||
@@ -136,6 +136,31 @@ type MCPToolCallResult struct {
|
||||
IsError bool `json:"isError"`
|
||||
}
|
||||
|
||||
// MCPClientInstallResult 表示安装 GoNavi 到外部 MCP 客户端配置文件的结果。
|
||||
type MCPClientInstallResult struct {
|
||||
Success bool `json:"success"`
|
||||
Client string `json:"client,omitempty"`
|
||||
Message string `json:"message"`
|
||||
ConfigPath string `json:"configPath,omitempty"`
|
||||
Command string `json:"command,omitempty"`
|
||||
Args []string `json:"args,omitempty"`
|
||||
}
|
||||
|
||||
// MCPClientInstallStatus 表示 GoNavi MCP 在外部客户端中的当前安装状态。
|
||||
type MCPClientInstallStatus struct {
|
||||
Client string `json:"client"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Installed bool `json:"installed"`
|
||||
MatchesCurrent bool `json:"matchesCurrent"`
|
||||
Message string `json:"message"`
|
||||
ConfigPath string `json:"configPath,omitempty"`
|
||||
Command string `json:"command,omitempty"`
|
||||
Args []string `json:"args,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeCodeMCPInstallResult 兼容旧命名,便于平滑迁移到通用结果类型。
|
||||
type ClaudeCodeMCPInstallResult = MCPClientInstallResult
|
||||
|
||||
// SkillScope 表示 Skill 的适用场景
|
||||
type SkillScope string
|
||||
|
||||
|
||||
98
internal/mcpserver/backend.go
Normal file
98
internal/mcpserver/backend.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package mcpserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"GoNavi-Wails/internal/ai"
|
||||
aiservice "GoNavi-Wails/internal/ai/service"
|
||||
appcore "GoNavi-Wails/internal/app"
|
||||
"GoNavi-Wails/internal/appdata"
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
)
|
||||
|
||||
// Backend 抽象 GoNavi 后端能力,便于复用真实 App 和单元测试替身。
|
||||
type Backend interface {
|
||||
Close(context.Context) error
|
||||
GetSavedConnections() ([]connection.SavedConnectionView, error)
|
||||
GetEditableSavedConnection(id string) (connection.SavedConnectionView, error)
|
||||
DBGetDatabases(config connection.ConnectionConfig) connection.QueryResult
|
||||
DBGetTables(config connection.ConnectionConfig, dbName string) connection.QueryResult
|
||||
DBGetColumns(config connection.ConnectionConfig, dbName string, tableName string) connection.QueryResult
|
||||
DBShowCreateTable(config connection.ConnectionConfig, dbName string, tableName string) connection.QueryResult
|
||||
DBQueryMulti(config connection.ConnectionConfig, dbName string, query string, queryID string) connection.QueryResult
|
||||
InspectSQL(dbType string, sql string) appcore.SQLInspection
|
||||
GetSQLSafetyLevel() ai.SQLPermissionLevel
|
||||
}
|
||||
|
||||
// AppBackend 基于现有 internal/app.App 暴露 MCP 所需数据库能力。
|
||||
type AppBackend struct {
|
||||
app *appcore.App
|
||||
}
|
||||
|
||||
func NewAppBackend(ctx context.Context) *AppBackend {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
a := appcore.NewApp()
|
||||
appcore.InitializeLifecycle(a, ctx)
|
||||
return &AppBackend{app: a}
|
||||
}
|
||||
|
||||
func (b *AppBackend) Close(ctx context.Context) error {
|
||||
if b == nil || b.app == nil {
|
||||
return nil
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
b.app.Shutdown(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *AppBackend) GetSavedConnections() ([]connection.SavedConnectionView, error) {
|
||||
return b.app.GetSavedConnections()
|
||||
}
|
||||
|
||||
func (b *AppBackend) GetEditableSavedConnection(id string) (connection.SavedConnectionView, error) {
|
||||
return b.app.GetEditableSavedConnection(id)
|
||||
}
|
||||
|
||||
func (b *AppBackend) DBGetDatabases(config connection.ConnectionConfig) connection.QueryResult {
|
||||
return b.app.DBGetDatabases(config)
|
||||
}
|
||||
|
||||
func (b *AppBackend) DBGetTables(config connection.ConnectionConfig, dbName string) connection.QueryResult {
|
||||
return b.app.DBGetTables(config, dbName)
|
||||
}
|
||||
|
||||
func (b *AppBackend) DBGetColumns(config connection.ConnectionConfig, dbName string, tableName string) connection.QueryResult {
|
||||
return b.app.DBGetColumns(config, dbName, tableName)
|
||||
}
|
||||
|
||||
func (b *AppBackend) DBShowCreateTable(config connection.ConnectionConfig, dbName string, tableName string) connection.QueryResult {
|
||||
return b.app.DBShowCreateTable(config, dbName, tableName)
|
||||
}
|
||||
|
||||
func (b *AppBackend) DBQueryMulti(config connection.ConnectionConfig, dbName string, query string, queryID string) connection.QueryResult {
|
||||
return b.app.DBQueryMulti(config, dbName, query, queryID)
|
||||
}
|
||||
|
||||
func (b *AppBackend) InspectSQL(dbType string, sql string) appcore.SQLInspection {
|
||||
return appcore.InspectSQL(dbType, sql)
|
||||
}
|
||||
|
||||
func (b *AppBackend) GetSQLSafetyLevel() ai.SQLPermissionLevel {
|
||||
inspection, err := aiservice.NewProviderConfigStore(appdata.MustResolveActiveRoot(), nil).Inspect()
|
||||
if err != nil {
|
||||
logger.Error(err, "加载 MCP SQL 安全控制失败,按只读模式回退")
|
||||
return ai.PermissionReadOnly
|
||||
}
|
||||
|
||||
switch inspection.Snapshot.SafetyLevel {
|
||||
case ai.PermissionReadOnly, ai.PermissionReadWrite, ai.PermissionFull:
|
||||
return inspection.Snapshot.SafetyLevel
|
||||
default:
|
||||
return ai.PermissionReadOnly
|
||||
}
|
||||
}
|
||||
29
internal/mcpserver/run.go
Normal file
29
internal/mcpserver/run.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package mcpserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// RunAppStdioServer 启动基于真实 GoNavi App 的 stdio MCP server。
|
||||
func RunAppStdioServer(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
backend := NewAppBackend(ctx)
|
||||
defer backend.Close(ctx)
|
||||
|
||||
return RunStdioServer(ctx, backend)
|
||||
}
|
||||
|
||||
// RunStdioServer 使用指定 backend 启动 stdio MCP server。
|
||||
func RunStdioServer(ctx context.Context, backend Backend) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
server := NewServer(backend)
|
||||
return server.Run(ctx, &mcp.StdioTransport{})
|
||||
}
|
||||
59
internal/mcpserver/server.go
Normal file
59
internal/mcpserver/server.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package mcpserver
|
||||
|
||||
import (
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
func NewServer(backend Backend) *mcp.Server {
|
||||
server := mcp.NewServer(&mcp.Implementation{
|
||||
Name: "gonavi-ai",
|
||||
Version: implementationVersion(),
|
||||
}, nil)
|
||||
|
||||
service := NewService(backend)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "get_connections",
|
||||
Description: "列出当前 GoNavi 已保存的数据库连接,先调用它获取 connectionId。不会返回明文密码等敏感信息。",
|
||||
}, service.GetConnections)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "get_databases",
|
||||
Description: "根据 connectionId 获取数据库/Schema 列表。",
|
||||
}, service.GetDatabases)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "get_tables",
|
||||
Description: "根据 connectionId 和可选 dbName 获取表列表。dbName 为空时优先使用保存连接里的默认数据库。",
|
||||
}, service.GetTables)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "get_columns",
|
||||
Description: "根据 connectionId、可选 dbName、tableName 获取字段定义。",
|
||||
}, service.GetColumns)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "get_table_ddl",
|
||||
Description: "根据 connectionId、可选 dbName、tableName 获取建表或建视图语句。",
|
||||
}, service.GetTableDDL)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "execute_sql",
|
||||
Description: "执行 SQL,支持多语句结果集。执行范围受 GoNavi AI 设置中的安全控制约束;命中允许范围内的 DML/DDL 等非只读语句时,仍必须显式传 allowMutating=true。",
|
||||
}, service.ExecuteSQL)
|
||||
|
||||
return server
|
||||
}
|
||||
|
||||
func implementationVersion() string {
|
||||
if info, ok := debug.ReadBuildInfo(); ok {
|
||||
version := strings.TrimSpace(info.Main.Version)
|
||||
if version != "" && version != "(devel)" {
|
||||
return version
|
||||
}
|
||||
}
|
||||
return "dev"
|
||||
}
|
||||
682
internal/mcpserver/service.go
Normal file
682
internal/mcpserver/service.go
Normal file
@@ -0,0 +1,682 @@
|
||||
package mcpserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"GoNavi-Wails/internal/ai"
|
||||
appcore "GoNavi-Wails/internal/app"
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMaxRowsPerResult = 200
|
||||
maxRowsPerResultLimit = 1000
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
backend Backend
|
||||
}
|
||||
|
||||
func NewService(backend Backend) *Service {
|
||||
return &Service{backend: backend}
|
||||
}
|
||||
|
||||
type emptyArgs struct{}
|
||||
|
||||
type connectionIDArgs struct {
|
||||
ConnectionID string `json:"connectionId" jsonschema:"get_connections 返回的连接 ID"`
|
||||
}
|
||||
|
||||
type databaseArgs struct {
|
||||
ConnectionID string `json:"connectionId" jsonschema:"get_connections 返回的连接 ID"`
|
||||
DBName string `json:"dbName,omitempty" jsonschema:"可选数据库/Schema 名称。为空时优先使用保存连接里的默认数据库"`
|
||||
}
|
||||
|
||||
type tableArgs struct {
|
||||
ConnectionID string `json:"connectionId" jsonschema:"get_connections 返回的连接 ID"`
|
||||
DBName string `json:"dbName,omitempty" jsonschema:"可选数据库/Schema 名称。为空时优先使用保存连接里的默认数据库"`
|
||||
TableName string `json:"tableName" jsonschema:"目标表或视图名称"`
|
||||
}
|
||||
|
||||
type executeSQLArgs struct {
|
||||
ConnectionID string `json:"connectionId" jsonschema:"get_connections 返回的连接 ID"`
|
||||
DBName string `json:"dbName,omitempty" jsonschema:"可选数据库/Schema 名称。为空时优先使用保存连接里的默认数据库"`
|
||||
SQL string `json:"sql" jsonschema:"待执行的 SQL 文本,可以包含多条语句"`
|
||||
AllowMutating bool `json:"allowMutating,omitempty" jsonschema:"当 SQL 包含当前 AI 安全控制允许范围内的 DDL/DML 等非只读语句时,必须显式设为 true"`
|
||||
MaxRowsPerResult int `json:"maxRowsPerResult,omitempty" jsonschema:"每个结果集最多返回多少行。默认 200,最大 1000"`
|
||||
}
|
||||
|
||||
type connectionDescriptor struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Host string `json:"host,omitempty"`
|
||||
Port int `json:"port,omitempty"`
|
||||
Database string `json:"database,omitempty"`
|
||||
Driver string `json:"driver,omitempty"`
|
||||
Topology string `json:"topology,omitempty"`
|
||||
Target string `json:"target,omitempty"`
|
||||
UseSSH bool `json:"useSSH,omitempty"`
|
||||
UseProxy bool `json:"useProxy,omitempty"`
|
||||
UseHTTPTunnel bool `json:"useHttpTunnel,omitempty"`
|
||||
DefaultDatabase string `json:"defaultDatabase,omitempty"`
|
||||
}
|
||||
|
||||
type getConnectionsResult struct {
|
||||
Connections []connectionDescriptor `json:"connections"`
|
||||
}
|
||||
|
||||
type getDatabasesResult struct {
|
||||
ConnectionID string `json:"connectionId"`
|
||||
Databases []string `json:"databases"`
|
||||
}
|
||||
|
||||
type getTablesResult struct {
|
||||
ConnectionID string `json:"connectionId"`
|
||||
DBName string `json:"dbName,omitempty"`
|
||||
Tables []string `json:"tables"`
|
||||
}
|
||||
|
||||
type getColumnsResult struct {
|
||||
ConnectionID string `json:"connectionId"`
|
||||
DBName string `json:"dbName,omitempty"`
|
||||
TableName string `json:"tableName"`
|
||||
Columns []connection.ColumnDefinition `json:"columns"`
|
||||
}
|
||||
|
||||
type getTableDDLResult struct {
|
||||
ConnectionID string `json:"connectionId"`
|
||||
DBName string `json:"dbName,omitempty"`
|
||||
TableName string `json:"tableName"`
|
||||
DDL string `json:"ddl"`
|
||||
}
|
||||
|
||||
type sqlStatementSummary struct {
|
||||
Index int `json:"index"`
|
||||
Keyword string `json:"keyword,omitempty"`
|
||||
ReadOnly bool `json:"readOnly"`
|
||||
}
|
||||
|
||||
type sqlResultSet struct {
|
||||
StatementIndex int `json:"statementIndex,omitempty"`
|
||||
Columns []string `json:"columns"`
|
||||
Rows []map[string]interface{} `json:"rows"`
|
||||
Messages []string `json:"messages,omitempty"`
|
||||
RowCount int `json:"rowCount"`
|
||||
Truncated bool `json:"truncated,omitempty"`
|
||||
}
|
||||
|
||||
type executeSQLResult struct {
|
||||
ConnectionID string `json:"connectionId"`
|
||||
DBName string `json:"dbName,omitempty"`
|
||||
StatementCount int `json:"statementCount"`
|
||||
ReadOnly bool `json:"readOnly"`
|
||||
QueryID string `json:"queryId,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Truncated bool `json:"truncated,omitempty"`
|
||||
Statements []sqlStatementSummary `json:"statements"`
|
||||
Results []sqlResultSet `json:"results"`
|
||||
}
|
||||
|
||||
func (s *Service) GetConnections(ctx context.Context, req *mcp.CallToolRequest, args emptyArgs) (*mcp.CallToolResult, getConnectionsResult, error) {
|
||||
_ = ctx
|
||||
_ = req
|
||||
_ = args
|
||||
|
||||
items, err := s.backend.GetSavedConnections()
|
||||
if err != nil {
|
||||
return toolError("获取已保存连接失败: %v", err), getConnectionsResult{}, nil
|
||||
}
|
||||
|
||||
result := getConnectionsResult{
|
||||
Connections: make([]connectionDescriptor, 0, len(items)),
|
||||
}
|
||||
for _, item := range items {
|
||||
cfg := item.Config
|
||||
result.Connections = append(result.Connections, connectionDescriptor{
|
||||
ID: item.ID,
|
||||
Name: item.Name,
|
||||
Type: strings.TrimSpace(cfg.Type),
|
||||
Host: strings.TrimSpace(cfg.Host),
|
||||
Port: cfg.Port,
|
||||
Database: strings.TrimSpace(cfg.Database),
|
||||
Driver: strings.TrimSpace(cfg.Driver),
|
||||
Topology: strings.TrimSpace(cfg.Topology),
|
||||
Target: describeConnectionTarget(cfg),
|
||||
UseSSH: cfg.UseSSH,
|
||||
UseProxy: cfg.UseProxy,
|
||||
UseHTTPTunnel: cfg.UseHTTPTunnel,
|
||||
DefaultDatabase: strings.TrimSpace(cfg.Database),
|
||||
})
|
||||
}
|
||||
return successResult(), result, nil
|
||||
}
|
||||
|
||||
func (s *Service) GetDatabases(ctx context.Context, req *mcp.CallToolRequest, args connectionIDArgs) (*mcp.CallToolResult, getDatabasesResult, error) {
|
||||
_ = ctx
|
||||
_ = req
|
||||
|
||||
view, errResult := s.resolveConnection(args.ConnectionID)
|
||||
if errResult != nil {
|
||||
return errResult, getDatabasesResult{}, nil
|
||||
}
|
||||
|
||||
queryResult := s.backend.DBGetDatabases(view.Config)
|
||||
if !queryResult.Success {
|
||||
return toolError("获取数据库列表失败: %s", strings.TrimSpace(queryResult.Message)), getDatabasesResult{}, nil
|
||||
}
|
||||
|
||||
databases, err := decodeNamedStringSlice(queryResult.Data, "Database", "database", "name")
|
||||
if err != nil {
|
||||
return toolError("解析数据库列表失败: %v", err), getDatabasesResult{}, nil
|
||||
}
|
||||
|
||||
return successResult(), getDatabasesResult{
|
||||
ConnectionID: view.ID,
|
||||
Databases: ensureNonNilStrings(databases),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) GetTables(ctx context.Context, req *mcp.CallToolRequest, args databaseArgs) (*mcp.CallToolResult, getTablesResult, error) {
|
||||
_ = ctx
|
||||
_ = req
|
||||
|
||||
view, errResult := s.resolveConnection(args.ConnectionID)
|
||||
if errResult != nil {
|
||||
return errResult, getTablesResult{}, nil
|
||||
}
|
||||
|
||||
dbName := effectiveDBName(args.DBName, view.Config)
|
||||
queryResult := s.backend.DBGetTables(view.Config, dbName)
|
||||
if !queryResult.Success {
|
||||
return toolError("获取表列表失败: %s", strings.TrimSpace(queryResult.Message)), getTablesResult{}, nil
|
||||
}
|
||||
|
||||
tables, err := decodeNamedStringSlice(queryResult.Data, "Table", "table", "name")
|
||||
if err != nil {
|
||||
return toolError("解析表列表失败: %v", err), getTablesResult{}, nil
|
||||
}
|
||||
|
||||
return successResult(), getTablesResult{
|
||||
ConnectionID: view.ID,
|
||||
DBName: dbName,
|
||||
Tables: ensureNonNilStrings(tables),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) GetColumns(ctx context.Context, req *mcp.CallToolRequest, args tableArgs) (*mcp.CallToolResult, getColumnsResult, error) {
|
||||
_ = ctx
|
||||
_ = req
|
||||
|
||||
view, errResult := s.resolveConnection(args.ConnectionID)
|
||||
if errResult != nil {
|
||||
return errResult, getColumnsResult{}, nil
|
||||
}
|
||||
|
||||
tableName := strings.TrimSpace(args.TableName)
|
||||
if tableName == "" {
|
||||
return toolError("tableName 不能为空"), getColumnsResult{}, nil
|
||||
}
|
||||
|
||||
dbName := effectiveDBName(args.DBName, view.Config)
|
||||
queryResult := s.backend.DBGetColumns(view.Config, dbName, tableName)
|
||||
if !queryResult.Success {
|
||||
return toolError("获取字段列表失败: %s", strings.TrimSpace(queryResult.Message)), getColumnsResult{}, nil
|
||||
}
|
||||
|
||||
columns, err := decodeColumns(queryResult.Data)
|
||||
if err != nil {
|
||||
return toolError("解析字段列表失败: %v", err), getColumnsResult{}, nil
|
||||
}
|
||||
|
||||
return successResult(), getColumnsResult{
|
||||
ConnectionID: view.ID,
|
||||
DBName: dbName,
|
||||
TableName: tableName,
|
||||
Columns: ensureNonNilColumns(columns),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) GetTableDDL(ctx context.Context, req *mcp.CallToolRequest, args tableArgs) (*mcp.CallToolResult, getTableDDLResult, error) {
|
||||
_ = ctx
|
||||
_ = req
|
||||
|
||||
view, errResult := s.resolveConnection(args.ConnectionID)
|
||||
if errResult != nil {
|
||||
return errResult, getTableDDLResult{}, nil
|
||||
}
|
||||
|
||||
tableName := strings.TrimSpace(args.TableName)
|
||||
if tableName == "" {
|
||||
return toolError("tableName 不能为空"), getTableDDLResult{}, nil
|
||||
}
|
||||
|
||||
dbName := effectiveDBName(args.DBName, view.Config)
|
||||
queryResult := s.backend.DBShowCreateTable(view.Config, dbName, tableName)
|
||||
if !queryResult.Success {
|
||||
return toolError("获取建表语句失败: %s", strings.TrimSpace(queryResult.Message)), getTableDDLResult{}, nil
|
||||
}
|
||||
|
||||
ddl, err := decodeString(queryResult.Data)
|
||||
if err != nil {
|
||||
return toolError("解析建表语句失败: %v", err), getTableDDLResult{}, nil
|
||||
}
|
||||
|
||||
return successResult(), getTableDDLResult{
|
||||
ConnectionID: view.ID,
|
||||
DBName: dbName,
|
||||
TableName: tableName,
|
||||
DDL: ddl,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) ExecuteSQL(ctx context.Context, req *mcp.CallToolRequest, args executeSQLArgs) (*mcp.CallToolResult, executeSQLResult, error) {
|
||||
_ = ctx
|
||||
_ = req
|
||||
|
||||
view, errResult := s.resolveConnection(args.ConnectionID)
|
||||
if errResult != nil {
|
||||
return errResult, executeSQLResult{}, nil
|
||||
}
|
||||
|
||||
sqlText := strings.TrimSpace(args.SQL)
|
||||
if sqlText == "" {
|
||||
return toolError("sql 不能为空"), executeSQLResult{}, nil
|
||||
}
|
||||
|
||||
inspection := s.backend.InspectSQL(view.Config.Type, sqlText)
|
||||
if inspection.StatementCount == 0 {
|
||||
return toolError("未识别到可执行的 SQL 语句"), executeSQLResult{}, nil
|
||||
}
|
||||
|
||||
safetyLevel := normalizeSQLSafetyLevel(s.backend.GetSQLSafetyLevel())
|
||||
safetyDecision := evaluateSQLSafety(safetyLevel, inspection)
|
||||
if len(safetyDecision.disallowed) > 0 {
|
||||
return toolError("%s", buildSafetyDeniedMessage(safetyLevel, safetyDecision.disallowed)), executeSQLResult{}, nil
|
||||
}
|
||||
if safetyDecision.requiresConfirm && !args.AllowMutating {
|
||||
return toolError("当前 SQL 已通过 GoNavi AI 安全控制(%s),但包含非只读语句 %s,请显式传入 allowMutating=true 后重试", safetyLevelDisplayName(safetyLevel), formatSafetyStatements(safetyDecision.confirmRequired)), executeSQLResult{}, nil
|
||||
}
|
||||
|
||||
dbName := effectiveDBName(args.DBName, view.Config)
|
||||
queryResult := s.backend.DBQueryMulti(view.Config, dbName, sqlText, "")
|
||||
if !queryResult.Success {
|
||||
return toolError("SQL 执行失败: %s", strings.TrimSpace(queryResult.Message)), executeSQLResult{}, nil
|
||||
}
|
||||
|
||||
resultSets, err := decodeResultSets(queryResult.Data)
|
||||
if err != nil {
|
||||
return toolError("解析 SQL 执行结果失败: %v", err), executeSQLResult{}, nil
|
||||
}
|
||||
|
||||
normalizedResults, truncated := normalizeResultSets(resultSets, normalizeMaxRowsPerResult(args.MaxRowsPerResult))
|
||||
return successResult(), executeSQLResult{
|
||||
ConnectionID: view.ID,
|
||||
DBName: dbName,
|
||||
StatementCount: inspection.StatementCount,
|
||||
ReadOnly: inspection.ReadOnly,
|
||||
QueryID: strings.TrimSpace(queryResult.QueryID),
|
||||
Message: strings.TrimSpace(queryResult.Message),
|
||||
Truncated: truncated,
|
||||
Statements: toStatementSummaries(inspection.Statements),
|
||||
Results: normalizedResults,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func successResult() *mcp.CallToolResult {
|
||||
return &mcp.CallToolResult{}
|
||||
}
|
||||
|
||||
func toolError(format string, args ...interface{}) *mcp.CallToolResult {
|
||||
return &mcp.CallToolResult{
|
||||
IsError: true,
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: fmt.Sprintf(format, args...)},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) resolveConnection(connectionID string) (connection.SavedConnectionView, *mcp.CallToolResult) {
|
||||
id := strings.TrimSpace(connectionID)
|
||||
if id == "" {
|
||||
return connection.SavedConnectionView{}, toolError("connectionId 不能为空")
|
||||
}
|
||||
view, err := s.backend.GetEditableSavedConnection(id)
|
||||
if err != nil {
|
||||
return connection.SavedConnectionView{}, toolError("加载连接 %s 失败: %v", id, err)
|
||||
}
|
||||
return view, nil
|
||||
}
|
||||
|
||||
func effectiveDBName(input string, config connection.ConnectionConfig) string {
|
||||
if trimmed := strings.TrimSpace(input); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
return strings.TrimSpace(config.Database)
|
||||
}
|
||||
|
||||
func describeConnectionTarget(config connection.ConnectionConfig) string {
|
||||
dbType := strings.ToLower(strings.TrimSpace(config.Type))
|
||||
switch dbType {
|
||||
case "sqlite", "duckdb":
|
||||
if path := strings.TrimSpace(config.Database); path != "" {
|
||||
return path
|
||||
}
|
||||
}
|
||||
if len(config.Hosts) > 0 {
|
||||
return strings.Join(config.Hosts, ",")
|
||||
}
|
||||
if host := strings.TrimSpace(config.Host); host != "" {
|
||||
if config.Port > 0 {
|
||||
return fmt.Sprintf("%s:%d", host, config.Port)
|
||||
}
|
||||
return host
|
||||
}
|
||||
if uri := strings.TrimSpace(config.URI); uri != "" {
|
||||
return uri
|
||||
}
|
||||
if dsn := strings.TrimSpace(config.DSN); dsn != "" {
|
||||
return dsn
|
||||
}
|
||||
return strings.TrimSpace(config.Database)
|
||||
}
|
||||
|
||||
func decodeNamedStringSlice(data interface{}, keys ...string) ([]string, error) {
|
||||
switch items := data.(type) {
|
||||
case nil:
|
||||
return []string{}, nil
|
||||
case []string:
|
||||
return ensureNonNilStrings(append([]string(nil), items...)), nil
|
||||
case []map[string]string:
|
||||
result := make([]string, 0, len(items))
|
||||
for _, item := range items {
|
||||
result = append(result, pickNamedStringFromStringMap(item, keys...))
|
||||
}
|
||||
return result, nil
|
||||
case []map[string]interface{}:
|
||||
result := make([]string, 0, len(items))
|
||||
for _, item := range items {
|
||||
result = append(result, pickNamedStringFromAnyMap(item, keys...))
|
||||
}
|
||||
return result, nil
|
||||
default:
|
||||
var decoded []map[string]interface{}
|
||||
if err := remarshal(data, &decoded); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return decodeNamedStringSlice(decoded, keys...)
|
||||
}
|
||||
}
|
||||
|
||||
func pickNamedStringFromStringMap(item map[string]string, keys ...string) string {
|
||||
for _, key := range keys {
|
||||
if value := strings.TrimSpace(item[key]); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
for _, value := range item {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func pickNamedStringFromAnyMap(item map[string]interface{}, keys ...string) string {
|
||||
for _, key := range keys {
|
||||
if value, ok := item[key]; ok {
|
||||
if text := strings.TrimSpace(fmt.Sprint(value)); text != "" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, value := range item {
|
||||
if text := strings.TrimSpace(fmt.Sprint(value)); text != "" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func decodeColumns(data interface{}) ([]connection.ColumnDefinition, error) {
|
||||
switch cols := data.(type) {
|
||||
case nil:
|
||||
return []connection.ColumnDefinition{}, nil
|
||||
case []connection.ColumnDefinition:
|
||||
return ensureNonNilColumns(append([]connection.ColumnDefinition(nil), cols...)), nil
|
||||
default:
|
||||
var decoded []connection.ColumnDefinition
|
||||
if err := remarshal(data, &decoded); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ensureNonNilColumns(decoded), nil
|
||||
}
|
||||
}
|
||||
|
||||
func decodeString(data interface{}) (string, error) {
|
||||
switch value := data.(type) {
|
||||
case nil:
|
||||
return "", nil
|
||||
case string:
|
||||
return value, nil
|
||||
default:
|
||||
return fmt.Sprint(value), nil
|
||||
}
|
||||
}
|
||||
|
||||
func decodeResultSets(data interface{}) ([]connection.ResultSetData, error) {
|
||||
switch items := data.(type) {
|
||||
case nil:
|
||||
return []connection.ResultSetData{}, nil
|
||||
case []connection.ResultSetData:
|
||||
return ensureNonNilResultSets(append([]connection.ResultSetData(nil), items...)), nil
|
||||
default:
|
||||
var decoded []connection.ResultSetData
|
||||
if err := remarshal(data, &decoded); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ensureNonNilResultSets(decoded), nil
|
||||
}
|
||||
}
|
||||
|
||||
func remarshal(from interface{}, to interface{}) error {
|
||||
payload, err := json.Marshal(from)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(payload, to)
|
||||
}
|
||||
|
||||
func normalizeMaxRowsPerResult(input int) int {
|
||||
if input <= 0 {
|
||||
return defaultMaxRowsPerResult
|
||||
}
|
||||
if input > maxRowsPerResultLimit {
|
||||
return maxRowsPerResultLimit
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
func normalizeResultSets(resultSets []connection.ResultSetData, maxRows int) ([]sqlResultSet, bool) {
|
||||
normalized := make([]sqlResultSet, 0, len(resultSets))
|
||||
truncatedAny := false
|
||||
for _, resultSet := range resultSets {
|
||||
rows := ensureNonNilRows(resultSet.Rows)
|
||||
rowCount := len(rows)
|
||||
truncated := false
|
||||
if maxRows > 0 && len(rows) > maxRows {
|
||||
rows = append([]map[string]interface{}(nil), rows[:maxRows]...)
|
||||
truncated = true
|
||||
truncatedAny = true
|
||||
}
|
||||
normalized = append(normalized, sqlResultSet{
|
||||
StatementIndex: resultSet.StatementIndex,
|
||||
Columns: ensureNonNilStrings(append([]string(nil), resultSet.Columns...)),
|
||||
Rows: rows,
|
||||
Messages: ensureNonNilStrings(append([]string(nil), resultSet.Messages...)),
|
||||
RowCount: rowCount,
|
||||
Truncated: truncated,
|
||||
})
|
||||
}
|
||||
return normalized, truncatedAny
|
||||
}
|
||||
|
||||
func toStatementSummaries(items []appcore.SQLStatementInspection) []sqlStatementSummary {
|
||||
result := make([]sqlStatementSummary, 0, len(items))
|
||||
for _, item := range items {
|
||||
result = append(result, sqlStatementSummary{
|
||||
Index: item.Index,
|
||||
Keyword: item.Keyword,
|
||||
ReadOnly: item.ReadOnly,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func ensureNonNilStrings(items []string) []string {
|
||||
if items == nil {
|
||||
return []string{}
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
func ensureNonNilColumns(items []connection.ColumnDefinition) []connection.ColumnDefinition {
|
||||
if items == nil {
|
||||
return []connection.ColumnDefinition{}
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
func ensureNonNilRows(items []map[string]interface{}) []map[string]interface{} {
|
||||
if items == nil {
|
||||
return []map[string]interface{}{}
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
func ensureNonNilResultSets(items []connection.ResultSetData) []connection.ResultSetData {
|
||||
if items == nil {
|
||||
return []connection.ResultSetData{}
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
type sqlSafetyStatement struct {
|
||||
Index int
|
||||
Keyword string
|
||||
OperationType ai.SQLOperationType
|
||||
}
|
||||
|
||||
type sqlSafetyDecision struct {
|
||||
requiresConfirm bool
|
||||
disallowed []sqlSafetyStatement
|
||||
confirmRequired []sqlSafetyStatement
|
||||
}
|
||||
|
||||
func evaluateSQLSafety(level ai.SQLPermissionLevel, inspection appcore.SQLInspection) sqlSafetyDecision {
|
||||
decision := sqlSafetyDecision{
|
||||
disallowed: []sqlSafetyStatement{},
|
||||
confirmRequired: []sqlSafetyStatement{},
|
||||
}
|
||||
|
||||
for _, stmt := range inspection.Statements {
|
||||
statement := sqlSafetyStatement{
|
||||
Index: stmt.Index,
|
||||
Keyword: strings.TrimSpace(stmt.Keyword),
|
||||
OperationType: classifyStatementOperation(stmt),
|
||||
}
|
||||
if !isOperationAllowed(level, statement.OperationType) {
|
||||
decision.disallowed = append(decision.disallowed, statement)
|
||||
continue
|
||||
}
|
||||
if statement.OperationType != ai.SQLOpQuery {
|
||||
decision.requiresConfirm = true
|
||||
decision.confirmRequired = append(decision.confirmRequired, statement)
|
||||
}
|
||||
}
|
||||
|
||||
return decision
|
||||
}
|
||||
|
||||
func classifyStatementOperation(stmt appcore.SQLStatementInspection) ai.SQLOperationType {
|
||||
if stmt.ReadOnly {
|
||||
return ai.SQLOpQuery
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(stmt.Keyword)) {
|
||||
case "insert", "update", "delete", "replace", "merge", "upsert":
|
||||
return ai.SQLOpDML
|
||||
case "create", "alter", "drop", "truncate", "rename":
|
||||
return ai.SQLOpDDL
|
||||
default:
|
||||
return ai.SQLOpOther
|
||||
}
|
||||
}
|
||||
|
||||
func isOperationAllowed(level ai.SQLPermissionLevel, opType ai.SQLOperationType) bool {
|
||||
switch normalizeSQLSafetyLevel(level) {
|
||||
case ai.PermissionReadOnly:
|
||||
return opType == ai.SQLOpQuery
|
||||
case ai.PermissionReadWrite:
|
||||
return opType == ai.SQLOpQuery || opType == ai.SQLOpDML
|
||||
case ai.PermissionFull:
|
||||
return opType == ai.SQLOpQuery || opType == ai.SQLOpDML || opType == ai.SQLOpDDL
|
||||
default:
|
||||
return opType == ai.SQLOpQuery
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeSQLSafetyLevel(level ai.SQLPermissionLevel) ai.SQLPermissionLevel {
|
||||
switch level {
|
||||
case ai.PermissionReadOnly, ai.PermissionReadWrite, ai.PermissionFull:
|
||||
return level
|
||||
default:
|
||||
return ai.PermissionReadOnly
|
||||
}
|
||||
}
|
||||
|
||||
func buildSafetyDeniedMessage(level ai.SQLPermissionLevel, statements []sqlSafetyStatement) string {
|
||||
return fmt.Sprintf("当前 GoNavi AI 安全控制为%s,已阻止以下语句:%s。%s", safetyLevelDisplayName(level), formatSafetyStatements(statements), safetyLevelRuleText(level))
|
||||
}
|
||||
|
||||
func safetyLevelDisplayName(level ai.SQLPermissionLevel) string {
|
||||
switch normalizeSQLSafetyLevel(level) {
|
||||
case ai.PermissionReadOnly:
|
||||
return "只读模式"
|
||||
case ai.PermissionReadWrite:
|
||||
return "读写模式"
|
||||
case ai.PermissionFull:
|
||||
return "完全模式"
|
||||
default:
|
||||
return "只读模式"
|
||||
}
|
||||
}
|
||||
|
||||
func safetyLevelRuleText(level ai.SQLPermissionLevel) string {
|
||||
switch normalizeSQLSafetyLevel(level) {
|
||||
case ai.PermissionReadOnly:
|
||||
return "只读模式仅允许查询语句。"
|
||||
case ai.PermissionReadWrite:
|
||||
return "读写模式仅允许查询和 DML 语句。"
|
||||
case ai.PermissionFull:
|
||||
return "完全模式仅允许查询、DML 和 DDL;未识别操作仍会被阻止。"
|
||||
default:
|
||||
return "只读模式仅允许查询语句。"
|
||||
}
|
||||
}
|
||||
|
||||
func formatSafetyStatements(statements []sqlSafetyStatement) string {
|
||||
parts := make([]string, 0, len(statements))
|
||||
for _, stmt := range statements {
|
||||
keyword := strings.TrimSpace(stmt.Keyword)
|
||||
if keyword == "" {
|
||||
keyword = "unknown"
|
||||
}
|
||||
parts = append(parts, fmt.Sprintf("#%d %s(%s)", stmt.Index, strings.ToLower(keyword), strings.ToUpper(string(stmt.OperationType))))
|
||||
}
|
||||
return strings.Join(parts, ",")
|
||||
}
|
||||
431
internal/mcpserver/service_test.go
Normal file
431
internal/mcpserver/service_test.go
Normal file
@@ -0,0 +1,431 @@
|
||||
package mcpserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/ai"
|
||||
appcore "GoNavi-Wails/internal/app"
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
type fakeBackend struct {
|
||||
savedConnections []connection.SavedConnectionView
|
||||
savedConnectionsErr error
|
||||
editableConnection connection.SavedConnectionView
|
||||
editableErr error
|
||||
databasesResult connection.QueryResult
|
||||
tablesResult connection.QueryResult
|
||||
columnsResult connection.QueryResult
|
||||
ddlResult connection.QueryResult
|
||||
queryResult connection.QueryResult
|
||||
inspection appcore.SQLInspection
|
||||
safetyLevel ai.SQLPermissionLevel
|
||||
queryCalled bool
|
||||
}
|
||||
|
||||
func (f *fakeBackend) Close(context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeBackend) GetSavedConnections() ([]connection.SavedConnectionView, error) {
|
||||
return f.savedConnections, f.savedConnectionsErr
|
||||
}
|
||||
|
||||
func (f *fakeBackend) GetEditableSavedConnection(id string) (connection.SavedConnectionView, error) {
|
||||
return f.editableConnection, f.editableErr
|
||||
}
|
||||
|
||||
func (f *fakeBackend) DBGetDatabases(config connection.ConnectionConfig) connection.QueryResult {
|
||||
return f.databasesResult
|
||||
}
|
||||
|
||||
func (f *fakeBackend) DBGetTables(config connection.ConnectionConfig, dbName string) connection.QueryResult {
|
||||
return f.tablesResult
|
||||
}
|
||||
|
||||
func (f *fakeBackend) DBGetColumns(config connection.ConnectionConfig, dbName string, tableName string) connection.QueryResult {
|
||||
return f.columnsResult
|
||||
}
|
||||
|
||||
func (f *fakeBackend) DBShowCreateTable(config connection.ConnectionConfig, dbName string, tableName string) connection.QueryResult {
|
||||
return f.ddlResult
|
||||
}
|
||||
|
||||
func (f *fakeBackend) DBQueryMulti(config connection.ConnectionConfig, dbName string, query string, queryID string) connection.QueryResult {
|
||||
f.queryCalled = true
|
||||
return f.queryResult
|
||||
}
|
||||
|
||||
func (f *fakeBackend) InspectSQL(dbType string, sql string) appcore.SQLInspection {
|
||||
return f.inspection
|
||||
}
|
||||
|
||||
func (f *fakeBackend) GetSQLSafetyLevel() ai.SQLPermissionLevel {
|
||||
if f.safetyLevel == "" {
|
||||
return ai.PermissionReadOnly
|
||||
}
|
||||
return f.safetyLevel
|
||||
}
|
||||
|
||||
func TestGetConnectionsReturnsSavedConnectionSummaries(t *testing.T) {
|
||||
backend := &fakeBackend{
|
||||
savedConnections: []connection.SavedConnectionView{
|
||||
{
|
||||
ID: "mysql-main",
|
||||
Name: "MySQL Main",
|
||||
Config: connection.ConnectionConfig{
|
||||
Type: "mysql",
|
||||
Host: "10.0.0.8",
|
||||
Port: 3306,
|
||||
Database: "app",
|
||||
UseSSH: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "duckdb-local",
|
||||
Name: "DuckDB Local",
|
||||
Config: connection.ConnectionConfig{
|
||||
Type: "duckdb",
|
||||
Database: `C:\data\example.duckdb`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
service := NewService(backend)
|
||||
result, out, err := service.GetConnections(context.Background(), nil, emptyArgs{})
|
||||
if err != nil {
|
||||
t.Fatalf("GetConnections returned error: %v", err)
|
||||
}
|
||||
if result == nil || result.IsError {
|
||||
t.Fatalf("expected success result, got %#v", result)
|
||||
}
|
||||
if len(out.Connections) != 2 {
|
||||
t.Fatalf("expected 2 connections, got %d", len(out.Connections))
|
||||
}
|
||||
if out.Connections[0].Target != "10.0.0.8:3306" {
|
||||
t.Fatalf("unexpected mysql target: %q", out.Connections[0].Target)
|
||||
}
|
||||
if out.Connections[1].Target != `C:\data\example.duckdb` {
|
||||
t.Fatalf("unexpected duckdb target: %q", out.Connections[1].Target)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteSQLRejectsMutatingStatementsWithoutAllowMutating(t *testing.T) {
|
||||
backend := &fakeBackend{
|
||||
editableConnection: connection.SavedConnectionView{
|
||||
ID: "mysql-main",
|
||||
Config: connection.ConnectionConfig{
|
||||
Type: "mysql",
|
||||
Database: "app",
|
||||
},
|
||||
},
|
||||
inspection: appcore.SQLInspection{
|
||||
StatementCount: 1,
|
||||
ReadOnly: false,
|
||||
Statements: []appcore.SQLStatementInspection{
|
||||
{Index: 1, Keyword: "delete", ReadOnly: false},
|
||||
},
|
||||
},
|
||||
safetyLevel: ai.PermissionReadWrite,
|
||||
}
|
||||
|
||||
service := NewService(backend)
|
||||
result, _, err := service.ExecuteSQL(context.Background(), nil, executeSQLArgs{
|
||||
ConnectionID: "mysql-main",
|
||||
SQL: "delete from users where id = 1",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteSQL returned error: %v", err)
|
||||
}
|
||||
if result == nil || !result.IsError {
|
||||
t.Fatalf("expected tool error, got %#v", result)
|
||||
}
|
||||
if !strings.Contains(firstTextContent(result), "allowMutating=true") {
|
||||
t.Fatalf("unexpected error text: %q", firstTextContent(result))
|
||||
}
|
||||
if backend.queryCalled {
|
||||
t.Fatalf("expected SQL not to execute when allowMutating is false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteSQLRejectsMutatingStatementsWhenAISafetyIsReadOnly(t *testing.T) {
|
||||
backend := &fakeBackend{
|
||||
editableConnection: connection.SavedConnectionView{
|
||||
ID: "mysql-main",
|
||||
Config: connection.ConnectionConfig{
|
||||
Type: "mysql",
|
||||
Database: "app",
|
||||
},
|
||||
},
|
||||
inspection: appcore.SQLInspection{
|
||||
StatementCount: 1,
|
||||
ReadOnly: false,
|
||||
Statements: []appcore.SQLStatementInspection{
|
||||
{Index: 1, Keyword: "delete", ReadOnly: false},
|
||||
},
|
||||
},
|
||||
safetyLevel: ai.PermissionReadOnly,
|
||||
}
|
||||
|
||||
service := NewService(backend)
|
||||
result, _, err := service.ExecuteSQL(context.Background(), nil, executeSQLArgs{
|
||||
ConnectionID: "mysql-main",
|
||||
SQL: "delete from users where id = 1",
|
||||
AllowMutating: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteSQL returned error: %v", err)
|
||||
}
|
||||
if result == nil || !result.IsError {
|
||||
t.Fatalf("expected tool error, got %#v", result)
|
||||
}
|
||||
if !strings.Contains(firstTextContent(result), "只读模式") {
|
||||
t.Fatalf("unexpected error text: %q", firstTextContent(result))
|
||||
}
|
||||
if backend.queryCalled {
|
||||
t.Fatalf("expected SQL not to execute when AI safety is readonly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteSQLRejectsDDLWhenAISafetyIsReadWrite(t *testing.T) {
|
||||
backend := &fakeBackend{
|
||||
editableConnection: connection.SavedConnectionView{
|
||||
ID: "mysql-main",
|
||||
Config: connection.ConnectionConfig{
|
||||
Type: "mysql",
|
||||
Database: "app",
|
||||
},
|
||||
},
|
||||
inspection: appcore.SQLInspection{
|
||||
StatementCount: 1,
|
||||
ReadOnly: false,
|
||||
Statements: []appcore.SQLStatementInspection{
|
||||
{Index: 1, Keyword: "drop", ReadOnly: false},
|
||||
},
|
||||
},
|
||||
safetyLevel: ai.PermissionReadWrite,
|
||||
}
|
||||
|
||||
service := NewService(backend)
|
||||
result, _, err := service.ExecuteSQL(context.Background(), nil, executeSQLArgs{
|
||||
ConnectionID: "mysql-main",
|
||||
SQL: "drop table users",
|
||||
AllowMutating: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteSQL returned error: %v", err)
|
||||
}
|
||||
if result == nil || !result.IsError {
|
||||
t.Fatalf("expected tool error, got %#v", result)
|
||||
}
|
||||
text := firstTextContent(result)
|
||||
if !strings.Contains(text, "读写模式") || !strings.Contains(text, "DDL") {
|
||||
t.Fatalf("unexpected error text: %q", text)
|
||||
}
|
||||
if backend.queryCalled {
|
||||
t.Fatalf("expected SQL not to execute when AI safety blocks DDL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteSQLRejectsMixedStatementsWhenAISafetyBlocksLaterStatement(t *testing.T) {
|
||||
backend := &fakeBackend{
|
||||
editableConnection: connection.SavedConnectionView{
|
||||
ID: "mysql-main",
|
||||
Config: connection.ConnectionConfig{
|
||||
Type: "mysql",
|
||||
Database: "app",
|
||||
},
|
||||
},
|
||||
inspection: appcore.SQLInspection{
|
||||
StatementCount: 2,
|
||||
ReadOnly: false,
|
||||
Statements: []appcore.SQLStatementInspection{
|
||||
{Index: 1, Keyword: "select", ReadOnly: true},
|
||||
{Index: 2, Keyword: "delete", ReadOnly: false},
|
||||
},
|
||||
},
|
||||
safetyLevel: ai.PermissionReadOnly,
|
||||
}
|
||||
|
||||
service := NewService(backend)
|
||||
result, _, err := service.ExecuteSQL(context.Background(), nil, executeSQLArgs{
|
||||
ConnectionID: "mysql-main",
|
||||
SQL: "select * from users; delete from users where id = 1",
|
||||
AllowMutating: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteSQL returned error: %v", err)
|
||||
}
|
||||
if result == nil || !result.IsError {
|
||||
t.Fatalf("expected tool error, got %#v", result)
|
||||
}
|
||||
if !strings.Contains(firstTextContent(result), "#2 delete") {
|
||||
t.Fatalf("unexpected error text: %q", firstTextContent(result))
|
||||
}
|
||||
if backend.queryCalled {
|
||||
t.Fatalf("expected SQL not to execute when a later statement is blocked")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteSQLAllowsDMLWhenAISafetyIsReadWriteAndAllowMutating(t *testing.T) {
|
||||
backend := &fakeBackend{
|
||||
editableConnection: connection.SavedConnectionView{
|
||||
ID: "mysql-main",
|
||||
Config: connection.ConnectionConfig{
|
||||
Type: "mysql",
|
||||
Database: "app",
|
||||
},
|
||||
},
|
||||
inspection: appcore.SQLInspection{
|
||||
StatementCount: 1,
|
||||
ReadOnly: false,
|
||||
Statements: []appcore.SQLStatementInspection{
|
||||
{Index: 1, Keyword: "insert", ReadOnly: false},
|
||||
},
|
||||
},
|
||||
safetyLevel: ai.PermissionReadWrite,
|
||||
queryResult: connection.QueryResult{
|
||||
Success: true,
|
||||
Data: []connection.ResultSetData{},
|
||||
},
|
||||
}
|
||||
|
||||
service := NewService(backend)
|
||||
result, out, err := service.ExecuteSQL(context.Background(), nil, executeSQLArgs{
|
||||
ConnectionID: "mysql-main",
|
||||
SQL: "insert into users(id) values (1)",
|
||||
AllowMutating: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteSQL returned error: %v", err)
|
||||
}
|
||||
if result == nil || result.IsError {
|
||||
t.Fatalf("expected success result, got %#v", result)
|
||||
}
|
||||
if !backend.queryCalled {
|
||||
t.Fatalf("expected SQL to be executed")
|
||||
}
|
||||
if out.ReadOnly {
|
||||
t.Fatalf("expected mutating SQL result, got %#v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteSQLAllowsDDLWhenAISafetyIsFullAndAllowMutating(t *testing.T) {
|
||||
backend := &fakeBackend{
|
||||
editableConnection: connection.SavedConnectionView{
|
||||
ID: "mysql-main",
|
||||
Config: connection.ConnectionConfig{
|
||||
Type: "mysql",
|
||||
Database: "app",
|
||||
},
|
||||
},
|
||||
inspection: appcore.SQLInspection{
|
||||
StatementCount: 1,
|
||||
ReadOnly: false,
|
||||
Statements: []appcore.SQLStatementInspection{
|
||||
{Index: 1, Keyword: "drop", ReadOnly: false},
|
||||
},
|
||||
},
|
||||
safetyLevel: ai.PermissionFull,
|
||||
queryResult: connection.QueryResult{
|
||||
Success: true,
|
||||
Data: []connection.ResultSetData{},
|
||||
},
|
||||
}
|
||||
|
||||
service := NewService(backend)
|
||||
result, _, err := service.ExecuteSQL(context.Background(), nil, executeSQLArgs{
|
||||
ConnectionID: "mysql-main",
|
||||
SQL: "drop table users",
|
||||
AllowMutating: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteSQL returned error: %v", err)
|
||||
}
|
||||
if result == nil || result.IsError {
|
||||
t.Fatalf("expected success result, got %#v", result)
|
||||
}
|
||||
if !backend.queryCalled {
|
||||
t.Fatalf("expected SQL to be executed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteSQLNormalizesAndTruncatesResultSets(t *testing.T) {
|
||||
backend := &fakeBackend{
|
||||
editableConnection: connection.SavedConnectionView{
|
||||
ID: "mysql-main",
|
||||
Config: connection.ConnectionConfig{
|
||||
Type: "mysql",
|
||||
Database: "app",
|
||||
},
|
||||
},
|
||||
inspection: appcore.SQLInspection{
|
||||
StatementCount: 1,
|
||||
ReadOnly: true,
|
||||
Statements: []appcore.SQLStatementInspection{
|
||||
{Index: 1, Keyword: "select", ReadOnly: true},
|
||||
},
|
||||
},
|
||||
queryResult: connection.QueryResult{
|
||||
Success: true,
|
||||
QueryID: "query-1",
|
||||
Data: []connection.ResultSetData{
|
||||
{
|
||||
StatementIndex: 1,
|
||||
Columns: []string{"id"},
|
||||
Rows: []map[string]interface{}{
|
||||
{"id": 1},
|
||||
{"id": 2},
|
||||
{"id": 3},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
service := NewService(backend)
|
||||
result, out, err := service.ExecuteSQL(context.Background(), nil, executeSQLArgs{
|
||||
ConnectionID: "mysql-main",
|
||||
SQL: "select id from users",
|
||||
MaxRowsPerResult: 2,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteSQL returned error: %v", err)
|
||||
}
|
||||
if result == nil || result.IsError {
|
||||
t.Fatalf("expected success result, got %#v", result)
|
||||
}
|
||||
if !backend.queryCalled {
|
||||
t.Fatalf("expected SQL to be executed")
|
||||
}
|
||||
if out.StatementCount != 1 || len(out.Results) != 1 {
|
||||
t.Fatalf("unexpected output: %#v", out)
|
||||
}
|
||||
if out.QueryID != "query-1" {
|
||||
t.Fatalf("unexpected query id: %q", out.QueryID)
|
||||
}
|
||||
if !out.Truncated || !out.Results[0].Truncated {
|
||||
t.Fatalf("expected truncated result, got %#v", out.Results[0])
|
||||
}
|
||||
if out.Results[0].RowCount != 3 {
|
||||
t.Fatalf("expected rowCount 3, got %d", out.Results[0].RowCount)
|
||||
}
|
||||
if len(out.Results[0].Rows) != 2 {
|
||||
t.Fatalf("expected 2 returned rows, got %d", len(out.Results[0].Rows))
|
||||
}
|
||||
}
|
||||
|
||||
func firstTextContent(result *mcp.CallToolResult) string {
|
||||
if result == nil || len(result.Content) == 0 {
|
||||
return ""
|
||||
}
|
||||
text, _ := result.Content[0].(*mcp.TextContent)
|
||||
if text == nil {
|
||||
return ""
|
||||
}
|
||||
return text.Text
|
||||
}
|
||||
29
main.go
29
main.go
@@ -8,6 +8,7 @@ import (
|
||||
aiservice "GoNavi-Wails/internal/ai/service"
|
||||
"GoNavi-Wails/internal/app"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"GoNavi-Wails/internal/mcpserver"
|
||||
|
||||
"github.com/wailsapp/wails/v2"
|
||||
"github.com/wailsapp/wails/v2/pkg/options"
|
||||
@@ -17,6 +18,10 @@ import (
|
||||
)
|
||||
|
||||
func main() {
|
||||
if runSpecialMode(os.Args[1:]) {
|
||||
return
|
||||
}
|
||||
|
||||
// Create an instance of the app structure
|
||||
application := app.NewApp()
|
||||
aiService := aiservice.NewService()
|
||||
@@ -68,6 +73,30 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
func runSpecialMode(args []string) bool {
|
||||
if !shouldRunMCPServerMode(args) {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := mcpserver.RunAppStdioServer(context.Background()); err != nil {
|
||||
logger.Error(err, "GoNavi MCP Server 退出")
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func shouldRunMCPServerMode(args []string) bool {
|
||||
if len(args) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(args[0])) {
|
||||
case "mcp-server", "--mcp-server":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isLowMemoryMode() bool {
|
||||
switch strings.ToLower(strings.TrimSpace(os.Getenv("GONAVI_LOW_MEMORY_MODE"))) {
|
||||
case "1", "true", "yes", "on":
|
||||
|
||||
21
main_test.go
21
main_test.go
@@ -24,3 +24,24 @@ func TestIsLowMemoryMode(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldRunMCPServerMode(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
args []string
|
||||
want bool
|
||||
}{
|
||||
{name: "empty", args: nil, want: false},
|
||||
{name: "mcp-server", args: []string{"mcp-server"}, want: true},
|
||||
{name: "flag style", args: []string{"--mcp-server"}, want: true},
|
||||
{name: "unknown", args: []string{"serve"}, want: false},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := shouldRunMCPServerMode(tc.args); got != tc.want {
|
||||
t.Fatalf("shouldRunMCPServerMode(%v) = %v, want %v", tc.args, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
2
tools/claude-gonavi-mcp.cmd
Normal file
2
tools/claude-gonavi-mcp.cmd
Normal file
@@ -0,0 +1,2 @@
|
||||
@echo off
|
||||
powershell -NoProfile -ExecutionPolicy Bypass -File "%~dp0claude-gonavi-mcp.ps1" %*
|
||||
44
tools/claude-gonavi-mcp.ps1
Normal file
44
tools/claude-gonavi-mcp.ps1
Normal file
@@ -0,0 +1,44 @@
|
||||
param(
|
||||
[switch]$SkipBuild
|
||||
)
|
||||
|
||||
$ErrorActionPreference = 'Stop'
|
||||
$ClaudeArgs = $args
|
||||
|
||||
$repoRoot = (Resolve-Path (Join-Path $PSScriptRoot '..')).Path
|
||||
$binDir = Join-Path $repoRoot 'bin'
|
||||
$serverExe = Join-Path $binDir 'gonavi-mcp-server.exe'
|
||||
|
||||
if (-not $SkipBuild) {
|
||||
if (-not (Test-Path $binDir)) {
|
||||
New-Item -ItemType Directory -Path $binDir | Out-Null
|
||||
}
|
||||
|
||||
& go build -o $serverExe .\cmd\gonavi-mcp-server
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
throw "构建 gonavi-mcp-server 失败"
|
||||
}
|
||||
} elseif (-not (Test-Path $serverExe)) {
|
||||
throw "未找到已编译的 gonavi-mcp-server.exe,请去掉 -SkipBuild 或先手动构建"
|
||||
}
|
||||
|
||||
$mcpConfig = @{
|
||||
mcpServers = @{
|
||||
gonavi = @{
|
||||
type = 'stdio'
|
||||
command = $serverExe
|
||||
args = @()
|
||||
env = @{}
|
||||
}
|
||||
}
|
||||
} | ConvertTo-Json -Compress -Depth 6
|
||||
|
||||
$tempConfig = Join-Path ([System.IO.Path]::GetTempPath()) ("gonavi-claude-mcp-" + [System.Guid]::NewGuid().ToString("N") + ".json")
|
||||
|
||||
try {
|
||||
Set-Content -LiteralPath $tempConfig -Value $mcpConfig -Encoding UTF8
|
||||
& claude @ClaudeArgs --mcp-config $tempConfig --strict-mcp-config
|
||||
exit $LASTEXITCODE
|
||||
} finally {
|
||||
Remove-Item -LiteralPath $tempConfig -ErrorAction SilentlyContinue
|
||||
}
|
||||
Reference in New Issue
Block a user