mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-05-12 04:19:40 +08:00
feat(app): 合并配置密文存储、数据表增强与驱动相关修复 (#339)
## 背景
本次合并汇总了 5 类功能,并对冲突处理后的代码进行了回归审查,目标是将配置密文存储能力、数据表体验增强及驱动相关修复一并并入。
## 本次变更
1. 修复 Data Viewer 多列排序状态残留,避免排序条件切换后失效。
2. 收紧 MongoDB 可选驱动支持区间,仅支持 1.17.x 与 2.x,并补齐对应版本识别与导入校验。
3. 完成配置密文存储前后端闭环,包括:
- 新增密钥存储基础设施与状态枚举
- 拆分 AI Provider 元数据与密钥存储
- 暴露连接配置、代理配置相关密钥存储 API
- 前端状态迁移为不保存明文密钥
- 通过连接配置 ID 路由 RPC 配置
- 修复密文编辑与状态残留问题
4. 增强 DataGrid 显示能力,补充展示策略并支持行级 SQL 复制。
5. 修复本地驱动导入版本识别与数据库连接校验遗漏,补齐 ClickHouse 等相关校验路径。
## 附带修复
- 修复 Claude CLI 在 Windows 下的测试稳定性问题。
## 验证情况
- `go test ./...`
- `go build ./...`
- `npm test -- src/store.test.ts src/utils/dataGridDisplay.test.ts
src/components/dataGridCopyInsert.test.ts
src/utils/connectionRpcConfig.test.ts
src/utils/connectionSecretDraft.test.ts src/utils/
providerSecretDraft.test.ts src/utils/customConnectionDsn.test.ts
src/utils/aiProviderEditorState.test.ts
src/utils/browserMockConnections.test.ts
src/utils/dataViewerAutoFetch.test.ts`
- `npm run build`
## 说明
- 其他功能主要依据提交差异、代码检查与自动化测试完成回归确认。
- 当前未发现因冲突处理导致的明确编译问题、功能失效或目标偏离
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -26,3 +26,5 @@ docs/需求追踪/
|
||||
|
||||
CLAUDE.md
|
||||
**/CLAUDE.md
|
||||
.worktrees
|
||||
docs
|
||||
@@ -1,5 +1,5 @@
|
||||
import React, { useState, useEffect, useMemo, useCallback } from 'react';
|
||||
import { Layout, Button, ConfigProvider, theme, message, Modal, Spin, Slider, Progress, Switch, Input, InputNumber, Select, Tooltip } from 'antd';
|
||||
import React, { useState, useEffect, useMemo, useCallback } from 'react';
|
||||
import { Layout, Button, ConfigProvider, theme, message, Modal, Spin, Slider, Progress, Switch, Input, InputNumber, Select, Segmented, Tooltip } from 'antd';
|
||||
import zhCN from 'antd/locale/zh_CN';
|
||||
import { PlusOutlined, ConsoleSqlOutlined, UploadOutlined, DownloadOutlined, CloudDownloadOutlined, BugOutlined, ToolOutlined, GlobalOutlined, InfoCircleOutlined, GithubOutlined, SkinOutlined, CheckOutlined, MinusOutlined, BorderOutlined, CloseOutlined, SettingOutlined, LinkOutlined, BgColorsOutlined, AppstoreOutlined, RobotOutlined } from '@ant-design/icons';
|
||||
import { BrowserOpenURL, Environment, EventsOn, Quit, WindowFullscreen, WindowGetPosition, WindowGetSize, WindowIsFullscreen, WindowIsMaximised, WindowMaximise, WindowMinimise, WindowSetPosition, WindowSetSize, WindowToggleMaximise, WindowUnfullscreen } from '../wailsjs/runtime';
|
||||
@@ -11,12 +11,15 @@ import DriverManagerModal from './components/DriverManagerModal';
|
||||
import LogPanel from './components/LogPanel';
|
||||
import AIChatPanel from './components/AIChatPanel';
|
||||
import AISettingsModal from './components/AISettingsModal';
|
||||
import { useStore } from './store';
|
||||
import { DEFAULT_APPEARANCE, useStore } from './store';
|
||||
import { SavedConnection } from './types';
|
||||
import { blurToFilter, normalizeBlurForPlatform, normalizeOpacityForPlatform, isWindowsPlatform, resolveAppearanceValues } from './utils/appearance';
|
||||
import { DATA_GRID_COLUMN_WIDTH_MODE_OPTIONS, sanitizeDataTableColumnWidthMode } from './utils/dataGridDisplay';
|
||||
import { getMacNativeTitlebarPaddingLeft, getMacNativeTitlebarPaddingRight, shouldHandleMacNativeFullscreenShortcut, shouldSuppressMacNativeEscapeExit } from './utils/macWindow';
|
||||
import { buildOverlayWorkbenchTheme } from './utils/overlayWorkbenchTheme';
|
||||
import { getConnectionWorkbenchState } from './utils/startupReadiness';
|
||||
import { createGlobalProxyDraft, toSaveGlobalProxyInput } from './utils/globalProxyDraft';
|
||||
import { LEGACY_PERSIST_KEY, readLegacyPersistedSecrets, stripLegacyPersistedSecrets } from './utils/legacyConnectionStorage';
|
||||
import {
|
||||
SHORTCUT_ACTION_META,
|
||||
SHORTCUT_ACTION_ORDER,
|
||||
@@ -35,7 +38,7 @@ import {
|
||||
resolveAIEdgeHandleDockStyle,
|
||||
resolveAIEdgeHandleStyle,
|
||||
} from './utils/aiEntryLayout';
|
||||
import { ConfigureGlobalProxy, SetMacNativeWindowControls, SetWindowTranslucency } from '../wailsjs/go/app/App';
|
||||
import { SetMacNativeWindowControls, SetWindowTranslucency } from '../wailsjs/go/app/App';
|
||||
import './App.css';
|
||||
|
||||
const { Sider, Content } = Layout;
|
||||
@@ -59,6 +62,24 @@ const detectNavigatorPlatform = (): string => {
|
||||
return navigator.userAgent || '';
|
||||
};
|
||||
|
||||
|
||||
const toLegacySavedConnectionInput = (item: any) => ({
|
||||
id: typeof item?.id === 'string' ? item.id : '',
|
||||
name: typeof item?.name === 'string' ? item.name : '',
|
||||
config: (item?.config && typeof item.config === 'object') ? item.config : {},
|
||||
includeDatabases: Array.isArray(item?.includeDatabases) ? item.includeDatabases : undefined,
|
||||
includeRedisDatabases: Array.isArray(item?.includeRedisDatabases) ? item.includeRedisDatabases : undefined,
|
||||
iconType: typeof item?.iconType === 'string' ? item.iconType : '',
|
||||
iconColor: typeof item?.iconColor === 'string' ? item.iconColor : '',
|
||||
});
|
||||
|
||||
const mergeSavedConnections = (current: SavedConnection[], imported: SavedConnection[]): SavedConnection[] => {
|
||||
const merged = new Map<string, SavedConnection>();
|
||||
current.forEach((conn) => merged.set(conn.id, conn));
|
||||
imported.forEach((conn) => merged.set(conn.id, conn));
|
||||
return Array.from(merged.values());
|
||||
};
|
||||
|
||||
function App() {
|
||||
const [isModalOpen, setIsModalOpen] = useState(false);
|
||||
const [isSyncModalOpen, setIsSyncModalOpen] = useState(false);
|
||||
@@ -76,6 +97,8 @@ function App() {
|
||||
const setStartupFullscreen = useStore(state => state.setStartupFullscreen);
|
||||
const globalProxy = useStore(state => state.globalProxy);
|
||||
const setGlobalProxy = useStore(state => state.setGlobalProxy);
|
||||
const replaceConnections = useStore(state => state.replaceConnections);
|
||||
const replaceGlobalProxy = useStore(state => state.replaceGlobalProxy);
|
||||
const shortcutOptions = useStore(state => state.shortcutOptions);
|
||||
const updateShortcut = useStore(state => state.updateShortcut);
|
||||
const resetShortcutOptions = useStore(state => state.resetShortcutOptions);
|
||||
@@ -100,14 +123,14 @@ function App() {
|
||||
const [runtimePlatform, setRuntimePlatform] = useState('');
|
||||
const [isLinuxRuntime, setIsLinuxRuntime] = useState(false);
|
||||
const [isStoreHydrated, setIsStoreHydrated] = useState(() => useStore.persist.hasHydrated());
|
||||
const [hasAppliedInitialGlobalProxy, setHasAppliedInitialGlobalProxy] = useState(false);
|
||||
const [hasLoadedSecureConfig, setHasLoadedSecureConfig] = useState(false);
|
||||
const sidebarWidth = useStore(state => state.sidebarWidth);
|
||||
const setSidebarWidth = useStore(state => state.setSidebarWidth);
|
||||
const aiPanelVisible = useStore(state => state.aiPanelVisible);
|
||||
const toggleAIPanel = useStore(state => state.toggleAIPanel);
|
||||
const setAIPanelVisible = useStore(state => state.setAIPanelVisible);
|
||||
const globalProxyInvalidHintShownRef = React.useRef(false);
|
||||
const connectionWorkbenchState = getConnectionWorkbenchState(isStoreHydrated, hasAppliedInitialGlobalProxy);
|
||||
const connectionWorkbenchState = getConnectionWorkbenchState(isStoreHydrated, hasLoadedSecureConfig);
|
||||
|
||||
// 同步 macOS 窗口透明度:opacity=1.0 且 blur=0 时关闭 NSVisualEffectView,
|
||||
// 避免 GPU 持续计算窗口背后的模糊合成
|
||||
@@ -167,6 +190,90 @@ function App() {
|
||||
return;
|
||||
}
|
||||
|
||||
let cancelled = false;
|
||||
const loadSecureConfig = async () => {
|
||||
const backendApp = (window as any).go?.app?.App;
|
||||
const persistedPayload = typeof window !== 'undefined'
|
||||
? window.localStorage.getItem(LEGACY_PERSIST_KEY)
|
||||
: null;
|
||||
const legacy = readLegacyPersistedSecrets(persistedPayload);
|
||||
|
||||
let importedLegacyConnections = false;
|
||||
let importedLegacyGlobalProxy = false;
|
||||
|
||||
if (legacy.connections.length > 0) {
|
||||
if (typeof backendApp?.ImportLegacyConnections === 'function') {
|
||||
try {
|
||||
await backendApp.ImportLegacyConnections(
|
||||
legacy.connections.map(toLegacySavedConnectionInput)
|
||||
);
|
||||
importedLegacyConnections = true;
|
||||
} catch (err) {
|
||||
console.warn('Failed to import legacy saved connections', err);
|
||||
}
|
||||
} else {
|
||||
replaceConnections(legacy.connections);
|
||||
}
|
||||
}
|
||||
|
||||
if (legacy.globalProxy) {
|
||||
if (typeof backendApp?.ImportLegacyGlobalProxy === 'function') {
|
||||
try {
|
||||
await backendApp.ImportLegacyGlobalProxy(toSaveGlobalProxyInput(legacy.globalProxy));
|
||||
importedLegacyGlobalProxy = true;
|
||||
} catch (err) {
|
||||
console.warn('Failed to import legacy global proxy', err);
|
||||
}
|
||||
} else {
|
||||
replaceGlobalProxy(createGlobalProxyDraft(legacy.globalProxy));
|
||||
}
|
||||
}
|
||||
|
||||
if ((importedLegacyConnections || importedLegacyGlobalProxy) && persistedPayload && typeof window !== 'undefined') {
|
||||
const sanitizedPayload = stripLegacyPersistedSecrets(persistedPayload);
|
||||
if (sanitizedPayload && sanitizedPayload !== persistedPayload) {
|
||||
window.localStorage.setItem(LEGACY_PERSIST_KEY, sanitizedPayload);
|
||||
}
|
||||
}
|
||||
|
||||
if (typeof backendApp?.GetSavedConnections === 'function') {
|
||||
try {
|
||||
const savedConnections = await backendApp.GetSavedConnections();
|
||||
if (!cancelled && Array.isArray(savedConnections)) {
|
||||
replaceConnections(savedConnections);
|
||||
}
|
||||
} catch (err) {
|
||||
console.warn('Failed to load saved connections from backend', err);
|
||||
}
|
||||
}
|
||||
|
||||
if (typeof backendApp?.GetGlobalProxyConfig === 'function') {
|
||||
try {
|
||||
const proxyResult = await backendApp.GetGlobalProxyConfig();
|
||||
if (!cancelled && proxyResult?.success && proxyResult.data) {
|
||||
replaceGlobalProxy(createGlobalProxyDraft(proxyResult.data));
|
||||
}
|
||||
} catch (err) {
|
||||
console.warn('Failed to load global proxy from backend', err);
|
||||
}
|
||||
}
|
||||
|
||||
if (!cancelled) {
|
||||
setHasLoadedSecureConfig(true);
|
||||
}
|
||||
};
|
||||
|
||||
void loadSecureConfig();
|
||||
return () => {
|
||||
cancelled = true;
|
||||
};
|
||||
}, [isStoreHydrated, replaceConnections, replaceGlobalProxy]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isStoreHydrated || !hasLoadedSecureConfig) {
|
||||
return;
|
||||
}
|
||||
|
||||
const host = String(globalProxy.host || '').trim();
|
||||
const port = Number(globalProxy.port);
|
||||
const portValid = Number.isFinite(port) && port > 0 && port <= 65535;
|
||||
@@ -180,57 +287,44 @@ function App() {
|
||||
});
|
||||
globalProxyInvalidHintShownRef.current = true;
|
||||
}
|
||||
} else {
|
||||
globalProxyInvalidHintShownRef.current = false;
|
||||
void message.destroy('global-proxy-invalid');
|
||||
return;
|
||||
}
|
||||
|
||||
const enabledForBackend = globalProxy.enabled && !invalidWhenEnabled;
|
||||
let cancelled = false;
|
||||
try {
|
||||
ConfigureGlobalProxy(enabledForBackend, {
|
||||
type: globalProxy.type,
|
||||
host,
|
||||
port: portValid ? port : (globalProxy.type === 'http' ? 8080 : 1080),
|
||||
user: String(globalProxy.user || '').trim(),
|
||||
password: globalProxy.password || '',
|
||||
})
|
||||
.then((res) => {
|
||||
if (cancelled || res?.success) {
|
||||
return;
|
||||
}
|
||||
void message.error({
|
||||
content: '全局代理配置失败: ' + (res?.message || '未知错误'),
|
||||
key: 'global-proxy-sync-error',
|
||||
});
|
||||
})
|
||||
.catch((err) => {
|
||||
if (cancelled) {
|
||||
return;
|
||||
}
|
||||
const errMsg = err instanceof Error ? err.message : String(err || '未知错误');
|
||||
void message.error({
|
||||
content: '全局代理配置失败: ' + errMsg,
|
||||
key: 'global-proxy-sync-error',
|
||||
});
|
||||
})
|
||||
.finally(() => {
|
||||
if (!cancelled) {
|
||||
setHasAppliedInitialGlobalProxy(true);
|
||||
}
|
||||
});
|
||||
} catch (e) {
|
||||
if (!cancelled) {
|
||||
setHasAppliedInitialGlobalProxy(true);
|
||||
}
|
||||
console.warn("Wails API: ConfigureGlobalProxy unavailable", e);
|
||||
globalProxyInvalidHintShownRef.current = false;
|
||||
void message.destroy('global-proxy-invalid');
|
||||
|
||||
const backendApp = (window as any).go?.app?.App;
|
||||
if (typeof backendApp?.SaveGlobalProxy !== 'function') {
|
||||
return;
|
||||
}
|
||||
|
||||
let cancelled = false;
|
||||
Promise.resolve(
|
||||
backendApp.SaveGlobalProxy(
|
||||
toSaveGlobalProxyInput({
|
||||
...globalProxy,
|
||||
host,
|
||||
port: portValid ? port : (globalProxy.type === 'http' ? 8080 : 1080),
|
||||
})
|
||||
)
|
||||
)
|
||||
.catch((err) => {
|
||||
if (cancelled) {
|
||||
return;
|
||||
}
|
||||
const errMsg = err instanceof Error ? err.message : String(err || '未知错误');
|
||||
void message.error({
|
||||
content: '全局代理配置失败: ' + errMsg,
|
||||
key: 'global-proxy-sync-error',
|
||||
});
|
||||
});
|
||||
|
||||
return () => {
|
||||
cancelled = true;
|
||||
};
|
||||
}, [
|
||||
isStoreHydrated,
|
||||
hasLoadedSecureConfig,
|
||||
globalProxy.enabled,
|
||||
globalProxy.type,
|
||||
globalProxy.host,
|
||||
@@ -676,7 +770,6 @@ function App() {
|
||||
const addTab = useStore(state => state.addTab);
|
||||
const activeContext = useStore(state => state.activeContext);
|
||||
const connections = useStore(state => state.connections);
|
||||
const addConnection = useStore(state => state.addConnection);
|
||||
const tabs = useStore(state => state.tabs);
|
||||
const activeTabId = useStore(state => state.activeTabId);
|
||||
const updateCheckInFlightRef = React.useRef(false);
|
||||
@@ -1091,20 +1184,29 @@ function App() {
|
||||
if (res.success) {
|
||||
try {
|
||||
const imported = JSON.parse(res.data);
|
||||
if (Array.isArray(imported)) {
|
||||
let count = 0;
|
||||
imported.forEach((conn: any) => {
|
||||
if (!connections.some(c => c.id === conn.id)) {
|
||||
addConnection(conn);
|
||||
count++;
|
||||
}
|
||||
});
|
||||
void message.success(`成功导入 ${count} 个连接`);
|
||||
} else {
|
||||
if (!Array.isArray(imported)) {
|
||||
void message.error("文件格式错误:需要 JSON 数组");
|
||||
return;
|
||||
}
|
||||
} catch (e) {
|
||||
void message.error("解析 JSON 失败");
|
||||
|
||||
const normalizedItems = imported.map(toLegacySavedConnectionInput);
|
||||
const backendApp = (window as any).go?.app?.App;
|
||||
|
||||
if (typeof backendApp?.ImportLegacyConnections === 'function') {
|
||||
const importedViews = await backendApp.ImportLegacyConnections(normalizedItems);
|
||||
if (!Array.isArray(importedViews)) {
|
||||
throw new Error('导入失败:后端未返回连接列表');
|
||||
}
|
||||
replaceConnections(mergeSavedConnections(connections, importedViews));
|
||||
void message.success(`成功导入 ${importedViews.length} 个连接`);
|
||||
return;
|
||||
}
|
||||
|
||||
const fallbackItems = normalizedItems as SavedConnection[];
|
||||
replaceConnections(mergeSavedConnections(connections, fallbackItems));
|
||||
void message.success(`成功导入 ${fallbackItems.length} 个连接`);
|
||||
} catch (e: any) {
|
||||
void message.error(e?.message || "解析 JSON 失败");
|
||||
}
|
||||
} else if (res.message !== "已取消") {
|
||||
void message.error("导入失败: " + res.message);
|
||||
@@ -1116,7 +1218,7 @@ function App() {
|
||||
void message.warning("没有连接可导出");
|
||||
return;
|
||||
}
|
||||
const res = await (window as any).go.app.App.ExportData(connections, ['id','name','config','includeDatabases','includeRedisDatabases'], "connections", "json");
|
||||
const res = await (window as any).go.app.App.ExportData(connections, ['id','name','config','includeDatabases','includeRedisDatabases','iconType','iconColor'], "connections", "json");
|
||||
if (res.success) {
|
||||
void message.success("导出成功");
|
||||
} else if (res.message !== "已取消") {
|
||||
@@ -2194,6 +2296,33 @@ function App() {
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div style={utilityPanelStyle}>
|
||||
<div style={{ marginBottom: 10, fontWeight: 500 }}>数据表显示</div>
|
||||
<div style={{ display: 'grid', gap: 14 }}>
|
||||
<div style={{ display: 'flex', alignItems: 'center', justifyContent: 'space-between', gap: 12 }}>
|
||||
<div>
|
||||
<div style={{ fontWeight: 500 }}>显示数据表竖向分隔线</div>
|
||||
<div style={{ ...utilityMutedTextStyle, marginTop: 4 }}>仅作用于数据表页面 DataGrid,不影响其他表格组件。</div>
|
||||
</div>
|
||||
<Switch
|
||||
checked={appearance.showDataTableVerticalBorders === true}
|
||||
onChange={(checked) => setAppearance({ showDataTableVerticalBorders: checked })}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<div style={{ marginBottom: 8, fontWeight: 500 }}>数据表列宽模式</div>
|
||||
<Segmented
|
||||
block
|
||||
options={DATA_GRID_COLUMN_WIDTH_MODE_OPTIONS}
|
||||
value={appearance.dataTableColumnWidthMode}
|
||||
onChange={(value) => setAppearance({ dataTableColumnWidthMode: sanitizeDataTableColumnWidthMode(value) })}
|
||||
/>
|
||||
<div style={{ ...utilityMutedTextStyle, marginTop: 8 }}>
|
||||
标准模式默认列宽 200px;紧凑模式默认列宽 140px。已手动拖拽调整的列宽优先保留。
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{isMacRuntime ? (
|
||||
<div style={utilityPanelStyle}>
|
||||
<div style={{ marginBottom: 8, fontWeight: 500 }}>macOS 窗口控制</div>
|
||||
@@ -2227,7 +2356,7 @@ function App() {
|
||||
onClick={() => {
|
||||
setUiScale(DEFAULT_UI_SCALE);
|
||||
setFontSize(DEFAULT_FONT_SIZE);
|
||||
setAppearance({ enabled: true, opacity: 1.0, blur: 0, useNativeMacWindowControls: false });
|
||||
setAppearance({ ...DEFAULT_APPEARANCE });
|
||||
}}
|
||||
>
|
||||
恢复默认
|
||||
|
||||
@@ -14,6 +14,7 @@ import { AIMessageBubble } from './ai/AIMessageBubble';
|
||||
import { AIChatInput } from './ai/AIChatInput';
|
||||
import { AIHistoryDrawer } from './ai/AIHistoryDrawer';
|
||||
import type { AIComposerNotice } from '../utils/aiComposerNotice';
|
||||
import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig';
|
||||
import {
|
||||
buildMissingModelNotice,
|
||||
buildMissingProviderNotice,
|
||||
@@ -260,7 +261,7 @@ export const AIChatPanel: React.FC<AIChatPanelProps> = ({
|
||||
const conn = useStore.getState().connections.find(c => c.id === connectionId);
|
||||
if (conn) {
|
||||
import('../../wailsjs/go/app/App').then(({ DBShowCreateTable }) => {
|
||||
DBShowCreateTable(conn.config as any, dbName, tableName).then(res => {
|
||||
DBShowCreateTable(buildRpcConnectionConfig(conn.config) as any, dbName, tableName).then(res => {
|
||||
if (res.success && res.data) {
|
||||
let createSql = '';
|
||||
if (typeof res.data === 'string') createSql = res.data;
|
||||
@@ -352,7 +353,12 @@ export const AIChatPanel: React.FC<AIChatPanelProps> = ({
|
||||
if (!activeProvider) return;
|
||||
try {
|
||||
const Service = (window as any).go?.aiservice?.Service;
|
||||
const payload = { ...activeProvider, model: val };
|
||||
const payload = {
|
||||
...activeProvider,
|
||||
model: val,
|
||||
apiKey: activeProvider.apiKey || '',
|
||||
hasSecret: activeProvider.hasSecret ?? Boolean(activeProvider.secretRef),
|
||||
};
|
||||
await Service?.AISaveProvider?.(payload);
|
||||
setActiveProvider(payload);
|
||||
setComposerNotice(null);
|
||||
@@ -834,7 +840,7 @@ SELECT * FROM users WHERE status = 1;
|
||||
const conn = useStore.getState().connections.find(c => c.id === args.connectionId);
|
||||
if (conn) {
|
||||
try {
|
||||
const dbRes = await DBGetDatabases(conn.config as any);
|
||||
const dbRes = await DBGetDatabases(buildRpcConnectionConfig(conn.config) as any);
|
||||
if (dbRes?.success && Array.isArray(dbRes.data)) {
|
||||
let dNames = dbRes.data.map((r: any) => r.Database || r.database || Object.values(r)[0]);
|
||||
if (dNames.length > 50) dNames = [...dNames.slice(0, 50), '...(截断)'];
|
||||
@@ -855,7 +861,7 @@ SELECT * FROM users WHERE status = 1;
|
||||
try {
|
||||
const rawDbName = args.dbName || args.database;
|
||||
const safeDbName = rawDbName ? String(rawDbName).trim() : '';
|
||||
const tbRes = await DBGetTables(conn.config as any, safeDbName);
|
||||
const tbRes = await DBGetTables(buildRpcConnectionConfig(conn.config) as any, safeDbName);
|
||||
if (tbRes?.success && Array.isArray(tbRes.data)) {
|
||||
let tNames = tbRes.data.map((r: any) => r.Table || r.table || Object.values(r)[0] as string);
|
||||
if (tNames.length > 150) tNames = [...tNames.slice(0, 150), '...(截断)'];
|
||||
@@ -881,7 +887,7 @@ SELECT * FROM users WHERE status = 1;
|
||||
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
|
||||
const safeTable = args.tableName ? String(args.tableName).trim() : '';
|
||||
const { DBGetColumns } = await import('../../wailsjs/go/app/App');
|
||||
const colRes = await DBGetColumns(conn.config as any, safeDbName, safeTable);
|
||||
const colRes = await DBGetColumns(buildRpcConnectionConfig(conn.config) as any, safeDbName, safeTable);
|
||||
if (colRes?.success && Array.isArray(colRes.data)) {
|
||||
// 只保留关键字段信息,减少 token 占用
|
||||
const cols = colRes.data.map((c: any) => {
|
||||
@@ -912,7 +918,7 @@ SELECT * FROM users WHERE status = 1;
|
||||
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
|
||||
const safeTable = args.tableName ? String(args.tableName).trim() : '';
|
||||
const { DBShowCreateTable } = await import('../../wailsjs/go/app/App');
|
||||
const ddlRes = await DBShowCreateTable(conn.config as any, safeDbName, safeTable);
|
||||
const ddlRes = await DBShowCreateTable(buildRpcConnectionConfig(conn.config) as any, safeDbName, safeTable);
|
||||
if (ddlRes?.success) {
|
||||
resStr = typeof ddlRes.data === 'string' ? ddlRes.data : JSON.stringify(ddlRes.data);
|
||||
success = true;
|
||||
@@ -946,7 +952,7 @@ SELECT * FROM users WHERE status = 1;
|
||||
const finalSql = (isReadQuery && !sqlTrimmed.toLowerCase().includes('limit'))
|
||||
? sqlTrimmed + ' LIMIT 50'
|
||||
: sqlTrimmed;
|
||||
const qRes = await DBQuery(conn.config as any, safeDbName, finalSql);
|
||||
const qRes = await DBQuery(buildRpcConnectionConfig(conn.config) as any, safeDbName, safeSql + (safeSql.toLowerCase().includes('limit') ? '' : ' LIMIT 50'));
|
||||
if (qRes?.success) {
|
||||
const rows = Array.isArray(qRes.data) ? qRes.data : [];
|
||||
const limitedRows = rows.slice(0, 50);
|
||||
@@ -1306,7 +1312,8 @@ SELECT * FROM users WHERE status = 1;
|
||||
const handleDeleteMessage = useCallback((id: string) => deleteAIChatMessage(sid, id), [sid, deleteAIChatMessage]);
|
||||
const activeConnectionConfig = useMemo(() => {
|
||||
if (!inferredConnectionId) return undefined;
|
||||
return connections.find(c => c.id === inferredConnectionId)?.config;
|
||||
const connection = connections.find(c => c.id === inferredConnectionId);
|
||||
return connection ? buildRpcConnectionConfig(connection.config) : undefined;
|
||||
}, [inferredConnectionId, connections]);
|
||||
const contextUsageChars = useMemo(() =>
|
||||
messages.reduce((sum, m) => sum + (m.content?.length || 0) + JSON.stringify(m.tool_calls || []).length, 0),
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import React, { useState, useEffect, useCallback, useRef } from 'react';
|
||||
import { Modal, Button, Input, Select, Form, message as antdMessage, Tooltip, Tabs, Space, Popconfirm, Slider } from 'antd';
|
||||
import { Modal, Button, Input, Select, Form, Checkbox, message as antdMessage, Tooltip, Tabs, Space, Popconfirm, Slider } from 'antd';
|
||||
import { PlusOutlined, DeleteOutlined, EditOutlined, CheckOutlined, ApiOutlined, SafetyCertificateOutlined, RobotOutlined, ThunderboltOutlined, CloudOutlined, ExperimentOutlined, KeyOutlined, LinkOutlined, AppstoreOutlined, ToolOutlined } from '@ant-design/icons';
|
||||
import type { AIProviderConfig, AIProviderType, AISafetyLevel, AIContextLevel } from '../types';
|
||||
import {
|
||||
@@ -18,6 +18,8 @@ import {
|
||||
PROVIDER_PRESET_GRID_STYLE,
|
||||
PROVIDER_PRESET_CARD_TITLE_STYLE,
|
||||
} from '../utils/aiSettingsPresetLayout';
|
||||
import { resolveProviderSecretDraft } from '../utils/providerSecretDraft';
|
||||
import { buildAddProviderEditorSession, buildClosedProviderEditorSession, buildEditProviderEditorSession, type ProviderEditorSession } from '../utils/aiProviderEditorState';
|
||||
|
||||
import type { OverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme';
|
||||
|
||||
@@ -88,6 +90,7 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
const [testStatus, setTestStatus] = useState<'idle' | 'success' | 'error'>('idle');
|
||||
const [builtinPrompts, setBuiltinPrompts] = useState<Record<string, string>>({});
|
||||
const [activeSection, setActiveSection] = useState<'providers' | 'safety' | 'context' | 'prompts' | 'tools'>('providers');
|
||||
const [clearProviderSecret, setClearProviderSecret] = useState(false);
|
||||
const [form] = Form.useForm();
|
||||
const modalBodyRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
@@ -105,6 +108,7 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
const watchedType = Form.useWatch('type', form);
|
||||
const watchedPresetKey = Form.useWatch('presetKey', form);
|
||||
const watchedApiFormat = Form.useWatch('apiFormat', form) || 'openai';
|
||||
const watchedApiKeyInput = Form.useWatch('apiKey', form);
|
||||
|
||||
const loadConfig = useCallback(async () => {
|
||||
try {
|
||||
@@ -131,18 +135,41 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
|
||||
useEffect(() => { if (open) void loadConfig(); }, [open, loadConfig]);
|
||||
|
||||
const applyProviderEditorSession = useCallback((session: ProviderEditorSession) => {
|
||||
setEditingProvider(session.editingProvider as AIProviderConfig | null);
|
||||
setIsEditing(session.isEditing);
|
||||
setTestStatus(session.testStatus);
|
||||
setClearProviderSecret(session.clearProviderSecret);
|
||||
form.resetFields();
|
||||
if (session.formValues) {
|
||||
form.setFieldsValue(session.formValues);
|
||||
}
|
||||
}, [form]);
|
||||
|
||||
const resetProviderEditorSession = useCallback(() => {
|
||||
applyProviderEditorSession(buildClosedProviderEditorSession());
|
||||
}, [applyProviderEditorSession]);
|
||||
|
||||
const handleModalClose = useCallback(() => {
|
||||
resetProviderEditorSession();
|
||||
onClose();
|
||||
}, [onClose, resetProviderEditorSession]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!open) {
|
||||
resetProviderEditorSession();
|
||||
}
|
||||
}, [open, resetProviderEditorSession]);
|
||||
const handleAddProvider = () => {
|
||||
const preset = findPreset('openai');
|
||||
const newProvider: AIProviderConfig = {
|
||||
id: '', type: preset.backendType, name: '', apiKey: '',
|
||||
baseUrl: preset.defaultBaseUrl, model: preset.defaultModel,
|
||||
models: [], maxTokens: 4096, temperature: 0.7,
|
||||
};
|
||||
setEditingProvider({ ...newProvider, presetKey: 'openai' } as any);
|
||||
setIsEditing(true);
|
||||
setTestStatus('idle');
|
||||
form.resetFields();
|
||||
form.setFieldsValue({ ...newProvider, presetKey: 'openai', apiFormat: 'openai' });
|
||||
applyProviderEditorSession(buildAddProviderEditorSession({
|
||||
presetKey: 'openai',
|
||||
presetBackendType: preset.backendType,
|
||||
presetBaseUrl: preset.defaultBaseUrl,
|
||||
presetModel: preset.defaultModel,
|
||||
presetModels: preset.models,
|
||||
apiFormat: 'openai',
|
||||
}));
|
||||
};
|
||||
|
||||
const handleEditProvider = (p: AIProviderConfig) => {
|
||||
@@ -153,17 +180,16 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
presetFixedApiFormat: matchedPreset.fixedApiFormat,
|
||||
valuesApiFormat: p.apiFormat,
|
||||
});
|
||||
setEditingProvider(p);
|
||||
setIsEditing(true);
|
||||
setTestStatus('idle');
|
||||
form.resetFields();
|
||||
form.setFieldsValue({
|
||||
...p,
|
||||
type: resolvedTransport.type,
|
||||
models: p.models || [],
|
||||
presetKey: matchedPreset.key,
|
||||
apiFormat: resolvedTransport.apiFormat || p.apiFormat || 'openai',
|
||||
});
|
||||
applyProviderEditorSession(buildEditProviderEditorSession({
|
||||
provider: { ...p, presetKey: matchedPreset.key } as any,
|
||||
formValues: {
|
||||
...p,
|
||||
type: resolvedTransport.type,
|
||||
models: p.models || [],
|
||||
presetKey: matchedPreset.key,
|
||||
apiFormat: resolvedTransport.apiFormat || p.apiFormat || 'openai',
|
||||
},
|
||||
}));
|
||||
};
|
||||
|
||||
const handleDeleteProvider = async (id: string) => {
|
||||
@@ -217,12 +243,18 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
presetFixedApiFormat: preset.fixedApiFormat,
|
||||
valuesApiFormat: values.apiFormat,
|
||||
});
|
||||
|
||||
const secretDraft = resolveProviderSecretDraft({
|
||||
hasSecret: editingProvider?.hasSecret,
|
||||
apiKeyInput: values.apiKey,
|
||||
clearSecret: clearProviderSecret,
|
||||
});
|
||||
const payload = {
|
||||
...editingProvider,
|
||||
...values,
|
||||
...resolvedTransport,
|
||||
name: finalName,
|
||||
apiKey: secretDraft.apiKey,
|
||||
hasSecret: secretDraft.hasSecret,
|
||||
model: finalModel,
|
||||
models: resolvedModels,
|
||||
baseUrl: finalBaseUrl,
|
||||
@@ -230,7 +262,7 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
};
|
||||
// 后端 AISaveProvider 统一处理新增和更新,返回 void,失败抛异常
|
||||
await Service?.AISaveProvider?.(payload);
|
||||
void messageApi.success('已保存'); setIsEditing(false); setEditingProvider(null); void loadConfig();
|
||||
void messageApi.success('已保存'); resetProviderEditorSession(); void loadConfig();
|
||||
window.dispatchEvent(new CustomEvent('gonavi:ai:provider-changed'));
|
||||
} catch (e: any) {
|
||||
if (e?.errorFields) { /* antd form validation error, ignore */ }
|
||||
@@ -287,10 +319,20 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
presetFixedApiFormat: preset.fixedApiFormat,
|
||||
valuesApiFormat: values.apiFormat,
|
||||
});
|
||||
const secretDraft = resolveProviderSecretDraft({
|
||||
hasSecret: editingProvider?.hasSecret,
|
||||
apiKeyInput: values.apiKey,
|
||||
clearSecret: clearProviderSecret,
|
||||
});
|
||||
if (secretDraft.mode === 'clear') {
|
||||
throw new Error('测试连接前请填写新的 API Key,或取消清除已保存密钥');
|
||||
}
|
||||
const res = await Service?.AITestProvider?.({
|
||||
...editingProvider,
|
||||
...values,
|
||||
...resolvedTransport,
|
||||
apiKey: secretDraft.apiKey,
|
||||
hasSecret: secretDraft.hasSecret,
|
||||
baseUrl: finalBaseUrl,
|
||||
model: finalModel,
|
||||
models: resolvedModels,
|
||||
@@ -401,7 +443,7 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
<div>
|
||||
{/* 顶部返回 */}
|
||||
<div style={{ marginBottom: 16, display: 'flex', alignItems: 'center', gap: 10 }}>
|
||||
<Button size="small" onClick={() => { setIsEditing(false); setEditingProvider(null); }}
|
||||
<Button size="small" onClick={resetProviderEditorSession}
|
||||
style={{ borderRadius: 8 }}>← 返回</Button>
|
||||
<span style={{ fontWeight: 700, fontSize: 16, color: overlayTheme.titleText }}>
|
||||
{editingProvider?.id ? '编辑模型供应商' : '添加模型供应商'}
|
||||
@@ -492,11 +534,25 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
<div style={fieldLabelStyle}>
|
||||
<KeyOutlined style={{ fontSize: 14 }} /> 认证 & 连接
|
||||
</div>
|
||||
<Form.Item label={<span style={{ fontWeight: 500, color: overlayTheme.titleText }}>API Key</span>} name="apiKey" rules={[{ required: true, message: '请输入 API Key' }]} style={{ marginBottom: 16 }}>
|
||||
<Input.Password placeholder="sk-... / 你的 API Key"
|
||||
<Form.Item label={<span style={{ fontWeight: 500, color: overlayTheme.titleText }}>API Key</span>} name="apiKey" rules={[{ validator: (_, value) => { const apiKey = String(value || '').trim(); if (apiKey || clearProviderSecret || editingProvider?.hasSecret) { return Promise.resolve(); } return Promise.reject(new Error('请输入 API Key')); } }]} style={{ marginBottom: editingProvider?.hasSecret ? 8 : 16 }}>
|
||||
<Input.Password placeholder={editingProvider?.hasSecret ? '留空表示继续沿用已保存密钥' : 'sk-... / 你的 API Key'}
|
||||
size="middle"
|
||||
style={{ borderRadius: 8, background: inputBg, border: `1px solid ${cardBorder}` }} />
|
||||
</Form.Item>
|
||||
{editingProvider?.hasSecret && (
|
||||
<div style={{ marginBottom: 16, padding: '10px 12px', borderRadius: 10, border: `1px solid ${cardBorder}`, background: cardBg }}>
|
||||
<div style={{ fontSize: 12, color: overlayTheme.mutedText, lineHeight: 1.6, marginBottom: 8 }}>
|
||||
当前已保存 API Key。留空表示继续沿用,输入新值表示替换。
|
||||
</div>
|
||||
<Checkbox
|
||||
checked={clearProviderSecret}
|
||||
disabled={String(watchedApiKeyInput || '').trim() !== ''}
|
||||
onChange={(event) => setClearProviderSecret(event.target.checked)}
|
||||
>
|
||||
清除已保存 API Key
|
||||
</Checkbox>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{(presetKeyFromForm === 'custom' || presetKeyFromForm === 'ollama') && (
|
||||
<Form.Item label={<span style={{ fontWeight: 500, color: overlayTheme.titleText }}>API Endpoint (URL)</span>} name="baseUrl" rules={[{ required: true, message: '请输入有效的接口地址' }]} style={{ marginBottom: 0 }}>
|
||||
@@ -699,7 +755,7 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
</div>
|
||||
}
|
||||
open={open}
|
||||
onCancel={onClose}
|
||||
onCancel={handleModalClose}
|
||||
footer={null}
|
||||
width={820}
|
||||
styles={{
|
||||
@@ -765,3 +821,9 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
|
||||
};
|
||||
|
||||
export default AISettingsModal;
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,8 @@ import { getDbIcon, getDbDefaultColor, getDbIconLabel, DB_ICON_TYPES, PRESET_ICO
|
||||
import { useStore } from '../store';
|
||||
import { buildOverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme';
|
||||
import { normalizeOpacityForPlatform, resolveAppearanceValues } from '../utils/appearance';
|
||||
import { resolveConnectionSecretDraft } from '../utils/connectionSecretDraft';
|
||||
import { getCustomConnectionDsnValidationMessage } from '../utils/customConnectionDsn';
|
||||
import { DBGetDatabases, GetDriverStatusList, MongoDiscoverMembers, TestConnection, RedisConnect, SelectDatabaseFile, SelectSSHKeyFile } from '../../wailsjs/go/app/App';
|
||||
import { ConnectionConfig, MongoMemberInfo, SavedConnection } from '../types';
|
||||
|
||||
@@ -18,6 +20,29 @@ const CONNECTION_MODAL_BODY_HEIGHT = 620;
|
||||
const STEP1_SIDEBAR_DIVIDER_DARK = 'rgba(255, 255, 255, 0.16)';
|
||||
const STEP1_SIDEBAR_DIVIDER_LIGHT = 'rgba(0, 0, 0, 0.08)';
|
||||
|
||||
type ConnectionSecretKey =
|
||||
| 'primaryPassword'
|
||||
| 'sshPassword'
|
||||
| 'proxyPassword'
|
||||
| 'httpTunnelPassword'
|
||||
| 'mysqlReplicaPassword'
|
||||
| 'mongoReplicaPassword'
|
||||
| 'opaqueURI'
|
||||
| 'opaqueDSN';
|
||||
|
||||
type ConnectionSecretClearState = Record<ConnectionSecretKey, boolean>;
|
||||
|
||||
const createEmptyConnectionSecretClearState = (): ConnectionSecretClearState => ({
|
||||
primaryPassword: false,
|
||||
sshPassword: false,
|
||||
proxyPassword: false,
|
||||
httpTunnelPassword: false,
|
||||
mysqlReplicaPassword: false,
|
||||
mongoReplicaPassword: false,
|
||||
opaqueURI: false,
|
||||
opaqueDSN: false,
|
||||
});
|
||||
|
||||
const getDefaultPortByType = (type: string) => {
|
||||
switch (type) {
|
||||
case 'mysql': return 3306;
|
||||
@@ -122,6 +147,7 @@ const ConnectionModal: React.FC<{
|
||||
const [driverStatusLoaded, setDriverStatusLoaded] = useState(false);
|
||||
const [selectingDbFile, setSelectingDbFile] = useState(false);
|
||||
const [selectingSSHKey, setSelectingSSHKey] = useState(false);
|
||||
const [clearSecrets, setClearSecrets] = useState<ConnectionSecretClearState>(createEmptyConnectionSecretClearState);
|
||||
const testInFlightRef = useRef(false);
|
||||
const testTimerRef = useRef<number | null>(null);
|
||||
const addConnection = useStore((state) => state.addConnection);
|
||||
@@ -192,6 +218,51 @@ const ConnectionModal: React.FC<{
|
||||
lineHeight: 1.6,
|
||||
}), [overlayTheme]);
|
||||
|
||||
const renderStoredSecretControls = ({
|
||||
fieldName,
|
||||
clearKey,
|
||||
hasStoredSecret,
|
||||
clearLabel,
|
||||
description,
|
||||
}: {
|
||||
fieldName: string;
|
||||
clearKey: ConnectionSecretKey;
|
||||
hasStoredSecret?: boolean;
|
||||
clearLabel: string;
|
||||
description: string;
|
||||
}) => {
|
||||
if (!initialValues || !hasStoredSecret) {
|
||||
return null;
|
||||
}
|
||||
return (
|
||||
<Form.Item noStyle shouldUpdate={(prev, next) => prev[fieldName] !== next[fieldName]}>
|
||||
{({ getFieldValue }) => {
|
||||
const draftValue = getFieldValue(fieldName);
|
||||
const hasDraftValue = String(draftValue ?? '') !== '';
|
||||
const cardBorder = darkMode ? '1px solid rgba(255,255,255,0.12)' : '1px solid rgba(16,24,40,0.08)';
|
||||
const cardBg = darkMode ? 'rgba(255,255,255,0.03)' : 'rgba(16,24,40,0.03)';
|
||||
const effectiveChecked = clearSecrets[clearKey] && !hasDraftValue;
|
||||
return (
|
||||
<div style={{ marginBottom: 16, padding: '10px 12px', borderRadius: 10, border: cardBorder, background: cardBg }}>
|
||||
<div style={{ fontSize: 12, color: overlayTheme.mutedText, lineHeight: 1.6, marginBottom: 8 }}>
|
||||
{hasDraftValue ? '已输入新值,保存时会替换当前已保存内容。' : description}
|
||||
</div>
|
||||
<Checkbox
|
||||
checked={effectiveChecked}
|
||||
disabled={hasDraftValue}
|
||||
onChange={(event) => {
|
||||
const checked = event.target.checked;
|
||||
setClearSecrets((prev) => ({ ...prev, [clearKey]: checked }));
|
||||
}}
|
||||
>
|
||||
{clearLabel}
|
||||
</Checkbox>
|
||||
</div>
|
||||
);
|
||||
}}
|
||||
</Form.Item>
|
||||
);
|
||||
};
|
||||
const renderConnectionModalTitle = (icon: React.ReactNode, title: string, description: string) => (
|
||||
<div style={{ display: 'flex', alignItems: 'flex-start', gap: 12 }}>
|
||||
<div style={{ width: 36, height: 36, borderRadius: 12, display: 'grid', placeItems: 'center', background: overlayTheme.iconBg, color: overlayTheme.iconColor, flexShrink: 0 }}>
|
||||
@@ -749,6 +820,19 @@ const ConnectionModal: React.FC<{
|
||||
}
|
||||
});
|
||||
|
||||
const createCustomDsnRule = () => ({
|
||||
validator(_: unknown, value: unknown) {
|
||||
const validationMessage = getCustomConnectionDsnValidationMessage({
|
||||
dsnInput: value,
|
||||
hasStoredSecret: initialValues?.hasOpaqueDSN,
|
||||
clearStoredSecret: clearSecrets.opaqueDSN,
|
||||
});
|
||||
return validationMessage
|
||||
? Promise.reject(new Error(validationMessage))
|
||||
: Promise.resolve();
|
||||
}
|
||||
});
|
||||
|
||||
const getUriPlaceholder = () => {
|
||||
if (dbType === 'mysql' || dbType === 'mariadb' || dbType === 'diros' || dbType === 'sphinx') {
|
||||
const defaultPort = getDefaultPortByType(dbType);
|
||||
@@ -1066,6 +1150,7 @@ const ConnectionModal: React.FC<{
|
||||
setUriFeedback(null);
|
||||
setCustomIconType(undefined);
|
||||
setCustomIconColor(undefined);
|
||||
setClearSecrets(createEmptyConnectionSecretClearState());
|
||||
setTypeSelectWarning(null);
|
||||
setDriverStatusLoaded(false);
|
||||
void refreshDriverStatus();
|
||||
@@ -1198,6 +1283,107 @@ const ConnectionModal: React.FC<{
|
||||
};
|
||||
}, []);
|
||||
|
||||
const buildSavedConnectionInput = (config: ConnectionConfig, values: any) => {
|
||||
const connectionId = initialValues?.id || config.id || Date.now().toString();
|
||||
const primaryDraft = resolveConnectionSecretDraft({
|
||||
hasSecret: initialValues?.hasPrimaryPassword,
|
||||
valueInput: config.password,
|
||||
clearSecret: clearSecrets.primaryPassword,
|
||||
forceClear: values.type === 'mongodb' && values.savePassword === false,
|
||||
});
|
||||
const sshDraft = resolveConnectionSecretDraft({
|
||||
hasSecret: initialValues?.hasSSHPassword,
|
||||
valueInput: config.ssh?.password,
|
||||
clearSecret: clearSecrets.sshPassword,
|
||||
forceClear: !config.useSSH,
|
||||
});
|
||||
const proxyDraft = resolveConnectionSecretDraft({
|
||||
hasSecret: initialValues?.hasProxyPassword,
|
||||
valueInput: config.proxy?.password,
|
||||
clearSecret: clearSecrets.proxyPassword,
|
||||
forceClear: !config.useProxy,
|
||||
});
|
||||
const httpTunnelDraft = resolveConnectionSecretDraft({
|
||||
hasSecret: initialValues?.hasHttpTunnelPassword,
|
||||
valueInput: config.httpTunnel?.password,
|
||||
clearSecret: clearSecrets.httpTunnelPassword,
|
||||
forceClear: !config.useHttpTunnel,
|
||||
});
|
||||
const mysqlReplicaEnabled = (config.type === 'mysql' || config.type === 'mariadb' || config.type === 'diros' || config.type === 'sphinx')
|
||||
&& config.topology === 'replica';
|
||||
const mysqlReplicaDraft = resolveConnectionSecretDraft({
|
||||
hasSecret: initialValues?.hasMySQLReplicaPassword,
|
||||
valueInput: config.mysqlReplicaPassword,
|
||||
clearSecret: clearSecrets.mysqlReplicaPassword,
|
||||
forceClear: !mysqlReplicaEnabled,
|
||||
});
|
||||
const mongoReplicaEnabled = config.type === 'mongodb'
|
||||
&& config.topology === 'replica'
|
||||
&& values.savePassword !== false;
|
||||
const mongoReplicaDraft = resolveConnectionSecretDraft({
|
||||
hasSecret: initialValues?.hasMongoReplicaPassword,
|
||||
valueInput: config.mongoReplicaPassword,
|
||||
clearSecret: clearSecrets.mongoReplicaPassword,
|
||||
forceClear: !mongoReplicaEnabled,
|
||||
});
|
||||
const opaqueUriDraft = resolveConnectionSecretDraft({
|
||||
hasSecret: initialValues?.hasOpaqueURI,
|
||||
valueInput: config.uri,
|
||||
clearSecret: clearSecrets.opaqueURI,
|
||||
forceClear: values.type === 'custom',
|
||||
trimInput: true,
|
||||
});
|
||||
const opaqueDsnDraft = resolveConnectionSecretDraft({
|
||||
hasSecret: initialValues?.hasOpaqueDSN,
|
||||
valueInput: config.dsn,
|
||||
clearSecret: clearSecrets.opaqueDSN,
|
||||
forceClear: values.type !== 'custom',
|
||||
trimInput: true,
|
||||
});
|
||||
const isRedisType = values.type === 'redis';
|
||||
const displayHost = String((config as any).host || values.host || '').trim();
|
||||
const nextName = values.name || (isFileDatabaseType(values.type)
|
||||
? (values.type === 'duckdb' ? 'DuckDB DB' : 'SQLite DB')
|
||||
: (values.type === 'redis' ? `Redis ${displayHost}` : displayHost));
|
||||
|
||||
return {
|
||||
id: connectionId,
|
||||
name: nextName,
|
||||
config: {
|
||||
...config,
|
||||
id: connectionId,
|
||||
password: primaryDraft.value,
|
||||
ssh: {
|
||||
...(config.ssh || { host: '', port: 22, user: '', password: '', keyPath: '' }),
|
||||
password: sshDraft.value,
|
||||
},
|
||||
proxy: {
|
||||
...(config.proxy || { type: 'socks5', host: '', port: 1080, user: '', password: '' }),
|
||||
password: proxyDraft.value,
|
||||
},
|
||||
httpTunnel: {
|
||||
...(config.httpTunnel || { host: '', port: 8080, user: '', password: '' }),
|
||||
password: httpTunnelDraft.value,
|
||||
},
|
||||
uri: opaqueUriDraft.value,
|
||||
dsn: opaqueDsnDraft.value,
|
||||
mysqlReplicaPassword: mysqlReplicaDraft.value,
|
||||
mongoReplicaPassword: mongoReplicaDraft.value,
|
||||
},
|
||||
includeDatabases: values.includeDatabases,
|
||||
includeRedisDatabases: isRedisType ? values.includeRedisDatabases : undefined,
|
||||
iconType: customIconType || '',
|
||||
iconColor: customIconColor || '',
|
||||
clearPrimaryPassword: primaryDraft.clearStoredSecret,
|
||||
clearSSHPassword: sshDraft.clearStoredSecret,
|
||||
clearProxyPassword: proxyDraft.clearStoredSecret,
|
||||
clearHttpTunnelPassword: httpTunnelDraft.clearStoredSecret,
|
||||
clearMySQLReplicaPassword: mysqlReplicaDraft.clearStoredSecret,
|
||||
clearMongoReplicaPassword: mongoReplicaDraft.clearStoredSecret,
|
||||
clearOpaqueURI: opaqueUriDraft.clearStoredSecret,
|
||||
clearOpaqueDSN: opaqueDsnDraft.clearStoredSecret,
|
||||
};
|
||||
};
|
||||
const handleOk = async () => {
|
||||
try {
|
||||
await form.validateFields();
|
||||
@@ -1211,28 +1397,21 @@ const ConnectionModal: React.FC<{
|
||||
setLoading(true);
|
||||
|
||||
const config = await buildConfig(values, true);
|
||||
const displayHost = String((config as any).host || values.host || '').trim();
|
||||
|
||||
const isRedisType = values.type === 'redis';
|
||||
const newConn = {
|
||||
id: initialValues ? initialValues.id : Date.now().toString(),
|
||||
name: values.name || (isFileDatabaseType(values.type) ? (values.type === 'duckdb' ? 'DuckDB DB' : 'SQLite DB') : (values.type === 'redis' ? `Redis ${displayHost}` : displayHost)),
|
||||
config: config,
|
||||
includeDatabases: values.includeDatabases,
|
||||
includeRedisDatabases: isRedisType ? values.includeRedisDatabases : undefined,
|
||||
iconType: customIconType,
|
||||
iconColor: customIconColor,
|
||||
};
|
||||
const payload = buildSavedConnectionInput(config, values);
|
||||
const backendApp = (window as any).go?.app?.App;
|
||||
const savedConnection = await backendApp?.SaveConnection?.(payload);
|
||||
if (!savedConnection) {
|
||||
throw new Error('保存连接失败:后端接口不可用');
|
||||
}
|
||||
|
||||
if (initialValues) {
|
||||
updateConnection(newConn);
|
||||
updateConnection(savedConnection);
|
||||
message.success('配置已更新(未连接)');
|
||||
} else {
|
||||
addConnection(newConn);
|
||||
addConnection(savedConnection);
|
||||
message.success('配置已保存(未连接)');
|
||||
}
|
||||
|
||||
setLoading(false);
|
||||
form.resetFields();
|
||||
setUseSSL(false);
|
||||
setUseSSH(false);
|
||||
@@ -1240,8 +1419,11 @@ const ConnectionModal: React.FC<{
|
||||
setUseHttpTunnel(false);
|
||||
setDbType('mysql');
|
||||
setStep(1);
|
||||
setClearSecrets(createEmptyConnectionSecretClearState());
|
||||
onClose();
|
||||
} catch (e) {
|
||||
} catch (e: any) {
|
||||
message.error(e?.message || '保存失败');
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
@@ -1271,6 +1453,30 @@ const ConnectionModal: React.FC<{
|
||||
}
|
||||
};
|
||||
|
||||
const getBlockingSecretClearMessage = (values: any): string | null => {
|
||||
if (clearSecrets.primaryPassword && values.type !== 'custom' && !isFileDatabaseType(values.type) && String(values.password ?? '') === '') {
|
||||
return '测试连接前请填写新的密码,或取消清除已保存密码';
|
||||
}
|
||||
if (clearSecrets.sshPassword && values.useSSH && String(values.sshPassword ?? '') === '') {
|
||||
return '测试连接前请填写新的 SSH 密码,或取消清除已保存 SSH 密码';
|
||||
}
|
||||
if (clearSecrets.proxyPassword && values.useProxy && !values.useHttpTunnel && String(values.proxyPassword ?? '') === '') {
|
||||
return '测试连接前请填写新的代理密码,或取消清除已保存代理密码';
|
||||
}
|
||||
if (clearSecrets.httpTunnelPassword && values.useHttpTunnel && String(values.httpTunnelPassword ?? '') === '') {
|
||||
return '测试连接前请填写新的隧道密码,或取消清除已保存隧道密码';
|
||||
}
|
||||
if (clearSecrets.mysqlReplicaPassword && (values.type === 'mysql' || values.type === 'mariadb' || values.type === 'diros' || values.type === 'sphinx') && values.mysqlTopology === 'replica' && String(values.mysqlReplicaPassword ?? '') === '') {
|
||||
return '测试连接前请填写新的从库密码,或取消清除已保存从库密码';
|
||||
}
|
||||
if (clearSecrets.mongoReplicaPassword && values.type === 'mongodb' && values.mongoTopology === 'replica' && String(values.mongoReplicaPassword ?? '') === '') {
|
||||
return '测试连接前请填写新的副本集密码,或取消清除已保存副本集密码';
|
||||
}
|
||||
if (values.type === 'mongodb' && values.savePassword === false && initialValues?.hasPrimaryPassword && String(values.password ?? '') === '') {
|
||||
return '测试连接前请填写新的 MongoDB 密码,或重新勾选保存密码';
|
||||
}
|
||||
return null;
|
||||
};
|
||||
const buildTestFailureMessage = (reason: unknown, fallback: string) => {
|
||||
const text = String(reason ?? '').trim();
|
||||
const normalized = text && text !== 'undefined' && text !== 'null' ? text : fallback;
|
||||
@@ -1290,9 +1496,17 @@ const ConnectionModal: React.FC<{
|
||||
promptInstallDriver(values.type, unavailableReason);
|
||||
return;
|
||||
}
|
||||
const blockingSecretClearMessage = getBlockingSecretClearMessage(values);
|
||||
if (blockingSecretClearMessage) {
|
||||
setTestResult({ type: 'error', message: blockingSecretClearMessage });
|
||||
return;
|
||||
}
|
||||
setLoading(true);
|
||||
setTestResult(null);
|
||||
const config = await buildConfig(values, false);
|
||||
if (initialValues?.id) {
|
||||
config.id = initialValues.id;
|
||||
}
|
||||
const timeoutSecondsRaw = Number(values.timeout);
|
||||
const timeoutSeconds = Number.isFinite(timeoutSecondsRaw) && timeoutSecondsRaw > 0
|
||||
? Math.min(timeoutSecondsRaw, MAX_TIMEOUT_SECONDS)
|
||||
@@ -1368,7 +1582,15 @@ const ConnectionModal: React.FC<{
|
||||
await form.validateFields();
|
||||
const values = form.getFieldsValue(true);
|
||||
setDiscoveringMembers(true);
|
||||
const blockingSecretClearMessage = getBlockingSecretClearMessage(values);
|
||||
if (blockingSecretClearMessage) {
|
||||
message.error(blockingSecretClearMessage);
|
||||
return;
|
||||
}
|
||||
const config = await buildConfig(values, false);
|
||||
if (initialValues?.id) {
|
||||
config.id = initialValues.id;
|
||||
}
|
||||
const result = await MongoDiscoverMembers(config as any);
|
||||
if (!result.success) {
|
||||
message.error(result.message || '成员发现失败');
|
||||
@@ -1877,6 +2099,13 @@ const ConnectionModal: React.FC<{
|
||||
style={{ marginBottom: 16 }}
|
||||
/>
|
||||
)}
|
||||
{renderStoredSecretControls({
|
||||
fieldName: 'uri',
|
||||
clearKey: 'opaqueURI',
|
||||
hasStoredSecret: initialValues?.hasOpaqueURI,
|
||||
clearLabel: '清除已保存 URI',
|
||||
description: '当前已保存连接 URI。留空表示继续沿用,输入新值表示替换。',
|
||||
})}
|
||||
</>
|
||||
)}
|
||||
|
||||
@@ -1885,9 +2114,16 @@ const ConnectionModal: React.FC<{
|
||||
<Form.Item name="driver" label="驱动名称 (Driver Name)" rules={[{ required: true, message: '请输入驱动名称' }]} help="已支持: mysql, postgres, sqlite, oracle, dm, kingbase">
|
||||
<Input placeholder="例如: mysql, postgres" />
|
||||
</Form.Item>
|
||||
<Form.Item name="dsn" label="连接字符串 (DSN)" rules={[{ required: true, message: '请输入连接字符串' }]}>
|
||||
<Form.Item name="dsn" label="连接字符串 (DSN)" rules={[createCustomDsnRule()]}>
|
||||
<Input.TextArea rows={4} placeholder="例如: user:pass@tcp(localhost:3306)/dbname?charset=utf8" />
|
||||
</Form.Item>
|
||||
{renderStoredSecretControls({
|
||||
fieldName: 'dsn',
|
||||
clearKey: 'opaqueDSN',
|
||||
hasStoredSecret: initialValues?.hasOpaqueDSN,
|
||||
clearLabel: '清除已保存 DSN',
|
||||
description: '当前已保存连接字符串。留空表示继续沿用,输入新值表示替换。',
|
||||
})}
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
@@ -1968,6 +2204,13 @@ const ConnectionModal: React.FC<{
|
||||
<Input.Password placeholder="留空沿用主库密码" />
|
||||
</Form.Item>
|
||||
</div>
|
||||
{renderStoredSecretControls({
|
||||
fieldName: 'mysqlReplicaPassword',
|
||||
clearKey: 'mysqlReplicaPassword',
|
||||
hasStoredSecret: initialValues?.hasMySQLReplicaPassword,
|
||||
clearLabel: '清除已保存从库密码',
|
||||
description: '当前已保存从库密码。留空表示继续沿用,输入新值表示替换。',
|
||||
})}
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
@@ -2010,6 +2253,13 @@ const ConnectionModal: React.FC<{
|
||||
<Form.Item name="mongoReplicaPassword" label="副本集密码(可选)" style={{ marginBottom: 0 }}>
|
||||
<Input.Password placeholder="留空沿用主密码" />
|
||||
</Form.Item>
|
||||
{renderStoredSecretControls({
|
||||
fieldName: 'mongoReplicaPassword',
|
||||
clearKey: 'mongoReplicaPassword',
|
||||
hasStoredSecret: initialValues?.hasMongoReplicaPassword,
|
||||
clearLabel: '清除已保存副本集密码',
|
||||
description: '当前已保存副本集密码。留空表示继续沿用,输入新值表示替换。',
|
||||
})}
|
||||
<Space size={8} style={{ marginTop: 12, marginBottom: 12 }}>
|
||||
<Button onClick={handleDiscoverMongoMembers} loading={discoveringMembers}>自动发现成员</Button>
|
||||
</Space>
|
||||
@@ -2084,6 +2334,13 @@ const ConnectionModal: React.FC<{
|
||||
<Form.Item name="password" label="密码 (可选)">
|
||||
<Input.Password placeholder="Redis 密码(如果设置了 requirepass)" />
|
||||
</Form.Item>
|
||||
{renderStoredSecretControls({
|
||||
fieldName: 'password',
|
||||
clearKey: 'primaryPassword',
|
||||
hasStoredSecret: initialValues?.hasPrimaryPassword,
|
||||
clearLabel: '清除已保存密码',
|
||||
description: '当前已保存 Redis 密码。留空表示继续沿用,输入新值表示替换。',
|
||||
})}
|
||||
<Form.Item
|
||||
name="includeRedisDatabases"
|
||||
label="显示数据库 (留空显示全部)"
|
||||
@@ -2097,6 +2354,7 @@ const ConnectionModal: React.FC<{
|
||||
)}
|
||||
|
||||
{!isFileDb && !isRedis && (
|
||||
<>
|
||||
<div style={{ display: 'grid', gridTemplateColumns: dbType === 'mongodb' ? 'minmax(0, 1fr) minmax(0, 1fr) 180px' : 'repeat(2, minmax(0, 1fr))', gap: 16 }}>
|
||||
<Form.Item
|
||||
name="user"
|
||||
@@ -2124,6 +2382,14 @@ const ConnectionModal: React.FC<{
|
||||
</Form.Item>
|
||||
)}
|
||||
</div>
|
||||
{renderStoredSecretControls({
|
||||
fieldName: 'password',
|
||||
clearKey: 'primaryPassword',
|
||||
hasStoredSecret: initialValues?.hasPrimaryPassword,
|
||||
clearLabel: '清除已保存密码',
|
||||
description: '当前已保存主连接密码。留空表示继续沿用,输入新值表示替换。',
|
||||
})}
|
||||
</>
|
||||
)}
|
||||
|
||||
{dbType === 'mongodb' && (
|
||||
@@ -2233,6 +2499,13 @@ const ConnectionModal: React.FC<{
|
||||
</Button>
|
||||
</Space.Compact>
|
||||
</Form.Item>
|
||||
{renderStoredSecretControls({
|
||||
fieldName: 'sshPassword',
|
||||
clearKey: 'sshPassword',
|
||||
hasStoredSecret: initialValues?.hasSSHPassword,
|
||||
clearLabel: '清除已保存 SSH 密码',
|
||||
description: '当前已保存 SSH 密码。留空表示继续沿用,输入新值表示替换。',
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
@@ -2271,6 +2544,13 @@ const ConnectionModal: React.FC<{
|
||||
<Input.Password placeholder="留空表示无认证" />
|
||||
</Form.Item>
|
||||
</div>
|
||||
{renderStoredSecretControls({
|
||||
fieldName: 'proxyPassword',
|
||||
clearKey: 'proxyPassword',
|
||||
hasStoredSecret: initialValues?.hasProxyPassword,
|
||||
clearLabel: '清除已保存代理密码',
|
||||
description: '当前已保存代理密码。留空表示继续沿用,输入新值表示替换。',
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
@@ -2302,6 +2582,13 @@ const ConnectionModal: React.FC<{
|
||||
<Input.Password placeholder="留空表示无认证" />
|
||||
</Form.Item>
|
||||
</div>
|
||||
{renderStoredSecretControls({
|
||||
fieldName: 'httpTunnelPassword',
|
||||
clearKey: 'httpTunnelPassword',
|
||||
hasStoredSecret: initialValues?.hasHttpTunnelPassword,
|
||||
clearLabel: '清除已保存隧道密码',
|
||||
description: '当前已保存隧道密码。留空表示继续沿用,输入新值表示替换。',
|
||||
})}
|
||||
<Text type="secondary" style={{ fontSize: 12 }}>与“使用代理”互斥,启用后将通过 HTTP CONNECT 建立独立隧道。</Text>
|
||||
</div>
|
||||
)}
|
||||
@@ -2832,3 +3119,8 @@ const ConnectionModal: React.FC<{
|
||||
};
|
||||
|
||||
export default ConnectionModal;
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -23,18 +23,31 @@ import {
|
||||
arrayMove
|
||||
} from '@dnd-kit/sortable';
|
||||
import { CSS } from '@dnd-kit/utilities';
|
||||
import { ImportData, ExportTable, ExportData, ExportQuery, ApplyChanges, DBGetColumns } from '../../wailsjs/go/app/App';
|
||||
import { ImportData, ExportTable, ExportData, ExportQuery, ApplyChanges, DBGetColumns, DBGetIndexes } from '../../wailsjs/go/app/App';
|
||||
import ImportPreviewModal from './ImportPreviewModal';
|
||||
import { useStore } from '../store';
|
||||
import type { ColumnDefinition } from '../types';
|
||||
import type { ColumnDefinition, IndexDefinition } from '../types';
|
||||
import { v4 as generateUuid } from 'uuid';
|
||||
import 'react-resizable/css/styles.css';
|
||||
import { buildOrderBySQL, buildPaginatedSelectSQL, buildWhereSQL, escapeLiteral, hasExplicitSort, quoteIdentPart, quoteQualifiedIdent, withSortBufferTuningSQL, type FilterCondition } from '../utils/sql';
|
||||
import { isMacLikePlatform, normalizeOpacityForPlatform, resolveAppearanceValues } from '../utils/appearance';
|
||||
import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities';
|
||||
import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig';
|
||||
import {
|
||||
resolveDataTableColumnWidth,
|
||||
resolveDataTableDefaultColumnWidth,
|
||||
resolveDataTableVerticalBorderColor,
|
||||
} from '../utils/dataGridDisplay';
|
||||
import { resolvePaginationPageText, resolvePaginationSummaryText, resolvePaginationTotalForControl } from '../utils/dataGridPagination';
|
||||
import { resolveGridSortInfoFromTableSorter } from '../utils/dataGridSort';
|
||||
import { calculateTableBodyBottomPadding, calculateVirtualTableScrollX } from './dataGridLayout';
|
||||
import { buildCopyInsertSQL, normalizeTemporalLiteralText } from './dataGridCopyInsert';
|
||||
import {
|
||||
buildCopyDeleteSQL,
|
||||
buildCopyInsertSQL,
|
||||
buildCopyUpdateSQL,
|
||||
normalizeTemporalLiteralText,
|
||||
resolveUniqueKeyGroupsFromIndexes,
|
||||
} from './dataGridCopyInsert';
|
||||
|
||||
// --- Error Boundary ---
|
||||
interface DataGridErrorBoundaryState {
|
||||
@@ -531,6 +544,8 @@ const DataContext = React.createContext<{
|
||||
selectedRowKeysRef: React.MutableRefObject<React.Key[]>;
|
||||
displayDataRef: React.MutableRefObject<any[]>;
|
||||
handleCopyInsert: (r: any) => void;
|
||||
handleCopyUpdate: (r: any) => void;
|
||||
handleCopyDelete: (r: any) => void;
|
||||
handleCopyJson: (r: any) => void;
|
||||
handleCopyCsv: (r: any) => void;
|
||||
handleExportSelected: (format: string, r: any) => Promise<void>;
|
||||
@@ -783,7 +798,19 @@ const ContextMenuRow = React.memo(({ children, record, ...props }: any) => {
|
||||
|
||||
if (!record || !context) return <tr {...props}>{children}</tr>;
|
||||
|
||||
const { selectedRowKeysRef, displayDataRef, handleCopyInsert, handleCopyJson, handleCopyCsv, handleExportSelected, copyToClipboard, enableRowContextMenu, supportsCopyInsert } = context;
|
||||
const {
|
||||
selectedRowKeysRef,
|
||||
displayDataRef,
|
||||
handleCopyInsert,
|
||||
handleCopyUpdate,
|
||||
handleCopyDelete,
|
||||
handleCopyJson,
|
||||
handleCopyCsv,
|
||||
handleExportSelected,
|
||||
copyToClipboard,
|
||||
enableRowContextMenu,
|
||||
supportsCopyInsert,
|
||||
} = context;
|
||||
|
||||
if (!enableRowContextMenu) {
|
||||
return <tr {...props}>{children}</tr>;
|
||||
@@ -804,6 +831,16 @@ const ContextMenuRow = React.memo(({ children, record, ...props }: any) => {
|
||||
label: '复制为 INSERT',
|
||||
icon: <ConsoleSqlOutlined />,
|
||||
onClick: () => handleCopyInsert(record),
|
||||
}, {
|
||||
key: 'update',
|
||||
label: '复制为 UPDATE',
|
||||
icon: <ConsoleSqlOutlined />,
|
||||
onClick: () => handleCopyUpdate(record),
|
||||
}, {
|
||||
key: 'delete',
|
||||
label: '复制为 DELETE',
|
||||
icon: <ConsoleSqlOutlined />,
|
||||
onClick: () => handleCopyDelete(record),
|
||||
}] : []),
|
||||
{ key: 'json', label: '复制为 JSON', icon: <FileTextOutlined />, onClick: () => handleCopyJson(record) },
|
||||
{ key: 'csv', label: '复制为 CSV', icon: <FileTextOutlined />, onClick: () => handleCopyCsv(record) },
|
||||
@@ -929,6 +966,13 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
const darkMode = theme === 'dark';
|
||||
const resolvedAppearance = resolveAppearanceValues(appearance);
|
||||
const opacity = normalizeOpacityForPlatform(resolvedAppearance.opacity);
|
||||
const showDataTableVerticalBorders = appearance.showDataTableVerticalBorders === true;
|
||||
const dataTableColumnWidthMode = appearance.dataTableColumnWidthMode;
|
||||
const defaultColumnWidth = resolveDataTableDefaultColumnWidth(dataTableColumnWidthMode);
|
||||
const dataTableVerticalBorderColor = resolveDataTableVerticalBorderColor({
|
||||
darkMode,
|
||||
visible: showDataTableVerticalBorders,
|
||||
});
|
||||
const canModifyData = !readOnly && !!tableName;
|
||||
const showColumnComment = queryOptions?.showColumnComment ?? true;
|
||||
const showColumnType = queryOptions?.showColumnType ?? true;
|
||||
@@ -1310,8 +1354,11 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
const [sortInfo, setSortInfo] = useState<Array<{ columnKey: string, order: string, enabled?: boolean }>>([]);
|
||||
const [columnWidths, setColumnWidths] = useState<Record<string, number>>({});
|
||||
const [columnMetaMap, setColumnMetaMap] = useState<Record<string, ColumnMeta>>({});
|
||||
const [uniqueKeyGroups, setUniqueKeyGroups] = useState<string[][]>([]);
|
||||
const columnMetaCacheRef = useRef<Record<string, Record<string, ColumnMeta>>>({});
|
||||
const columnMetaSeqRef = useRef(0);
|
||||
const uniqueKeyGroupsCacheRef = useRef<Record<string, string[][]>>({});
|
||||
const uniqueKeyGroupsSeqRef = useRef(0);
|
||||
|
||||
useEffect(() => {
|
||||
const ext = sortInfoExternal || [];
|
||||
@@ -1326,10 +1373,12 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
const normalizedDbName = String(dbName || '').trim();
|
||||
if (!connectionId || !normalizedTableName) {
|
||||
setColumnMetaMap({});
|
||||
setUniqueKeyGroups([]);
|
||||
return;
|
||||
}
|
||||
const cacheKey = `${connectionId}|${normalizedDbName}|${normalizedTableName}`;
|
||||
setColumnMetaMap(columnMetaCacheRef.current[cacheKey] || {});
|
||||
setUniqueKeyGroups(uniqueKeyGroupsCacheRef.current[cacheKey] || []);
|
||||
}, [connectionId, dbName, tableName]);
|
||||
|
||||
useEffect(() => {
|
||||
@@ -1356,7 +1405,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
};
|
||||
|
||||
const seq = ++columnMetaSeqRef.current;
|
||||
DBGetColumns(config as any, normalizedDbName, normalizedTableName)
|
||||
DBGetColumns(buildRpcConnectionConfig(config) as any, normalizedDbName, normalizedTableName)
|
||||
.then((res) => {
|
||||
if (seq !== columnMetaSeqRef.current) return;
|
||||
if (!res.success || !Array.isArray(res.data)) {
|
||||
@@ -1380,6 +1429,47 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
});
|
||||
}, [connections, connectionId, dbName, tableName]);
|
||||
|
||||
useEffect(() => {
|
||||
const normalizedTableName = String(tableName || '').trim();
|
||||
const normalizedDbName = String(dbName || '').trim();
|
||||
if (!connectionId || !normalizedTableName) return;
|
||||
|
||||
const cacheKey = `${connectionId}|${normalizedDbName}|${normalizedTableName}`;
|
||||
if (uniqueKeyGroupsCacheRef.current[cacheKey]) return;
|
||||
|
||||
const conn = connections.find(c => c.id === connectionId);
|
||||
if (!conn) {
|
||||
setUniqueKeyGroups([]);
|
||||
return;
|
||||
}
|
||||
|
||||
const config = {
|
||||
...conn.config,
|
||||
port: Number(conn.config.port),
|
||||
password: conn.config.password || "",
|
||||
database: conn.config.database || "",
|
||||
useSSH: conn.config.useSSH || false,
|
||||
ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
|
||||
};
|
||||
|
||||
const seq = ++uniqueKeyGroupsSeqRef.current;
|
||||
DBGetIndexes(config as any, normalizedDbName, normalizedTableName)
|
||||
.then((res) => {
|
||||
if (seq !== uniqueKeyGroupsSeqRef.current) return;
|
||||
if (!res.success || !Array.isArray(res.data)) {
|
||||
setUniqueKeyGroups([]);
|
||||
return;
|
||||
}
|
||||
const nextGroups = resolveUniqueKeyGroupsFromIndexes(res.data as IndexDefinition[]);
|
||||
uniqueKeyGroupsCacheRef.current[cacheKey] = nextGroups;
|
||||
setUniqueKeyGroups(nextGroups);
|
||||
})
|
||||
.catch(() => {
|
||||
if (seq !== uniqueKeyGroupsSeqRef.current) return;
|
||||
setUniqueKeyGroups([]);
|
||||
});
|
||||
}, [connections, connectionId, dbName, tableName]);
|
||||
|
||||
const columnMetaMapByLowerName = useMemo(() => {
|
||||
const next: Record<string, ColumnMeta> = {};
|
||||
Object.entries(columnMetaMap).forEach(([name, meta]) => {
|
||||
@@ -1400,6 +1490,17 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
return next;
|
||||
}, [columnMetaMapByLowerName]);
|
||||
|
||||
const allTableColumnNames = useMemo(() => {
|
||||
const metaColumns = Object.keys(columnMetaMap);
|
||||
if (metaColumns.length > 0) {
|
||||
return metaColumns;
|
||||
}
|
||||
if (exportScope === 'table') {
|
||||
return columnNames.filter((columnName) => columnName !== GONAVI_ROW_KEY);
|
||||
}
|
||||
return [];
|
||||
}, [columnMetaMap, exportScope, columnNames]);
|
||||
|
||||
const normalizeCommitCellValue = useCallback(
|
||||
(columnName: string, value: any, mode: 'insert' | 'update') => {
|
||||
if (value === undefined) return undefined;
|
||||
@@ -1570,8 +1671,15 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
overflow: hidden !important;
|
||||
}
|
||||
.${gridId} .ant-table-tbody > tr > td,
|
||||
.${gridId} .ant-table-tbody .ant-table-row > .ant-table-cell { background: transparent !important; border-bottom: 1px solid ${darkMode ? 'rgba(255,255,255,0.05)' : 'rgba(0,0,0,0.05)'} !important; border-inline-end: 1px solid transparent !important; }
|
||||
.${gridId} .ant-table-thead > tr > th { background: transparent !important; border-bottom: 1px solid ${darkMode ? 'rgba(255,255,255,0.05)' : 'rgba(0,0,0,0.05)'} !important; border-inline-end: 1px solid transparent !important; }
|
||||
.${gridId} .ant-table-tbody .ant-table-row > .ant-table-cell,
|
||||
.${gridId} .ant-table-tbody-virtual-holder .ant-table-row > .ant-table-cell { background: transparent !important; border-bottom: 1px solid ${darkMode ? 'rgba(255,255,255,0.05)' : 'rgba(0,0,0,0.05)'} !important; border-inline-end: 1px solid ${dataTableVerticalBorderColor} !important; }
|
||||
.${gridId} .ant-table-thead > tr > th { background: transparent !important; border-bottom: 1px solid ${darkMode ? 'rgba(255,255,255,0.05)' : 'rgba(0,0,0,0.05)'} !important; border-inline-end: 1px solid ${dataTableVerticalBorderColor} !important; }
|
||||
.${gridId} .ant-table-tbody > tr > td:last-child,
|
||||
.${gridId} .ant-table-tbody .ant-table-row > .ant-table-cell:last-child,
|
||||
.${gridId} .ant-table-tbody-virtual-holder .ant-table-row > .ant-table-cell:last-child,
|
||||
.${gridId} .ant-table-thead > tr > th:last-child {
|
||||
border-inline-end-color: transparent !important;
|
||||
}
|
||||
/* 选择列对齐:header TH 无 class(Ant Design 虚拟模式),需用 :first-child 匹配 */
|
||||
.${gridId} .ant-table-header th:first-child,
|
||||
.${gridId} .ant-table-thead > tr > th:first-child {
|
||||
@@ -2008,7 +2116,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
justify-content: center;
|
||||
line-height: 1;
|
||||
}
|
||||
`, [themeStyles, gridId, tableBodyBottomPadding, darkMode, opacity]);
|
||||
`, [themeStyles, gridId, tableBodyBottomPadding, darkMode, opacity, dataTableVerticalBorderColor]);
|
||||
|
||||
const recalculateTableMetrics = useCallback((targetElement?: HTMLElement | null) => {
|
||||
const target = targetElement || containerRef.current;
|
||||
@@ -2762,39 +2870,10 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
|
||||
const handleTableChange = useCallback((_pag: any, _filtersArg: any, sorter: any) => {
|
||||
if (isResizingRef.current) return; // Block sort if resizing
|
||||
// Ant Design 多列排序模式下 sorter 可能是数组
|
||||
const sorters = Array.isArray(sorter) ? sorter : (sorter?.field ? [sorter] : []);
|
||||
if (sorters.length === 0) {
|
||||
setSortInfo([]);
|
||||
if (onSort) onSort(JSON.stringify([]), '');
|
||||
return;
|
||||
}
|
||||
// 在现有排序数组基础上增量更新
|
||||
const next = [...sortInfo];
|
||||
for (const s of sorters) {
|
||||
const field = String(s.field || '');
|
||||
if (!field) continue;
|
||||
const order = s.order as string;
|
||||
const normalizedOrder = order === 'ascend' || order === 'descend' ? order : '';
|
||||
const existIdx = next.findIndex(item => item.columnKey === field);
|
||||
if (!normalizedOrder) {
|
||||
// Ant Design 第三次点击想取消排序:
|
||||
// 如果该字段已在排序数组中,回转为升序而非移除
|
||||
if (existIdx >= 0) {
|
||||
next[existIdx] = { ...next[existIdx], order: 'ascend', enabled: true };
|
||||
}
|
||||
// 不在数组中则忽略
|
||||
} else if (existIdx >= 0) {
|
||||
// 已存在:更新排序方向
|
||||
next[existIdx] = { ...next[existIdx], order: normalizedOrder, enabled: true };
|
||||
} else {
|
||||
// 不存在:追加到末尾
|
||||
next.push({ columnKey: field, order: normalizedOrder, enabled: true });
|
||||
}
|
||||
}
|
||||
const next = resolveGridSortInfoFromTableSorter({ sorter });
|
||||
setSortInfo(next);
|
||||
if (onSort) onSort(JSON.stringify(next), '');
|
||||
}, [onSort, sortInfo]);
|
||||
}, [onSort]);
|
||||
|
||||
// Native Drag State
|
||||
const draggingRef = useRef<{
|
||||
@@ -2832,7 +2911,10 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
|
||||
const startX = e.clientX;
|
||||
|
||||
const currentWidth = columnWidths[key] || 200;
|
||||
const currentWidth = resolveDataTableColumnWidth({
|
||||
manualWidth: columnWidths[key],
|
||||
widthMode: dataTableColumnWidthMode,
|
||||
});
|
||||
|
||||
const containerLeft = containerRef.current?.getBoundingClientRect().left ?? 0;
|
||||
|
||||
@@ -2863,7 +2945,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
|
||||
document.body.style.userSelect = 'none';
|
||||
|
||||
}, [columnWidths]);
|
||||
}, [columnWidths, dataTableColumnWidthMode]);
|
||||
|
||||
// 2. Drag Move (Global)
|
||||
const handleResizeMove = useCallback((e: MouseEvent) => {
|
||||
@@ -3307,7 +3389,10 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
dataIndex: key,
|
||||
key: key,
|
||||
// 不使用 ellipsis,避免 Ant Design 的 Tooltip 展开行为
|
||||
width: columnWidths[key] || 200,
|
||||
width: resolveDataTableColumnWidth({
|
||||
manualWidth: columnWidths[key],
|
||||
widthMode: dataTableColumnWidthMode,
|
||||
}),
|
||||
sorter: onSort ? { multiple: displayColumnNames.indexOf(key) + 1 } : false,
|
||||
sortOrder: (sortInfo.find(s => s.columnKey === key && s.enabled !== false)?.order || null) as SortOrder | undefined,
|
||||
editable: canModifyData, // Only editable if table name known and not readonly
|
||||
@@ -3348,7 +3433,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
},
|
||||
}),
|
||||
}));
|
||||
}, [displayColumnNames, columnWidths, sortInfo, handleResizeStart, canModifyData, onSort, renderColumnTitle]);
|
||||
}, [displayColumnNames, columnWidths, sortInfo, handleResizeStart, canModifyData, onSort, renderColumnTitle, dataTableColumnWidthMode]);
|
||||
|
||||
const mergedColumns = useMemo(() => columns.map((col): ColumnType<any> => {
|
||||
const dataIndex = String(col.dataIndex);
|
||||
@@ -3528,7 +3613,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
};
|
||||
|
||||
const startTime = Date.now();
|
||||
const res = await ApplyChanges(config as any, dbName || '', tableName, { inserts, updates, deletes } as any);
|
||||
const res = await ApplyChanges(buildRpcConnectionConfig(config) as any, dbName || '', tableName, { inserts, updates, deletes } as any);
|
||||
const duration = Date.now() - startTime;
|
||||
|
||||
// Construct a pseudo-SQL representation for the log
|
||||
@@ -3581,24 +3666,87 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
return [clickedRecord];
|
||||
}, []);
|
||||
|
||||
const handleCopyInsert = useCallback((record: any) => {
|
||||
const buildCopySqlBatchText = useCallback((mode: 'insert' | 'update' | 'delete', record: any): string | null => {
|
||||
if (!supportsCopyInsert) {
|
||||
void message.warning("当前数据源不支持复制为 INSERT,请使用 JSON/CSV/Markdown 复制。");
|
||||
return;
|
||||
void message.warning("当前数据源不支持复制 SQL,请使用 JSON/CSV/Markdown 复制。");
|
||||
return null;
|
||||
}
|
||||
const records = getTargets(record);
|
||||
// 使用 columnNames 保持表定义的字段顺序,而非 Object.keys() 的不确定顺序
|
||||
const orderedCols = columnNames.filter(c => c !== GONAVI_ROW_KEY);
|
||||
const sqlList = records.map((r: any) => {
|
||||
return buildCopyInsertSQL({
|
||||
if (mode === 'insert') {
|
||||
return records.map((row: any) => buildCopyInsertSQL({
|
||||
dbType,
|
||||
tableName,
|
||||
orderedCols,
|
||||
record: r,
|
||||
record: row,
|
||||
columnTypesByLowerName: columnTypeMapByLowerName,
|
||||
});
|
||||
})).join('\n\n');
|
||||
}
|
||||
|
||||
const sqlResults = records.map((row: any) => (
|
||||
mode === 'update'
|
||||
? buildCopyUpdateSQL({
|
||||
dbType,
|
||||
tableName,
|
||||
orderedCols,
|
||||
record: row,
|
||||
pkColumns,
|
||||
uniqueKeyGroups,
|
||||
allTableColumns: allTableColumnNames,
|
||||
columnTypesByLowerName: columnTypeMapByLowerName,
|
||||
})
|
||||
: buildCopyDeleteSQL({
|
||||
dbType,
|
||||
tableName,
|
||||
orderedCols,
|
||||
record: row,
|
||||
pkColumns,
|
||||
uniqueKeyGroups,
|
||||
allTableColumns: allTableColumnNames,
|
||||
columnTypesByLowerName: columnTypeMapByLowerName,
|
||||
})
|
||||
));
|
||||
const failedResult = sqlResults.find((result) => result.ok === false);
|
||||
if (failedResult && failedResult.ok === false) {
|
||||
void message.warning(failedResult.error);
|
||||
return null;
|
||||
}
|
||||
const sqlTexts: string[] = [];
|
||||
sqlResults.forEach((result) => {
|
||||
if (result.ok) {
|
||||
sqlTexts.push(result.sql);
|
||||
}
|
||||
});
|
||||
copyToClipboard(sqlList.join('\n')); }, [supportsCopyInsert, columnNames, getTargets, copyToClipboard, dbType, tableName, columnTypeMapByLowerName]);
|
||||
return sqlTexts.join('\n\n');
|
||||
}, [
|
||||
supportsCopyInsert,
|
||||
getTargets,
|
||||
columnNames,
|
||||
dbType,
|
||||
tableName,
|
||||
columnTypeMapByLowerName,
|
||||
pkColumns,
|
||||
uniqueKeyGroups,
|
||||
allTableColumnNames,
|
||||
]);
|
||||
|
||||
const handleCopyInsert = useCallback((record: any) => {
|
||||
const batchText = buildCopySqlBatchText('insert', record);
|
||||
if (!batchText) return;
|
||||
copyToClipboard(batchText);
|
||||
}, [buildCopySqlBatchText, copyToClipboard]);
|
||||
|
||||
const handleCopyUpdate = useCallback((record: any) => {
|
||||
const batchText = buildCopySqlBatchText('update', record);
|
||||
if (!batchText) return;
|
||||
copyToClipboard(batchText);
|
||||
}, [buildCopySqlBatchText, copyToClipboard]);
|
||||
|
||||
const handleCopyDelete = useCallback((record: any) => {
|
||||
const batchText = buildCopySqlBatchText('delete', record);
|
||||
if (!batchText) return;
|
||||
copyToClipboard(batchText);
|
||||
}, [buildCopySqlBatchText, copyToClipboard]);
|
||||
|
||||
const handleCopyJson = useCallback((record: any) => {
|
||||
const records = getTargets(record);
|
||||
@@ -3646,7 +3794,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
if (!config) return;
|
||||
const hide = message.loading(`正在导出...`, 0);
|
||||
try {
|
||||
const res = await ExportQuery(config as any, dbName || '', sql, defaultName || 'export', format);
|
||||
const res = await ExportQuery(buildRpcConnectionConfig(config) as any, dbName || '', sql, defaultName || 'export', format);
|
||||
if (res.success) {
|
||||
void message.success("导出成功");
|
||||
} else if (res.message !== "已取消") {
|
||||
@@ -3764,7 +3912,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
if (!config) return;
|
||||
const hide = message.loading(`正在导出全部数据...`, 0);
|
||||
try {
|
||||
const res = await ExportTable(config as any, dbName || '', tableName, format);
|
||||
const res = await ExportTable(buildRpcConnectionConfig(config) as any, dbName || '', tableName, format);
|
||||
if (res.success) {
|
||||
void message.success("导出成功");
|
||||
} else if (res.message !== "已取消") {
|
||||
@@ -3839,7 +3987,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
const config = buildConnConfig();
|
||||
if (!config) return;
|
||||
|
||||
const res = await ImportData(config as any, dbName || '', tableName);
|
||||
const res = await ImportData(buildRpcConnectionConfig(config) as any, dbName || '', tableName);
|
||||
if (res.success && res.data && res.data.filePath) {
|
||||
setImportFilePath(res.data.filePath);
|
||||
setImportPreviewVisible(true);
|
||||
@@ -4049,6 +4197,8 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
selectedRowKeysRef,
|
||||
displayDataRef,
|
||||
handleCopyInsert,
|
||||
handleCopyUpdate,
|
||||
handleCopyDelete,
|
||||
handleCopyJson,
|
||||
handleCopyCsv,
|
||||
handleExportSelected,
|
||||
@@ -4056,7 +4206,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
tableName,
|
||||
enableRowContextMenu: false,
|
||||
supportsCopyInsert,
|
||||
}), [handleCopyCsv, handleCopyInsert, handleCopyJson, handleExportSelected, copyToClipboard, tableName, canModifyData, supportsCopyInsert]);
|
||||
}), [handleCopyCsv, handleCopyDelete, handleCopyInsert, handleCopyJson, handleCopyUpdate, handleExportSelected, copyToClipboard, tableName, supportsCopyInsert]);
|
||||
|
||||
const cellContextMenuValue = useMemo(() => ({
|
||||
showMenu: showCellContextMenu,
|
||||
@@ -4071,7 +4221,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
|
||||
const rowPropsFactory = useCallback((record: any) => ({ record } as any), []);
|
||||
|
||||
const totalWidth = columns.reduce((sum: number, col: any) => sum + (Number(col.width) || 200), 0) + selectionColumnWidth;
|
||||
const totalWidth = columns.reduce((sum: number, col: any) => sum + (Number(col.width) || defaultColumnWidth), 0) + selectionColumnWidth;
|
||||
const useContextMenuRow = false;
|
||||
const tableScrollX = useMemo(() => {
|
||||
// rc-table 在 scroll.x 小于容器宽度时会把实际列宽按视口补齐。
|
||||
@@ -5473,21 +5623,53 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
</>
|
||||
)}
|
||||
{supportsCopyInsert && (
|
||||
<div
|
||||
style={{
|
||||
padding: '8px 12px',
|
||||
cursor: 'pointer',
|
||||
transition: 'background 0.2s',
|
||||
}}
|
||||
onMouseEnter={(e) => e.currentTarget.style.background = darkMode ? '#303030' : '#f5f5f5'}
|
||||
onMouseLeave={(e) => e.currentTarget.style.background = 'transparent'}
|
||||
onClick={() => {
|
||||
if (cellContextMenu.record) handleCopyInsert(cellContextMenu.record);
|
||||
setCellContextMenu(prev => ({ ...prev, visible: false }));
|
||||
}}
|
||||
>
|
||||
复制为 INSERT
|
||||
</div>
|
||||
<>
|
||||
<div
|
||||
style={{
|
||||
padding: '8px 12px',
|
||||
cursor: 'pointer',
|
||||
transition: 'background 0.2s',
|
||||
}}
|
||||
onMouseEnter={(e) => e.currentTarget.style.background = darkMode ? '#303030' : '#f5f5f5'}
|
||||
onMouseLeave={(e) => e.currentTarget.style.background = 'transparent'}
|
||||
onClick={() => {
|
||||
if (cellContextMenu.record) handleCopyInsert(cellContextMenu.record);
|
||||
setCellContextMenu(prev => ({ ...prev, visible: false }));
|
||||
}}
|
||||
>
|
||||
复制为 INSERT
|
||||
</div>
|
||||
<div
|
||||
style={{
|
||||
padding: '8px 12px',
|
||||
cursor: 'pointer',
|
||||
transition: 'background 0.2s',
|
||||
}}
|
||||
onMouseEnter={(e) => e.currentTarget.style.background = darkMode ? '#303030' : '#f5f5f5'}
|
||||
onMouseLeave={(e) => e.currentTarget.style.background = 'transparent'}
|
||||
onClick={() => {
|
||||
if (cellContextMenu.record) handleCopyUpdate(cellContextMenu.record);
|
||||
setCellContextMenu(prev => ({ ...prev, visible: false }));
|
||||
}}
|
||||
>
|
||||
复制为 UPDATE
|
||||
</div>
|
||||
<div
|
||||
style={{
|
||||
padding: '8px 12px',
|
||||
cursor: 'pointer',
|
||||
transition: 'background 0.2s',
|
||||
}}
|
||||
onMouseEnter={(e) => e.currentTarget.style.background = darkMode ? '#303030' : '#f5f5f5'}
|
||||
onMouseLeave={(e) => e.currentTarget.style.background = 'transparent'}
|
||||
onClick={() => {
|
||||
if (cellContextMenu.record) handleCopyDelete(cellContextMenu.record);
|
||||
setCellContextMenu(prev => ({ ...prev, visible: false }));
|
||||
}}
|
||||
>
|
||||
复制为 DELETE
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
<div
|
||||
style={{
|
||||
|
||||
@@ -6,6 +6,7 @@ import { DBGetDatabases, DBGetTables, DataSync, DataSyncAnalyze, DataSyncPreview
|
||||
import { SavedConnection } from '../types';
|
||||
import { EventsOn } from '../../wailsjs/runtime/runtime';
|
||||
import { normalizeOpacityForPlatform, resolveAppearanceValues } from '../utils/appearance';
|
||||
import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig';
|
||||
import { formatLocalDateTimeLiteral, normalizeTemporalLiteralText } from './dataGridCopyInsert';
|
||||
|
||||
const { Title, Text } = Typography;
|
||||
@@ -236,14 +237,11 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
const logBoxRef = useRef<HTMLDivElement>(null);
|
||||
const autoScrollRef = useRef(true);
|
||||
|
||||
const normalizeConnConfig = (conn: SavedConnection, database?: string) => ({
|
||||
...conn.config,
|
||||
port: Number((conn.config as any).port),
|
||||
password: conn.config.password || "",
|
||||
useSSH: conn.config.useSSH || false,
|
||||
ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" },
|
||||
database: typeof database === 'string' ? database : (conn.config.database || ""),
|
||||
});
|
||||
const normalizeConnConfig = (conn: SavedConnection, database?: string) => (
|
||||
buildRpcConnectionConfig(conn.config, {
|
||||
database: typeof database === 'string' ? database : (conn.config.database || ''),
|
||||
})
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (!open) return;
|
||||
@@ -542,22 +540,8 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open,
|
||||
});
|
||||
|
||||
const config = {
|
||||
sourceConfig: {
|
||||
...sConn.config,
|
||||
port: Number((sConn.config as any).port),
|
||||
password: sConn.config.password || "",
|
||||
useSSH: sConn.config.useSSH || false,
|
||||
ssh: sConn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" },
|
||||
database: sourceDb,
|
||||
},
|
||||
targetConfig: {
|
||||
...tConn.config,
|
||||
port: Number((tConn.config as any).port),
|
||||
password: tConn.config.password || "",
|
||||
useSSH: tConn.config.useSSH || false,
|
||||
ssh: tConn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" },
|
||||
database: targetDb,
|
||||
},
|
||||
sourceConfig: normalizeConnConfig(sConn, sourceDb),
|
||||
targetConfig: normalizeConnConfig(tConn, targetDb),
|
||||
tables: selectedTables,
|
||||
content: syncContent,
|
||||
mode: syncMode,
|
||||
|
||||
@@ -9,6 +9,7 @@ import { buildMongoCountCommand, buildMongoFilter, buildMongoFindCommand, buildM
|
||||
import { buildOracleApproximateTotalSql, parseApproximateTableCountRow, resolveApproximateTableCountStrategy } from '../utils/approximateTableCount';
|
||||
import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities';
|
||||
import { resolveDataViewerAutoFetchAction } from '../utils/dataViewerAutoFetch';
|
||||
import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig';
|
||||
|
||||
type ViewerPaginationState = {
|
||||
current: number;
|
||||
@@ -319,7 +320,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct
|
||||
const countSeq = ++manualCountSeqRef.current;
|
||||
const countStart = Date.now();
|
||||
setPagination(prev => ({ ...prev, totalCountLoading: true, totalCountCancelled: false }));
|
||||
const countConfig: any = { ...(config as any), timeout: 120 };
|
||||
const countConfig = buildRpcConnectionConfig(config, { timeout: 120 });
|
||||
|
||||
try {
|
||||
const resCount = await DBQuery(countConfig as any, dbName, countSql);
|
||||
@@ -478,7 +479,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct
|
||||
const executeDataQuery = async (querySql: string, attemptLabel: string) => {
|
||||
const startTime = Date.now();
|
||||
try {
|
||||
const result = await DBQuery(config as any, dbName, querySql);
|
||||
const result = await DBQuery(buildRpcConnectionConfig(config) as any, dbName, querySql);
|
||||
addSqlLog({
|
||||
id: `log-${Date.now()}-data`,
|
||||
timestamp: Date.now(),
|
||||
@@ -514,7 +515,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct
|
||||
let safeSelect = duckdbSafeSelectCacheRef.current[cacheKey] || '';
|
||||
if (!safeSelect) {
|
||||
try {
|
||||
const resCols = await DBGetColumns(config as any, dbName, tableName);
|
||||
const resCols = await DBGetColumns(buildRpcConnectionConfig(config) as any, dbName, tableName);
|
||||
if (resCols?.success && Array.isArray(resCols.data)) {
|
||||
const columnDefs = resCols.data as ColumnDefinition[];
|
||||
const selectParts = columnDefs.map((col) => {
|
||||
@@ -567,7 +568,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct
|
||||
if (pkKeyRef.current !== pkKey) {
|
||||
pkKeyRef.current = pkKey;
|
||||
const pkSeq = ++pkSeqRef.current;
|
||||
DBGetColumns(config as any, dbName, tableName)
|
||||
DBGetColumns(buildRpcConnectionConfig(config) as any, dbName, tableName)
|
||||
.then((resCols: any) => {
|
||||
if (pkSeqRef.current !== pkSeq) return;
|
||||
if (pkKeyRef.current !== pkKey) return;
|
||||
@@ -680,7 +681,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct
|
||||
const countStart = Date.now();
|
||||
// 大表 COUNT(*) 可能非常慢,且在部分运行时环境下会影响后续操作响应;
|
||||
// DuckDB 大文件场景下该统计会显著拖慢翻页,已禁用后台 COUNT。
|
||||
const countConfig: any = { ...(config as any), timeout: 5 };
|
||||
const countConfig = buildRpcConnectionConfig(config, { timeout: 5 });
|
||||
|
||||
DBQuery(countConfig, dbName, countSql)
|
||||
.then((resCount: any) => {
|
||||
@@ -734,7 +735,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct
|
||||
const { schemaName, pureTableName } = resolveDuckDBSchemaAndTable(dbName, tableName);
|
||||
const escapedSchema = escapeSQLLiteral(schemaName);
|
||||
const escapedTable = escapeSQLLiteral(pureTableName);
|
||||
const approxConfig: any = { ...(config as any), timeout: 3 };
|
||||
const approxConfig = buildRpcConnectionConfig(config, { timeout: 3 });
|
||||
const approxSqlCandidates = [
|
||||
`SELECT estimated_size AS approx_total FROM duckdb_tables() WHERE schema_name='${escapedSchema}' AND table_name='${escapedTable}' LIMIT 1`,
|
||||
`SELECT estimated_size AS approx_total FROM duckdb_tables() WHERE table_name='${escapedTable}' ORDER BY CASE WHEN schema_name='${escapedSchema}' THEN 0 ELSE 1 END LIMIT 1`,
|
||||
@@ -775,7 +776,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct
|
||||
if (approximateCountStrategy === 'oracle-num-rows' && oracleApproxKeyRef.current !== countKey) {
|
||||
oracleApproxKeyRef.current = countKey;
|
||||
const approxSeq = ++oracleApproxSeqRef.current;
|
||||
const approxConfig: any = { ...(config as any), timeout: 3 };
|
||||
const approxConfig = buildRpcConnectionConfig(config, { timeout: 3 });
|
||||
const approxSql = buildOracleApproximateTotalSql({ dbName, tableName });
|
||||
|
||||
DBQuery(approxConfig as any, dbName, approxSql)
|
||||
|
||||
@@ -4,6 +4,7 @@ import { Spin, Alert } from 'antd';
|
||||
import { TabData } from '../types';
|
||||
import { useStore } from '../store';
|
||||
import { DBQuery } from '../../wailsjs/go/app/App';
|
||||
import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig';
|
||||
|
||||
interface DefinitionViewerProps {
|
||||
tab: TabData;
|
||||
@@ -201,7 +202,7 @@ const DefinitionViewer: React.FC<DefinitionViewerProps> = ({ tab }) => {
|
||||
const sql = String(query || '').trim();
|
||||
if (!sql) continue;
|
||||
try {
|
||||
const result = await DBQuery(config as any, dbName, sql);
|
||||
const result = await DBQuery(buildRpcConnectionConfig(config) as any, dbName, sql);
|
||||
if (!result.success || !Array.isArray(result.data)) {
|
||||
lastMessage = result.message || lastMessage;
|
||||
continue;
|
||||
@@ -227,7 +228,7 @@ const DefinitionViewer: React.FC<DefinitionViewerProps> = ({ tab }) => {
|
||||
];
|
||||
for (const query of candidates) {
|
||||
try {
|
||||
const result = await DBQuery(config as any, dbName, query);
|
||||
const result = await DBQuery(buildRpcConnectionConfig(config) as any, dbName, query);
|
||||
if (!result.success || !Array.isArray(result.data) || result.data.length === 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -757,6 +757,16 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG
|
||||
};
|
||||
}, [appendOperationLog, open]);
|
||||
|
||||
const resolveLocalImportVersion = useCallback((row: DriverStatusRow) => {
|
||||
const options = versionMap[row.type] || [];
|
||||
const selectedKey = selectedVersionMap[row.type];
|
||||
const selectedOption =
|
||||
options.find((item) => buildVersionOptionKey(item) === selectedKey) ||
|
||||
options.find((item) => item.recommended) ||
|
||||
options[0];
|
||||
return selectedOption?.version || row.pinnedVersion || '';
|
||||
}, [selectedVersionMap, versionMap]);
|
||||
|
||||
const installDriver = useCallback(async (row: DriverStatusRow) => {
|
||||
setActionState({ driverType: row.type, kind: 'install' });
|
||||
setProgressMap((prev) => ({
|
||||
@@ -820,9 +830,11 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG
|
||||
percent: 0,
|
||||
},
|
||||
}));
|
||||
appendOperationLog(row.type, `[START] 开始本地导入(${sourceLabel}):${pathText}`);
|
||||
const selectedVersion = resolveLocalImportVersion(row);
|
||||
const versionTip = selectedVersion ? `(${selectedVersion})` : '';
|
||||
appendOperationLog(row.type, `[START] 开始本地导入${versionTip}(${sourceLabel}):${pathText}`);
|
||||
try {
|
||||
const result = await InstallLocalDriverPackage(row.type, pathText, downloadDir);
|
||||
const result = await InstallLocalDriverPackage(row.type, pathText, downloadDir, selectedVersion);
|
||||
if (!result?.success) {
|
||||
const errText = result?.message || `导入 ${row.name} 本地驱动包失败`;
|
||||
appendOperationLog(row.type, `[ERROR] ${errText}`);
|
||||
@@ -831,9 +843,9 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG
|
||||
}
|
||||
return false;
|
||||
}
|
||||
appendOperationLog(row.type, '[DONE] 本地导入安装完成');
|
||||
appendOperationLog(row.type, `[DONE] 本地导入安装完成 ${versionTip}`.trim());
|
||||
if (!options?.silentToast) {
|
||||
message.success(`${row.name} 本地驱动包已安装启用`);
|
||||
message.success(`${row.name}${versionTip} 本地驱动包已安装启用`);
|
||||
}
|
||||
if (!options?.skipRefresh) {
|
||||
await refreshStatus(false);
|
||||
@@ -842,7 +854,7 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG
|
||||
} finally {
|
||||
setActionState({ driverType: '', kind: '' });
|
||||
}
|
||||
}, [appendOperationLog, downloadDir, refreshStatus]);
|
||||
}, [appendOperationLog, downloadDir, refreshStatus, resolveLocalImportVersion]);
|
||||
|
||||
const installDriverFromLocalFile = useCallback(async (row: DriverStatusRow) => {
|
||||
const fileRes = await SelectDriverPackageFile(downloadDir);
|
||||
@@ -1067,29 +1079,35 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG
|
||||
const options = versionMap[row.type] || [];
|
||||
const selectedKey = selectedVersionMap[row.type];
|
||||
const selectOptions = buildVersionSelectOptions(options);
|
||||
const mongoHint = row.type === 'mongodb'
|
||||
? '当前仅支持 MongoDB 1.17.x 和 2.x;更老 1.x 暂不提供安装。'
|
||||
: '';
|
||||
return (
|
||||
<Select
|
||||
size="small"
|
||||
style={{ width: '100%' }}
|
||||
loading={!!versionLoadingMap[row.type]}
|
||||
disabled={actionState.driverType === row.type}
|
||||
placeholder={options.length > 0 ? '选择驱动版本' : '点击展开加载版本'}
|
||||
value={selectedKey}
|
||||
options={selectOptions as any}
|
||||
onOpenChange={(open) => {
|
||||
if (open && options.length === 0 && !versionLoadingMap[row.type]) {
|
||||
void loadVersionOptions(row, true);
|
||||
return;
|
||||
}
|
||||
if (open && selectedKey) {
|
||||
void loadVersionPackageSize(row, selectedKey);
|
||||
}
|
||||
}}
|
||||
onChange={(value) => {
|
||||
setSelectedVersionMap((prev) => ({ ...prev, [row.type]: value }));
|
||||
void loadVersionPackageSize(row, value);
|
||||
}}
|
||||
/>
|
||||
<div style={{ display: 'grid', gap: 4 }}>
|
||||
<Select
|
||||
size="small"
|
||||
style={{ width: '100%' }}
|
||||
loading={!!versionLoadingMap[row.type]}
|
||||
disabled={actionState.driverType === row.type}
|
||||
placeholder={options.length > 0 ? '选择驱动版本' : '点击展开加载版本'}
|
||||
value={selectedKey}
|
||||
options={selectOptions as any}
|
||||
onOpenChange={(open) => {
|
||||
if (open && options.length === 0 && !versionLoadingMap[row.type]) {
|
||||
void loadVersionOptions(row, true);
|
||||
return;
|
||||
}
|
||||
if (open && selectedKey) {
|
||||
void loadVersionPackageSize(row, selectedKey);
|
||||
}
|
||||
}}
|
||||
onChange={(value) => {
|
||||
setSelectedVersionMap((prev) => ({ ...prev, [row.type]: value }));
|
||||
void loadVersionPackageSize(row, value);
|
||||
}}
|
||||
/>
|
||||
{mongoHint ? <Text type="secondary" style={{ fontSize: 12 }}>{mongoHint}</Text> : null}
|
||||
</div>
|
||||
);
|
||||
},
|
||||
},
|
||||
|
||||
@@ -5,6 +5,7 @@ import { DBQuery, DBGetTables, DBGetAllColumns } from '../../wailsjs/go/app/App'
|
||||
import { quoteIdentPart, escapeLiteral } from '../utils/sql';
|
||||
import { useStore } from '../store';
|
||||
import { buildOverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme';
|
||||
import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig';
|
||||
|
||||
interface FindInDatabaseModalProps {
|
||||
open: boolean;
|
||||
@@ -106,7 +107,7 @@ const FindInDatabaseModal: React.FC<FindInDatabaseModalProps> = ({ open, onClose
|
||||
|
||||
try {
|
||||
// 1. 获取所有表
|
||||
const tablesRes = await DBGetTables(config as any, dbName);
|
||||
const tablesRes = await DBGetTables(buildRpcConnectionConfig(config) as any, dbName);
|
||||
if (!tablesRes.success) {
|
||||
message.error('获取表列表失败: ' + tablesRes.message);
|
||||
setSearching(false);
|
||||
@@ -124,7 +125,7 @@ const FindInDatabaseModal: React.FC<FindInDatabaseModalProps> = ({ open, onClose
|
||||
setProgress({ current: 0, total: tableNames.length, tableName: '' });
|
||||
|
||||
// 2. 获取所有列信息(返回 any[],含 tableName/name/type 字段)
|
||||
const allColsRes = await DBGetAllColumns(config as any, dbName);
|
||||
const allColsRes = await DBGetAllColumns(buildRpcConnectionConfig(config) as any, dbName);
|
||||
const allColumns: any[] = (allColsRes?.success && Array.isArray(allColsRes.data)) ? allColsRes.data : [];
|
||||
|
||||
// 按表名分组
|
||||
@@ -166,7 +167,7 @@ const FindInDatabaseModal: React.FC<FindInDatabaseModalProps> = ({ open, onClose
|
||||
const sql = buildLimitedSelectSQL(dbType, baseSql, MAX_MATCH_ROWS_PER_TABLE);
|
||||
|
||||
try {
|
||||
const res = await DBQuery(config as any, dbName, sql);
|
||||
const res = await DBQuery(buildRpcConnectionConfig(config) as any, dbName, sql);
|
||||
if (res.success && Array.isArray(res.data) && res.data.length > 0) {
|
||||
// 检查哪些列实际匹配了
|
||||
const matchedCols = new Set<string>();
|
||||
|
||||
@@ -4,6 +4,7 @@ import { CheckCircleOutlined, CloseCircleOutlined } from '@ant-design/icons';
|
||||
import { PreviewImportFile, ImportDataWithProgress } from '../../wailsjs/go/app/App';
|
||||
import { EventsOn, EventsOff } from '../../wailsjs/runtime/runtime';
|
||||
import { useStore } from '../store';
|
||||
import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig';
|
||||
|
||||
interface ImportPreviewModalProps {
|
||||
visible: boolean;
|
||||
@@ -107,7 +108,7 @@ const ImportPreviewModal: React.FC<ImportPreviewModalProps> = ({
|
||||
ssh: conn.config.ssh || { host: '', port: 22, user: '', password: '', keyPath: '' }
|
||||
};
|
||||
|
||||
const res = await ImportDataWithProgress(config as any, dbName, tableName, filePath);
|
||||
const res = await ImportDataWithProgress(buildRpcConnectionConfig(config) as any, dbName, tableName, filePath);
|
||||
|
||||
if (res.success && res.data) {
|
||||
setImportResult(res.data);
|
||||
|
||||
@@ -11,6 +11,7 @@ import DataGrid, { GONAVI_ROW_KEY } from './DataGrid';
|
||||
import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities';
|
||||
import { convertMongoShellToJsonCommand } from '../utils/mongodb';
|
||||
import { getShortcutDisplay, isEditableElement, isShortcutMatch } from '../utils/shortcuts';
|
||||
import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig';
|
||||
|
||||
const SQL_KEYWORDS = [
|
||||
'SELECT', 'FROM', 'WHERE', 'LIMIT', 'INSERT', 'UPDATE', 'DELETE', 'JOIN', 'LEFT', 'RIGHT',
|
||||
@@ -336,7 +337,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
|
||||
};
|
||||
|
||||
const res = await DBGetDatabases(config as any);
|
||||
const res = await DBGetDatabases(buildRpcConnectionConfig(config) as any);
|
||||
if (res.success && Array.isArray(res.data)) {
|
||||
let dbs = res.data.map((row: any) => row.Database || row.database);
|
||||
|
||||
@@ -392,7 +393,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
|
||||
for (const dbName of visibleDbs) {
|
||||
// 获取表
|
||||
const resTables = await DBGetTables(config as any, dbName);
|
||||
const resTables = await DBGetTables(buildRpcConnectionConfig(config) as any, dbName);
|
||||
if (resTables.success && Array.isArray(resTables.data)) {
|
||||
const tableNames = resTables.data.map((row: any) => Object.values(row)[0] as string);
|
||||
tableNames.forEach((tableName: string) => {
|
||||
@@ -401,7 +402,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
}
|
||||
|
||||
// 获取列 (所有数据库类型都支持 DBGetAllColumns)
|
||||
const resCols = await DBGetAllColumns(config as any, dbName);
|
||||
const resCols = await DBGetAllColumns(buildRpcConnectionConfig(config) as any, dbName);
|
||||
if (resCols.success && Array.isArray(resCols.data)) {
|
||||
resCols.data.forEach((col: any) => {
|
||||
allColumns.push({
|
||||
@@ -577,7 +578,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
const config = buildConnConfig();
|
||||
if (!config) return [] as ColumnDefinition[];
|
||||
|
||||
const res = await DBGetColumns(config as any, dbName, tableIdent);
|
||||
const res = await DBGetColumns(buildRpcConnectionConfig(config) as any, dbName, tableIdent);
|
||||
if (res?.success && Array.isArray(res.data)) {
|
||||
const cols = res.data as ColumnDefinition[];
|
||||
sharedColumnsCacheData[key] = cols;
|
||||
@@ -1555,7 +1556,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
} catch {
|
||||
queryId = 'reload-' + Date.now();
|
||||
}
|
||||
const res = await DBQueryMulti(config as any, currentDb, sql, queryId);
|
||||
const res = await DBQueryMulti(buildRpcConnectionConfig(config) as any, currentDb, sql, queryId);
|
||||
if (!res?.success) {
|
||||
message.error('刷新失败: ' + (res?.message || '未知错误'));
|
||||
return;
|
||||
@@ -1643,7 +1644,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
|
||||
try {
|
||||
const rawSQL = getSelectedSQL() || currentQuery;
|
||||
const dbType = String((config as any).type || 'mysql');
|
||||
const dbType = String((buildRpcConnectionConfig(config) as any).type || 'mysql');
|
||||
const normalizedDbType = dbType.trim().toLowerCase();
|
||||
const normalizedRawSQL = String(rawSQL || '').replace(/;/g, ';');
|
||||
|
||||
@@ -1694,7 +1695,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
}
|
||||
setQueryId(queryId);
|
||||
|
||||
const res = await DBQueryWithCancel(config as any, currentDb, executedSql, queryId);
|
||||
const res = await DBQueryWithCancel(buildRpcConnectionConfig(config) as any, currentDb, executedSql, queryId);
|
||||
const duration = Date.now() - startTime;
|
||||
addSqlLog({
|
||||
id: `log-${Date.now()}-query-${idx + 1}`,
|
||||
@@ -1795,7 +1796,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
}
|
||||
setQueryId(queryId);
|
||||
|
||||
const res = await DBQueryMulti(config as any, currentDb, fullSQL, queryId);
|
||||
const res = await DBQueryMulti(buildRpcConnectionConfig(config) as any, currentDb, fullSQL, queryId);
|
||||
const duration = Date.now() - startTime;
|
||||
|
||||
addSqlLog({
|
||||
@@ -1921,7 +1922,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
setActiveResultKey(nextResultSets[0]?.key || '');
|
||||
|
||||
pendingPk.forEach(({ resultKey, tableName }) => {
|
||||
DBGetColumns(config as any, currentDb, tableName)
|
||||
DBGetColumns(buildRpcConnectionConfig(config) as any, currentDb, tableName)
|
||||
.then((resCols: any) => {
|
||||
if (runSeqRef.current !== runSeq) return;
|
||||
if (!resCols?.success) {
|
||||
|
||||
@@ -2,6 +2,7 @@ import React, { useState, useCallback, useRef, useEffect } from 'react';
|
||||
import { Button, Space, message } from 'antd';
|
||||
import { PlayCircleOutlined, ClearOutlined } from '@ant-design/icons';
|
||||
import { useStore } from '../store';
|
||||
import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig';
|
||||
import Editor, { OnMount } from '@monaco-editor/react';
|
||||
|
||||
interface RedisCommandEditorProps {
|
||||
@@ -201,7 +202,7 @@ const RedisCommandEditor: React.FC<RedisCommandEditorProps> = ({ connectionId, r
|
||||
for (const cmd of commands) {
|
||||
const start = Date.now();
|
||||
try {
|
||||
const res = await (window as any).go.app.App.RedisExecuteCommand(config, cmd);
|
||||
const res = await (window as any).go.app.App.RedisExecuteCommand(buildRpcConnectionConfig(config), cmd);
|
||||
newResults.push({
|
||||
command: cmd,
|
||||
result: res.success ? res.data : null,
|
||||
|
||||
@@ -12,6 +12,7 @@ import {
|
||||
} from '@ant-design/icons';
|
||||
import { useStore } from '../store';
|
||||
import { SavedConnection } from '../types';
|
||||
import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig';
|
||||
import { RedisGetServerInfo } from '../../wailsjs/go/app/App';
|
||||
|
||||
const { Title, Text } = Typography;
|
||||
@@ -61,7 +62,7 @@ const RedisMonitor: React.FC<RedisMonitorProps> = ({ connectionId, redisDB }) =>
|
||||
if (!connection) return;
|
||||
|
||||
try {
|
||||
const config = { ...connection.config, redisDB } as any;
|
||||
const config = buildRpcConnectionConfig(connection.config, { redisDB });
|
||||
const res = await RedisGetServerInfo(config);
|
||||
|
||||
if (!mountedRef.current) return;
|
||||
|
||||
@@ -7,6 +7,7 @@ import { RedisKeyInfo, RedisValue, StreamEntry } from '../types';
|
||||
import Editor from '@monaco-editor/react';
|
||||
import type { DataNode } from 'antd/es/tree';
|
||||
import { blurToFilter, normalizeBlurForPlatform, normalizeOpacityForPlatform, resolveAppearanceValues } from '../utils/appearance';
|
||||
import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig';
|
||||
import {
|
||||
applyRenamedRedisKeyState,
|
||||
applyTreeNodeCheck,
|
||||
@@ -429,7 +430,7 @@ const RedisViewer: React.FC<RedisViewerProps> = ({ connectionId, redisDB }) => {
|
||||
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await (window as any).go.app.App.RedisScanKeys(config, normalizedPattern, fromCursor, effectiveTargetCount);
|
||||
const res = await (window as any).go.app.App.RedisScanKeys(buildRpcConnectionConfig(config), normalizedPattern, fromCursor, effectiveTargetCount);
|
||||
if (requestId !== latestLoadRequestIdRef.current) {
|
||||
return;
|
||||
}
|
||||
@@ -508,7 +509,7 @@ const RedisViewer: React.FC<RedisViewerProps> = ({ connectionId, redisDB }) => {
|
||||
|
||||
setValueLoading(true);
|
||||
try {
|
||||
const res = await (window as any).go.app.App.RedisGetValue(config, key);
|
||||
const res = await (window as any).go.app.App.RedisGetValue(buildRpcConnectionConfig(config), key);
|
||||
if (res.success) {
|
||||
setKeyValue(res.data);
|
||||
setSelectedKey(key);
|
||||
@@ -539,7 +540,7 @@ const RedisViewer: React.FC<RedisViewerProps> = ({ connectionId, redisDB }) => {
|
||||
if (!config) return;
|
||||
|
||||
try {
|
||||
const res = await (window as any).go.app.App.RedisDeleteKeys(config, keysToDelete);
|
||||
const res = await (window as any).go.app.App.RedisDeleteKeys(buildRpcConnectionConfig(config), keysToDelete);
|
||||
if (res.success) {
|
||||
message.success(`已删除 ${res.data.deleted} 个 Key`);
|
||||
setKeys(prev => prev.filter(k => !keysToDelete.includes(k.key)));
|
||||
@@ -567,7 +568,7 @@ const RedisViewer: React.FC<RedisViewerProps> = ({ connectionId, redisDB }) => {
|
||||
|
||||
try {
|
||||
const values = await ttlForm.validateFields();
|
||||
const res = await (window as any).go.app.App.RedisSetTTL(config, selectedKey, values.ttl);
|
||||
const res = await (window as any).go.app.App.RedisSetTTL(buildRpcConnectionConfig(config), selectedKey, values.ttl);
|
||||
if (res.success) {
|
||||
message.success('TTL 设置成功');
|
||||
setTtlModalOpen(false);
|
||||
@@ -586,7 +587,7 @@ const RedisViewer: React.FC<RedisViewerProps> = ({ connectionId, redisDB }) => {
|
||||
if (!config || !selectedKey) return;
|
||||
|
||||
try {
|
||||
const res = await (window as any).go.app.App.RedisSetString(config, selectedKey, editValue, keyValue?.ttl || -1);
|
||||
const res = await (window as any).go.app.App.RedisSetString(buildRpcConnectionConfig(config), selectedKey, editValue, keyValue?.ttl || -1);
|
||||
if (res.success) {
|
||||
message.success('保存成功');
|
||||
setEditModalOpen(false);
|
||||
@@ -605,7 +606,7 @@ const RedisViewer: React.FC<RedisViewerProps> = ({ connectionId, redisDB }) => {
|
||||
|
||||
try {
|
||||
const values = await newKeyForm.validateFields();
|
||||
const res = await (window as any).go.app.App.RedisSetString(config, values.key, values.value, values.ttl || -1);
|
||||
const res = await (window as any).go.app.App.RedisSetString(buildRpcConnectionConfig(config), values.key, values.value, values.ttl || -1);
|
||||
if (res.success) {
|
||||
message.success('创建成功');
|
||||
setNewKeyModalOpen(false);
|
||||
@@ -642,7 +643,7 @@ const RedisViewer: React.FC<RedisViewerProps> = ({ connectionId, redisDB }) => {
|
||||
return;
|
||||
}
|
||||
|
||||
const existsRes = await (window as any).go.app.App.RedisKeyExists(config, nextKey);
|
||||
const existsRes = await (window as any).go.app.App.RedisKeyExists(buildRpcConnectionConfig(config), nextKey);
|
||||
if (!existsRes?.success) {
|
||||
message.error('校验目标 Key 失败: ' + (existsRes?.message || '未知错误'));
|
||||
return;
|
||||
@@ -652,7 +653,7 @@ const RedisViewer: React.FC<RedisViewerProps> = ({ connectionId, redisDB }) => {
|
||||
return;
|
||||
}
|
||||
|
||||
const res = await (window as any).go.app.App.RedisRenameKey(config, renameTargetKey, nextKey);
|
||||
const res = await (window as any).go.app.App.RedisRenameKey(buildRpcConnectionConfig(config), renameTargetKey, nextKey);
|
||||
if (res.success) {
|
||||
const nextState = applyRenamedRedisKeyState(
|
||||
{
|
||||
@@ -1177,7 +1178,7 @@ const RedisViewer: React.FC<RedisViewerProps> = ({ connectionId, redisDB }) => {
|
||||
const config = getConfig();
|
||||
if (!config) return;
|
||||
try {
|
||||
const res = await (window as any).go.app.App.RedisSetHashField(config, selectedKey, field, newValue);
|
||||
const res = await (window as any).go.app.App.RedisSetHashField(buildRpcConnectionConfig(config), selectedKey, field, newValue);
|
||||
if (res.success) {
|
||||
message.success('修改成功');
|
||||
loadKeyValue(selectedKey);
|
||||
@@ -1193,7 +1194,7 @@ const RedisViewer: React.FC<RedisViewerProps> = ({ connectionId, redisDB }) => {
|
||||
const config = getConfig();
|
||||
if (!config) return;
|
||||
try {
|
||||
const res = await (window as any).go.app.App.RedisDeleteHashField(config, selectedKey, field);
|
||||
const res = await (window as any).go.app.App.RedisDeleteHashField(buildRpcConnectionConfig(config), selectedKey, field);
|
||||
if (res.success) {
|
||||
message.success('删除成功');
|
||||
loadKeyValue(selectedKey);
|
||||
@@ -1338,7 +1339,7 @@ const RedisViewer: React.FC<RedisViewerProps> = ({ connectionId, redisDB }) => {
|
||||
const config = getConfig();
|
||||
if (!config) return;
|
||||
try {
|
||||
const res = await (window as any).go.app.App.RedisListSet(config, selectedKey, index, newValue);
|
||||
const res = await (window as any).go.app.App.RedisListSet(buildRpcConnectionConfig(config), selectedKey, index, newValue);
|
||||
if (res.success) {
|
||||
message.success('修改成功');
|
||||
loadKeyValue(selectedKey);
|
||||
@@ -1354,7 +1355,7 @@ const RedisViewer: React.FC<RedisViewerProps> = ({ connectionId, redisDB }) => {
|
||||
const config = getConfig();
|
||||
if (!config) return;
|
||||
try {
|
||||
const res = await (window as any).go.app.App.RedisListPush(config, selectedKey, { values: [value], position });
|
||||
const res = await (window as any).go.app.App.RedisListPush(buildRpcConnectionConfig(config), selectedKey, { values: [value], position });
|
||||
if (res.success) {
|
||||
message.success('添加成功');
|
||||
loadKeyValue(selectedKey);
|
||||
@@ -1508,7 +1509,7 @@ const RedisViewer: React.FC<RedisViewerProps> = ({ connectionId, redisDB }) => {
|
||||
const config = getConfig();
|
||||
if (!config) return;
|
||||
try {
|
||||
const res = await (window as any).go.app.App.RedisSetAdd(config, selectedKey, [member]);
|
||||
const res = await (window as any).go.app.App.RedisSetAdd(buildRpcConnectionConfig(config), selectedKey, [member]);
|
||||
if (res.success) {
|
||||
message.success('添加成功');
|
||||
loadKeyValue(selectedKey);
|
||||
@@ -1524,7 +1525,7 @@ const RedisViewer: React.FC<RedisViewerProps> = ({ connectionId, redisDB }) => {
|
||||
const config = getConfig();
|
||||
if (!config) return;
|
||||
try {
|
||||
const res = await (window as any).go.app.App.RedisSetRemove(config, selectedKey, [member]);
|
||||
const res = await (window as any).go.app.App.RedisSetRemove(buildRpcConnectionConfig(config), selectedKey, [member]);
|
||||
if (res.success) {
|
||||
message.success('删除成功');
|
||||
loadKeyValue(selectedKey);
|
||||
@@ -1645,7 +1646,7 @@ const RedisViewer: React.FC<RedisViewerProps> = ({ connectionId, redisDB }) => {
|
||||
const config = getConfig();
|
||||
if (!config) return;
|
||||
try {
|
||||
const res = await (window as any).go.app.App.RedisZSetAdd(config, selectedKey, [{ member, score }]);
|
||||
const res = await (window as any).go.app.App.RedisZSetAdd(buildRpcConnectionConfig(config), selectedKey, [{ member, score }]);
|
||||
if (res.success) {
|
||||
message.success('添加成功');
|
||||
loadKeyValue(selectedKey);
|
||||
@@ -1661,7 +1662,7 @@ const RedisViewer: React.FC<RedisViewerProps> = ({ connectionId, redisDB }) => {
|
||||
const config = getConfig();
|
||||
if (!config) return;
|
||||
try {
|
||||
const res = await (window as any).go.app.App.RedisZSetRemove(config, selectedKey, [member]);
|
||||
const res = await (window as any).go.app.App.RedisZSetRemove(buildRpcConnectionConfig(config), selectedKey, [member]);
|
||||
if (res.success) {
|
||||
message.success('删除成功');
|
||||
loadKeyValue(selectedKey);
|
||||
@@ -1841,7 +1842,7 @@ const RedisViewer: React.FC<RedisViewerProps> = ({ connectionId, redisDB }) => {
|
||||
}
|
||||
|
||||
try {
|
||||
const res = await (window as any).go.app.App.RedisStreamAdd(config, selectedKey, fieldMap, id || '*');
|
||||
const res = await (window as any).go.app.App.RedisStreamAdd(buildRpcConnectionConfig(config), selectedKey, fieldMap, id || '*');
|
||||
if (res.success) {
|
||||
const newID = res.data?.id ? ` (${res.data.id})` : '';
|
||||
message.success(`添加成功${newID}`);
|
||||
@@ -1859,7 +1860,7 @@ const RedisViewer: React.FC<RedisViewerProps> = ({ connectionId, redisDB }) => {
|
||||
if (!config) return;
|
||||
|
||||
try {
|
||||
const res = await (window as any).go.app.App.RedisStreamDelete(config, selectedKey, [id]);
|
||||
const res = await (window as any).go.app.App.RedisStreamDelete(buildRpcConnectionConfig(config), selectedKey, [id]);
|
||||
if (res.success) {
|
||||
const deleted = Number(res.data?.deleted ?? 0);
|
||||
if (deleted > 0) {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, { useEffect, useState, useMemo, useRef } from 'react';
|
||||
import React, { useEffect, useState, useMemo, useRef } from 'react';
|
||||
import { Tree, message, Dropdown, MenuProps, Input, Button, Modal, Form, Badge, Checkbox, Space, Select, Popover, Tooltip, Progress } from 'antd';
|
||||
import {
|
||||
DatabaseOutlined,
|
||||
@@ -42,6 +42,7 @@ import { getDbIcon } from './DatabaseIcons';
|
||||
import { EventsOn } from '../../wailsjs/runtime/runtime';
|
||||
import { normalizeOpacityForPlatform, resolveAppearanceValues } from '../utils/appearance';
|
||||
import FindInDatabaseModal from './FindInDatabaseModal';
|
||||
import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig';
|
||||
|
||||
const { Search } = Input;
|
||||
|
||||
@@ -366,129 +367,25 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
});
|
||||
}, [connections, connectionTags]);
|
||||
|
||||
const buildDuplicateConnectionName = (rawName: string): string => {
|
||||
const baseName = String(rawName || '').trim() || '连接';
|
||||
const suffix = ' - 副本';
|
||||
const usedNames = new Set(connections.map(conn => String(conn.name || '').trim()));
|
||||
let candidate = `${baseName}${suffix}`;
|
||||
let counter = 2;
|
||||
while (usedNames.has(candidate)) {
|
||||
candidate = `${baseName}${suffix} ${counter}`;
|
||||
counter += 1;
|
||||
}
|
||||
return candidate;
|
||||
};
|
||||
const handleDuplicateConnection = async (conn: SavedConnection) => {
|
||||
if (!conn?.id) return;
|
||||
|
||||
const backendApp = (window as any).go?.app?.App;
|
||||
if (typeof backendApp?.DuplicateConnection !== 'function') {
|
||||
message.error('复制连接失败:后端接口不可用');
|
||||
return;
|
||||
}
|
||||
|
||||
const cloneConnectionConfig = (config: SavedConnection['config']): SavedConnection['config'] => {
|
||||
const raw: any = config || {};
|
||||
let cloned: any = {};
|
||||
try {
|
||||
cloned = typeof structuredClone === 'function'
|
||||
? structuredClone(raw)
|
||||
: JSON.parse(JSON.stringify(raw));
|
||||
} catch {
|
||||
cloned = { ...raw };
|
||||
const duplicatedConnection = await backendApp.DuplicateConnection(conn.id);
|
||||
if (!duplicatedConnection) {
|
||||
throw new Error('复制连接失败:后端未返回结果');
|
||||
}
|
||||
addConnection(duplicatedConnection);
|
||||
message.success(`已复制连接: ${duplicatedConnection.name}`);
|
||||
} catch (error: any) {
|
||||
message.error(error?.message || '复制连接失败');
|
||||
}
|
||||
|
||||
const readString = (...values: unknown[]): string => {
|
||||
for (const value of values) {
|
||||
if (typeof value === 'string') {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
return '';
|
||||
};
|
||||
|
||||
const readBool = (fallback: boolean, ...values: unknown[]): boolean => {
|
||||
for (const value of values) {
|
||||
if (typeof value === 'boolean') {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
return fallback;
|
||||
};
|
||||
|
||||
const readNumber = (fallback: number, ...values: unknown[]): number => {
|
||||
for (const value of values) {
|
||||
const num = Number(value);
|
||||
if (Number.isFinite(num)) {
|
||||
return num;
|
||||
}
|
||||
}
|
||||
return fallback;
|
||||
};
|
||||
|
||||
const rawSSH = (cloned.ssh ?? cloned.SSH ?? {}) as Record<string, unknown>;
|
||||
const normalizedSSH = {
|
||||
host: readString(rawSSH.host, rawSSH.Host, cloned.sshHost, cloned.SSHHost),
|
||||
port: readNumber(22, rawSSH.port, rawSSH.Port, cloned.sshPort, cloned.SSHPort),
|
||||
user: readString(rawSSH.user, rawSSH.User, cloned.sshUser, cloned.SSHUser),
|
||||
password: readString(rawSSH.password, rawSSH.Password, cloned.sshPassword, cloned.SSHPassword),
|
||||
keyPath: readString(rawSSH.keyPath, rawSSH.KeyPath, cloned.sshKeyPath, cloned.SSHKeyPath),
|
||||
};
|
||||
const hasSSHDetail = Boolean(
|
||||
normalizedSSH.host
|
||||
|| normalizedSSH.user
|
||||
|| normalizedSSH.password
|
||||
|| normalizedSSH.keyPath
|
||||
);
|
||||
|
||||
const rawProxy = (cloned.proxy ?? cloned.Proxy ?? {}) as Record<string, unknown>;
|
||||
const proxyTypeRaw = readString(rawProxy.type, rawProxy.Type, cloned.proxyType, cloned.ProxyType).toLowerCase();
|
||||
const proxyType: 'socks5' | 'http' = proxyTypeRaw === 'http' ? 'http' : 'socks5';
|
||||
const normalizedProxy = {
|
||||
type: proxyType,
|
||||
host: readString(rawProxy.host, rawProxy.Host, cloned.proxyHost, cloned.ProxyHost),
|
||||
port: readNumber(proxyType === 'http' ? 8080 : 1080, rawProxy.port, rawProxy.Port, cloned.proxyPort, cloned.ProxyPort),
|
||||
user: readString(rawProxy.user, rawProxy.User, cloned.proxyUser, cloned.ProxyUser),
|
||||
password: readString(rawProxy.password, rawProxy.Password, cloned.proxyPassword, cloned.ProxyPassword),
|
||||
};
|
||||
const hasProxyDetail = Boolean(normalizedProxy.host || normalizedProxy.user || normalizedProxy.password);
|
||||
const rawHttpTunnel = (cloned.httpTunnel ?? cloned.HTTPTunnel ?? {}) as Record<string, unknown>;
|
||||
const normalizedHttpTunnel = {
|
||||
host: readString(rawHttpTunnel.host, rawHttpTunnel.Host, cloned.httpTunnelHost, cloned.HttpTunnelHost),
|
||||
port: readNumber(8080, rawHttpTunnel.port, rawHttpTunnel.Port, cloned.httpTunnelPort, cloned.HttpTunnelPort),
|
||||
user: readString(rawHttpTunnel.user, rawHttpTunnel.User, cloned.httpTunnelUser, cloned.HttpTunnelUser),
|
||||
password: readString(rawHttpTunnel.password, rawHttpTunnel.Password, cloned.httpTunnelPassword, cloned.HttpTunnelPassword),
|
||||
};
|
||||
const hasHttpTunnelDetail = Boolean(normalizedHttpTunnel.host || normalizedHttpTunnel.user || normalizedHttpTunnel.password);
|
||||
const normalizedUseHttpTunnel = readBool(hasHttpTunnelDetail, cloned.useHttpTunnel, cloned.UseHTTPTunnel);
|
||||
const normalizedUseProxy = !normalizedUseHttpTunnel && readBool(hasProxyDetail, cloned.useProxy, cloned.UseProxy);
|
||||
|
||||
const rawHosts = Array.isArray(cloned.hosts)
|
||||
? cloned.hosts
|
||||
: (Array.isArray(cloned.Hosts) ? cloned.Hosts : []);
|
||||
const normalizedHosts = rawHosts
|
||||
.map((entry: unknown) => String(entry || '').trim())
|
||||
.filter((entry: string) => !!entry);
|
||||
|
||||
return {
|
||||
...(cloned as SavedConnection['config']),
|
||||
useSSH: readBool(hasSSHDetail, cloned.useSSH, cloned.UseSSH),
|
||||
ssh: normalizedSSH,
|
||||
useProxy: normalizedUseProxy,
|
||||
proxy: normalizedProxy,
|
||||
useHttpTunnel: normalizedUseHttpTunnel,
|
||||
httpTunnel: normalizedHttpTunnel,
|
||||
hosts: normalizedHosts,
|
||||
timeout: readNumber(30, cloned.timeout, cloned.Timeout),
|
||||
};
|
||||
};
|
||||
|
||||
const handleDuplicateConnection = (conn: SavedConnection) => {
|
||||
if (!conn) return;
|
||||
|
||||
const duplicatedConnection: SavedConnection = {
|
||||
...conn,
|
||||
id: `${Date.now()}-${Math.random().toString(36).slice(2, 8)}`,
|
||||
name: buildDuplicateConnectionName(conn.name),
|
||||
config: cloneConnectionConfig(conn.config),
|
||||
includeDatabases: conn.includeDatabases ? [...conn.includeDatabases] : undefined,
|
||||
includeRedisDatabases: conn.includeRedisDatabases ? [...conn.includeRedisDatabases] : undefined,
|
||||
};
|
||||
|
||||
addConnection(duplicatedConnection);
|
||||
message.success(`已复制连接: ${duplicatedConnection.name}`);
|
||||
};
|
||||
const updateTreeData = (list: TreeNode[], key: React.Key, children: TreeNode[] | undefined): TreeNode[] => {
|
||||
return list.map(node => {
|
||||
@@ -527,7 +424,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
if (SIDEBAR_SCHEMA_DB_TYPES.has(dbType)) return true;
|
||||
if (dbType !== 'custom') return false;
|
||||
|
||||
const customDriver = String((conn?.config as any)?.driver || '').trim().toLowerCase();
|
||||
const customDriver = String(conn?.config?.driver || '').trim().toLowerCase();
|
||||
return SIDEBAR_SCHEMA_CUSTOM_DRIVERS.has(customDriver);
|
||||
};
|
||||
|
||||
@@ -543,7 +440,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
const getMetadataDialect = (conn: SavedConnection | undefined): string => {
|
||||
const type = String(conn?.config?.type || '').trim().toLowerCase();
|
||||
if (type === 'custom') {
|
||||
const driver = String((conn?.config as any)?.driver || '').trim().toLowerCase();
|
||||
const driver = String(conn?.config?.driver || '').trim().toLowerCase();
|
||||
if (driver === 'diros' || driver === 'doris') return 'mysql';
|
||||
return driver;
|
||||
}
|
||||
@@ -569,7 +466,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
const type = String(conn?.config?.type || '').trim().toLowerCase();
|
||||
if (type === 'sphinx') return true;
|
||||
if (type !== 'custom') return false;
|
||||
const driver = String((conn?.config as any)?.driver || '').trim().toLowerCase();
|
||||
const driver = String(conn?.config?.driver || '').trim().toLowerCase();
|
||||
return driver === 'sphinx' || driver === 'sphinxql';
|
||||
};
|
||||
|
||||
@@ -857,7 +754,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
|
||||
for (const spec of normalizedSpecs) {
|
||||
try {
|
||||
const result = await DBQuery(config as any, dbName, spec.sql);
|
||||
const result = await DBQuery(buildRpcConnectionConfig(config) as any, dbName, spec.sql);
|
||||
if (!result.success || !Array.isArray(result.data)) {
|
||||
continue;
|
||||
}
|
||||
@@ -988,7 +885,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
// Handle Redis connections differently
|
||||
if (conn.config.type === 'redis') {
|
||||
try {
|
||||
const res = await (window as any).go.app.App.RedisGetDatabases(config);
|
||||
const res = await (window as any).go.app.App.RedisGetDatabases(buildRpcConnectionConfig(config));
|
||||
if (res.success) {
|
||||
setConnectionStates(prev => ({ ...prev, [conn.id]: 'success' }));
|
||||
const redisRows: any[] = Array.isArray(res.data) ? res.data : [];
|
||||
@@ -1020,7 +917,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
}
|
||||
|
||||
try {
|
||||
const res = await DBGetDatabases(config as any);
|
||||
const res = await DBGetDatabases(buildRpcConnectionConfig(config) as any);
|
||||
if (res.success) {
|
||||
setConnectionStates(prev => ({ ...prev, [conn.id]: 'success' }));
|
||||
const dbRows: any[] = Array.isArray(res.data) ? res.data : [];
|
||||
@@ -1094,7 +991,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
|
||||
};
|
||||
try {
|
||||
const res = await DBGetTables(config as any, conn.dbName);
|
||||
const res = await DBGetTables(buildRpcConnectionConfig(config) as any, conn.dbName);
|
||||
if (res.success) {
|
||||
setConnectionStates(prev => ({ ...prev, [key as string]: 'success' }));
|
||||
|
||||
@@ -1578,14 +1475,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
|
||||
const handleCopyStructure = async (node: any) => {
|
||||
const { config, dbName, tableName } = node.dataRef;
|
||||
const res = await DBShowCreateTable({
|
||||
...config,
|
||||
port: Number(config.port),
|
||||
password: config.password || "",
|
||||
database: config.database || "",
|
||||
useSSH: config.useSSH || false,
|
||||
ssh: config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
|
||||
} as any, dbName, tableName);
|
||||
const res = await DBShowCreateTable(buildRpcConnectionConfig(config) as any, dbName, tableName);
|
||||
if (res.success) {
|
||||
navigator.clipboard.writeText(res.data as string);
|
||||
message.success('表结构已复制到剪贴板');
|
||||
@@ -1597,14 +1487,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
const handleExport = async (node: any, format: string) => {
|
||||
const { config, dbName, tableName } = node.dataRef;
|
||||
const hide = message.loading(`正在导出 ${tableName} 为 ${format.toUpperCase()}...`, 0);
|
||||
const res = await ExportTable({
|
||||
...config,
|
||||
port: Number(config.port),
|
||||
password: config.password || "",
|
||||
database: config.database || "",
|
||||
useSSH: config.useSSH || false,
|
||||
ssh: config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
|
||||
} as any, dbName, tableName, format);
|
||||
const res = await ExportTable(buildRpcConnectionConfig(config) as any, dbName, tableName, format);
|
||||
hide();
|
||||
if (res.success) {
|
||||
message.success('导出成功');
|
||||
@@ -1613,14 +1496,9 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
}
|
||||
};
|
||||
|
||||
const normalizeConnConfig = (raw: any) => ({
|
||||
...raw,
|
||||
port: Number(raw.port),
|
||||
password: raw.password || "",
|
||||
database: raw.database || "",
|
||||
useSSH: raw.useSSH || false,
|
||||
ssh: raw.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
|
||||
});
|
||||
const normalizeConnConfig = (raw: any) => (
|
||||
buildRpcConnectionConfig(raw)
|
||||
);
|
||||
|
||||
const handleExportDatabaseSQL = async (node: any, includeData: boolean) => {
|
||||
const conn = node.dataRef;
|
||||
@@ -1715,7 +1593,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
|
||||
};
|
||||
|
||||
const res = await DBGetDatabases(config as any);
|
||||
const res = await DBGetDatabases(buildRpcConnectionConfig(config) as any);
|
||||
if (res.success) {
|
||||
const dbRows: any[] = Array.isArray(res.data) ? res.data : [];
|
||||
let dbs = dbRows.map((row: any) => {
|
||||
@@ -1750,7 +1628,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
};
|
||||
|
||||
const [res, viewResult] = await Promise.all([
|
||||
DBGetTables(config as any, dbName),
|
||||
DBGetTables(buildRpcConnectionConfig(config) as any, dbName),
|
||||
loadViews(conn, dbName).catch(() => ({ views: [], supported: false })),
|
||||
]);
|
||||
|
||||
@@ -2026,7 +1904,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
|
||||
};
|
||||
|
||||
const res = await DBGetDatabases(config as any);
|
||||
const res = await DBGetDatabases(buildRpcConnectionConfig(config) as any);
|
||||
if (res.success) {
|
||||
const dbRows: any[] = Array.isArray(res.data) ? res.data : [];
|
||||
let dbs = dbRows.map((row: any) => {
|
||||
@@ -2238,7 +2116,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
|
||||
};
|
||||
|
||||
const res = await CreateDatabase(config as any, values.name);
|
||||
const res = await CreateDatabase(buildRpcConnectionConfig(config) as any, values.name);
|
||||
if (res.success) {
|
||||
message.success("数据库创建成功");
|
||||
setIsCreateDbModalOpen(false);
|
||||
@@ -2254,14 +2132,9 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
};
|
||||
|
||||
const buildRuntimeConfig = (conn: any, overrideDatabase?: string, clearDatabase: boolean = false) => {
|
||||
return {
|
||||
...conn.config,
|
||||
port: Number(conn.config.port),
|
||||
password: conn.config.password || "",
|
||||
database: clearDatabase ? "" : ((overrideDatabase ?? conn.config.database) || ""),
|
||||
useSSH: conn.config.useSSH || false,
|
||||
ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }
|
||||
};
|
||||
return buildRpcConnectionConfig(conn.config, {
|
||||
database: clearDatabase ? '' : ((overrideDatabase ?? conn.config.database) || ''),
|
||||
});
|
||||
};
|
||||
|
||||
const getConnectionNodeRef = (connRef: any) => {
|
||||
@@ -2303,7 +2176,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
}
|
||||
|
||||
const config = buildRuntimeConfig(conn, conn.dbName);
|
||||
const res = await RenameDatabase(config as any, oldDbName, newDbName);
|
||||
const res = await RenameDatabase(buildRpcConnectionConfig(config) as any, oldDbName, newDbName);
|
||||
if (res.success) {
|
||||
message.success("数据库重命名成功");
|
||||
setExpandedKeys(prev => prev.filter(k => !k.toString().startsWith(`${conn.id}-${oldDbName}`)));
|
||||
@@ -2330,7 +2203,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
okButtonProps: { danger: true },
|
||||
onOk: async () => {
|
||||
const config = buildRuntimeConfig(conn, conn.dbName);
|
||||
const res = await DropDatabase(config as any, dbName);
|
||||
const res = await DropDatabase(buildRpcConnectionConfig(config) as any, dbName);
|
||||
if (res.success) {
|
||||
message.success("数据库删除成功");
|
||||
closeTabsByDatabase(conn.id, dbName);
|
||||
@@ -2360,7 +2233,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
return;
|
||||
}
|
||||
const config = buildRuntimeConfig(conn, conn.dbName);
|
||||
const res = await RenameTable(config as any, conn.dbName, oldTableName, newTableName);
|
||||
const res = await RenameTable(buildRpcConnectionConfig(config) as any, conn.dbName, oldTableName, newTableName);
|
||||
if (res.success) {
|
||||
message.success("表重命名成功");
|
||||
await loadTables(getDatabaseNodeRef(conn, conn.dbName));
|
||||
@@ -2385,7 +2258,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
okButtonProps: { danger: true },
|
||||
onOk: async () => {
|
||||
const config = buildRuntimeConfig(conn, conn.dbName);
|
||||
const res = await DropTable(config as any, conn.dbName, tableName);
|
||||
const res = await DropTable(buildRpcConnectionConfig(config) as any, conn.dbName, tableName);
|
||||
if (res.success) {
|
||||
message.success("表删除成功");
|
||||
await loadTables(getDatabaseNodeRef(conn, conn.dbName));
|
||||
@@ -2445,7 +2318,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
}
|
||||
}
|
||||
if (query) {
|
||||
const result = await DBQuery(config as any, dbName, query);
|
||||
const result = await DBQuery(buildRpcConnectionConfig(config) as any, dbName, query);
|
||||
if (result.success && Array.isArray(result.data) && result.data.length > 0) {
|
||||
const row = result.data[0] as Record<string, any>;
|
||||
const def = row.view_definition || row.VIEW_DEFINITION || Object.values(row).find(v => typeof v === 'string' && String(v).length > 10) || '';
|
||||
@@ -2511,7 +2384,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
okButtonProps: { danger: true },
|
||||
onOk: async () => {
|
||||
const config = buildRuntimeConfig(conn, conn.dbName);
|
||||
const res = await DropView(config as any, conn.dbName, viewName);
|
||||
const res = await DropView(buildRpcConnectionConfig(config) as any, conn.dbName, viewName);
|
||||
if (res.success) {
|
||||
message.success("视图删除成功");
|
||||
await loadTables(getDatabaseNodeRef(conn, conn.dbName));
|
||||
@@ -2538,7 +2411,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
return;
|
||||
}
|
||||
const config = buildRuntimeConfig(conn, conn.dbName);
|
||||
const res = await RenameView(config as any, conn.dbName, oldViewName, newViewName);
|
||||
const res = await RenameView(buildRpcConnectionConfig(config) as any, conn.dbName, oldViewName, newViewName);
|
||||
if (res.success) {
|
||||
message.success("视图重命名成功");
|
||||
await loadTables(getDatabaseNodeRef(conn, conn.dbName));
|
||||
@@ -2610,7 +2483,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
}
|
||||
}
|
||||
if (query) {
|
||||
const result = await DBQuery(config as any, dbName, query);
|
||||
const result = await DBQuery(buildRpcConnectionConfig(config) as any, dbName, query);
|
||||
if (result.success && Array.isArray(result.data) && result.data.length > 0) {
|
||||
if (dialect === 'oracle' || dialect === 'dm') {
|
||||
const lines = result.data.map((row: any) => row.text || row.TEXT || Object.values(row)[0] || '').join('');
|
||||
@@ -2704,7 +2577,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
okButtonProps: { danger: true },
|
||||
onOk: async () => {
|
||||
const config = buildRuntimeConfig(conn, conn.dbName);
|
||||
const res = await DropFunction(config as any, conn.dbName, routineName, routineType);
|
||||
const res = await DropFunction(buildRpcConnectionConfig(config) as any, conn.dbName, routineName, routineType);
|
||||
if (res.success) {
|
||||
message.success(`${typeLabel}删除成功`);
|
||||
await loadTables(getDatabaseNodeRef(conn, conn.dbName));
|
||||
@@ -3186,9 +3059,22 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
Modal.confirm({
|
||||
title: '确认删除',
|
||||
content: `确定要删除连接 "${node.title}" 吗?`,
|
||||
onOk: () => {
|
||||
closeTabsByConnection(String(node.key));
|
||||
removeConnection(node.key);
|
||||
onOk: async () => {
|
||||
const connId = String(node.key);
|
||||
const backendApp = (window as any).go?.app?.App;
|
||||
if (typeof backendApp?.DeleteConnection !== 'function') {
|
||||
message.error('删除连接失败:后端接口不可用');
|
||||
throw new Error('DeleteConnection unavailable');
|
||||
}
|
||||
try {
|
||||
await backendApp.DeleteConnection(connId);
|
||||
closeTabsByConnection(connId);
|
||||
removeConnection(connId);
|
||||
message.success('已删除连接');
|
||||
} catch (error: any) {
|
||||
message.error(error?.message || '删除连接失败');
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -3323,9 +3209,22 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
Modal.confirm({
|
||||
title: '确认删除',
|
||||
content: `确定要删除连接 "${node.title}" 吗?`,
|
||||
onOk: () => {
|
||||
closeTabsByConnection(String(node.key));
|
||||
removeConnection(node.key);
|
||||
onOk: async () => {
|
||||
const connId = String(node.key);
|
||||
const backendApp = (window as any).go?.app?.App;
|
||||
if (typeof backendApp?.DeleteConnection !== 'function') {
|
||||
message.error('删除连接失败:后端接口不可用');
|
||||
throw new Error('DeleteConnection unavailable');
|
||||
}
|
||||
try {
|
||||
await backendApp.DeleteConnection(connId);
|
||||
closeTabsByConnection(connId);
|
||||
removeConnection(connId);
|
||||
message.success('已删除连接');
|
||||
} catch (error: any) {
|
||||
message.error(error?.message || '删除连接失败');
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import { TabData, ColumnDefinition, IndexDefinition, ForeignKeyDefinition, Trigg
|
||||
import { useStore } from '../store';
|
||||
import { DBGetColumns, DBGetIndexes, DBQuery, DBGetForeignKeys, DBGetTriggers, DBShowCreateTable } from '../../wailsjs/go/app/App';
|
||||
import { hasIndexFormChanged, normalizeIndexFormFromRow, shouldRestoreOriginalIndex, toggleIndexSelection as getNextIndexSelection, type IndexDisplaySnapshot } from './tableDesignerIndexUtils';
|
||||
import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig';
|
||||
|
||||
interface EditableColumn extends ColumnDefinition {
|
||||
_key: string;
|
||||
@@ -751,14 +752,14 @@ const TableDesigner: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
};
|
||||
|
||||
const promises: Promise<any>[] = [
|
||||
DBGetColumns(config as any, tab.dbName || '', tab.tableName || ''),
|
||||
DBGetIndexes(config as any, tab.dbName || '', tab.tableName || ''),
|
||||
DBGetForeignKeys(config as any, tab.dbName || '', tab.tableName || ''),
|
||||
DBGetTriggers(config as any, tab.dbName || '', tab.tableName || '')
|
||||
DBGetColumns(buildRpcConnectionConfig(config) as any, tab.dbName || '', tab.tableName || ''),
|
||||
DBGetIndexes(buildRpcConnectionConfig(config) as any, tab.dbName || '', tab.tableName || ''),
|
||||
DBGetForeignKeys(buildRpcConnectionConfig(config) as any, tab.dbName || '', tab.tableName || ''),
|
||||
DBGetTriggers(buildRpcConnectionConfig(config) as any, tab.dbName || '', tab.tableName || '')
|
||||
];
|
||||
|
||||
if (!isNewTable) {
|
||||
promises.push(DBShowCreateTable(config as any, tab.dbName || '', tab.tableName || ''));
|
||||
promises.push(DBShowCreateTable(buildRpcConnectionConfig(config) as any, tab.dbName || '', tab.tableName || ''));
|
||||
}
|
||||
|
||||
const results = await Promise.all(promises);
|
||||
@@ -848,7 +849,7 @@ const TableDesigner: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
if (!type) return '';
|
||||
|
||||
if (type === 'custom') {
|
||||
return inferDialectFromCustomDriver(String((conn?.config as any)?.driver || ''));
|
||||
return inferDialectFromCustomDriver(String(conn?.config?.driver || ''));
|
||||
}
|
||||
|
||||
if (type === 'mariadb' || type === 'diros' || type === 'sphinx') return 'mysql';
|
||||
@@ -993,7 +994,7 @@ ${selectedTrigger.statement}`;
|
||||
const dropSql = buildDropTriggerSql(selectedTrigger.name);
|
||||
|
||||
try {
|
||||
const res = await DBQuery(config as any, tab.dbName || '', dropSql);
|
||||
const res = await DBQuery(buildRpcConnectionConfig(config) as any, tab.dbName || '', dropSql);
|
||||
if (res.success) {
|
||||
message.success('触发器删除成功');
|
||||
setSelectedTrigger(null);
|
||||
@@ -1030,7 +1031,7 @@ ${selectedTrigger.statement}`;
|
||||
// 如果是编辑模式,先删除旧触发器
|
||||
if (triggerEditMode === 'edit' && selectedTrigger) {
|
||||
const dropSql = buildDropTriggerSql(selectedTrigger.name);
|
||||
const dropRes = await DBQuery(config as any, tab.dbName || '', dropSql);
|
||||
const dropRes = await DBQuery(buildRpcConnectionConfig(config) as any, tab.dbName || '', dropSql);
|
||||
if (!dropRes.success) {
|
||||
message.error('删除旧触发器失败: ' + dropRes.message);
|
||||
setTriggerExecuting(false);
|
||||
@@ -1039,7 +1040,7 @@ ${selectedTrigger.statement}`;
|
||||
}
|
||||
|
||||
// 执行创建语句
|
||||
const res = await DBQuery(config as any, tab.dbName || '', triggerEditSql);
|
||||
const res = await DBQuery(buildRpcConnectionConfig(config) as any, tab.dbName || '', triggerEditSql);
|
||||
if (res.success) {
|
||||
message.success(triggerEditMode === 'create' ? '触发器创建成功' : '触发器修改成功');
|
||||
setIsTriggerEditModalOpen(false);
|
||||
@@ -1522,7 +1523,7 @@ ${selectedTrigger.statement}`;
|
||||
const sql = buildCreateTableSql(copyTableName.trim(), selectedColumns, copyCharset, copyCollation);
|
||||
setCopyExecuting(true);
|
||||
try {
|
||||
const res = await DBQuery(config as any, tab.dbName || '', sql);
|
||||
const res = await DBQuery(buildRpcConnectionConfig(config) as any, tab.dbName || '', sql);
|
||||
if (res.success) {
|
||||
message.success(`已将 ${selectedColumns.length} 个字段复制到新表 ${copyTableName.trim()}`);
|
||||
setIsCopyColumnsModalOpen(false);
|
||||
@@ -1551,7 +1552,7 @@ ${selectedTrigger.statement}`;
|
||||
for (let i = 0; i < statements.length; i++) {
|
||||
let stmt = statements[i];
|
||||
if (!stmt.endsWith(';')) stmt += ';';
|
||||
const res = await DBQuery(config as any, tab.dbName || '', stmt);
|
||||
const res = await DBQuery(buildRpcConnectionConfig(config) as any, tab.dbName || '', stmt);
|
||||
if (!res.success) {
|
||||
const prefix = statements.length > 1 ? `第 ${i + 1}/${statements.length} 条语句执行失败: ` : '执行失败: ';
|
||||
return {
|
||||
@@ -2202,7 +2203,7 @@ END;`;
|
||||
const conn = connections.find(c => c.id === tab.connectionId);
|
||||
if (!conn) return;
|
||||
const config = { ...conn.config, port: Number(conn.config.port), password: conn.config.password || "", database: conn.config.database || "", useSSH: conn.config.useSSH || false, ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" } };
|
||||
const res = await DBQuery(config as any, tab.dbName || '', previewSql);
|
||||
const res = await DBQuery(buildRpcConnectionConfig(config) as any, tab.dbName || '', previewSql);
|
||||
if (res.success) {
|
||||
message.success(isNewTable ? "表创建成功!" : "表结构修改成功!");
|
||||
setIsPreviewOpen(false);
|
||||
|
||||
@@ -4,6 +4,7 @@ import { TableOutlined, SearchOutlined, ReloadOutlined, SortAscendingOutlined, D
|
||||
import { useStore } from '../store';
|
||||
import { DBQuery, DBShowCreateTable, ExportTable, DropTable, RenameTable } from '../../wailsjs/go/app/App';
|
||||
import type { TabData } from '../types';
|
||||
import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig';
|
||||
|
||||
interface TableOverviewProps {
|
||||
tab: TabData;
|
||||
@@ -163,9 +164,9 @@ const TableOverview: React.FC<TableOverviewProps> = ({ tab }) => {
|
||||
useSSH: connection.config.useSSH || false,
|
||||
ssh: connection.config.ssh || { host: '', port: 22, user: '', password: '', keyPath: '' },
|
||||
};
|
||||
const dialect = getMetadataDialect(connection.config.type, (connection.config as any)?.driver);
|
||||
const dialect = getMetadataDialect(connection.config.type, connection.config.driver);
|
||||
const sql = buildTableStatusSQL(dialect, tab.dbName || '', (tab as any).schemaName);
|
||||
const res = await DBQuery(config as any, tab.dbName || '', sql);
|
||||
const res = await DBQuery(buildRpcConnectionConfig(config) as any, tab.dbName || '', sql);
|
||||
if (res.success && Array.isArray(res.data)) {
|
||||
setTables(parseTableStats(dialect, res.data));
|
||||
} else {
|
||||
@@ -239,7 +240,7 @@ const TableOverview: React.FC<TableOverviewProps> = ({ tab }) => {
|
||||
const handleCopyStructure = useCallback(async (tableName: string) => {
|
||||
const config = buildConfig();
|
||||
if (!config) return;
|
||||
const res = await DBShowCreateTable(config as any, tab.dbName || '', tableName);
|
||||
const res = await DBShowCreateTable(buildRpcConnectionConfig(config) as any, tab.dbName || '', tableName);
|
||||
if (res.success) {
|
||||
navigator.clipboard.writeText(res.data as string);
|
||||
message.success('表结构已复制到剪贴板');
|
||||
@@ -252,7 +253,7 @@ const TableOverview: React.FC<TableOverviewProps> = ({ tab }) => {
|
||||
const config = buildConfig();
|
||||
if (!config) return;
|
||||
const hide = message.loading(`正在导出 ${tableName} 为 ${format.toUpperCase()}...`, 0);
|
||||
const res = await ExportTable(config as any, tab.dbName || '', tableName, format);
|
||||
const res = await ExportTable(buildRpcConnectionConfig(config) as any, tab.dbName || '', tableName, format);
|
||||
hide();
|
||||
if (res.success) {
|
||||
message.success('导出成功');
|
||||
@@ -269,7 +270,7 @@ const TableOverview: React.FC<TableOverviewProps> = ({ tab }) => {
|
||||
content: `确定删除表 "${tableName}" 吗?该操作不可恢复。`,
|
||||
okButtonProps: { danger: true },
|
||||
onOk: async () => {
|
||||
const res = await DropTable(config as any, tab.dbName || '', tableName);
|
||||
const res = await DropTable(buildRpcConnectionConfig(config) as any, tab.dbName || '', tableName);
|
||||
if (res.success) {
|
||||
message.success('表删除成功');
|
||||
loadData();
|
||||
@@ -299,7 +300,7 @@ const TableOverview: React.FC<TableOverviewProps> = ({ tab }) => {
|
||||
const trimmed = newName.trim();
|
||||
if (!trimmed) { message.error('表名不能为空'); return Promise.reject(); }
|
||||
if (trimmed === tableName) { message.warning('新旧表名相同'); return; }
|
||||
const res = await RenameTable(config as any, tab.dbName || '', tableName, trimmed);
|
||||
const res = await RenameTable(buildRpcConnectionConfig(config) as any, tab.dbName || '', tableName, trimmed);
|
||||
if (res.success) {
|
||||
message.success('表重命名成功');
|
||||
loadData();
|
||||
|
||||
@@ -4,6 +4,7 @@ import { Spin, Alert } from 'antd';
|
||||
import { TabData } from '../types';
|
||||
import { useStore } from '../store';
|
||||
import { DBQuery } from '../../wailsjs/go/app/App';
|
||||
import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig';
|
||||
|
||||
interface TriggerViewerProps {
|
||||
tab: TabData;
|
||||
@@ -100,7 +101,7 @@ LIMIT 1`];
|
||||
const sql = String(query || '').trim();
|
||||
if (!sql) continue;
|
||||
try {
|
||||
const result = await DBQuery(config as any, dbName, sql);
|
||||
const result = await DBQuery(buildRpcConnectionConfig(config) as any, dbName, sql);
|
||||
if (!result.success || !Array.isArray(result.data)) {
|
||||
lastMessage = result.message || lastMessage;
|
||||
continue;
|
||||
@@ -126,7 +127,7 @@ LIMIT 1`];
|
||||
];
|
||||
for (const query of candidates) {
|
||||
try {
|
||||
const result = await DBQuery(config as any, dbName, query);
|
||||
const result = await DBQuery(buildRpcConnectionConfig(config) as any, dbName, query);
|
||||
if (!result.success || !Array.isArray(result.data) || result.data.length === 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import { useStore } from '../../store';
|
||||
import { DBGetTables, DBShowCreateTable, DBGetDatabases } from '../../../wailsjs/go/app/App';
|
||||
import type { OverlayWorkbenchTheme } from '../../utils/overlayWorkbenchTheme';
|
||||
import type { AIComposerNotice } from '../../utils/aiComposerNotice';
|
||||
import { buildRpcConnectionConfig } from '../../utils/connectionRpcConfig';
|
||||
|
||||
interface AIChatInputProps {
|
||||
input: string;
|
||||
@@ -124,7 +125,7 @@ export const AIChatInput: React.FC<AIChatInputProps> = ({
|
||||
setContextLoading(true);
|
||||
setSelectedDbName(dbName);
|
||||
try {
|
||||
const res = await DBGetTables(connConfig, dbName);
|
||||
const res = await DBGetTables(buildRpcConnectionConfig(connConfig), dbName);
|
||||
if (res.success && Array.isArray(res.data)) {
|
||||
setContextTables(res.data.map(r => ({ name: Object.values(r)[0] as string })));
|
||||
} else {
|
||||
@@ -155,7 +156,7 @@ export const AIChatInput: React.FC<AIChatInputProps> = ({
|
||||
|
||||
try {
|
||||
// Fetch databases
|
||||
const dbRes = await DBGetDatabases(conn.config as any);
|
||||
const dbRes = await DBGetDatabases(buildRpcConnectionConfig(conn.config) as any);
|
||||
if (dbRes.success && Array.isArray(dbRes.data)) {
|
||||
const databases = dbRes.data.map((r: any) => Object.values(r)[0] as string);
|
||||
setDbList(databases);
|
||||
@@ -164,7 +165,7 @@ export const AIChatInput: React.FC<AIChatInputProps> = ({
|
||||
// Fetch tables for the active contextual database
|
||||
const initDbName = activeContext.dbName || '';
|
||||
setSelectedDbName(initDbName);
|
||||
const tablesRes = await DBGetTables(conn.config as any, initDbName);
|
||||
const tablesRes = await DBGetTables(buildRpcConnectionConfig(conn.config) as any, initDbName);
|
||||
if (tablesRes.success && Array.isArray(tablesRes.data)) {
|
||||
setContextTables(tablesRes.data.map((r: any) => ({ name: Object.values(r)[0] as string })));
|
||||
} else {
|
||||
@@ -201,7 +202,7 @@ export const AIChatInput: React.FC<AIChatInputProps> = ({
|
||||
if (activeContextItems.find(c => c.dbName === dbName && c.tableName === tableName)) {
|
||||
continue;
|
||||
}
|
||||
const res = await DBShowCreateTable(conn.config as any, dbName, tableName);
|
||||
const res = await DBShowCreateTable(buildRpcConnectionConfig(conn.config) as any, dbName, tableName);
|
||||
let createSql = '';
|
||||
if (res.success && res.data) {
|
||||
if (typeof res.data === 'string') {
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { buildCopyInsertSQL } from './dataGridCopyInsert';
|
||||
import {
|
||||
buildCopyDeleteSQL,
|
||||
buildCopyInsertSQL,
|
||||
buildCopyUpdateSQL,
|
||||
resolveUniqueKeyGroupsFromIndexes,
|
||||
} from './dataGridCopyInsert';
|
||||
|
||||
describe('buildCopyInsertSQL', () => {
|
||||
it('normalizes PostgreSQL timestamp values for copy-as-insert and uses PostgreSQL identifier quoting', () => {
|
||||
@@ -58,4 +63,100 @@ describe('buildCopyInsertSQL', () => {
|
||||
`INSERT INTO public.audit_log (payload) VALUES ('2026-01-21T18:32:26+08:00');`,
|
||||
);
|
||||
});
|
||||
|
||||
it('groups composite unique indexes by name and sequence order', () => {
|
||||
expect(resolveUniqueKeyGroupsFromIndexes([
|
||||
{ name: 'PRIMARY', columnName: 'id', nonUnique: 0, seqInIndex: 1, indexType: 'BTREE' },
|
||||
{ name: 'uk_order_code', columnName: 'code', nonUnique: 0, seqInIndex: 2, indexType: 'BTREE' },
|
||||
{ name: 'uk_order_code', columnName: 'tenant_id', nonUnique: 0, seqInIndex: 1, indexType: 'BTREE' },
|
||||
{ name: 'idx_note', columnName: 'note', nonUnique: 1, seqInIndex: 1, indexType: 'BTREE' },
|
||||
])).toEqual([
|
||||
['id'],
|
||||
['tenant_id', 'code'],
|
||||
]);
|
||||
});
|
||||
|
||||
it('builds UPDATE SQL with a primary-key WHERE clause and keeps literal formatting aligned with INSERT', () => {
|
||||
const result = buildCopyUpdateSQL({
|
||||
dbType: 'mysql',
|
||||
tableName: 'orders',
|
||||
orderedCols: ['id', 'note', 'deleted_at'],
|
||||
record: {
|
||||
id: 7,
|
||||
note: "O'Brien",
|
||||
deleted_at: null,
|
||||
},
|
||||
pkColumns: ['id'],
|
||||
columnTypesByLowerName: {
|
||||
deleted_at: 'datetime',
|
||||
},
|
||||
allTableColumns: ['id', 'note', 'deleted_at'],
|
||||
});
|
||||
|
||||
expect(result).toEqual({
|
||||
ok: true,
|
||||
whereStrategy: 'primary-key',
|
||||
sql: `UPDATE \`orders\` SET \`id\` = '7', \`note\` = 'O''Brien', \`deleted_at\` = NULL WHERE (\`id\` = '7');`,
|
||||
});
|
||||
});
|
||||
|
||||
it('builds DELETE SQL with a composite unique-key WHERE clause when no primary key is available', () => {
|
||||
const result = buildCopyDeleteSQL({
|
||||
dbType: 'postgres',
|
||||
tableName: 'public.audit_log',
|
||||
orderedCols: ['tenant_id', 'code', 'payload'],
|
||||
record: {
|
||||
tenant_id: 'acme',
|
||||
code: 'evt-7',
|
||||
payload: '{"ok":true}',
|
||||
},
|
||||
uniqueKeyGroups: [['tenant_id', 'code']],
|
||||
allTableColumns: ['tenant_id', 'code', 'payload'],
|
||||
});
|
||||
|
||||
expect(result).toEqual({
|
||||
ok: true,
|
||||
whereStrategy: 'unique-key',
|
||||
sql: `DELETE FROM public.audit_log WHERE (tenant_id = 'acme' AND code = 'evt-7');`,
|
||||
});
|
||||
});
|
||||
|
||||
it('falls back to all-column matching and uses IS NULL for null values', () => {
|
||||
const result = buildCopyDeleteSQL({
|
||||
dbType: 'sqlserver',
|
||||
tableName: 'dbo.OrderLog',
|
||||
orderedCols: ['id', 'deleted_at', 'flag'],
|
||||
allTableColumns: ['id', 'deleted_at', 'flag'],
|
||||
record: {
|
||||
id: 5,
|
||||
deleted_at: null,
|
||||
flag: true,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result).toEqual({
|
||||
ok: true,
|
||||
whereStrategy: 'all-columns',
|
||||
sql: `DELETE FROM [dbo].[OrderLog] WHERE ([id] = '5' AND [deleted_at] IS NULL AND [flag] = 'true');`,
|
||||
});
|
||||
});
|
||||
|
||||
it('refuses to build UPDATE/DELETE SQL when the result set lacks keys and does not cover all table columns', () => {
|
||||
const result = buildCopyDeleteSQL({
|
||||
dbType: 'mysql',
|
||||
tableName: 'orders',
|
||||
orderedCols: ['note'],
|
||||
allTableColumns: ['id', 'note', 'created_at'],
|
||||
record: {
|
||||
note: 'partial row',
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.ok).toBe(false);
|
||||
if (result.ok) {
|
||||
throw new Error('expected buildCopyDeleteSQL to fail');
|
||||
}
|
||||
expect(result.error).toContain('主键');
|
||||
expect(result.error).toContain('全部字段');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { IndexDefinition } from '../types';
|
||||
import { escapeLiteral, quoteIdentPart, quoteQualifiedIdent } from '../utils/sql';
|
||||
|
||||
type BuildCopyInsertSQLParams = {
|
||||
@@ -8,6 +9,22 @@ type BuildCopyInsertSQLParams = {
|
||||
columnTypesByLowerName?: Record<string, string>;
|
||||
};
|
||||
|
||||
type BuildCopyMutationSQLParams = BuildCopyInsertSQLParams & {
|
||||
pkColumns?: string[];
|
||||
uniqueKeyGroups?: string[][];
|
||||
allTableColumns?: string[];
|
||||
};
|
||||
|
||||
type CopySqlWhereStrategy = 'primary-key' | 'unique-key' | 'all-columns';
|
||||
|
||||
export type CopyMutationSQLResult =
|
||||
| { ok: true; sql: string; whereStrategy: CopySqlWhereStrategy }
|
||||
| { ok: false; error: string };
|
||||
|
||||
type CopyMutationWhereClauseResult =
|
||||
| { ok: true; clause: string; whereStrategy: CopySqlWhereStrategy }
|
||||
| { ok: false; error: string };
|
||||
|
||||
const looksLikeDateTimeText = (val: string): boolean => {
|
||||
if (!val) return false;
|
||||
const len = val.length;
|
||||
@@ -104,6 +121,157 @@ export const formatLocalDateTimeLiteral = (value: Date): string => {
|
||||
return `${year}-${month}-${day} ${hour}:${minute}:${second}`;
|
||||
};
|
||||
|
||||
const getColumnType = (columnTypesByLowerName: Record<string, string>, columnName: string): string | undefined => (
|
||||
columnTypesByLowerName[String(columnName || '').toLowerCase()]
|
||||
);
|
||||
|
||||
const getRecordValue = (
|
||||
record: Record<string, any>,
|
||||
columnName: string,
|
||||
): { exists: boolean; value: any } => {
|
||||
if (Object.prototype.hasOwnProperty.call(record || {}, columnName)) {
|
||||
return { exists: true, value: record?.[columnName] };
|
||||
}
|
||||
const loweredColumnName = String(columnName || '').toLowerCase();
|
||||
const matchedKey = Object.keys(record || {}).find((key) => key.toLowerCase() === loweredColumnName);
|
||||
if (!matchedKey) {
|
||||
return { exists: false, value: undefined };
|
||||
}
|
||||
return { exists: true, value: record?.[matchedKey] };
|
||||
};
|
||||
|
||||
const normalizeColumnList = (columns: string[] | undefined): string[] => {
|
||||
const seen = new Set<string>();
|
||||
const result: string[] = [];
|
||||
(columns || []).forEach((column) => {
|
||||
const normalized = String(column || '').trim();
|
||||
if (!normalized) return;
|
||||
const lowered = normalized.toLowerCase();
|
||||
if (seen.has(lowered)) return;
|
||||
seen.add(lowered);
|
||||
result.push(normalized);
|
||||
});
|
||||
return result;
|
||||
};
|
||||
|
||||
const toNormalizedLiteralText = (value: any, columnType?: string): string => {
|
||||
if (typeof value === 'string') {
|
||||
return normalizeTemporalLiteralText(value, columnType, true);
|
||||
}
|
||||
if (value instanceof Date) {
|
||||
return formatLocalDateTimeLiteral(value);
|
||||
}
|
||||
return String(value);
|
||||
};
|
||||
|
||||
const formatCopySqlLiteral = (value: any, columnType?: string): string => {
|
||||
if (value === null || value === undefined) {
|
||||
return 'NULL';
|
||||
}
|
||||
return `'${escapeLiteral(toNormalizedLiteralText(value, columnType))}'`;
|
||||
};
|
||||
|
||||
const doesResultCoverAllTableColumns = (orderedCols: string[], allTableColumns: string[]): boolean => {
|
||||
const normalizedOrderedCols = normalizeColumnList(orderedCols);
|
||||
const normalizedAllTableColumns = normalizeColumnList(allTableColumns);
|
||||
if (normalizedOrderedCols.length === 0 || normalizedOrderedCols.length !== normalizedAllTableColumns.length) {
|
||||
return false;
|
||||
}
|
||||
const orderedSet = new Set(normalizedOrderedCols.map((column) => column.toLowerCase()));
|
||||
return normalizedAllTableColumns.every((column) => orderedSet.has(column.toLowerCase()));
|
||||
};
|
||||
|
||||
const buildWhereClauseForColumns = ({
|
||||
dbType,
|
||||
columns,
|
||||
record,
|
||||
columnTypesByLowerName,
|
||||
requireNonNullValues,
|
||||
}: {
|
||||
dbType: string;
|
||||
columns: string[];
|
||||
record: Record<string, any>;
|
||||
columnTypesByLowerName: Record<string, string>;
|
||||
requireNonNullValues: boolean;
|
||||
}): string | null => {
|
||||
const predicates: string[] = [];
|
||||
for (const columnName of columns) {
|
||||
const { exists, value } = getRecordValue(record, columnName);
|
||||
if (!exists) {
|
||||
return null;
|
||||
}
|
||||
const quotedColumn = quoteIdentPart(dbType, columnName);
|
||||
if (value === null || value === undefined) {
|
||||
if (requireNonNullValues) {
|
||||
return null;
|
||||
}
|
||||
predicates.push(`${quotedColumn} IS NULL`);
|
||||
continue;
|
||||
}
|
||||
predicates.push(`${quotedColumn} = ${formatCopySqlLiteral(value, getColumnType(columnTypesByLowerName, columnName))}`);
|
||||
}
|
||||
if (predicates.length === 0) {
|
||||
return null;
|
||||
}
|
||||
return `(${predicates.join(' AND ')})`;
|
||||
};
|
||||
|
||||
const resolveMutationWhereClause = ({
|
||||
dbType,
|
||||
orderedCols,
|
||||
record,
|
||||
pkColumns = [],
|
||||
uniqueKeyGroups = [],
|
||||
allTableColumns = [],
|
||||
columnTypesByLowerName = {},
|
||||
}: BuildCopyMutationSQLParams): CopyMutationWhereClauseResult => {
|
||||
const normalizedPkColumns = normalizeColumnList(pkColumns);
|
||||
const pkWhereClause = buildWhereClauseForColumns({
|
||||
dbType,
|
||||
columns: normalizedPkColumns,
|
||||
record,
|
||||
columnTypesByLowerName,
|
||||
requireNonNullValues: true,
|
||||
});
|
||||
if (pkWhereClause) {
|
||||
return { ok: true, clause: pkWhereClause, whereStrategy: 'primary-key' };
|
||||
}
|
||||
|
||||
const normalizedUniqueKeyGroups = (uniqueKeyGroups || [])
|
||||
.map((group) => normalizeColumnList(group))
|
||||
.filter((group) => group.length > 0);
|
||||
for (const group of normalizedUniqueKeyGroups) {
|
||||
const uniqueWhereClause = buildWhereClauseForColumns({
|
||||
dbType,
|
||||
columns: group,
|
||||
record,
|
||||
columnTypesByLowerName,
|
||||
requireNonNullValues: true,
|
||||
});
|
||||
if (uniqueWhereClause) {
|
||||
return { ok: true, clause: uniqueWhereClause, whereStrategy: 'unique-key' };
|
||||
}
|
||||
}
|
||||
|
||||
if (doesResultCoverAllTableColumns(orderedCols, allTableColumns)) {
|
||||
const fullRowWhereClause = buildWhereClauseForColumns({
|
||||
dbType,
|
||||
columns: orderedCols,
|
||||
record,
|
||||
columnTypesByLowerName,
|
||||
requireNonNullValues: false,
|
||||
});
|
||||
if (fullRowWhereClause) {
|
||||
return { ok: true, clause: fullRowWhereClause, whereStrategy: 'all-columns' };
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
ok: false,
|
||||
error: '当前结果集缺少可安全定位行数据的主键/唯一键,且未覆盖表的全部字段,无法生成 WHERE 条件。',
|
||||
};
|
||||
};
|
||||
|
||||
export const buildCopyInsertSQL = ({
|
||||
dbType,
|
||||
tableName,
|
||||
@@ -114,18 +282,136 @@ export const buildCopyInsertSQL = ({
|
||||
const targetTable = quoteQualifiedIdent(dbType, tableName || 'table');
|
||||
const quotedCols = orderedCols.map((col) => quoteIdentPart(dbType, col));
|
||||
const values = orderedCols.map((col) => {
|
||||
const value = record?.[col];
|
||||
if (value === null || value === undefined) return 'NULL';
|
||||
|
||||
const columnType = columnTypesByLowerName[String(col || '').toLowerCase()];
|
||||
const raw =
|
||||
typeof value === 'string'
|
||||
? normalizeTemporalLiteralText(value, columnType, true)
|
||||
: value instanceof Date
|
||||
? formatLocalDateTimeLiteral(value)
|
||||
: String(value);
|
||||
return `'${escapeLiteral(raw)}'`;
|
||||
const { value } = getRecordValue(record, col);
|
||||
return formatCopySqlLiteral(value, getColumnType(columnTypesByLowerName, col));
|
||||
});
|
||||
|
||||
return `INSERT INTO ${targetTable} (${quotedCols.join(', ')}) VALUES (${values.join(', ')});`;
|
||||
};
|
||||
|
||||
const buildCopyMutationSQL = (
|
||||
mode: 'update' | 'delete',
|
||||
{
|
||||
dbType,
|
||||
tableName,
|
||||
orderedCols,
|
||||
record,
|
||||
pkColumns = [],
|
||||
uniqueKeyGroups = [],
|
||||
allTableColumns = [],
|
||||
columnTypesByLowerName = {},
|
||||
}: BuildCopyMutationSQLParams,
|
||||
): CopyMutationSQLResult => {
|
||||
const normalizedTableName = String(tableName || '').trim();
|
||||
const normalizedOrderedCols = normalizeColumnList(orderedCols);
|
||||
if (!normalizedTableName) {
|
||||
return {
|
||||
ok: false,
|
||||
error: `当前结果集未关联明确表名,无法生成 ${mode.toUpperCase()} SQL。`,
|
||||
};
|
||||
}
|
||||
if (normalizedOrderedCols.length === 0) {
|
||||
return {
|
||||
ok: false,
|
||||
error: '当前结果集没有可复制的字段,无法生成 SQL。',
|
||||
};
|
||||
}
|
||||
|
||||
const whereClause = resolveMutationWhereClause({
|
||||
dbType,
|
||||
orderedCols: normalizedOrderedCols,
|
||||
record,
|
||||
pkColumns,
|
||||
uniqueKeyGroups,
|
||||
allTableColumns,
|
||||
columnTypesByLowerName,
|
||||
});
|
||||
if (whereClause.ok === false) {
|
||||
return { ok: false, error: whereClause.error };
|
||||
}
|
||||
|
||||
const targetTable = quoteQualifiedIdent(dbType, normalizedTableName);
|
||||
if (mode === 'delete') {
|
||||
return {
|
||||
ok: true,
|
||||
sql: `DELETE FROM ${targetTable} WHERE ${whereClause.clause};`,
|
||||
whereStrategy: whereClause.whereStrategy,
|
||||
};
|
||||
}
|
||||
|
||||
const assignments = normalizedOrderedCols.map((columnName) => {
|
||||
const { value } = getRecordValue(record, columnName);
|
||||
return `${quoteIdentPart(dbType, columnName)} = ${formatCopySqlLiteral(value, getColumnType(columnTypesByLowerName, columnName))}`;
|
||||
});
|
||||
|
||||
return {
|
||||
ok: true,
|
||||
sql: `UPDATE ${targetTable} SET ${assignments.join(', ')} WHERE ${whereClause.clause};`,
|
||||
whereStrategy: whereClause.whereStrategy,
|
||||
};
|
||||
};
|
||||
|
||||
export const buildCopyUpdateSQL = (params: BuildCopyMutationSQLParams): CopyMutationSQLResult => (
|
||||
buildCopyMutationSQL('update', params)
|
||||
);
|
||||
|
||||
export const buildCopyDeleteSQL = (params: BuildCopyMutationSQLParams): CopyMutationSQLResult => (
|
||||
buildCopyMutationSQL('delete', params)
|
||||
);
|
||||
|
||||
export const resolveUniqueKeyGroupsFromIndexes = (indexes: IndexDefinition[] | undefined): string[][] => {
|
||||
type IndexBucket = {
|
||||
order: number;
|
||||
columns: Array<{ columnName: string; seqInIndex: number; order: number }>;
|
||||
};
|
||||
|
||||
const buckets = new Map<string, IndexBucket>();
|
||||
(indexes || []).forEach((index, order) => {
|
||||
if (index?.nonUnique !== 0) {
|
||||
return;
|
||||
}
|
||||
const name = String(index?.name || '').trim();
|
||||
const columnName = String(index?.columnName || '').trim();
|
||||
if (!name || !columnName) {
|
||||
return;
|
||||
}
|
||||
if (!buckets.has(name)) {
|
||||
buckets.set(name, { order, columns: [] });
|
||||
}
|
||||
const bucket = buckets.get(name);
|
||||
if (!bucket) {
|
||||
return;
|
||||
}
|
||||
bucket.columns.push({
|
||||
columnName,
|
||||
seqInIndex: Number.isFinite(Number(index?.seqInIndex)) ? Number(index.seqInIndex) : 0,
|
||||
order,
|
||||
});
|
||||
});
|
||||
|
||||
return Array.from(buckets.values())
|
||||
.sort((left, right) => left.order - right.order)
|
||||
.map((bucket) => {
|
||||
const seen = new Set<string>();
|
||||
return bucket.columns
|
||||
.slice()
|
||||
.sort((left, right) => {
|
||||
const leftSeq = left.seqInIndex > 0 ? left.seqInIndex : Number.MAX_SAFE_INTEGER;
|
||||
const rightSeq = right.seqInIndex > 0 ? right.seqInIndex : Number.MAX_SAFE_INTEGER;
|
||||
if (leftSeq !== rightSeq) {
|
||||
return leftSeq - rightSeq;
|
||||
}
|
||||
return left.order - right.order;
|
||||
})
|
||||
.map((item) => item.columnName)
|
||||
.filter((columnName) => {
|
||||
const lowered = columnName.toLowerCase();
|
||||
if (seen.has(lowered)) {
|
||||
return false;
|
||||
}
|
||||
seen.add(lowered);
|
||||
return true;
|
||||
});
|
||||
})
|
||||
.filter((group) => group.length > 0);
|
||||
};
|
||||
|
||||
@@ -15,17 +15,98 @@ dayjs.locale('zh-cn')
|
||||
import 'monaco-editor/esm/nls.messages.zh-cn'
|
||||
import { loader } from '@monaco-editor/react'
|
||||
import * as monaco from 'monaco-editor'
|
||||
import { cloneBrowserMockValue, duplicateBrowserMockConnection, resolveBrowserMockSecretFlag } from './utils/browserMockConnections'
|
||||
loader.config({ monaco })
|
||||
|
||||
if (typeof window !== 'undefined' && !(window as any).go) {
|
||||
const mockConnections: any[] = [];
|
||||
let mockGlobalProxy: any = { enabled: false, type: 'socks5', host: '', port: 1080, user: '', password: '', hasPassword: false };
|
||||
|
||||
const upsertMockConnection = (view: any) => {
|
||||
const index = mockConnections.findIndex((item) => item.id === view.id);
|
||||
if (index >= 0) {
|
||||
mockConnections[index] = view;
|
||||
return;
|
||||
}
|
||||
mockConnections.push(view);
|
||||
};
|
||||
|
||||
const saveMockConnection = (input: any) => {
|
||||
const existing = mockConnections.find((item) => item.id === input?.id);
|
||||
const config = (input?.config && typeof input.config === 'object') ? input.config : {};
|
||||
const ssh = (config.ssh && typeof config.ssh === 'object') ? config.ssh : {};
|
||||
const proxy = (config.proxy && typeof config.proxy === 'object') ? config.proxy : {};
|
||||
const httpTunnel = (config.httpTunnel && typeof config.httpTunnel === 'object') ? config.httpTunnel : {};
|
||||
const nextId = String(input?.id || existing?.id || `mock-${Date.now()}-${Math.random().toString(36).slice(2, 8)}`);
|
||||
const view = {
|
||||
id: nextId,
|
||||
name: String(input?.name || existing?.name || '未命名连接'),
|
||||
config: {
|
||||
...config,
|
||||
id: nextId,
|
||||
password: '',
|
||||
ssh: { ...ssh, password: '' },
|
||||
proxy: { ...proxy, password: '' },
|
||||
httpTunnel: { ...httpTunnel, password: '' },
|
||||
uri: '',
|
||||
dsn: '',
|
||||
mysqlReplicaPassword: '',
|
||||
mongoReplicaPassword: '',
|
||||
},
|
||||
includeDatabases: Array.isArray(input?.includeDatabases) ? [...input.includeDatabases] : existing?.includeDatabases,
|
||||
includeRedisDatabases: Array.isArray(input?.includeRedisDatabases) ? [...input.includeRedisDatabases] : existing?.includeRedisDatabases,
|
||||
iconType: typeof input?.iconType === 'string' ? input.iconType : (existing?.iconType || ''),
|
||||
iconColor: typeof input?.iconColor === 'string' ? input.iconColor : (existing?.iconColor || ''),
|
||||
hasPrimaryPassword: resolveBrowserMockSecretFlag(config.password, !!input?.clearPrimaryPassword, existing?.hasPrimaryPassword),
|
||||
hasSSHPassword: resolveBrowserMockSecretFlag(ssh.password, !!input?.clearSSHPassword, existing?.hasSSHPassword),
|
||||
hasProxyPassword: resolveBrowserMockSecretFlag(proxy.password, !!input?.clearProxyPassword, existing?.hasProxyPassword),
|
||||
hasHttpTunnelPassword: resolveBrowserMockSecretFlag(httpTunnel.password, !!input?.clearHttpTunnelPassword, existing?.hasHttpTunnelPassword),
|
||||
hasMySQLReplicaPassword: resolveBrowserMockSecretFlag(config.mysqlReplicaPassword, !!input?.clearMySQLReplicaPassword, existing?.hasMySQLReplicaPassword),
|
||||
hasMongoReplicaPassword: resolveBrowserMockSecretFlag(config.mongoReplicaPassword, !!input?.clearMongoReplicaPassword, existing?.hasMongoReplicaPassword),
|
||||
hasOpaqueURI: resolveBrowserMockSecretFlag(config.uri, !!input?.clearOpaqueURI, existing?.hasOpaqueURI),
|
||||
hasOpaqueDSN: resolveBrowserMockSecretFlag(config.dsn, !!input?.clearOpaqueDSN, existing?.hasOpaqueDSN),
|
||||
};
|
||||
upsertMockConnection(view);
|
||||
return cloneBrowserMockValue(view);
|
||||
};
|
||||
|
||||
const saveMockGlobalProxy = (input: any) => {
|
||||
const nextPassword = String(input?.password ?? '');
|
||||
mockGlobalProxy = {
|
||||
...mockGlobalProxy,
|
||||
...input,
|
||||
password: '',
|
||||
hasPassword: nextPassword !== '' ? true : !!mockGlobalProxy.hasPassword,
|
||||
};
|
||||
return cloneBrowserMockValue(mockGlobalProxy);
|
||||
};
|
||||
|
||||
(window as any).go = {
|
||||
app: {
|
||||
App: {
|
||||
CheckUpdate: async () => ({ success: false }),
|
||||
DownloadUpdate: async () => ({ success: false }),
|
||||
GetSavedConnections: async () => [],
|
||||
SaveConnection: async () => null,
|
||||
DeleteConnection: async () => null,
|
||||
GetSavedConnections: async () => cloneBrowserMockValue(mockConnections),
|
||||
SaveConnection: async (input: any) => saveMockConnection(input),
|
||||
DeleteConnection: async (id: string) => {
|
||||
const index = mockConnections.findIndex((item) => item.id === id);
|
||||
if (index >= 0) {
|
||||
mockConnections.splice(index, 1);
|
||||
}
|
||||
return null;
|
||||
},
|
||||
DuplicateConnection: async (id: string) => {
|
||||
const existing = mockConnections.find((item) => item.id === id);
|
||||
if (!existing) return null;
|
||||
const duplicated = duplicateBrowserMockConnection({
|
||||
existing,
|
||||
items: mockConnections,
|
||||
nextId: `mock-${Date.now()}-${Math.random().toString(36).slice(2, 8)}`,
|
||||
});
|
||||
mockConnections.push(duplicated);
|
||||
return cloneBrowserMockValue(duplicated);
|
||||
},
|
||||
ImportLegacyConnections: async (items: any[]) => items.map((item) => saveMockConnection(item)),
|
||||
OpenConnection: async () => null,
|
||||
CloseConnection: async () => null,
|
||||
GetDatabases: async () => [],
|
||||
@@ -42,11 +123,13 @@ if (typeof window !== 'undefined' && !(window as any).go) {
|
||||
InstallUpdateAndRestart: async () => ({ success: false }),
|
||||
ImportConfigFile: async () => ({ success: false }),
|
||||
ExportData: async () => ({ success: false }),
|
||||
GetGlobalProxyConfig: async () => ({ success: true, data: cloneBrowserMockValue(mockGlobalProxy) }),
|
||||
SaveGlobalProxy: async (input: any) => saveMockGlobalProxy(input),
|
||||
ImportLegacyGlobalProxy: async (input: any) => saveMockGlobalProxy(input),
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// 全局注册透明主题,避免每个 Editor 组件 beforeMount 中重复定义
|
||||
monaco.editor.defineTheme('transparent-dark', {
|
||||
base: 'vs-dark', inherit: true, rules: [],
|
||||
@@ -62,3 +145,6 @@ ReactDOM.createRoot(document.getElementById('root')!).render(
|
||||
<App />
|
||||
</React.StrictMode>,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
94
frontend/src/store.test.ts
Normal file
94
frontend/src/store.test.ts
Normal file
@@ -0,0 +1,94 @@
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
class MemoryStorage implements Storage {
|
||||
private data = new Map<string, string>();
|
||||
|
||||
get length(): number {
|
||||
return this.data.size;
|
||||
}
|
||||
|
||||
clear(): void {
|
||||
this.data.clear();
|
||||
}
|
||||
|
||||
getItem(key: string): string | null {
|
||||
return this.data.has(key) ? this.data.get(key)! : null;
|
||||
}
|
||||
|
||||
key(index: number): string | null {
|
||||
return Array.from(this.data.keys())[index] ?? null;
|
||||
}
|
||||
|
||||
removeItem(key: string): void {
|
||||
this.data.delete(key);
|
||||
}
|
||||
|
||||
setItem(key: string, value: string): void {
|
||||
this.data.set(key, String(value));
|
||||
}
|
||||
}
|
||||
|
||||
const importStore = async () => {
|
||||
const store = await import('./store');
|
||||
await store.useStore.persist.rehydrate();
|
||||
return store;
|
||||
};
|
||||
|
||||
describe('store appearance persistence', () => {
|
||||
let storage: MemoryStorage;
|
||||
|
||||
beforeEach(() => {
|
||||
storage = new MemoryStorage();
|
||||
vi.stubGlobal('localStorage', storage);
|
||||
vi.resetModules();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.unstubAllGlobals();
|
||||
vi.resetModules();
|
||||
});
|
||||
|
||||
it('fills missing DataGrid appearance settings with defaults during hydration', async () => {
|
||||
storage.setItem('lite-db-storage', JSON.stringify({
|
||||
state: {
|
||||
appearance: {
|
||||
enabled: false,
|
||||
opacity: 0.75,
|
||||
blur: 6,
|
||||
useNativeMacWindowControls: true,
|
||||
},
|
||||
},
|
||||
version: 7,
|
||||
}));
|
||||
|
||||
const { useStore } = await importStore();
|
||||
const appearance = useStore.getState().appearance;
|
||||
|
||||
expect(appearance.enabled).toBe(false);
|
||||
expect(appearance.opacity).toBe(0.75);
|
||||
expect(appearance.blur).toBe(6);
|
||||
expect(appearance.useNativeMacWindowControls).toBe(true);
|
||||
expect(appearance.showDataTableVerticalBorders).toBe(false);
|
||||
expect(appearance.dataTableColumnWidthMode).toBe('standard');
|
||||
});
|
||||
|
||||
it('persists DataGrid appearance settings and restores them after reload', async () => {
|
||||
const { useStore } = await importStore();
|
||||
|
||||
useStore.getState().setAppearance({
|
||||
showDataTableVerticalBorders: true,
|
||||
dataTableColumnWidthMode: 'compact',
|
||||
});
|
||||
|
||||
const persisted = JSON.parse(storage.getItem('lite-db-storage') || '{}');
|
||||
expect(persisted.state.appearance.showDataTableVerticalBorders).toBe(true);
|
||||
expect(persisted.state.appearance.dataTableColumnWidthMode).toBe('compact');
|
||||
|
||||
vi.resetModules();
|
||||
const reloaded = await importStore();
|
||||
const appearance = reloaded.useStore.getState().appearance;
|
||||
|
||||
expect(appearance.showDataTableVerticalBorders).toBe(true);
|
||||
expect(appearance.dataTableColumnWidthMode).toBe('compact');
|
||||
});
|
||||
});
|
||||
@@ -1,6 +1,6 @@
|
||||
import { create } from 'zustand';
|
||||
import { persist } from 'zustand/middleware';
|
||||
import { ConnectionConfig, ProxyConfig, SavedConnection, TabData, SavedQuery, ConnectionTag, AIChatMessage, AIContextItem } from './types';
|
||||
import { ConnectionConfig, ProxyConfig, SavedConnection, TabData, SavedQuery, ConnectionTag, AIChatMessage, AIContextItem, GlobalProxyConfig } from './types';
|
||||
import {
|
||||
ShortcutAction,
|
||||
ShortcutBinding,
|
||||
@@ -9,8 +9,27 @@ import {
|
||||
cloneShortcutOptions,
|
||||
sanitizeShortcutOptions,
|
||||
} from './utils/shortcuts';
|
||||
import { toPersistedGlobalProxy } from './utils/globalProxyDraft';
|
||||
import {
|
||||
DEFAULT_DATA_GRID_DISPLAY_SETTINGS,
|
||||
sanitizeDataGridDisplaySettings,
|
||||
type DataGridDisplaySettings,
|
||||
} from './utils/dataGridDisplay';
|
||||
|
||||
const DEFAULT_APPEARANCE = { enabled: true, opacity: 1.0, blur: 0, useNativeMacWindowControls: false };
|
||||
export interface AppearanceSettings extends DataGridDisplaySettings {
|
||||
enabled: boolean;
|
||||
opacity: number;
|
||||
blur: number;
|
||||
useNativeMacWindowControls: boolean;
|
||||
}
|
||||
|
||||
export const DEFAULT_APPEARANCE: AppearanceSettings = {
|
||||
enabled: true,
|
||||
opacity: 1.0,
|
||||
blur: 0,
|
||||
useNativeMacWindowControls: false,
|
||||
...DEFAULT_DATA_GRID_DISPLAY_SETTINGS,
|
||||
};
|
||||
const DEFAULT_UI_SCALE = 1.0;
|
||||
const MIN_UI_SCALE = 0.8;
|
||||
const MAX_UI_SCALE = 1.25;
|
||||
@@ -25,7 +44,7 @@ const MAX_HOST_ENTRY_LENGTH = 512;
|
||||
const MAX_HOST_ENTRIES = 64;
|
||||
const DEFAULT_TIMEOUT_SECONDS = 30;
|
||||
const MAX_TIMEOUT_SECONDS = 3600;
|
||||
const PERSIST_VERSION = 7;
|
||||
const PERSIST_VERSION = 8;
|
||||
const DEFAULT_CONNECTION_TYPE = 'mysql';
|
||||
const DEFAULT_GLOBAL_PROXY: GlobalProxyConfig = {
|
||||
enabled: false,
|
||||
@@ -34,6 +53,7 @@ const DEFAULT_GLOBAL_PROXY: GlobalProxyConfig = {
|
||||
port: 1080,
|
||||
user: '',
|
||||
password: '',
|
||||
hasPassword: false,
|
||||
};
|
||||
const SUPPORTED_CONNECTION_TYPES = new Set([
|
||||
'mysql',
|
||||
@@ -246,6 +266,7 @@ const sanitizeConnectionConfig = (value: unknown): ConnectionConfig => {
|
||||
|
||||
const safeConfig: ConnectionConfig & Record<string, unknown> = {
|
||||
...raw,
|
||||
id: toTrimmedString(raw.id ?? raw.ID),
|
||||
type,
|
||||
host: toTrimmedString(raw.host, 'localhost') || 'localhost',
|
||||
port: normalizePort(raw.port, defaultPort),
|
||||
@@ -321,7 +342,16 @@ const sanitizeSavedConnection = (value: unknown, index: number): SavedConnection
|
||||
return {
|
||||
id,
|
||||
name,
|
||||
config,
|
||||
config: { ...config, id: config.id || id },
|
||||
secretRef: toTrimmedString(raw.secretRef) || undefined,
|
||||
hasPrimaryPassword: raw.hasPrimaryPassword === true,
|
||||
hasSSHPassword: raw.hasSSHPassword === true,
|
||||
hasProxyPassword: raw.hasProxyPassword === true,
|
||||
hasHttpTunnelPassword: raw.hasHttpTunnelPassword === true,
|
||||
hasMySQLReplicaPassword: raw.hasMySQLReplicaPassword === true,
|
||||
hasMongoReplicaPassword: raw.hasMongoReplicaPassword === true,
|
||||
hasOpaqueURI: raw.hasOpaqueURI === true,
|
||||
hasOpaqueDSN: raw.hasOpaqueDSN === true,
|
||||
includeDatabases: includeDatabases.length > 0 ? includeDatabases : undefined,
|
||||
includeRedisDatabases: includeRedisDatabases.length > 0 ? includeRedisDatabases : undefined,
|
||||
};
|
||||
@@ -393,10 +423,6 @@ export interface QueryOptions {
|
||||
showColumnType: boolean;
|
||||
}
|
||||
|
||||
export interface GlobalProxyConfig extends ProxyConfig {
|
||||
enabled: boolean;
|
||||
}
|
||||
|
||||
interface AppState {
|
||||
connections: SavedConnection[];
|
||||
connectionTags: ConnectionTag[];
|
||||
@@ -405,7 +431,7 @@ interface AppState {
|
||||
activeContext: { connectionId: string; dbName: string } | null;
|
||||
savedQueries: SavedQuery[];
|
||||
theme: 'light' | 'dark';
|
||||
appearance: { enabled: boolean; opacity: number; blur: number; useNativeMacWindowControls: boolean };
|
||||
appearance: AppearanceSettings;
|
||||
uiScale: number;
|
||||
fontSize: number;
|
||||
startupFullscreen: boolean;
|
||||
@@ -440,6 +466,7 @@ interface AppState {
|
||||
addConnection: (conn: SavedConnection) => void;
|
||||
updateConnection: (conn: SavedConnection) => void;
|
||||
removeConnection: (id: string) => void;
|
||||
replaceConnections: (connections: SavedConnection[]) => void;
|
||||
|
||||
addConnectionTag: (tag: ConnectionTag) => void;
|
||||
updateConnectionTag: (tag: ConnectionTag) => void;
|
||||
@@ -463,11 +490,12 @@ interface AppState {
|
||||
deleteQuery: (id: string) => void;
|
||||
|
||||
setTheme: (theme: 'light' | 'dark') => void;
|
||||
setAppearance: (appearance: Partial<{ enabled: boolean; opacity: number; blur: number; useNativeMacWindowControls: boolean }>) => void;
|
||||
setAppearance: (appearance: Partial<AppearanceSettings>) => void;
|
||||
setUiScale: (scale: number) => void;
|
||||
setFontSize: (size: number) => void;
|
||||
setStartupFullscreen: (enabled: boolean) => void;
|
||||
setGlobalProxy: (proxy: Partial<GlobalProxyConfig>) => void;
|
||||
replaceGlobalProxy: (proxy: Partial<GlobalProxyConfig>) => void;
|
||||
setSqlFormatOptions: (options: { keywordCase: 'upper' | 'lower' }) => void;
|
||||
setQueryOptions: (options: Partial<QueryOptions>) => void;
|
||||
updateShortcut: (action: ShortcutAction, binding: Partial<ShortcutBinding>) => void;
|
||||
@@ -586,12 +614,13 @@ const sanitizeTableHiddenColumns = (value: unknown): Record<string, string[]> =>
|
||||
};
|
||||
|
||||
const sanitizeAppearance = (
|
||||
appearance: Partial<{ enabled: boolean; opacity: number; blur: number; useNativeMacWindowControls: boolean }> | undefined,
|
||||
appearance: Partial<AppearanceSettings> | undefined,
|
||||
version: number
|
||||
): { enabled: boolean; opacity: number; blur: number; useNativeMacWindowControls: boolean } => {
|
||||
): AppearanceSettings => {
|
||||
if (!appearance || typeof appearance !== 'object') {
|
||||
return { ...DEFAULT_APPEARANCE };
|
||||
}
|
||||
const dataGridDisplaySettings = sanitizeDataGridDisplaySettings(appearance);
|
||||
const nextAppearance = {
|
||||
enabled: typeof appearance.enabled === 'boolean' ? appearance.enabled : DEFAULT_APPEARANCE.enabled,
|
||||
opacity: typeof appearance.opacity === 'number' ? appearance.opacity : DEFAULT_APPEARANCE.opacity,
|
||||
@@ -599,6 +628,8 @@ const sanitizeAppearance = (
|
||||
useNativeMacWindowControls: typeof appearance.useNativeMacWindowControls === 'boolean'
|
||||
? appearance.useNativeMacWindowControls
|
||||
: DEFAULT_APPEARANCE.useNativeMacWindowControls,
|
||||
showDataTableVerticalBorders: dataGridDisplaySettings.showDataTableVerticalBorders,
|
||||
dataTableColumnWidthMode: dataGridDisplaySettings.dataTableColumnWidthMode,
|
||||
};
|
||||
if (version < 2 && isLegacyDefaultAppearance(appearance)) {
|
||||
return { ...DEFAULT_APPEARANCE };
|
||||
@@ -618,18 +649,24 @@ const sanitizeFontSize = (value: unknown): number => {
|
||||
return normalizeIntegerInRange(value, DEFAULT_FONT_SIZE, MIN_FONT_SIZE, MAX_FONT_SIZE);
|
||||
};
|
||||
|
||||
const sanitizeGlobalProxy = (value: unknown): GlobalProxyConfig => {
|
||||
const sanitizeGlobalProxy = (
|
||||
value: unknown,
|
||||
options: { allowPassword?: boolean } = {}
|
||||
): GlobalProxyConfig => {
|
||||
const raw = (value && typeof value === 'object') ? value as Record<string, unknown> : {};
|
||||
const typeRaw = toTrimmedString(raw.type, DEFAULT_GLOBAL_PROXY.type).toLowerCase();
|
||||
const type: 'socks5' | 'http' = typeRaw === 'http' ? 'http' : 'socks5';
|
||||
const fallbackPort = type === 'http' ? 8080 : 1080;
|
||||
const password = toTrimmedString(raw.password);
|
||||
return {
|
||||
enabled: raw.enabled === true,
|
||||
type,
|
||||
host: toTrimmedString(raw.host),
|
||||
port: normalizePort(raw.port, fallbackPort),
|
||||
user: toTrimmedString(raw.user),
|
||||
password: toTrimmedString(raw.password),
|
||||
password: options.allowPassword === false ? '' : password,
|
||||
hasPassword: raw.hasPassword === true || password !== '',
|
||||
secretRef: toTrimmedString(raw.secretRef) || undefined,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -782,6 +819,7 @@ export const useStore = create<AppState>()(
|
||||
connectionIds: tag.connectionIds.filter(cid => cid !== id)
|
||||
}))
|
||||
})),
|
||||
replaceConnections: (connections) => set({ connections: sanitizeConnections(connections) }),
|
||||
|
||||
addConnectionTag: (tag) => set((state) => ({ connectionTags: [...state.connectionTags, tag] })),
|
||||
updateConnectionTag: (tag) => set((state) => ({
|
||||
@@ -963,6 +1001,7 @@ export const useStore = create<AppState>()(
|
||||
setFontSize: (size) => set({ fontSize: sanitizeFontSize(size) }),
|
||||
setStartupFullscreen: (enabled) => set({ startupFullscreen: !!enabled }),
|
||||
setGlobalProxy: (proxy) => set((state) => ({ globalProxy: sanitizeGlobalProxy({ ...state.globalProxy, ...proxy }) })),
|
||||
replaceGlobalProxy: (proxy) => set({ globalProxy: sanitizeGlobalProxy({ ...DEFAULT_GLOBAL_PROXY, ...proxy }) }),
|
||||
setSqlFormatOptions: (options) => set({ sqlFormatOptions: options }),
|
||||
setQueryOptions: (options) => set((state) => ({ queryOptions: { ...state.queryOptions, ...options } })),
|
||||
updateShortcut: (action, binding) => set((state) => ({
|
||||
@@ -1203,7 +1242,7 @@ export const useStore = create<AppState>()(
|
||||
migrate: (persistedState: unknown, version: number) => {
|
||||
const state = unwrapPersistedAppState(persistedState) as Partial<AppState>;
|
||||
const nextState: Partial<AppState> = { ...state };
|
||||
nextState.connections = sanitizeConnections(state.connections);
|
||||
nextState.connections = [];
|
||||
if (version < 5) {
|
||||
nextState.connectionTags = sanitizeConnectionTags(state.connectionTags);
|
||||
} else {
|
||||
@@ -1215,7 +1254,7 @@ export const useStore = create<AppState>()(
|
||||
nextState.uiScale = sanitizeUiScale(state.uiScale);
|
||||
nextState.fontSize = sanitizeFontSize(state.fontSize);
|
||||
nextState.startupFullscreen = sanitizeStartupFullscreen(state.startupFullscreen);
|
||||
nextState.globalProxy = sanitizeGlobalProxy(state.globalProxy);
|
||||
nextState.globalProxy = sanitizeGlobalProxy(state.globalProxy, { allowPassword: false });
|
||||
nextState.sqlFormatOptions = sanitizeSqlFormatOptions(state.sqlFormatOptions);
|
||||
nextState.queryOptions = sanitizeQueryOptions(state.queryOptions);
|
||||
nextState.shortcutOptions = sanitizeShortcutOptions(state.shortcutOptions);
|
||||
@@ -1242,7 +1281,7 @@ export const useStore = create<AppState>()(
|
||||
return {
|
||||
...currentState,
|
||||
...state,
|
||||
connections: sanitizeConnections(state.connections),
|
||||
connections: currentState.connections,
|
||||
connectionTags: sanitizeConnectionTags(state.connectionTags),
|
||||
savedQueries: sanitizeSavedQueries(state.savedQueries),
|
||||
theme: sanitizeTheme(state.theme),
|
||||
@@ -1250,7 +1289,7 @@ export const useStore = create<AppState>()(
|
||||
uiScale: sanitizeUiScale(state.uiScale),
|
||||
fontSize: sanitizeFontSize(state.fontSize),
|
||||
startupFullscreen: sanitizeStartupFullscreen(state.startupFullscreen),
|
||||
globalProxy: sanitizeGlobalProxy(state.globalProxy),
|
||||
globalProxy: sanitizeGlobalProxy(state.globalProxy, { allowPassword: false }),
|
||||
tableSortPreference: sanitizeTableSortPreference(state.tableSortPreference),
|
||||
tableColumnOrders: sanitizeTableColumnOrders(state.tableColumnOrders),
|
||||
enableColumnOrderMemory: state.enableColumnOrderMemory !== false,
|
||||
@@ -1271,7 +1310,6 @@ export const useStore = create<AppState>()(
|
||||
};
|
||||
},
|
||||
partialize: (state) => ({
|
||||
connections: state.connections,
|
||||
connectionTags: state.connectionTags,
|
||||
savedQueries: state.savedQueries,
|
||||
theme: state.theme,
|
||||
@@ -1279,7 +1317,7 @@ export const useStore = create<AppState>()(
|
||||
uiScale: state.uiScale,
|
||||
fontSize: state.fontSize,
|
||||
startupFullscreen: state.startupFullscreen,
|
||||
globalProxy: state.globalProxy,
|
||||
globalProxy: toPersistedGlobalProxy(state.globalProxy),
|
||||
sqlFormatOptions: state.sqlFormatOptions,
|
||||
queryOptions: state.queryOptions,
|
||||
shortcutOptions: state.shortcutOptions,
|
||||
|
||||
@@ -22,6 +22,7 @@ export interface HTTPTunnelConfig {
|
||||
}
|
||||
|
||||
export interface ConnectionConfig {
|
||||
id?: string;
|
||||
type: string;
|
||||
host: string;
|
||||
port: number;
|
||||
@@ -70,12 +71,27 @@ export interface SavedConnection {
|
||||
id: string;
|
||||
name: string;
|
||||
config: ConnectionConfig;
|
||||
secretRef?: string;
|
||||
hasPrimaryPassword?: boolean;
|
||||
hasSSHPassword?: boolean;
|
||||
hasProxyPassword?: boolean;
|
||||
hasHttpTunnelPassword?: boolean;
|
||||
hasMySQLReplicaPassword?: boolean;
|
||||
hasMongoReplicaPassword?: boolean;
|
||||
hasOpaqueURI?: boolean;
|
||||
hasOpaqueDSN?: boolean;
|
||||
includeDatabases?: string[];
|
||||
includeRedisDatabases?: number[]; // Redis databases to show (0-15)
|
||||
iconType?: string; // 自定义图标类型(如 'mysql','postgres'),不填则取 config.type
|
||||
iconColor?: string; // 自定义图标颜色(十六进制),不填则取类型默认色
|
||||
}
|
||||
|
||||
export interface GlobalProxyConfig extends ProxyConfig {
|
||||
enabled: boolean;
|
||||
hasPassword?: boolean;
|
||||
secretRef?: string;
|
||||
}
|
||||
|
||||
export interface ConnectionTag {
|
||||
id: string;
|
||||
name: string;
|
||||
@@ -201,6 +217,8 @@ export interface AIProviderConfig {
|
||||
type: AIProviderType;
|
||||
name: string;
|
||||
apiKey: string;
|
||||
secretRef?: string;
|
||||
hasSecret?: boolean;
|
||||
baseUrl: string;
|
||||
model: string;
|
||||
models?: string[];
|
||||
@@ -243,3 +261,5 @@ export interface AISafetyResult {
|
||||
requiresConfirm: boolean;
|
||||
warningMessage?: string;
|
||||
}
|
||||
|
||||
|
||||
|
||||
49
frontend/src/utils/aiProviderEditorState.test.ts
Normal file
49
frontend/src/utils/aiProviderEditorState.test.ts
Normal file
@@ -0,0 +1,49 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import {
|
||||
buildAddProviderEditorSession,
|
||||
buildClosedProviderEditorSession,
|
||||
buildEditProviderEditorSession,
|
||||
} from './aiProviderEditorState';
|
||||
|
||||
describe('aiProviderEditorState', () => {
|
||||
it('resets clearProviderSecret when starting add flow', () => {
|
||||
const session = buildAddProviderEditorSession({
|
||||
previousClearProviderSecret: true,
|
||||
presetBackendType: 'openai',
|
||||
presetBaseUrl: 'https://api.openai.com/v1',
|
||||
presetModel: 'gpt-4.1',
|
||||
});
|
||||
|
||||
expect(session.clearProviderSecret).toBe(false);
|
||||
expect(session.isEditing).toBe(true);
|
||||
expect(session.testStatus).toBe('idle');
|
||||
});
|
||||
|
||||
it('resets clearProviderSecret when starting edit flow', () => {
|
||||
const session = buildEditProviderEditorSession({
|
||||
previousClearProviderSecret: true,
|
||||
provider: {
|
||||
id: 'provider-1',
|
||||
type: 'openai',
|
||||
name: 'OpenAI',
|
||||
apiKey: '',
|
||||
hasSecret: true,
|
||||
},
|
||||
});
|
||||
|
||||
expect(session.clearProviderSecret).toBe(false);
|
||||
expect(session.isEditing).toBe(true);
|
||||
expect(session.editingProvider?.id).toBe('provider-1');
|
||||
});
|
||||
|
||||
it('resets clearProviderSecret when the modal closes', () => {
|
||||
const session = buildClosedProviderEditorSession({
|
||||
previousClearProviderSecret: true,
|
||||
});
|
||||
|
||||
expect(session.clearProviderSecret).toBe(false);
|
||||
expect(session.isEditing).toBe(false);
|
||||
expect(session.editingProvider).toBeNull();
|
||||
});
|
||||
});
|
||||
92
frontend/src/utils/aiProviderEditorState.ts
Normal file
92
frontend/src/utils/aiProviderEditorState.ts
Normal file
@@ -0,0 +1,92 @@
|
||||
import type { AIProviderConfig, AIProviderType } from '../types';
|
||||
|
||||
type ProviderEditorStatus = 'idle' | 'success' | 'error';
|
||||
|
||||
type ProviderEditorConfig = Partial<AIProviderConfig> & Pick<AIProviderConfig, 'id' | 'type' | 'name' | 'apiKey'> & { presetKey?: string };
|
||||
|
||||
export interface ProviderEditorSession {
|
||||
editingProvider: ProviderEditorConfig | null;
|
||||
formValues: Record<string, unknown> | null;
|
||||
isEditing: boolean;
|
||||
clearProviderSecret: boolean;
|
||||
testStatus: ProviderEditorStatus;
|
||||
}
|
||||
|
||||
interface BuildAddProviderEditorSessionInput {
|
||||
previousClearProviderSecret?: boolean;
|
||||
presetKey?: string;
|
||||
presetBackendType: AIProviderType;
|
||||
presetBaseUrl: string;
|
||||
presetModel: string;
|
||||
presetModels?: string[];
|
||||
apiFormat?: string;
|
||||
}
|
||||
|
||||
interface BuildEditProviderEditorSessionInput {
|
||||
previousClearProviderSecret?: boolean;
|
||||
provider: ProviderEditorConfig;
|
||||
formValues?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
interface BuildClosedProviderEditorSessionInput {
|
||||
previousClearProviderSecret?: boolean;
|
||||
}
|
||||
|
||||
export const buildAddProviderEditorSession = ({
|
||||
presetKey = 'openai',
|
||||
presetBackendType,
|
||||
presetBaseUrl,
|
||||
presetModel,
|
||||
presetModels = [],
|
||||
apiFormat = 'openai',
|
||||
}: BuildAddProviderEditorSessionInput): ProviderEditorSession => {
|
||||
const editingProvider: ProviderEditorConfig = {
|
||||
id: '',
|
||||
type: presetBackendType,
|
||||
name: '',
|
||||
apiKey: '',
|
||||
baseUrl: presetBaseUrl,
|
||||
model: presetModel,
|
||||
models: [...presetModels],
|
||||
maxTokens: 4096,
|
||||
temperature: 0.7,
|
||||
presetKey,
|
||||
};
|
||||
|
||||
return {
|
||||
editingProvider,
|
||||
formValues: {
|
||||
...editingProvider,
|
||||
presetKey,
|
||||
apiFormat,
|
||||
},
|
||||
isEditing: true,
|
||||
clearProviderSecret: false,
|
||||
testStatus: 'idle',
|
||||
};
|
||||
};
|
||||
|
||||
export const buildEditProviderEditorSession = ({
|
||||
provider,
|
||||
formValues,
|
||||
}: BuildEditProviderEditorSessionInput): ProviderEditorSession => ({
|
||||
editingProvider: provider,
|
||||
formValues: formValues || {
|
||||
...provider,
|
||||
models: provider.models || [],
|
||||
presetKey: provider.presetKey,
|
||||
apiFormat: provider.apiFormat || 'openai',
|
||||
},
|
||||
isEditing: true,
|
||||
clearProviderSecret: false,
|
||||
testStatus: 'idle',
|
||||
});
|
||||
|
||||
export const buildClosedProviderEditorSession = (_input?: BuildClosedProviderEditorSessionInput): ProviderEditorSession => ({
|
||||
editingProvider: null,
|
||||
formValues: null,
|
||||
isEditing: false,
|
||||
clearProviderSecret: false,
|
||||
testStatus: 'idle',
|
||||
});
|
||||
|
||||
26
frontend/src/utils/browserMockConnections.test.ts
Normal file
26
frontend/src/utils/browserMockConnections.test.ts
Normal file
@@ -0,0 +1,26 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { duplicateBrowserMockConnection } from './browserMockConnections';
|
||||
|
||||
describe('duplicateBrowserMockConnection', () => {
|
||||
it('rewrites config.id to match the duplicated top-level id', () => {
|
||||
const duplicated = duplicateBrowserMockConnection({
|
||||
existing: {
|
||||
id: 'conn-1',
|
||||
name: 'Primary',
|
||||
config: {
|
||||
id: 'conn-1',
|
||||
type: 'postgres',
|
||||
},
|
||||
includeDatabases: ['appdb'],
|
||||
},
|
||||
items: [],
|
||||
nextId: 'conn-2',
|
||||
});
|
||||
|
||||
expect(duplicated.id).toBe('conn-2');
|
||||
expect(duplicated.config.id).toBe('conn-2');
|
||||
expect(duplicated.name).toBe('Primary - 副本');
|
||||
expect(duplicated.includeDatabases).toEqual(['appdb']);
|
||||
});
|
||||
});
|
||||
47
frontend/src/utils/browserMockConnections.ts
Normal file
47
frontend/src/utils/browserMockConnections.ts
Normal file
@@ -0,0 +1,47 @@
|
||||
export const cloneBrowserMockValue = <T,>(value: T): T => {
|
||||
try {
|
||||
return JSON.parse(JSON.stringify(value));
|
||||
} catch {
|
||||
return value;
|
||||
}
|
||||
};
|
||||
|
||||
export const resolveBrowserMockSecretFlag = (nextValue: unknown, clearFlag: boolean, existingFlag?: boolean) => {
|
||||
if (String(nextValue ?? '') !== '') return true;
|
||||
if (clearFlag) return false;
|
||||
return !!existingFlag;
|
||||
};
|
||||
|
||||
export const buildBrowserMockDuplicateName = (rawName: string, items: any[]): string => {
|
||||
const baseName = String(rawName || '').trim() || '连接';
|
||||
const suffix = ' - 副本';
|
||||
const usedNames = new Set(items.map((item) => String(item?.name || '').trim()));
|
||||
let candidate = `${baseName}${suffix}`;
|
||||
let counter = 2;
|
||||
while (usedNames.has(candidate)) {
|
||||
candidate = `${baseName}${suffix} ${counter}`;
|
||||
counter += 1;
|
||||
}
|
||||
return candidate;
|
||||
};
|
||||
|
||||
interface DuplicateBrowserMockConnectionInput {
|
||||
existing: any;
|
||||
items: any[];
|
||||
nextId: string;
|
||||
}
|
||||
|
||||
export const duplicateBrowserMockConnection = ({ existing, items, nextId }: DuplicateBrowserMockConnectionInput) => {
|
||||
const duplicated = cloneBrowserMockValue({
|
||||
...existing,
|
||||
id: nextId,
|
||||
name: buildBrowserMockDuplicateName(existing?.name, items),
|
||||
config: {
|
||||
...cloneBrowserMockValue(existing?.config),
|
||||
id: nextId,
|
||||
},
|
||||
includeDatabases: Array.isArray(existing?.includeDatabases) ? [...existing.includeDatabases] : undefined,
|
||||
includeRedisDatabases: Array.isArray(existing?.includeRedisDatabases) ? [...existing.includeRedisDatabases] : undefined,
|
||||
});
|
||||
return duplicated;
|
||||
};
|
||||
104
frontend/src/utils/connectionRpcConfig.test.ts
Normal file
104
frontend/src/utils/connectionRpcConfig.test.ts
Normal file
@@ -0,0 +1,104 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { connection } from '../../wailsjs/go/models';
|
||||
import { buildRpcConnectionConfig } from './connectionRpcConfig';
|
||||
|
||||
describe('buildRpcConnectionConfig', () => {
|
||||
it('preserves the saved connection id while normalizing numeric fields', () => {
|
||||
const result = buildRpcConnectionConfig({
|
||||
id: 'conn-1',
|
||||
type: 'postgres',
|
||||
host: 'db.local',
|
||||
port: '5432' as unknown as number,
|
||||
user: 'postgres',
|
||||
useSSH: true,
|
||||
ssh: {
|
||||
host: 'bastion.local',
|
||||
port: '2222' as unknown as number,
|
||||
user: 'ops',
|
||||
},
|
||||
useProxy: true,
|
||||
proxy: {
|
||||
type: 'http',
|
||||
host: '127.0.0.1',
|
||||
port: '8080' as unknown as number,
|
||||
},
|
||||
} as any, {
|
||||
id: 'conn-2',
|
||||
timeout: '120' as unknown as number,
|
||||
redisDB: '6' as unknown as number,
|
||||
database: 'app',
|
||||
});
|
||||
|
||||
expect(result.id).toBe('conn-1');
|
||||
expect(result.port).toBe(5432);
|
||||
expect(result.ssh?.port).toBe(2222);
|
||||
expect(result.proxy?.port).toBe(8080);
|
||||
expect(result.timeout).toBe(120);
|
||||
expect(result.redisDB).toBe(6);
|
||||
expect(result.database).toBe('app');
|
||||
});
|
||||
|
||||
it('fills default nested config blocks needed by RPC calls', () => {
|
||||
const result = buildRpcConnectionConfig({
|
||||
id: 'conn-redis',
|
||||
type: 'redis',
|
||||
host: '127.0.0.1',
|
||||
port: 6379,
|
||||
user: '',
|
||||
} as any, {
|
||||
useSSH: true,
|
||||
useHttpTunnel: true,
|
||||
redisDB: '4' as unknown as number,
|
||||
});
|
||||
|
||||
expect(result.id).toBe('conn-redis');
|
||||
expect(result.redisDB).toBe(4);
|
||||
expect(result.ssh).toEqual({
|
||||
host: '',
|
||||
port: 22,
|
||||
user: '',
|
||||
password: '',
|
||||
keyPath: '',
|
||||
});
|
||||
expect(result.httpTunnel).toEqual({
|
||||
host: '',
|
||||
port: 8080,
|
||||
user: '',
|
||||
password: '',
|
||||
});
|
||||
});
|
||||
|
||||
it('returns a Wails connection model instance for RPC compatibility', () => {
|
||||
const result = buildRpcConnectionConfig({
|
||||
id: 'conn-model',
|
||||
type: 'mysql',
|
||||
host: '127.0.0.1',
|
||||
port: '3306' as unknown as number,
|
||||
user: 'root',
|
||||
useSSH: true,
|
||||
ssh: {
|
||||
host: 'jump.local',
|
||||
port: '2222' as unknown as number,
|
||||
user: 'ops',
|
||||
},
|
||||
useProxy: true,
|
||||
proxy: {
|
||||
type: 'http',
|
||||
host: '127.0.0.1',
|
||||
port: '8080' as unknown as number,
|
||||
},
|
||||
useHttpTunnel: true,
|
||||
httpTunnel: {
|
||||
host: '127.0.0.1',
|
||||
port: '9000' as unknown as number,
|
||||
},
|
||||
} as any);
|
||||
|
||||
expect(result).toBeInstanceOf(connection.ConnectionConfig);
|
||||
expect(result.ssh).toBeInstanceOf(connection.SSHConfig);
|
||||
expect(result.proxy).toBeInstanceOf(connection.ProxyConfig);
|
||||
expect(result.httpTunnel).toBeInstanceOf(connection.HTTPTunnelConfig);
|
||||
expect(typeof (result as any).convertValues).toBe('function');
|
||||
});
|
||||
});
|
||||
122
frontend/src/utils/connectionRpcConfig.ts
Normal file
122
frontend/src/utils/connectionRpcConfig.ts
Normal file
@@ -0,0 +1,122 @@
|
||||
import { connection } from '../../wailsjs/go/models';
|
||||
|
||||
export type RpcConnectionConfig = connection.ConnectionConfig & { id?: string };
|
||||
type ConnectionConfigInput = {
|
||||
id?: string;
|
||||
ssh?: Record<string, any>;
|
||||
proxy?: Record<string, any>;
|
||||
httpTunnel?: Record<string, any>;
|
||||
[key: string]: any;
|
||||
};
|
||||
type SSHConfigInput = Record<string, any>;
|
||||
type ProxyConfigInput = Record<string, any>;
|
||||
type HttpTunnelConfigInput = Record<string, any>;
|
||||
|
||||
const toStringValue = (value: unknown, fallback = ''): string => {
|
||||
if (typeof value === 'string') {
|
||||
return value;
|
||||
}
|
||||
if (typeof value === 'number' || typeof value === 'boolean') {
|
||||
return String(value);
|
||||
}
|
||||
return fallback;
|
||||
};
|
||||
|
||||
const toOptionalInteger = (value: unknown, fallback?: number): number | undefined => {
|
||||
if (value === undefined || value === null || value === '') {
|
||||
return fallback;
|
||||
}
|
||||
const parsed = Number(value);
|
||||
if (!Number.isFinite(parsed)) {
|
||||
return fallback;
|
||||
}
|
||||
return Math.trunc(parsed);
|
||||
};
|
||||
|
||||
const normalizeProxyType = (value: unknown): 'socks5' | 'http' => {
|
||||
return toStringValue(value).toLowerCase() === 'http' ? 'http' : 'socks5';
|
||||
};
|
||||
|
||||
const normalizeSSHConfig = (value: unknown): connection.SSHConfig => {
|
||||
const raw = (value ?? {}) as SSHConfigInput;
|
||||
return new connection.SSHConfig({
|
||||
host: toStringValue(raw.host),
|
||||
port: toOptionalInteger(raw.port, 22) ?? 22,
|
||||
user: toStringValue(raw.user),
|
||||
password: toStringValue(raw.password),
|
||||
keyPath: toStringValue(raw.keyPath),
|
||||
});
|
||||
};
|
||||
|
||||
const normalizeProxyConfig = (value: unknown): connection.ProxyConfig => {
|
||||
const raw = (value ?? {}) as ProxyConfigInput;
|
||||
const type = normalizeProxyType(raw.type);
|
||||
return new connection.ProxyConfig({
|
||||
type,
|
||||
host: toStringValue(raw.host),
|
||||
port: toOptionalInteger(raw.port, type === 'http' ? 8080 : 1080) ?? (type === 'http' ? 8080 : 1080),
|
||||
user: toStringValue(raw.user),
|
||||
password: toStringValue(raw.password),
|
||||
});
|
||||
};
|
||||
|
||||
const normalizeHttpTunnelConfig = (value: unknown): connection.HTTPTunnelConfig => {
|
||||
const raw = (value ?? {}) as HttpTunnelConfigInput;
|
||||
return new connection.HTTPTunnelConfig({
|
||||
host: toStringValue(raw.host),
|
||||
port: toOptionalInteger(raw.port, 8080) ?? 8080,
|
||||
user: toStringValue(raw.user),
|
||||
password: toStringValue(raw.password),
|
||||
});
|
||||
};
|
||||
|
||||
export function buildRpcConnectionConfig(
|
||||
config: ConnectionConfigInput,
|
||||
overrides: ConnectionConfigInput = {},
|
||||
): RpcConnectionConfig {
|
||||
const mergedSSH = {
|
||||
...(config.ssh ?? {}),
|
||||
...(overrides.ssh ?? {}),
|
||||
};
|
||||
const mergedProxy = {
|
||||
...(config.proxy ?? {}),
|
||||
...(overrides.proxy ?? {}),
|
||||
};
|
||||
const mergedHttpTunnel = {
|
||||
...(config.httpTunnel ?? {}),
|
||||
...(overrides.httpTunnel ?? {}),
|
||||
};
|
||||
const merged: ConnectionConfigInput = {
|
||||
...config,
|
||||
...overrides,
|
||||
ssh: mergedSSH,
|
||||
proxy: mergedProxy,
|
||||
httpTunnel: mergedHttpTunnel,
|
||||
};
|
||||
|
||||
const baseId = toStringValue(config.id).trim() || toStringValue(overrides.id).trim() || undefined;
|
||||
const timeout = toOptionalInteger(merged.timeout, toOptionalInteger(config.timeout));
|
||||
const redisDB = toOptionalInteger(merged.redisDB, toOptionalInteger(config.redisDB));
|
||||
|
||||
const rpcConfig = new connection.ConnectionConfig({
|
||||
...merged,
|
||||
type: toStringValue(merged.type),
|
||||
host: toStringValue(merged.host),
|
||||
port: toOptionalInteger(merged.port, toOptionalInteger(config.port, 0)) ?? 0,
|
||||
user: toStringValue(merged.user),
|
||||
password: toStringValue(merged.password),
|
||||
database: toStringValue(merged.database),
|
||||
useSSH: merged.useSSH === true,
|
||||
ssh: normalizeSSHConfig(merged.ssh),
|
||||
useProxy: merged.useProxy === true,
|
||||
proxy: normalizeProxyConfig(merged.proxy),
|
||||
useHttpTunnel: merged.useHttpTunnel === true,
|
||||
httpTunnel: normalizeHttpTunnelConfig(merged.httpTunnel),
|
||||
timeout,
|
||||
redisDB,
|
||||
}) as RpcConnectionConfig;
|
||||
|
||||
rpcConfig.id = baseId;
|
||||
return rpcConfig;
|
||||
}
|
||||
|
||||
86
frontend/src/utils/connectionSecretDraft.test.ts
Normal file
86
frontend/src/utils/connectionSecretDraft.test.ts
Normal file
@@ -0,0 +1,86 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { resolveConnectionSecretDraft } from './connectionSecretDraft';
|
||||
|
||||
describe('resolveConnectionSecretDraft', () => {
|
||||
it('keeps an existing stored secret when edit form leaves the field blank', () => {
|
||||
const result = resolveConnectionSecretDraft({
|
||||
hasSecret: true,
|
||||
valueInput: '',
|
||||
clearSecret: false,
|
||||
});
|
||||
|
||||
expect(result.value).toBe('');
|
||||
expect(result.clearStoredSecret).toBe(false);
|
||||
expect(result.keepsStoredSecret).toBe(true);
|
||||
expect(result.hasSecretAfterSave).toBe(true);
|
||||
});
|
||||
|
||||
it('replaces the stored secret when a new value is entered', () => {
|
||||
const result = resolveConnectionSecretDraft({
|
||||
hasSecret: true,
|
||||
valueInput: ' mongodb://demo ',
|
||||
clearSecret: false,
|
||||
trimInput: true,
|
||||
});
|
||||
|
||||
expect(result.value).toBe('mongodb://demo');
|
||||
expect(result.clearStoredSecret).toBe(false);
|
||||
expect(result.keepsStoredSecret).toBe(false);
|
||||
expect(result.hasSecretAfterSave).toBe(true);
|
||||
});
|
||||
|
||||
it('clears the stored secret when explicitly requested', () => {
|
||||
const result = resolveConnectionSecretDraft({
|
||||
hasSecret: true,
|
||||
valueInput: '',
|
||||
clearSecret: true,
|
||||
});
|
||||
|
||||
expect(result.value).toBe('');
|
||||
expect(result.clearStoredSecret).toBe(true);
|
||||
expect(result.keepsStoredSecret).toBe(false);
|
||||
expect(result.hasSecretAfterSave).toBe(false);
|
||||
});
|
||||
|
||||
it('prefers a newly entered value over a stale clear toggle', () => {
|
||||
const result = resolveConnectionSecretDraft({
|
||||
hasSecret: true,
|
||||
valueInput: 'new-password',
|
||||
clearSecret: true,
|
||||
});
|
||||
|
||||
expect(result.value).toBe('new-password');
|
||||
expect(result.clearStoredSecret).toBe(false);
|
||||
expect(result.keepsStoredSecret).toBe(false);
|
||||
expect(result.hasSecretAfterSave).toBe(true);
|
||||
});
|
||||
|
||||
it('does not emit a clear flag for a brand new blank field', () => {
|
||||
const result = resolveConnectionSecretDraft({
|
||||
hasSecret: false,
|
||||
valueInput: '',
|
||||
clearSecret: false,
|
||||
});
|
||||
|
||||
expect(result.value).toBe('');
|
||||
expect(result.clearStoredSecret).toBe(false);
|
||||
expect(result.keepsStoredSecret).toBe(false);
|
||||
expect(result.hasSecretAfterSave).toBe(false);
|
||||
});
|
||||
|
||||
it('supports force clearing stored secrets', () => {
|
||||
const result = resolveConnectionSecretDraft({
|
||||
hasSecret: true,
|
||||
valueInput: 'temporary',
|
||||
clearSecret: false,
|
||||
forceClear: true,
|
||||
});
|
||||
|
||||
expect(result.value).toBe('');
|
||||
expect(result.clearStoredSecret).toBe(true);
|
||||
expect(result.keepsStoredSecret).toBe(false);
|
||||
expect(result.hasSecretAfterSave).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
63
frontend/src/utils/connectionSecretDraft.ts
Normal file
63
frontend/src/utils/connectionSecretDraft.ts
Normal file
@@ -0,0 +1,63 @@
|
||||
export interface ConnectionSecretDraftInput {
|
||||
valueInput?: string;
|
||||
hasSecret?: boolean;
|
||||
clearSecret?: boolean;
|
||||
forceClear?: boolean;
|
||||
trimInput?: boolean;
|
||||
}
|
||||
|
||||
export interface ConnectionSecretDraftResult {
|
||||
value: string;
|
||||
clearStoredSecret: boolean;
|
||||
keepsStoredSecret: boolean;
|
||||
hasSecretAfterSave: boolean;
|
||||
}
|
||||
|
||||
export function resolveConnectionSecretDraft(input: ConnectionSecretDraftInput): ConnectionSecretDraftResult {
|
||||
const rawValue = input.valueInput ?? '';
|
||||
const value = input.trimInput ? String(rawValue).trim() : String(rawValue);
|
||||
|
||||
if (input.forceClear) {
|
||||
return {
|
||||
value: '',
|
||||
clearStoredSecret: true,
|
||||
keepsStoredSecret: false,
|
||||
hasSecretAfterSave: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (value !== '') {
|
||||
return {
|
||||
value,
|
||||
clearStoredSecret: false,
|
||||
keepsStoredSecret: false,
|
||||
hasSecretAfterSave: true,
|
||||
};
|
||||
}
|
||||
|
||||
if (input.clearSecret) {
|
||||
return {
|
||||
value: '',
|
||||
clearStoredSecret: true,
|
||||
keepsStoredSecret: false,
|
||||
hasSecretAfterSave: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (input.hasSecret) {
|
||||
return {
|
||||
value: '',
|
||||
clearStoredSecret: false,
|
||||
keepsStoredSecret: true,
|
||||
hasSecretAfterSave: true,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
value: '',
|
||||
clearStoredSecret: false,
|
||||
keepsStoredSecret: false,
|
||||
hasSecretAfterSave: false,
|
||||
};
|
||||
}
|
||||
|
||||
37
frontend/src/utils/customConnectionDsn.test.ts
Normal file
37
frontend/src/utils/customConnectionDsn.test.ts
Normal file
@@ -0,0 +1,37 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { shouldAllowBlankCustomDsn } from './customConnectionDsn';
|
||||
|
||||
describe('shouldAllowBlankCustomDsn', () => {
|
||||
it('allows a blank DSN when editing a connection that already has a stored opaque DSN', () => {
|
||||
expect(shouldAllowBlankCustomDsn({
|
||||
dsnInput: '',
|
||||
hasStoredSecret: true,
|
||||
clearStoredSecret: false,
|
||||
})).toBe(true);
|
||||
});
|
||||
|
||||
it('requires a new DSN when the user chooses to clear the stored opaque DSN', () => {
|
||||
expect(shouldAllowBlankCustomDsn({
|
||||
dsnInput: '',
|
||||
hasStoredSecret: true,
|
||||
clearStoredSecret: true,
|
||||
})).toBe(false);
|
||||
});
|
||||
|
||||
it('requires a DSN for brand new custom connections', () => {
|
||||
expect(shouldAllowBlankCustomDsn({
|
||||
dsnInput: '',
|
||||
hasStoredSecret: false,
|
||||
clearStoredSecret: false,
|
||||
})).toBe(false);
|
||||
});
|
||||
|
||||
it('accepts a newly entered DSN even when a stored secret already exists', () => {
|
||||
expect(shouldAllowBlankCustomDsn({
|
||||
dsnInput: 'driver://demo',
|
||||
hasStoredSecret: true,
|
||||
clearStoredSecret: true,
|
||||
})).toBe(true);
|
||||
});
|
||||
});
|
||||
27
frontend/src/utils/customConnectionDsn.ts
Normal file
27
frontend/src/utils/customConnectionDsn.ts
Normal file
@@ -0,0 +1,27 @@
|
||||
export interface CustomConnectionDsnState {
|
||||
dsnInput: unknown;
|
||||
hasStoredSecret?: boolean;
|
||||
clearStoredSecret?: boolean;
|
||||
}
|
||||
|
||||
export const getCustomConnectionDsnValidationMessage = ({
|
||||
dsnInput,
|
||||
hasStoredSecret,
|
||||
clearStoredSecret,
|
||||
}: CustomConnectionDsnState): string | null => {
|
||||
const dsnText = String(dsnInput ?? '').trim();
|
||||
if (dsnText !== '') {
|
||||
return null;
|
||||
}
|
||||
if (hasStoredSecret && !clearStoredSecret) {
|
||||
return null;
|
||||
}
|
||||
if (hasStoredSecret && clearStoredSecret) {
|
||||
return '请输入新的连接字符串,或取消清除已保存 DSN';
|
||||
}
|
||||
return '请输入连接字符串';
|
||||
};
|
||||
|
||||
export const shouldAllowBlankCustomDsn = (state: CustomConnectionDsnState): boolean => (
|
||||
getCustomConnectionDsnValidationMessage(state) === null
|
||||
);
|
||||
32
frontend/src/utils/dataGridDisplay.test.ts
Normal file
32
frontend/src/utils/dataGridDisplay.test.ts
Normal file
@@ -0,0 +1,32 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import {
|
||||
DEFAULT_DATA_GRID_DISPLAY_SETTINGS,
|
||||
resolveDataTableColumnWidth,
|
||||
resolveDataTableDefaultColumnWidth,
|
||||
resolveDataTableVerticalBorderColor,
|
||||
sanitizeDataGridDisplaySettings,
|
||||
} from './dataGridDisplay';
|
||||
|
||||
describe('dataGridDisplay helpers', () => {
|
||||
it('sanitizes missing display settings to safe defaults', () => {
|
||||
expect(sanitizeDataGridDisplaySettings(undefined)).toEqual(DEFAULT_DATA_GRID_DISPLAY_SETTINGS);
|
||||
expect(sanitizeDataGridDisplaySettings({ dataTableColumnWidthMode: 'invalid' as never })).toEqual(DEFAULT_DATA_GRID_DISPLAY_SETTINGS);
|
||||
});
|
||||
|
||||
it('resolves standard and compact default column widths', () => {
|
||||
expect(resolveDataTableDefaultColumnWidth('standard')).toBe(200);
|
||||
expect(resolveDataTableDefaultColumnWidth('compact')).toBe(140);
|
||||
});
|
||||
|
||||
it('keeps manual column widths ahead of mode defaults', () => {
|
||||
expect(resolveDataTableColumnWidth({ manualWidth: 320, widthMode: 'compact' })).toBe(320);
|
||||
expect(resolveDataTableColumnWidth({ manualWidth: undefined, widthMode: 'compact' })).toBe(140);
|
||||
});
|
||||
|
||||
it('uses subtle themed vertical border colors and transparent when disabled', () => {
|
||||
expect(resolveDataTableVerticalBorderColor({ darkMode: true, visible: true })).toBe('rgba(255, 255, 255, 0.08)');
|
||||
expect(resolveDataTableVerticalBorderColor({ darkMode: false, visible: true })).toBe('rgba(15, 23, 42, 0.08)');
|
||||
expect(resolveDataTableVerticalBorderColor({ darkMode: false, visible: false })).toBe('transparent');
|
||||
});
|
||||
});
|
||||
72
frontend/src/utils/dataGridDisplay.ts
Normal file
72
frontend/src/utils/dataGridDisplay.ts
Normal file
@@ -0,0 +1,72 @@
|
||||
export type DataTableColumnWidthMode = 'standard' | 'compact';
|
||||
|
||||
export interface DataGridDisplaySettings {
|
||||
showDataTableVerticalBorders: boolean;
|
||||
dataTableColumnWidthMode: DataTableColumnWidthMode;
|
||||
}
|
||||
|
||||
export const DEFAULT_DATA_GRID_DISPLAY_SETTINGS: DataGridDisplaySettings = {
|
||||
showDataTableVerticalBorders: false,
|
||||
dataTableColumnWidthMode: 'standard',
|
||||
};
|
||||
|
||||
export const DATA_GRID_COLUMN_WIDTH_MODE_OPTIONS = [
|
||||
{ label: '标准 200px', value: 'standard' as const },
|
||||
{ label: '紧凑 140px', value: 'compact' as const },
|
||||
];
|
||||
|
||||
const STANDARD_DATA_TABLE_COLUMN_WIDTH = 200;
|
||||
const COMPACT_DATA_TABLE_COLUMN_WIDTH = 140;
|
||||
|
||||
export const sanitizeDataTableColumnWidthMode = (value: unknown): DataTableColumnWidthMode => {
|
||||
return value === 'compact' ? 'compact' : 'standard';
|
||||
};
|
||||
|
||||
export const sanitizeDataGridDisplaySettings = (
|
||||
value: Partial<DataGridDisplaySettings> | undefined
|
||||
): DataGridDisplaySettings => {
|
||||
if (!value || typeof value !== 'object') {
|
||||
return { ...DEFAULT_DATA_GRID_DISPLAY_SETTINGS };
|
||||
}
|
||||
|
||||
return {
|
||||
showDataTableVerticalBorders: value.showDataTableVerticalBorders === true,
|
||||
dataTableColumnWidthMode: sanitizeDataTableColumnWidthMode(value.dataTableColumnWidthMode),
|
||||
};
|
||||
};
|
||||
|
||||
export const resolveDataTableDefaultColumnWidth = (
|
||||
widthMode: DataTableColumnWidthMode | null | undefined
|
||||
): number => {
|
||||
return sanitizeDataTableColumnWidthMode(widthMode) === 'compact'
|
||||
? COMPACT_DATA_TABLE_COLUMN_WIDTH
|
||||
: STANDARD_DATA_TABLE_COLUMN_WIDTH;
|
||||
};
|
||||
|
||||
export const resolveDataTableColumnWidth = ({
|
||||
manualWidth,
|
||||
widthMode,
|
||||
}: {
|
||||
manualWidth: number | null | undefined;
|
||||
widthMode: DataTableColumnWidthMode | null | undefined;
|
||||
}): number => {
|
||||
if (typeof manualWidth === 'number' && Number.isFinite(manualWidth) && manualWidth > 0) {
|
||||
return manualWidth;
|
||||
}
|
||||
|
||||
return resolveDataTableDefaultColumnWidth(widthMode);
|
||||
};
|
||||
|
||||
export const resolveDataTableVerticalBorderColor = ({
|
||||
darkMode,
|
||||
visible,
|
||||
}: {
|
||||
darkMode: boolean;
|
||||
visible: boolean;
|
||||
}): string => {
|
||||
if (!visible) {
|
||||
return 'transparent';
|
||||
}
|
||||
|
||||
return darkMode ? 'rgba(255, 255, 255, 0.08)' : 'rgba(15, 23, 42, 0.08)';
|
||||
};
|
||||
43
frontend/src/utils/dataGridSort.ts
Normal file
43
frontend/src/utils/dataGridSort.ts
Normal file
@@ -0,0 +1,43 @@
|
||||
export type GridSortInfoItem = {
|
||||
columnKey: string;
|
||||
order: string;
|
||||
enabled?: boolean;
|
||||
};
|
||||
|
||||
type TableSorterLike = {
|
||||
field?: unknown;
|
||||
columnKey?: unknown;
|
||||
order?: unknown;
|
||||
};
|
||||
|
||||
export const resolveGridSortInfoFromTableSorter = ({
|
||||
sorter,
|
||||
}: {
|
||||
sorter: TableSorterLike | TableSorterLike[] | null | undefined;
|
||||
}): GridSortInfoItem[] => {
|
||||
const sorters = Array.isArray(sorter)
|
||||
? sorter
|
||||
: ((sorter?.field || sorter?.columnKey) ? [sorter] : []);
|
||||
|
||||
if (sorters.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const next: GridSortInfoItem[] = [];
|
||||
const seen = new Set<string>();
|
||||
|
||||
for (const item of sorters) {
|
||||
const field = String(item?.field || item?.columnKey || '').trim();
|
||||
if (!field) continue;
|
||||
|
||||
const order = item?.order as string;
|
||||
const normalizedOrder = order === 'ascend' || order === 'descend' ? order : '';
|
||||
if (!normalizedOrder) continue;
|
||||
const dedupeKey = field.toLowerCase();
|
||||
if (seen.has(dedupeKey)) continue;
|
||||
seen.add(dedupeKey);
|
||||
next.push({ columnKey: field, order: normalizedOrder, enabled: true });
|
||||
}
|
||||
|
||||
return next;
|
||||
};
|
||||
35
frontend/src/utils/globalProxyDraft.test.ts
Normal file
35
frontend/src/utils/globalProxyDraft.test.ts
Normal file
@@ -0,0 +1,35 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { createGlobalProxyDraft, toPersistedGlobalProxy } from './globalProxyDraft';
|
||||
|
||||
describe('global proxy draft', () => {
|
||||
it('hydrates a secretless draft from backend metadata while keeping password input blank', () => {
|
||||
const draft = createGlobalProxyDraft({
|
||||
enabled: true,
|
||||
type: 'http',
|
||||
host: '127.0.0.1',
|
||||
port: 8080,
|
||||
user: 'ops',
|
||||
hasPassword: true,
|
||||
password: 'should-be-ignored',
|
||||
});
|
||||
|
||||
expect(draft.password).toBe('');
|
||||
expect(draft.hasPassword).toBe(true);
|
||||
});
|
||||
|
||||
it('drops password from persisted metadata but preserves hasPassword', () => {
|
||||
const persisted = toPersistedGlobalProxy({
|
||||
enabled: true,
|
||||
type: 'http',
|
||||
host: '127.0.0.1',
|
||||
port: 8080,
|
||||
user: 'ops',
|
||||
password: 'proxy-secret',
|
||||
hasPassword: true,
|
||||
});
|
||||
|
||||
expect('password' in persisted).toBe(false);
|
||||
expect(persisted.hasPassword).toBe(true);
|
||||
});
|
||||
});
|
||||
62
frontend/src/utils/globalProxyDraft.ts
Normal file
62
frontend/src/utils/globalProxyDraft.ts
Normal file
@@ -0,0 +1,62 @@
|
||||
import { GlobalProxyConfig } from '../types';
|
||||
|
||||
const toTrimmedString = (value: unknown): string => {
|
||||
if (typeof value === 'string') {
|
||||
return value.trim();
|
||||
}
|
||||
if (typeof value === 'number' || typeof value === 'boolean') {
|
||||
return String(value).trim();
|
||||
}
|
||||
return '';
|
||||
};
|
||||
|
||||
const normalizeProxyType = (value: unknown): 'socks5' | 'http' => {
|
||||
return toTrimmedString(value).toLowerCase() === 'http' ? 'http' : 'socks5';
|
||||
};
|
||||
|
||||
const normalizePort = (value: unknown, fallbackPort: number): number => {
|
||||
const parsed = Number(value);
|
||||
if (!Number.isFinite(parsed)) {
|
||||
return fallbackPort;
|
||||
}
|
||||
const port = Math.trunc(parsed);
|
||||
if (port <= 0 || port > 65535) {
|
||||
return fallbackPort;
|
||||
}
|
||||
return port;
|
||||
};
|
||||
|
||||
export function createGlobalProxyDraft(value: Partial<GlobalProxyConfig> = {}): GlobalProxyConfig {
|
||||
const type = normalizeProxyType(value.type);
|
||||
return {
|
||||
enabled: value.enabled === true,
|
||||
type,
|
||||
host: toTrimmedString(value.host),
|
||||
port: normalizePort(value.port, type === 'http' ? 8080 : 1080),
|
||||
user: toTrimmedString(value.user),
|
||||
password: '',
|
||||
hasPassword: value.hasPassword === true,
|
||||
secretRef: toTrimmedString(value.secretRef) || undefined,
|
||||
};
|
||||
}
|
||||
|
||||
export function toPersistedGlobalProxy(value: Partial<GlobalProxyConfig> = {}): Omit<GlobalProxyConfig, 'password'> {
|
||||
const draft = createGlobalProxyDraft(value);
|
||||
return {
|
||||
enabled: draft.enabled,
|
||||
type: draft.type,
|
||||
host: draft.host,
|
||||
port: draft.port,
|
||||
user: draft.user,
|
||||
hasPassword: draft.hasPassword,
|
||||
secretRef: draft.secretRef,
|
||||
};
|
||||
}
|
||||
|
||||
export function toSaveGlobalProxyInput(value: Partial<GlobalProxyConfig> = {}): GlobalProxyConfig {
|
||||
const draft = createGlobalProxyDraft(value);
|
||||
return {
|
||||
...draft,
|
||||
password: typeof value.password === 'string' ? value.password : '',
|
||||
};
|
||||
}
|
||||
75
frontend/src/utils/legacyConnectionStorage.test.ts
Normal file
75
frontend/src/utils/legacyConnectionStorage.test.ts
Normal file
@@ -0,0 +1,75 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { readLegacyPersistedSecrets, stripLegacyPersistedSecrets } from './legacyConnectionStorage';
|
||||
|
||||
describe('legacy connection storage', () => {
|
||||
it('extracts legacy saved connections and global proxy password from lite-db-storage', () => {
|
||||
const payload = JSON.stringify({
|
||||
state: {
|
||||
connections: [
|
||||
{
|
||||
id: 'conn-1',
|
||||
name: 'Primary',
|
||||
config: {
|
||||
id: 'conn-1',
|
||||
type: 'postgres',
|
||||
host: 'db.local',
|
||||
port: 5432,
|
||||
user: 'postgres',
|
||||
password: 'secret',
|
||||
},
|
||||
},
|
||||
],
|
||||
globalProxy: {
|
||||
enabled: true,
|
||||
type: 'http',
|
||||
host: '127.0.0.1',
|
||||
port: 8080,
|
||||
user: 'ops',
|
||||
password: 'proxy-secret',
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const result = readLegacyPersistedSecrets(payload);
|
||||
expect(result.connections).toHaveLength(1);
|
||||
expect(result.connections[0]?.config.password).toBe('secret');
|
||||
expect(result.globalProxy?.password).toBe('proxy-secret');
|
||||
});
|
||||
|
||||
it('strips persisted connection secrets but keeps secretless proxy metadata', () => {
|
||||
const payload = JSON.stringify({
|
||||
state: {
|
||||
connections: [
|
||||
{
|
||||
id: 'conn-1',
|
||||
name: 'Primary',
|
||||
config: {
|
||||
id: 'conn-1',
|
||||
type: 'postgres',
|
||||
host: 'db.local',
|
||||
port: 5432,
|
||||
user: 'postgres',
|
||||
password: 'secret',
|
||||
},
|
||||
},
|
||||
],
|
||||
globalProxy: {
|
||||
enabled: true,
|
||||
type: 'http',
|
||||
host: '127.0.0.1',
|
||||
port: 8080,
|
||||
user: 'ops',
|
||||
password: 'proxy-secret',
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const sanitized = stripLegacyPersistedSecrets(payload);
|
||||
const parsed = JSON.parse(sanitized);
|
||||
|
||||
expect(parsed.state.connections).toEqual([]);
|
||||
expect(parsed.state.globalProxy.password).toBeUndefined();
|
||||
expect(parsed.state.globalProxy.hasPassword).toBe(true);
|
||||
});
|
||||
});
|
||||
110
frontend/src/utils/legacyConnectionStorage.ts
Normal file
110
frontend/src/utils/legacyConnectionStorage.ts
Normal file
@@ -0,0 +1,110 @@
|
||||
import { GlobalProxyConfig, SavedConnection } from '../types';
|
||||
|
||||
export const LEGACY_PERSIST_KEY = 'lite-db-storage';
|
||||
|
||||
const toTrimmedString = (value: unknown): string => {
|
||||
if (typeof value === 'string') {
|
||||
return value.trim();
|
||||
}
|
||||
if (typeof value === 'number' || typeof value === 'boolean') {
|
||||
return String(value).trim();
|
||||
}
|
||||
return '';
|
||||
};
|
||||
|
||||
const normalizeProxyType = (value: unknown): 'socks5' | 'http' => {
|
||||
return toTrimmedString(value).toLowerCase() === 'http' ? 'http' : 'socks5';
|
||||
};
|
||||
|
||||
const normalizePort = (value: unknown, fallbackPort: number): number => {
|
||||
const parsed = Number(value);
|
||||
if (!Number.isFinite(parsed)) {
|
||||
return fallbackPort;
|
||||
}
|
||||
const port = Math.trunc(parsed);
|
||||
if (port <= 0 || port > 65535) {
|
||||
return fallbackPort;
|
||||
}
|
||||
return port;
|
||||
};
|
||||
|
||||
const parsePersistedEnvelope = (payload: string | null | undefined): Record<string, unknown> => {
|
||||
if (!payload || typeof payload !== 'string') {
|
||||
return {};
|
||||
}
|
||||
try {
|
||||
const parsed = JSON.parse(payload) as Record<string, unknown>;
|
||||
if (parsed.state && typeof parsed.state === 'object') {
|
||||
return parsed.state as Record<string, unknown>;
|
||||
}
|
||||
return parsed;
|
||||
} catch {
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
export function readLegacyPersistedSecrets(payload: string | null | undefined): {
|
||||
connections: SavedConnection[];
|
||||
globalProxy: GlobalProxyConfig | null;
|
||||
} {
|
||||
const state = parsePersistedEnvelope(payload);
|
||||
const connections = Array.isArray(state.connections)
|
||||
? state.connections.filter((item): item is SavedConnection => !!item && typeof item === 'object')
|
||||
: [];
|
||||
|
||||
const proxyRaw = state.globalProxy && typeof state.globalProxy === 'object'
|
||||
? state.globalProxy as Record<string, unknown>
|
||||
: null;
|
||||
if (!proxyRaw) {
|
||||
return { connections, globalProxy: null };
|
||||
}
|
||||
|
||||
const type = normalizeProxyType(proxyRaw.type);
|
||||
const password = toTrimmedString(proxyRaw.password);
|
||||
const globalProxy: GlobalProxyConfig = {
|
||||
enabled: proxyRaw.enabled === true,
|
||||
type,
|
||||
host: toTrimmedString(proxyRaw.host),
|
||||
port: normalizePort(proxyRaw.port, type === 'http' ? 8080 : 1080),
|
||||
user: toTrimmedString(proxyRaw.user),
|
||||
password,
|
||||
hasPassword: proxyRaw.hasPassword === true || password !== '',
|
||||
secretRef: toTrimmedString(proxyRaw.secretRef) || undefined,
|
||||
};
|
||||
|
||||
const hasMeaningfulProxyState = globalProxy.enabled || globalProxy.host !== '' || globalProxy.user !== '' || globalProxy.password !== '' || globalProxy.hasPassword === true;
|
||||
return {
|
||||
connections,
|
||||
globalProxy: hasMeaningfulProxyState ? globalProxy : null,
|
||||
};
|
||||
}
|
||||
|
||||
export function stripLegacyPersistedSecrets(payload: string | null | undefined): string {
|
||||
if (!payload || typeof payload !== 'string') {
|
||||
return '';
|
||||
}
|
||||
|
||||
let parsed: Record<string, unknown>;
|
||||
try {
|
||||
parsed = JSON.parse(payload) as Record<string, unknown>;
|
||||
} catch {
|
||||
return payload;
|
||||
}
|
||||
|
||||
const state = parsed.state && typeof parsed.state === 'object'
|
||||
? parsed.state as Record<string, unknown>
|
||||
: parsed;
|
||||
state.connections = [];
|
||||
|
||||
if (state.globalProxy && typeof state.globalProxy === 'object') {
|
||||
const proxy = { ...(state.globalProxy as Record<string, unknown>) };
|
||||
const password = toTrimmedString(proxy.password);
|
||||
delete proxy.password;
|
||||
if (password !== '') {
|
||||
proxy.hasPassword = true;
|
||||
}
|
||||
state.globalProxy = proxy;
|
||||
}
|
||||
|
||||
return JSON.stringify(parsed);
|
||||
}
|
||||
41
frontend/src/utils/providerSecretDraft.test.ts
Normal file
41
frontend/src/utils/providerSecretDraft.test.ts
Normal file
@@ -0,0 +1,41 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { resolveProviderSecretDraft } from './providerSecretDraft';
|
||||
|
||||
describe('resolveProviderSecretDraft', () => {
|
||||
it('keeps existing provider secret when edit form leaves apiKey blank', () => {
|
||||
const result = resolveProviderSecretDraft({
|
||||
hasSecret: true,
|
||||
apiKeyInput: '',
|
||||
clearSecret: false,
|
||||
});
|
||||
|
||||
expect(result.mode).toBe('keep');
|
||||
expect(result.apiKey).toBe('');
|
||||
expect(result.hasSecret).toBe(true);
|
||||
});
|
||||
|
||||
it('replaces the provider secret when a new apiKey is entered', () => {
|
||||
const result = resolveProviderSecretDraft({
|
||||
hasSecret: true,
|
||||
apiKeyInput: ' sk-new ',
|
||||
clearSecret: false,
|
||||
});
|
||||
|
||||
expect(result.mode).toBe('replace');
|
||||
expect(result.apiKey).toBe('sk-new');
|
||||
expect(result.hasSecret).toBe(true);
|
||||
});
|
||||
|
||||
it('clears the stored provider secret when requested', () => {
|
||||
const result = resolveProviderSecretDraft({
|
||||
hasSecret: true,
|
||||
apiKeyInput: '',
|
||||
clearSecret: true,
|
||||
});
|
||||
|
||||
expect(result.mode).toBe('clear');
|
||||
expect(result.apiKey).toBe('');
|
||||
expect(result.hasSecret).toBe(false);
|
||||
});
|
||||
});
|
||||
47
frontend/src/utils/providerSecretDraft.ts
Normal file
47
frontend/src/utils/providerSecretDraft.ts
Normal file
@@ -0,0 +1,47 @@
|
||||
export type ProviderSecretDraftMode = 'keep' | 'replace' | 'clear';
|
||||
|
||||
export interface ProviderSecretDraftInput {
|
||||
hasSecret?: boolean;
|
||||
apiKeyInput?: string;
|
||||
clearSecret?: boolean;
|
||||
}
|
||||
|
||||
export interface ProviderSecretDraftResult {
|
||||
mode: ProviderSecretDraftMode;
|
||||
apiKey: string;
|
||||
hasSecret: boolean;
|
||||
}
|
||||
|
||||
export function resolveProviderSecretDraft(input: ProviderSecretDraftInput): ProviderSecretDraftResult {
|
||||
const apiKey = String(input.apiKeyInput || '').trim();
|
||||
|
||||
if (input.clearSecret) {
|
||||
return {
|
||||
mode: 'clear',
|
||||
apiKey: '',
|
||||
hasSecret: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (apiKey) {
|
||||
return {
|
||||
mode: 'replace',
|
||||
apiKey,
|
||||
hasSecret: true,
|
||||
};
|
||||
}
|
||||
|
||||
if (input.hasSecret) {
|
||||
return {
|
||||
mode: 'keep',
|
||||
apiKey: '',
|
||||
hasSecret: true,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
mode: 'clear',
|
||||
apiKey: '',
|
||||
hasSecret: false,
|
||||
};
|
||||
}
|
||||
@@ -10,10 +10,10 @@ describe('startup readiness helpers', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('keeps sidebar blocked until initial global proxy sync finishes', () => {
|
||||
it('keeps sidebar blocked until secure config bootstrap finishes', () => {
|
||||
expect(getConnectionWorkbenchState(true, false)).toEqual({
|
||||
ready: false,
|
||||
message: '正在同步全局代理配置...',
|
||||
message: '正在加载安全配置...',
|
||||
});
|
||||
});
|
||||
|
||||
@@ -24,3 +24,4 @@ describe('startup readiness helpers', () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ export function getConnectionWorkbenchState(
|
||||
if (!hasAppliedInitialGlobalProxy) {
|
||||
return {
|
||||
ready: false,
|
||||
message: '正在同步全局代理配置...',
|
||||
message: '正在加载安全配置...',
|
||||
};
|
||||
}
|
||||
return {
|
||||
@@ -24,3 +24,4 @@ export function getConnectionWorkbenchState(
|
||||
message: '',
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
16
frontend/wailsjs/go/app/App.d.ts
vendored
16
frontend/wailsjs/go/app/App.d.ts
vendored
@@ -52,6 +52,8 @@ export function DataSyncAnalyze(arg1:sync.SyncConfig):Promise<connection.QueryRe
|
||||
|
||||
export function DataSyncPreview(arg1:sync.SyncConfig,arg2:string,arg3:number):Promise<connection.QueryResult>;
|
||||
|
||||
export function DeleteConnection(arg1:string):Promise<void>;
|
||||
|
||||
export function DownloadDriverPackage(arg1:string,arg2:string,arg3:string,arg4:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function DownloadUpdate():Promise<connection.QueryResult>;
|
||||
@@ -64,6 +66,8 @@ export function DropTable(arg1:connection.ConnectionConfig,arg2:string,arg3:stri
|
||||
|
||||
export function DropView(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function DuplicateConnection(arg1:string):Promise<connection.SavedConnectionView>;
|
||||
|
||||
export function ExecuteSQLFile(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function ExportData(arg1:Array<Record<string, any>>,arg2:Array<string>,arg3:string,arg4:string):Promise<connection.QueryResult>;
|
||||
@@ -90,13 +94,19 @@ export function GetDriverVersionPackageSize(arg1:string,arg2:string):Promise<con
|
||||
|
||||
export function GetGlobalProxyConfig():Promise<connection.QueryResult>;
|
||||
|
||||
export function GetSavedConnections():Promise<Array<connection.SavedConnectionView>>;
|
||||
|
||||
export function ImportConfigFile():Promise<connection.QueryResult>;
|
||||
|
||||
export function ImportData(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function ImportDataWithProgress(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function InstallLocalDriverPackage(arg1:string,arg2:string,arg3:string):Promise<connection.QueryResult>;
|
||||
export function ImportLegacyConnections(arg1:Array<connection.SavedConnectionInput>):Promise<Array<connection.SavedConnectionView>>;
|
||||
|
||||
export function ImportLegacyGlobalProxy(arg1:connection.SaveGlobalProxyInput):Promise<connection.GlobalProxyView>;
|
||||
|
||||
export function InstallLocalDriverPackage(arg1:string,arg2:string,arg3:string,arg4:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function InstallUpdateAndRestart():Promise<connection.QueryResult>;
|
||||
|
||||
@@ -180,6 +190,10 @@ export function ResolveDriverPackageDownloadURL(arg1:string,arg2:string):Promise
|
||||
|
||||
export function ResolveDriverRepositoryURL(arg1:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function SaveConnection(arg1:connection.SavedConnectionInput):Promise<connection.SavedConnectionView>;
|
||||
|
||||
export function SaveGlobalProxy(arg1:connection.SaveGlobalProxyInput):Promise<connection.GlobalProxyView>;
|
||||
|
||||
export function SelectDatabaseFile(arg1:string,arg2:string):Promise<connection.QueryResult>;
|
||||
|
||||
export function SelectDriverDownloadDirectory(arg1:string):Promise<connection.QueryResult>;
|
||||
|
||||
@@ -98,6 +98,10 @@ export function DataSyncPreview(arg1, arg2, arg3) {
|
||||
return window['go']['app']['App']['DataSyncPreview'](arg1, arg2, arg3);
|
||||
}
|
||||
|
||||
export function DeleteConnection(arg1) {
|
||||
return window['go']['app']['App']['DeleteConnection'](arg1);
|
||||
}
|
||||
|
||||
export function DownloadDriverPackage(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['app']['App']['DownloadDriverPackage'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
@@ -122,6 +126,10 @@ export function DropView(arg1, arg2, arg3) {
|
||||
return window['go']['app']['App']['DropView'](arg1, arg2, arg3);
|
||||
}
|
||||
|
||||
export function DuplicateConnection(arg1) {
|
||||
return window['go']['app']['App']['DuplicateConnection'](arg1);
|
||||
}
|
||||
|
||||
export function ExecuteSQLFile(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['app']['App']['ExecuteSQLFile'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
@@ -174,6 +182,10 @@ export function GetGlobalProxyConfig() {
|
||||
return window['go']['app']['App']['GetGlobalProxyConfig']();
|
||||
}
|
||||
|
||||
export function GetSavedConnections() {
|
||||
return window['go']['app']['App']['GetSavedConnections']();
|
||||
}
|
||||
|
||||
export function ImportConfigFile() {
|
||||
return window['go']['app']['App']['ImportConfigFile']();
|
||||
}
|
||||
@@ -186,8 +198,16 @@ export function ImportDataWithProgress(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['app']['App']['ImportDataWithProgress'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
export function InstallLocalDriverPackage(arg1, arg2, arg3) {
|
||||
return window['go']['app']['App']['InstallLocalDriverPackage'](arg1, arg2, arg3);
|
||||
export function ImportLegacyConnections(arg1) {
|
||||
return window['go']['app']['App']['ImportLegacyConnections'](arg1);
|
||||
}
|
||||
|
||||
export function ImportLegacyGlobalProxy(arg1) {
|
||||
return window['go']['app']['App']['ImportLegacyGlobalProxy'](arg1);
|
||||
}
|
||||
|
||||
export function InstallLocalDriverPackage(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['app']['App']['InstallLocalDriverPackage'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
export function InstallUpdateAndRestart() {
|
||||
@@ -354,6 +374,14 @@ export function ResolveDriverRepositoryURL(arg1) {
|
||||
return window['go']['app']['App']['ResolveDriverRepositoryURL'](arg1);
|
||||
}
|
||||
|
||||
export function SaveConnection(arg1) {
|
||||
return window['go']['app']['App']['SaveConnection'](arg1);
|
||||
}
|
||||
|
||||
export function SaveGlobalProxy(arg1) {
|
||||
return window['go']['app']['App']['SaveGlobalProxy'](arg1);
|
||||
}
|
||||
|
||||
export function SelectDatabaseFile(arg1, arg2) {
|
||||
return window['go']['app']['App']['SelectDatabaseFile'](arg1, arg2);
|
||||
}
|
||||
|
||||
@@ -78,6 +78,8 @@ export namespace ai {
|
||||
type: string;
|
||||
name: string;
|
||||
apiKey: string;
|
||||
secretRef?: string;
|
||||
hasSecret?: boolean;
|
||||
baseUrl: string;
|
||||
model: string;
|
||||
models?: string[];
|
||||
@@ -96,6 +98,8 @@ export namespace ai {
|
||||
this.type = source["type"];
|
||||
this.name = source["name"];
|
||||
this.apiKey = source["apiKey"];
|
||||
this.secretRef = source["secretRef"];
|
||||
this.hasSecret = source["hasSecret"];
|
||||
this.baseUrl = source["baseUrl"];
|
||||
this.model = source["model"];
|
||||
this.models = source["models"];
|
||||
@@ -284,6 +288,7 @@ export namespace connection {
|
||||
}
|
||||
}
|
||||
export class ConnectionConfig {
|
||||
id?: string;
|
||||
type: string;
|
||||
host: string;
|
||||
port: number;
|
||||
@@ -324,6 +329,7 @@ export namespace connection {
|
||||
|
||||
constructor(source: any = {}) {
|
||||
if ('string' === typeof source) source = JSON.parse(source);
|
||||
this.id = source["id"];
|
||||
this.type = source["type"];
|
||||
this.host = source["host"];
|
||||
this.port = source["port"];
|
||||
@@ -377,6 +383,32 @@ export namespace connection {
|
||||
return a;
|
||||
}
|
||||
}
|
||||
export class GlobalProxyView {
|
||||
enabled: boolean;
|
||||
type: string;
|
||||
host: string;
|
||||
port: number;
|
||||
user?: string;
|
||||
password?: string;
|
||||
hasPassword?: boolean;
|
||||
secretRef?: string;
|
||||
|
||||
static createFrom(source: any = {}) {
|
||||
return new GlobalProxyView(source);
|
||||
}
|
||||
|
||||
constructor(source: any = {}) {
|
||||
if ('string' === typeof source) source = JSON.parse(source);
|
||||
this.enabled = source["enabled"];
|
||||
this.type = source["type"];
|
||||
this.host = source["host"];
|
||||
this.port = source["port"];
|
||||
this.user = source["user"];
|
||||
this.password = source["password"];
|
||||
this.hasPassword = source["hasPassword"];
|
||||
this.secretRef = source["secretRef"];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
export class QueryResult {
|
||||
@@ -400,6 +432,146 @@ export namespace connection {
|
||||
}
|
||||
}
|
||||
|
||||
export class SaveGlobalProxyInput {
|
||||
enabled: boolean;
|
||||
type: string;
|
||||
host: string;
|
||||
port: number;
|
||||
user?: string;
|
||||
password?: string;
|
||||
|
||||
static createFrom(source: any = {}) {
|
||||
return new SaveGlobalProxyInput(source);
|
||||
}
|
||||
|
||||
constructor(source: any = {}) {
|
||||
if ('string' === typeof source) source = JSON.parse(source);
|
||||
this.enabled = source["enabled"];
|
||||
this.type = source["type"];
|
||||
this.host = source["host"];
|
||||
this.port = source["port"];
|
||||
this.user = source["user"];
|
||||
this.password = source["password"];
|
||||
}
|
||||
}
|
||||
export class SavedConnectionInput {
|
||||
id?: string;
|
||||
name: string;
|
||||
config: ConnectionConfig;
|
||||
includeDatabases?: string[];
|
||||
includeRedisDatabases?: number[];
|
||||
iconType?: string;
|
||||
iconColor?: string;
|
||||
clearPrimaryPassword?: boolean;
|
||||
clearSSHPassword?: boolean;
|
||||
clearProxyPassword?: boolean;
|
||||
clearHttpTunnelPassword?: boolean;
|
||||
clearMySQLReplicaPassword?: boolean;
|
||||
clearMongoReplicaPassword?: boolean;
|
||||
clearOpaqueURI?: boolean;
|
||||
clearOpaqueDSN?: boolean;
|
||||
|
||||
static createFrom(source: any = {}) {
|
||||
return new SavedConnectionInput(source);
|
||||
}
|
||||
|
||||
constructor(source: any = {}) {
|
||||
if ('string' === typeof source) source = JSON.parse(source);
|
||||
this.id = source["id"];
|
||||
this.name = source["name"];
|
||||
this.config = this.convertValues(source["config"], ConnectionConfig);
|
||||
this.includeDatabases = source["includeDatabases"];
|
||||
this.includeRedisDatabases = source["includeRedisDatabases"];
|
||||
this.iconType = source["iconType"];
|
||||
this.iconColor = source["iconColor"];
|
||||
this.clearPrimaryPassword = source["clearPrimaryPassword"];
|
||||
this.clearSSHPassword = source["clearSSHPassword"];
|
||||
this.clearProxyPassword = source["clearProxyPassword"];
|
||||
this.clearHttpTunnelPassword = source["clearHttpTunnelPassword"];
|
||||
this.clearMySQLReplicaPassword = source["clearMySQLReplicaPassword"];
|
||||
this.clearMongoReplicaPassword = source["clearMongoReplicaPassword"];
|
||||
this.clearOpaqueURI = source["clearOpaqueURI"];
|
||||
this.clearOpaqueDSN = source["clearOpaqueDSN"];
|
||||
}
|
||||
|
||||
convertValues(a: any, classs: any, asMap: boolean = false): any {
|
||||
if (!a) {
|
||||
return a;
|
||||
}
|
||||
if (a.slice && a.map) {
|
||||
return (a as any[]).map(elem => this.convertValues(elem, classs));
|
||||
} else if ("object" === typeof a) {
|
||||
if (asMap) {
|
||||
for (const key of Object.keys(a)) {
|
||||
a[key] = new classs(a[key]);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
return new classs(a);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
}
|
||||
export class SavedConnectionView {
|
||||
id: string;
|
||||
name: string;
|
||||
config: ConnectionConfig;
|
||||
includeDatabases?: string[];
|
||||
includeRedisDatabases?: number[];
|
||||
iconType?: string;
|
||||
iconColor?: string;
|
||||
secretRef?: string;
|
||||
hasPrimaryPassword?: boolean;
|
||||
hasSSHPassword?: boolean;
|
||||
hasProxyPassword?: boolean;
|
||||
hasHttpTunnelPassword?: boolean;
|
||||
hasMySQLReplicaPassword?: boolean;
|
||||
hasMongoReplicaPassword?: boolean;
|
||||
hasOpaqueURI?: boolean;
|
||||
hasOpaqueDSN?: boolean;
|
||||
|
||||
static createFrom(source: any = {}) {
|
||||
return new SavedConnectionView(source);
|
||||
}
|
||||
|
||||
constructor(source: any = {}) {
|
||||
if ('string' === typeof source) source = JSON.parse(source);
|
||||
this.id = source["id"];
|
||||
this.name = source["name"];
|
||||
this.config = this.convertValues(source["config"], ConnectionConfig);
|
||||
this.includeDatabases = source["includeDatabases"];
|
||||
this.includeRedisDatabases = source["includeRedisDatabases"];
|
||||
this.iconType = source["iconType"];
|
||||
this.iconColor = source["iconColor"];
|
||||
this.secretRef = source["secretRef"];
|
||||
this.hasPrimaryPassword = source["hasPrimaryPassword"];
|
||||
this.hasSSHPassword = source["hasSSHPassword"];
|
||||
this.hasProxyPassword = source["hasProxyPassword"];
|
||||
this.hasHttpTunnelPassword = source["hasHttpTunnelPassword"];
|
||||
this.hasMySQLReplicaPassword = source["hasMySQLReplicaPassword"];
|
||||
this.hasMongoReplicaPassword = source["hasMongoReplicaPassword"];
|
||||
this.hasOpaqueURI = source["hasOpaqueURI"];
|
||||
this.hasOpaqueDSN = source["hasOpaqueDSN"];
|
||||
}
|
||||
|
||||
convertValues(a: any, classs: any, asMap: boolean = false): any {
|
||||
if (!a) {
|
||||
return a;
|
||||
}
|
||||
if (a.slice && a.map) {
|
||||
return (a as any[]).map(elem => this.convertValues(elem, classs));
|
||||
} else if ("object" === typeof a) {
|
||||
if (asMap) {
|
||||
for (const key of Object.keys(a)) {
|
||||
a[key] = new classs(a[key]);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
return new classs(a);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
8
go.mod
8
go.mod
@@ -28,11 +28,14 @@ require (
|
||||
|
||||
require (
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 // indirect
|
||||
github.com/99designs/keyring v1.2.2
|
||||
github.com/ClickHouse/ch-go v0.71.0 // indirect
|
||||
github.com/andybalholm/brotli v1.2.0 // indirect
|
||||
github.com/apache/arrow-go/v18 v18.5.1 // indirect
|
||||
github.com/bep/debounce v1.2.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/danieljoos/wincred v1.1.2 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/duckdb/duckdb-go-bindings v0.3.3 // indirect
|
||||
github.com/duckdb/duckdb-go-bindings/lib/darwin-amd64 v0.3.3 // indirect
|
||||
@@ -41,17 +44,20 @@ require (
|
||||
github.com/duckdb/duckdb-go-bindings/lib/linux-arm64 v0.3.3 // indirect
|
||||
github.com/duckdb/duckdb-go-bindings/lib/windows-amd64 v0.3.3 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/dvsekhvalnov/jose2go v1.5.0 // indirect
|
||||
github.com/go-faster/city v1.0.1 // indirect
|
||||
github.com/go-faster/errors v0.7.1 // indirect
|
||||
github.com/go-ole/go-ole v1.3.0 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
|
||||
github.com/goccy/go-json v0.10.5 // indirect
|
||||
github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2 // indirect
|
||||
github.com/godbus/dbus/v5 v5.1.0 // indirect
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||
github.com/golang/snappy v1.0.0 // indirect
|
||||
github.com/google/flatbuffers v25.12.19+incompatible // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c // indirect
|
||||
github.com/hashicorp/go-version v1.8.0 // indirect
|
||||
github.com/jchv/go-winloader v0.0.0-20210711035445-715c2860da7e // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
@@ -68,6 +74,7 @@ require (
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/montanaflynn/stats v0.7.1 // indirect
|
||||
github.com/mtibben/percent v0.2.1 // indirect
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/paulmach/orb v0.12.0 // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.25 // indirect
|
||||
@@ -100,6 +107,7 @@ require (
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/sys v0.40.0 // indirect
|
||||
golang.org/x/telemetry v0.0.0-20260116145544-c6413dc483f5 // indirect
|
||||
golang.org/x/term v0.39.0 // indirect
|
||||
golang.org/x/tools v0.41.0 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect
|
||||
modernc.org/libc v1.67.6 // indirect
|
||||
|
||||
17
go.sum
17
go.sum
@@ -4,6 +4,10 @@ gitea.com/kingbase/gokb v0.0.0-20201021123113-29bd62a876c3 h1:QjslQNaH5Nuap5i4ni
|
||||
gitea.com/kingbase/gokb v0.0.0-20201021123113-29bd62a876c3/go.mod h1:7lH5A1jzCXD9Nl16DzaBUOfDAT8NPrDmZwKu1p5wf94=
|
||||
gitee.com/chunanyong/dm v1.8.22 h1:H7fsrnUIvEA0jlDWew7vwELry1ff+tLMIu2Fk2cIBSg=
|
||||
gitee.com/chunanyong/dm v1.8.22/go.mod h1:EPRJnuPFgbyOFgJ0TRYCTGzhq+ZT4wdyaj/GW/LLcNg=
|
||||
github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 h1:/vQbFIOMbk2FiG/kXiLl8BRyzTWDw7gX/Hz7Dd5eDMs=
|
||||
github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4/go.mod h1:hN7oaIRCjzsZ2dE+yG5k+rsdt3qcwykqK6HVGcKwsw4=
|
||||
github.com/99designs/keyring v1.2.2 h1:pZd3neh/EmUzWONb35LxQfvuY7kiSXAq3HQd97+XBn0=
|
||||
github.com/99designs/keyring v1.2.2/go.mod h1:wes/FrByc8j7lFOAGLGSNEg8f/PaI3cgTBqhFkHUrPk=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 h1:Gt0j3wceWMwPmiazCa8MzMA0MfhmPIz0Qp0FJ6qcM0U=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0/go.mod h1:Ot/6aikWnKWi4l9QB7qVSwa8iMphQNqkWALMoNT3rzM=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 h1:B+blDbyVIG3WaikNxPnhPiJ1MThR03b3vKGtER95TP4=
|
||||
@@ -34,6 +38,8 @@ github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/danieljoos/wincred v1.1.2 h1:QLdCxFs1/Yl4zduvBdcHB8goaYk9RARS2SgLLRuAyr0=
|
||||
github.com/danieljoos/wincred v1.1.2/go.mod h1:GijpziifJoIBfYh+S7BbkdUTU4LfM+QnGqR5Vl2tAx0=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
@@ -56,6 +62,8 @@ github.com/duckdb/duckdb-go/v2 v2.5.5 h1:TlK8ipnzoKW2aNrjGqRkFWLCDpJDxR/VwH8ezEc
|
||||
github.com/duckdb/duckdb-go/v2 v2.5.5/go.mod h1:6uIbC3gz36NCEygECzboygOo/Z9TeVwox/puG+ohWV0=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/dvsekhvalnov/jose2go v1.5.0 h1:3j8ya4Z4kMCwT5nXIKFSV84YS+HdqSSO0VsTQxaLAeM=
|
||||
github.com/dvsekhvalnov/jose2go v1.5.0/go.mod h1:QsHjhyTlD/lAVqn/NSbVZmSCGeDehTB/mPZadG+mhXU=
|
||||
github.com/go-faster/city v1.0.1 h1:4WAxSZ3V2Ws4QRDrscLEDcibJY8uf41H6AhXDrNDcGw=
|
||||
github.com/go-faster/city v1.0.1/go.mod h1:jKcUJId49qdW3L1qKHH/3wPeUstCVpVSXTM6vO3VcTw=
|
||||
github.com/go-faster/errors v0.7.1 h1:MkJTnDoEdi9pDabt1dpWf7AA8/BaSYZqibYyhZ20AYg=
|
||||
@@ -68,6 +76,8 @@ github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPE
|
||||
github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2 h1:ZpnhV/YsD2/4cESfV5+Hoeu/iUR3ruzNvZ+yQfO03a0=
|
||||
github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4=
|
||||
github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
|
||||
github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
@@ -95,6 +105,8 @@ github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c h1:6rhixN/i8ZofjG1Y75iExal34USq5p+wiN1tpie8IrU=
|
||||
github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c/go.mod h1:NMPJylDgVpX0MLRlPy15sqSwOFv/U1GZ2m21JhFfek0=
|
||||
github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
|
||||
github.com/hashicorp/go-version v1.8.0 h1:KAkNb1HAiZd1ukkxDFGmokVZe1Xy9HG6NUp+bPle2i4=
|
||||
github.com/hashicorp/go-version v1.8.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
|
||||
@@ -158,6 +170,8 @@ github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjY
|
||||
github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc=
|
||||
github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE=
|
||||
github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
|
||||
github.com/mtibben/percent v0.2.1 h1:5gssi8Nqo8QU/r2pynCm+hBQHpkB/uNK7BJCFogWdzs=
|
||||
github.com/mtibben/percent v0.2.1/go.mod h1:KG9uO+SZkUp+VkRHsCdYQV3XSZrrSpR3O9ibNBTZrns=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
|
||||
@@ -201,6 +215,7 @@ github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
@@ -300,6 +315,7 @@ golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210819135213-f52c844e1c1c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
@@ -342,6 +358,7 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
|
||||
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@@ -3,8 +3,11 @@ package provider
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -324,6 +327,26 @@ func TestClaudeCLIProvider_ChatStreamReportsApiRetryAuthenticationFailure(t *tes
|
||||
func writeFakeClaudeScript(t *testing.T, content string) string {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
bashPath, err := resolveClaudeCodeGitBashPath(os.Environ(), runtime.GOOS, exec.LookPath, fileExists)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to resolve git bash for fake claude command: %v", err)
|
||||
}
|
||||
|
||||
scriptPath := filepath.Join(dir, "claude.sh")
|
||||
if err := os.WriteFile(scriptPath, []byte(content), 0o755); err != nil {
|
||||
t.Fatalf("failed to write fake claude shell script: %v", err)
|
||||
}
|
||||
|
||||
wrapperPath := filepath.Join(dir, "claude.cmd")
|
||||
wrapper := fmt.Sprintf("@echo off\r\n\"%s\" \"%s\" %%*\r\n", bashPath, scriptPath)
|
||||
if err := os.WriteFile(wrapperPath, []byte(wrapper), 0o755); err != nil {
|
||||
t.Fatalf("failed to write fake claude wrapper: %v", err)
|
||||
}
|
||||
return wrapperPath
|
||||
}
|
||||
|
||||
path := filepath.Join(dir, "claude")
|
||||
if err := os.WriteFile(path, []byte(content), 0o755); err != nil {
|
||||
t.Fatalf("failed to write fake claude script: %v", err)
|
||||
@@ -335,12 +358,19 @@ func overrideClaudeCLIForTest(t *testing.T, fakeClaudePath string) func() {
|
||||
t.Helper()
|
||||
|
||||
originalLookPath := claudeLookPath
|
||||
originalCommandContext := claudeCommandContext
|
||||
claudeLookPath = func(name string) (string, error) {
|
||||
if name == "claude" {
|
||||
return fakeClaudePath, nil
|
||||
}
|
||||
return originalLookPath(name)
|
||||
}
|
||||
claudeCommandContext = func(ctx context.Context, name string, args ...string) *exec.Cmd {
|
||||
if name == "claude" {
|
||||
return exec.CommandContext(ctx, fakeClaudePath, args...)
|
||||
}
|
||||
return originalCommandContext(ctx, name, args...)
|
||||
}
|
||||
|
||||
originalPath := os.Getenv("PATH")
|
||||
if err := os.Setenv("PATH", filepath.Dir(fakeClaudePath)+string(os.PathListSeparator)+originalPath); err != nil {
|
||||
@@ -349,6 +379,7 @@ func overrideClaudeCLIForTest(t *testing.T, fakeClaudePath string) func() {
|
||||
|
||||
return func() {
|
||||
claudeLookPath = originalLookPath
|
||||
claudeCommandContext = originalCommandContext
|
||||
_ = os.Setenv("PATH", originalPath)
|
||||
}
|
||||
}
|
||||
|
||||
231
internal/ai/service/provider_secret.go
Normal file
231
internal/ai/service/provider_secret.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package aiservice
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"GoNavi-Wails/internal/ai"
|
||||
"GoNavi-Wails/internal/secretstore"
|
||||
)
|
||||
|
||||
const providerSecretKind = "ai-provider"
|
||||
|
||||
type providerSecretBundle struct {
|
||||
APIKey string `json:"apiKey,omitempty"`
|
||||
SensitiveHeaders map[string]string `json:"sensitiveHeaders,omitempty"`
|
||||
}
|
||||
|
||||
func (b providerSecretBundle) hasAny() bool {
|
||||
return strings.TrimSpace(b.APIKey) != "" || len(b.SensitiveHeaders) > 0
|
||||
}
|
||||
|
||||
func mergeProviderSecretBundles(base, overlay providerSecretBundle) providerSecretBundle {
|
||||
merged := providerSecretBundle{
|
||||
APIKey: base.APIKey,
|
||||
SensitiveHeaders: cloneStringMap(base.SensitiveHeaders),
|
||||
}
|
||||
if strings.TrimSpace(overlay.APIKey) != "" {
|
||||
merged.APIKey = overlay.APIKey
|
||||
}
|
||||
for key, value := range overlay.SensitiveHeaders {
|
||||
if merged.SensitiveHeaders == nil {
|
||||
merged.SensitiveHeaders = make(map[string]string, len(overlay.SensitiveHeaders))
|
||||
}
|
||||
merged.SensitiveHeaders[key] = value
|
||||
}
|
||||
if len(merged.SensitiveHeaders) == 0 {
|
||||
merged.SensitiveHeaders = nil
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
func splitProviderSecrets(cfg ai.ProviderConfig) (ai.ProviderConfig, providerSecretBundle) {
|
||||
meta := cfg
|
||||
meta.APIKey = ""
|
||||
|
||||
bundle := providerSecretBundle{}
|
||||
if apiKey := strings.TrimSpace(cfg.APIKey); apiKey != "" {
|
||||
bundle.APIKey = apiKey
|
||||
}
|
||||
|
||||
if len(cfg.Headers) > 0 {
|
||||
safeHeaders := make(map[string]string, len(cfg.Headers))
|
||||
sensitiveHeaders := make(map[string]string)
|
||||
for key, value := range cfg.Headers {
|
||||
if isSensitiveProviderHeader(key) {
|
||||
if strings.TrimSpace(value) != "" {
|
||||
sensitiveHeaders[key] = value
|
||||
}
|
||||
continue
|
||||
}
|
||||
safeHeaders[key] = value
|
||||
}
|
||||
if len(safeHeaders) > 0 {
|
||||
meta.Headers = safeHeaders
|
||||
} else {
|
||||
meta.Headers = nil
|
||||
}
|
||||
if len(sensitiveHeaders) > 0 {
|
||||
bundle.SensitiveHeaders = sensitiveHeaders
|
||||
}
|
||||
} else {
|
||||
meta.Headers = nil
|
||||
}
|
||||
|
||||
meta.HasSecret = cfg.HasSecret || bundle.hasAny()
|
||||
meta.SecretRef = strings.TrimSpace(cfg.SecretRef)
|
||||
if meta.HasSecret && meta.SecretRef == "" && strings.TrimSpace(cfg.ID) != "" {
|
||||
if ref, err := secretstore.BuildRef(providerSecretKind, cfg.ID); err == nil {
|
||||
meta.SecretRef = ref
|
||||
}
|
||||
}
|
||||
if !meta.HasSecret {
|
||||
meta.SecretRef = ""
|
||||
}
|
||||
|
||||
return meta, bundle
|
||||
}
|
||||
|
||||
func mergeProviderSecrets(cfg ai.ProviderConfig, bundle providerSecretBundle) ai.ProviderConfig {
|
||||
merged := cfg
|
||||
merged.APIKey = bundle.APIKey
|
||||
|
||||
headers := cloneStringMap(cfg.Headers)
|
||||
if len(bundle.SensitiveHeaders) > 0 {
|
||||
if headers == nil {
|
||||
headers = make(map[string]string, len(bundle.SensitiveHeaders))
|
||||
}
|
||||
for key, value := range bundle.SensitiveHeaders {
|
||||
headers[key] = value
|
||||
}
|
||||
}
|
||||
if len(headers) > 0 {
|
||||
merged.Headers = headers
|
||||
} else {
|
||||
merged.Headers = nil
|
||||
}
|
||||
|
||||
merged.HasSecret = cfg.HasSecret || bundle.hasAny()
|
||||
if merged.HasSecret && strings.TrimSpace(merged.SecretRef) == "" && strings.TrimSpace(merged.ID) != "" {
|
||||
if ref, err := secretstore.BuildRef(providerSecretKind, merged.ID); err == nil {
|
||||
merged.SecretRef = ref
|
||||
}
|
||||
}
|
||||
if !merged.HasSecret {
|
||||
merged.SecretRef = ""
|
||||
}
|
||||
|
||||
return merged
|
||||
}
|
||||
|
||||
func (s *Service) persistProviderSecretBundle(meta ai.ProviderConfig, bundle providerSecretBundle) (ai.ProviderConfig, error) {
|
||||
meta, _ = splitProviderSecrets(meta)
|
||||
if !bundle.hasAny() {
|
||||
meta.HasSecret = false
|
||||
meta.SecretRef = ""
|
||||
return meta, nil
|
||||
}
|
||||
if s.secretStore == nil {
|
||||
return meta, fmt.Errorf("secret store unavailable")
|
||||
}
|
||||
if err := s.secretStore.HealthCheck(); err != nil {
|
||||
return meta, err
|
||||
}
|
||||
|
||||
ref := strings.TrimSpace(meta.SecretRef)
|
||||
if ref == "" {
|
||||
var err error
|
||||
ref, err = secretstore.BuildRef(providerSecretKind, meta.ID)
|
||||
if err != nil {
|
||||
return meta, err
|
||||
}
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(bundle)
|
||||
if err != nil {
|
||||
return meta, fmt.Errorf("序列化 provider secret bundle 失败: %w", err)
|
||||
}
|
||||
if err := s.secretStore.Put(ref, payload); err != nil {
|
||||
return meta, err
|
||||
}
|
||||
|
||||
meta.SecretRef = ref
|
||||
meta.HasSecret = true
|
||||
return meta, nil
|
||||
}
|
||||
|
||||
func (s *Service) resolveProviderConfigSecrets(cfg ai.ProviderConfig) (ai.ProviderConfig, error) {
|
||||
cfg = normalizeProviderConfig(cfg)
|
||||
meta, bundle := splitProviderSecrets(cfg)
|
||||
if bundle.hasAny() {
|
||||
return mergeProviderSecrets(meta, bundle), nil
|
||||
}
|
||||
if !meta.HasSecret {
|
||||
return meta, nil
|
||||
}
|
||||
if s.secretStore == nil {
|
||||
return meta, fmt.Errorf("secret store unavailable")
|
||||
}
|
||||
|
||||
ref := strings.TrimSpace(meta.SecretRef)
|
||||
if ref == "" {
|
||||
var err error
|
||||
ref, err = secretstore.BuildRef(providerSecretKind, meta.ID)
|
||||
if err != nil {
|
||||
return meta, err
|
||||
}
|
||||
meta.SecretRef = ref
|
||||
}
|
||||
|
||||
payload, err := s.secretStore.Get(ref)
|
||||
if err != nil {
|
||||
return meta, err
|
||||
}
|
||||
|
||||
var stored providerSecretBundle
|
||||
if err := json.Unmarshal(payload, &stored); err != nil {
|
||||
return meta, fmt.Errorf("解析 provider secret bundle 失败: %w", err)
|
||||
}
|
||||
return mergeProviderSecrets(meta, stored), nil
|
||||
}
|
||||
|
||||
func providerMetadataView(cfg ai.ProviderConfig) ai.ProviderConfig {
|
||||
meta, _ := splitProviderSecrets(normalizeProviderConfig(cfg))
|
||||
return meta
|
||||
}
|
||||
|
||||
func isSensitiveProviderHeader(name string) bool {
|
||||
normalized := strings.TrimSpace(strings.ToLower(name))
|
||||
switch normalized {
|
||||
case "authorization", "proxy-authorization", "x-api-key", "api-key":
|
||||
return true
|
||||
}
|
||||
|
||||
for _, token := range providerHeaderTokens(normalized) {
|
||||
switch token {
|
||||
case "auth", "authorization", "token", "secret", "key", "apikey":
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func providerHeaderTokens(name string) []string {
|
||||
return strings.FieldsFunc(name, func(r rune) bool {
|
||||
return !unicode.IsLetter(r) && !unicode.IsDigit(r)
|
||||
})
|
||||
}
|
||||
|
||||
func cloneStringMap(input map[string]string) map[string]string {
|
||||
if len(input) == 0 {
|
||||
return nil
|
||||
}
|
||||
cloned := make(map[string]string, len(input))
|
||||
for key, value := range input {
|
||||
cloned[key] = value
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
348
internal/ai/service/provider_secret_test.go
Normal file
348
internal/ai/service/provider_secret_test.go
Normal file
@@ -0,0 +1,348 @@
|
||||
package aiservice
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/ai"
|
||||
"GoNavi-Wails/internal/secretstore"
|
||||
)
|
||||
|
||||
func TestSplitProviderSecretsStripsAPIKeyAndSensitiveHeaders(t *testing.T) {
|
||||
input := ai.ProviderConfig{
|
||||
ID: "openai-main",
|
||||
APIKey: "sk-test",
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer test",
|
||||
"X-Team": "db",
|
||||
},
|
||||
}
|
||||
|
||||
meta, bundle := splitProviderSecrets(input)
|
||||
if meta.APIKey != "" {
|
||||
t.Fatal("apiKey should not stay in metadata")
|
||||
}
|
||||
if meta.Headers["Authorization"] != "" {
|
||||
t.Fatal("sensitive header should not stay in metadata")
|
||||
}
|
||||
if meta.Headers["X-Team"] != "db" {
|
||||
t.Fatal("non-sensitive header should stay in metadata")
|
||||
}
|
||||
if bundle.APIKey != "sk-test" {
|
||||
t.Fatal("bundle should keep apiKey")
|
||||
}
|
||||
if bundle.SensitiveHeaders["Authorization"] != "Bearer test" {
|
||||
t.Fatal("bundle should keep sensitive header")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveProviderConfigSecretsRestoresStoredSecretBundle(t *testing.T) {
|
||||
store := newFakeProviderSecretStore()
|
||||
service := NewServiceWithSecretStore(store)
|
||||
ref, err := secretstore.BuildRef("ai-provider", "openai-main")
|
||||
if err != nil {
|
||||
t.Fatalf("BuildRef returned error: %v", err)
|
||||
}
|
||||
payload, err := json.Marshal(providerSecretBundle{
|
||||
APIKey: "sk-test",
|
||||
SensitiveHeaders: map[string]string{
|
||||
"Authorization": "Bearer test",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal returned error: %v", err)
|
||||
}
|
||||
if err := store.Put(ref, payload); err != nil {
|
||||
t.Fatalf("Put returned error: %v", err)
|
||||
}
|
||||
|
||||
resolved, err := service.resolveProviderConfigSecrets(ai.ProviderConfig{
|
||||
ID: "openai-main",
|
||||
SecretRef: ref,
|
||||
HasSecret: true,
|
||||
Headers: map[string]string{
|
||||
"X-Team": "db",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("resolveProviderConfigSecrets returned error: %v", err)
|
||||
}
|
||||
if resolved.APIKey != "sk-test" {
|
||||
t.Fatalf("expected restored apiKey, got %q", resolved.APIKey)
|
||||
}
|
||||
if resolved.Headers["Authorization"] != "Bearer test" {
|
||||
t.Fatalf("expected restored sensitive header, got %#v", resolved.Headers)
|
||||
}
|
||||
if resolved.Headers["X-Team"] != "db" {
|
||||
t.Fatalf("expected non-sensitive header to survive, got %#v", resolved.Headers)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigMigratesPlaintextProviderSecrets(t *testing.T) {
|
||||
store := newFakeProviderSecretStore()
|
||||
service := NewServiceWithSecretStore(store)
|
||||
service.configDir = t.TempDir()
|
||||
|
||||
legacy := aiConfig{
|
||||
Providers: []ai.ProviderConfig{
|
||||
{
|
||||
ID: "openai-main",
|
||||
Type: "openai",
|
||||
Name: "OpenAI",
|
||||
APIKey: "sk-test",
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer test",
|
||||
"X-Team": "db",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data, err := json.MarshalIndent(legacy, "", " ")
|
||||
if err != nil {
|
||||
t.Fatalf("MarshalIndent returned error: %v", err)
|
||||
}
|
||||
configPath := filepath.Join(service.configDir, "ai_config.json")
|
||||
if err := os.WriteFile(configPath, data, 0o644); err != nil {
|
||||
t.Fatalf("WriteFile returned error: %v", err)
|
||||
}
|
||||
|
||||
service.loadConfig()
|
||||
|
||||
providers := service.AIGetProviders()
|
||||
if len(providers) != 1 {
|
||||
t.Fatalf("expected 1 provider, got %d", len(providers))
|
||||
}
|
||||
if providers[0].APIKey != "" {
|
||||
t.Fatalf("expected migrated provider to be secretless, got %q", providers[0].APIKey)
|
||||
}
|
||||
if !providers[0].HasSecret {
|
||||
t.Fatal("expected migrated provider to report HasSecret=true")
|
||||
}
|
||||
stored, err := store.Get(providers[0].SecretRef)
|
||||
if err != nil {
|
||||
t.Fatalf("expected secret bundle in store, got error: %v", err)
|
||||
}
|
||||
var bundle providerSecretBundle
|
||||
if err := json.Unmarshal(stored, &bundle); err != nil {
|
||||
t.Fatalf("Unmarshal returned error: %v", err)
|
||||
}
|
||||
if bundle.APIKey != "sk-test" {
|
||||
t.Fatalf("expected migrated apiKey in store, got %q", bundle.APIKey)
|
||||
}
|
||||
if bundle.SensitiveHeaders["Authorization"] != "Bearer test" {
|
||||
t.Fatalf("expected migrated sensitive header in store, got %#v", bundle.SensitiveHeaders)
|
||||
}
|
||||
|
||||
rewritten, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile returned error: %v", err)
|
||||
}
|
||||
text := string(rewritten)
|
||||
if strings.Contains(text, "sk-test") {
|
||||
t.Fatalf("expected rewritten config to remove api key, got %s", text)
|
||||
}
|
||||
if strings.Contains(text, "Bearer test") {
|
||||
t.Fatalf("expected rewritten config to remove sensitive header, got %s", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAISaveProviderPersistsSecretlessConfigAndReturnsSecretlessView(t *testing.T) {
|
||||
store := newFakeProviderSecretStore()
|
||||
service := NewServiceWithSecretStore(store)
|
||||
service.configDir = t.TempDir()
|
||||
|
||||
err := service.AISaveProvider(ai.ProviderConfig{
|
||||
ID: "openai-main",
|
||||
Type: "openai",
|
||||
Name: "OpenAI",
|
||||
APIKey: "sk-test",
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer test",
|
||||
"X-Team": "db",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("AISaveProvider returned error: %v", err)
|
||||
}
|
||||
|
||||
providers := service.AIGetProviders()
|
||||
if len(providers) != 1 {
|
||||
t.Fatalf("expected 1 provider, got %d", len(providers))
|
||||
}
|
||||
if providers[0].APIKey != "" {
|
||||
t.Fatalf("expected secretless provider view, got %q", providers[0].APIKey)
|
||||
}
|
||||
if !providers[0].HasSecret {
|
||||
t.Fatal("expected saved provider view to report HasSecret=true")
|
||||
}
|
||||
if providers[0].Headers["Authorization"] != "" {
|
||||
t.Fatalf("expected secretless provider headers, got %#v", providers[0].Headers)
|
||||
}
|
||||
if service.providers[0].APIKey != "sk-test" {
|
||||
t.Fatalf("expected runtime provider to keep apiKey, got %q", service.providers[0].APIKey)
|
||||
}
|
||||
if service.providers[0].Headers["Authorization"] != "Bearer test" {
|
||||
t.Fatalf("expected runtime provider to keep sensitive header, got %#v", service.providers[0].Headers)
|
||||
}
|
||||
|
||||
configPath := filepath.Join(service.configDir, "ai_config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile returned error: %v", err)
|
||||
}
|
||||
text := string(data)
|
||||
if strings.Contains(text, "sk-test") {
|
||||
t.Fatalf("expected config file to be secretless, got %s", text)
|
||||
}
|
||||
if strings.Contains(text, "Bearer test") {
|
||||
t.Fatalf("expected config file to remove sensitive headers, got %s", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAISaveProviderKeepsExistingSecretWhenInputOmitsAPIKey(t *testing.T) {
|
||||
store := newFakeProviderSecretStore()
|
||||
service := NewServiceWithSecretStore(store)
|
||||
service.configDir = t.TempDir()
|
||||
|
||||
if err := service.AISaveProvider(ai.ProviderConfig{
|
||||
ID: "openai-main",
|
||||
Type: "openai",
|
||||
Name: "OpenAI",
|
||||
APIKey: "sk-original",
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer original",
|
||||
"X-Team": "db",
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("initial AISaveProvider returned error: %v", err)
|
||||
}
|
||||
|
||||
if err := service.AISaveProvider(ai.ProviderConfig{
|
||||
ID: "openai-main",
|
||||
Type: "openai",
|
||||
Name: "OpenAI Updated",
|
||||
BaseURL: "https://gateway.openai.com/v1",
|
||||
HasSecret: true,
|
||||
Headers: map[string]string{
|
||||
"X-Team": "platform",
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("update AISaveProvider returned error: %v", err)
|
||||
}
|
||||
|
||||
if service.providers[0].APIKey != "sk-original" {
|
||||
t.Fatalf("expected runtime provider to keep original apiKey, got %q", service.providers[0].APIKey)
|
||||
}
|
||||
if service.providers[0].Headers["Authorization"] != "Bearer original" {
|
||||
t.Fatalf("expected runtime provider to keep original sensitive header, got %#v", service.providers[0].Headers)
|
||||
}
|
||||
if service.providers[0].Headers["X-Team"] != "platform" {
|
||||
t.Fatalf("expected runtime provider to update non-sensitive headers, got %#v", service.providers[0].Headers)
|
||||
}
|
||||
if service.providers[0].BaseURL != "https://gateway.openai.com/v1" {
|
||||
t.Fatalf("expected runtime provider to update metadata, got %q", service.providers[0].BaseURL)
|
||||
}
|
||||
|
||||
providers := service.AIGetProviders()
|
||||
if len(providers) != 1 || !providers[0].HasSecret {
|
||||
t.Fatalf("expected provider view to keep HasSecret=true, got %#v", providers)
|
||||
}
|
||||
if providers[0].APIKey != "" {
|
||||
t.Fatalf("expected provider view to stay secretless, got %q", providers[0].APIKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAISaveProviderMergesStoredSensitiveHeadersWhenUpdatingOnlyAPIKey(t *testing.T) {
|
||||
store := newFakeProviderSecretStore()
|
||||
service := NewServiceWithSecretStore(store)
|
||||
service.configDir = t.TempDir()
|
||||
|
||||
if err := service.AISaveProvider(ai.ProviderConfig{
|
||||
ID: "openai-main",
|
||||
Type: "openai",
|
||||
Name: "OpenAI",
|
||||
APIKey: "sk-original",
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer original",
|
||||
"X-Team": "db",
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("initial AISaveProvider returned error: %v", err)
|
||||
}
|
||||
|
||||
if err := service.AISaveProvider(ai.ProviderConfig{
|
||||
ID: "openai-main",
|
||||
Type: "openai",
|
||||
Name: "OpenAI",
|
||||
APIKey: "sk-updated",
|
||||
HasSecret: true,
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
Headers: map[string]string{
|
||||
"X-Team": "db",
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("update AISaveProvider returned error: %v", err)
|
||||
}
|
||||
|
||||
if service.providers[0].APIKey != "sk-updated" {
|
||||
t.Fatalf("expected updated apiKey, got %q", service.providers[0].APIKey)
|
||||
}
|
||||
if service.providers[0].Headers["Authorization"] != "Bearer original" {
|
||||
t.Fatalf("expected existing sensitive header to be kept, got %#v", service.providers[0].Headers)
|
||||
}
|
||||
|
||||
stored, err := store.Get(service.providers[0].SecretRef)
|
||||
if err != nil {
|
||||
t.Fatalf("expected merged secret bundle in store, got %v", err)
|
||||
}
|
||||
var bundle providerSecretBundle
|
||||
if err := json.Unmarshal(stored, &bundle); err != nil {
|
||||
t.Fatalf("Unmarshal returned error: %v", err)
|
||||
}
|
||||
if bundle.APIKey != "sk-updated" {
|
||||
t.Fatalf("expected store to keep updated apiKey, got %q", bundle.APIKey)
|
||||
}
|
||||
if bundle.SensitiveHeaders["Authorization"] != "Bearer original" {
|
||||
t.Fatalf("expected store to keep existing sensitive header, got %#v", bundle.SensitiveHeaders)
|
||||
}
|
||||
}
|
||||
|
||||
type fakeProviderSecretStore struct {
|
||||
items map[string][]byte
|
||||
}
|
||||
|
||||
func newFakeProviderSecretStore() *fakeProviderSecretStore {
|
||||
return &fakeProviderSecretStore{items: make(map[string][]byte)}
|
||||
}
|
||||
|
||||
func (s *fakeProviderSecretStore) Put(ref string, payload []byte) error {
|
||||
s.items[ref] = append([]byte(nil), payload...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeProviderSecretStore) Get(ref string) ([]byte, error) {
|
||||
payload, ok := s.items[ref]
|
||||
if !ok {
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
return append([]byte(nil), payload...), nil
|
||||
}
|
||||
|
||||
func (s *fakeProviderSecretStore) Delete(ref string) error {
|
||||
delete(s.items, ref)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeProviderSecretStore) HealthCheck() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ secretstore.SecretStore = (*fakeProviderSecretStore)(nil)
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"GoNavi-Wails/internal/ai/provider"
|
||||
"GoNavi-Wails/internal/ai/safety"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"GoNavi-Wails/internal/secretstore"
|
||||
|
||||
"github.com/google/uuid"
|
||||
wailsRuntime "github.com/wailsapp/wails/v2/pkg/runtime"
|
||||
@@ -32,7 +33,8 @@ type Service struct {
|
||||
safetyLevel ai.SQLPermissionLevel
|
||||
contextLevel ai.ContextLevel
|
||||
guard *safety.Guard
|
||||
configDir string // 配置存储目录
|
||||
configDir string // 配置存储目录
|
||||
secretStore secretstore.SecretStore
|
||||
cancelFuncs map[string]context.CancelFunc // 记录每个 session 的 context 取消函数
|
||||
}
|
||||
|
||||
@@ -97,11 +99,19 @@ var claudeCLIHealthCheckFunc = func(config ai.ProviderConfig) error {
|
||||
|
||||
// NewService 创建 AI Service 实例
|
||||
func NewService() *Service {
|
||||
return NewServiceWithSecretStore(secretstore.NewKeyringStore())
|
||||
}
|
||||
|
||||
func NewServiceWithSecretStore(store secretstore.SecretStore) *Service {
|
||||
if store == nil {
|
||||
store = secretstore.NewUnavailableStore("secret store unavailable")
|
||||
}
|
||||
return &Service{
|
||||
providers: make([]ai.ProviderConfig, 0),
|
||||
safetyLevel: ai.PermissionReadOnly,
|
||||
contextLevel: ai.ContextSchemaOnly,
|
||||
guard: safety.NewGuard(ai.PermissionReadOnly),
|
||||
secretStore: store,
|
||||
cancelFuncs: make(map[string]context.CancelFunc),
|
||||
}
|
||||
}
|
||||
@@ -127,35 +137,80 @@ func (s *Service) AIGetProviders() []ai.ProviderConfig {
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
result := make([]ai.ProviderConfig, len(s.providers))
|
||||
copy(result, s.providers)
|
||||
for i := range result {
|
||||
result[i] = normalizeProviderConfig(result[i])
|
||||
for i := range s.providers {
|
||||
result[i] = providerMetadataView(s.providers[i])
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// AISaveProvider 保存/更新 Provider 配置
|
||||
func (s *Service) AISaveProvider(config ai.ProviderConfig) error {
|
||||
fmt.Printf("[AISaveProvider DEBUG] ID: %s, Model: %s\n", config.ID, config.Model)
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
config = normalizeProviderConfig(config)
|
||||
|
||||
if strings.TrimSpace(config.ID) == "" {
|
||||
config.ID = "provider-" + uuid.New().String()[:8]
|
||||
}
|
||||
|
||||
var existing ai.ProviderConfig
|
||||
found := false
|
||||
for i, p := range s.providers {
|
||||
if p.ID == config.ID {
|
||||
s.providers[i] = config
|
||||
for _, providerConfig := range s.providers {
|
||||
if providerConfig.ID == config.ID {
|
||||
existing = providerConfig
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
s.providers = append(s.providers, config)
|
||||
|
||||
meta, bundle := splitProviderSecrets(config)
|
||||
var runtimeConfig ai.ProviderConfig
|
||||
switch {
|
||||
case bundle.hasAny():
|
||||
mergedBundle := bundle
|
||||
if found && existing.HasSecret {
|
||||
_, existingBundle := splitProviderSecrets(existing)
|
||||
mergedBundle = mergeProviderSecretBundles(existingBundle, bundle)
|
||||
}
|
||||
if found && strings.TrimSpace(meta.SecretRef) == "" {
|
||||
meta.SecretRef = existing.SecretRef
|
||||
}
|
||||
storedMeta, err := s.persistProviderSecretBundle(meta, mergedBundle)
|
||||
if err != nil {
|
||||
return fmt.Errorf("保存 Provider secret 失败: %w", err)
|
||||
}
|
||||
runtimeConfig = mergeProviderSecrets(storedMeta, mergedBundle)
|
||||
case found && (config.HasSecret || existing.HasSecret):
|
||||
meta.SecretRef = existing.SecretRef
|
||||
meta.HasSecret = config.HasSecret || existing.HasSecret
|
||||
resolved, err := s.resolveProviderConfigSecrets(meta)
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取已保存 Provider secret 失败: %w", err)
|
||||
}
|
||||
runtimeConfig = resolved
|
||||
default:
|
||||
runtimeConfig = meta
|
||||
}
|
||||
|
||||
if !runtimeConfig.HasSecret && found && strings.TrimSpace(existing.SecretRef) != "" {
|
||||
if err := s.secretStore.Delete(existing.SecretRef); err != nil {
|
||||
return fmt.Errorf("删除 Provider secret 失败: %w", err)
|
||||
}
|
||||
}
|
||||
if !runtimeConfig.HasSecret {
|
||||
runtimeConfig.SecretRef = ""
|
||||
}
|
||||
|
||||
runtimeConfig = normalizeProviderConfig(runtimeConfig)
|
||||
if found {
|
||||
for i := range s.providers {
|
||||
if s.providers[i].ID == runtimeConfig.ID {
|
||||
s.providers[i] = runtimeConfig
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
s.providers = append(s.providers, runtimeConfig)
|
||||
}
|
||||
|
||||
return s.saveConfig()
|
||||
@@ -167,9 +222,19 @@ func (s *Service) AIDeleteProvider(id string) error {
|
||||
defer s.mu.Unlock()
|
||||
|
||||
newProviders := make([]ai.ProviderConfig, 0, len(s.providers))
|
||||
for _, p := range s.providers {
|
||||
if p.ID != id {
|
||||
newProviders = append(newProviders, p)
|
||||
var removed ai.ProviderConfig
|
||||
removedFound := false
|
||||
for _, providerConfig := range s.providers {
|
||||
if providerConfig.ID == id {
|
||||
removed = providerConfig
|
||||
removedFound = true
|
||||
continue
|
||||
}
|
||||
newProviders = append(newProviders, providerConfig)
|
||||
}
|
||||
if removedFound && strings.TrimSpace(removed.SecretRef) != "" {
|
||||
if err := s.secretStore.Delete(removed.SecretRef); err != nil {
|
||||
return fmt.Errorf("删除 Provider secret 失败: %w", err)
|
||||
}
|
||||
}
|
||||
s.providers = newProviders
|
||||
@@ -186,17 +251,29 @@ func (s *Service) AIDeleteProvider(id string) error {
|
||||
|
||||
// AITestProvider 测试 Provider 配置是否可用,仅测试端点连通性与密钥,不实际调用对话
|
||||
func (s *Service) AITestProvider(config ai.ProviderConfig) map[string]interface{} {
|
||||
// 如果传入脱敏的 key,使用已保存的 key
|
||||
s.mu.RLock()
|
||||
if isMaskedAPIKey(config.APIKey) {
|
||||
for _, p := range s.providers {
|
||||
if p.ID == config.ID {
|
||||
config.APIKey = p.APIKey
|
||||
break
|
||||
config.APIKey = ""
|
||||
config.HasSecret = true
|
||||
}
|
||||
if strings.TrimSpace(config.APIKey) == "" && (config.HasSecret || strings.TrimSpace(config.SecretRef) != "") {
|
||||
s.mu.RLock()
|
||||
if strings.TrimSpace(config.SecretRef) == "" {
|
||||
for _, providerConfig := range s.providers {
|
||||
if providerConfig.ID == config.ID {
|
||||
config.SecretRef = providerConfig.SecretRef
|
||||
config.HasSecret = config.HasSecret || providerConfig.HasSecret
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
resolved, err := s.resolveProviderConfigSecrets(config)
|
||||
if err != nil {
|
||||
return map[string]interface{}{"success": false, "message": fmt.Sprintf("连接测试失败: %s", err.Error())}
|
||||
}
|
||||
config = resolved
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
config = normalizeProviderConfig(config)
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(config.BaseURL), "/")
|
||||
@@ -842,13 +919,35 @@ func (s *Service) getActiveProvider() (provider.Provider, error) {
|
||||
|
||||
// --- 配置持久化 ---
|
||||
|
||||
const aiConfigSchemaVersion = 2
|
||||
|
||||
type aiConfig struct {
|
||||
SchemaVersion int `json:"schemaVersion,omitempty"`
|
||||
Providers []ai.ProviderConfig `json:"providers"`
|
||||
ActiveProvider string `json:"activeProvider"`
|
||||
SafetyLevel string `json:"safetyLevel"`
|
||||
ContextLevel string `json:"contextLevel"`
|
||||
}
|
||||
|
||||
func (s *Service) loadRuntimeProviderConfig(config ai.ProviderConfig) (ai.ProviderConfig, bool, error) {
|
||||
meta, bundle := splitProviderSecrets(config)
|
||||
if bundle.hasAny() {
|
||||
storedMeta, err := s.persistProviderSecretBundle(meta, bundle)
|
||||
if err != nil {
|
||||
meta.HasSecret = false
|
||||
meta.SecretRef = ""
|
||||
return meta, true, err
|
||||
}
|
||||
return mergeProviderSecrets(storedMeta, bundle), true, nil
|
||||
}
|
||||
|
||||
resolved, err := s.resolveProviderConfigSecrets(meta)
|
||||
if err != nil {
|
||||
return meta, false, err
|
||||
}
|
||||
return resolved, false, nil
|
||||
}
|
||||
|
||||
func (s *Service) loadConfig() {
|
||||
path := filepath.Join(s.configDir, "ai_config.json")
|
||||
data, err := os.ReadFile(path)
|
||||
@@ -862,13 +961,22 @@ func (s *Service) loadConfig() {
|
||||
return
|
||||
}
|
||||
|
||||
s.providers = cfg.Providers
|
||||
if s.providers == nil {
|
||||
s.providers = make([]ai.ProviderConfig, 0)
|
||||
providers := make([]ai.ProviderConfig, 0, len(cfg.Providers))
|
||||
shouldRewrite := cfg.SchemaVersion != aiConfigSchemaVersion
|
||||
for _, providerConfig := range cfg.Providers {
|
||||
runtimeConfig, rewritten, err := s.loadRuntimeProviderConfig(normalizeProviderConfig(providerConfig))
|
||||
if err != nil {
|
||||
logger.Error(err, "加载 AI Provider secret 失败,provider=%s", providerConfig.ID)
|
||||
}
|
||||
if rewritten {
|
||||
shouldRewrite = true
|
||||
}
|
||||
providers = append(providers, runtimeConfig)
|
||||
}
|
||||
for i := range s.providers {
|
||||
s.providers[i] = normalizeProviderConfig(s.providers[i])
|
||||
if providers == nil {
|
||||
providers = make([]ai.ProviderConfig, 0)
|
||||
}
|
||||
s.providers = providers
|
||||
s.activeProvider = cfg.ActiveProvider
|
||||
|
||||
switch ai.SQLPermissionLevel(cfg.SafetyLevel) {
|
||||
@@ -885,11 +993,23 @@ func (s *Service) loadConfig() {
|
||||
default:
|
||||
s.contextLevel = ai.ContextSchemaOnly
|
||||
}
|
||||
|
||||
if shouldRewrite {
|
||||
if err := s.saveConfig(); err != nil {
|
||||
logger.Error(err, "重写 AI 配置失败")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) saveConfig() error {
|
||||
providers := make([]ai.ProviderConfig, len(s.providers))
|
||||
for i := range s.providers {
|
||||
providers[i] = providerMetadataView(s.providers[i])
|
||||
}
|
||||
|
||||
cfg := aiConfig{
|
||||
Providers: s.providers,
|
||||
SchemaVersion: aiConfigSchemaVersion,
|
||||
Providers: providers,
|
||||
ActiveProvider: s.activeProvider,
|
||||
SafetyLevel: string(s.safetyLevel),
|
||||
ContextLevel: string(s.contextLevel),
|
||||
|
||||
@@ -69,6 +69,8 @@ type ProviderConfig struct {
|
||||
Type string `json:"type"` // openai | anthropic | gemini | custom
|
||||
Name string `json:"name"`
|
||||
APIKey string `json:"apiKey"`
|
||||
SecretRef string `json:"secretRef,omitempty"`
|
||||
HasSecret bool `json:"hasSecret,omitempty"`
|
||||
BaseURL string `json:"baseUrl"`
|
||||
Model string `json:"model"`
|
||||
Models []string `json:"models,omitempty"`
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"GoNavi-Wails/internal/db"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
proxytunnel "GoNavi-Wails/internal/proxy"
|
||||
"GoNavi-Wails/internal/secretstore"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
@@ -53,14 +54,25 @@ type App struct {
|
||||
updateMu sync.Mutex
|
||||
updateState updateState
|
||||
queryMu sync.RWMutex
|
||||
configDir string
|
||||
secretStore secretstore.SecretStore
|
||||
runningQueries map[string]queryContext // queryID -> cancelFunc and start time
|
||||
}
|
||||
|
||||
// NewApp creates a new App application struct
|
||||
func NewApp() *App {
|
||||
return NewAppWithSecretStore(secretstore.NewKeyringStore())
|
||||
}
|
||||
|
||||
func NewAppWithSecretStore(store secretstore.SecretStore) *App {
|
||||
if store == nil {
|
||||
store = secretstore.NewUnavailableStore("secret store unavailable")
|
||||
}
|
||||
return &App{
|
||||
dbCache: make(map[string]cachedDatabase),
|
||||
runningQueries: make(map[string]queryContext),
|
||||
configDir: resolveAppConfigDir(),
|
||||
secretStore: store,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,7 +86,11 @@ func InitializeLifecycle(a *App, ctx context.Context) {
|
||||
func (a *App) startup(ctx context.Context) {
|
||||
a.ctx = ctx
|
||||
a.startedAt = time.Now()
|
||||
if strings.TrimSpace(a.configDir) == "" {
|
||||
a.configDir = resolveAppConfigDir()
|
||||
}
|
||||
logger.Init()
|
||||
a.loadPersistedGlobalProxy()
|
||||
applyMacWindowTranslucencyFix()
|
||||
logger.Infof("应用启动完成(首次连接保护窗口=%s,最多重试=%d 次)", startupConnectRetryWindow, startupConnectRetryAttempts)
|
||||
}
|
||||
@@ -111,6 +127,7 @@ func (a *App) Shutdown(ctx context.Context) {
|
||||
|
||||
func normalizeCacheKeyConfig(config connection.ConnectionConfig) connection.ConnectionConfig {
|
||||
normalized := config
|
||||
normalized.ID = ""
|
||||
normalized.Type = strings.ToLower(strings.TrimSpace(normalized.Type))
|
||||
// timeout 仅用于 Query/Ping 控制,不应作为物理连接复用键的一部分。
|
||||
normalized.Timeout = 0
|
||||
@@ -216,6 +233,9 @@ func shouldRefreshCachedConnection(err error) bool {
|
||||
}
|
||||
|
||||
func (a *App) invalidateCachedDatabase(config connection.ConnectionConfig, reason error) bool {
|
||||
if resolvedConfig, err := a.resolveConnectionSecrets(config); err == nil {
|
||||
config = resolvedConfig
|
||||
}
|
||||
effectiveConfig := applyGlobalProxyToConnection(config)
|
||||
key := getCacheKey(effectiveConfig)
|
||||
shortKey := shortCacheKey(key)
|
||||
@@ -439,7 +459,11 @@ func (a *App) getDatabase(config connection.ConnectionConfig) (db.Database, erro
|
||||
}
|
||||
|
||||
func (a *App) openDatabaseIsolated(config connection.ConnectionConfig) (db.Database, error) {
|
||||
effectiveConfig := applyGlobalProxyToConnection(config)
|
||||
resolvedConfig, err := a.resolveConnectionSecrets(config)
|
||||
if err != nil {
|
||||
return nil, wrapConnectError(config, err)
|
||||
}
|
||||
effectiveConfig := applyGlobalProxyToConnection(resolvedConfig)
|
||||
if supported, reason := db.DriverRuntimeSupportStatus(effectiveConfig.Type); !supported {
|
||||
if strings.TrimSpace(reason) == "" {
|
||||
reason = fmt.Sprintf("%s 驱动未启用,请先在驱动管理中安装启用", strings.TrimSpace(effectiveConfig.Type))
|
||||
@@ -465,7 +489,11 @@ func (a *App) openDatabaseIsolated(config connection.ConnectionConfig) (db.Datab
|
||||
}
|
||||
|
||||
func (a *App) getDatabaseWithPing(config connection.ConnectionConfig, forcePing bool) (db.Database, error) {
|
||||
effectiveConfig := applyGlobalProxyToConnection(config)
|
||||
resolvedConfig, err := a.resolveConnectionSecrets(config)
|
||||
if err != nil {
|
||||
return nil, wrapConnectError(config, err)
|
||||
}
|
||||
effectiveConfig := applyGlobalProxyToConnection(resolvedConfig)
|
||||
isFileDB := isFileDatabaseType(effectiveConfig.Type)
|
||||
|
||||
key := getCacheKey(effectiveConfig)
|
||||
@@ -546,7 +574,7 @@ func (a *App) getDatabaseWithPing(config connection.ConnectionConfig, forcePing
|
||||
logger.Infof("未命中文件库连接缓存,开始创建连接:类型=%s 缓存Key=%s", strings.TrimSpace(effectiveConfig.Type), shortKey)
|
||||
}
|
||||
|
||||
dbInst, connectedConfig, err := a.connectDatabaseWithStartupRetry(config)
|
||||
dbInst, connectedConfig, err := a.connectDatabaseWithStartupRetry(resolvedConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -581,6 +609,12 @@ func shortenCacheKey(key string) string {
|
||||
}
|
||||
|
||||
func (a *App) connectDatabaseWithStartupRetry(rawConfig connection.ConnectionConfig) (db.Database, connection.ConnectionConfig, error) {
|
||||
resolvedConfig, err := a.resolveConnectionSecrets(rawConfig)
|
||||
if err != nil {
|
||||
return nil, rawConfig, wrapConnectError(rawConfig, err)
|
||||
}
|
||||
rawConfig = resolvedConfig
|
||||
|
||||
var lastErr error
|
||||
var lastEffectiveConfig connection.ConnectionConfig
|
||||
|
||||
|
||||
@@ -24,6 +24,25 @@ func TestGetCacheKey_IgnoreTimeout(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCacheKey_IgnoreConnectionID(t *testing.T) {
|
||||
base := connection.ConnectionConfig{
|
||||
ID: "conn-1",
|
||||
Type: "mysql",
|
||||
Host: "127.0.0.1",
|
||||
Port: 3306,
|
||||
User: "root",
|
||||
Password: "root",
|
||||
}
|
||||
modified := base
|
||||
modified.ID = "conn-2"
|
||||
|
||||
left := getCacheKey(base)
|
||||
right := getCacheKey(modified)
|
||||
if left != right {
|
||||
t.Fatalf("expected same cache key when only connection id differs, got %s vs %s", left, right)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCacheKey_DuckDBHostAndDatabaseEquivalent(t *testing.T) {
|
||||
withHost := connection.ConnectionConfig{
|
||||
Type: "duckdb",
|
||||
|
||||
71
internal/app/connection_secret_resolution.go
Normal file
71
internal/app/connection_secret_resolution.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
func (a *App) resolveConnectionSecrets(config connection.ConnectionConfig) (connection.ConnectionConfig, error) {
|
||||
if strings.TrimSpace(config.ID) == "" {
|
||||
return config, nil
|
||||
}
|
||||
|
||||
repo := newSavedConnectionRepository(a.configDir, a.secretStore)
|
||||
view, err := repo.Find(config.ID)
|
||||
if err != nil {
|
||||
return config, err
|
||||
}
|
||||
|
||||
base := config
|
||||
if connectionMetadataLooksEmpty(base) {
|
||||
base = view.Config
|
||||
}
|
||||
bundle, err := repo.loadSecretBundle(view)
|
||||
if err != nil {
|
||||
return base, err
|
||||
}
|
||||
resolved := mergeConnectionSecretBundleIntoConfig(base, bundle)
|
||||
resolved.ID = view.ID
|
||||
return resolved, nil
|
||||
}
|
||||
|
||||
func connectionMetadataLooksEmpty(config connection.ConnectionConfig) bool {
|
||||
return strings.TrimSpace(config.Type) == "" &&
|
||||
strings.TrimSpace(config.Host) == "" &&
|
||||
config.Port == 0 &&
|
||||
strings.TrimSpace(config.User) == "" &&
|
||||
strings.TrimSpace(config.Database) == "" &&
|
||||
strings.TrimSpace(config.DSN) == "" &&
|
||||
strings.TrimSpace(config.URI) == "" &&
|
||||
len(config.Hosts) == 0
|
||||
}
|
||||
|
||||
func mergeConnectionSecretBundleIntoConfig(config connection.ConnectionConfig, bundle connectionSecretBundle) connection.ConnectionConfig {
|
||||
merged := config
|
||||
if strings.TrimSpace(merged.Password) == "" {
|
||||
merged.Password = bundle.Password
|
||||
}
|
||||
if strings.TrimSpace(merged.SSH.Password) == "" {
|
||||
merged.SSH.Password = bundle.SSHPassword
|
||||
}
|
||||
if strings.TrimSpace(merged.Proxy.Password) == "" {
|
||||
merged.Proxy.Password = bundle.ProxyPassword
|
||||
}
|
||||
if strings.TrimSpace(merged.HTTPTunnel.Password) == "" {
|
||||
merged.HTTPTunnel.Password = bundle.HTTPTunnelPassword
|
||||
}
|
||||
if strings.TrimSpace(merged.MySQLReplicaPassword) == "" {
|
||||
merged.MySQLReplicaPassword = bundle.MySQLReplicaPassword
|
||||
}
|
||||
if strings.TrimSpace(merged.MongoReplicaPassword) == "" {
|
||||
merged.MongoReplicaPassword = bundle.MongoReplicaPassword
|
||||
}
|
||||
if strings.TrimSpace(merged.URI) == "" {
|
||||
merged.URI = bundle.OpaqueURI
|
||||
}
|
||||
if strings.TrimSpace(merged.DSN) == "" {
|
||||
merged.DSN = bundle.OpaqueDSN
|
||||
}
|
||||
return merged
|
||||
}
|
||||
42
internal/app/connection_secret_resolution_test.go
Normal file
42
internal/app/connection_secret_resolution_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
func TestResolveConnectionConfigByIDLoadsSecretsFromStore(t *testing.T) {
|
||||
store := newFakeAppSecretStore()
|
||||
app := NewAppWithSecretStore(store)
|
||||
app.configDir = t.TempDir()
|
||||
|
||||
repo := newSavedConnectionRepository(app.configDir, store)
|
||||
view, err := repo.Save(connection.SavedConnectionInput{
|
||||
ID: "conn-1",
|
||||
Name: "Primary",
|
||||
Config: connection.ConnectionConfig{
|
||||
ID: "conn-1",
|
||||
Type: "postgres",
|
||||
Host: "db.local",
|
||||
Port: 5432,
|
||||
User: "postgres",
|
||||
Password: "postgres-secret",
|
||||
DSN: "postgres://user:pass@db.local/app",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Save returned error: %v", err)
|
||||
}
|
||||
|
||||
resolved, err := app.resolveConnectionSecrets(view.Config)
|
||||
if err != nil {
|
||||
t.Fatalf("resolveConnectionSecrets returned error: %v", err)
|
||||
}
|
||||
if resolved.Password != "postgres-secret" {
|
||||
t.Fatalf("expected restored password, got %q", resolved.Password)
|
||||
}
|
||||
if resolved.DSN != "postgres://user:pass@db.local/app" {
|
||||
t.Fatalf("expected restored DSN, got %q", resolved.DSN)
|
||||
}
|
||||
}
|
||||
@@ -123,11 +123,26 @@ func proxyConfigEqual(a, b connection.ProxyConfig) bool {
|
||||
a.Password == b.Password
|
||||
}
|
||||
|
||||
func currentGlobalProxyView() connection.GlobalProxyView {
|
||||
snapshot := currentGlobalProxyConfig()
|
||||
if !snapshot.Enabled {
|
||||
return connection.GlobalProxyView{Enabled: false}
|
||||
}
|
||||
return connection.GlobalProxyView{
|
||||
Enabled: true,
|
||||
Type: snapshot.Proxy.Type,
|
||||
Host: snapshot.Proxy.Host,
|
||||
Port: snapshot.Proxy.Port,
|
||||
User: snapshot.Proxy.User,
|
||||
HasPassword: strings.TrimSpace(snapshot.Proxy.Password) != "",
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) GetGlobalProxyConfig() connection.QueryResult {
|
||||
return connection.QueryResult{
|
||||
Success: true,
|
||||
Message: "OK",
|
||||
Data: currentGlobalProxyConfig(),
|
||||
Data: currentGlobalProxyView(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -312,3 +327,4 @@ func buildProxyURLFromConfig(proxyConfig connection.ProxyConfig) (*url.URL, erro
|
||||
}
|
||||
return proxyURL, nil
|
||||
}
|
||||
|
||||
|
||||
208
internal/app/global_proxy_persistence.go
Normal file
208
internal/app/global_proxy_persistence.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"GoNavi-Wails/internal/secretstore"
|
||||
)
|
||||
|
||||
const (
|
||||
globalProxyFileName = "global_proxy.json"
|
||||
globalProxySecretKind = "global-proxy"
|
||||
globalProxySecretID = "default"
|
||||
)
|
||||
|
||||
type globalProxySecretBundle struct {
|
||||
Password string `json:"password,omitempty"`
|
||||
}
|
||||
|
||||
func globalProxyMetadataPath(configDir string) string {
|
||||
return filepath.Join(configDir, globalProxyFileName)
|
||||
}
|
||||
|
||||
func (a *App) saveGlobalProxy(input connection.SaveGlobalProxyInput) (connection.GlobalProxyView, error) {
|
||||
if strings.TrimSpace(a.configDir) == "" {
|
||||
a.configDir = resolveAppConfigDir()
|
||||
}
|
||||
|
||||
existing, err := a.loadStoredGlobalProxyView()
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return connection.GlobalProxyView{}, err
|
||||
}
|
||||
|
||||
view := connection.GlobalProxyView{
|
||||
Enabled: input.Enabled,
|
||||
Type: strings.TrimSpace(input.Type),
|
||||
Host: strings.TrimSpace(input.Host),
|
||||
Port: input.Port,
|
||||
User: strings.TrimSpace(input.User),
|
||||
}
|
||||
|
||||
bundle := globalProxySecretBundle{}
|
||||
if strings.TrimSpace(input.Password) != "" {
|
||||
bundle.Password = input.Password
|
||||
} else if existing.HasPassword {
|
||||
existingBundle, loadErr := a.loadGlobalProxySecretBundle(existing)
|
||||
if loadErr != nil {
|
||||
return connection.GlobalProxyView{}, loadErr
|
||||
}
|
||||
bundle = existingBundle
|
||||
view.SecretRef = existing.SecretRef
|
||||
}
|
||||
|
||||
if !view.Enabled {
|
||||
if strings.TrimSpace(existing.SecretRef) != "" && a.secretStore != nil {
|
||||
if deleteErr := a.secretStore.Delete(existing.SecretRef); deleteErr != nil {
|
||||
return connection.GlobalProxyView{}, deleteErr
|
||||
}
|
||||
}
|
||||
view = connection.GlobalProxyView{Enabled: false}
|
||||
if err := a.persistGlobalProxyView(view); err != nil {
|
||||
return connection.GlobalProxyView{}, err
|
||||
}
|
||||
if _, err := setGlobalProxyConfig(false, connection.ProxyConfig{}); err != nil {
|
||||
return connection.GlobalProxyView{}, err
|
||||
}
|
||||
return view, nil
|
||||
}
|
||||
|
||||
if strings.TrimSpace(bundle.Password) != "" {
|
||||
ref, storeErr := a.storeGlobalProxySecret(view.SecretRef, bundle)
|
||||
if storeErr != nil {
|
||||
return connection.GlobalProxyView{}, storeErr
|
||||
}
|
||||
view.SecretRef = ref
|
||||
view.HasPassword = true
|
||||
} else {
|
||||
if strings.TrimSpace(existing.SecretRef) != "" && a.secretStore != nil {
|
||||
if deleteErr := a.secretStore.Delete(existing.SecretRef); deleteErr != nil {
|
||||
return connection.GlobalProxyView{}, deleteErr
|
||||
}
|
||||
}
|
||||
view.SecretRef = ""
|
||||
view.HasPassword = false
|
||||
}
|
||||
|
||||
if err := a.persistGlobalProxyView(view); err != nil {
|
||||
return connection.GlobalProxyView{}, err
|
||||
}
|
||||
if _, err := setGlobalProxyConfig(true, connection.ProxyConfig{
|
||||
Type: view.Type,
|
||||
Host: view.Host,
|
||||
Port: view.Port,
|
||||
User: view.User,
|
||||
Password: bundle.Password,
|
||||
}); err != nil {
|
||||
return connection.GlobalProxyView{}, err
|
||||
}
|
||||
view.Password = ""
|
||||
return view, nil
|
||||
}
|
||||
|
||||
func (a *App) persistGlobalProxyView(view connection.GlobalProxyView) error {
|
||||
if err := os.MkdirAll(a.configDir, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
payload, err := json.MarshalIndent(view, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(globalProxyMetadataPath(a.configDir), payload, 0o644)
|
||||
}
|
||||
|
||||
func (a *App) loadStoredGlobalProxyView() (connection.GlobalProxyView, error) {
|
||||
data, err := os.ReadFile(globalProxyMetadataPath(a.configDir))
|
||||
if err != nil {
|
||||
return connection.GlobalProxyView{}, err
|
||||
}
|
||||
var view connection.GlobalProxyView
|
||||
if err := json.Unmarshal(data, &view); err != nil {
|
||||
return connection.GlobalProxyView{}, err
|
||||
}
|
||||
return view, nil
|
||||
}
|
||||
|
||||
func (a *App) loadGlobalProxySecretBundle(view connection.GlobalProxyView) (globalProxySecretBundle, error) {
|
||||
if !view.HasPassword {
|
||||
return globalProxySecretBundle{}, nil
|
||||
}
|
||||
if a.secretStore == nil {
|
||||
return globalProxySecretBundle{}, fmt.Errorf("secret store unavailable")
|
||||
}
|
||||
ref := strings.TrimSpace(view.SecretRef)
|
||||
if ref == "" {
|
||||
var err error
|
||||
ref, err = secretstore.BuildRef(globalProxySecretKind, globalProxySecretID)
|
||||
if err != nil {
|
||||
return globalProxySecretBundle{}, err
|
||||
}
|
||||
}
|
||||
payload, err := a.secretStore.Get(ref)
|
||||
if err != nil {
|
||||
return globalProxySecretBundle{}, err
|
||||
}
|
||||
var bundle globalProxySecretBundle
|
||||
if err := json.Unmarshal(payload, &bundle); err != nil {
|
||||
return globalProxySecretBundle{}, err
|
||||
}
|
||||
return bundle, nil
|
||||
}
|
||||
|
||||
func (a *App) storeGlobalProxySecret(existingRef string, bundle globalProxySecretBundle) (string, error) {
|
||||
if a.secretStore == nil {
|
||||
return "", fmt.Errorf("secret store unavailable")
|
||||
}
|
||||
if err := a.secretStore.HealthCheck(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
ref := strings.TrimSpace(existingRef)
|
||||
if ref == "" {
|
||||
var err error
|
||||
ref, err = secretstore.BuildRef(globalProxySecretKind, globalProxySecretID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
payload, err := json.Marshal(bundle)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := a.secretStore.Put(ref, payload); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
func (a *App) loadPersistedGlobalProxy() {
|
||||
view, err := a.loadStoredGlobalProxyView()
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
logger.Error(err, "加载全局代理元数据失败")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
proxyConfig := connection.ProxyConfig{
|
||||
Type: view.Type,
|
||||
Host: view.Host,
|
||||
Port: view.Port,
|
||||
User: view.User,
|
||||
}
|
||||
if view.HasPassword {
|
||||
bundle, loadErr := a.loadGlobalProxySecretBundle(view)
|
||||
if loadErr != nil {
|
||||
logger.Error(loadErr, "加载全局代理密码失败")
|
||||
return
|
||||
}
|
||||
proxyConfig.Password = bundle.Password
|
||||
}
|
||||
if _, err := setGlobalProxyConfig(view.Enabled, proxyConfig); err != nil {
|
||||
logger.Error(err, "恢复全局代理配置失败")
|
||||
}
|
||||
}
|
||||
66
internal/app/global_proxy_secret_test.go
Normal file
66
internal/app/global_proxy_secret_test.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
func TestSaveGlobalProxyStripsPasswordFromView(t *testing.T) {
|
||||
store := newFakeAppSecretStore()
|
||||
app := NewAppWithSecretStore(store)
|
||||
app.configDir = t.TempDir()
|
||||
|
||||
view, err := app.saveGlobalProxy(connection.SaveGlobalProxyInput{
|
||||
Enabled: true,
|
||||
Type: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
User: "ops",
|
||||
Password: "proxy-secret",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("saveGlobalProxy returned error: %v", err)
|
||||
}
|
||||
if view.Password != "" {
|
||||
t.Fatal("global proxy view must not expose plaintext password")
|
||||
}
|
||||
if !view.HasPassword {
|
||||
t.Fatal("expected hasPassword=true")
|
||||
}
|
||||
|
||||
snapshot := currentGlobalProxyConfig()
|
||||
if snapshot.Proxy.Password != "proxy-secret" {
|
||||
t.Fatalf("expected runtime proxy password to be preserved, got %q", snapshot.Proxy.Password)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetGlobalProxyConfigReturnsSecretlessView(t *testing.T) {
|
||||
store := newFakeAppSecretStore()
|
||||
app := NewAppWithSecretStore(store)
|
||||
app.configDir = t.TempDir()
|
||||
|
||||
if _, err := app.saveGlobalProxy(connection.SaveGlobalProxyInput{
|
||||
Enabled: true,
|
||||
Type: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
User: "ops",
|
||||
Password: "proxy-secret",
|
||||
}); err != nil {
|
||||
t.Fatalf("saveGlobalProxy returned error: %v", err)
|
||||
}
|
||||
|
||||
result := app.GetGlobalProxyConfig()
|
||||
view, ok := result.Data.(connection.GlobalProxyView)
|
||||
if !ok {
|
||||
t.Fatalf("expected GlobalProxyView, got %T", result.Data)
|
||||
}
|
||||
if view.Password != "" {
|
||||
t.Fatal("GetGlobalProxyConfig must not expose plaintext password")
|
||||
}
|
||||
if !view.HasPassword {
|
||||
t.Fatal("expected hasPassword=true")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -543,7 +543,10 @@ func (a *App) GetDriverVersionPackageSize(driverType string, version string) con
|
||||
if normalizedVersion == "" {
|
||||
return connection.QueryResult{Success: false, Message: "版本号为空"}
|
||||
}
|
||||
assetName := optionalDriverReleaseAssetName(normalizedType)
|
||||
if err := validateDriverSelectedVersion(definition, normalizedVersion); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
assetName := optionalDriverReleaseAssetNameForVersion(normalizedType, normalizedVersion)
|
||||
if strings.TrimSpace(assetName) == "" {
|
||||
return connection.QueryResult{Success: false, Message: "驱动资产名称为空"}
|
||||
}
|
||||
@@ -554,14 +557,15 @@ func (a *App) GetDriverVersionPackageSize(driverType string, version string) con
|
||||
if sizeByAsset, err := loadReleaseAssetSizesCached("tag:"+tag, func() (*githubRelease, error) {
|
||||
return fetchReleaseByTag(tag)
|
||||
}); err == nil {
|
||||
sizeBytes = resolveOptionalDriverAssetSize(sizeByAsset, normalizedType)
|
||||
sizeBytes = resolveOptionalDriverAssetSizeForVersion(sizeByAsset, normalizedType, normalizedVersion)
|
||||
if sizeBytes > 0 {
|
||||
sizeSource = "tag"
|
||||
}
|
||||
}
|
||||
if sizeBytes <= 0 {
|
||||
allowLatestFallback := sameDriverVersion(normalizedVersion, definition.PinnedVersion) || sameDriverVersion(normalizedVersion, latestDriverVersionMap[normalizedType])
|
||||
if sizeBytes <= 0 && allowLatestFallback {
|
||||
if sizeByAsset, err := loadReleaseAssetSizesCached("latest", fetchLatestReleaseForDriverAssets); err == nil {
|
||||
sizeBytes = resolveOptionalDriverAssetSize(sizeByAsset, normalizedType)
|
||||
sizeBytes = resolveOptionalDriverAssetSizeForVersion(sizeByAsset, normalizedType, normalizedVersion)
|
||||
if sizeBytes > 0 {
|
||||
sizeSource = "latest"
|
||||
}
|
||||
@@ -741,7 +745,7 @@ func (a *App) CheckDriverNetworkStatus() connection.QueryResult {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) InstallLocalDriverPackage(driverType string, filePath string, downloadDir string) connection.QueryResult {
|
||||
func (a *App) InstallLocalDriverPackage(driverType string, filePath string, downloadDir string, version string) connection.QueryResult {
|
||||
definition, ok := resolveDriverDefinition(driverType)
|
||||
if !ok {
|
||||
return connection.QueryResult{Success: false, Message: "不支持的驱动类型"}
|
||||
@@ -764,7 +768,10 @@ func (a *App) InstallLocalDriverPackage(driverType string, filePath string, down
|
||||
db.SetExternalDriverDownloadDirectory(resolvedDir)
|
||||
|
||||
a.emitDriverDownloadProgress(definition.Type, "start", 0, 100, "开始安装本地驱动包")
|
||||
selectedVersion := resolveDriverInstallVersion(definition.PinnedVersion, "local://manual", definition)
|
||||
selectedVersion := resolveDriverInstallVersion(version, "local://manual", definition)
|
||||
if err := validateDriverSelectedVersion(definition, selectedVersion); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
meta, installErr := installOptionalDriverAgentFromLocalPath(definition, filePath, resolvedDir, selectedVersion)
|
||||
if installErr != nil {
|
||||
errText := normalizeErrorMessage(installErr)
|
||||
@@ -816,6 +823,9 @@ func (a *App) DownloadDriverPackage(driverType string, version string, downloadU
|
||||
urlText = fmt.Sprintf("builtin://activate/%s", optionalDriverPublicTypeName(definition.Type))
|
||||
}
|
||||
selectedVersion := resolveDriverInstallVersion(version, urlText, definition)
|
||||
if err := validateDriverSelectedVersion(definition, selectedVersion); err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
resolvedDir, err := resolveDriverDownloadDirectory(downloadDir)
|
||||
if err != nil {
|
||||
@@ -1424,6 +1434,11 @@ func resolveDriverVersionOptions(definition driverDefinition, repositoryURL stri
|
||||
if versionText == "" && urlText == "" {
|
||||
return
|
||||
}
|
||||
if versionText != "" {
|
||||
if err := validateDriverSelectedVersion(definition, versionText); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
versionKey := normalizeVersion(versionText)
|
||||
key := ""
|
||||
if versionKey != "" {
|
||||
@@ -1550,6 +1565,16 @@ func resolveVersionedDriverOption(definition driverDefinition, version string, s
|
||||
if versionText == "" {
|
||||
return "", "", false
|
||||
}
|
||||
if err := validateDriverSelectedVersion(definition, versionText); err != nil {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
if publishedURL, ok := resolvePublishedDriverDownloadURL(definition, versionText); ok {
|
||||
return versionText, publishedURL, true
|
||||
}
|
||||
if !optionalDriverSourceBuildAvailable(definition, versionText) {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
urlText := strings.TrimSpace(definition.DefaultDownloadURL)
|
||||
if urlText == "" && effectiveDriverEngine(definition) == driverEngineGo {
|
||||
@@ -1580,6 +1605,97 @@ func sameDriverVersion(left, right string) bool {
|
||||
return a != "" && a == b
|
||||
}
|
||||
|
||||
func validateDriverSelectedVersion(definition driverDefinition, version string) error {
|
||||
driverType := normalizeDriverType(definition.Type)
|
||||
versionText := normalizeVersion(strings.TrimSpace(version))
|
||||
if driverType == "" || versionText == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch driverType {
|
||||
case "mongodb":
|
||||
if strings.HasPrefix(versionText, "2.") {
|
||||
return nil
|
||||
}
|
||||
if strings.HasPrefix(versionText, "1.17.") {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("MongoDB 版本 %s 当前不受支持;仅支持 1.17.x 和 2.x", versionText)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func shouldRestrictToExplicitVersionArtifact(definition driverDefinition, selectedVersion string) bool {
|
||||
versionText := normalizeVersion(strings.TrimSpace(selectedVersion))
|
||||
if versionText == "" {
|
||||
return false
|
||||
}
|
||||
return !sameDriverVersion(versionText, definition.PinnedVersion)
|
||||
}
|
||||
|
||||
func optionalDriverSourceBuildAvailable(definition driverDefinition, selectedVersion string) bool {
|
||||
driverType := normalizeDriverType(definition.Type)
|
||||
if driverType == "" || !db.IsOptionalGoDriver(driverType) {
|
||||
return false
|
||||
}
|
||||
if _, err := optionalDriverBuildTag(driverType, selectedVersion); err != nil {
|
||||
return false
|
||||
}
|
||||
if _, err := exec.LookPath("go"); err != nil {
|
||||
return false
|
||||
}
|
||||
if _, err := locateProjectRootForAgentBuild(); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func resolvePublishedDriverDownloadURL(definition driverDefinition, version string) (string, bool) {
|
||||
driverType := normalizeDriverType(definition.Type)
|
||||
versionText := normalizeVersion(strings.TrimSpace(version))
|
||||
if driverType == "" || versionText == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
tag := "v" + versionText
|
||||
assetName, ok := resolvePublishedDriverReleaseAssetName(driverType, versionText, tag)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return fmt.Sprintf("https://github.com/%s/releases/download/%s/%s", updateRepo, tag, assetName), true
|
||||
}
|
||||
|
||||
func resolvePublishedDriverReleaseAssetName(driverType string, version string, tag string) (string, bool) {
|
||||
assetNames := optionalDriverReleaseAssetNamesForVersion(driverType, version)
|
||||
if len(assetNames) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
cacheKey := "tag:" + strings.TrimSpace(tag)
|
||||
if sizeByAsset, ok := readReleaseAssetSizesFromCache(cacheKey); ok {
|
||||
for _, assetName := range assetNames {
|
||||
if sizeByAsset[assetName] > 0 {
|
||||
return assetName, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
sizeByAsset, err := loadReleaseAssetSizesCached(cacheKey, func() (*githubRelease, error) {
|
||||
return fetchReleaseByTag(tag)
|
||||
})
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
for _, assetName := range assetNames {
|
||||
if sizeByAsset[assetName] > 0 {
|
||||
return assetName, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func resolveDriverVersionPackageSizeBytes(definition driverDefinition, option driverVersionOptionItem) int64 {
|
||||
driverType := normalizeDriverType(definition.Type)
|
||||
if driverType == "" || definition.BuiltIn {
|
||||
@@ -1593,20 +1709,20 @@ func resolveDriverVersionPackageSizeBytes(definition driverDefinition, option dr
|
||||
if version == "" {
|
||||
return 0
|
||||
}
|
||||
assetName := optionalDriverReleaseAssetName(driverType)
|
||||
if strings.TrimSpace(assetName) == "" {
|
||||
assetNames := optionalDriverReleaseAssetNamesForVersion(driverType, version)
|
||||
if len(assetNames) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
tag := "v" + version
|
||||
if sizeByAsset, ok := readReleaseAssetSizesFromCache("tag:" + tag); ok {
|
||||
return resolveOptionalDriverAssetSize(sizeByAsset, driverType)
|
||||
return resolveOptionalDriverAssetSizeForVersion(sizeByAsset, driverType, version)
|
||||
}
|
||||
|
||||
// 下拉版本列表要求快速返回:仅复用已有缓存,不在这里触发网络请求。
|
||||
if strings.EqualFold(strings.TrimSpace(option.Source), "latest") {
|
||||
if sizeByAsset, ok := readReleaseAssetSizesFromCache("latest"); ok {
|
||||
return resolveOptionalDriverAssetSize(sizeByAsset, driverType)
|
||||
return resolveOptionalDriverAssetSizeForVersion(sizeByAsset, driverType, version)
|
||||
}
|
||||
}
|
||||
return 0
|
||||
@@ -1906,19 +2022,23 @@ func resolveDriverVersionOptionsFromReleases(definition driverDefinition) []driv
|
||||
return nil
|
||||
}
|
||||
|
||||
assetName := optionalDriverReleaseAssetName(driverType)
|
||||
assetNames := optionalDriverReleaseAssetNames(driverType)
|
||||
result := make([]driverVersionOptionItem, 0, len(releases))
|
||||
for _, release := range releases {
|
||||
if release.Prerelease {
|
||||
continue
|
||||
}
|
||||
tag := strings.TrimSpace(release.TagName)
|
||||
if tag == "" || !releaseContainsAnyAsset(release, assetNames) {
|
||||
version := normalizeVersion(tag)
|
||||
if tag == "" || version == "" {
|
||||
continue
|
||||
}
|
||||
assetName := optionalDriverReleaseAssetNameForVersion(driverType, version)
|
||||
assetNames := optionalDriverReleaseAssetNamesForVersion(driverType, version)
|
||||
if !releaseContainsAnyAsset(release, assetNames) {
|
||||
continue
|
||||
}
|
||||
result = append(result, driverVersionOptionItem{
|
||||
Version: normalizeVersion(tag),
|
||||
Version: version,
|
||||
DownloadURL: fmt.Sprintf("https://github.com/%s/releases/download/%s/%s", updateRepo, tag, assetName),
|
||||
Source: "release",
|
||||
})
|
||||
@@ -2511,7 +2631,7 @@ func installOptionalDriverAgentFromLocalPath(definition driverDefinition, filePa
|
||||
sourceName := filepath.Base(pathText)
|
||||
downloadSource := fmt.Sprintf("local://manual/%s", filepath.Base(pathText))
|
||||
if info.IsDir() {
|
||||
matchedPath, matchedEntry, resolveErr := resolveLocalDriverAgentFromDirectory(pathText, driverType)
|
||||
matchedPath, matchedEntry, resolveErr := resolveLocalDriverAgentFromLocalDirectory(pathText, driverType, selectedVersion)
|
||||
if resolveErr != nil {
|
||||
return installedDriverPackage{}, resolveErr
|
||||
}
|
||||
@@ -2524,7 +2644,7 @@ func installOptionalDriverAgentFromLocalPath(definition driverDefinition, filePa
|
||||
}
|
||||
|
||||
if !info.IsDir() && strings.EqualFold(filepath.Ext(pathText), ".zip") {
|
||||
entryName, extractErr := installOptionalDriverAgentFromLocalZip(pathText, definition, executablePath)
|
||||
entryName, extractErr := installOptionalDriverAgentFromLocalZip(pathText, definition, executablePath, selectedVersion)
|
||||
if extractErr != nil {
|
||||
return installedDriverPackage{}, extractErr
|
||||
}
|
||||
@@ -2563,7 +2683,7 @@ type localDriverCandidate struct {
|
||||
inPlatformDir bool
|
||||
}
|
||||
|
||||
func resolveLocalDriverAgentFromDirectory(directoryPath string, driverType string) (string, string, error) {
|
||||
func resolveLocalDriverAgentFromLocalDirectory(directoryPath string, driverType string, selectedVersion string) (string, string, error) {
|
||||
root := strings.TrimSpace(directoryPath)
|
||||
if root == "" {
|
||||
return "", "", fmt.Errorf("本地驱动目录路径为空")
|
||||
@@ -2586,9 +2706,9 @@ func resolveLocalDriverAgentFromDirectory(directoryPath string, driverType strin
|
||||
}
|
||||
displayName := resolveDriverDisplayName(displayDefinition)
|
||||
platformDir := optionalDriverBundlePlatformDir(stdRuntime.GOOS)
|
||||
assetNameCandidates := optionalDriverReleaseAssetNames(normalizedType)
|
||||
baseNameCandidates := optionalDriverExecutableBaseNames(normalizedType)
|
||||
assetName := optionalDriverReleaseAssetName(normalizedType)
|
||||
assetNameCandidates := optionalDriverReleaseAssetNamesForVersion(normalizedType, selectedVersion)
|
||||
baseNameCandidates := optionalDriverExecutableBaseNamesForVersion(normalizedType, selectedVersion)
|
||||
assetName := optionalDriverReleaseAssetNameForVersion(normalizedType, selectedVersion)
|
||||
|
||||
exactRelativePath := filepath.ToSlash(filepath.Join(platformDir, assetName))
|
||||
for _, candidateName := range assetNameCandidates {
|
||||
@@ -2703,7 +2823,7 @@ func resolveLocalDriverAgentFromDirectory(directoryPath string, driverType strin
|
||||
)
|
||||
}
|
||||
|
||||
func installOptionalDriverAgentFromLocalZip(zipPath string, definition driverDefinition, executablePath string) (string, error) {
|
||||
func installOptionalDriverAgentFromLocalZip(zipPath string, definition driverDefinition, executablePath string, selectedVersion string) (string, error) {
|
||||
driverType := normalizeDriverType(definition.Type)
|
||||
displayName := resolveDriverDisplayName(definition)
|
||||
reader, err := zip.OpenReader(zipPath)
|
||||
@@ -2712,9 +2832,9 @@ func installOptionalDriverAgentFromLocalZip(zipPath string, definition driverDef
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
entryPath := optionalDriverBundleEntryPath(driverType)
|
||||
entryPaths := optionalDriverBundleEntryPaths(driverType)
|
||||
expectedBaseNames := optionalDriverReleaseAssetNames(driverType)
|
||||
entryPath := optionalDriverBundleEntryPathForVersion(driverType, selectedVersion)
|
||||
entryPaths := optionalDriverBundleEntryPathsForVersion(driverType, selectedVersion)
|
||||
expectedBaseNames := optionalDriverReleaseAssetNamesForVersion(driverType, selectedVersion)
|
||||
findEntry := func() *zip.File {
|
||||
for _, file := range reader.File {
|
||||
name := filepath.ToSlash(strings.TrimPrefix(strings.TrimSpace(file.Name), "./"))
|
||||
@@ -2791,9 +2911,10 @@ func installOptionalDriverAgentFromLocalZip(zipPath string, definition driverDef
|
||||
func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, executablePath string, downloadURL string, selectedVersion string) (string, string, error) {
|
||||
driverType := normalizeDriverType(definition.Type)
|
||||
displayName := resolveDriverDisplayName(definition)
|
||||
forceSourceBuild := shouldForceSourceBuildForVersion(driverType, selectedVersion)
|
||||
forceSourceBuild := shouldForceSourceBuildForResolvedDownload(driverType, selectedVersion, downloadURL)
|
||||
preferSourceBuildBeforeDownload := shouldPreferSourceBuildBeforeDownload(driverType, selectedVersion)
|
||||
skipReuseCandidate := shouldSkipReusableAgentCandidate(driverType, selectedVersion)
|
||||
restrictToExplicitArtifact := shouldRestrictToExplicitVersionArtifact(definition, selectedVersion)
|
||||
|
||||
info, err := os.Stat(executablePath)
|
||||
if err == nil && !info.IsDir() {
|
||||
@@ -2851,7 +2972,7 @@ func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, execut
|
||||
}
|
||||
|
||||
if !forceSourceBuild {
|
||||
downloadURLs := resolveOptionalDriverAgentDownloadURLs(definition, downloadURL)
|
||||
downloadURLs := resolveOptionalDriverAgentDownloadURLs(definition, downloadURL, selectedVersion)
|
||||
if len(downloadURLs) > 0 {
|
||||
for _, candidateURL := range downloadURLs {
|
||||
if a != nil {
|
||||
@@ -2865,7 +2986,7 @@ func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, execut
|
||||
}
|
||||
}
|
||||
bundleURLs := resolveOptionalDriverBundleDownloadURLs()
|
||||
if len(bundleURLs) > 0 {
|
||||
if !restrictToExplicitArtifact && len(bundleURLs) > 0 {
|
||||
for _, bundleURL := range bundleURLs {
|
||||
if a != nil {
|
||||
a.emitDriverDownloadProgress(driverType, "downloading", 20, 100, fmt.Sprintf("从驱动总包提取 %s 代理", displayName))
|
||||
@@ -3108,6 +3229,23 @@ func shouldForceSourceBuildForVersion(driverType string, selectedVersion string)
|
||||
return resolveMongoDriverMajorFromVersion(selectedVersion) == 1
|
||||
}
|
||||
|
||||
func shouldForceSourceBuildForResolvedDownload(driverType string, selectedVersion string, downloadURL string) bool {
|
||||
if !shouldForceSourceBuildForVersion(driverType, selectedVersion) {
|
||||
return false
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(strings.TrimSpace(downloadURL))
|
||||
if err != nil || parsed == nil {
|
||||
return true
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(parsed.Scheme)) {
|
||||
case "http", "https":
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func shouldPreferSourceBuildBeforeDownload(driverType string, selectedVersion string) bool {
|
||||
_ = selectedVersion
|
||||
switch normalizeDriverType(driverType) {
|
||||
@@ -3224,11 +3362,80 @@ func optionalDriverReleaseAssetNameForType(typeName string, goos string, goarch
|
||||
return name
|
||||
}
|
||||
|
||||
func optionalDriverExecutableBaseNames(driverType string) []string {
|
||||
func optionalDriverNameStemCandidates(driverType string, selectedVersion string) []string {
|
||||
candidates := make([]string, 0, 3)
|
||||
seen := make(map[string]struct{}, 3)
|
||||
appendStem := func(stem string) {
|
||||
trimmed := strings.TrimSpace(stem)
|
||||
if trimmed == "" {
|
||||
return
|
||||
}
|
||||
if _, ok := seen[trimmed]; ok {
|
||||
return
|
||||
}
|
||||
seen[trimmed] = struct{}{}
|
||||
candidates = append(candidates, trimmed)
|
||||
}
|
||||
|
||||
base := fmt.Sprintf("%s-driver-agent", optionalDriverPublicTypeName(driverType))
|
||||
if normalizeDriverType(driverType) == "mongodb" {
|
||||
switch resolveMongoDriverMajorFromVersion(selectedVersion) {
|
||||
case 1:
|
||||
appendStem(base + "-v1")
|
||||
appendStem(base)
|
||||
case 2:
|
||||
appendStem(base)
|
||||
appendStem(base + "-v2")
|
||||
default:
|
||||
appendStem(base)
|
||||
}
|
||||
return candidates
|
||||
}
|
||||
|
||||
appendStem(base)
|
||||
return candidates
|
||||
}
|
||||
|
||||
func optionalDriverExecutableBaseNamesForVersion(driverType string, selectedVersion string) []string {
|
||||
names := make([]string, 0, 2)
|
||||
seen := make(map[string]struct{}, 2)
|
||||
appendName := func(typeName string) {
|
||||
name := optionalDriverExecutableBaseNameForType(typeName)
|
||||
appendName := func(stem string) {
|
||||
name := strings.TrimSpace(stem)
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return
|
||||
}
|
||||
if stdRuntime.GOOS == "windows" {
|
||||
name += ".exe"
|
||||
}
|
||||
if _, ok := seen[name]; ok {
|
||||
return
|
||||
}
|
||||
seen[name] = struct{}{}
|
||||
names = append(names, name)
|
||||
}
|
||||
|
||||
for _, stem := range optionalDriverNameStemCandidates(driverType, selectedVersion) {
|
||||
appendName(stem)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
func optionalDriverExecutableBaseNames(driverType string) []string {
|
||||
return optionalDriverExecutableBaseNamesForVersion(driverType, "")
|
||||
}
|
||||
|
||||
func optionalDriverReleaseAssetNamesForVersion(driverType string, selectedVersion string) []string {
|
||||
names := make([]string, 0, 2)
|
||||
seen := make(map[string]struct{}, 2)
|
||||
appendName := func(stem string) {
|
||||
trimmedStem := strings.TrimSpace(stem)
|
||||
if trimmedStem == "" {
|
||||
return
|
||||
}
|
||||
name := fmt.Sprintf("%s-%s-%s", trimmedStem, stdRuntime.GOOS, stdRuntime.GOARCH)
|
||||
if strings.EqualFold(stdRuntime.GOOS, "windows") {
|
||||
name += ".exe"
|
||||
}
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return
|
||||
}
|
||||
@@ -3239,27 +3446,14 @@ func optionalDriverExecutableBaseNames(driverType string) []string {
|
||||
names = append(names, name)
|
||||
}
|
||||
|
||||
appendName(optionalDriverPublicTypeName(driverType))
|
||||
for _, stem := range optionalDriverNameStemCandidates(driverType, selectedVersion) {
|
||||
appendName(stem)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
func optionalDriverReleaseAssetNames(driverType string) []string {
|
||||
names := make([]string, 0, 2)
|
||||
seen := make(map[string]struct{}, 2)
|
||||
appendName := func(typeName string) {
|
||||
name := optionalDriverReleaseAssetNameForType(typeName, stdRuntime.GOOS, stdRuntime.GOARCH)
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return
|
||||
}
|
||||
if _, ok := seen[name]; ok {
|
||||
return
|
||||
}
|
||||
seen[name] = struct{}{}
|
||||
names = append(names, name)
|
||||
}
|
||||
|
||||
appendName(optionalDriverPublicTypeName(driverType))
|
||||
return names
|
||||
return optionalDriverReleaseAssetNamesForVersion(driverType, "")
|
||||
}
|
||||
|
||||
func optionalDriverExecutableBaseName(driverType string) string {
|
||||
@@ -3278,6 +3472,14 @@ func optionalDriverReleaseAssetName(driverType string) string {
|
||||
return names[0]
|
||||
}
|
||||
|
||||
func optionalDriverReleaseAssetNameForVersion(driverType string, selectedVersion string) string {
|
||||
names := optionalDriverReleaseAssetNamesForVersion(driverType, selectedVersion)
|
||||
if len(names) == 0 {
|
||||
return optionalDriverReleaseAssetNameForType("", stdRuntime.GOOS, stdRuntime.GOARCH)
|
||||
}
|
||||
return names[0]
|
||||
}
|
||||
|
||||
func optionalDriverBundlePlatformDir(goos string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(goos)) {
|
||||
case "windows":
|
||||
@@ -3291,9 +3493,9 @@ func optionalDriverBundlePlatformDir(goos string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func optionalDriverBundleEntryPaths(driverType string) []string {
|
||||
func optionalDriverBundleEntryPathsForVersion(driverType string, selectedVersion string) []string {
|
||||
platformDir := optionalDriverBundlePlatformDir(stdRuntime.GOOS)
|
||||
assetNames := optionalDriverReleaseAssetNames(driverType)
|
||||
assetNames := optionalDriverReleaseAssetNamesForVersion(driverType, selectedVersion)
|
||||
result := make([]string, 0, len(assetNames))
|
||||
seen := make(map[string]struct{}, len(assetNames))
|
||||
for _, assetName := range assetNames {
|
||||
@@ -3307,14 +3509,22 @@ func optionalDriverBundleEntryPaths(driverType string) []string {
|
||||
return result
|
||||
}
|
||||
|
||||
func optionalDriverBundleEntryPath(driverType string) string {
|
||||
paths := optionalDriverBundleEntryPaths(driverType)
|
||||
func optionalDriverBundleEntryPaths(driverType string) []string {
|
||||
return optionalDriverBundleEntryPathsForVersion(driverType, "")
|
||||
}
|
||||
|
||||
func optionalDriverBundleEntryPathForVersion(driverType string, selectedVersion string) string {
|
||||
paths := optionalDriverBundleEntryPathsForVersion(driverType, selectedVersion)
|
||||
if len(paths) == 0 {
|
||||
return filepath.ToSlash(filepath.Join(optionalDriverBundlePlatformDir(stdRuntime.GOOS), optionalDriverReleaseAssetName(driverType)))
|
||||
return filepath.ToSlash(filepath.Join(optionalDriverBundlePlatformDir(stdRuntime.GOOS), optionalDriverReleaseAssetNameForVersion(driverType, selectedVersion)))
|
||||
}
|
||||
return paths[0]
|
||||
}
|
||||
|
||||
func optionalDriverBundleEntryPath(driverType string) string {
|
||||
return optionalDriverBundleEntryPathForVersion(driverType, "")
|
||||
}
|
||||
|
||||
func resolveOptionalDriverAssetSize(sizeByAsset map[string]int64, driverType string) int64 {
|
||||
if len(sizeByAsset) == 0 {
|
||||
return 0
|
||||
@@ -3328,6 +3538,19 @@ func resolveOptionalDriverAssetSize(sizeByAsset map[string]int64, driverType str
|
||||
return 0
|
||||
}
|
||||
|
||||
func resolveOptionalDriverAssetSizeForVersion(sizeByAsset map[string]int64, driverType string, version string) int64 {
|
||||
if len(sizeByAsset) == 0 {
|
||||
return 0
|
||||
}
|
||||
for _, assetName := range optionalDriverReleaseAssetNamesForVersion(driverType, version) {
|
||||
sizeBytes := sizeByAsset[assetName]
|
||||
if sizeBytes > 0 {
|
||||
return sizeBytes
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func resolveOptionalDriverBundleDownloadURLs() []string {
|
||||
candidates := make([]string, 0, 2)
|
||||
seen := make(map[string]struct{}, 2)
|
||||
@@ -3351,7 +3574,7 @@ func resolveOptionalDriverBundleDownloadURLs() []string {
|
||||
return candidates
|
||||
}
|
||||
|
||||
func resolveOptionalDriverAgentDownloadURLs(definition driverDefinition, rawURL string) []string {
|
||||
func resolveOptionalDriverAgentDownloadURLs(definition driverDefinition, rawURL string, selectedVersion string) []string {
|
||||
driverType := normalizeDriverType(definition.Type)
|
||||
candidates := make([]string, 0, 3)
|
||||
seen := make(map[string]struct{}, 3)
|
||||
@@ -3373,6 +3596,9 @@ func resolveOptionalDriverAgentDownloadURLs(definition driverDefinition, rawURL
|
||||
appendURL(parsed.String())
|
||||
}
|
||||
}
|
||||
if shouldRestrictToExplicitVersionArtifact(definition, selectedVersion) {
|
||||
return candidates
|
||||
}
|
||||
|
||||
assetNames := optionalDriverReleaseAssetNames(driverType)
|
||||
currentVersion := normalizeVersion(getCurrentVersion())
|
||||
|
||||
331
internal/app/methods_driver_version_test.go
Normal file
331
internal/app/methods_driver_version_test.go
Normal file
@@ -0,0 +1,331 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestResolveVersionedDriverOptionUsesPublishedMongoV1Release(t *testing.T) {
|
||||
definition, ok := resolveDriverDefinition("mongodb")
|
||||
if !ok {
|
||||
t.Fatal("expected mongodb driver definition")
|
||||
}
|
||||
|
||||
version := "1.17.4"
|
||||
assetName := mongoVersionedReleaseAssetName(1)
|
||||
seedReleaseAssetSizeCache(t, "tag:v"+version, map[string]int64{
|
||||
assetName: 24 << 20,
|
||||
})
|
||||
chdirTemp(t)
|
||||
|
||||
gotVersion, gotURL, ok := resolveVersionedDriverOption(definition, version, "history")
|
||||
if !ok {
|
||||
t.Fatal("expected published mongodb v1 option to remain available")
|
||||
}
|
||||
if gotVersion != version {
|
||||
t.Fatalf("expected version %q, got %q", version, gotVersion)
|
||||
}
|
||||
|
||||
wantURL := fmt.Sprintf("https://github.com/%s/releases/download/v%s/%s", updateRepo, version, assetName)
|
||||
if gotURL != wantURL {
|
||||
t.Fatalf("expected published release URL %q, got %q", wantURL, gotURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDriverVersionSupportRangeForMongoDB(t *testing.T) {
|
||||
definition, ok := resolveDriverDefinition("mongodb")
|
||||
if !ok {
|
||||
t.Fatal("expected mongodb driver definition")
|
||||
}
|
||||
|
||||
if err := validateDriverSelectedVersion(definition, "1.17.4"); err != nil {
|
||||
t.Fatalf("expected 1.17.4 to stay supported, got %v", err)
|
||||
}
|
||||
if err := validateDriverSelectedVersion(definition, "2.5.0"); err != nil {
|
||||
t.Fatalf("expected 2.5.0 to stay supported, got %v", err)
|
||||
}
|
||||
if err := validateDriverSelectedVersion(definition, "1.16.1"); err == nil {
|
||||
t.Fatal("expected 1.16.1 to be rejected by MongoDB support range")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveVersionedDriverOptionSkipsMongoV1WithoutPublishedReleaseOrSourceBuild(t *testing.T) {
|
||||
definition, ok := resolveDriverDefinition("mongodb")
|
||||
if !ok {
|
||||
t.Fatal("expected mongodb driver definition")
|
||||
}
|
||||
|
||||
version := "1.17.4"
|
||||
seedReleaseAssetSizeCache(t, "tag:v"+version, map[string]int64{})
|
||||
chdirTemp(t)
|
||||
|
||||
_, _, ok = resolveVersionedDriverOption(definition, version, "history")
|
||||
if ok {
|
||||
t.Fatal("expected unpublished mongodb v1 option to be filtered out when source build is unavailable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveVersionedDriverOptionRejectsUnsupportedMongoV1Range(t *testing.T) {
|
||||
definition, ok := resolveDriverDefinition("mongodb")
|
||||
if !ok {
|
||||
t.Fatal("expected mongodb driver definition")
|
||||
}
|
||||
|
||||
seedReleaseAssetSizeCache(t, "tag:v1.16.1", map[string]int64{
|
||||
mongoVersionedReleaseAssetName(1): 24 << 20,
|
||||
})
|
||||
|
||||
_, _, ok = resolveVersionedDriverOption(definition, "1.16.1", "history")
|
||||
if ok {
|
||||
t.Fatal("expected MongoDB 1.16.1 to be hidden from the selectable version list")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveDriverVersionPackageSizeBytesReadsMongoV1VersionedAsset(t *testing.T) {
|
||||
definition, ok := resolveDriverDefinition("mongodb")
|
||||
if !ok {
|
||||
t.Fatal("expected mongodb driver definition")
|
||||
}
|
||||
|
||||
version := "1.17.4"
|
||||
assetName := mongoVersionedReleaseAssetName(1)
|
||||
const wantSize int64 = 31 << 20
|
||||
seedReleaseAssetSizeCache(t, "tag:v"+version, map[string]int64{
|
||||
assetName: wantSize,
|
||||
})
|
||||
|
||||
got := resolveDriverVersionPackageSizeBytes(definition, driverVersionOptionItem{
|
||||
Version: version,
|
||||
Source: "history",
|
||||
})
|
||||
if got != wantSize {
|
||||
t.Fatalf("expected size %d, got %d", wantSize, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveOptionalDriverAgentDownloadURLsDoesNotFallbackForHistoricalVersion(t *testing.T) {
|
||||
definition, ok := resolveDriverDefinition("mongodb")
|
||||
if !ok {
|
||||
t.Fatal("expected mongodb driver definition")
|
||||
}
|
||||
|
||||
explicitURL := fmt.Sprintf("https://github.com/Syngnat/GoNavi/releases/download/v1.17.4/%s", mongoVersionedReleaseAssetName(1))
|
||||
urls := resolveOptionalDriverAgentDownloadURLs(
|
||||
definition,
|
||||
explicitURL,
|
||||
"1.17.4",
|
||||
)
|
||||
if len(urls) != 1 {
|
||||
t.Fatalf("expected only explicit historical URL, got %d candidates: %v", len(urls), urls)
|
||||
}
|
||||
if urls[0] != explicitURL {
|
||||
t.Fatalf("unexpected historical URL candidate: %v", urls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadDriverPackageRejectsUnsupportedMongoVersion(t *testing.T) {
|
||||
app := &App{}
|
||||
|
||||
result := app.DownloadDriverPackage("mongodb", "1.16.1", "builtin://activate/mongodb?channel=history&version=1.16.1", t.TempDir())
|
||||
if result.Success {
|
||||
t.Fatal("expected unsupported MongoDB 1.16.1 install to be rejected")
|
||||
}
|
||||
if !strings.Contains(result.Message, "仅支持 1.17.x 和 2.x") {
|
||||
t.Fatalf("expected support-range error, got %q", result.Message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldForceSourceBuildForResolvedDownload(t *testing.T) {
|
||||
if !shouldForceSourceBuildForResolvedDownload("mongodb", "1.17.4", "builtin://activate/mongodb?channel=history&version=1.17.4") {
|
||||
t.Fatal("expected mongodb v1 builtin install to keep source build mode")
|
||||
}
|
||||
|
||||
explicitURL := fmt.Sprintf("https://github.com/%s/releases/download/v1.17.4/%s", updateRepo, mongoVersionedReleaseAssetName(1))
|
||||
if shouldForceSourceBuildForResolvedDownload("mongodb", "1.17.4", explicitURL) {
|
||||
t.Fatal("expected mongodb v1 published asset install to skip forced source build")
|
||||
}
|
||||
|
||||
if shouldForceSourceBuildForResolvedDownload("mongodb", "2.5.0", "builtin://activate/mongodb?channel=latest&version=2.5.0") {
|
||||
t.Fatal("expected mongodb v2 install not to force source build")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstallOptionalDriverAgentFromLocalPathSupportsMongoV1DirectoryImport(t *testing.T) {
|
||||
definition, ok := resolveDriverDefinition("mongodb")
|
||||
if !ok {
|
||||
t.Fatal("expected mongodb driver definition")
|
||||
}
|
||||
|
||||
packageRoot := t.TempDir()
|
||||
platformDir := filepath.Join(packageRoot, optionalDriverBundlePlatformDir(runtime.GOOS))
|
||||
if err := os.MkdirAll(platformDir, 0o755); err != nil {
|
||||
t.Fatalf("mkdir package dir failed: %v", err)
|
||||
}
|
||||
|
||||
assetName := mongoVersionedReleaseAssetName(1)
|
||||
writeSelfExecutable(t, filepath.Join(platformDir, assetName))
|
||||
|
||||
installRoot := filepath.Join(t.TempDir(), "drivers")
|
||||
meta, err := installOptionalDriverAgentFromLocalPath(definition, packageRoot, installRoot, "1.17.4")
|
||||
if err != nil {
|
||||
t.Fatalf("expected mongodb v1 directory import to succeed, got %v", err)
|
||||
}
|
||||
if meta.Version != "1.17.4" {
|
||||
t.Fatalf("expected imported version to stay 1.17.4, got %q", meta.Version)
|
||||
}
|
||||
if filepath.Base(meta.FilePath) != assetName {
|
||||
t.Fatalf("expected source file %q, got %q", assetName, meta.FilePath)
|
||||
}
|
||||
if !strings.Contains(meta.DownloadURL, assetName) {
|
||||
t.Fatalf("expected download source to reference %q, got %q", assetName, meta.DownloadURL)
|
||||
}
|
||||
if _, err := os.Stat(meta.ExecutablePath); err != nil {
|
||||
t.Fatalf("expected imported executable to exist, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstallOptionalDriverAgentFromLocalPathSupportsMongoV1ZipImport(t *testing.T) {
|
||||
definition, ok := resolveDriverDefinition("mongodb")
|
||||
if !ok {
|
||||
t.Fatal("expected mongodb driver definition")
|
||||
}
|
||||
|
||||
assetName := mongoVersionedReleaseAssetName(1)
|
||||
zipPath := filepath.Join(t.TempDir(), "mongodb-v1.zip")
|
||||
writeZipWithSelfExecutable(t, zipPath, filepath.ToSlash(filepath.Join(optionalDriverBundlePlatformDir(runtime.GOOS), assetName)))
|
||||
|
||||
installRoot := filepath.Join(t.TempDir(), "drivers")
|
||||
meta, err := installOptionalDriverAgentFromLocalPath(definition, zipPath, installRoot, "1.17.4")
|
||||
if err != nil {
|
||||
t.Fatalf("expected mongodb v1 zip import to succeed, got %v", err)
|
||||
}
|
||||
if meta.Version != "1.17.4" {
|
||||
t.Fatalf("expected imported version to stay 1.17.4, got %q", meta.Version)
|
||||
}
|
||||
if !strings.Contains(meta.DownloadURL, assetName) {
|
||||
t.Fatalf("expected zip download source to reference %q, got %q", assetName, meta.DownloadURL)
|
||||
}
|
||||
if _, err := os.Stat(meta.ExecutablePath); err != nil {
|
||||
t.Fatalf("expected imported executable to exist, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func seedReleaseAssetSizeCache(t *testing.T, cacheKey string, sizeByKey map[string]int64) {
|
||||
t.Helper()
|
||||
|
||||
driverReleaseSizeMu.Lock()
|
||||
original := cloneReleaseAssetSizeCache(driverReleaseSizeMap)
|
||||
driverReleaseSizeMap[cacheKey] = driverReleaseAssetSizeCacheEntry{
|
||||
LoadedAt: time.Now(),
|
||||
SizeByKey: cloneInt64Map(sizeByKey),
|
||||
}
|
||||
driverReleaseSizeMu.Unlock()
|
||||
|
||||
t.Cleanup(func() {
|
||||
driverReleaseSizeMu.Lock()
|
||||
driverReleaseSizeMap = original
|
||||
driverReleaseSizeMu.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
func cloneReleaseAssetSizeCache(src map[string]driverReleaseAssetSizeCacheEntry) map[string]driverReleaseAssetSizeCacheEntry {
|
||||
cloned := make(map[string]driverReleaseAssetSizeCacheEntry, len(src))
|
||||
for key, value := range src {
|
||||
cloned[key] = driverReleaseAssetSizeCacheEntry{
|
||||
LoadedAt: value.LoadedAt,
|
||||
SizeByKey: cloneInt64Map(value.SizeByKey),
|
||||
Err: value.Err,
|
||||
}
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneInt64Map(src map[string]int64) map[string]int64 {
|
||||
if len(src) == 0 {
|
||||
return map[string]int64{}
|
||||
}
|
||||
cloned := make(map[string]int64, len(src))
|
||||
for key, value := range src {
|
||||
cloned[key] = value
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func chdirTemp(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatalf("getwd failed: %v", err)
|
||||
}
|
||||
tempDir := t.TempDir()
|
||||
if err := os.Chdir(tempDir); err != nil {
|
||||
t.Fatalf("chdir temp failed: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := os.Chdir(wd); err != nil {
|
||||
t.Fatalf("restore cwd failed: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func mongoVersionedReleaseAssetName(major int) string {
|
||||
name := fmt.Sprintf("mongodb-driver-agent-v%d-%s-%s", major, runtime.GOOS, runtime.GOARCH)
|
||||
if runtime.GOOS == "windows" {
|
||||
return name + ".exe"
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func writeSelfExecutable(t *testing.T, targetPath string) {
|
||||
t.Helper()
|
||||
|
||||
selfPath, err := os.Executable()
|
||||
if err != nil {
|
||||
t.Fatalf("executable path failed: %v", err)
|
||||
}
|
||||
content, err := os.ReadFile(selfPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read self executable failed: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(targetPath, content, 0o755); err != nil {
|
||||
t.Fatalf("write executable failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func writeZipWithSelfExecutable(t *testing.T, zipPath string, entryName string) {
|
||||
t.Helper()
|
||||
|
||||
selfPath, err := os.Executable()
|
||||
if err != nil {
|
||||
t.Fatalf("executable path failed: %v", err)
|
||||
}
|
||||
content, err := os.ReadFile(selfPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read self executable failed: %v", err)
|
||||
}
|
||||
|
||||
file, err := os.Create(zipPath)
|
||||
if err != nil {
|
||||
t.Fatalf("create zip failed: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
writer := zip.NewWriter(file)
|
||||
entry, err := writer.Create(entryName)
|
||||
if err != nil {
|
||||
t.Fatalf("create zip entry failed: %v", err)
|
||||
}
|
||||
if _, err := entry.Write(content); err != nil {
|
||||
t.Fatalf("write zip entry failed: %v", err)
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatalf("close zip writer failed: %v", err)
|
||||
}
|
||||
}
|
||||
44
internal/app/methods_saved_connections.go
Normal file
44
internal/app/methods_saved_connections.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package app
|
||||
|
||||
import "GoNavi-Wails/internal/connection"
|
||||
|
||||
func (a *App) savedConnectionRepository() *savedConnectionRepository {
|
||||
return newSavedConnectionRepository(a.configDir, a.secretStore)
|
||||
}
|
||||
|
||||
func (a *App) GetSavedConnections() ([]connection.SavedConnectionView, error) {
|
||||
return a.savedConnectionRepository().List()
|
||||
}
|
||||
|
||||
func (a *App) SaveConnection(input connection.SavedConnectionInput) (connection.SavedConnectionView, error) {
|
||||
return a.savedConnectionRepository().Save(input)
|
||||
}
|
||||
|
||||
func (a *App) DeleteConnection(id string) error {
|
||||
return a.savedConnectionRepository().Delete(id)
|
||||
}
|
||||
|
||||
func (a *App) DuplicateConnection(id string) (connection.SavedConnectionView, error) {
|
||||
return a.savedConnectionRepository().Duplicate(id)
|
||||
}
|
||||
|
||||
func (a *App) ImportLegacyConnections(items []connection.LegacySavedConnection) ([]connection.SavedConnectionView, error) {
|
||||
result := make([]connection.SavedConnectionView, 0, len(items))
|
||||
repo := a.savedConnectionRepository()
|
||||
for _, item := range items {
|
||||
view, err := repo.Save(connection.SavedConnectionInput(item))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, view)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (a *App) SaveGlobalProxy(input connection.SaveGlobalProxyInput) (connection.GlobalProxyView, error) {
|
||||
return a.saveGlobalProxy(input)
|
||||
}
|
||||
|
||||
func (a *App) ImportLegacyGlobalProxy(input connection.LegacyGlobalProxyInput) (connection.GlobalProxyView, error) {
|
||||
return a.saveGlobalProxy(connection.SaveGlobalProxyInput(input))
|
||||
}
|
||||
187
internal/app/methods_saved_connections_test.go
Normal file
187
internal/app/methods_saved_connections_test.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
func TestSaveConnectionMethodReturnsSecretlessView(t *testing.T) {
|
||||
app := NewAppWithSecretStore(newFakeAppSecretStore())
|
||||
app.configDir = t.TempDir()
|
||||
|
||||
result, err := app.SaveConnection(connection.SavedConnectionInput{
|
||||
ID: "conn-1",
|
||||
Name: "Primary",
|
||||
IncludeDatabases: []string{"appdb"},
|
||||
IconType: "postgres",
|
||||
IconColor: "#1677ff",
|
||||
Config: connection.ConnectionConfig{
|
||||
ID: "conn-1",
|
||||
Type: "postgres",
|
||||
Host: "db.local",
|
||||
Port: 5432,
|
||||
User: "postgres",
|
||||
Password: "postgres-secret",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Config.Password != "" {
|
||||
t.Fatal("SaveConnection must not return plaintext password")
|
||||
}
|
||||
if !result.HasPrimaryPassword {
|
||||
t.Fatal("expected HasPrimaryPassword=true")
|
||||
}
|
||||
if !reflect.DeepEqual(result.IncludeDatabases, []string{"appdb"}) {
|
||||
t.Fatalf("expected include databases to be preserved, got %#v", result.IncludeDatabases)
|
||||
}
|
||||
if result.IconType != "postgres" || result.IconColor != "#1677ff" {
|
||||
t.Fatalf("expected icon metadata to be preserved, got type=%q color=%q", result.IconType, result.IconColor)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveConnectionClearsRequestedSecretFields(t *testing.T) {
|
||||
app := NewAppWithSecretStore(newFakeAppSecretStore())
|
||||
app.configDir = t.TempDir()
|
||||
|
||||
_, err := app.SaveConnection(connection.SavedConnectionInput{
|
||||
ID: "conn-1",
|
||||
Name: "Primary",
|
||||
Config: connection.ConnectionConfig{
|
||||
ID: "conn-1",
|
||||
Type: "postgres",
|
||||
Host: "db.local",
|
||||
Port: 5432,
|
||||
User: "postgres",
|
||||
Password: "postgres-secret",
|
||||
UseSSH: true,
|
||||
SSH: connection.SSHConfig{
|
||||
Host: "jump.local",
|
||||
Port: 22,
|
||||
User: "ops",
|
||||
Password: "ssh-secret",
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
view, err := app.SaveConnection(connection.SavedConnectionInput{
|
||||
ID: "conn-1",
|
||||
Name: "Primary",
|
||||
Config: connection.ConnectionConfig{
|
||||
ID: "conn-1",
|
||||
Type: "postgres",
|
||||
Host: "db.local",
|
||||
Port: 5432,
|
||||
User: "postgres",
|
||||
UseSSH: true,
|
||||
SSH: connection.SSHConfig{
|
||||
Host: "jump.local",
|
||||
Port: 22,
|
||||
User: "ops",
|
||||
},
|
||||
},
|
||||
ClearPrimaryPassword: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if view.HasPrimaryPassword {
|
||||
t.Fatal("expected HasPrimaryPassword=false after clearing")
|
||||
}
|
||||
if !view.HasSSHPassword {
|
||||
t.Fatal("expected SSH password to stay stored")
|
||||
}
|
||||
|
||||
resolved, err := app.resolveConnectionSecrets(view.Config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resolved.Password != "" {
|
||||
t.Fatalf("expected cleared primary password, got %q", resolved.Password)
|
||||
}
|
||||
if resolved.SSH.Password != "ssh-secret" {
|
||||
t.Fatalf("expected SSH password to stay stored, got %q", resolved.SSH.Password)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDuplicateConnectionClonesSecretBundle(t *testing.T) {
|
||||
app := NewAppWithSecretStore(newFakeAppSecretStore())
|
||||
app.configDir = t.TempDir()
|
||||
|
||||
_, err := app.SaveConnection(connection.SavedConnectionInput{
|
||||
ID: "conn-1",
|
||||
Name: "Primary",
|
||||
IncludeDatabases: []string{"appdb"},
|
||||
IncludeRedisDatabases: []int{0, 1},
|
||||
IconType: "postgres",
|
||||
IconColor: "#1677ff",
|
||||
Config: connection.ConnectionConfig{
|
||||
ID: "conn-1",
|
||||
Type: "postgres",
|
||||
Host: "db.local",
|
||||
Port: 5432,
|
||||
User: "postgres",
|
||||
Password: "postgres-secret",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
duplicate, err := app.DuplicateConnection("conn-1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if duplicate.ID == "conn-1" {
|
||||
t.Fatal("duplicate should have a new id")
|
||||
}
|
||||
if duplicate.Name != "Primary - 副本" {
|
||||
t.Fatalf("expected duplicate name to keep existing UX, got %q", duplicate.Name)
|
||||
}
|
||||
if !reflect.DeepEqual(duplicate.IncludeDatabases, []string{"appdb"}) {
|
||||
t.Fatalf("expected include databases to be cloned, got %#v", duplicate.IncludeDatabases)
|
||||
}
|
||||
if !reflect.DeepEqual(duplicate.IncludeRedisDatabases, []int{0, 1}) {
|
||||
t.Fatalf("expected redis include databases to be cloned, got %#v", duplicate.IncludeRedisDatabases)
|
||||
}
|
||||
if duplicate.IconType != "postgres" || duplicate.IconColor != "#1677ff" {
|
||||
t.Fatalf("expected icon metadata to be cloned, got type=%q color=%q", duplicate.IconType, duplicate.IconColor)
|
||||
}
|
||||
|
||||
resolved, err := app.resolveConnectionSecrets(duplicate.Config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resolved.Password != "postgres-secret" {
|
||||
t.Fatalf("expected duplicated secret bundle, got %q", resolved.Password)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveGlobalProxyReturnsSecretlessView(t *testing.T) {
|
||||
app := NewAppWithSecretStore(newFakeAppSecretStore())
|
||||
app.configDir = t.TempDir()
|
||||
|
||||
view, err := app.SaveGlobalProxy(connection.SaveGlobalProxyInput{
|
||||
Enabled: true,
|
||||
Type: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
User: "ops",
|
||||
Password: "proxy-secret",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if view.Password != "" {
|
||||
t.Fatal("global proxy view must not expose plaintext password")
|
||||
}
|
||||
if !view.HasPassword {
|
||||
t.Fatal("expected hasPassword=true")
|
||||
}
|
||||
}
|
||||
476
internal/app/saved_connections.go
Normal file
476
internal/app/saved_connections.go
Normal file
@@ -0,0 +1,476 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/secretstore"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
savedConnectionsFileName = "connections.json"
|
||||
savedConnectionSecretKind = "connection"
|
||||
)
|
||||
|
||||
type connectionSecretBundle struct {
|
||||
Password string `json:"password,omitempty"`
|
||||
SSHPassword string `json:"sshPassword,omitempty"`
|
||||
ProxyPassword string `json:"proxyPassword,omitempty"`
|
||||
HTTPTunnelPassword string `json:"httpTunnelPassword,omitempty"`
|
||||
MySQLReplicaPassword string `json:"mysqlReplicaPassword,omitempty"`
|
||||
MongoReplicaPassword string `json:"mongoReplicaPassword,omitempty"`
|
||||
OpaqueURI string `json:"opaqueURI,omitempty"`
|
||||
OpaqueDSN string `json:"opaqueDSN,omitempty"`
|
||||
}
|
||||
|
||||
type savedConnectionsFile struct {
|
||||
Connections []connection.SavedConnectionView `json:"connections"`
|
||||
}
|
||||
|
||||
type savedConnectionRepository struct {
|
||||
configDir string
|
||||
secretStore secretstore.SecretStore
|
||||
}
|
||||
|
||||
func resolveAppConfigDir() string {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil || strings.TrimSpace(homeDir) == "" {
|
||||
return "."
|
||||
}
|
||||
return filepath.Join(homeDir, ".gonavi")
|
||||
}
|
||||
|
||||
func newSavedConnectionRepository(configDir string, store secretstore.SecretStore) *savedConnectionRepository {
|
||||
if strings.TrimSpace(configDir) == "" {
|
||||
configDir = resolveAppConfigDir()
|
||||
}
|
||||
if store == nil {
|
||||
store = secretstore.NewUnavailableStore("secret store unavailable")
|
||||
}
|
||||
return &savedConnectionRepository{configDir: configDir, secretStore: store}
|
||||
}
|
||||
|
||||
func (b connectionSecretBundle) hasAny() bool {
|
||||
return strings.TrimSpace(b.Password) != "" ||
|
||||
strings.TrimSpace(b.SSHPassword) != "" ||
|
||||
strings.TrimSpace(b.ProxyPassword) != "" ||
|
||||
strings.TrimSpace(b.HTTPTunnelPassword) != "" ||
|
||||
strings.TrimSpace(b.MySQLReplicaPassword) != "" ||
|
||||
strings.TrimSpace(b.MongoReplicaPassword) != "" ||
|
||||
strings.TrimSpace(b.OpaqueURI) != "" ||
|
||||
strings.TrimSpace(b.OpaqueDSN) != ""
|
||||
}
|
||||
|
||||
func mergeConnectionSecretBundles(base, overlay connectionSecretBundle) connectionSecretBundle {
|
||||
merged := base
|
||||
if strings.TrimSpace(overlay.Password) != "" {
|
||||
merged.Password = overlay.Password
|
||||
}
|
||||
if strings.TrimSpace(overlay.SSHPassword) != "" {
|
||||
merged.SSHPassword = overlay.SSHPassword
|
||||
}
|
||||
if strings.TrimSpace(overlay.ProxyPassword) != "" {
|
||||
merged.ProxyPassword = overlay.ProxyPassword
|
||||
}
|
||||
if strings.TrimSpace(overlay.HTTPTunnelPassword) != "" {
|
||||
merged.HTTPTunnelPassword = overlay.HTTPTunnelPassword
|
||||
}
|
||||
if strings.TrimSpace(overlay.MySQLReplicaPassword) != "" {
|
||||
merged.MySQLReplicaPassword = overlay.MySQLReplicaPassword
|
||||
}
|
||||
if strings.TrimSpace(overlay.MongoReplicaPassword) != "" {
|
||||
merged.MongoReplicaPassword = overlay.MongoReplicaPassword
|
||||
}
|
||||
if strings.TrimSpace(overlay.OpaqueURI) != "" {
|
||||
merged.OpaqueURI = overlay.OpaqueURI
|
||||
}
|
||||
if strings.TrimSpace(overlay.OpaqueDSN) != "" {
|
||||
merged.OpaqueDSN = overlay.OpaqueDSN
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
func applyConnectionSecretClears(bundle connectionSecretBundle, input connection.SavedConnectionInput) connectionSecretBundle {
|
||||
cleared := bundle
|
||||
if input.ClearPrimaryPassword {
|
||||
cleared.Password = ""
|
||||
}
|
||||
if input.ClearSSHPassword {
|
||||
cleared.SSHPassword = ""
|
||||
}
|
||||
if input.ClearProxyPassword {
|
||||
cleared.ProxyPassword = ""
|
||||
}
|
||||
if input.ClearHTTPTunnelPassword {
|
||||
cleared.HTTPTunnelPassword = ""
|
||||
}
|
||||
if input.ClearMySQLReplicaPassword {
|
||||
cleared.MySQLReplicaPassword = ""
|
||||
}
|
||||
if input.ClearMongoReplicaPassword {
|
||||
cleared.MongoReplicaPassword = ""
|
||||
}
|
||||
if input.ClearOpaqueURI {
|
||||
cleared.OpaqueURI = ""
|
||||
}
|
||||
if input.ClearOpaqueDSN {
|
||||
cleared.OpaqueDSN = ""
|
||||
}
|
||||
return cleared
|
||||
}
|
||||
|
||||
func cloneStringSlice(input []string) []string {
|
||||
if len(input) == 0 {
|
||||
return nil
|
||||
}
|
||||
cloned := make([]string, len(input))
|
||||
copy(cloned, input)
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneIntSlice(input []int) []int {
|
||||
if len(input) == 0 {
|
||||
return nil
|
||||
}
|
||||
cloned := make([]int, len(input))
|
||||
copy(cloned, input)
|
||||
return cloned
|
||||
}
|
||||
|
||||
func splitConnectionSecrets(input connection.SavedConnectionInput) (connection.SavedConnectionView, connectionSecretBundle) {
|
||||
id := strings.TrimSpace(input.ID)
|
||||
if id == "" {
|
||||
id = strings.TrimSpace(input.Config.ID)
|
||||
}
|
||||
|
||||
meta := input.Config
|
||||
meta.ID = id
|
||||
meta.SavePassword = false
|
||||
|
||||
bundle := connectionSecretBundle{}
|
||||
if strings.TrimSpace(meta.Password) != "" {
|
||||
bundle.Password = meta.Password
|
||||
meta.Password = ""
|
||||
}
|
||||
if strings.TrimSpace(meta.SSH.Password) != "" {
|
||||
bundle.SSHPassword = meta.SSH.Password
|
||||
meta.SSH.Password = ""
|
||||
}
|
||||
if strings.TrimSpace(meta.Proxy.Password) != "" {
|
||||
bundle.ProxyPassword = meta.Proxy.Password
|
||||
meta.Proxy.Password = ""
|
||||
}
|
||||
if strings.TrimSpace(meta.HTTPTunnel.Password) != "" {
|
||||
bundle.HTTPTunnelPassword = meta.HTTPTunnel.Password
|
||||
meta.HTTPTunnel.Password = ""
|
||||
}
|
||||
if strings.TrimSpace(meta.MySQLReplicaPassword) != "" {
|
||||
bundle.MySQLReplicaPassword = meta.MySQLReplicaPassword
|
||||
meta.MySQLReplicaPassword = ""
|
||||
}
|
||||
if strings.TrimSpace(meta.MongoReplicaPassword) != "" {
|
||||
bundle.MongoReplicaPassword = meta.MongoReplicaPassword
|
||||
meta.MongoReplicaPassword = ""
|
||||
}
|
||||
if strings.TrimSpace(meta.URI) != "" {
|
||||
bundle.OpaqueURI = meta.URI
|
||||
meta.URI = ""
|
||||
}
|
||||
if strings.TrimSpace(meta.DSN) != "" {
|
||||
bundle.OpaqueDSN = meta.DSN
|
||||
meta.DSN = ""
|
||||
}
|
||||
|
||||
view := connection.SavedConnectionView{
|
||||
ID: id,
|
||||
Name: strings.TrimSpace(input.Name),
|
||||
Config: meta,
|
||||
IncludeDatabases: cloneStringSlice(input.IncludeDatabases),
|
||||
IncludeRedisDatabases: cloneIntSlice(input.IncludeRedisDatabases),
|
||||
IconType: strings.TrimSpace(input.IconType),
|
||||
IconColor: strings.TrimSpace(input.IconColor),
|
||||
HasPrimaryPassword: strings.TrimSpace(bundle.Password) != "",
|
||||
HasSSHPassword: strings.TrimSpace(bundle.SSHPassword) != "",
|
||||
HasProxyPassword: strings.TrimSpace(bundle.ProxyPassword) != "",
|
||||
HasHTTPTunnelPassword: strings.TrimSpace(bundle.HTTPTunnelPassword) != "",
|
||||
HasMySQLReplicaPassword: strings.TrimSpace(bundle.MySQLReplicaPassword) != "",
|
||||
HasMongoReplicaPassword: strings.TrimSpace(bundle.MongoReplicaPassword) != "",
|
||||
HasOpaqueURI: strings.TrimSpace(bundle.OpaqueURI) != "",
|
||||
HasOpaqueDSN: strings.TrimSpace(bundle.OpaqueDSN) != "",
|
||||
}
|
||||
return view, bundle
|
||||
}
|
||||
|
||||
func (r *savedConnectionRepository) connectionsPath() string {
|
||||
return filepath.Join(r.configDir, savedConnectionsFileName)
|
||||
}
|
||||
|
||||
func (r *savedConnectionRepository) load() ([]connection.SavedConnectionView, error) {
|
||||
data, err := os.ReadFile(r.connectionsPath())
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return []connection.SavedConnectionView{}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var file savedConnectionsFile
|
||||
if err := json.Unmarshal(data, &file); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if file.Connections == nil {
|
||||
return []connection.SavedConnectionView{}, nil
|
||||
}
|
||||
return file.Connections, nil
|
||||
}
|
||||
|
||||
func (r *savedConnectionRepository) saveAll(connections []connection.SavedConnectionView) error {
|
||||
if err := os.MkdirAll(r.configDir, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
payload, err := json.MarshalIndent(savedConnectionsFile{Connections: connections}, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(r.connectionsPath(), payload, 0o644)
|
||||
}
|
||||
|
||||
func (r *savedConnectionRepository) Save(input connection.SavedConnectionInput) (connection.SavedConnectionView, error) {
|
||||
if strings.TrimSpace(input.ID) == "" && strings.TrimSpace(input.Config.ID) == "" {
|
||||
input.ID = "conn-" + uuid.New().String()[:8]
|
||||
}
|
||||
if strings.TrimSpace(input.ID) == "" {
|
||||
input.ID = strings.TrimSpace(input.Config.ID)
|
||||
}
|
||||
input.Config.ID = input.ID
|
||||
|
||||
connections, err := r.load()
|
||||
if err != nil {
|
||||
return connection.SavedConnectionView{}, err
|
||||
}
|
||||
|
||||
view, bundle := splitConnectionSecrets(input)
|
||||
index := -1
|
||||
var existing connection.SavedConnectionView
|
||||
for i, item := range connections {
|
||||
if item.ID == view.ID {
|
||||
index = i
|
||||
existing = item
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
mergedBundle := bundle
|
||||
if index >= 0 && savedConnectionViewHasSecrets(existing) {
|
||||
existingBundle, bundleErr := r.loadSecretBundle(existing)
|
||||
if bundleErr != nil {
|
||||
return connection.SavedConnectionView{}, bundleErr
|
||||
}
|
||||
mergedBundle = mergeConnectionSecretBundles(existingBundle, bundle)
|
||||
view.SecretRef = existing.SecretRef
|
||||
}
|
||||
mergedBundle = applyConnectionSecretClears(mergedBundle, input)
|
||||
|
||||
if mergedBundle.hasAny() {
|
||||
ref, storeErr := r.storeSecretBundle(view.ID, view.SecretRef, mergedBundle)
|
||||
if storeErr != nil {
|
||||
return connection.SavedConnectionView{}, storeErr
|
||||
}
|
||||
view.SecretRef = ref
|
||||
applyConnectionBundleFlags(&view, mergedBundle)
|
||||
} else {
|
||||
if index >= 0 && strings.TrimSpace(existing.SecretRef) != "" {
|
||||
if deleteErr := r.secretStore.Delete(existing.SecretRef); deleteErr != nil {
|
||||
return connection.SavedConnectionView{}, deleteErr
|
||||
}
|
||||
}
|
||||
view.SecretRef = ""
|
||||
applyConnectionBundleFlags(&view, connectionSecretBundle{})
|
||||
}
|
||||
|
||||
if index >= 0 {
|
||||
connections[index] = view
|
||||
} else {
|
||||
connections = append(connections, view)
|
||||
}
|
||||
if err := r.saveAll(connections); err != nil {
|
||||
return connection.SavedConnectionView{}, err
|
||||
}
|
||||
return view, nil
|
||||
}
|
||||
|
||||
func (r *savedConnectionRepository) Find(id string) (connection.SavedConnectionView, error) {
|
||||
connections, err := r.load()
|
||||
if err != nil {
|
||||
return connection.SavedConnectionView{}, err
|
||||
}
|
||||
for _, item := range connections {
|
||||
if item.ID == strings.TrimSpace(id) {
|
||||
return item, nil
|
||||
}
|
||||
}
|
||||
return connection.SavedConnectionView{}, fmt.Errorf("saved connection not found: %s", id)
|
||||
}
|
||||
|
||||
func (r *savedConnectionRepository) storeSecretBundle(id string, existingRef string, bundle connectionSecretBundle) (string, error) {
|
||||
if r.secretStore == nil {
|
||||
return "", fmt.Errorf("secret store unavailable")
|
||||
}
|
||||
if err := r.secretStore.HealthCheck(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
ref := strings.TrimSpace(existingRef)
|
||||
if ref == "" {
|
||||
var err error
|
||||
ref, err = secretstore.BuildRef(savedConnectionSecretKind, id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
payload, err := json.Marshal(bundle)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := r.secretStore.Put(ref, payload); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
func (r *savedConnectionRepository) loadSecretBundle(view connection.SavedConnectionView) (connectionSecretBundle, error) {
|
||||
if !savedConnectionViewHasSecrets(view) {
|
||||
return connectionSecretBundle{}, nil
|
||||
}
|
||||
if r.secretStore == nil {
|
||||
return connectionSecretBundle{}, fmt.Errorf("secret store unavailable")
|
||||
}
|
||||
ref := strings.TrimSpace(view.SecretRef)
|
||||
if ref == "" {
|
||||
var err error
|
||||
ref, err = secretstore.BuildRef(savedConnectionSecretKind, view.ID)
|
||||
if err != nil {
|
||||
return connectionSecretBundle{}, err
|
||||
}
|
||||
}
|
||||
payload, err := r.secretStore.Get(ref)
|
||||
if err != nil {
|
||||
return connectionSecretBundle{}, err
|
||||
}
|
||||
var bundle connectionSecretBundle
|
||||
if err := json.Unmarshal(payload, &bundle); err != nil {
|
||||
return connectionSecretBundle{}, err
|
||||
}
|
||||
return bundle, nil
|
||||
}
|
||||
|
||||
func savedConnectionViewHasSecrets(view connection.SavedConnectionView) bool {
|
||||
return view.HasPrimaryPassword || view.HasSSHPassword || view.HasProxyPassword || view.HasHTTPTunnelPassword ||
|
||||
view.HasMySQLReplicaPassword || view.HasMongoReplicaPassword || view.HasOpaqueURI || view.HasOpaqueDSN
|
||||
}
|
||||
|
||||
func applyConnectionBundleFlags(view *connection.SavedConnectionView, bundle connectionSecretBundle) {
|
||||
view.HasPrimaryPassword = strings.TrimSpace(bundle.Password) != ""
|
||||
view.HasSSHPassword = strings.TrimSpace(bundle.SSHPassword) != ""
|
||||
view.HasProxyPassword = strings.TrimSpace(bundle.ProxyPassword) != ""
|
||||
view.HasHTTPTunnelPassword = strings.TrimSpace(bundle.HTTPTunnelPassword) != ""
|
||||
view.HasMySQLReplicaPassword = strings.TrimSpace(bundle.MySQLReplicaPassword) != ""
|
||||
view.HasMongoReplicaPassword = strings.TrimSpace(bundle.MongoReplicaPassword) != ""
|
||||
view.HasOpaqueURI = strings.TrimSpace(bundle.OpaqueURI) != ""
|
||||
view.HasOpaqueDSN = strings.TrimSpace(bundle.OpaqueDSN) != ""
|
||||
}
|
||||
|
||||
func buildDuplicateConnectionName(baseName string, existing []connection.SavedConnectionView) string {
|
||||
trimmedBaseName := strings.TrimSpace(baseName)
|
||||
if trimmedBaseName == "" {
|
||||
trimmedBaseName = "连接"
|
||||
}
|
||||
suffix := " - 副本"
|
||||
usedNames := make(map[string]struct{}, len(existing))
|
||||
for _, item := range existing {
|
||||
usedNames[strings.TrimSpace(item.Name)] = struct{}{}
|
||||
}
|
||||
candidate := trimmedBaseName + suffix
|
||||
counter := 2
|
||||
for {
|
||||
if _, exists := usedNames[candidate]; !exists {
|
||||
return candidate
|
||||
}
|
||||
candidate = fmt.Sprintf("%s%s %d", trimmedBaseName, suffix, counter)
|
||||
counter++
|
||||
}
|
||||
}
|
||||
|
||||
func (r *savedConnectionRepository) List() ([]connection.SavedConnectionView, error) {
|
||||
return r.load()
|
||||
}
|
||||
|
||||
func (r *savedConnectionRepository) Delete(id string) error {
|
||||
connections, err := r.load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
filtered := make([]connection.SavedConnectionView, 0, len(connections))
|
||||
for _, item := range connections {
|
||||
if item.ID == strings.TrimSpace(id) {
|
||||
if strings.TrimSpace(item.SecretRef) != "" && r.secretStore != nil {
|
||||
if deleteErr := r.secretStore.Delete(item.SecretRef); deleteErr != nil {
|
||||
return deleteErr
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, item)
|
||||
}
|
||||
return r.saveAll(filtered)
|
||||
}
|
||||
|
||||
func (r *savedConnectionRepository) Duplicate(id string) (connection.SavedConnectionView, error) {
|
||||
connections, err := r.load()
|
||||
if err != nil {
|
||||
return connection.SavedConnectionView{}, err
|
||||
}
|
||||
|
||||
index := -1
|
||||
for i, item := range connections {
|
||||
if item.ID == strings.TrimSpace(id) {
|
||||
index = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if index < 0 {
|
||||
return connection.SavedConnectionView{}, fmt.Errorf("saved connection not found: %s", id)
|
||||
}
|
||||
|
||||
original := connections[index]
|
||||
duplicate := original
|
||||
duplicate.ID = "conn-" + uuid.New().String()[:8]
|
||||
duplicate.Config.ID = duplicate.ID
|
||||
duplicate.Name = buildDuplicateConnectionName(original.Name, connections)
|
||||
|
||||
bundle, err := r.loadSecretBundle(original)
|
||||
if err != nil {
|
||||
return connection.SavedConnectionView{}, err
|
||||
}
|
||||
if bundle.hasAny() {
|
||||
ref, storeErr := r.storeSecretBundle(duplicate.ID, "", bundle)
|
||||
if storeErr != nil {
|
||||
return connection.SavedConnectionView{}, storeErr
|
||||
}
|
||||
duplicate.SecretRef = ref
|
||||
applyConnectionBundleFlags(&duplicate, bundle)
|
||||
} else {
|
||||
duplicate.SecretRef = ""
|
||||
applyConnectionBundleFlags(&duplicate, connectionSecretBundle{})
|
||||
}
|
||||
|
||||
connections = append(connections, duplicate)
|
||||
if err := r.saveAll(connections); err != nil {
|
||||
return connection.SavedConnectionView{}, err
|
||||
}
|
||||
return duplicate, nil
|
||||
}
|
||||
72
internal/app/saved_connections_test.go
Normal file
72
internal/app/saved_connections_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/secretstore"
|
||||
)
|
||||
|
||||
func TestSplitConnectionSecretsStripsPasswordsAndOpaqueDSN(t *testing.T) {
|
||||
input := connection.SavedConnectionInput{
|
||||
ID: "conn-1",
|
||||
Name: "Primary",
|
||||
Config: connection.ConnectionConfig{
|
||||
ID: "conn-1",
|
||||
Type: "postgres",
|
||||
Host: "db.local",
|
||||
Password: "postgres-secret",
|
||||
DSN: "postgres://user:pass@db.local/app",
|
||||
},
|
||||
}
|
||||
|
||||
view, bundle := splitConnectionSecrets(input)
|
||||
if view.Config.Password != "" {
|
||||
t.Fatal("metadata must not keep password")
|
||||
}
|
||||
if bundle.Password != "postgres-secret" {
|
||||
t.Fatal("bundle should keep primary password")
|
||||
}
|
||||
if bundle.OpaqueDSN == "" {
|
||||
t.Fatal("opaque DSN should be stored as secret")
|
||||
}
|
||||
if !view.HasPrimaryPassword {
|
||||
t.Fatal("expected view to report primary password")
|
||||
}
|
||||
if !view.HasOpaqueDSN {
|
||||
t.Fatal("expected view to report opaque DSN")
|
||||
}
|
||||
}
|
||||
|
||||
type fakeAppSecretStore struct {
|
||||
items map[string][]byte
|
||||
}
|
||||
|
||||
func newFakeAppSecretStore() *fakeAppSecretStore {
|
||||
return &fakeAppSecretStore{items: make(map[string][]byte)}
|
||||
}
|
||||
|
||||
func (s *fakeAppSecretStore) Put(ref string, payload []byte) error {
|
||||
s.items[ref] = append([]byte(nil), payload...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeAppSecretStore) Get(ref string) ([]byte, error) {
|
||||
payload, ok := s.items[ref]
|
||||
if !ok {
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
return append([]byte(nil), payload...), nil
|
||||
}
|
||||
|
||||
func (s *fakeAppSecretStore) Delete(ref string) error {
|
||||
delete(s.items, ref)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeAppSecretStore) HealthCheck() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ secretstore.SecretStore = (*fakeAppSecretStore)(nil)
|
||||
62
internal/connection/saved_types.go
Normal file
62
internal/connection/saved_types.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package connection
|
||||
|
||||
type SavedConnectionInput struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Config ConnectionConfig `json:"config"`
|
||||
IncludeDatabases []string `json:"includeDatabases,omitempty"`
|
||||
IncludeRedisDatabases []int `json:"includeRedisDatabases,omitempty"`
|
||||
IconType string `json:"iconType,omitempty"`
|
||||
IconColor string `json:"iconColor,omitempty"`
|
||||
ClearPrimaryPassword bool `json:"clearPrimaryPassword,omitempty"`
|
||||
ClearSSHPassword bool `json:"clearSSHPassword,omitempty"`
|
||||
ClearProxyPassword bool `json:"clearProxyPassword,omitempty"`
|
||||
ClearHTTPTunnelPassword bool `json:"clearHttpTunnelPassword,omitempty"`
|
||||
ClearMySQLReplicaPassword bool `json:"clearMySQLReplicaPassword,omitempty"`
|
||||
ClearMongoReplicaPassword bool `json:"clearMongoReplicaPassword,omitempty"`
|
||||
ClearOpaqueURI bool `json:"clearOpaqueURI,omitempty"`
|
||||
ClearOpaqueDSN bool `json:"clearOpaqueDSN,omitempty"`
|
||||
}
|
||||
|
||||
type SavedConnectionView struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Config ConnectionConfig `json:"config"`
|
||||
IncludeDatabases []string `json:"includeDatabases,omitempty"`
|
||||
IncludeRedisDatabases []int `json:"includeRedisDatabases,omitempty"`
|
||||
IconType string `json:"iconType,omitempty"`
|
||||
IconColor string `json:"iconColor,omitempty"`
|
||||
SecretRef string `json:"secretRef,omitempty"`
|
||||
HasPrimaryPassword bool `json:"hasPrimaryPassword,omitempty"`
|
||||
HasSSHPassword bool `json:"hasSSHPassword,omitempty"`
|
||||
HasProxyPassword bool `json:"hasProxyPassword,omitempty"`
|
||||
HasHTTPTunnelPassword bool `json:"hasHttpTunnelPassword,omitempty"`
|
||||
HasMySQLReplicaPassword bool `json:"hasMySQLReplicaPassword,omitempty"`
|
||||
HasMongoReplicaPassword bool `json:"hasMongoReplicaPassword,omitempty"`
|
||||
HasOpaqueURI bool `json:"hasOpaqueURI,omitempty"`
|
||||
HasOpaqueDSN bool `json:"hasOpaqueDSN,omitempty"`
|
||||
}
|
||||
|
||||
type LegacySavedConnection = SavedConnectionInput
|
||||
|
||||
type SaveGlobalProxyInput struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Type string `json:"type"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
User string `json:"user,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
}
|
||||
|
||||
type GlobalProxyView struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Type string `json:"type"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
User string `json:"user,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
HasPassword bool `json:"hasPassword,omitempty"`
|
||||
SecretRef string `json:"secretRef,omitempty"`
|
||||
}
|
||||
|
||||
type LegacyGlobalProxyInput = SaveGlobalProxyInput
|
||||
@@ -28,6 +28,7 @@ type HTTPTunnelConfig struct {
|
||||
|
||||
// ConnectionConfig 存储数据库连接的完整配置,包括 SSH、代理、SSL 等网络层设置。
|
||||
type ConnectionConfig struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
|
||||
@@ -279,7 +279,44 @@ func (c *ClickHouseDB) Ping() error {
|
||||
}
|
||||
ctx, cancel := utils.ContextWithTimeout(timeout)
|
||||
defer cancel()
|
||||
return c.conn.PingContext(ctx)
|
||||
if err := c.conn.PingContext(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.validateQueryPath()
|
||||
}
|
||||
|
||||
func (c *ClickHouseDB) validateQueryPath() error {
|
||||
if c.conn == nil {
|
||||
return fmt.Errorf("连接未打开")
|
||||
}
|
||||
timeout := c.pingTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
ctx, cancel := utils.ContextWithTimeout(timeout)
|
||||
defer cancel()
|
||||
|
||||
rows, err := c.conn.QueryContext(ctx, "SELECT currentDatabase()")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("连接查询验证未返回结果")
|
||||
}
|
||||
|
||||
var current sql.NullString
|
||||
if err := rows.Scan(¤t); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClickHouseDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
|
||||
119
internal/db/clickhouse_impl_test.go
Normal file
119
internal/db/clickhouse_impl_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
//go:build gonavi_full_drivers || gonavi_clickhouse_driver
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
const fakeClickHouseDriverName = "gonavi-fake-clickhouse"
|
||||
|
||||
var (
|
||||
registerFakeClickHouseDriverOnce sync.Once
|
||||
fakeClickHouseStateMu sync.Mutex
|
||||
fakeClickHouseState = struct {
|
||||
pingErr error
|
||||
queryErr error
|
||||
lastQuery string
|
||||
}{
|
||||
lastQuery: "",
|
||||
}
|
||||
)
|
||||
|
||||
func TestClickHousePingValidatesQueryPath(t *testing.T) {
|
||||
registerFakeClickHouseDriverOnce.Do(func() {
|
||||
sql.Register(fakeClickHouseDriverName, fakeClickHouseDriver{})
|
||||
})
|
||||
|
||||
db, err := sql.Open(fakeClickHouseDriverName, "")
|
||||
if err != nil {
|
||||
t.Fatalf("open fake clickhouse db failed: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
fakeClickHouseStateMu.Lock()
|
||||
fakeClickHouseState.pingErr = nil
|
||||
fakeClickHouseState.queryErr = errors.New("query path failed")
|
||||
fakeClickHouseState.lastQuery = ""
|
||||
fakeClickHouseStateMu.Unlock()
|
||||
|
||||
client := &ClickHouseDB{
|
||||
conn: db,
|
||||
pingTimeout: time.Second,
|
||||
}
|
||||
err = client.Ping()
|
||||
if err == nil {
|
||||
t.Fatal("expected Ping to fail when query validation fails")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "query path failed") {
|
||||
t.Fatalf("expected query validation error, got %v", err)
|
||||
}
|
||||
|
||||
fakeClickHouseStateMu.Lock()
|
||||
lastQuery := fakeClickHouseState.lastQuery
|
||||
fakeClickHouseStateMu.Unlock()
|
||||
if lastQuery != "SELECT currentDatabase()" {
|
||||
t.Fatalf("expected query validation SQL to run, got %q", lastQuery)
|
||||
}
|
||||
}
|
||||
|
||||
type fakeClickHouseDriver struct{}
|
||||
|
||||
func (fakeClickHouseDriver) Open(name string) (driver.Conn, error) {
|
||||
return fakeClickHouseConn{}, nil
|
||||
}
|
||||
|
||||
type fakeClickHouseConn struct{}
|
||||
|
||||
func (fakeClickHouseConn) Prepare(query string) (driver.Stmt, error) {
|
||||
return nil, errors.New("prepare not implemented")
|
||||
}
|
||||
|
||||
func (fakeClickHouseConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fakeClickHouseConn) Begin() (driver.Tx, error) {
|
||||
return nil, errors.New("transactions not implemented")
|
||||
}
|
||||
|
||||
func (fakeClickHouseConn) Ping(ctx context.Context) error {
|
||||
fakeClickHouseStateMu.Lock()
|
||||
defer fakeClickHouseStateMu.Unlock()
|
||||
return fakeClickHouseState.pingErr
|
||||
}
|
||||
|
||||
func (fakeClickHouseConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
|
||||
fakeClickHouseStateMu.Lock()
|
||||
defer fakeClickHouseStateMu.Unlock()
|
||||
fakeClickHouseState.lastQuery = query
|
||||
if fakeClickHouseState.queryErr != nil {
|
||||
return nil, fakeClickHouseState.queryErr
|
||||
}
|
||||
return &fakeClickHouseRows{}, nil
|
||||
}
|
||||
|
||||
type fakeClickHouseRows struct{}
|
||||
|
||||
func (r *fakeClickHouseRows) Columns() []string {
|
||||
return []string{"currentDatabase"}
|
||||
}
|
||||
|
||||
func (r *fakeClickHouseRows) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeClickHouseRows) Next(dest []driver.Value) error {
|
||||
if len(dest) > 0 {
|
||||
dest[0] = "default"
|
||||
}
|
||||
return io.EOF
|
||||
}
|
||||
104
internal/secretstore/keyring_store.go
Normal file
104
internal/secretstore/keyring_store.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package secretstore
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
||||
"github.com/99designs/keyring"
|
||||
)
|
||||
|
||||
type keyringClient interface {
|
||||
Get(key string) (keyring.Item, error)
|
||||
Set(item keyring.Item) error
|
||||
Remove(key string) error
|
||||
}
|
||||
|
||||
type keyringStore struct {
|
||||
ring keyringClient
|
||||
}
|
||||
|
||||
type keyringOpener func(cfg keyring.Config) (keyring.Keyring, error)
|
||||
|
||||
func NewKeyringStore() SecretStore {
|
||||
return newKeyringStoreWithOpener(runtime.GOOS, keyring.Open)
|
||||
}
|
||||
|
||||
func newKeyringStoreWithOpener(goos string, open keyringOpener) SecretStore {
|
||||
cfg, err := keyringConfigFor(goos)
|
||||
if err != nil {
|
||||
return NewUnavailableStore(err.Error())
|
||||
}
|
||||
|
||||
ring, err := open(cfg)
|
||||
if err != nil {
|
||||
return NewUnavailableStore(err.Error())
|
||||
}
|
||||
|
||||
return &keyringStore{ring: ring}
|
||||
}
|
||||
|
||||
func (s *keyringStore) Put(ref string, payload []byte) error {
|
||||
return wrapKeyringError(s.ring.Set(keyring.Item{Key: ref, Data: payload}))
|
||||
}
|
||||
|
||||
func (s *keyringStore) Get(ref string) ([]byte, error) {
|
||||
item, err := s.ring.Get(ref)
|
||||
if err != nil {
|
||||
return nil, wrapKeyringError(err)
|
||||
}
|
||||
return item.Data, nil
|
||||
}
|
||||
|
||||
func (s *keyringStore) Delete(ref string) error {
|
||||
return wrapKeyringError(s.ring.Remove(ref))
|
||||
}
|
||||
|
||||
func (s *keyringStore) HealthCheck() error {
|
||||
_, err := s.ring.Get(healthCheckRef)
|
||||
if err == nil || errors.Is(err, keyring.ErrKeyNotFound) {
|
||||
return nil
|
||||
}
|
||||
return wrapKeyringError(err)
|
||||
}
|
||||
|
||||
func wrapKeyringError(err error) error {
|
||||
if err == nil || errors.Is(err, keyring.ErrKeyNotFound) || IsUnavailable(err) {
|
||||
return err
|
||||
}
|
||||
return &UnavailableError{Reason: err.Error()}
|
||||
}
|
||||
|
||||
func keyringConfigFor(goos string) (keyring.Config, error) {
|
||||
backends := allowedBackendsFor(goos)
|
||||
if len(backends) == 0 {
|
||||
return keyring.Config{}, fmt.Errorf("unsupported keyring platform: %s", goos)
|
||||
}
|
||||
|
||||
return keyring.Config{
|
||||
ServiceName: serviceName,
|
||||
AllowedBackends: backends,
|
||||
KeychainTrustApplication: true,
|
||||
KeychainAccessibleWhenUnlocked: true,
|
||||
LibSecretCollectionName: "default",
|
||||
KeyCtlScope: "user",
|
||||
WinCredPrefix: serviceName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func allowedBackendsFor(goos string) []keyring.BackendType {
|
||||
switch goos {
|
||||
case "windows":
|
||||
return []keyring.BackendType{keyring.WinCredBackend}
|
||||
case "darwin":
|
||||
return []keyring.BackendType{keyring.KeychainBackend}
|
||||
case "linux":
|
||||
return []keyring.BackendType{
|
||||
keyring.SecretServiceBackend,
|
||||
keyring.KWalletBackend,
|
||||
keyring.KeyCtlBackend,
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
113
internal/secretstore/keyring_store_test.go
Normal file
113
internal/secretstore/keyring_store_test.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package secretstore
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/99designs/keyring"
|
||||
)
|
||||
|
||||
func TestStoreStatusValuesRemainStable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if StatusAvailable != "available" {
|
||||
t.Fatalf("expected StatusAvailable to remain stable, got %q", StatusAvailable)
|
||||
}
|
||||
if StatusUnavailable != "unavailable" {
|
||||
t.Fatalf("expected StatusUnavailable to remain stable, got %q", StatusUnavailable)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRefRejectsEmptyKind(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if _, err := BuildRef("", "secret-id"); err == nil {
|
||||
t.Fatal("BuildRef should reject an empty kind")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRefRejectsEmptyID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if _, err := BuildRef("database", ""); err == nil {
|
||||
t.Fatal("BuildRef should reject an empty id")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnavailableStoreHealthCheckReturnsUnavailableError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := NewUnavailableStore("keyring backend disabled")
|
||||
|
||||
err := store.HealthCheck()
|
||||
if err == nil {
|
||||
t.Fatal("HealthCheck should return an unavailable error")
|
||||
}
|
||||
|
||||
if !IsUnavailable(err) {
|
||||
t.Fatalf("HealthCheck error should be detected by IsUnavailable, got %T", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyringStoreHealthCheckTreatsMissingProbeItemAsHealthy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := &keyringStore{ring: fakeKeyringClient{getErr: keyring.ErrKeyNotFound}}
|
||||
if err := store.HealthCheck(); err != nil {
|
||||
t.Fatalf("HealthCheck should accept ErrKeyNotFound, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyringStoreHealthCheckReturnsUnavailableErrorOnBackendFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := &keyringStore{ring: fakeKeyringClient{getErr: errors.New("backend offline")}}
|
||||
if err := store.HealthCheck(); err == nil || !IsUnavailable(err) {
|
||||
t.Fatalf("HealthCheck should wrap backend failures as unavailable, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewKeyringStoreReturnsUnavailableStoreWhenOpenFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newKeyringStoreWithOpener("windows", func(cfg keyring.Config) (keyring.Keyring, error) {
|
||||
if len(cfg.AllowedBackends) != 1 || cfg.AllowedBackends[0] != keyring.WinCredBackend {
|
||||
t.Fatalf("unexpected backend config: %#v", cfg.AllowedBackends)
|
||||
}
|
||||
return nil, errors.New("no backend")
|
||||
})
|
||||
|
||||
if err := store.HealthCheck(); err == nil || !IsUnavailable(err) {
|
||||
t.Fatalf("expected unavailable store when open fails, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
type fakeKeyringClient struct {
|
||||
getErr error
|
||||
item keyring.Item
|
||||
removeErr error
|
||||
}
|
||||
|
||||
func (f fakeKeyringClient) Get(string) (keyring.Item, error) {
|
||||
if f.getErr != nil {
|
||||
return keyring.Item{}, f.getErr
|
||||
}
|
||||
return f.item, nil
|
||||
}
|
||||
|
||||
func (f fakeKeyringClient) Set(item keyring.Item) error {
|
||||
_ = item
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f fakeKeyringClient) Remove(string) error {
|
||||
return f.removeErr
|
||||
}
|
||||
|
||||
func (f fakeKeyringClient) GetMetadata(string) (keyring.Metadata, error) {
|
||||
return keyring.Metadata{}, nil
|
||||
}
|
||||
|
||||
func (f fakeKeyringClient) Keys() ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
76
internal/secretstore/store.go
Normal file
76
internal/secretstore/store.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package secretstore
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
serviceName = "gonavi"
|
||||
healthCheckRef = "oskeyring://gonavi/healthcheck/ping"
|
||||
)
|
||||
|
||||
type SecretStore interface {
|
||||
Put(ref string, payload []byte) error
|
||||
Get(ref string) ([]byte, error)
|
||||
Delete(ref string) error
|
||||
HealthCheck() error
|
||||
}
|
||||
|
||||
type StoreStatus string
|
||||
|
||||
const (
|
||||
StatusAvailable StoreStatus = "available"
|
||||
StatusUnavailable StoreStatus = "unavailable"
|
||||
)
|
||||
|
||||
type UnavailableError struct {
|
||||
Reason string
|
||||
}
|
||||
|
||||
func (e *UnavailableError) Error() string {
|
||||
reason := strings.TrimSpace(e.Reason)
|
||||
if reason == "" {
|
||||
return "secret store unavailable"
|
||||
}
|
||||
return fmt.Sprintf("secret store unavailable: %s", reason)
|
||||
}
|
||||
|
||||
func IsUnavailable(err error) bool {
|
||||
var target *UnavailableError
|
||||
return errors.As(err, &target)
|
||||
}
|
||||
|
||||
type unavailableStore struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func NewUnavailableStore(reason string) SecretStore {
|
||||
return unavailableStore{err: &UnavailableError{Reason: strings.TrimSpace(reason)}}
|
||||
}
|
||||
|
||||
func (s unavailableStore) Put(string, []byte) error {
|
||||
return s.err
|
||||
}
|
||||
|
||||
func (s unavailableStore) Get(string) ([]byte, error) {
|
||||
return nil, s.err
|
||||
}
|
||||
|
||||
func (s unavailableStore) Delete(string) error {
|
||||
return s.err
|
||||
}
|
||||
|
||||
func (s unavailableStore) HealthCheck() error {
|
||||
return s.err
|
||||
}
|
||||
|
||||
func BuildRef(kind, id string) (string, error) {
|
||||
kind = strings.TrimSpace(kind)
|
||||
id = strings.TrimSpace(id)
|
||||
if kind == "" || id == "" {
|
||||
return "", fmt.Errorf("invalid secret ref")
|
||||
}
|
||||
return fmt.Sprintf("oskeyring://%s/%s/%s", serviceName, kind, id), nil
|
||||
}
|
||||
Reference in New Issue
Block a user