mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-05-30 07:09:35 +08:00
feat(mongodb,connection-tree,query-editor,sidebar,sqlserver,table-designer,ssl): 完成 MongoDB v1/v2 驱动切换与复制连接,增强快捷键/搜索/筛选与设计表体验,并修复 SQLServer、SSL 及连接稳定性问题 (#180)
* feat(mongodb-driver,connection-tree): 支持 MongoDB v1/v2 切换并新增复制连接 * fix(mongodb-query): 修复 MongoDB 筛选不生效并兼容 shell 语法执行 refs #153 * fix(query-editor): 修复 SQLServer 自动补全回车重复 dbo 前缀 refs #159 * fix(sqlserver-table-designer): 修复设计表读取列时错误使用 schema 作为数据库名 refs #156 * feat(shortcuts): 增加快捷键设置并支持 SQL 执行/侧边栏搜索 refs #158 * fix(sidebar-search): 优化范围搜索匹配与交互 refs #158 * fix(filter,connection-recovery): 保持筛选状态并修复连接失效卡死 refs #165 同步修复连接失效后侧栏持续转圈、断开后无法恢复的问题 * feat(table-designer): 统一设计表界面风格并优化字段新增交互 - 统一设计表页面与数据面板的视觉风格,覆盖工具栏、Tabs、表格与编辑区域 - 移除默认硬边框,改为透明背景与细分隔线,提升整体观感一致性 - 添加字段后自动滚动到新行并高亮,且自动聚焦输入框 - 新增" 在选中字段后添加\,支持按选中字段位置插入字段 * feat(data-grid-filter): 筛选字段支持快捷搜索 - 在筛选条件字段下拉启用可搜索(showSearch) - 支持字段名大小写不敏感模糊匹配 - 表字段较多时可快速定位目标字段,减少下拉查找耗时 refs #171 * fix(db-ssl): 支持多数据源 SSL/TLS 连接并补齐达梦证书配置 refs #167 * fix(sidebar): 修复数据库加载时 null.map 导致表加载失败 * fix(query-editor): 合并运行按钮并保留 SQL 停止执行入口
This commit is contained in:
12
cmd/optional-driver-agent/provider_mongodb_v1.go
Normal file
12
cmd/optional-driver-agent/provider_mongodb_v1.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build gonavi_mongodb_driver_v1
|
||||
|
||||
package main
|
||||
|
||||
import "GoNavi-Wails/internal/db"
|
||||
|
||||
func init() {
|
||||
agentDriverType = "mongodb"
|
||||
agentDatabaseFactory = func() db.Database {
|
||||
return &db.MongoDBV1{}
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,7 @@ import React, { useState, useEffect } from 'react';
|
||||
import { Layout, Button, ConfigProvider, theme, Dropdown, MenuProps, message, Modal, Spin, Slider, Progress, Switch, Input, InputNumber, Select } from 'antd';
|
||||
import zhCN from 'antd/locale/zh_CN';
|
||||
import { PlusOutlined, ConsoleSqlOutlined, UploadOutlined, DownloadOutlined, CloudDownloadOutlined, BugOutlined, ToolOutlined, GlobalOutlined, InfoCircleOutlined, GithubOutlined, SkinOutlined, CheckOutlined, MinusOutlined, BorderOutlined, CloseOutlined, SettingOutlined, LinkOutlined } from '@ant-design/icons';
|
||||
import { BrowserOpenURL, Environment, EventsOn, Quit, WindowFullscreen, WindowIsFullscreen, WindowIsMaximised, WindowMaximise, WindowMinimise, WindowToggleMaximise } from '../wailsjs/runtime';
|
||||
import { BrowserOpenURL, Environment, EventsOn, Quit, WindowFullscreen, WindowGetSize, WindowIsFullscreen, WindowIsMaximised, WindowMaximise, WindowMinimise, WindowSetSize, WindowToggleMaximise } from '../wailsjs/runtime';
|
||||
import Sidebar from './components/Sidebar';
|
||||
import TabManager from './components/TabManager';
|
||||
import ConnectionModal from './components/ConnectionModal';
|
||||
@@ -12,6 +12,17 @@ import LogPanel from './components/LogPanel';
|
||||
import { useStore } from './store';
|
||||
import { SavedConnection } from './types';
|
||||
import { blurToFilter, normalizeBlurForPlatform, normalizeOpacityForPlatform, isWindowsPlatform } from './utils/appearance';
|
||||
import {
|
||||
SHORTCUT_ACTION_META,
|
||||
SHORTCUT_ACTION_ORDER,
|
||||
ShortcutAction,
|
||||
eventToShortcut,
|
||||
getShortcutDisplay,
|
||||
hasModifierKey,
|
||||
isEditableElement,
|
||||
isShortcutMatch,
|
||||
normalizeShortcutCombo,
|
||||
} from './utils/shortcuts';
|
||||
import { ConfigureGlobalProxy, SetWindowTranslucency } from '../wailsjs/go/app/App';
|
||||
import './App.css';
|
||||
|
||||
@@ -53,6 +64,9 @@ function App() {
|
||||
const setStartupFullscreen = useStore(state => state.setStartupFullscreen);
|
||||
const globalProxy = useStore(state => state.globalProxy);
|
||||
const setGlobalProxy = useStore(state => state.setGlobalProxy);
|
||||
const shortcutOptions = useStore(state => state.shortcutOptions);
|
||||
const updateShortcut = useStore(state => state.updateShortcut);
|
||||
const resetShortcutOptions = useStore(state => state.resetShortcutOptions);
|
||||
const darkMode = themeMode === 'dark';
|
||||
const effectiveUiScale = Math.min(MAX_UI_SCALE, Math.max(MIN_UI_SCALE, Number(uiScale) || DEFAULT_UI_SCALE));
|
||||
const effectiveFontSize = Math.min(MAX_FONT_SIZE, Math.max(MIN_FONT_SIZE, Math.round(Number(fontSize) || DEFAULT_FONT_SIZE)));
|
||||
@@ -260,6 +274,80 @@ function App() {
|
||||
};
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isWindowsPlatform()) {
|
||||
return;
|
||||
}
|
||||
|
||||
let cancelled = false;
|
||||
let inFlight = false;
|
||||
let lastRatio = Number(window.devicePixelRatio) || 1;
|
||||
let lastFixAt = 0;
|
||||
|
||||
const wait = (ms: number) => new Promise<void>((resolve) => window.setTimeout(resolve, ms));
|
||||
|
||||
const fixWindowScaleIfNeeded = async () => {
|
||||
if (cancelled || inFlight) return;
|
||||
const now = Date.now();
|
||||
if (now - lastFixAt < 700) return;
|
||||
inFlight = true;
|
||||
try {
|
||||
const [isFullscreen, isMaximised] = await Promise.all([
|
||||
WindowIsFullscreen().catch(() => false),
|
||||
WindowIsMaximised().catch(() => false),
|
||||
]);
|
||||
|
||||
// 避免在全屏/最大化状态下强制改尺寸;这两种状态通常能自行保持 DPI 同步。
|
||||
if (isFullscreen || isMaximised) {
|
||||
window.dispatchEvent(new Event('resize'));
|
||||
lastFixAt = Date.now();
|
||||
return;
|
||||
}
|
||||
|
||||
const size = await WindowGetSize().catch(() => null);
|
||||
const width = Math.trunc(Number(size?.w || 0));
|
||||
const height = Math.trunc(Number(size?.h || 0));
|
||||
if (width <= 0 || height <= 0) {
|
||||
window.dispatchEvent(new Event('resize'));
|
||||
lastFixAt = Date.now();
|
||||
return;
|
||||
}
|
||||
|
||||
const nudgedWidth = width > 480 ? width - 1 : width + 1;
|
||||
WindowSetSize(nudgedWidth, height);
|
||||
await wait(28);
|
||||
WindowSetSize(width, height);
|
||||
window.dispatchEvent(new Event('resize'));
|
||||
lastFixAt = Date.now();
|
||||
} finally {
|
||||
inFlight = false;
|
||||
}
|
||||
};
|
||||
|
||||
const checkDevicePixelRatio = () => {
|
||||
if (cancelled) return;
|
||||
const currentRatio = Number(window.devicePixelRatio) || 1;
|
||||
if (Math.abs(currentRatio - lastRatio) < 0.02) {
|
||||
return;
|
||||
}
|
||||
lastRatio = currentRatio;
|
||||
void fixWindowScaleIfNeeded();
|
||||
};
|
||||
|
||||
const pollTimer = window.setInterval(checkDevicePixelRatio, 900);
|
||||
window.addEventListener('resize', checkDevicePixelRatio);
|
||||
window.addEventListener('focus', checkDevicePixelRatio);
|
||||
document.addEventListener('visibilitychange', checkDevicePixelRatio);
|
||||
|
||||
return () => {
|
||||
cancelled = true;
|
||||
window.clearInterval(pollTimer);
|
||||
window.removeEventListener('resize', checkDevicePixelRatio);
|
||||
window.removeEventListener('focus', checkDevicePixelRatio);
|
||||
document.removeEventListener('visibilitychange', checkDevicePixelRatio);
|
||||
};
|
||||
}, []);
|
||||
|
||||
// Background Helper
|
||||
const getBg = (darkHex: string) => {
|
||||
if (!darkMode) return `rgba(255, 255, 255, ${effectiveOpacity})`; // Light mode usually white
|
||||
@@ -720,10 +808,18 @@ function App() {
|
||||
label: '外观设置...',
|
||||
icon: <SettingOutlined />,
|
||||
onClick: () => setIsAppearanceModalOpen(true)
|
||||
},
|
||||
{
|
||||
key: 'shortcut-settings',
|
||||
label: '快捷键管理...',
|
||||
icon: <LinkOutlined />,
|
||||
onClick: () => setIsShortcutModalOpen(true)
|
||||
}
|
||||
];
|
||||
|
||||
const [isAppearanceModalOpen, setIsAppearanceModalOpen] = useState(false);
|
||||
const [isShortcutModalOpen, setIsShortcutModalOpen] = useState(false);
|
||||
const [capturingShortcutAction, setCapturingShortcutAction] = useState<ShortcutAction | null>(null);
|
||||
const [isProxyModalOpen, setIsProxyModalOpen] = useState(false);
|
||||
|
||||
|
||||
@@ -935,6 +1031,113 @@ function App() {
|
||||
};
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
const handleOpenShortcutSettingsEvent = () => {
|
||||
setIsShortcutModalOpen(true);
|
||||
};
|
||||
window.addEventListener('gonavi:open-shortcut-settings', handleOpenShortcutSettingsEvent as EventListener);
|
||||
return () => {
|
||||
window.removeEventListener('gonavi:open-shortcut-settings', handleOpenShortcutSettingsEvent as EventListener);
|
||||
};
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
const handleGlobalShortcut = (event: KeyboardEvent) => {
|
||||
const matchedAction = SHORTCUT_ACTION_ORDER.find((action) => {
|
||||
const binding = shortcutOptions[action];
|
||||
if (!binding?.enabled) {
|
||||
return false;
|
||||
}
|
||||
if (isEditableElement(event.target) && !SHORTCUT_ACTION_META[action].allowInEditable) {
|
||||
return false;
|
||||
}
|
||||
return isShortcutMatch(event, binding.combo);
|
||||
});
|
||||
|
||||
if (!matchedAction) {
|
||||
return;
|
||||
}
|
||||
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
|
||||
switch (matchedAction) {
|
||||
case 'runQuery':
|
||||
window.dispatchEvent(new CustomEvent('gonavi:run-active-query'));
|
||||
break;
|
||||
case 'focusSidebarSearch':
|
||||
window.dispatchEvent(new CustomEvent('gonavi:focus-sidebar-search'));
|
||||
break;
|
||||
case 'newQueryTab':
|
||||
handleNewQuery();
|
||||
break;
|
||||
case 'toggleLogPanel':
|
||||
setIsLogPanelOpen((prev) => !prev);
|
||||
break;
|
||||
case 'toggleTheme':
|
||||
setTheme(themeMode === 'dark' ? 'light' : 'dark');
|
||||
break;
|
||||
case 'openShortcutManager':
|
||||
setIsShortcutModalOpen(true);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener('keydown', handleGlobalShortcut);
|
||||
return () => {
|
||||
window.removeEventListener('keydown', handleGlobalShortcut);
|
||||
};
|
||||
}, [handleNewQuery, shortcutOptions, themeMode, setTheme]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!capturingShortcutAction) {
|
||||
return;
|
||||
}
|
||||
|
||||
const handleShortcutCapture = (event: KeyboardEvent) => {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
|
||||
if (event.key === 'Escape') {
|
||||
setCapturingShortcutAction(null);
|
||||
return;
|
||||
}
|
||||
|
||||
const combo = eventToShortcut(event);
|
||||
if (!combo) {
|
||||
return;
|
||||
}
|
||||
if (!hasModifierKey(combo)) {
|
||||
void message.warning('快捷键至少包含 Ctrl / Alt / Shift / Meta 之一');
|
||||
return;
|
||||
}
|
||||
|
||||
const normalizedCombo = normalizeShortcutCombo(combo);
|
||||
const conflictAction = SHORTCUT_ACTION_ORDER.find((action) => {
|
||||
if (action === capturingShortcutAction) {
|
||||
return false;
|
||||
}
|
||||
const binding = shortcutOptions[action];
|
||||
if (!binding?.enabled) {
|
||||
return false;
|
||||
}
|
||||
return normalizeShortcutCombo(binding.combo) === normalizedCombo;
|
||||
});
|
||||
if (conflictAction) {
|
||||
void message.warning(`与「${SHORTCUT_ACTION_META[conflictAction].label}」冲突,请换一个快捷键`);
|
||||
return;
|
||||
}
|
||||
|
||||
updateShortcut(capturingShortcutAction, { combo: normalizedCombo, enabled: true });
|
||||
setCapturingShortcutAction(null);
|
||||
};
|
||||
|
||||
window.addEventListener('keydown', handleShortcutCapture, true);
|
||||
return () => {
|
||||
window.removeEventListener('keydown', handleShortcutCapture, true);
|
||||
};
|
||||
}, [capturingShortcutAction, shortcutOptions, updateShortcut]);
|
||||
|
||||
const linuxResizeHandleStyleBase = {
|
||||
position: 'fixed',
|
||||
zIndex: 12000,
|
||||
@@ -1354,6 +1557,84 @@ function App() {
|
||||
</div>
|
||||
</Modal>
|
||||
|
||||
<Modal
|
||||
title="快捷键管理"
|
||||
open={isShortcutModalOpen}
|
||||
onCancel={() => {
|
||||
setIsShortcutModalOpen(false);
|
||||
setCapturingShortcutAction(null);
|
||||
}}
|
||||
width={720}
|
||||
footer={[
|
||||
<Button
|
||||
key="reset"
|
||||
onClick={() => {
|
||||
resetShortcutOptions();
|
||||
setCapturingShortcutAction(null);
|
||||
void message.success('已恢复默认快捷键');
|
||||
}}
|
||||
>
|
||||
恢复默认
|
||||
</Button>,
|
||||
<Button
|
||||
key="close"
|
||||
type="primary"
|
||||
onClick={() => {
|
||||
setIsShortcutModalOpen(false);
|
||||
setCapturingShortcutAction(null);
|
||||
}}
|
||||
>
|
||||
关闭
|
||||
</Button>,
|
||||
]}
|
||||
>
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: 12, paddingTop: 8 }}>
|
||||
<div style={{ fontSize: 12, color: '#8c8c8c' }}>
|
||||
点击“录制”后按下快捷键。按 Esc 可取消录制。建议至少包含一个修饰键(Ctrl/Alt/Shift/Meta)。
|
||||
</div>
|
||||
{SHORTCUT_ACTION_ORDER.map((action) => {
|
||||
const meta = SHORTCUT_ACTION_META[action];
|
||||
const binding = shortcutOptions[action] ?? { combo: '', enabled: false };
|
||||
const isCapturing = capturingShortcutAction === action;
|
||||
return (
|
||||
<div
|
||||
key={action}
|
||||
style={{
|
||||
display: 'grid',
|
||||
gridTemplateColumns: '1fr auto',
|
||||
gap: 12,
|
||||
alignItems: 'center',
|
||||
padding: '10px 12px',
|
||||
border: '1px solid rgba(128, 128, 128, 0.2)',
|
||||
borderRadius: 8,
|
||||
}}
|
||||
>
|
||||
<div>
|
||||
<div style={{ fontWeight: 500 }}>{meta.label}</div>
|
||||
<div style={{ fontSize: 12, color: '#8c8c8c' }}>{meta.description}</div>
|
||||
</div>
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: 8 }}>
|
||||
<Input
|
||||
readOnly
|
||||
value={isCapturing ? '请按下快捷键...' : getShortcutDisplay(binding.combo)}
|
||||
style={{ width: 180, fontFamily: 'Consolas, Menlo, Monaco, monospace' }}
|
||||
/>
|
||||
<Button
|
||||
size="small"
|
||||
onClick={() => setCapturingShortcutAction((prev) => (prev === action ? null : action))}
|
||||
>
|
||||
{isCapturing ? '取消' : '录制'}
|
||||
</Button>
|
||||
<Switch
|
||||
checked={binding.enabled}
|
||||
onChange={(checked) => updateShortcut(action, { enabled: checked })}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</Modal>
|
||||
<Modal
|
||||
title="全局代理设置"
|
||||
open={isProxyModalOpen}
|
||||
|
||||
@@ -54,6 +54,26 @@ const singleHostUriSchemesByType: Record<string, string[]> = {
|
||||
vastbase: ['vastbase'],
|
||||
};
|
||||
|
||||
const sslSupportedTypes = new Set([
|
||||
'mysql',
|
||||
'mariadb',
|
||||
'diros',
|
||||
'sphinx',
|
||||
'dameng',
|
||||
'clickhouse',
|
||||
'postgres',
|
||||
'sqlserver',
|
||||
'oracle',
|
||||
'kingbase',
|
||||
'highgo',
|
||||
'vastbase',
|
||||
'mongodb',
|
||||
'redis',
|
||||
'tdengine',
|
||||
]);
|
||||
|
||||
const supportsSSLForType = (type: string) => sslSupportedTypes.has(String(type || '').trim().toLowerCase());
|
||||
|
||||
const isFileDatabaseType = (type: string) => type === 'sqlite' || type === 'duckdb';
|
||||
|
||||
type DriverStatusSnapshot = {
|
||||
@@ -78,6 +98,7 @@ const ConnectionModal: React.FC<{
|
||||
}> = ({ open, onClose, initialValues, onOpenDriverManager }) => {
|
||||
const [form] = Form.useForm();
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [useSSL, setUseSSL] = useState(false);
|
||||
const [useSSH, setUseSSH] = useState(false);
|
||||
const [useProxy, setUseProxy] = useState(false);
|
||||
const [dbType, setDbType] = useState('mysql');
|
||||
@@ -107,6 +128,17 @@ const ConnectionModal: React.FC<{
|
||||
const mongoTopology = Form.useWatch('mongoTopology', form) || 'single';
|
||||
const mongoSrv = Form.useWatch('mongoSrv', form) || false;
|
||||
const redisTopology = Form.useWatch('redisTopology', form) || 'single';
|
||||
const isMySQLLike = dbType === 'mysql' || dbType === 'mariadb' || dbType === 'diros' || dbType === 'sphinx';
|
||||
const isSSLType = supportsSSLForType(dbType);
|
||||
const sslHintText = isMySQLLike
|
||||
? '当 MySQL/MariaDB/Doris/Sphinx 开启安全传输策略时,请启用 SSL;本地自签证书场景可先用 Preferred 或 Skip Verify。'
|
||||
: dbType === 'dameng'
|
||||
? '达梦驱动启用 SSL 需要客户端证书与私钥路径(sslCertPath / sslKeyPath)。'
|
||||
: dbType === 'sqlserver'
|
||||
? 'SQL Server 推荐在生产环境使用 Required,并关闭 TrustServerCertificate。'
|
||||
: dbType === 'mongodb'
|
||||
? 'MongoDB 可通过 TLS 保护连接,证书校验异常时可先用 Skip Verify 验证连通性。'
|
||||
: '建议优先使用 Required;仅在测试环境或自签证书场景使用 Skip Verify。';
|
||||
|
||||
const getSectionBg = (darkHex: string) => {
|
||||
if (!darkMode) {
|
||||
@@ -364,7 +396,7 @@ const ConnectionModal: React.FC<{
|
||||
uriText: string,
|
||||
expectedSchemes: string[],
|
||||
defaultPort: number,
|
||||
): { host: string; port: number; username: string; password: string; database: string } | null => {
|
||||
): { host: string; port: number; username: string; password: string; database: string; params: URLSearchParams } | null => {
|
||||
let parsed: ReturnType<typeof parseMultiHostUri> | null = null;
|
||||
for (const scheme of expectedSchemes) {
|
||||
parsed = parseMultiHostUri(uriText, scheme);
|
||||
@@ -392,6 +424,7 @@ const ConnectionModal: React.FC<{
|
||||
username: parsed.username,
|
||||
password: parsed.password,
|
||||
database: parsed.database || '',
|
||||
params: parsed.params,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -425,12 +458,22 @@ const ConnectionModal: React.FC<{
|
||||
const primary = parseHostPort(hostList[0] || `localhost:${mysqlDefaultPort}`, mysqlDefaultPort);
|
||||
const timeoutValue = Number(parsed.params.get('timeout'));
|
||||
const topology = String(parsed.params.get('topology') || '').toLowerCase();
|
||||
const tlsValue = String(parsed.params.get('tls') || '').trim().toLowerCase();
|
||||
const sslMode = tlsValue === 'true'
|
||||
? 'required'
|
||||
: tlsValue === 'skip-verify'
|
||||
? 'skip-verify'
|
||||
: tlsValue === 'preferred'
|
||||
? 'preferred'
|
||||
: 'disable';
|
||||
return {
|
||||
host: primary?.host || 'localhost',
|
||||
port: primary?.port || mysqlDefaultPort,
|
||||
user: parsed.username,
|
||||
password: parsed.password,
|
||||
database: parsed.database || '',
|
||||
useSSL: sslMode !== 'disable',
|
||||
sslMode,
|
||||
mysqlTopology: hostList.length > 1 || topology === 'replica' ? 'replica' : 'single',
|
||||
mysqlReplicaHosts: hostList.slice(1),
|
||||
timeout: Number.isFinite(timeoutValue) && timeoutValue > 0
|
||||
@@ -451,7 +494,7 @@ const ConnectionModal: React.FC<{
|
||||
}
|
||||
|
||||
if (type === 'redis') {
|
||||
const parsed = parseMultiHostUri(trimmedUri, 'redis');
|
||||
const parsed = parseMultiHostUri(trimmedUri, 'redis') || parseMultiHostUri(trimmedUri, 'rediss');
|
||||
if (!parsed) {
|
||||
return null;
|
||||
}
|
||||
@@ -469,10 +512,15 @@ const ConnectionModal: React.FC<{
|
||||
const topologyParam = String(parsed.params.get('topology') || '').toLowerCase();
|
||||
const dbText = String(parsed.database || '').trim().replace(/^\//, '');
|
||||
const dbIndex = Number(dbText);
|
||||
const isRediss = trimmedUri.toLowerCase().startsWith('rediss://');
|
||||
const skipVerifyText = String(parsed.params.get('skip_verify') || '').trim().toLowerCase();
|
||||
const skipVerify = skipVerifyText === '1' || skipVerifyText === 'true' || skipVerifyText === 'yes' || skipVerifyText === 'on';
|
||||
return {
|
||||
host: primary?.host || 'localhost',
|
||||
port: primary?.port || 6379,
|
||||
password: parsed.password || '',
|
||||
useSSL: isRediss,
|
||||
sslMode: isRediss ? (skipVerify ? 'skip-verify' : 'required') : 'disable',
|
||||
redisTopology: hostList.length > 1 || topologyParam === 'cluster' ? 'cluster' : 'single',
|
||||
redisHosts: hostList.slice(1),
|
||||
redisDB: Number.isFinite(dbIndex) && dbIndex >= 0 && dbIndex <= 15 ? Math.trunc(dbIndex) : 0,
|
||||
@@ -501,12 +549,18 @@ const ConnectionModal: React.FC<{
|
||||
? { host: hostList[0] || 'localhost', port: 27017 }
|
||||
: parseHostPort(hostList[0] || 'localhost:27017', 27017);
|
||||
const timeoutMs = Number(parsed.params.get('connectTimeoutMS') || parsed.params.get('serverSelectionTimeoutMS'));
|
||||
const tlsText = String(parsed.params.get('tls') || parsed.params.get('ssl') || '').trim().toLowerCase();
|
||||
const tlsInsecureText = String(parsed.params.get('tlsInsecure') || parsed.params.get('sslInsecure') || '').trim().toLowerCase();
|
||||
const tlsEnabled = tlsText === '1' || tlsText === 'true' || tlsText === 'yes' || tlsText === 'on';
|
||||
const tlsInsecure = tlsInsecureText === '1' || tlsInsecureText === 'true' || tlsInsecureText === 'yes' || tlsInsecureText === 'on';
|
||||
return {
|
||||
host: primary?.host || 'localhost',
|
||||
port: primary?.port || 27017,
|
||||
user: parsed.username,
|
||||
password: parsed.password,
|
||||
database: parsed.database || '',
|
||||
useSSL: tlsEnabled,
|
||||
sslMode: tlsEnabled ? (tlsInsecure ? 'skip-verify' : 'required') : 'disable',
|
||||
mongoTopology: hostList.length > 1 || !!parsed.params.get('replicaSet') ? 'replica' : 'single',
|
||||
mongoHosts: hostList.slice(1),
|
||||
mongoSrv: isSrv,
|
||||
@@ -531,13 +585,94 @@ const ConnectionModal: React.FC<{
|
||||
// Oracle 需要显式 service name,避免 URI 解析后放过必填校验。
|
||||
return null;
|
||||
}
|
||||
return {
|
||||
const parsedValues: Record<string, any> = {
|
||||
host: parsed.host,
|
||||
port: parsed.port,
|
||||
user: parsed.username,
|
||||
password: parsed.password,
|
||||
database: parsed.database,
|
||||
};
|
||||
|
||||
if (supportsSSLForType(type)) {
|
||||
const normalizeBool = (raw: unknown) => {
|
||||
const text = String(raw ?? '').trim().toLowerCase();
|
||||
return text === '1' || text === 'true' || text === 'yes' || text === 'on';
|
||||
};
|
||||
if (type === 'postgres' || type === 'kingbase' || type === 'highgo' || type === 'vastbase') {
|
||||
const sslMode = String(parsed.params.get('sslmode') || '').trim().toLowerCase();
|
||||
if (sslMode) {
|
||||
parsedValues.useSSL = sslMode !== 'disable' && sslMode !== 'false';
|
||||
parsedValues.sslMode = sslMode === 'disable' || sslMode === 'false'
|
||||
? 'disable'
|
||||
: 'required';
|
||||
}
|
||||
} else if (type === 'sqlserver') {
|
||||
const encrypt = String(parsed.params.get('encrypt') || '').trim().toLowerCase();
|
||||
const trust = String(parsed.params.get('TrustServerCertificate') || parsed.params.get('trustservercertificate') || '').trim().toLowerCase();
|
||||
const encrypted = encrypt === 'true' || encrypt === 'mandatory' || encrypt === 'yes' || encrypt === '1' || encrypt === 'strict';
|
||||
if (encrypted) {
|
||||
parsedValues.useSSL = true;
|
||||
parsedValues.sslMode = trust === 'true' || trust === '1' || trust === 'yes' ? 'skip-verify' : 'required';
|
||||
} else if (encrypt) {
|
||||
parsedValues.useSSL = false;
|
||||
parsedValues.sslMode = 'disable';
|
||||
}
|
||||
} else if (type === 'clickhouse') {
|
||||
const secure = String(parsed.params.get('secure') || parsed.params.get('tls') || '').trim().toLowerCase();
|
||||
const skipVerify = normalizeBool(parsed.params.get('skip_verify'));
|
||||
if (secure) {
|
||||
parsedValues.useSSL = normalizeBool(secure);
|
||||
parsedValues.sslMode = skipVerify ? 'skip-verify' : (parsedValues.useSSL ? 'required' : 'disable');
|
||||
}
|
||||
} else if (type === 'dameng') {
|
||||
const certPath = String(
|
||||
parsed.params.get('SSL_CERT_PATH')
|
||||
|| parsed.params.get('ssl_cert_path')
|
||||
|| parsed.params.get('sslCertPath')
|
||||
|| ''
|
||||
).trim();
|
||||
const keyPath = String(
|
||||
parsed.params.get('SSL_KEY_PATH')
|
||||
|| parsed.params.get('ssl_key_path')
|
||||
|| parsed.params.get('sslKeyPath')
|
||||
|| ''
|
||||
).trim();
|
||||
parsedValues.sslCertPath = certPath;
|
||||
parsedValues.sslKeyPath = keyPath;
|
||||
if (certPath || keyPath) {
|
||||
parsedValues.useSSL = true;
|
||||
parsedValues.sslMode = 'required';
|
||||
}
|
||||
} else if (type === 'oracle') {
|
||||
const ssl = String(parsed.params.get('SSL') || parsed.params.get('ssl') || '').trim().toLowerCase();
|
||||
const sslVerify = String(
|
||||
parsed.params.get('SSL VERIFY')
|
||||
|| parsed.params.get('ssl verify')
|
||||
|| parsed.params.get('SSL_VERIFY')
|
||||
|| parsed.params.get('ssl_verify')
|
||||
|| ''
|
||||
).trim().toLowerCase();
|
||||
if (ssl) {
|
||||
parsedValues.useSSL = normalizeBool(ssl);
|
||||
if (!parsedValues.useSSL) {
|
||||
parsedValues.sslMode = 'disable';
|
||||
} else {
|
||||
parsedValues.sslMode = normalizeBool(sslVerify || 'true') ? 'required' : 'skip-verify';
|
||||
}
|
||||
}
|
||||
} else if (type === 'tdengine') {
|
||||
const protocol = String(parsed.params.get('protocol') || '').trim().toLowerCase();
|
||||
const skipVerify = normalizeBool(parsed.params.get('skip_verify'));
|
||||
if (protocol === 'wss') {
|
||||
parsedValues.useSSL = true;
|
||||
parsedValues.sslMode = skipVerify ? 'skip-verify' : 'required';
|
||||
} else if (protocol === 'ws') {
|
||||
parsedValues.useSSL = false;
|
||||
parsedValues.sslMode = 'disable';
|
||||
}
|
||||
}
|
||||
};
|
||||
return parsedValues;
|
||||
}
|
||||
|
||||
return null;
|
||||
@@ -609,6 +744,16 @@ const ConnectionModal: React.FC<{
|
||||
if (hosts.length > 1 || values.mysqlTopology === 'replica') {
|
||||
params.set('topology', 'replica');
|
||||
}
|
||||
if (values.useSSL) {
|
||||
const mode = String(values.sslMode || 'preferred').trim().toLowerCase();
|
||||
if (mode === 'required') {
|
||||
params.set('tls', 'true');
|
||||
} else if (mode === 'skip-verify') {
|
||||
params.set('tls', 'skip-verify');
|
||||
} else {
|
||||
params.set('tls', 'preferred');
|
||||
}
|
||||
}
|
||||
if (Number.isFinite(timeout) && timeout > 0) {
|
||||
params.set('timeout', String(timeout));
|
||||
}
|
||||
@@ -634,8 +779,15 @@ const ConnectionModal: React.FC<{
|
||||
? Math.max(0, Math.min(15, Math.trunc(Number(values.redisDB))))
|
||||
: 0;
|
||||
const dbPath = `/${redisDB}`;
|
||||
if (values.useSSL) {
|
||||
const mode = String(values.sslMode || 'preferred').trim().toLowerCase();
|
||||
if (mode === 'skip-verify' || mode === 'preferred') {
|
||||
params.set('skip_verify', 'true');
|
||||
}
|
||||
}
|
||||
const query = params.toString();
|
||||
return `redis://${redisAuth}${hosts.join(',')}${dbPath}${query ? `?${query}` : ''}`;
|
||||
const scheme = values.useSSL ? 'rediss' : 'redis';
|
||||
return `${scheme}://${redisAuth}${hosts.join(',')}${dbPath}${query ? `?${query}` : ''}`;
|
||||
}
|
||||
|
||||
if (isFileDatabaseType(type)) {
|
||||
@@ -675,6 +827,15 @@ const ConnectionModal: React.FC<{
|
||||
if (authMechanism) {
|
||||
params.set('authMechanism', authMechanism);
|
||||
}
|
||||
if (values.useSSL) {
|
||||
const mode = String(values.sslMode || 'preferred').trim().toLowerCase();
|
||||
params.set('tls', 'true');
|
||||
if (mode === 'skip-verify' || mode === 'preferred') {
|
||||
params.set('tlsInsecure', 'true');
|
||||
} else {
|
||||
params.delete('tlsInsecure');
|
||||
}
|
||||
}
|
||||
if (Number.isFinite(timeout) && timeout > 0) {
|
||||
params.set('connectTimeoutMS', String(timeout * 1000));
|
||||
params.set('serverSelectionTimeoutMS', String(timeout * 1000));
|
||||
@@ -686,7 +847,45 @@ const ConnectionModal: React.FC<{
|
||||
|
||||
const scheme = type === 'postgres' ? 'postgresql' : type;
|
||||
const dbPath = database ? `/${encodeURIComponent(database)}` : '';
|
||||
return `${scheme}://${encodedAuth}${toAddress(host, port, defaultPort)}${dbPath}`;
|
||||
const params = new URLSearchParams();
|
||||
if (supportsSSLForType(type) && values.useSSL) {
|
||||
const mode = String(values.sslMode || 'preferred').trim().toLowerCase();
|
||||
if (type === 'postgres' || type === 'kingbase' || type === 'highgo' || type === 'vastbase') {
|
||||
params.set('sslmode', 'require');
|
||||
} else if (type === 'sqlserver') {
|
||||
params.set('encrypt', 'true');
|
||||
params.set('TrustServerCertificate', mode === 'skip-verify' || mode === 'preferred' ? 'true' : 'false');
|
||||
} else if (type === 'clickhouse') {
|
||||
params.set('secure', 'true');
|
||||
if (mode === 'skip-verify' || mode === 'preferred') {
|
||||
params.set('skip_verify', 'true');
|
||||
}
|
||||
} else if (type === 'dameng') {
|
||||
const certPath = String(values.sslCertPath || '').trim();
|
||||
const keyPath = String(values.sslKeyPath || '').trim();
|
||||
if (certPath) params.set('SSL_CERT_PATH', certPath);
|
||||
if (keyPath) params.set('SSL_KEY_PATH', keyPath);
|
||||
} else if (type === 'oracle') {
|
||||
params.set('SSL', 'TRUE');
|
||||
params.set('SSL VERIFY', mode === 'required' ? 'TRUE' : 'FALSE');
|
||||
} else if (type === 'tdengine') {
|
||||
params.set('protocol', 'wss');
|
||||
if (mode === 'skip-verify' || mode === 'preferred') {
|
||||
params.set('skip_verify', 'true');
|
||||
}
|
||||
}
|
||||
} else if (supportsSSLForType(type)) {
|
||||
if (type === 'postgres' || type === 'kingbase' || type === 'highgo' || type === 'vastbase') {
|
||||
params.set('sslmode', 'disable');
|
||||
} else if (type === 'sqlserver') {
|
||||
params.set('encrypt', 'disable');
|
||||
params.set('TrustServerCertificate', 'true');
|
||||
} else if (type === 'tdengine') {
|
||||
params.set('protocol', 'ws');
|
||||
}
|
||||
}
|
||||
const query = params.toString();
|
||||
return `${scheme}://${encodedAuth}${toAddress(host, port, defaultPort)}${dbPath}${query ? `?${query}` : ''}`;
|
||||
};
|
||||
|
||||
const handleGenerateURI = () => {
|
||||
@@ -838,6 +1037,10 @@ const ConnectionModal: React.FC<{
|
||||
uri: config.uri || '',
|
||||
includeDatabases: initialValues.includeDatabases,
|
||||
includeRedisDatabases: initialValues.includeRedisDatabases,
|
||||
useSSL: !!config.useSSL,
|
||||
sslMode: config.sslMode || 'preferred',
|
||||
sslCertPath: config.sslCertPath || '',
|
||||
sslKeyPath: config.sslKeyPath || '',
|
||||
useSSH: config.useSSH,
|
||||
sshHost: config.ssh?.host,
|
||||
sshPort: config.ssh?.port,
|
||||
@@ -871,6 +1074,7 @@ const ConnectionModal: React.FC<{
|
||||
mongoReplicaUser: config.mongoReplicaUser || '',
|
||||
mongoReplicaPassword: config.mongoReplicaPassword || ''
|
||||
});
|
||||
setUseSSL(!!config.useSSL);
|
||||
setUseSSH(config.useSSH || false);
|
||||
setUseProxy(config.useProxy || false);
|
||||
setDbType(configType);
|
||||
@@ -882,6 +1086,7 @@ const ConnectionModal: React.FC<{
|
||||
// Create mode: Start at step 1
|
||||
setStep(1);
|
||||
form.resetFields();
|
||||
setUseSSL(false);
|
||||
setUseSSH(false);
|
||||
setUseProxy(false);
|
||||
setDbType('mysql');
|
||||
@@ -932,6 +1137,7 @@ const ConnectionModal: React.FC<{
|
||||
|
||||
setLoading(false);
|
||||
form.resetFields();
|
||||
setUseSSL(false);
|
||||
setUseSSH(false);
|
||||
setUseProxy(false);
|
||||
setDbType('mysql');
|
||||
@@ -1073,6 +1279,21 @@ const ConnectionModal: React.FC<{
|
||||
const type = String(mergedValues.type || '').toLowerCase();
|
||||
const defaultPort = getDefaultPortByType(type);
|
||||
const isFileDbType = isFileDatabaseType(type);
|
||||
const sslCapableType = supportsSSLForType(type);
|
||||
const sslModeRaw = String(mergedValues.sslMode || 'preferred').trim().toLowerCase();
|
||||
const sslMode: 'preferred' | 'required' | 'skip-verify' | 'disable' = sslModeRaw === 'required'
|
||||
? 'required'
|
||||
: sslModeRaw === 'skip-verify'
|
||||
? 'skip-verify'
|
||||
: sslModeRaw === 'disable'
|
||||
? 'disable'
|
||||
: 'preferred';
|
||||
const effectiveUseSSL = sslCapableType && !!mergedValues.useSSL;
|
||||
const sslCertPath = sslCapableType ? String(mergedValues.sslCertPath || '').trim() : '';
|
||||
const sslKeyPath = sslCapableType ? String(mergedValues.sslKeyPath || '').trim() : '';
|
||||
if (type === 'dameng' && effectiveUseSSL && (!sslCertPath || !sslKeyPath)) {
|
||||
throw new Error('达梦启用 SSL 时必须填写证书路径与私钥路径');
|
||||
}
|
||||
|
||||
let primaryHost = 'localhost';
|
||||
let primaryPort = defaultPort;
|
||||
@@ -1194,6 +1415,10 @@ const ConnectionModal: React.FC<{
|
||||
password: keepPassword ? (mergedValues.password || "") : "",
|
||||
savePassword: savePassword,
|
||||
database: mergedValues.database || "",
|
||||
useSSL: effectiveUseSSL,
|
||||
sslMode: effectiveUseSSL ? sslMode : 'disable',
|
||||
sslCertPath: sslCertPath,
|
||||
sslKeyPath: sslKeyPath,
|
||||
useSSH: !!mergedValues.useSSH,
|
||||
ssh: sshConfig,
|
||||
useProxy: effectiveUseProxy,
|
||||
@@ -1233,6 +1458,7 @@ const ConnectionModal: React.FC<{
|
||||
|
||||
const defaultPort = getDefaultPortByType(type);
|
||||
if (isFileDatabaseType(type)) {
|
||||
setUseSSL(false);
|
||||
setUseSSH(false);
|
||||
setUseProxy(false);
|
||||
form.setFieldsValue({
|
||||
@@ -1241,6 +1467,10 @@ const ConnectionModal: React.FC<{
|
||||
user: '',
|
||||
password: '',
|
||||
database: '',
|
||||
useSSL: false,
|
||||
sslMode: 'preferred',
|
||||
sslCertPath: '',
|
||||
sslKeyPath: '',
|
||||
useSSH: false,
|
||||
sshHost: '',
|
||||
sshPort: 22,
|
||||
@@ -1273,10 +1503,16 @@ const ConnectionModal: React.FC<{
|
||||
});
|
||||
} else if (type !== 'custom') {
|
||||
const defaultUser = type === 'clickhouse' ? 'default' : 'root';
|
||||
const sslCapableType = supportsSSLForType(type);
|
||||
setUseSSL(false);
|
||||
form.setFieldsValue({
|
||||
user: defaultUser,
|
||||
database: '',
|
||||
port: defaultPort,
|
||||
useSSL: sslCapableType ? false : undefined,
|
||||
sslMode: sslCapableType ? 'preferred' : undefined,
|
||||
sslCertPath: sslCapableType ? '' : undefined,
|
||||
sslKeyPath: sslCapableType ? '' : undefined,
|
||||
mysqlTopology: 'single',
|
||||
redisTopology: 'single',
|
||||
mongoTopology: 'single',
|
||||
@@ -1420,6 +1656,10 @@ const ConnectionModal: React.FC<{
|
||||
port: 3306,
|
||||
database: '',
|
||||
user: 'root',
|
||||
useSSL: false,
|
||||
sslMode: 'preferred',
|
||||
sslCertPath: '',
|
||||
sslKeyPath: '',
|
||||
useSSH: false,
|
||||
sshPort: 22,
|
||||
useProxy: false,
|
||||
@@ -1451,6 +1691,7 @@ const ConnectionModal: React.FC<{
|
||||
if (changed.uri !== undefined || changed.type !== undefined) {
|
||||
setUriFeedback(null);
|
||||
}
|
||||
if (changed.useSSL !== undefined) setUseSSL(changed.useSSL);
|
||||
if (changed.useSSH !== undefined) setUseSSH(changed.useSSH);
|
||||
if (changed.useProxy !== undefined) setUseProxy(changed.useProxy);
|
||||
if (changed.proxyType !== undefined) {
|
||||
@@ -1835,6 +2076,56 @@ const ConnectionModal: React.FC<{
|
||||
|
||||
{!isFileDb && (
|
||||
<>
|
||||
{isSSLType && (
|
||||
<>
|
||||
<Divider style={{ margin: '12px 0' }} />
|
||||
<Form.Item name="useSSL" valuePropName="checked" style={{ marginBottom: 0 }}>
|
||||
<Checkbox>使用 SSL/TLS</Checkbox>
|
||||
</Form.Item>
|
||||
{useSSL && (
|
||||
<div style={tunnelSectionStyle}>
|
||||
<Form.Item
|
||||
name="sslMode"
|
||||
label="SSL 模式"
|
||||
rules={[{ required: true, message: '请选择 SSL 模式' }]}
|
||||
style={{ marginBottom: 8 }}
|
||||
>
|
||||
<Select
|
||||
options={[
|
||||
{ value: 'preferred', label: 'Preferred(优先 SSL,推荐)' },
|
||||
{ value: 'required', label: 'Required(必须 SSL,校验证书)' },
|
||||
{ value: 'skip-verify', label: 'Skip Verify(必须 SSL,跳过证书校验)' },
|
||||
]}
|
||||
/>
|
||||
</Form.Item>
|
||||
{dbType === 'dameng' && (
|
||||
<>
|
||||
<Form.Item
|
||||
name="sslCertPath"
|
||||
label="客户端证书路径 (SSL_CERT_PATH)"
|
||||
rules={[{ required: true, message: '达梦 SSL 需要证书路径' }]}
|
||||
style={{ marginBottom: 8 }}
|
||||
>
|
||||
<Input placeholder="例如: C:\\certs\\client-cert.pem" />
|
||||
</Form.Item>
|
||||
<Form.Item
|
||||
name="sslKeyPath"
|
||||
label="客户端私钥路径 (SSL_KEY_PATH)"
|
||||
rules={[{ required: true, message: '达梦 SSL 需要私钥路径' }]}
|
||||
style={{ marginBottom: 8 }}
|
||||
>
|
||||
<Input placeholder="例如: C:\\certs\\client-key.pem" />
|
||||
</Form.Item>
|
||||
</>
|
||||
)}
|
||||
<Text type="secondary" style={{ fontSize: 12 }}>
|
||||
{sslHintText}
|
||||
</Text>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
|
||||
<Divider style={{ margin: '12px 0' }} />
|
||||
<Form.Item name="useSSH" valuePropName="checked" style={{ marginBottom: 0 }}>
|
||||
<Checkbox>使用 SSH 隧道 (SSH Tunnel)</Checkbox>
|
||||
|
||||
@@ -580,6 +580,7 @@ interface DataGridProps {
|
||||
onToggleFilter?: () => void;
|
||||
exportSqlWithFilter?: string;
|
||||
onApplyFilter?: (conditions: GridFilterCondition[]) => void;
|
||||
appliedFilterConditions?: FilterCondition[];
|
||||
}
|
||||
|
||||
type GridFilterCondition = FilterCondition & {
|
||||
@@ -599,7 +600,7 @@ type ColumnMeta = {
|
||||
|
||||
const DataGrid: React.FC<DataGridProps> = ({
|
||||
data, columnNames, loading, tableName, exportScope = 'table', resultSql, dbName, connectionId, pkColumns = [], readOnly = false,
|
||||
onReload, onSort, onPageChange, pagination, onRequestTotalCount, onCancelTotalCount, sortInfoExternal, showFilter, onToggleFilter, exportSqlWithFilter, onApplyFilter
|
||||
onReload, onSort, onPageChange, pagination, onRequestTotalCount, onCancelTotalCount, sortInfoExternal, showFilter, onToggleFilter, exportSqlWithFilter, onApplyFilter, appliedFilterConditions
|
||||
}) => {
|
||||
const connections = useStore(state => state.connections);
|
||||
const addSqlLog = useStore(state => state.addSqlLog);
|
||||
@@ -1064,10 +1065,40 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
const [modifiedRows, setModifiedRows] = useState<Record<string, any>>({});
|
||||
const [deletedRowKeys, setDeletedRowKeys] = useState<Set<string>>(new Set());
|
||||
|
||||
const normalizeFilterLogic = useCallback((logic: unknown): 'AND' | 'OR' => {
|
||||
return String(logic || '').trim().toUpperCase() === 'OR' ? 'OR' : 'AND';
|
||||
}, []);
|
||||
|
||||
const normalizeGridFilterConditions = useCallback((conditions?: FilterCondition[]): GridFilterCondition[] => {
|
||||
if (!Array.isArray(conditions)) return [];
|
||||
return conditions.map((cond, index) => {
|
||||
const fallbackId = index + 1;
|
||||
const nextId = Number.isFinite(Number(cond?.id)) ? Number(cond?.id) : fallbackId;
|
||||
const op = String(cond?.op || '=');
|
||||
const rawColumn = String(cond?.column || '');
|
||||
return {
|
||||
id: nextId,
|
||||
enabled: cond?.enabled !== false,
|
||||
logic: normalizeFilterLogic(cond?.logic),
|
||||
column: rawColumn || (op === 'CUSTOM' ? '' : String(columnNames[0] || '')),
|
||||
op,
|
||||
value: String(cond?.value ?? ''),
|
||||
value2: String(cond?.value2 ?? ''),
|
||||
};
|
||||
});
|
||||
}, [columnNames, normalizeFilterLogic]);
|
||||
|
||||
// Filter State
|
||||
const [filterConditions, setFilterConditions] = useState<GridFilterCondition[]>([]);
|
||||
const [nextFilterId, setNextFilterId] = useState(1);
|
||||
|
||||
useEffect(() => {
|
||||
const nextConditions = normalizeGridFilterConditions(appliedFilterConditions);
|
||||
setFilterConditions(nextConditions);
|
||||
const maxId = nextConditions.reduce((max, cond) => (cond.id > max ? cond.id : max), 0);
|
||||
setNextFilterId(Math.max(1, maxId + 1));
|
||||
}, [appliedFilterConditions, normalizeGridFilterConditions]);
|
||||
|
||||
const selectedRowKeysRef = useRef(selectedRowKeys);
|
||||
const displayDataRef = useRef<any[]>([]);
|
||||
|
||||
@@ -2547,6 +2578,10 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
{ value: 'NOT_IN', label: '不在列表' },
|
||||
{ value: 'CUSTOM', label: '[自定义]' },
|
||||
]), []);
|
||||
const filterLogicOptions = useMemo(() => ([
|
||||
{ value: 'AND', label: '且 (AND)' },
|
||||
{ value: 'OR', label: '或 (OR)' },
|
||||
]), []);
|
||||
|
||||
const isNoValueOp = useCallback((op: string) => (
|
||||
op === 'IS_NULL' || op === 'IS_NOT_NULL' || op === 'IS_EMPTY' || op === 'IS_NOT_EMPTY'
|
||||
@@ -2555,7 +2590,18 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
const isListOp = useCallback((op: string) => op === 'IN' || op === 'NOT_IN', []);
|
||||
|
||||
const addFilter = () => {
|
||||
setFilterConditions([...filterConditions, { id: nextFilterId, enabled: true, column: columnNames[0] || '', op: '=', value: '', value2: '' }]);
|
||||
setFilterConditions([
|
||||
...filterConditions,
|
||||
{
|
||||
id: nextFilterId,
|
||||
enabled: true,
|
||||
logic: 'AND',
|
||||
column: columnNames[0] || '',
|
||||
op: '=',
|
||||
value: '',
|
||||
value2: '',
|
||||
}
|
||||
]);
|
||||
setNextFilterId(nextFilterId + 1);
|
||||
};
|
||||
const updateFilter = (id: number, field: keyof GridFilterCondition, val: string | boolean) => {
|
||||
@@ -3066,7 +3112,7 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
background: 'transparent',
|
||||
boxSizing: 'border-box',
|
||||
}}>
|
||||
{filterConditions.map(cond => (
|
||||
{filterConditions.map((cond, condIndex) => (
|
||||
<div key={cond.id} style={{ display: 'flex', gap: 8, marginBottom: 8, alignItems: 'flex-start', opacity: cond.enabled === false ? 0.58 : 1 }}>
|
||||
<Checkbox
|
||||
checked={cond.enabled !== false}
|
||||
@@ -3075,13 +3121,33 @@ const DataGrid: React.FC<DataGridProps> = ({
|
||||
>
|
||||
启用
|
||||
</Checkbox>
|
||||
<Select
|
||||
style={{ width: 180 }}
|
||||
value={cond.column}
|
||||
onChange={v => updateFilter(cond.id, 'column', v)}
|
||||
options={columnNames.map(c => ({ value: c, label: c }))}
|
||||
disabled={cond.op === 'CUSTOM'}
|
||||
/>
|
||||
{condIndex === 0 ? (
|
||||
<div style={{ width: 96, marginTop: 7, textAlign: 'center', fontSize: 12, color: '#8c8c8c' }}>
|
||||
首条
|
||||
</div>
|
||||
) : (
|
||||
<Select
|
||||
style={{ width: 96 }}
|
||||
value={cond.logic === 'OR' ? 'OR' : 'AND'}
|
||||
onChange={v => updateFilter(cond.id, 'logic', v)}
|
||||
options={filterLogicOptions as any}
|
||||
/>
|
||||
)}
|
||||
<Select
|
||||
style={{ width: 180 }}
|
||||
value={cond.column}
|
||||
onChange={v => updateFilter(cond.id, 'column', v)}
|
||||
options={columnNames.map(c => ({ value: c, label: c }))}
|
||||
showSearch
|
||||
optionFilterProp="label"
|
||||
filterOption={(input, option) =>
|
||||
String(option?.label ?? '')
|
||||
.toLowerCase()
|
||||
.includes(String(input || '').trim().toLowerCase())
|
||||
}
|
||||
placeholder="搜索字段名"
|
||||
disabled={cond.op === 'CUSTOM'}
|
||||
/>
|
||||
<Select
|
||||
style={{ width: 140 }}
|
||||
value={cond.op}
|
||||
|
||||
@@ -5,6 +5,7 @@ import { useStore } from '../store';
|
||||
import { DBQuery, DBGetColumns } from '../../wailsjs/go/app/App';
|
||||
import DataGrid, { GONAVI_ROW_KEY } from './DataGrid';
|
||||
import { buildOrderBySQL, buildWhereSQL, quoteIdentPart, quoteQualifiedIdent, withSortBufferTuningSQL, type FilterCondition } from '../utils/sql';
|
||||
import { buildMongoCountCommand, buildMongoFilter, buildMongoFindCommand, buildMongoSort } from '../utils/mongodb';
|
||||
import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities';
|
||||
|
||||
type ViewerPaginationState = {
|
||||
@@ -151,6 +152,37 @@ const reverseOrderBySQL = (orderBySQL: string): string => {
|
||||
return ` ORDER BY ${parts.join(', ')}`;
|
||||
};
|
||||
|
||||
type ViewerFilterSnapshot = {
|
||||
showFilter: boolean;
|
||||
conditions: FilterCondition[];
|
||||
};
|
||||
|
||||
const viewerFilterSnapshotsByTab = new Map<string, ViewerFilterSnapshot>();
|
||||
|
||||
const normalizeViewerFilterConditions = (conditions: FilterCondition[] | undefined): FilterCondition[] => {
|
||||
if (!Array.isArray(conditions)) return [];
|
||||
return conditions.map((cond) => ({
|
||||
id: Number.isFinite(Number(cond?.id)) ? Number(cond?.id) : undefined,
|
||||
enabled: cond?.enabled !== false,
|
||||
logic: String(cond?.logic || '').trim().toUpperCase() === 'OR' ? 'OR' : 'AND',
|
||||
column: String(cond?.column || ''),
|
||||
op: String(cond?.op || '='),
|
||||
value: String(cond?.value ?? ''),
|
||||
value2: String(cond?.value2 ?? ''),
|
||||
}));
|
||||
};
|
||||
|
||||
const getViewerFilterSnapshot = (tabId: string): ViewerFilterSnapshot => {
|
||||
const cached = viewerFilterSnapshotsByTab.get(String(tabId || '').trim());
|
||||
if (!cached) {
|
||||
return { showFilter: false, conditions: [] };
|
||||
}
|
||||
return {
|
||||
showFilter: cached.showFilter === true,
|
||||
conditions: normalizeViewerFilterConditions(cached.conditions),
|
||||
};
|
||||
};
|
||||
|
||||
const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
const [data, setData] = useState<any[]>([]);
|
||||
const [columnNames, setColumnNames] = useState<string[]>([]);
|
||||
@@ -185,14 +217,27 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
|
||||
const [sortInfo, setSortInfo] = useState<{ columnKey: string, order: string } | null>(null);
|
||||
|
||||
const [showFilter, setShowFilter] = useState(false);
|
||||
const [filterConditions, setFilterConditions] = useState<FilterCondition[]>([]);
|
||||
const [showFilter, setShowFilter] = useState<boolean>(() => getViewerFilterSnapshot(tab.id).showFilter);
|
||||
const [filterConditions, setFilterConditions] = useState<FilterCondition[]>(() => getViewerFilterSnapshot(tab.id).conditions);
|
||||
const duckdbSafeSelectCacheRef = useRef<Record<string, string>>({});
|
||||
const currentConnConfig = connections.find(c => c.id === tab.connectionId)?.config;
|
||||
const currentConnCaps = getDataSourceCapabilities(currentConnConfig);
|
||||
const currentConnType = currentConnCaps.type;
|
||||
const forceReadOnly = currentConnCaps.forceReadOnlyQueryResult;
|
||||
|
||||
useEffect(() => {
|
||||
const snapshot = getViewerFilterSnapshot(tab.id);
|
||||
setShowFilter(snapshot.showFilter);
|
||||
setFilterConditions(snapshot.conditions);
|
||||
}, [tab.id]);
|
||||
|
||||
useEffect(() => {
|
||||
viewerFilterSnapshotsByTab.set(tab.id, {
|
||||
showFilter,
|
||||
conditions: normalizeViewerFilterConditions(filterConditions),
|
||||
});
|
||||
}, [tab.id, showFilter, filterConditions]);
|
||||
|
||||
useEffect(() => {
|
||||
setPkColumns([]);
|
||||
pkKeyRef.current = '';
|
||||
@@ -315,42 +360,67 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
|
||||
const dbName = tab.dbName || '';
|
||||
const tableName = tab.tableName || '';
|
||||
const isMongoDB = dbTypeLower === 'mongodb';
|
||||
let mongoFilter: Record<string, unknown> | undefined;
|
||||
if (isMongoDB) {
|
||||
try {
|
||||
mongoFilter = buildMongoFilter(filterConditions);
|
||||
} catch (e: any) {
|
||||
message.error(`Mongo 筛选条件无效:${String(e?.message || e || '解析失败')}`);
|
||||
if (fetchSeqRef.current === seq) setLoading(false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const whereSQL = buildWhereSQL(dbType, filterConditions);
|
||||
|
||||
const countSql = `SELECT COUNT(*) as total FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
|
||||
|
||||
const baseSql = `SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
|
||||
const orderBySQL = buildOrderBySQL(dbType, sortInfo, pkColumns);
|
||||
let sql = `${baseSql}${orderBySQL}`;
|
||||
const whereSQL = isMongoDB
|
||||
? JSON.stringify(mongoFilter || {})
|
||||
: buildWhereSQL(dbType, filterConditions);
|
||||
const countSql = isMongoDB
|
||||
? buildMongoCountCommand(tableName, mongoFilter || {})
|
||||
: `SELECT COUNT(*) as total FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
|
||||
const orderBySQL = isMongoDB ? '' : buildOrderBySQL(dbType, sortInfo, pkColumns);
|
||||
const totalRows = Number(pagination.total);
|
||||
const hasFiniteTotal = Number.isFinite(totalRows) && totalRows >= 0;
|
||||
const totalKnown = pagination.totalKnown && hasFiniteTotal;
|
||||
const totalPages = hasFiniteTotal ? Math.max(1, Math.ceil(totalRows / size)) : 0;
|
||||
const currentPage = totalPages > 0 ? Math.min(Math.max(1, page), totalPages) : Math.max(1, page);
|
||||
const offset = (currentPage - 1) * size;
|
||||
const isClickHouse = dbTypeLower === 'clickhouse';
|
||||
const isClickHouse = !isMongoDB && dbTypeLower === 'clickhouse';
|
||||
const reverseOrderSQL = isClickHouse ? reverseOrderBySQL(orderBySQL) : '';
|
||||
let useClickHouseReversePagination = false;
|
||||
let clickHouseReverseLimit = 0;
|
||||
let clickHouseReverseHasMore = false;
|
||||
// ClickHouse 深分页在超大 OFFSET 下容易超时。对于总数已知且存在 ORDER BY 的场景,
|
||||
// 当“尾部偏移”小于“头部偏移”时,改为反向 ORDER BY + 小 OFFSET,并在前端翻转结果。
|
||||
if (isClickHouse && totalKnown && offset > 0 && reverseOrderSQL) {
|
||||
const pageRowCount = Math.max(0, Math.min(size, totalRows - offset));
|
||||
if (pageRowCount > 0) {
|
||||
const tailOffset = Math.max(0, totalRows - (offset + pageRowCount));
|
||||
if (tailOffset < offset) {
|
||||
sql = `${baseSql}${reverseOrderSQL} LIMIT ${pageRowCount} OFFSET ${tailOffset}`;
|
||||
useClickHouseReversePagination = true;
|
||||
clickHouseReverseLimit = pageRowCount;
|
||||
clickHouseReverseHasMore = currentPage < totalPages;
|
||||
let sql = '';
|
||||
if (isMongoDB) {
|
||||
const mongoSort = buildMongoSort(sortInfo, pkColumns);
|
||||
sql = buildMongoFindCommand({
|
||||
collection: tableName,
|
||||
filter: mongoFilter || {},
|
||||
sort: mongoSort,
|
||||
limit: size + 1,
|
||||
skip: offset,
|
||||
});
|
||||
} else {
|
||||
const baseSql = `SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
|
||||
sql = `${baseSql}${orderBySQL}`;
|
||||
// ClickHouse 深分页在超大 OFFSET 下容易超时。对于总数已知且存在 ORDER BY 的场景,
|
||||
// 当“尾部偏移”小于“头部偏移”时,改为反向 ORDER BY + 小 OFFSET,并在前端翻转结果。
|
||||
if (isClickHouse && totalKnown && offset > 0 && reverseOrderSQL) {
|
||||
const pageRowCount = Math.max(0, Math.min(size, totalRows - offset));
|
||||
if (pageRowCount > 0) {
|
||||
const tailOffset = Math.max(0, totalRows - (offset + pageRowCount));
|
||||
if (tailOffset < offset) {
|
||||
sql = `${baseSql}${reverseOrderSQL} LIMIT ${pageRowCount} OFFSET ${tailOffset}`;
|
||||
useClickHouseReversePagination = true;
|
||||
clickHouseReverseLimit = pageRowCount;
|
||||
clickHouseReverseHasMore = currentPage < totalPages;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!useClickHouseReversePagination) {
|
||||
// 大表性能:打开表不阻塞在 COUNT(*),先通过多取 1 条判断是否还有下一页;总数在后台统计并异步回填。
|
||||
sql += ` LIMIT ${size + 1} OFFSET ${offset}`;
|
||||
if (!useClickHouseReversePagination) {
|
||||
// 大表性能:打开表不阻塞在 COUNT(*),先通过多取 1 条判断是否还有下一页;总数在后台统计并异步回填。
|
||||
sql += ` LIMIT ${size + 1} OFFSET ${offset}`;
|
||||
}
|
||||
}
|
||||
|
||||
const requestStartTime = Date.now();
|
||||
@@ -718,6 +788,7 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
showFilter={showFilter}
|
||||
onToggleFilter={handleToggleFilter}
|
||||
onApplyFilter={handleApplyFilter}
|
||||
appliedFilterConditions={filterConditions}
|
||||
readOnly={forceReadOnly}
|
||||
sortInfoExternal={sortInfo}
|
||||
exportSqlWithFilter={exportSqlWithFilter || undefined}
|
||||
|
||||
@@ -9,6 +9,8 @@ import { useStore } from '../store';
|
||||
import { DBQuery, DBQueryWithCancel, DBGetTables, DBGetAllColumns, DBGetDatabases, DBGetColumns, CancelQuery, GenerateQueryID } from '../../wailsjs/go/app/App';
|
||||
import DataGrid, { GONAVI_ROW_KEY } from './DataGrid';
|
||||
import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities';
|
||||
import { convertMongoShellToJsonCommand } from '../utils/mongodb';
|
||||
import { getShortcutDisplay, isEditableElement, isShortcutMatch } from '../utils/shortcuts';
|
||||
|
||||
const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
const [query, setQuery] = useState(tab.query || 'SELECT * FROM ');
|
||||
@@ -68,6 +70,8 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
const setSqlFormatOptions = useStore(state => state.setSqlFormatOptions);
|
||||
const queryOptions = useStore(state => state.queryOptions);
|
||||
const setQueryOptions = useStore(state => state.setQueryOptions);
|
||||
const shortcutOptions = useStore(state => state.shortcutOptions);
|
||||
const activeTabId = useStore(state => state.activeTabId);
|
||||
|
||||
useEffect(() => {
|
||||
currentConnectionIdRef.current = currentConnectionId;
|
||||
@@ -268,6 +272,19 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
return parts[parts.length - 1] || raw;
|
||||
};
|
||||
|
||||
const splitSchemaAndTable = (qualified: string): { schema: string; table: string } => {
|
||||
const raw = normalizeQualifiedName(qualified);
|
||||
if (!raw) return { schema: '', table: '' };
|
||||
const parts = raw.split('.').filter(Boolean);
|
||||
if (parts.length >= 2) {
|
||||
return {
|
||||
schema: parts[parts.length - 2] || '',
|
||||
table: parts[parts.length - 1] || '',
|
||||
};
|
||||
}
|
||||
return { schema: '', table: parts[0] || '' };
|
||||
};
|
||||
|
||||
const buildConnConfig = () => {
|
||||
const connId = currentConnectionIdRef.current;
|
||||
const conn = connectionsRef.current.find(c => c.id === connId);
|
||||
@@ -340,13 +357,14 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
if (qualifierMatch) {
|
||||
const qualifier = stripQuotes(qualifierMatch[1]);
|
||||
const prefix = (qualifierMatch[2] || '').toLowerCase();
|
||||
const qualifierLower = qualifier.toLowerCase();
|
||||
|
||||
// 首先检查 qualifier 是否是数据库名(跨库表提示)
|
||||
const visibleDbs = visibleDbsRef.current;
|
||||
if (visibleDbs.some(db => db.toLowerCase() === qualifier.toLowerCase())) {
|
||||
if (visibleDbs.some(db => db.toLowerCase() === qualifierLower)) {
|
||||
// qualifier 是数据库名,提示该库的表
|
||||
const tables = tablesRef.current.filter(t =>
|
||||
(t.dbName || '').toLowerCase() === qualifier.toLowerCase()
|
||||
(t.dbName || '').toLowerCase() === qualifierLower
|
||||
);
|
||||
const filtered = prefix
|
||||
? tables.filter(t => (t.tableName || '').toLowerCase().startsWith(prefix))
|
||||
@@ -363,6 +381,34 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
return { suggestions };
|
||||
}
|
||||
|
||||
// qualifier 是 schema(如 dbo/public)时,仅补全表名,避免输入 dbo. 后再补成 dbo.dbo.table
|
||||
const schemaTables = tablesRef.current
|
||||
.map(t => {
|
||||
const parsed = splitSchemaAndTable(t.tableName || '');
|
||||
return {
|
||||
dbName: t.dbName || '',
|
||||
schema: parsed.schema,
|
||||
table: parsed.table,
|
||||
};
|
||||
})
|
||||
.filter(t => t.schema.toLowerCase() === qualifierLower && !!t.table);
|
||||
|
||||
if (schemaTables.length > 0) {
|
||||
const filtered = prefix
|
||||
? schemaTables.filter(t => t.table.toLowerCase().startsWith(prefix))
|
||||
: schemaTables;
|
||||
|
||||
const suggestions = filtered.map(t => ({
|
||||
label: t.table,
|
||||
kind: monaco.languages.CompletionItemKind.Class,
|
||||
insertText: t.table,
|
||||
detail: `Table (${t.dbName}${t.schema ? '.' + t.schema : ''})`,
|
||||
range,
|
||||
sortText: '0' + t.table
|
||||
}));
|
||||
return { suggestions };
|
||||
}
|
||||
|
||||
// 否则检查是否是表别名或表名,提示列
|
||||
const reserved = new Set([
|
||||
'where', 'on', 'group', 'order', 'limit', 'having',
|
||||
@@ -531,6 +577,12 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
icon: sqlFormatOptions.keywordCase === 'lower' ? '✓' : undefined,
|
||||
onClick: () => setSqlFormatOptions({ keywordCase: 'lower' })
|
||||
},
|
||||
{ type: 'divider' },
|
||||
{
|
||||
key: 'shortcut-settings',
|
||||
label: '快捷键管理...',
|
||||
onClick: () => window.dispatchEvent(new CustomEvent('gonavi:open-shortcut-settings')),
|
||||
},
|
||||
];
|
||||
|
||||
const splitSQLStatements = (sql: string): string[] => {
|
||||
@@ -1035,7 +1087,15 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
|
||||
try {
|
||||
const rawSQL = getSelectedSQL() || query;
|
||||
const statements = splitSQLStatements(rawSQL);
|
||||
const dbType = String((config as any).type || 'mysql');
|
||||
const normalizedDbType = dbType.trim().toLowerCase();
|
||||
const normalizedRawSQL = String(rawSQL || '').replace(/;/g, ';');
|
||||
const splitInput = normalizedDbType === 'mongodb'
|
||||
? normalizedRawSQL
|
||||
.replace(/^\s*\/\/.*$/gm, '')
|
||||
.replace(/^\s*#.*$/gm, '')
|
||||
: normalizedRawSQL;
|
||||
const statements = splitSQLStatements(splitInput);
|
||||
if (statements.length === 0) {
|
||||
message.info('没有可执行的 SQL。');
|
||||
setResultSets([]);
|
||||
@@ -1045,7 +1105,6 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
|
||||
const nextResultSets: ResultSet[] = [];
|
||||
const maxRows = Number(queryOptions?.maxRows) || 0;
|
||||
const dbType = String((config as any).type || 'mysql');
|
||||
const forceReadOnlyResult = connCaps.forceReadOnlyQueryResult;
|
||||
const wantsLimitProbe = Number.isFinite(maxRows) && maxRows > 0;
|
||||
const probeLimit = wantsLimitProbe ? (maxRows + 1) : 0;
|
||||
@@ -1059,9 +1118,24 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
|
||||
const limitApplied = shouldAutoLimit && wantsLimitProbe;
|
||||
const limited = limitApplied ? applyAutoLimit(rawStatement, dbType, probeLimit) : { sql: rawStatement, applied: false, maxRows: probeLimit };
|
||||
const executedSql = limited.sql;
|
||||
let executedSql = limited.sql;
|
||||
if (String(dbType || '').trim().toLowerCase() === 'mongodb') {
|
||||
const shellConvert = convertMongoShellToJsonCommand(executedSql);
|
||||
if (shellConvert.recognized) {
|
||||
if (shellConvert.error) {
|
||||
const prefix = statements.length > 1 ? `第 ${idx + 1} 条语句执行失败:` : '';
|
||||
message.error(prefix + shellConvert.error);
|
||||
setResultSets([]);
|
||||
setActiveResultKey('');
|
||||
return;
|
||||
}
|
||||
if (shellConvert.command) {
|
||||
executedSql = shellConvert.command;
|
||||
}
|
||||
}
|
||||
}
|
||||
const startTime = Date.now();
|
||||
|
||||
|
||||
// Generate query ID for cancellation using backend UUID with fallback
|
||||
let queryId: string;
|
||||
try {
|
||||
@@ -1071,7 +1145,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
queryId = 'query-' + uuidv4();
|
||||
}
|
||||
setQueryId(queryId);
|
||||
|
||||
|
||||
const res = await DBQueryWithCancel(config as any, currentDb, executedSql, queryId);
|
||||
const duration = Date.now() - startTime;
|
||||
|
||||
@@ -1089,19 +1163,19 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
if (!res.success) {
|
||||
// 检查是否为查询取消错误
|
||||
const errorMsg = res.message.toLowerCase();
|
||||
const isCancelledError = errorMsg.includes('context canceled') ||
|
||||
const isCancelledError = errorMsg.includes('context canceled') ||
|
||||
errorMsg.includes('查询已取消') ||
|
||||
errorMsg.includes('canceled') ||
|
||||
errorMsg.includes('cancelled') ||
|
||||
errorMsg.includes('statement canceled') ||
|
||||
errorMsg.includes('sql: statement canceled');
|
||||
|
||||
|
||||
// 确保不是超时错误
|
||||
const isTimeoutError = errorMsg.includes('context deadline exceeded') ||
|
||||
errorMsg.includes('timeout') ||
|
||||
errorMsg.includes('超时') ||
|
||||
errorMsg.includes('deadline exceeded');
|
||||
|
||||
|
||||
if (isCancelledError && !isTimeoutError) {
|
||||
// 查询已被用户取消,不显示错误消息,清理状态
|
||||
setResultSets([]);
|
||||
@@ -1112,7 +1186,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
const prefix = statements.length > 1 ? `第 ${idx + 1} 条语句执行失败:` : '';
|
||||
message.error(prefix + res.message);
|
||||
setResultSets([]);
|
||||
@@ -1245,6 +1319,48 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const binding = shortcutOptions.runQuery;
|
||||
if (!binding?.enabled || !binding.combo) {
|
||||
return;
|
||||
}
|
||||
|
||||
const handleRunShortcut = (event: KeyboardEvent) => {
|
||||
if (activeTabId !== tab.id) {
|
||||
return;
|
||||
}
|
||||
if (!isShortcutMatch(event, binding.combo)) {
|
||||
return;
|
||||
}
|
||||
const editorHasFocus = !!editorRef.current?.hasTextFocus?.();
|
||||
if (!editorHasFocus && !isEditableElement(event.target)) {
|
||||
return;
|
||||
}
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
void handleRun();
|
||||
};
|
||||
|
||||
window.addEventListener('keydown', handleRunShortcut);
|
||||
return () => {
|
||||
window.removeEventListener('keydown', handleRunShortcut);
|
||||
};
|
||||
}, [activeTabId, tab.id, shortcutOptions.runQuery, handleRun]);
|
||||
|
||||
useEffect(() => {
|
||||
const handleRunActiveQuery = () => {
|
||||
if (activeTabId !== tab.id) {
|
||||
return;
|
||||
}
|
||||
void handleRun();
|
||||
};
|
||||
|
||||
window.addEventListener('gonavi:run-active-query', handleRunActiveQuery as EventListener);
|
||||
return () => {
|
||||
window.removeEventListener('gonavi:run-active-query', handleRunActiveQuery as EventListener);
|
||||
};
|
||||
}, [activeTabId, tab.id, handleRun]);
|
||||
|
||||
const handleSave = async () => {
|
||||
try {
|
||||
const values = await saveForm.validateFields();
|
||||
@@ -1357,9 +1473,17 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
/>
|
||||
</Tooltip>
|
||||
<Button.Group>
|
||||
<Button type="primary" icon={<PlayCircleOutlined />} onClick={handleRun} loading={loading}>
|
||||
运行
|
||||
</Button>
|
||||
<Tooltip
|
||||
title={
|
||||
shortcutOptions.runQuery?.enabled && shortcutOptions.runQuery?.combo
|
||||
? `运行(${getShortcutDisplay(shortcutOptions.runQuery.combo)})`
|
||||
: '运行'
|
||||
}
|
||||
>
|
||||
<Button type="primary" icon={<PlayCircleOutlined />} onClick={handleRun} loading={loading}>
|
||||
运行
|
||||
</Button>
|
||||
</Tooltip>
|
||||
{loading && (
|
||||
<Button type="primary" danger icon={<StopOutlined />} onClick={handleCancel}>
|
||||
停止
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import React, { useEffect, useState, useMemo, useRef } from 'react';
|
||||
import { Tree, message, Dropdown, MenuProps, Input, Button, Modal, Form, Badge, Checkbox, Space, Select } from 'antd';
|
||||
import { Tree, message, Dropdown, MenuProps, Input, Button, Modal, Form, Badge, Checkbox, Space, Select, Popover, Tooltip } from 'antd';
|
||||
import {
|
||||
DatabaseOutlined,
|
||||
TableOutlined,
|
||||
@@ -50,6 +50,7 @@ type BatchTableExportMode = 'schema' | 'backup' | 'dataOnly';
|
||||
type BatchObjectType = 'table' | 'view';
|
||||
type BatchObjectFilterType = 'all' | BatchObjectType;
|
||||
type BatchSelectionScope = 'filtered' | 'all';
|
||||
type SearchScope = 'smart' | 'object' | 'database' | 'host' | 'tag';
|
||||
|
||||
interface BatchObjectItem {
|
||||
title: string;
|
||||
@@ -59,9 +60,23 @@ interface BatchObjectItem {
|
||||
dataRef: any;
|
||||
}
|
||||
|
||||
const SEARCH_SCOPE_OPTIONS: Array<{ value: SearchScope; label: string }> = [
|
||||
{ value: 'smart', label: '智能' },
|
||||
{ value: 'object', label: '表对象' },
|
||||
{ value: 'database', label: '库' },
|
||||
{ value: 'host', label: 'Host' },
|
||||
{ value: 'tag', label: '标签' },
|
||||
];
|
||||
|
||||
const SEARCH_SCOPE_LABEL_MAP: Record<SearchScope, string> = SEARCH_SCOPE_OPTIONS.reduce((acc, option) => {
|
||||
acc[option.value] = option.label;
|
||||
return acc;
|
||||
}, {} as Record<SearchScope, string>);
|
||||
|
||||
const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> = ({ onEditConnection }) => {
|
||||
const connections = useStore(state => state.connections);
|
||||
const savedQueries = useStore(state => state.savedQueries);
|
||||
const addConnection = useStore(state => state.addConnection);
|
||||
const addTab = useStore(state => state.addTab);
|
||||
const setActiveContext = useStore(state => state.setActiveContext);
|
||||
const removeConnection = useStore(state => state.removeConnection);
|
||||
@@ -94,6 +109,9 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
};
|
||||
const bgMain = getBg('#141414');
|
||||
const [searchValue, setSearchValue] = useState('');
|
||||
const [searchScopes, setSearchScopes] = useState<SearchScope[]>(['smart']);
|
||||
const [isSearchScopePopoverOpen, setIsSearchScopePopoverOpen] = useState(false);
|
||||
const searchInputRef = useRef<any>(null);
|
||||
const [expandedKeys, setExpandedKeys] = useState<React.Key[]>([]);
|
||||
const [autoExpandParent, setAutoExpandParent] = useState(true);
|
||||
const [loadedKeys, setLoadedKeys] = useState<React.Key[]>([]);
|
||||
@@ -116,6 +134,21 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
resizeObserver.observe(treeContainerRef.current);
|
||||
return () => resizeObserver.disconnect();
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
const handleFocusSidebarSearch = () => {
|
||||
const inputEl = searchInputRef.current?.input as HTMLInputElement | undefined;
|
||||
if (!inputEl) {
|
||||
return;
|
||||
}
|
||||
inputEl.focus();
|
||||
inputEl.select();
|
||||
};
|
||||
window.addEventListener('gonavi:focus-sidebar-search', handleFocusSidebarSearch as EventListener);
|
||||
return () => {
|
||||
window.removeEventListener('gonavi:focus-sidebar-search', handleFocusSidebarSearch as EventListener);
|
||||
};
|
||||
}, []);
|
||||
|
||||
// Connection Status State: key -> 'success' | 'error'
|
||||
const [connectionStates, setConnectionStates] = useState<Record<string, 'success' | 'error'>>({});
|
||||
@@ -219,7 +252,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
useEffect(() => {
|
||||
setTreeData((prev) => {
|
||||
const prevMap = new Map<string, TreeNode>();
|
||||
|
||||
|
||||
// We need to recursively extract connections from old tag structures
|
||||
// so if a user expands a connection that was tagged, the state remains
|
||||
const recurseCollect = (nodes: TreeNode[]) => {
|
||||
@@ -271,6 +304,118 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
});
|
||||
}, [connections, connectionTags]);
|
||||
|
||||
const buildDuplicateConnectionName = (rawName: string): string => {
|
||||
const baseName = String(rawName || '').trim() || '连接';
|
||||
const suffix = ' - 副本';
|
||||
const usedNames = new Set(connections.map(conn => String(conn.name || '').trim()));
|
||||
let candidate = `${baseName}${suffix}`;
|
||||
let counter = 2;
|
||||
while (usedNames.has(candidate)) {
|
||||
candidate = `${baseName}${suffix} ${counter}`;
|
||||
counter += 1;
|
||||
}
|
||||
return candidate;
|
||||
};
|
||||
|
||||
const cloneConnectionConfig = (config: SavedConnection['config']): SavedConnection['config'] => {
|
||||
const raw: any = config || {};
|
||||
let cloned: any = {};
|
||||
try {
|
||||
cloned = typeof structuredClone === 'function'
|
||||
? structuredClone(raw)
|
||||
: JSON.parse(JSON.stringify(raw));
|
||||
} catch {
|
||||
cloned = { ...raw };
|
||||
}
|
||||
|
||||
const readString = (...values: unknown[]): string => {
|
||||
for (const value of values) {
|
||||
if (typeof value === 'string') {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
return '';
|
||||
};
|
||||
|
||||
const readBool = (fallback: boolean, ...values: unknown[]): boolean => {
|
||||
for (const value of values) {
|
||||
if (typeof value === 'boolean') {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
return fallback;
|
||||
};
|
||||
|
||||
const readNumber = (fallback: number, ...values: unknown[]): number => {
|
||||
for (const value of values) {
|
||||
const num = Number(value);
|
||||
if (Number.isFinite(num)) {
|
||||
return num;
|
||||
}
|
||||
}
|
||||
return fallback;
|
||||
};
|
||||
|
||||
const rawSSH = (cloned.ssh ?? cloned.SSH ?? {}) as Record<string, unknown>;
|
||||
const normalizedSSH = {
|
||||
host: readString(rawSSH.host, rawSSH.Host, cloned.sshHost, cloned.SSHHost),
|
||||
port: readNumber(22, rawSSH.port, rawSSH.Port, cloned.sshPort, cloned.SSHPort),
|
||||
user: readString(rawSSH.user, rawSSH.User, cloned.sshUser, cloned.SSHUser),
|
||||
password: readString(rawSSH.password, rawSSH.Password, cloned.sshPassword, cloned.SSHPassword),
|
||||
keyPath: readString(rawSSH.keyPath, rawSSH.KeyPath, cloned.sshKeyPath, cloned.SSHKeyPath),
|
||||
};
|
||||
const hasSSHDetail = Boolean(
|
||||
normalizedSSH.host
|
||||
|| normalizedSSH.user
|
||||
|| normalizedSSH.password
|
||||
|| normalizedSSH.keyPath
|
||||
);
|
||||
|
||||
const rawProxy = (cloned.proxy ?? cloned.Proxy ?? {}) as Record<string, unknown>;
|
||||
const proxyTypeRaw = readString(rawProxy.type, rawProxy.Type, cloned.proxyType, cloned.ProxyType).toLowerCase();
|
||||
const proxyType: 'socks5' | 'http' = proxyTypeRaw === 'http' ? 'http' : 'socks5';
|
||||
const normalizedProxy = {
|
||||
type: proxyType,
|
||||
host: readString(rawProxy.host, rawProxy.Host, cloned.proxyHost, cloned.ProxyHost),
|
||||
port: readNumber(proxyType === 'http' ? 8080 : 1080, rawProxy.port, rawProxy.Port, cloned.proxyPort, cloned.ProxyPort),
|
||||
user: readString(rawProxy.user, rawProxy.User, cloned.proxyUser, cloned.ProxyUser),
|
||||
password: readString(rawProxy.password, rawProxy.Password, cloned.proxyPassword, cloned.ProxyPassword),
|
||||
};
|
||||
const hasProxyDetail = Boolean(normalizedProxy.host || normalizedProxy.user || normalizedProxy.password);
|
||||
|
||||
const rawHosts = Array.isArray(cloned.hosts)
|
||||
? cloned.hosts
|
||||
: (Array.isArray(cloned.Hosts) ? cloned.Hosts : []);
|
||||
const normalizedHosts = rawHosts
|
||||
.map((entry: unknown) => String(entry || '').trim())
|
||||
.filter((entry: string) => !!entry);
|
||||
|
||||
return {
|
||||
...(cloned as SavedConnection['config']),
|
||||
useSSH: readBool(hasSSHDetail, cloned.useSSH, cloned.UseSSH),
|
||||
ssh: normalizedSSH,
|
||||
useProxy: readBool(hasProxyDetail, cloned.useProxy, cloned.UseProxy),
|
||||
proxy: normalizedProxy,
|
||||
hosts: normalizedHosts,
|
||||
timeout: readNumber(30, cloned.timeout, cloned.Timeout),
|
||||
};
|
||||
};
|
||||
|
||||
const handleDuplicateConnection = (conn: SavedConnection) => {
|
||||
if (!conn) return;
|
||||
|
||||
const duplicatedConnection: SavedConnection = {
|
||||
...conn,
|
||||
id: `${Date.now()}-${Math.random().toString(36).slice(2, 8)}`,
|
||||
name: buildDuplicateConnectionName(conn.name),
|
||||
config: cloneConnectionConfig(conn.config),
|
||||
includeDatabases: conn.includeDatabases ? [...conn.includeDatabases] : undefined,
|
||||
includeRedisDatabases: conn.includeRedisDatabases ? [...conn.includeRedisDatabases] : undefined,
|
||||
};
|
||||
|
||||
addConnection(duplicatedConnection);
|
||||
message.success(`已复制连接: ${duplicatedConnection.name}`);
|
||||
};
|
||||
const updateTreeData = (list: TreeNode[], key: React.Key, children: TreeNode[] | undefined): TreeNode[] => {
|
||||
return list.map(node => {
|
||||
if (node.key === key) {
|
||||
@@ -749,7 +894,8 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
const res = await (window as any).go.app.App.RedisGetDatabases(config);
|
||||
if (res.success) {
|
||||
setConnectionStates(prev => ({ ...prev, [conn.id]: 'success' }));
|
||||
let dbs = (res.data as any[]).map((db: any) => ({
|
||||
const redisRows: any[] = Array.isArray(res.data) ? res.data : [];
|
||||
let dbs = redisRows.map((db: any) => ({
|
||||
title: `db${db.index}${db.keys > 0 ? ` (${db.keys})` : ''}`,
|
||||
key: `${conn.id}-db${db.index}`,
|
||||
icon: <DatabaseOutlined style={{ color: '#DC382D' }} />,
|
||||
@@ -780,7 +926,8 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
const res = await DBGetDatabases(config as any);
|
||||
if (res.success) {
|
||||
setConnectionStates(prev => ({ ...prev, [conn.id]: 'success' }));
|
||||
let dbs = (res.data as any[]).map((row: any) => ({
|
||||
const dbRows: any[] = Array.isArray(res.data) ? res.data : [];
|
||||
let dbs = dbRows.map((row: any) => ({
|
||||
title: row.Database || row.database,
|
||||
key: `${conn.id}-${row.Database || row.database}`,
|
||||
icon: <DatabaseOutlined />,
|
||||
@@ -795,10 +942,13 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
}
|
||||
|
||||
setTreeData(origin => updateTreeData(origin, node.key, dbs));
|
||||
} else {
|
||||
setConnectionStates(prev => ({ ...prev, [conn.id]: 'error' }));
|
||||
message.error({ content: res.message, key: `conn-${conn.id}-dbs` });
|
||||
}
|
||||
} else {
|
||||
setConnectionStates(prev => ({ ...prev, [conn.id]: 'error' }));
|
||||
message.error({ content: res.message, key: `conn-${conn.id}-dbs` });
|
||||
}
|
||||
} catch (e: any) {
|
||||
setConnectionStates(prev => ({ ...prev, [conn.id]: 'error' }));
|
||||
message.error({ content: '连接失败: ' + (e?.message || String(e)), key: `conn-${conn.id}-dbs` });
|
||||
} finally {
|
||||
loadingNodesRef.current.delete(loadKey);
|
||||
}
|
||||
@@ -843,7 +993,8 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
if (res.success) {
|
||||
setConnectionStates(prev => ({ ...prev, [key as string]: 'success' }));
|
||||
|
||||
const tableEntries = (res.data as any[]).map((row: any) => {
|
||||
const tableRows: any[] = Array.isArray(res.data) ? res.data : [];
|
||||
const tableEntries = tableRows.map((row: any) => {
|
||||
const tableName = Object.values(row)[0] as string;
|
||||
const parsed = splitQualifiedName(tableName);
|
||||
return {
|
||||
@@ -859,7 +1010,11 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
loadFunctions(conn, conn.dbName),
|
||||
]);
|
||||
|
||||
const viewEntries = viewsResult.views.map((viewName) => {
|
||||
const viewRows: string[] = Array.isArray(viewsResult.views) ? viewsResult.views : [];
|
||||
const triggerRows: any[] = Array.isArray(triggersResult.triggers) ? triggersResult.triggers : [];
|
||||
const routineRows: any[] = Array.isArray(routinesResult.routines) ? routinesResult.routines : [];
|
||||
|
||||
const viewEntries = viewRows.map((viewName: string) => {
|
||||
const parsed = splitQualifiedName(viewName);
|
||||
return {
|
||||
viewName,
|
||||
@@ -873,7 +1028,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
const triggerSeen = new Set<string>();
|
||||
const metadataDialect = getMetadataDialect(conn as SavedConnection);
|
||||
|
||||
triggersResult.triggers.forEach((trigger) => {
|
||||
triggerRows.forEach((trigger: any) => {
|
||||
const triggerParsed = splitQualifiedName(trigger.triggerName);
|
||||
const tableParsed = splitQualifiedName(trigger.tableName);
|
||||
const schemaName = tableParsed.schemaName || triggerParsed.schemaName || String(conn.dbName || '').trim();
|
||||
@@ -898,7 +1053,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
return deduped;
|
||||
})();
|
||||
|
||||
const routineEntries = routinesResult.routines.map((routine) => {
|
||||
const routineEntries = routineRows.map((routine: any) => {
|
||||
const parsed = splitQualifiedName(routine.routineName);
|
||||
const typeLabel = routine.routineType === 'PROCEDURE' ? 'P' : 'F';
|
||||
return {
|
||||
@@ -1080,6 +1235,9 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
setConnectionStates(prev => ({ ...prev, [key as string]: 'error' }));
|
||||
message.error({ content: res.message, key: `db-${key}-tables` });
|
||||
}
|
||||
} catch (e: any) {
|
||||
setConnectionStates(prev => ({ ...prev, [key as string]: 'error' }));
|
||||
message.error({ content: '加载表失败: ' + (e?.message || String(e)), key: `db-${key}-tables` });
|
||||
} finally {
|
||||
loadingNodesRef.current.delete(loadKey);
|
||||
}
|
||||
@@ -1425,7 +1583,8 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
|
||||
const res = await DBGetDatabases(config as any);
|
||||
if (res.success) {
|
||||
let dbs = (res.data as any[]).map((row: any) => {
|
||||
const dbRows: any[] = Array.isArray(res.data) ? res.data : [];
|
||||
let dbs = dbRows.map((row: any) => {
|
||||
const dbName = row.Database || row.database;
|
||||
return {
|
||||
title: dbName,
|
||||
@@ -1466,9 +1625,11 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
return;
|
||||
}
|
||||
|
||||
const viewSet = new Set(viewResult.views.map(view => view.toLowerCase()));
|
||||
const tableRows: any[] = Array.isArray(res.data) ? res.data : [];
|
||||
const viewRows: string[] = Array.isArray(viewResult.views) ? viewResult.views : [];
|
||||
const viewSet = new Set(viewRows.map((view: string) => view.toLowerCase()));
|
||||
|
||||
const tableObjects: BatchObjectItem[] = (res.data as any[])
|
||||
const tableObjects: BatchObjectItem[] = tableRows
|
||||
.map((row: any) => Object.values(row)[0] as string)
|
||||
.filter((tableName: string) => !viewSet.has(tableName.toLowerCase()))
|
||||
.map((tableName: string) => ({
|
||||
@@ -1479,7 +1640,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
dataRef: { ...conn, tableName, dbName, objectType: 'table' },
|
||||
}));
|
||||
|
||||
const viewObjects: BatchObjectItem[] = viewResult.views.map((viewName: string) => ({
|
||||
const viewObjects: BatchObjectItem[] = viewRows.map((viewName: string) => ({
|
||||
title: getSidebarTableDisplayName(conn, viewName),
|
||||
key: `${conn.id}-${dbName}-view-${viewName}`,
|
||||
objectName: viewName,
|
||||
@@ -1645,7 +1806,8 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
|
||||
const res = await DBGetDatabases(config as any);
|
||||
if (res.success) {
|
||||
let dbs = (res.data as any[]).map((row: any) => {
|
||||
const dbRows: any[] = Array.isArray(res.data) ? res.data : [];
|
||||
let dbs = dbRows.map((row: any) => {
|
||||
const dbName = row.Database || row.database;
|
||||
return {
|
||||
title: dbName,
|
||||
@@ -2232,28 +2394,205 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
setSearchValue(value);
|
||||
};
|
||||
|
||||
const loop = (data: TreeNode[]): TreeNode[] => {
|
||||
const result: TreeNode[] = [];
|
||||
data.forEach(item => {
|
||||
const match = item.title.toLowerCase().indexOf(searchValue.toLowerCase()) > -1;
|
||||
if (item.children) {
|
||||
const filteredChildren = loop(item.children);
|
||||
if (filteredChildren.length > 0 || match) {
|
||||
result.push({ ...item, children: filteredChildren });
|
||||
}
|
||||
const toggleSearchScope = (scope: SearchScope) => {
|
||||
setSearchScopes((prev) => {
|
||||
if (scope === 'smart') {
|
||||
return ['smart'];
|
||||
}
|
||||
const withoutSmart = prev.filter((item) => item !== 'smart');
|
||||
if (withoutSmart.includes(scope)) {
|
||||
const next = withoutSmart.filter((item) => item !== scope);
|
||||
return next.length > 0 ? next : ['smart'];
|
||||
}
|
||||
return [...withoutSmart, scope];
|
||||
});
|
||||
};
|
||||
|
||||
const setSearchScopeChecked = (scope: SearchScope, checked: boolean) => {
|
||||
if (scope === 'smart') {
|
||||
if (checked) {
|
||||
setSearchScopes(['smart']);
|
||||
} else if (searchScopes.length === 1 && searchScopes[0] === 'smart') {
|
||||
setSearchScopes(['smart']);
|
||||
} else {
|
||||
if (match) {
|
||||
setSearchScopes((prev) => {
|
||||
const next = prev.filter((item) => item !== 'smart');
|
||||
return next.length > 0 ? next : ['smart'];
|
||||
});
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (checked) {
|
||||
setSearchScopes((prev) => {
|
||||
const withoutSmart = prev.filter((item) => item !== 'smart');
|
||||
if (withoutSmart.includes(scope)) {
|
||||
return withoutSmart;
|
||||
}
|
||||
return [...withoutSmart, scope];
|
||||
});
|
||||
} else {
|
||||
setSearchScopes((prev) => {
|
||||
const next = prev.filter((item) => item !== scope && item !== 'smart');
|
||||
return next.length > 0 ? next : ['smart'];
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const searchScopeSummary = useMemo(() => {
|
||||
if (searchScopes.includes('smart')) {
|
||||
return '智能';
|
||||
}
|
||||
return searchScopes.map((scope) => SEARCH_SCOPE_LABEL_MAP[scope]).join(' + ');
|
||||
}, [searchScopes]);
|
||||
|
||||
const searchScopePopoverContent = useMemo(() => {
|
||||
const smartSelected = searchScopes.includes('smart');
|
||||
const scopedOptions = SEARCH_SCOPE_OPTIONS.filter((option) => option.value !== 'smart');
|
||||
return (
|
||||
<div style={{ minWidth: 220, display: 'flex', flexDirection: 'column', gap: 8 }}>
|
||||
<div style={{ fontSize: 12, color: '#8c8c8c' }}>搜索范围</div>
|
||||
<Checkbox
|
||||
checked={smartSelected}
|
||||
onChange={(e) => setSearchScopeChecked('smart', e.target.checked)}
|
||||
>
|
||||
智能(推荐)
|
||||
</Checkbox>
|
||||
<div style={{ paddingLeft: 12, display: 'grid', gap: 6 }}>
|
||||
{scopedOptions.map((option) => (
|
||||
<Checkbox
|
||||
key={option.value}
|
||||
checked={searchScopes.includes(option.value)}
|
||||
onChange={(e) => setSearchScopeChecked(option.value, e.target.checked)}
|
||||
>
|
||||
{option.label}
|
||||
</Checkbox>
|
||||
))}
|
||||
</div>
|
||||
<div style={{ fontSize: 12, color: '#8c8c8c' }}>
|
||||
智能与其他项互斥;其他项支持多选。
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}, [searchScopes]);
|
||||
|
||||
const parseHostOnlyToken = (value: unknown): string[] => {
|
||||
const raw = String(value || '').trim();
|
||||
if (!raw) {
|
||||
return [];
|
||||
}
|
||||
let text = raw.replace(/^[a-z][a-z0-9+.-]*:\/\//i, '');
|
||||
if (text.includes('/')) {
|
||||
text = text.split('/')[0];
|
||||
}
|
||||
if (text.includes('?')) {
|
||||
text = text.split('?')[0];
|
||||
}
|
||||
if (text.includes('@')) {
|
||||
text = text.split('@').pop() || '';
|
||||
}
|
||||
return text
|
||||
.split(',')
|
||||
.map((entry) => {
|
||||
const token = entry.trim();
|
||||
if (!token) return '';
|
||||
if (token.startsWith('[')) {
|
||||
const rightBracketIndex = token.indexOf(']');
|
||||
if (rightBracketIndex > 0) {
|
||||
return token.slice(0, rightBracketIndex + 1).toLowerCase();
|
||||
}
|
||||
}
|
||||
const colonIndex = token.lastIndexOf(':');
|
||||
if (colonIndex > 0) {
|
||||
return token.slice(0, colonIndex).toLowerCase();
|
||||
}
|
||||
return token.toLowerCase();
|
||||
})
|
||||
.filter(Boolean);
|
||||
};
|
||||
|
||||
const getConnectionHostSearchText = (node: TreeNode): string => {
|
||||
if (node.type !== 'connection') return '';
|
||||
const config = node.dataRef?.config || {};
|
||||
const hostTokens = [
|
||||
...parseHostOnlyToken(config.host),
|
||||
...(Array.isArray(config.hosts) ? config.hosts.flatMap((entry: string) => parseHostOnlyToken(entry)) : []),
|
||||
...parseHostOnlyToken(config.uri),
|
||||
];
|
||||
const uniqueHosts = Array.from(new Set(hostTokens));
|
||||
return uniqueHosts.join(' ');
|
||||
};
|
||||
|
||||
const getConnectionNameSearchText = (node: TreeNode): string => {
|
||||
if (node.type !== 'connection') return '';
|
||||
const name = node.dataRef?.name ?? node.title;
|
||||
return String(name || '').toLowerCase();
|
||||
};
|
||||
|
||||
const isObjectNode = (node: TreeNode): boolean => {
|
||||
return node.type === 'table'
|
||||
|| node.type === 'view'
|
||||
|| node.type === 'db-trigger'
|
||||
|| node.type === 'routine'
|
||||
|| node.type === 'object-group';
|
||||
};
|
||||
|
||||
const matchByScopes = (node: TreeNode, keyword: string, scopes: SearchScope[]): boolean => {
|
||||
const title = String(node.title || '').toLowerCase();
|
||||
if (scopes.includes('database') && node.type === 'database' && title.includes(keyword)) {
|
||||
return true;
|
||||
}
|
||||
if (scopes.includes('tag') && node.type === 'tag' && title.includes(keyword)) {
|
||||
return true;
|
||||
}
|
||||
if (scopes.includes('host') && node.type === 'connection' && getConnectionHostSearchText(node).includes(keyword)) {
|
||||
return true;
|
||||
}
|
||||
if (scopes.includes('object') && isObjectNode(node) && title.includes(keyword)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
const loop = (data: TreeNode[], keyword: string): TreeNode[] => {
|
||||
const isSmartMode = searchScopes.includes('smart');
|
||||
const result: TreeNode[] = [];
|
||||
data.forEach((item) => {
|
||||
const titleMatch = String(item.title || '').toLowerCase().includes(keyword);
|
||||
const smartMatch = item.type === 'connection'
|
||||
? getConnectionNameSearchText(item).includes(keyword) || getConnectionHostSearchText(item).includes(keyword)
|
||||
: titleMatch;
|
||||
const scopedMatch = matchByScopes(item, keyword, searchScopes);
|
||||
const selfMatch = isSmartMode ? smartMatch : scopedMatch;
|
||||
const filteredChildren = item.children ? loop(item.children, keyword) : [];
|
||||
|
||||
if (selfMatch) {
|
||||
const shouldKeepFullSubtree = isSmartMode
|
||||
|| item.type === 'connection'
|
||||
|| item.type === 'database'
|
||||
|| item.type === 'tag';
|
||||
if (item.children && shouldKeepFullSubtree) {
|
||||
result.push(item);
|
||||
} else if (item.children && filteredChildren.length > 0) {
|
||||
result.push({ ...item, children: filteredChildren });
|
||||
} else {
|
||||
result.push(item);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (filteredChildren.length > 0) {
|
||||
result.push({ ...item, children: filteredChildren });
|
||||
}
|
||||
});
|
||||
return result;
|
||||
};
|
||||
|
||||
const displayTreeData = useMemo(() => {
|
||||
if (!searchValue) return treeData;
|
||||
return loop(treeData);
|
||||
}, [searchValue, treeData]);
|
||||
const keyword = searchValue.trim().toLowerCase();
|
||||
if (!keyword) return treeData;
|
||||
return loop(treeData, keyword);
|
||||
}, [searchValue, searchScopes, treeData]);
|
||||
|
||||
const getNodeMenuItems = (node: any): MenuProps['items'] => {
|
||||
const conn = node.dataRef as SavedConnection;
|
||||
@@ -2395,6 +2734,12 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
if (onEditConnection) onEditConnection(node.dataRef);
|
||||
}
|
||||
},
|
||||
{
|
||||
key: 'copy-connection',
|
||||
label: '复制连接',
|
||||
icon: <CopyOutlined />,
|
||||
onClick: () => handleDuplicateConnection(node.dataRef as SavedConnection)
|
||||
},
|
||||
{
|
||||
key: 'disconnect',
|
||||
label: '断开连接',
|
||||
@@ -2493,6 +2838,12 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
if (onEditConnection) onEditConnection(node.dataRef);
|
||||
}
|
||||
},
|
||||
{
|
||||
key: 'copy-connection',
|
||||
label: '复制连接',
|
||||
icon: <CopyOutlined />,
|
||||
onClick: () => handleDuplicateConnection(node.dataRef as SavedConnection)
|
||||
},
|
||||
{
|
||||
key: 'move-to-tag',
|
||||
label: '移至标签',
|
||||
@@ -2504,6 +2855,13 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
label: '断开连接',
|
||||
icon: <DisconnectOutlined />,
|
||||
onClick: () => {
|
||||
const connId = String(node.key || '');
|
||||
// 强制清理该连接相关的 loading 标记,避免网络卡住后重连仍被短路。
|
||||
Array.from(loadingNodesRef.current).forEach((loadingKey) => {
|
||||
if (loadingKey === `dbs-${connId}` || loadingKey.startsWith(`tables-${connId}-`)) {
|
||||
loadingNodesRef.current.delete(loadingKey);
|
||||
}
|
||||
});
|
||||
// Reset status recursively
|
||||
setConnectionStates(prev => {
|
||||
const next = { ...prev };
|
||||
@@ -2625,6 +2983,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
onClick: () => {
|
||||
const dbConnId = String(node.dataRef?.id || '');
|
||||
const dbName = String(node.dataRef?.dbName || node.title || '').trim();
|
||||
loadingNodesRef.current.delete(`tables-${dbConnId}-${dbName}`);
|
||||
setConnectionStates(prev => {
|
||||
const next = { ...prev };
|
||||
delete next[node.key];
|
||||
@@ -2857,15 +3216,15 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
// Get current order
|
||||
const currentTagOrder = connectionTags.map(t => t.id);
|
||||
const dragTagId = dragNode.dataRef.id;
|
||||
|
||||
|
||||
// Filter out the dragging tag
|
||||
const newOrder = currentTagOrder.filter(id => id !== dragTagId);
|
||||
|
||||
|
||||
let insertIndex = newOrder.length;
|
||||
if (dropNode.type === 'tag') {
|
||||
const dropTagId = dropNode.dataRef.id;
|
||||
const dropIndex = newOrder.indexOf(dropTagId);
|
||||
|
||||
|
||||
if (dropPosition === -1) {
|
||||
insertIndex = dropIndex;
|
||||
} else {
|
||||
@@ -2876,7 +3235,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
// Since tags are always displayed before ungrouped connections, just put it at the end
|
||||
insertIndex = newOrder.length;
|
||||
}
|
||||
|
||||
|
||||
newOrder.splice(insertIndex, 0, dragTagId);
|
||||
reorderTags(newOrder);
|
||||
}
|
||||
@@ -2897,7 +3256,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
moveConnectionToTag(dragNode.key, targetTag.id);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
// Drop target is NOT under a tag (ungrouped) -> move OUT of tag
|
||||
const sourceTag = connectionTags.find(t => t.connectionIds.includes(dragNode.key));
|
||||
if (sourceTag) {
|
||||
@@ -2921,7 +3280,28 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
|
||||
return (
|
||||
<div style={{ display: 'flex', flexDirection: 'column', height: '100%' }}>
|
||||
<div style={{ padding: '4px 8px' }}>
|
||||
<Search placeholder="搜索..." onChange={onSearch} size="small" />
|
||||
<Space.Compact block size="small">
|
||||
<Search
|
||||
ref={searchInputRef}
|
||||
placeholder="搜索..."
|
||||
onChange={onSearch}
|
||||
size="small"
|
||||
style={{ width: '100%' }}
|
||||
/>
|
||||
<Popover
|
||||
content={searchScopePopoverContent}
|
||||
trigger="click"
|
||||
placement="bottomRight"
|
||||
open={isSearchScopePopoverOpen}
|
||||
onOpenChange={setIsSearchScopePopoverOpen}
|
||||
>
|
||||
<Tooltip title={`搜索范围:${searchScopeSummary}`}>
|
||||
<Button size="small" icon={<DownOutlined />} style={{ width: 86 }}>
|
||||
范围{searchScopes.includes('smart') ? '(智)' : `(${searchScopes.length})`}
|
||||
</Button>
|
||||
</Tooltip>
|
||||
</Popover>
|
||||
</Space.Compact>
|
||||
</div>
|
||||
|
||||
{/* Toolbar */}
|
||||
|
||||
@@ -261,9 +261,18 @@ const TableDesigner: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
const darkMode = theme === 'dark';
|
||||
const resizeGuideColor = darkMode ? '#f6c453' : '#1890ff';
|
||||
const readOnly = !!tab.readOnly;
|
||||
const panelRadius = 10;
|
||||
const panelFrameColor = darkMode ? 'rgba(0, 0, 0, 0.18)' : 'rgba(0, 0, 0, 0.12)';
|
||||
const panelToolbarBorder = darkMode ? 'rgba(255, 255, 255, 0.12)' : 'rgba(0, 0, 0, 0.10)';
|
||||
const panelToolbarBg = darkMode ? 'rgba(20, 20, 20, 0.35)' : 'rgba(255, 255, 255, 0.72)';
|
||||
const panelBodyBg = darkMode ? 'rgba(0, 0, 0, 0.24)' : 'rgba(255, 255, 255, 0.82)';
|
||||
const focusRowBg = darkMode ? 'rgba(246, 196, 83, 0.22)' : 'rgba(24, 144, 255, 0.12)';
|
||||
|
||||
const [tableHeight, setTableHeight] = useState(500);
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const pendingFocusColumnKeyRef = useRef<string | null>(null);
|
||||
const focusHighlightTimerRef = useRef<number | null>(null);
|
||||
const [focusColumnKey, setFocusColumnKey] = useState('');
|
||||
|
||||
const openCommentEditor = useCallback((record: EditableColumn) => {
|
||||
if (!record?._key) return;
|
||||
@@ -346,6 +355,61 @@ const TableDesigner: React.FC<{ tab: TabData }> = ({ tab }) => {
|
||||
setSelectedColumnRowKeys(prev => prev.filter(key => columns.some(c => c._key === key)));
|
||||
}, [columns]);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (focusHighlightTimerRef.current !== null) {
|
||||
window.clearTimeout(focusHighlightTimerRef.current);
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
const focusColumnRow = useCallback((targetKey: string): boolean => {
|
||||
if (activeKey !== 'columns') return false;
|
||||
const tableBody = containerRef.current?.querySelector('.ant-table-body') as HTMLElement | null;
|
||||
if (!tableBody) return false;
|
||||
const row = tableBody.querySelector(`tr[data-row-key="${targetKey}"]`) as HTMLTableRowElement | null;
|
||||
if (!row) return false;
|
||||
|
||||
row.scrollIntoView({ behavior: 'smooth', block: 'nearest' });
|
||||
setFocusColumnKey(targetKey);
|
||||
if (focusHighlightTimerRef.current !== null) {
|
||||
window.clearTimeout(focusHighlightTimerRef.current);
|
||||
}
|
||||
focusHighlightTimerRef.current = window.setTimeout(() => {
|
||||
setFocusColumnKey(prev => (prev === targetKey ? '' : prev));
|
||||
}, 1600);
|
||||
|
||||
if (!readOnly) {
|
||||
const firstInput = row.querySelector('input') as HTMLInputElement | null;
|
||||
if (firstInput) {
|
||||
firstInput.focus();
|
||||
firstInput.select();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}, [activeKey, readOnly]);
|
||||
|
||||
useEffect(() => {
|
||||
const pendingKey = pendingFocusColumnKeyRef.current;
|
||||
if (!pendingKey || activeKey !== 'columns') return;
|
||||
|
||||
let cancelled = false;
|
||||
const tryFocus = () => {
|
||||
if (cancelled) return;
|
||||
if (focusColumnRow(pendingKey)) {
|
||||
pendingFocusColumnKeyRef.current = null;
|
||||
}
|
||||
};
|
||||
|
||||
const timerA = window.setTimeout(tryFocus, 0);
|
||||
const timerB = window.setTimeout(tryFocus, 96);
|
||||
return () => {
|
||||
cancelled = true;
|
||||
window.clearTimeout(timerA);
|
||||
window.clearTimeout(timerB);
|
||||
};
|
||||
}, [activeKey, columns, focusColumnRow]);
|
||||
|
||||
// Initial Columns Definition
|
||||
useEffect(() => {
|
||||
const initialCols = [
|
||||
@@ -886,21 +950,46 @@ ${selectedTrigger.statement}`;
|
||||
}));
|
||||
};
|
||||
|
||||
const handleAddColumn = () => {
|
||||
const newCol: EditableColumn = {
|
||||
name: isNewTable ? 'new_column' : `new_col_${columns.length + 1}`,
|
||||
type: 'varchar(255)',
|
||||
nullable: 'YES',
|
||||
key: '',
|
||||
extra: '',
|
||||
comment: '',
|
||||
default: '',
|
||||
_key: `new-${Date.now()}`,
|
||||
isNew: true,
|
||||
isAutoIncrement: false
|
||||
};
|
||||
setColumns([...columns, newCol]);
|
||||
};
|
||||
const createNewColumn = useCallback((indexHint: number): EditableColumn => ({
|
||||
name: isNewTable ? 'new_column' : `new_col_${indexHint}`,
|
||||
type: 'varchar(255)',
|
||||
nullable: 'YES',
|
||||
key: '',
|
||||
extra: '',
|
||||
comment: '',
|
||||
default: '',
|
||||
_key: `new-${Date.now()}-${Math.random().toString(36).slice(2, 8)}`,
|
||||
isNew: true,
|
||||
isAutoIncrement: false
|
||||
}), [isNewTable]);
|
||||
|
||||
const handleAddColumn = useCallback((insertAfterKey?: string) => {
|
||||
const newCol = createNewColumn(columns.length + 1);
|
||||
setColumns(prev => {
|
||||
const next = [...prev];
|
||||
if (insertAfterKey) {
|
||||
const insertIndex = next.findIndex(col => col._key === insertAfterKey);
|
||||
if (insertIndex >= 0) {
|
||||
next.splice(insertIndex + 1, 0, newCol);
|
||||
return next;
|
||||
}
|
||||
}
|
||||
next.push(newCol);
|
||||
return next;
|
||||
});
|
||||
setSelectedColumnRowKeys([newCol._key]);
|
||||
pendingFocusColumnKeyRef.current = newCol._key;
|
||||
}, [columns.length, createNewColumn]);
|
||||
|
||||
const handleAddColumnAfterSelected = useCallback(() => {
|
||||
const selectedSet = new Set(selectedColumnRowKeys);
|
||||
const anchor = columns.find(col => selectedSet.has(col._key));
|
||||
if (!anchor) {
|
||||
message.warning('请先选择一个字段,再执行插入。');
|
||||
return;
|
||||
}
|
||||
handleAddColumn(anchor._key);
|
||||
}, [columns, handleAddColumn, selectedColumnRowKeys]);
|
||||
|
||||
const handleDeleteColumn = (key: string) => {
|
||||
setColumns(prev => prev.filter(c => c._key !== key));
|
||||
@@ -1920,22 +2009,35 @@ END;`;
|
||||
}));
|
||||
|
||||
const columnsTabContent = (
|
||||
<div ref={containerRef} className="table-designer-wrapper" style={{ height: '100%', overflow: 'hidden', position: 'relative' }}>
|
||||
<div
|
||||
ref={containerRef}
|
||||
className="table-designer-wrapper"
|
||||
style={{
|
||||
height: '100%',
|
||||
overflow: 'hidden',
|
||||
position: 'relative',
|
||||
background: panelBodyBg
|
||||
}}
|
||||
>
|
||||
<style>{`
|
||||
.table-designer-wrapper .ant-table-body {
|
||||
max-height: ${tableHeight}px !important;
|
||||
}
|
||||
}
|
||||
.table-designer-wrapper .table-designer-focus-row > .ant-table-cell {
|
||||
background: ${focusRowBg} !important;
|
||||
}
|
||||
`}</style>
|
||||
{readOnly ? (
|
||||
<Table
|
||||
dataSource={columns}
|
||||
columns={resizableColumns}
|
||||
rowKey="_key"
|
||||
rowClassName={(record: EditableColumn) => record._key === focusColumnKey ? 'table-designer-focus-row' : ''}
|
||||
size="small"
|
||||
pagination={false}
|
||||
loading={loading}
|
||||
scroll={{ y: tableHeight }}
|
||||
bordered
|
||||
bordered={false}
|
||||
components={{
|
||||
header: {
|
||||
cell: ResizableTitle,
|
||||
@@ -1953,11 +2055,12 @@ END;`;
|
||||
onChange: (nextSelectedRowKeys) => setSelectedColumnRowKeys(nextSelectedRowKeys as string[]),
|
||||
}}
|
||||
rowKey="_key"
|
||||
rowClassName={(record: EditableColumn) => record._key === focusColumnKey ? 'table-designer-focus-row' : ''}
|
||||
size="small"
|
||||
pagination={false}
|
||||
loading={loading}
|
||||
scroll={{ y: tableHeight }}
|
||||
bordered
|
||||
bordered={false}
|
||||
components={{
|
||||
body: { row: SortableRow },
|
||||
header: { cell: ResizableTitle }
|
||||
@@ -1985,8 +2088,63 @@ END;`;
|
||||
);
|
||||
|
||||
return (
|
||||
<div style={{ display: 'flex', flexDirection: 'column', height: '100%' }}>
|
||||
<div style={{ padding: '8px', borderBottom: '1px solid #eee', display: 'flex', gap: '8px', alignItems: 'center' }}>
|
||||
<div className="table-designer-shell" style={{ display: 'flex', flexDirection: 'column', height: '100%', minHeight: 0, padding: '6px 0' }}>
|
||||
<style>{`
|
||||
.table-designer-shell .ant-table,
|
||||
.table-designer-shell .ant-table-wrapper,
|
||||
.table-designer-shell .ant-table-container {
|
||||
background: transparent !important;
|
||||
}
|
||||
.table-designer-shell .ant-table-wrapper,
|
||||
.table-designer-shell .ant-table-container {
|
||||
border: none !important;
|
||||
overflow: hidden !important;
|
||||
}
|
||||
.table-designer-shell .ant-table-thead > tr > th {
|
||||
background: transparent !important;
|
||||
border-bottom: 1px solid ${darkMode ? 'rgba(255,255,255,0.06)' : 'rgba(0,0,0,0.06)'} !important;
|
||||
border-inline-end: 1px solid transparent !important;
|
||||
}
|
||||
.table-designer-shell .ant-table-tbody > tr > td,
|
||||
.table-designer-shell .ant-table-tbody .ant-table-row > .ant-table-cell {
|
||||
background: transparent !important;
|
||||
border-bottom: 1px solid ${darkMode ? 'rgba(255,255,255,0.05)' : 'rgba(0,0,0,0.05)'} !important;
|
||||
border-inline-end: 1px solid transparent !important;
|
||||
}
|
||||
.table-designer-shell .ant-table-thead > tr > th::before {
|
||||
display: none !important;
|
||||
}
|
||||
.table-designer-shell .ant-table-tbody > tr:hover > td,
|
||||
.table-designer-shell .ant-table-tbody .ant-table-row:hover > .ant-table-cell {
|
||||
background: ${darkMode ? 'rgba(255,255,255,0.06)' : 'rgba(0,0,0,0.02)'} !important;
|
||||
}
|
||||
.table-designer-shell .ant-tabs-nav {
|
||||
margin-bottom: 8px !important;
|
||||
}
|
||||
.table-designer-shell .ant-tabs-nav::before {
|
||||
border-bottom-color: ${darkMode ? 'rgba(255,255,255,0.08)' : 'rgba(0,0,0,0.08)'} !important;
|
||||
}
|
||||
.table-designer-shell .ant-tabs-content-holder,
|
||||
.table-designer-shell .ant-tabs-content,
|
||||
.table-designer-shell .ant-tabs-tabpane {
|
||||
height: 100%;
|
||||
}
|
||||
`}</style>
|
||||
<div
|
||||
style={{
|
||||
padding: '10px 12px 8px 12px',
|
||||
borderBottom: `1px solid ${panelToolbarBorder}`,
|
||||
borderTopLeftRadius: panelRadius,
|
||||
borderTopRightRadius: panelRadius,
|
||||
borderLeft: `1px solid ${panelFrameColor}`,
|
||||
borderRight: `1px solid ${panelFrameColor}`,
|
||||
borderTop: `1px solid ${panelFrameColor}`,
|
||||
background: panelToolbarBg,
|
||||
display: 'flex',
|
||||
gap: '8px',
|
||||
alignItems: 'center'
|
||||
}}
|
||||
>
|
||||
{isNewTable && (
|
||||
<>
|
||||
<Input
|
||||
@@ -2014,14 +2172,25 @@ END;`;
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
{!readOnly && <Button icon={<SaveOutlined />} type="primary" onClick={generateDDL}>保存</Button>}
|
||||
{!isNewTable && <Button icon={<ReloadOutlined />} onClick={fetchData}>刷新</Button>}
|
||||
{!readOnly && <Button size="small" icon={<SaveOutlined />} type="primary" onClick={generateDDL}>保存</Button>}
|
||||
{!isNewTable && <Button size="small" icon={<ReloadOutlined />} onClick={fetchData}>刷新</Button>}
|
||||
{!isNewTable && !readOnly && supportsTableCommentOps() && (
|
||||
<Button icon={<EditOutlined />} onClick={openTableCommentModal}>表备注</Button>
|
||||
<Button size="small" icon={<EditOutlined />} onClick={openTableCommentModal}>表备注</Button>
|
||||
)}
|
||||
{!readOnly && <Button icon={<PlusOutlined />} onClick={handleAddColumn}>添加字段</Button>}
|
||||
{!readOnly && <Button size="small" icon={<PlusOutlined />} onClick={() => handleAddColumn()}>添加字段</Button>}
|
||||
{!readOnly && (
|
||||
<Button
|
||||
size="small"
|
||||
icon={<PlusOutlined />}
|
||||
onClick={handleAddColumnAfterSelected}
|
||||
disabled={selectedColumnRowKeys.length === 0}
|
||||
>
|
||||
在选中字段后添加
|
||||
</Button>
|
||||
)}
|
||||
{!readOnly && (
|
||||
<Button
|
||||
size="small"
|
||||
icon={<CopyOutlined />}
|
||||
onClick={openCopySelectedColumnsModal}
|
||||
disabled={selectedColumns.length === 0}
|
||||
@@ -2034,7 +2203,17 @@ END;`;
|
||||
<Tabs
|
||||
activeKey={activeKey}
|
||||
onChange={setActiveKey}
|
||||
style={{ flex: 1, padding: '0 10px' }}
|
||||
style={{
|
||||
flex: 1,
|
||||
minHeight: 0,
|
||||
padding: '8px 10px 10px 10px',
|
||||
borderBottomLeftRadius: panelRadius,
|
||||
borderBottomRightRadius: panelRadius,
|
||||
borderLeft: `1px solid ${panelFrameColor}`,
|
||||
borderRight: `1px solid ${panelFrameColor}`,
|
||||
borderBottom: `1px solid ${panelFrameColor}`,
|
||||
background: panelBodyBg
|
||||
}}
|
||||
items={[
|
||||
{
|
||||
key: 'columns',
|
||||
@@ -2276,7 +2455,7 @@ END;`;
|
||||
label: 'DDL',
|
||||
icon: <FileTextOutlined />,
|
||||
children: (
|
||||
<div style={{ height: 'calc(100vh - 200px)', border: darkMode ? '1px solid #303030' : '1px solid #d9d9d9', borderRadius: 4 }}>
|
||||
<div style={{ height: 'calc(100vh - 200px)', border: `1px solid ${panelFrameColor}`, borderRadius: panelRadius, background: panelBodyBg }}>
|
||||
<Editor
|
||||
height="100%"
|
||||
language="sql"
|
||||
@@ -2517,7 +2696,7 @@ END;`;
|
||||
<span><strong>时机:</strong> {selectedTrigger.timing}</span>
|
||||
<span><strong>事件:</strong> {selectedTrigger.event}</span>
|
||||
</div>
|
||||
<div style={{ border: darkMode ? '1px solid #303030' : '1px solid #d9d9d9', borderRadius: 4 }}>
|
||||
<div style={{ border: `1px solid ${panelFrameColor}`, borderRadius: panelRadius, background: panelBodyBg }}>
|
||||
<Editor
|
||||
height="350px"
|
||||
language="sql"
|
||||
@@ -2553,7 +2732,7 @@ END;`;
|
||||
<span>修改触发器时会先删除原触发器,再创建新触发器。</span>
|
||||
)}
|
||||
</div>
|
||||
<div style={{ border: darkMode ? '1px solid #303030' : '1px solid #d9d9d9', borderRadius: 4 }}>
|
||||
<div style={{ border: `1px solid ${panelFrameColor}`, borderRadius: panelRadius, background: panelBodyBg }}>
|
||||
<Editor
|
||||
height="350px"
|
||||
language="sql"
|
||||
|
||||
@@ -1,6 +1,14 @@
|
||||
import { create } from 'zustand';
|
||||
import { persist } from 'zustand/middleware';
|
||||
import { ConnectionConfig, ProxyConfig, SavedConnection, TabData, SavedQuery, ConnectionTag } from './types';
|
||||
import {
|
||||
ShortcutAction,
|
||||
ShortcutBinding,
|
||||
ShortcutOptions,
|
||||
DEFAULT_SHORTCUT_OPTIONS,
|
||||
cloneShortcutOptions,
|
||||
sanitizeShortcutOptions,
|
||||
} from './utils/shortcuts';
|
||||
|
||||
const DEFAULT_APPEARANCE = { opacity: 1.0, blur: 0 };
|
||||
const DEFAULT_UI_SCALE = 1.0;
|
||||
@@ -48,6 +56,23 @@ const SUPPORTED_CONNECTION_TYPES = new Set([
|
||||
'duckdb',
|
||||
'custom',
|
||||
]);
|
||||
const SSL_SUPPORTED_CONNECTION_TYPES = new Set([
|
||||
'mysql',
|
||||
'mariadb',
|
||||
'diros',
|
||||
'sphinx',
|
||||
'dameng',
|
||||
'clickhouse',
|
||||
'postgres',
|
||||
'sqlserver',
|
||||
'oracle',
|
||||
'kingbase',
|
||||
'highgo',
|
||||
'vastbase',
|
||||
'mongodb',
|
||||
'redis',
|
||||
'tdengine',
|
||||
]);
|
||||
|
||||
const getDefaultPortByType = (type: string): number => {
|
||||
switch (type) {
|
||||
@@ -177,6 +202,16 @@ const sanitizeConnectionConfig = (value: unknown): ConnectionConfig => {
|
||||
const defaultPort = getDefaultPortByType(type);
|
||||
const savePassword = typeof raw.savePassword === 'boolean' ? raw.savePassword : true;
|
||||
const mongoSrv = !!raw.mongoSrv;
|
||||
const sslCapable = SSL_SUPPORTED_CONNECTION_TYPES.has(type);
|
||||
const sslModeRaw = toTrimmedString(raw.sslMode, 'preferred').toLowerCase();
|
||||
const sslMode: 'preferred' | 'required' | 'skip-verify' | 'disable' =
|
||||
sslModeRaw === 'required'
|
||||
? 'required'
|
||||
: sslModeRaw === 'skip-verify'
|
||||
? 'skip-verify'
|
||||
: sslModeRaw === 'disable'
|
||||
? 'disable'
|
||||
: 'preferred';
|
||||
|
||||
const sshRaw = (raw.ssh && typeof raw.ssh === 'object') ? raw.ssh as Record<string, unknown> : {};
|
||||
const ssh = {
|
||||
@@ -206,6 +241,10 @@ const sanitizeConnectionConfig = (value: unknown): ConnectionConfig => {
|
||||
password: savePassword ? toTrimmedString(raw.password) : '',
|
||||
savePassword,
|
||||
database: toTrimmedString(raw.database),
|
||||
useSSL: sslCapable ? !!raw.useSSL : false,
|
||||
sslMode: sslCapable ? sslMode : 'disable',
|
||||
sslCertPath: sslCapable ? toTrimmedString(raw.sslCertPath) : '',
|
||||
sslKeyPath: sslCapable ? toTrimmedString(raw.sslKeyPath) : '',
|
||||
useSSH: !!raw.useSSH,
|
||||
ssh,
|
||||
useProxy: !!raw.useProxy,
|
||||
@@ -359,6 +398,7 @@ interface AppState {
|
||||
globalProxy: GlobalProxyConfig;
|
||||
sqlFormatOptions: { keywordCase: 'upper' | 'lower' };
|
||||
queryOptions: QueryOptions;
|
||||
shortcutOptions: ShortcutOptions;
|
||||
sqlLogs: SqlLog[];
|
||||
tableAccessCount: Record<string, number>;
|
||||
tableSortPreference: Record<string, 'name' | 'frequency'>;
|
||||
@@ -396,6 +436,8 @@ interface AppState {
|
||||
setGlobalProxy: (proxy: Partial<GlobalProxyConfig>) => void;
|
||||
setSqlFormatOptions: (options: { keywordCase: 'upper' | 'lower' }) => void;
|
||||
setQueryOptions: (options: Partial<QueryOptions>) => void;
|
||||
updateShortcut: (action: ShortcutAction, binding: Partial<ShortcutBinding>) => void;
|
||||
resetShortcutOptions: () => void;
|
||||
|
||||
addSqlLog: (log: SqlLog) => void;
|
||||
clearSqlLogs: () => void;
|
||||
@@ -537,6 +579,7 @@ export const useStore = create<AppState>()(
|
||||
globalProxy: { ...DEFAULT_GLOBAL_PROXY },
|
||||
sqlFormatOptions: { keywordCase: 'upper' },
|
||||
queryOptions: { maxRows: 5000, showColumnComment: true, showColumnType: true },
|
||||
shortcutOptions: cloneShortcutOptions(DEFAULT_SHORTCUT_OPTIONS),
|
||||
sqlLogs: [],
|
||||
tableAccessCount: {},
|
||||
tableSortPreference: {},
|
||||
@@ -708,6 +751,16 @@ export const useStore = create<AppState>()(
|
||||
setGlobalProxy: (proxy) => set((state) => ({ globalProxy: sanitizeGlobalProxy({ ...state.globalProxy, ...proxy }) })),
|
||||
setSqlFormatOptions: (options) => set({ sqlFormatOptions: options }),
|
||||
setQueryOptions: (options) => set((state) => ({ queryOptions: { ...state.queryOptions, ...options } })),
|
||||
updateShortcut: (action, binding) => set((state) => ({
|
||||
shortcutOptions: {
|
||||
...state.shortcutOptions,
|
||||
[action]: {
|
||||
...state.shortcutOptions[action],
|
||||
...binding,
|
||||
},
|
||||
},
|
||||
})),
|
||||
resetShortcutOptions: () => set({ shortcutOptions: cloneShortcutOptions(DEFAULT_SHORTCUT_OPTIONS) }),
|
||||
|
||||
addSqlLog: (log) => set((state) => ({ sqlLogs: [log, ...state.sqlLogs].slice(0, 1000) })), // Keep last 1000 logs
|
||||
clearSqlLogs: () => set({ sqlLogs: [] }),
|
||||
@@ -754,6 +807,7 @@ export const useStore = create<AppState>()(
|
||||
nextState.globalProxy = sanitizeGlobalProxy(state.globalProxy);
|
||||
nextState.sqlFormatOptions = sanitizeSqlFormatOptions(state.sqlFormatOptions);
|
||||
nextState.queryOptions = sanitizeQueryOptions(state.queryOptions);
|
||||
nextState.shortcutOptions = sanitizeShortcutOptions(state.shortcutOptions);
|
||||
nextState.tableAccessCount = sanitizeTableAccessCount(state.tableAccessCount);
|
||||
nextState.tableSortPreference = sanitizeTableSortPreference(state.tableSortPreference);
|
||||
return nextState as AppState;
|
||||
@@ -774,6 +828,7 @@ export const useStore = create<AppState>()(
|
||||
globalProxy: sanitizeGlobalProxy(state.globalProxy),
|
||||
sqlFormatOptions: sanitizeSqlFormatOptions(state.sqlFormatOptions),
|
||||
queryOptions: sanitizeQueryOptions(state.queryOptions),
|
||||
shortcutOptions: sanitizeShortcutOptions(state.shortcutOptions),
|
||||
tableAccessCount: sanitizeTableAccessCount(state.tableAccessCount),
|
||||
tableSortPreference: sanitizeTableSortPreference(state.tableSortPreference),
|
||||
};
|
||||
@@ -790,6 +845,7 @@ export const useStore = create<AppState>()(
|
||||
globalProxy: state.globalProxy,
|
||||
sqlFormatOptions: state.sqlFormatOptions,
|
||||
queryOptions: state.queryOptions,
|
||||
shortcutOptions: state.shortcutOptions,
|
||||
tableAccessCount: state.tableAccessCount,
|
||||
tableSortPreference: state.tableSortPreference
|
||||
}), // Don't persist logs
|
||||
|
||||
@@ -22,6 +22,10 @@ export interface ConnectionConfig {
|
||||
password?: string;
|
||||
savePassword?: boolean;
|
||||
database?: string;
|
||||
useSSL?: boolean;
|
||||
sslMode?: 'preferred' | 'required' | 'skip-verify' | 'disable';
|
||||
sslCertPath?: string;
|
||||
sslKeyPath?: string;
|
||||
useSSH?: boolean;
|
||||
ssh?: SSHConfig;
|
||||
useProxy?: boolean;
|
||||
|
||||
1014
frontend/src/utils/mongodb.ts
Normal file
1014
frontend/src/utils/mongodb.ts
Normal file
File diff suppressed because it is too large
Load Diff
258
frontend/src/utils/shortcuts.ts
Normal file
258
frontend/src/utils/shortcuts.ts
Normal file
@@ -0,0 +1,258 @@
|
||||
import type { KeyboardEvent as ReactKeyboardEvent } from 'react';
|
||||
|
||||
export type ShortcutAction =
|
||||
| 'runQuery'
|
||||
| 'focusSidebarSearch'
|
||||
| 'newQueryTab'
|
||||
| 'toggleLogPanel'
|
||||
| 'toggleTheme'
|
||||
| 'openShortcutManager';
|
||||
|
||||
export interface ShortcutBinding {
|
||||
combo: string;
|
||||
enabled: boolean;
|
||||
}
|
||||
|
||||
export type ShortcutOptions = Record<ShortcutAction, ShortcutBinding>;
|
||||
|
||||
export interface ShortcutActionMeta {
|
||||
label: string;
|
||||
description: string;
|
||||
allowInEditable?: boolean;
|
||||
}
|
||||
|
||||
const MODIFIER_ORDER = ['Ctrl', 'Meta', 'Alt', 'Shift'] as const;
|
||||
const MODIFIER_SET = new Set(MODIFIER_ORDER);
|
||||
|
||||
const KEY_ALIASES: Record<string, string> = {
|
||||
control: 'Ctrl',
|
||||
ctrl: 'Ctrl',
|
||||
command: 'Meta',
|
||||
cmd: 'Meta',
|
||||
meta: 'Meta',
|
||||
option: 'Alt',
|
||||
alt: 'Alt',
|
||||
shift: 'Shift',
|
||||
escape: 'Esc',
|
||||
esc: 'Esc',
|
||||
return: 'Enter',
|
||||
enter: 'Enter',
|
||||
tab: 'Tab',
|
||||
space: 'Space',
|
||||
' ': 'Space',
|
||||
backspace: 'Backspace',
|
||||
delete: 'Delete',
|
||||
del: 'Delete',
|
||||
arrowup: 'Up',
|
||||
up: 'Up',
|
||||
arrowdown: 'Down',
|
||||
down: 'Down',
|
||||
arrowleft: 'Left',
|
||||
left: 'Left',
|
||||
arrowright: 'Right',
|
||||
right: 'Right',
|
||||
pagedown: 'PageDown',
|
||||
pageup: 'PageUp',
|
||||
home: 'Home',
|
||||
end: 'End',
|
||||
insert: 'Insert',
|
||||
',': ',',
|
||||
'.': '.',
|
||||
'/': '/',
|
||||
';': ';',
|
||||
"'": "'",
|
||||
'[': '[',
|
||||
']': ']',
|
||||
'\\': '\\',
|
||||
'-': '-',
|
||||
'=': '=',
|
||||
'`': '`',
|
||||
};
|
||||
|
||||
export const SHORTCUT_ACTION_ORDER: ShortcutAction[] = [
|
||||
'runQuery',
|
||||
'focusSidebarSearch',
|
||||
'newQueryTab',
|
||||
'toggleLogPanel',
|
||||
'toggleTheme',
|
||||
'openShortcutManager',
|
||||
];
|
||||
|
||||
export const SHORTCUT_ACTION_META: Record<ShortcutAction, ShortcutActionMeta> = {
|
||||
runQuery: {
|
||||
label: '执行 SQL',
|
||||
description: '在当前查询页执行 SQL',
|
||||
},
|
||||
focusSidebarSearch: {
|
||||
label: '聚焦侧边栏搜索',
|
||||
description: '定位到左侧连接树搜索框',
|
||||
allowInEditable: true,
|
||||
},
|
||||
newQueryTab: {
|
||||
label: '新建查询页',
|
||||
description: '创建一个新的 SQL 查询标签页',
|
||||
},
|
||||
toggleLogPanel: {
|
||||
label: '切换日志面板',
|
||||
description: '打开或关闭 SQL 执行日志面板',
|
||||
},
|
||||
toggleTheme: {
|
||||
label: '切换主题',
|
||||
description: '在亮色和暗色主题之间切换',
|
||||
},
|
||||
openShortcutManager: {
|
||||
label: '打开快捷键管理',
|
||||
description: '打开快捷键设置面板',
|
||||
allowInEditable: true,
|
||||
},
|
||||
};
|
||||
|
||||
export const DEFAULT_SHORTCUT_OPTIONS: ShortcutOptions = {
|
||||
runQuery: { combo: 'Ctrl+Shift+R', enabled: true },
|
||||
focusSidebarSearch: { combo: 'Ctrl+F', enabled: true },
|
||||
newQueryTab: { combo: 'Ctrl+Shift+N', enabled: true },
|
||||
toggleLogPanel: { combo: 'Ctrl+Shift+L', enabled: true },
|
||||
toggleTheme: { combo: 'Ctrl+Shift+D', enabled: true },
|
||||
openShortcutManager: { combo: 'Ctrl+,', enabled: true },
|
||||
};
|
||||
|
||||
const normalizeKeyToken = (value: string): string => {
|
||||
const token = String(value || '').trim();
|
||||
if (!token) return '';
|
||||
const alias = KEY_ALIASES[token.toLowerCase()];
|
||||
if (alias) return alias;
|
||||
if (/^f([1-9]|1[0-2])$/i.test(token)) {
|
||||
return token.toUpperCase();
|
||||
}
|
||||
if (token.length === 1) {
|
||||
return token === '+' ? '+' : token.toUpperCase();
|
||||
}
|
||||
return token.length > 1 ? token[0].toUpperCase() + token.slice(1).toLowerCase() : token;
|
||||
};
|
||||
|
||||
export const normalizeShortcutCombo = (combo: string): string => {
|
||||
const raw = String(combo || '').trim();
|
||||
if (!raw) return '';
|
||||
|
||||
const pieces = raw
|
||||
.split('+')
|
||||
.map(part => part.trim())
|
||||
.filter(Boolean);
|
||||
|
||||
const modifiers: string[] = [];
|
||||
let key = '';
|
||||
|
||||
pieces.forEach((part) => {
|
||||
const normalized = normalizeKeyToken(part);
|
||||
if (!normalized) return;
|
||||
if (MODIFIER_SET.has(normalized as typeof MODIFIER_ORDER[number])) {
|
||||
if (!modifiers.includes(normalized)) {
|
||||
modifiers.push(normalized);
|
||||
}
|
||||
return;
|
||||
}
|
||||
key = normalized;
|
||||
});
|
||||
|
||||
modifiers.sort((a, b) => MODIFIER_ORDER.indexOf(a as typeof MODIFIER_ORDER[number]) - MODIFIER_ORDER.indexOf(b as typeof MODIFIER_ORDER[number]));
|
||||
if (!key) {
|
||||
return modifiers.join('+');
|
||||
}
|
||||
return [...modifiers, key].join('+');
|
||||
};
|
||||
|
||||
const normalizeKeyboardKey = (key: string): string => {
|
||||
const token = String(key || '').trim();
|
||||
if (!token) return '';
|
||||
const alias = KEY_ALIASES[token.toLowerCase()];
|
||||
if (alias) return alias;
|
||||
if (token.length === 1) {
|
||||
if (token === ' ') return 'Space';
|
||||
return token.toUpperCase();
|
||||
}
|
||||
if (/^f([1-9]|1[0-2])$/i.test(token)) {
|
||||
return token.toUpperCase();
|
||||
}
|
||||
return token.length > 1 ? token[0].toUpperCase() + token.slice(1) : token;
|
||||
};
|
||||
|
||||
export const eventToShortcut = (event: KeyboardEvent | ReactKeyboardEvent): string => {
|
||||
const key = normalizeKeyboardKey(event.key);
|
||||
if (!key || MODIFIER_SET.has(key as typeof MODIFIER_ORDER[number])) {
|
||||
return '';
|
||||
}
|
||||
|
||||
const modifiers: string[] = [];
|
||||
if (event.ctrlKey) modifiers.push('Ctrl');
|
||||
if (event.metaKey) modifiers.push('Meta');
|
||||
if (event.altKey) modifiers.push('Alt');
|
||||
if (event.shiftKey) modifiers.push('Shift');
|
||||
|
||||
return normalizeShortcutCombo([...modifiers, key].join('+'));
|
||||
};
|
||||
|
||||
export const isShortcutMatch = (event: KeyboardEvent | ReactKeyboardEvent, combo: string): boolean => {
|
||||
const expected = normalizeShortcutCombo(combo);
|
||||
if (!expected) return false;
|
||||
const actual = eventToShortcut(event);
|
||||
return actual === expected;
|
||||
};
|
||||
|
||||
export const hasModifierKey = (combo: string): boolean => {
|
||||
const normalized = normalizeShortcutCombo(combo);
|
||||
if (!normalized) return false;
|
||||
return normalized.split('+').some(part => MODIFIER_SET.has(part as typeof MODIFIER_ORDER[number]));
|
||||
};
|
||||
|
||||
export const cloneShortcutOptions = (value: ShortcutOptions): ShortcutOptions => {
|
||||
return SHORTCUT_ACTION_ORDER.reduce((acc, action) => {
|
||||
acc[action] = {
|
||||
combo: normalizeShortcutCombo(value[action]?.combo || DEFAULT_SHORTCUT_OPTIONS[action].combo),
|
||||
enabled: value[action]?.enabled !== false,
|
||||
};
|
||||
return acc;
|
||||
}, {} as ShortcutOptions);
|
||||
};
|
||||
|
||||
export const sanitizeShortcutOptions = (value: unknown): ShortcutOptions => {
|
||||
const raw = (value && typeof value === 'object') ? value as Record<string, unknown> : {};
|
||||
const defaults = cloneShortcutOptions(DEFAULT_SHORTCUT_OPTIONS);
|
||||
|
||||
SHORTCUT_ACTION_ORDER.forEach((action) => {
|
||||
const actionRaw = raw[action];
|
||||
if (!actionRaw || typeof actionRaw !== 'object') {
|
||||
return;
|
||||
}
|
||||
const binding = actionRaw as Record<string, unknown>;
|
||||
const combo = normalizeShortcutCombo(String(binding.combo || defaults[action].combo));
|
||||
defaults[action] = {
|
||||
combo: combo || defaults[action].combo,
|
||||
enabled: binding.enabled === false ? false : true,
|
||||
};
|
||||
});
|
||||
|
||||
return defaults;
|
||||
};
|
||||
|
||||
export const isEditableElement = (target: EventTarget | null): boolean => {
|
||||
if (!(target instanceof HTMLElement)) {
|
||||
return false;
|
||||
}
|
||||
const tag = target.tagName.toLowerCase();
|
||||
if (target.isContentEditable) {
|
||||
return true;
|
||||
}
|
||||
if (tag === 'input' || tag === 'textarea' || tag === 'select') {
|
||||
return true;
|
||||
}
|
||||
if (target.closest('.monaco-editor, .monaco-inputbox, .ant-select, .ant-picker, .ant-input')) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
export const getShortcutDisplay = (combo: string): string => {
|
||||
const normalized = normalizeShortcutCombo(combo);
|
||||
return normalized || '-';
|
||||
};
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
export type FilterCondition = {
|
||||
id?: number;
|
||||
enabled?: boolean;
|
||||
logic?: 'AND' | 'OR';
|
||||
column?: string;
|
||||
op?: string;
|
||||
value?: string;
|
||||
@@ -142,8 +143,12 @@ export const parseListValues = (val: string) => {
|
||||
.filter(Boolean);
|
||||
};
|
||||
|
||||
const normalizeConditionLogic = (logic: unknown): 'AND' | 'OR' => {
|
||||
return String(logic || '').trim().toUpperCase() === 'OR' ? 'OR' : 'AND';
|
||||
};
|
||||
|
||||
export const buildWhereSQL = (dbType: string, conditions: FilterCondition[]) => {
|
||||
const whereParts: string[] = [];
|
||||
const whereParts: Array<{ expr: string; logic: 'AND' | 'OR' }> = [];
|
||||
|
||||
(conditions || []).forEach((cond) => {
|
||||
if (cond?.enabled === false) return;
|
||||
@@ -152,10 +157,17 @@ export const buildWhereSQL = (dbType: string, conditions: FilterCondition[]) =>
|
||||
const column = (cond?.column || '').trim();
|
||||
const value = (cond?.value ?? '').toString();
|
||||
const value2 = (cond?.value2 ?? '').toString();
|
||||
const logic = normalizeConditionLogic(cond?.logic);
|
||||
|
||||
const appendWherePart = (expr: string) => {
|
||||
const normalizedExpr = String(expr || '').trim();
|
||||
if (!normalizedExpr) return;
|
||||
whereParts.push({ expr: normalizedExpr, logic });
|
||||
};
|
||||
|
||||
if (op === 'CUSTOM') {
|
||||
const expr = value.trim();
|
||||
if (expr) whereParts.push(`(${expr})`);
|
||||
if (expr) appendWherePart(`(${expr})`);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -165,80 +177,80 @@ export const buildWhereSQL = (dbType: string, conditions: FilterCondition[]) =>
|
||||
|
||||
switch (op) {
|
||||
case 'IS_NULL':
|
||||
whereParts.push(`${col} IS NULL`);
|
||||
appendWherePart(`${col} IS NULL`);
|
||||
return;
|
||||
case 'IS_NOT_NULL':
|
||||
whereParts.push(`${col} IS NOT NULL`);
|
||||
appendWherePart(`${col} IS NOT NULL`);
|
||||
return;
|
||||
case 'IS_EMPTY':
|
||||
// 兼容:空值通常理解为 NULL 或空字符串
|
||||
whereParts.push(`(${col} IS NULL OR ${col} = '')`);
|
||||
appendWherePart(`(${col} IS NULL OR ${col} = '')`);
|
||||
return;
|
||||
case 'IS_NOT_EMPTY':
|
||||
whereParts.push(`(${col} IS NOT NULL AND ${col} <> '')`);
|
||||
appendWherePart(`(${col} IS NOT NULL AND ${col} <> '')`);
|
||||
return;
|
||||
case 'BETWEEN': {
|
||||
const v1 = value.trim();
|
||||
const v2 = value2.trim();
|
||||
if (!v1 || !v2) return;
|
||||
whereParts.push(`${col} BETWEEN '${escapeLiteral(v1)}' AND '${escapeLiteral(v2)}'`);
|
||||
appendWherePart(`${col} BETWEEN '${escapeLiteral(v1)}' AND '${escapeLiteral(v2)}'`);
|
||||
return;
|
||||
}
|
||||
case 'NOT_BETWEEN': {
|
||||
const v1 = value.trim();
|
||||
const v2 = value2.trim();
|
||||
if (!v1 || !v2) return;
|
||||
whereParts.push(`${col} NOT BETWEEN '${escapeLiteral(v1)}' AND '${escapeLiteral(v2)}'`);
|
||||
appendWherePart(`${col} NOT BETWEEN '${escapeLiteral(v1)}' AND '${escapeLiteral(v2)}'`);
|
||||
return;
|
||||
}
|
||||
case 'IN': {
|
||||
const items = parseListValues(value);
|
||||
if (items.length === 0) return;
|
||||
const list = items.map(v => `'${escapeLiteral(v)}'`).join(', ');
|
||||
whereParts.push(`${col} IN (${list})`);
|
||||
appendWherePart(`${col} IN (${list})`);
|
||||
return;
|
||||
}
|
||||
case 'NOT_IN': {
|
||||
const items = parseListValues(value);
|
||||
if (items.length === 0) return;
|
||||
const list = items.map(v => `'${escapeLiteral(v)}'`).join(', ');
|
||||
whereParts.push(`${col} NOT IN (${list})`);
|
||||
appendWherePart(`${col} NOT IN (${list})`);
|
||||
return;
|
||||
}
|
||||
case 'CONTAINS': {
|
||||
const v = value.trim();
|
||||
if (!v) return;
|
||||
whereParts.push(`${col} LIKE '%${escapeLiteral(v)}%'`);
|
||||
appendWherePart(`${col} LIKE '%${escapeLiteral(v)}%'`);
|
||||
return;
|
||||
}
|
||||
case 'NOT_CONTAINS': {
|
||||
const v = value.trim();
|
||||
if (!v) return;
|
||||
whereParts.push(`${col} NOT LIKE '%${escapeLiteral(v)}%'`);
|
||||
appendWherePart(`${col} NOT LIKE '%${escapeLiteral(v)}%'`);
|
||||
return;
|
||||
}
|
||||
case 'STARTS_WITH': {
|
||||
const v = value.trim();
|
||||
if (!v) return;
|
||||
whereParts.push(`${col} LIKE '${escapeLiteral(v)}%'`);
|
||||
appendWherePart(`${col} LIKE '${escapeLiteral(v)}%'`);
|
||||
return;
|
||||
}
|
||||
case 'NOT_STARTS_WITH': {
|
||||
const v = value.trim();
|
||||
if (!v) return;
|
||||
whereParts.push(`${col} NOT LIKE '${escapeLiteral(v)}%'`);
|
||||
appendWherePart(`${col} NOT LIKE '${escapeLiteral(v)}%'`);
|
||||
return;
|
||||
}
|
||||
case 'ENDS_WITH': {
|
||||
const v = value.trim();
|
||||
if (!v) return;
|
||||
whereParts.push(`${col} LIKE '%${escapeLiteral(v)}'`);
|
||||
appendWherePart(`${col} LIKE '%${escapeLiteral(v)}'`);
|
||||
return;
|
||||
}
|
||||
case 'NOT_ENDS_WITH': {
|
||||
const v = value.trim();
|
||||
if (!v) return;
|
||||
whereParts.push(`${col} NOT LIKE '%${escapeLiteral(v)}'`);
|
||||
appendWherePart(`${col} NOT LIKE '%${escapeLiteral(v)}'`);
|
||||
return;
|
||||
}
|
||||
case '=':
|
||||
@@ -249,7 +261,7 @@ export const buildWhereSQL = (dbType: string, conditions: FilterCondition[]) =>
|
||||
case '>=': {
|
||||
const v = value.trim();
|
||||
if (!v) return;
|
||||
whereParts.push(`${col} ${op} '${escapeLiteral(v)}'`);
|
||||
appendWherePart(`${col} ${op} '${escapeLiteral(v)}'`);
|
||||
return;
|
||||
}
|
||||
default: {
|
||||
@@ -257,16 +269,23 @@ export const buildWhereSQL = (dbType: string, conditions: FilterCondition[]) =>
|
||||
if (op.toUpperCase() === 'LIKE') {
|
||||
const v = value.trim();
|
||||
if (!v) return;
|
||||
whereParts.push(`${col} LIKE '%${escapeLiteral(v)}%'`);
|
||||
appendWherePart(`${col} LIKE '%${escapeLiteral(v)}%'`);
|
||||
return;
|
||||
}
|
||||
|
||||
const v = value.trim();
|
||||
if (!v) return;
|
||||
whereParts.push(`${col} ${op} '${escapeLiteral(v)}'`);
|
||||
appendWherePart(`${col} ${op} '${escapeLiteral(v)}'`);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return whereParts.length > 0 ? `WHERE ${whereParts.join(' AND ')}` : '';
|
||||
if (whereParts.length === 0) return '';
|
||||
|
||||
let whereExpr = `(${whereParts[0].expr})`;
|
||||
for (let i = 1; i < whereParts.length; i++) {
|
||||
const part = whereParts[i];
|
||||
whereExpr = `(${whereExpr} ${part.logic} (${part.expr}))`;
|
||||
}
|
||||
return `WHERE ${whereExpr}`;
|
||||
};
|
||||
|
||||
@@ -96,6 +96,10 @@ export namespace connection {
|
||||
password: string;
|
||||
savePassword?: boolean;
|
||||
database: string;
|
||||
useSSL?: boolean;
|
||||
sslMode?: string;
|
||||
sslCertPath?: string;
|
||||
sslKeyPath?: string;
|
||||
useSSH: boolean;
|
||||
ssh: SSHConfig;
|
||||
useProxy?: boolean;
|
||||
@@ -130,6 +134,10 @@ export namespace connection {
|
||||
this.password = source["password"];
|
||||
this.savePassword = source["savePassword"];
|
||||
this.database = source["database"];
|
||||
this.useSSL = source["useSSL"];
|
||||
this.sslMode = source["sslMode"];
|
||||
this.sslCertPath = source["sslCertPath"];
|
||||
this.sslKeyPath = source["sslKeyPath"];
|
||||
this.useSSH = source["useSSH"];
|
||||
this.ssh = this.convertValues(source["ssh"], SSHConfig);
|
||||
this.useProxy = source["useProxy"];
|
||||
|
||||
2
go.mod
2
go.mod
@@ -17,6 +17,7 @@ require (
|
||||
github.com/taosdata/driver-go/v3 v3.7.8
|
||||
github.com/wailsapp/wails/v2 v2.11.0
|
||||
github.com/xuri/excelize/v2 v2.10.0
|
||||
go.mongodb.org/mongo-driver v1.17.9
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0
|
||||
golang.org/x/crypto v0.47.0
|
||||
golang.org/x/mod v0.32.0
|
||||
@@ -66,6 +67,7 @@ require (
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/montanaflynn/stats v0.7.1 // indirect
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/paulmach/orb v0.12.0 // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.25 // indirect
|
||||
|
||||
4
go.sum
4
go.sum
@@ -156,6 +156,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc=
|
||||
github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE=
|
||||
github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
|
||||
@@ -248,6 +250,8 @@ github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN
|
||||
github.com/zeebo/xxh3 v1.1.0 h1:s7DLGDK45Dyfg7++yxI0khrfwq9661w9EN78eP/UZVs=
|
||||
github.com/zeebo/xxh3 v1.1.0/go.mod h1:IisAie1LELR4xhVinxWS5+zf1lA4p0MW4T+w+W07F5s=
|
||||
go.mongodb.org/mongo-driver v1.11.4/go.mod h1:PTSz5yu21bkT/wXpkS7WR5f0ddqw5quethTUn9WM+2g=
|
||||
go.mongodb.org/mongo-driver v1.17.9 h1:IexDdCuuNJ3BHrELgBlyaH9p60JXAvdzWR128q+U5tU=
|
||||
go.mongodb.org/mongo-driver v1.17.9/go.mod h1:LlOhpH5NUEfhxcAwG0UEkMqwYcc4JU18gtCdGudk/tQ=
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE=
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0=
|
||||
go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48=
|
||||
|
||||
@@ -148,6 +148,67 @@ func getCacheKey(config connection.ConnectionConfig) string {
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func shortCacheKey(cacheKey string) string {
|
||||
shortKey := cacheKey
|
||||
if len(shortKey) > 12 {
|
||||
shortKey = shortKey[:12]
|
||||
}
|
||||
return shortKey
|
||||
}
|
||||
|
||||
func shouldRefreshCachedConnection(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
normalized := strings.ToLower(normalizeErrorMessage(err))
|
||||
if normalized == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
patterns := []string{
|
||||
"invalid connection",
|
||||
"bad connection",
|
||||
"database is closed",
|
||||
"connection is already closed",
|
||||
"use of closed network connection",
|
||||
"broken pipe",
|
||||
"connection reset by peer",
|
||||
"server has gone away",
|
||||
"eof",
|
||||
}
|
||||
for _, pattern := range patterns {
|
||||
if strings.Contains(normalized, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *App) invalidateCachedDatabase(config connection.ConnectionConfig, reason error) bool {
|
||||
effectiveConfig := applyGlobalProxyToConnection(config)
|
||||
key := getCacheKey(effectiveConfig)
|
||||
shortKey := shortCacheKey(key)
|
||||
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
entry, exists := a.dbCache[key]
|
||||
if !exists || entry.inst == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if closeErr := entry.inst.Close(); closeErr != nil {
|
||||
logger.Error(closeErr, "关闭失效缓存连接失败:缓存Key=%s", shortKey)
|
||||
}
|
||||
delete(a.dbCache, key)
|
||||
if reason != nil {
|
||||
logger.Errorf("检测到连接失效,已清理缓存连接:%s 缓存Key=%s 原因=%s", formatConnSummary(effectiveConfig), shortKey, normalizeErrorMessage(reason))
|
||||
} else {
|
||||
logger.Infof("已清理缓存连接:%s 缓存Key=%s", formatConnSummary(effectiveConfig), shortKey)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func wrapConnectError(config connection.ConnectionConfig, err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
|
||||
@@ -36,6 +36,17 @@ func normalizeSchemaAndTable(config connection.ConnectionConfig, dbName string,
|
||||
return rawDB, rawTable
|
||||
}
|
||||
|
||||
dbType := strings.ToLower(strings.TrimSpace(config.Type))
|
||||
if dbType == "sqlserver" {
|
||||
// SQL Server 的 DB 接口约定:第一个参数是数据库名,schema 由 tableName(如 dbo.users) 自行解析。
|
||||
// 不能把 schema(dbo) 传到第一个参数,否则会拼出 dbo.sys.columns 等无效对象名。
|
||||
targetDB := rawDB
|
||||
if targetDB == "" {
|
||||
targetDB = strings.TrimSpace(config.Database)
|
||||
}
|
||||
return targetDB, rawTable
|
||||
}
|
||||
|
||||
if parts := strings.SplitN(rawTable, ".", 2); len(parts) == 2 {
|
||||
schema := strings.TrimSpace(parts[0])
|
||||
table := strings.TrimSpace(parts[1])
|
||||
@@ -44,13 +55,10 @@ func normalizeSchemaAndTable(config connection.ConnectionConfig, dbName string,
|
||||
}
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(config.Type)) {
|
||||
switch dbType {
|
||||
case "postgres", "kingbase", "highgo", "vastbase":
|
||||
// PG/金仓/瀚高/海量:dbName 在 UI 里是"数据库",schema 需从 tableName 或使用默认 public。
|
||||
return "public", rawTable
|
||||
case "sqlserver":
|
||||
// SQL Server:dbName 表示数据库,schema 默认 dbo
|
||||
return "dbo", rawTable
|
||||
default:
|
||||
// MySQL:dbName 表示数据库;Oracle/达梦:dbName 表示 schema/owner。
|
||||
return rawDB, rawTable
|
||||
|
||||
51
internal/app/db_context_test.go
Normal file
51
internal/app/db_context_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
func TestNormalizeSchemaAndTable_SQLServerKeepsDatabaseAndQualifiedTable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
schemaOrDb, table := normalizeSchemaAndTable(connection.ConnectionConfig{
|
||||
Type: "sqlserver",
|
||||
Database: "master",
|
||||
}, "biz_db", "dbo.users")
|
||||
|
||||
if schemaOrDb != "biz_db" {
|
||||
t.Fatalf("expected sqlserver first return value as database name, got %q", schemaOrDb)
|
||||
}
|
||||
if table != "dbo.users" {
|
||||
t.Fatalf("expected sqlserver table name keep qualified form, got %q", table)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeSchemaAndTable_SQLServerFallbackToConfigDatabase(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
schemaOrDb, table := normalizeSchemaAndTable(connection.ConnectionConfig{
|
||||
Type: "sqlserver",
|
||||
Database: "biz_db",
|
||||
}, "", "dbo.users")
|
||||
|
||||
if schemaOrDb != "biz_db" {
|
||||
t.Fatalf("expected sqlserver fallback database from config, got %q", schemaOrDb)
|
||||
}
|
||||
if table != "dbo.users" {
|
||||
t.Fatalf("expected sqlserver table name keep qualified form, got %q", table)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeSchemaAndTable_PostgresStillSplitsQualifiedName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
schema, table := normalizeSchemaAndTable(connection.ConnectionConfig{
|
||||
Type: "postgres",
|
||||
}, "demo_db", "public.orders")
|
||||
|
||||
if schema != "public" || table != "orders" {
|
||||
t.Fatalf("expected postgres qualified split to public.orders, got %q.%q", schema, table)
|
||||
}
|
||||
}
|
||||
@@ -422,15 +422,36 @@ func (a *App) DBQueryWithCancel(config connection.ConnectionConfig, dbName strin
|
||||
if !isReadQuery && strings.ToLower(strings.TrimSpace(runConfig.Type)) == "mongodb" && strings.HasPrefix(strings.TrimSpace(query), "{") {
|
||||
isReadQuery = true
|
||||
}
|
||||
if isReadQuery {
|
||||
var data []map[string]interface{}
|
||||
var columns []string
|
||||
if q, ok := dbInst.(interface {
|
||||
|
||||
runReadQuery := func(inst db.Database) ([]map[string]interface{}, []string, error) {
|
||||
if q, ok := inst.(interface {
|
||||
QueryContext(context.Context, string) ([]map[string]interface{}, []string, error)
|
||||
}); ok {
|
||||
data, columns, err = q.QueryContext(ctx, query)
|
||||
} else {
|
||||
data, columns, err = dbInst.Query(query)
|
||||
return q.QueryContext(ctx, query)
|
||||
}
|
||||
return inst.Query(query)
|
||||
}
|
||||
|
||||
runExecQuery := func(inst db.Database) (int64, error) {
|
||||
if e, ok := inst.(interface {
|
||||
ExecContext(context.Context, string) (int64, error)
|
||||
}); ok {
|
||||
return e.ExecContext(ctx, query)
|
||||
}
|
||||
return inst.Exec(query)
|
||||
}
|
||||
|
||||
if isReadQuery {
|
||||
data, columns, err := runReadQuery(dbInst)
|
||||
if err != nil && shouldRefreshCachedConnection(err) {
|
||||
if a.invalidateCachedDatabase(runConfig, err) {
|
||||
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
|
||||
if retryErr != nil {
|
||||
logger.Error(retryErr, "DBQuery 重建连接失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: retryErr.Error()}
|
||||
}
|
||||
data, columns, err = runReadQuery(retryInst)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error(err, "DBQuery 查询失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
@@ -438,13 +459,16 @@ func (a *App) DBQueryWithCancel(config connection.ConnectionConfig, dbName strin
|
||||
}
|
||||
return connection.QueryResult{Success: true, Data: data, Fields: columns, QueryID: queryID}
|
||||
} else {
|
||||
var affected int64
|
||||
if e, ok := dbInst.(interface {
|
||||
ExecContext(context.Context, string) (int64, error)
|
||||
}); ok {
|
||||
affected, err = e.ExecContext(ctx, query)
|
||||
} else {
|
||||
affected, err = dbInst.Exec(query)
|
||||
affected, err := runExecQuery(dbInst)
|
||||
if err != nil && shouldRefreshCachedConnection(err) {
|
||||
if a.invalidateCachedDatabase(runConfig, err) {
|
||||
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
|
||||
if retryErr != nil {
|
||||
logger.Error(retryErr, "DBQuery 重建连接失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: retryErr.Error()}
|
||||
}
|
||||
affected, err = runExecQuery(retryInst)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error(err, "DBQuery 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
@@ -524,15 +548,26 @@ func sqlSnippet(query string) string {
|
||||
}
|
||||
|
||||
func (a *App) DBGetDatabases(config connection.ConnectionConfig) connection.QueryResult {
|
||||
dbInst, err := a.getDatabase(config)
|
||||
runConfig := normalizeRunConfig(config, "")
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
logger.Error(err, "DBGetDatabases 获取连接失败:%s", formatConnSummary(config))
|
||||
logger.Error(err, "DBGetDatabases 获取连接失败:%s", formatConnSummary(runConfig))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
dbs, err := dbInst.GetDatabases()
|
||||
if err != nil && shouldRefreshCachedConnection(err) {
|
||||
if a.invalidateCachedDatabase(runConfig, err) {
|
||||
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
|
||||
if retryErr != nil {
|
||||
logger.Error(retryErr, "DBGetDatabases 重建连接失败:%s", formatConnSummary(runConfig))
|
||||
return connection.QueryResult{Success: false, Message: retryErr.Error()}
|
||||
}
|
||||
dbs, err = retryInst.GetDatabases()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error(err, "DBGetDatabases 获取数据库列表失败:%s", formatConnSummary(config))
|
||||
logger.Error(err, "DBGetDatabases 获取数据库列表失败:%s", formatConnSummary(runConfig))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
|
||||
@@ -554,6 +589,16 @@ func (a *App) DBGetTables(config connection.ConnectionConfig, dbName string) con
|
||||
}
|
||||
|
||||
tables, err := dbInst.GetTables(dbName)
|
||||
if err != nil && shouldRefreshCachedConnection(err) {
|
||||
if a.invalidateCachedDatabase(runConfig, err) {
|
||||
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
|
||||
if retryErr != nil {
|
||||
logger.Error(retryErr, "DBGetTables 重建连接失败:%s", formatConnSummary(runConfig))
|
||||
return connection.QueryResult{Success: false, Message: retryErr.Error()}
|
||||
}
|
||||
tables, err = retryInst.GetTables(dbName)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error(err, "DBGetTables 获取表列表失败:%s", formatConnSummary(runConfig))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
|
||||
@@ -304,13 +304,33 @@ var driverGoModulePathMap = map[string]string{
|
||||
"clickhouse": "github.com/ClickHouse/clickhouse-go/v2",
|
||||
}
|
||||
|
||||
var driverGoModuleAliasPathMap = map[string][]string{
|
||||
"mongodb": {
|
||||
"go.mongodb.org/mongo-driver",
|
||||
},
|
||||
}
|
||||
|
||||
var driverExtraHistoryLimitMap = map[string]int{
|
||||
"mongodb": 10,
|
||||
}
|
||||
|
||||
var fallbackRecentDriverVersionsMap = map[string][]goModuleVersionMeta{
|
||||
"mongodb": {
|
||||
{Version: "2.5.0"},
|
||||
{Version: "2.4.2"},
|
||||
{Version: "2.4.1"},
|
||||
{Version: "2.4.0"},
|
||||
{Version: "2.3.1"},
|
||||
{Version: "2.3.0"},
|
||||
{Version: "2.2.3"},
|
||||
{Version: "1.17.9"},
|
||||
{Version: "1.17.8"},
|
||||
{Version: "1.17.7"},
|
||||
{Version: "1.17.6"},
|
||||
{Version: "1.17.4"},
|
||||
{Version: "1.17.3"},
|
||||
{Version: "1.17.2"},
|
||||
{Version: "1.17.1"},
|
||||
{Version: "1.17.0"},
|
||||
{Version: "1.16.1"},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1600,17 +1620,57 @@ func resolveRecentDriverVersionMetas(driverType string, limit int) []goModuleVer
|
||||
if normalized == "" {
|
||||
return nil
|
||||
}
|
||||
if modulePath := strings.TrimSpace(driverGoModulePathMap[normalized]); modulePath != "" {
|
||||
if metas := fetchGoModuleVersionMetasCached(modulePath); len(metas) > 0 {
|
||||
if len(metas) > limit {
|
||||
return append([]goModuleVersionMeta(nil), metas[:limit]...)
|
||||
modulePaths := resolveDriverGoModulePaths(normalized)
|
||||
if len(modulePaths) > 0 {
|
||||
result := make([]goModuleVersionMeta, 0, limit)
|
||||
seen := make(map[string]struct{}, limit)
|
||||
appendUnique := func(values []goModuleVersionMeta, maxAppend int) {
|
||||
if maxAppend <= 0 {
|
||||
return
|
||||
}
|
||||
return append([]goModuleVersionMeta(nil), metas...)
|
||||
appended := 0
|
||||
for _, meta := range values {
|
||||
version := normalizeVersion(strings.TrimSpace(meta.Version))
|
||||
if version == "" {
|
||||
continue
|
||||
}
|
||||
key := strings.ToLower(version)
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
meta.Version = version
|
||||
result = append(result, meta)
|
||||
seen[key] = struct{}{}
|
||||
appended++
|
||||
if appended >= maxAppend {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
appendUnique(fetchGoModuleVersionMetasCached(modulePaths[0]), limit)
|
||||
|
||||
extraLimit := resolveDriverExtraHistoryLimit(normalized)
|
||||
for _, modulePath := range modulePaths[1:] {
|
||||
if extraLimit <= 0 {
|
||||
break
|
||||
}
|
||||
before := len(result)
|
||||
appendUnique(fetchGoModuleVersionMetasCached(modulePath), extraLimit)
|
||||
extraLimit -= len(result) - before
|
||||
}
|
||||
if len(result) > 0 {
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
fallbackLimit := limit + resolveDriverExtraHistoryLimit(normalized)
|
||||
if fallbackLimit <= 0 {
|
||||
fallbackLimit = limit
|
||||
}
|
||||
if fallback := fallbackRecentDriverVersionsMap[normalized]; len(fallback) > 0 {
|
||||
if len(fallback) > limit {
|
||||
return append([]goModuleVersionMeta(nil), fallback[:limit]...)
|
||||
if len(fallback) > fallbackLimit {
|
||||
return append([]goModuleVersionMeta(nil), fallback[:fallbackLimit]...)
|
||||
}
|
||||
return append([]goModuleVersionMeta(nil), fallback...)
|
||||
}
|
||||
@@ -1635,15 +1695,13 @@ func triggerDriverVersionMetadataWarmup(definitions []driverDefinition) {
|
||||
if driverType == "" || !db.IsOptionalGoDriver(driverType) {
|
||||
continue
|
||||
}
|
||||
modulePath := strings.TrimSpace(driverGoModulePathMap[driverType])
|
||||
if modulePath == "" {
|
||||
continue
|
||||
for _, modulePath := range resolveDriverGoModulePaths(driverType) {
|
||||
if _, ok := seenModule[modulePath]; ok {
|
||||
continue
|
||||
}
|
||||
seenModule[modulePath] = struct{}{}
|
||||
modulePaths = append(modulePaths, modulePath)
|
||||
}
|
||||
if _, ok := seenModule[modulePath]; ok {
|
||||
continue
|
||||
}
|
||||
seenModule[modulePath] = struct{}{}
|
||||
modulePaths = append(modulePaths, modulePath)
|
||||
}
|
||||
|
||||
if len(modulePaths) == 0 {
|
||||
@@ -1663,6 +1721,40 @@ func triggerDriverVersionMetadataWarmup(definitions []driverDefinition) {
|
||||
}(append([]string(nil), modulePaths...))
|
||||
}
|
||||
|
||||
func resolveDriverGoModulePaths(driverType string) []string {
|
||||
normalized := normalizeDriverType(driverType)
|
||||
if normalized == "" {
|
||||
return nil
|
||||
}
|
||||
paths := make([]string, 0, 3)
|
||||
seen := make(map[string]struct{}, 3)
|
||||
appendPath := func(path string) {
|
||||
trimmed := strings.TrimSpace(path)
|
||||
if trimmed == "" {
|
||||
return
|
||||
}
|
||||
if _, ok := seen[trimmed]; ok {
|
||||
return
|
||||
}
|
||||
seen[trimmed] = struct{}{}
|
||||
paths = append(paths, trimmed)
|
||||
}
|
||||
|
||||
appendPath(driverGoModulePathMap[normalized])
|
||||
for _, alias := range driverGoModuleAliasPathMap[normalized] {
|
||||
appendPath(alias)
|
||||
}
|
||||
return paths
|
||||
}
|
||||
|
||||
func resolveDriverExtraHistoryLimit(driverType string) int {
|
||||
limit := driverExtraHistoryLimitMap[normalizeDriverType(driverType)]
|
||||
if limit < 0 {
|
||||
return 0
|
||||
}
|
||||
return limit
|
||||
}
|
||||
|
||||
func tryStartDriverVersionMetadataWarmup(now time.Time) bool {
|
||||
driverVersionWarmupMu.Lock()
|
||||
defer driverVersionWarmupMu.Unlock()
|
||||
@@ -2356,16 +2448,23 @@ func hashFileSHA256(filePath string) (string, error) {
|
||||
|
||||
func installOptionalDriverAgentPackage(a *App, definition driverDefinition, selectedVersion string, resolvedDir string, downloadURL string) (installedDriverPackage, error) {
|
||||
driverType := normalizeDriverType(definition.Type)
|
||||
executablePath, err := db.ResolveOptionalDriverAgentExecutablePath(resolvedDir, driverType)
|
||||
installPath, err := db.ResolveOptionalDriverAgentExecutablePathForVersion(resolvedDir, driverType, selectedVersion)
|
||||
if err != nil {
|
||||
return installedDriverPackage{}, err
|
||||
}
|
||||
downloadSource, hash, err := ensureOptionalDriverAgentBinary(a, definition, executablePath, downloadURL)
|
||||
runtimePath, err := db.ResolveOptionalDriverAgentExecutablePath(resolvedDir, driverType)
|
||||
if err != nil {
|
||||
return installedDriverPackage{}, err
|
||||
}
|
||||
downloadSource, hash, err := ensureOptionalDriverAgentBinary(a, definition, installPath, downloadURL, selectedVersion)
|
||||
if err != nil {
|
||||
return installedDriverPackage{}, err
|
||||
}
|
||||
if activateErr := activateOptionalDriverAgentBinary(installPath, runtimePath); activateErr != nil {
|
||||
return installedDriverPackage{}, fmt.Errorf("activate %s driver agent failed: %w", resolveDriverDisplayName(definition), activateErr)
|
||||
}
|
||||
if strings.TrimSpace(hash) == "" {
|
||||
hash, err = hashFileSHA256(executablePath)
|
||||
hash, err = hashFileSHA256(installPath)
|
||||
if err != nil {
|
||||
return installedDriverPackage{}, fmt.Errorf("计算 %s 驱动代理摘要失败:%w", resolveDriverDisplayName(definition), err)
|
||||
}
|
||||
@@ -2376,9 +2475,9 @@ func installOptionalDriverAgentPackage(a *App, definition driverDefinition, sele
|
||||
return installedDriverPackage{
|
||||
DriverType: driverType,
|
||||
Version: strings.TrimSpace(selectedVersion),
|
||||
FilePath: executablePath,
|
||||
FileName: filepath.Base(executablePath),
|
||||
ExecutablePath: executablePath,
|
||||
FilePath: installPath,
|
||||
FileName: filepath.Base(installPath),
|
||||
ExecutablePath: runtimePath,
|
||||
DownloadURL: strings.TrimSpace(downloadSource),
|
||||
SHA256: hash,
|
||||
DownloadedAt: time.Now().Format(time.RFC3339),
|
||||
@@ -2686,9 +2785,11 @@ func installOptionalDriverAgentFromLocalZip(zipPath string, definition driverDef
|
||||
return filepath.ToSlash(strings.TrimPrefix(strings.TrimSpace(entry.Name), "./")), nil
|
||||
}
|
||||
|
||||
func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, executablePath string, downloadURL string) (string, string, error) {
|
||||
func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, executablePath string, downloadURL string, selectedVersion string) (string, string, error) {
|
||||
driverType := normalizeDriverType(definition.Type)
|
||||
displayName := resolveDriverDisplayName(definition)
|
||||
forceSourceBuild := shouldForceSourceBuildForVersion(driverType, selectedVersion)
|
||||
skipReuseCandidate := shouldSkipReusableAgentCandidate(driverType, selectedVersion)
|
||||
|
||||
info, err := os.Stat(executablePath)
|
||||
if err == nil && !info.IsDir() {
|
||||
@@ -2708,49 +2809,53 @@ func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, execut
|
||||
if a != nil {
|
||||
a.emitDriverDownloadProgress(driverType, "downloading", 10, 100, "检查本地驱动代理缓存")
|
||||
}
|
||||
if sourcePath, ok := findExistingOptionalDriverAgentCandidate(definition, executablePath); ok {
|
||||
if copyErr := copyAgentBinary(sourcePath, executablePath); copyErr != nil {
|
||||
return "", "", fmt.Errorf("复制预置 %s 驱动代理失败:%w", displayName, copyErr)
|
||||
if !skipReuseCandidate {
|
||||
if sourcePath, ok := findExistingOptionalDriverAgentCandidate(definition, executablePath); ok {
|
||||
if copyErr := copyAgentBinary(sourcePath, executablePath); copyErr != nil {
|
||||
return "", "", fmt.Errorf("复制预置 %s 驱动代理失败:%w", displayName, copyErr)
|
||||
}
|
||||
hash, hashErr := hashFileSHA256(executablePath)
|
||||
if hashErr != nil {
|
||||
return "", "", fmt.Errorf("计算预置 %s 驱动代理摘要失败:%w", displayName, hashErr)
|
||||
}
|
||||
return "file://" + sourcePath, hash, nil
|
||||
}
|
||||
hash, hashErr := hashFileSHA256(executablePath)
|
||||
if hashErr != nil {
|
||||
return "", "", fmt.Errorf("计算预置 %s 驱动代理摘要失败:%w", displayName, hashErr)
|
||||
}
|
||||
return "file://" + sourcePath, hash, nil
|
||||
}
|
||||
|
||||
downloadURLs := resolveOptionalDriverAgentDownloadURLs(definition, downloadURL)
|
||||
var downloadErrs []string
|
||||
if len(downloadURLs) > 0 {
|
||||
for _, candidateURL := range downloadURLs {
|
||||
if a != nil {
|
||||
a.emitDriverDownloadProgress(driverType, "downloading", 20, 100, fmt.Sprintf("下载预编译 %s 驱动代理", displayName))
|
||||
if !forceSourceBuild {
|
||||
downloadURLs := resolveOptionalDriverAgentDownloadURLs(definition, downloadURL)
|
||||
if len(downloadURLs) > 0 {
|
||||
for _, candidateURL := range downloadURLs {
|
||||
if a != nil {
|
||||
a.emitDriverDownloadProgress(driverType, "downloading", 20, 100, fmt.Sprintf("下载预编译 %s 驱动代理", displayName))
|
||||
}
|
||||
hash, dlErr := downloadOptionalDriverAgentBinary(a, definition, candidateURL, executablePath)
|
||||
if dlErr == nil {
|
||||
return candidateURL, hash, nil
|
||||
}
|
||||
downloadErrs = append(downloadErrs, fmt.Sprintf("%s: %s", candidateURL, strings.TrimSpace(dlErr.Error())))
|
||||
}
|
||||
hash, dlErr := downloadOptionalDriverAgentBinary(a, definition, candidateURL, executablePath)
|
||||
if dlErr == nil {
|
||||
return candidateURL, hash, nil
|
||||
}
|
||||
downloadErrs = append(downloadErrs, fmt.Sprintf("%s: %s", candidateURL, strings.TrimSpace(dlErr.Error())))
|
||||
}
|
||||
}
|
||||
bundleURLs := resolveOptionalDriverBundleDownloadURLs()
|
||||
if len(bundleURLs) > 0 {
|
||||
for _, bundleURL := range bundleURLs {
|
||||
if a != nil {
|
||||
a.emitDriverDownloadProgress(driverType, "downloading", 20, 100, fmt.Sprintf("从驱动总包提取 %s 代理", displayName))
|
||||
bundleURLs := resolveOptionalDriverBundleDownloadURLs()
|
||||
if len(bundleURLs) > 0 {
|
||||
for _, bundleURL := range bundleURLs {
|
||||
if a != nil {
|
||||
a.emitDriverDownloadProgress(driverType, "downloading", 20, 100, fmt.Sprintf("从驱动总包提取 %s 代理", displayName))
|
||||
}
|
||||
source, hash, bundleErr := downloadOptionalDriverAgentFromBundle(a, definition, bundleURL, executablePath)
|
||||
if bundleErr == nil {
|
||||
return source, hash, nil
|
||||
}
|
||||
downloadErrs = append(downloadErrs, fmt.Sprintf("%s: %s", bundleURL, strings.TrimSpace(bundleErr.Error())))
|
||||
}
|
||||
source, hash, bundleErr := downloadOptionalDriverAgentFromBundle(a, definition, bundleURL, executablePath)
|
||||
if bundleErr == nil {
|
||||
return source, hash, nil
|
||||
}
|
||||
downloadErrs = append(downloadErrs, fmt.Sprintf("%s: %s", bundleURL, strings.TrimSpace(bundleErr.Error())))
|
||||
}
|
||||
}
|
||||
if a != nil {
|
||||
a.emitDriverDownloadProgress(driverType, "downloading", 92, 100, "未命中预编译包,尝试开发态本地构建")
|
||||
}
|
||||
|
||||
hash, buildErr := buildOptionalDriverAgentFromSource(definition, executablePath)
|
||||
hash, buildErr := buildOptionalDriverAgentFromSource(definition, executablePath, selectedVersion)
|
||||
if buildErr == nil {
|
||||
return fmt.Sprintf("local://go-build/%s-driver-agent", driverType), hash, nil
|
||||
}
|
||||
@@ -2912,7 +3017,7 @@ func downloadOptionalDriverAgentFromBundle(a *App, definition driverDefinition,
|
||||
return source, hash, nil
|
||||
}
|
||||
|
||||
func buildOptionalDriverAgentFromSource(definition driverDefinition, executablePath string) (string, error) {
|
||||
func buildOptionalDriverAgentFromSource(definition driverDefinition, executablePath string, selectedVersion string) (string, error) {
|
||||
driverType := normalizeDriverType(definition.Type)
|
||||
displayName := resolveDriverDisplayName(definition)
|
||||
goPath, lookErr := exec.LookPath("go")
|
||||
@@ -2920,7 +3025,7 @@ func buildOptionalDriverAgentFromSource(definition driverDefinition, executableP
|
||||
return "", fmt.Errorf("当前环境未安装 Go,且未找到可用的 %s 预编译代理包", displayName)
|
||||
}
|
||||
|
||||
tagName, tagErr := optionalDriverBuildTag(driverType)
|
||||
tagName, tagErr := optionalDriverBuildTag(driverType, selectedVersion)
|
||||
if tagErr != nil {
|
||||
return "", tagErr
|
||||
}
|
||||
@@ -2931,6 +3036,7 @@ func buildOptionalDriverAgentFromSource(definition driverDefinition, executableP
|
||||
}
|
||||
cmd := exec.Command(goPath, "build", "-tags", tagName, "-trimpath", "-ldflags", "-s -w", "-o", executablePath, "./cmd/optional-driver-agent")
|
||||
cmd.Dir = projectRoot
|
||||
cmd.Env = append(os.Environ(), "GOTOOLCHAIN=auto")
|
||||
output, buildErr := cmd.CombinedOutput()
|
||||
if buildErr != nil {
|
||||
return "", fmt.Errorf("构建 %s 驱动代理失败:%v,输出:%s", displayName, buildErr, strings.TrimSpace(string(output)))
|
||||
@@ -2945,7 +3051,31 @@ func buildOptionalDriverAgentFromSource(definition driverDefinition, executableP
|
||||
return hash, nil
|
||||
}
|
||||
|
||||
func optionalDriverBuildTag(driverType string) (string, error) {
|
||||
func resolveMongoDriverMajorFromVersion(version string) int {
|
||||
trimmed := strings.TrimSpace(version)
|
||||
trimmed = strings.TrimPrefix(trimmed, "v")
|
||||
if strings.HasPrefix(trimmed, "1.") || trimmed == "1" {
|
||||
return 1
|
||||
}
|
||||
return 2
|
||||
}
|
||||
|
||||
func shouldForceSourceBuildForVersion(driverType string, selectedVersion string) bool {
|
||||
if normalizeDriverType(driverType) != "mongodb" {
|
||||
return false
|
||||
}
|
||||
return resolveMongoDriverMajorFromVersion(selectedVersion) == 1
|
||||
}
|
||||
|
||||
func shouldSkipReusableAgentCandidate(driverType string, selectedVersion string) bool {
|
||||
if normalizeDriverType(driverType) != "mongodb" {
|
||||
return false
|
||||
}
|
||||
_ = selectedVersion
|
||||
return true
|
||||
}
|
||||
|
||||
func optionalDriverBuildTag(driverType string, selectedVersion string) (string, error) {
|
||||
switch normalizeDriverType(driverType) {
|
||||
case "mysql":
|
||||
return "gonavi_mysql_driver", nil
|
||||
@@ -2970,6 +3100,9 @@ func optionalDriverBuildTag(driverType string) (string, error) {
|
||||
case "vastbase":
|
||||
return "gonavi_vastbase_driver", nil
|
||||
case "mongodb":
|
||||
if resolveMongoDriverMajorFromVersion(selectedVersion) == 1 {
|
||||
return "gonavi_mongodb_driver_v1", nil
|
||||
}
|
||||
return "gonavi_mongodb_driver", nil
|
||||
case "tdengine":
|
||||
return "gonavi_tdengine_driver", nil
|
||||
@@ -3310,6 +3443,30 @@ func resolveDriverDisplayName(definition driverDefinition) string {
|
||||
return "未知"
|
||||
}
|
||||
|
||||
func activateOptionalDriverAgentBinary(installPath string, runtimePath string) error {
|
||||
source := strings.TrimSpace(installPath)
|
||||
target := strings.TrimSpace(runtimePath)
|
||||
if source == "" || target == "" {
|
||||
return fmt.Errorf("agent path is empty")
|
||||
}
|
||||
if source == target {
|
||||
return nil
|
||||
}
|
||||
|
||||
absSource := source
|
||||
absTarget := target
|
||||
if value, err := filepath.Abs(source); err == nil && strings.TrimSpace(value) != "" {
|
||||
absSource = value
|
||||
}
|
||||
if value, err := filepath.Abs(target); err == nil && strings.TrimSpace(value) != "" {
|
||||
absTarget = value
|
||||
}
|
||||
if strings.EqualFold(absSource, absTarget) {
|
||||
return nil
|
||||
}
|
||||
return copyAgentBinary(source, target)
|
||||
}
|
||||
|
||||
func copyAgentBinary(sourcePath, targetPath string) error {
|
||||
src, err := os.Open(sourcePath)
|
||||
if err != nil {
|
||||
|
||||
@@ -27,6 +27,10 @@ type ConnectionConfig struct {
|
||||
Password string `json:"password"`
|
||||
SavePassword bool `json:"savePassword,omitempty"` // Persist password in saved connection
|
||||
Database string `json:"database"`
|
||||
UseSSL bool `json:"useSSL,omitempty"` // MySQL-like SSL/TLS switch
|
||||
SSLMode string `json:"sslMode,omitempty"` // preferred | required | skip-verify | disable
|
||||
SSLCertPath string `json:"sslCertPath,omitempty"` // TLS client certificate path (e.g., Dameng)
|
||||
SSLKeyPath string `json:"sslKeyPath,omitempty"` // TLS client private key path (e.g., Dameng)
|
||||
UseSSH bool `json:"useSSH"`
|
||||
SSH SSHConfig `json:"ssh"`
|
||||
UseProxy bool `json:"useProxy,omitempty"`
|
||||
|
||||
@@ -107,7 +107,7 @@ func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig
|
||||
if readTimeout < minClickHouseReadTimeout {
|
||||
readTimeout = minClickHouseReadTimeout
|
||||
}
|
||||
return &clickhouse.Options{
|
||||
opts := &clickhouse.Options{
|
||||
Addr: []string{
|
||||
net.JoinHostPort(config.Host, strconv.Itoa(config.Port)),
|
||||
},
|
||||
@@ -119,6 +119,10 @@ func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig
|
||||
DialTimeout: connectTimeout,
|
||||
ReadTimeout: readTimeout,
|
||||
}
|
||||
if tlsConfig := resolveGenericTLSConfig(config); tlsConfig != nil {
|
||||
opts.TLS = tlsConfig
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
func (c *ClickHouseDB) Connect(config connection.ConnectionConfig) error {
|
||||
@@ -165,13 +169,30 @@ func (c *ClickHouseDB) Connect(config connection.ConnectionConfig) error {
|
||||
logger.Infof("ClickHouse 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
|
||||
}
|
||||
|
||||
c.conn = clickhouse.OpenDB(c.buildClickHouseOptions(runConfig))
|
||||
|
||||
if err := c.Ping(); err != nil {
|
||||
_ = c.Close()
|
||||
return fmt.Errorf("连接建立后验证失败:%w", err)
|
||||
attempts := []connection.ConnectionConfig{runConfig}
|
||||
if shouldTrySSLPreferredFallback(runConfig) {
|
||||
attempts = append(attempts, withSSLDisabled(runConfig))
|
||||
}
|
||||
return nil
|
||||
|
||||
var failures []string
|
||||
for idx, attempt := range attempts {
|
||||
c.conn = clickhouse.OpenDB(c.buildClickHouseOptions(attempt))
|
||||
if err := c.Ping(); err != nil {
|
||||
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err))
|
||||
if c.conn != nil {
|
||||
_ = c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
if idx > 0 {
|
||||
logger.Warnf("ClickHouse SSL 优先连接失败,已回退至明文连接")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
_ = c.Close()
|
||||
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ";"))
|
||||
}
|
||||
|
||||
func (c *ClickHouseDB) Close() error {
|
||||
|
||||
@@ -36,6 +36,14 @@ func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
|
||||
if config.Database != "" {
|
||||
q.Set("schema", config.Database)
|
||||
}
|
||||
if config.UseSSL {
|
||||
if certPath := strings.TrimSpace(config.SSLCertPath); certPath != "" {
|
||||
q.Set("SSL_CERT_PATH", certPath)
|
||||
}
|
||||
if keyPath := strings.TrimSpace(config.SSLKeyPath); keyPath != "" {
|
||||
q.Set("SSL_KEY_PATH", keyPath)
|
||||
}
|
||||
}
|
||||
if escapedPassword != config.Password {
|
||||
// 达梦驱动要求:密码包含特殊字符时,password 需 PathEscape,并添加 escapeProcess=true 让驱动解码。
|
||||
q.Set("escapeProcess", "true")
|
||||
@@ -50,8 +58,12 @@ func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
|
||||
}
|
||||
|
||||
func (d *DamengDB) Connect(config connection.ConnectionConfig) error {
|
||||
var dsn string
|
||||
var err error
|
||||
runConfig := config
|
||||
if runConfig.UseSSL {
|
||||
if strings.TrimSpace(runConfig.SSLCertPath) == "" || strings.TrimSpace(runConfig.SSLKeyPath) == "" {
|
||||
return fmt.Errorf("达梦启用 SSL 需要同时配置证书路径(sslCertPath)与私钥路径(sslKeyPath)")
|
||||
}
|
||||
}
|
||||
|
||||
if config.UseSSH {
|
||||
// Create SSH tunnel with local port forwarding
|
||||
@@ -80,22 +92,37 @@ func (d *DamengDB) Connect(config connection.ConnectionConfig) error {
|
||||
localConfig.Port = port
|
||||
localConfig.UseSSH = false
|
||||
|
||||
dsn = d.getDSN(localConfig)
|
||||
runConfig = localConfig
|
||||
logger.Infof("达梦数据库通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
|
||||
} else {
|
||||
dsn = d.getDSN(config)
|
||||
}
|
||||
|
||||
db, err := sql.Open("dm", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开数据库连接失败:%w", err)
|
||||
attempts := []connection.ConnectionConfig{runConfig}
|
||||
if shouldTrySSLPreferredFallback(runConfig) {
|
||||
attempts = append(attempts, withSSLDisabled(runConfig))
|
||||
}
|
||||
d.conn = db
|
||||
d.pingTimeout = getConnectTimeout(config)
|
||||
if err := d.Ping(); err != nil {
|
||||
return fmt.Errorf("连接建立后验证失败:%w", err)
|
||||
|
||||
var failures []string
|
||||
for idx, attempt := range attempts {
|
||||
dsn := d.getDSN(attempt)
|
||||
db, err := sql.Open("dm", dsn)
|
||||
if err != nil {
|
||||
failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err))
|
||||
continue
|
||||
}
|
||||
d.conn = db
|
||||
d.pingTimeout = getConnectTimeout(attempt)
|
||||
if err := d.Ping(); err != nil {
|
||||
_ = db.Close()
|
||||
d.conn = nil
|
||||
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err))
|
||||
continue
|
||||
}
|
||||
if idx > 0 {
|
||||
logger.Warnf("达梦 SSL 优先连接失败,已回退至明文连接")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ";"))
|
||||
}
|
||||
|
||||
func (d *DamengDB) Close() error {
|
||||
|
||||
@@ -151,9 +151,10 @@ func (d *DirosDB) getDSN(config connection.ConnectionConfig) string {
|
||||
}
|
||||
|
||||
timeout := getConnectTimeoutSeconds(config)
|
||||
tlsMode := resolveMySQLTLSMode(config)
|
||||
|
||||
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds",
|
||||
config.User, config.Password, protocol, address, database, timeout)
|
||||
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s",
|
||||
config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode))
|
||||
}
|
||||
|
||||
func resolveDirosCredential(config connection.ConnectionConfig, addressIndex int) (string, string) {
|
||||
|
||||
@@ -33,6 +33,44 @@ func TestPostgresDSN_EscapesPassword(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresDSN_SSLModeRequireWhenEnabled(t *testing.T) {
|
||||
p := &PostgresDB{}
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "postgres",
|
||||
Host: "127.0.0.1",
|
||||
Port: 5432,
|
||||
User: "user",
|
||||
Password: "pass",
|
||||
Database: "db",
|
||||
UseSSL: true,
|
||||
SSLMode: "required",
|
||||
}
|
||||
|
||||
dsn := p.getDSN(cfg)
|
||||
if !strings.Contains(dsn, "sslmode=require") {
|
||||
t.Fatalf("dsn 缺少 sslmode=require 参数:%s", dsn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQLDSN_UsesTLSParamWhenSSLEnabled(t *testing.T) {
|
||||
m := &MySQLDB{}
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "mysql",
|
||||
Host: "127.0.0.1",
|
||||
Port: 3306,
|
||||
User: "root",
|
||||
Password: "pass",
|
||||
Database: "db",
|
||||
UseSSL: true,
|
||||
SSLMode: "required",
|
||||
}
|
||||
|
||||
dsn := m.getDSN(cfg)
|
||||
if !strings.Contains(dsn, "tls=true") {
|
||||
t.Fatalf("dsn 缺少 tls=true 参数:%s", dsn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOracleDSN_EscapesUserAndPassword(t *testing.T) {
|
||||
o := &OracleDB{}
|
||||
cfg := connection.ConnectionConfig{
|
||||
@@ -82,6 +120,30 @@ func TestDamengDSN_EscapesPasswordAndEnablesEscapeProcess(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDamengDSN_AppendsSSLCertAndKeyParams(t *testing.T) {
|
||||
d := &DamengDB{}
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "dameng",
|
||||
Host: "127.0.0.1",
|
||||
Port: 5236,
|
||||
User: "SYSDBA",
|
||||
Password: "pass",
|
||||
Database: "DBName",
|
||||
UseSSL: true,
|
||||
SSLMode: "required",
|
||||
SSLCertPath: "C:\\certs\\client-cert.pem",
|
||||
SSLKeyPath: "C:\\certs\\client-key.pem",
|
||||
}
|
||||
|
||||
dsn := d.getDSN(cfg)
|
||||
if !strings.Contains(dsn, "SSL_CERT_PATH=") {
|
||||
t.Fatalf("dsn 缺少 SSL_CERT_PATH 参数:%s", dsn)
|
||||
}
|
||||
if !strings.Contains(dsn, "SSL_KEY_PATH=") {
|
||||
t.Fatalf("dsn 缺少 SSL_KEY_PATH 参数:%s", dsn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKingbaseDSN_QuotesPasswordWithSpaces(t *testing.T) {
|
||||
k := &KingbaseDB{}
|
||||
cfg := connection.ConnectionConfig{
|
||||
@@ -116,6 +178,47 @@ func TestTDengineDSN_UsesWebSocketFormat(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTDengineDSN_UsesSecureWebSocketWhenSSLEnabled(t *testing.T) {
|
||||
td := &TDengineDB{}
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "tdengine",
|
||||
Host: "127.0.0.1",
|
||||
Port: 6041,
|
||||
User: "root",
|
||||
Password: "taosdata",
|
||||
Database: "power",
|
||||
UseSSL: true,
|
||||
SSLMode: "required",
|
||||
}
|
||||
|
||||
dsn := td.getDSN(cfg)
|
||||
if !strings.HasPrefix(dsn, "root:taosdata@wss(127.0.0.1:6041)/power") {
|
||||
t.Fatalf("tdengine ssl dsn 格式不正确:%s", dsn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLServerDSN_EncryptMapping(t *testing.T) {
|
||||
s := &SqlServerDB{}
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "sqlserver",
|
||||
Host: "127.0.0.1",
|
||||
Port: 1433,
|
||||
User: "sa",
|
||||
Password: "pass",
|
||||
Database: "master",
|
||||
UseSSL: true,
|
||||
SSLMode: "required",
|
||||
}
|
||||
|
||||
dsn := s.getDSN(cfg)
|
||||
if !strings.Contains(strings.ToLower(dsn), "encrypt=true") {
|
||||
t.Fatalf("sqlserver dsn 缺少 encrypt=true:%s", dsn)
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(dsn), "trustservercertificate=false") {
|
||||
t.Fatalf("sqlserver dsn 缺少 TrustServerCertificate=false:%s", dsn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClickHouseOptions_UsesStructuredTimeoutAndAuth(t *testing.T) {
|
||||
c := &ClickHouseDB{}
|
||||
cfg := normalizeClickHouseConfig(connection.ConnectionConfig{
|
||||
|
||||
@@ -42,7 +42,7 @@ func (h *HighGoDB) getDSN(config connection.ConnectionConfig) string {
|
||||
}
|
||||
u.User = url.UserPassword(config.User, config.Password)
|
||||
q := url.Values{}
|
||||
q.Set("sslmode", "disable")
|
||||
q.Set("sslmode", resolvePostgresSSLMode(config))
|
||||
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
@@ -50,7 +50,7 @@ func (h *HighGoDB) getDSN(config connection.ConnectionConfig) string {
|
||||
}
|
||||
|
||||
func (h *HighGoDB) Connect(config connection.ConnectionConfig) error {
|
||||
var dsn string
|
||||
runConfig := config
|
||||
|
||||
if config.UseSSH {
|
||||
logger.Infof("HighGo 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
|
||||
@@ -76,23 +76,37 @@ func (h *HighGoDB) Connect(config connection.ConnectionConfig) error {
|
||||
localConfig.Port = port
|
||||
localConfig.UseSSH = false
|
||||
|
||||
dsn = h.getDSN(localConfig)
|
||||
runConfig = localConfig
|
||||
logger.Infof("HighGo 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
|
||||
} else {
|
||||
dsn = h.getDSN(config)
|
||||
}
|
||||
|
||||
db, err := sql.Open("highgo", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开数据库连接失败:%w", err)
|
||||
attempts := []connection.ConnectionConfig{runConfig}
|
||||
if shouldTrySSLPreferredFallback(runConfig) {
|
||||
attempts = append(attempts, withSSLDisabled(runConfig))
|
||||
}
|
||||
h.conn = db
|
||||
h.pingTimeout = getConnectTimeout(config)
|
||||
|
||||
if err := h.Ping(); err != nil {
|
||||
return fmt.Errorf("连接建立后验证失败:%w", err)
|
||||
var failures []string
|
||||
for idx, attempt := range attempts {
|
||||
dsn := h.getDSN(attempt)
|
||||
db, err := sql.Open("highgo", dsn)
|
||||
if err != nil {
|
||||
failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err))
|
||||
continue
|
||||
}
|
||||
h.conn = db
|
||||
h.pingTimeout = getConnectTimeout(attempt)
|
||||
if err := h.Ping(); err != nil {
|
||||
_ = db.Close()
|
||||
h.conn = nil
|
||||
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err))
|
||||
continue
|
||||
}
|
||||
if idx > 0 {
|
||||
logger.Warnf("HighGo SSL 优先连接失败,已回退至明文连接")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ";"))
|
||||
}
|
||||
|
||||
func (h *HighGoDB) Close() error {
|
||||
|
||||
@@ -65,12 +65,13 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string {
|
||||
port := config.Port
|
||||
|
||||
// Construct DSN
|
||||
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable connect_timeout=%d",
|
||||
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s connect_timeout=%d",
|
||||
quoteConnValue(address),
|
||||
port,
|
||||
quoteConnValue(config.User),
|
||||
quoteConnValue(config.Password),
|
||||
quoteConnValue(config.Database),
|
||||
quoteConnValue(resolvePostgresSSLMode(config)),
|
||||
getConnectTimeoutSeconds(config),
|
||||
)
|
||||
|
||||
@@ -78,8 +79,7 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string {
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error {
|
||||
var dsn string
|
||||
var err error
|
||||
runConfig := config
|
||||
|
||||
if config.UseSSH {
|
||||
// Create SSH tunnel with local port forwarding
|
||||
@@ -108,23 +108,37 @@ func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error {
|
||||
localConfig.Port = port
|
||||
localConfig.UseSSH = false
|
||||
|
||||
dsn = k.getDSN(localConfig)
|
||||
runConfig = localConfig
|
||||
logger.Infof("人大金仓通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
|
||||
} else {
|
||||
dsn = k.getDSN(config)
|
||||
}
|
||||
|
||||
// Open using "kingbase" driver
|
||||
db, err := sql.Open("kingbase", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开数据库连接失败:%w", err)
|
||||
attempts := []connection.ConnectionConfig{runConfig}
|
||||
if shouldTrySSLPreferredFallback(runConfig) {
|
||||
attempts = append(attempts, withSSLDisabled(runConfig))
|
||||
}
|
||||
k.conn = db
|
||||
k.pingTimeout = getConnectTimeout(config)
|
||||
if err := k.Ping(); err != nil {
|
||||
return fmt.Errorf("连接建立后验证失败:%w", err)
|
||||
|
||||
var failures []string
|
||||
for idx, attempt := range attempts {
|
||||
dsn := k.getDSN(attempt)
|
||||
db, err := sql.Open("kingbase", dsn)
|
||||
if err != nil {
|
||||
failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err))
|
||||
continue
|
||||
}
|
||||
k.conn = db
|
||||
k.pingTimeout = getConnectTimeout(attempt)
|
||||
if err := k.Ping(); err != nil {
|
||||
_ = db.Close()
|
||||
k.conn = nil
|
||||
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err))
|
||||
continue
|
||||
}
|
||||
if idx > 0 {
|
||||
logger.Warnf("人大金仓 SSL 优先连接失败,已回退至明文连接")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ";"))
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) Close() error {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -40,9 +41,10 @@ func (m *MariaDB) getDSN(config connection.ConnectionConfig) string {
|
||||
}
|
||||
|
||||
timeout := getConnectTimeoutSeconds(config)
|
||||
tlsMode := resolveMySQLTLSMode(config)
|
||||
|
||||
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds",
|
||||
config.User, config.Password, protocol, address, database, timeout)
|
||||
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s",
|
||||
config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode))
|
||||
}
|
||||
|
||||
func (m *MariaDB) Connect(config connection.ConnectionConfig) error {
|
||||
|
||||
@@ -4,6 +4,7 @@ package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
@@ -327,35 +328,60 @@ func (m *MongoDB) Connect(config connection.ConnectionConfig) error {
|
||||
m.database = "admin"
|
||||
}
|
||||
|
||||
attemptConfigs := buildMongoAuthAttempts(connectConfig)
|
||||
sslAttempts := []connection.ConnectionConfig{connectConfig}
|
||||
if shouldTrySSLPreferredFallback(connectConfig) {
|
||||
sslAttempts = append(sslAttempts, withSSLDisabled(connectConfig))
|
||||
}
|
||||
|
||||
var errorDetails []string
|
||||
for index, attemptConfig := range attemptConfigs {
|
||||
authLabel := "主库凭据"
|
||||
if index > 0 {
|
||||
authLabel = "从库凭据"
|
||||
for sslIndex, sslConfig := range sslAttempts {
|
||||
sslLabel := "SSL"
|
||||
if sslIndex > 0 {
|
||||
sslLabel = "明文回退"
|
||||
}
|
||||
|
||||
uri := m.getURI(attemptConfig)
|
||||
clientOpts := options.Client().ApplyURI(uri)
|
||||
if attemptConfig.UseProxy {
|
||||
clientOpts.SetDialer(&mongoProxyDialer{proxyConfig: attemptConfig.Proxy})
|
||||
}
|
||||
client, err := mongo.Connect(clientOpts)
|
||||
if err != nil {
|
||||
errorDetails = append(errorDetails, fmt.Sprintf("%s连接失败: %v", authLabel, err))
|
||||
continue
|
||||
}
|
||||
attemptConfigs := buildMongoAuthAttempts(sslConfig)
|
||||
for index, attemptConfig := range attemptConfigs {
|
||||
authLabel := "主库凭据"
|
||||
if index > 0 {
|
||||
authLabel = "从库凭据"
|
||||
}
|
||||
|
||||
m.client = client
|
||||
if err := m.Ping(); err != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
_ = client.Disconnect(ctx)
|
||||
cancel()
|
||||
m.client = nil
|
||||
errorDetails = append(errorDetails, fmt.Sprintf("%s验证失败: %v", authLabel, err))
|
||||
continue
|
||||
if sslIndex > 0 {
|
||||
attemptConfig.URI = ""
|
||||
}
|
||||
uri := m.getURI(attemptConfig)
|
||||
clientOpts := options.Client().ApplyURI(uri)
|
||||
tlsEnabled, tlsInsecure := resolveMongoTLSSettings(attemptConfig)
|
||||
if tlsEnabled {
|
||||
clientOpts.SetTLSConfig(&tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
InsecureSkipVerify: tlsInsecure,
|
||||
})
|
||||
}
|
||||
if attemptConfig.UseProxy {
|
||||
clientOpts.SetDialer(&mongoProxyDialer{proxyConfig: attemptConfig.Proxy})
|
||||
}
|
||||
client, err := mongo.Connect(clientOpts)
|
||||
if err != nil {
|
||||
errorDetails = append(errorDetails, fmt.Sprintf("%s %s连接失败: %v", sslLabel, authLabel, err))
|
||||
continue
|
||||
}
|
||||
|
||||
m.client = client
|
||||
if err := m.Ping(); err != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
_ = client.Disconnect(ctx)
|
||||
cancel()
|
||||
m.client = nil
|
||||
errorDetails = append(errorDetails, fmt.Sprintf("%s %s验证失败: %v", sslLabel, authLabel, err))
|
||||
continue
|
||||
}
|
||||
if sslIndex > 0 {
|
||||
logger.Warnf("MongoDB SSL 优先连接失败,已回退至明文连接")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(errorDetails) > 0 {
|
||||
|
||||
1187
internal/db/mongodb_impl_v1.go
Normal file
1187
internal/db/mongodb_impl_v1.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -35,6 +35,32 @@ func ResolveOptionalDriverAgentExecutablePath(downloadDir string, driverType str
|
||||
return filepath.Join(root, normalized, optionalDriverAgentExecutableName(normalized)), nil
|
||||
}
|
||||
|
||||
func ResolveOptionalDriverAgentExecutablePathForVersion(downloadDir string, driverType string, version string) (string, error) {
|
||||
normalized := normalizeRuntimeDriverType(driverType)
|
||||
if strings.TrimSpace(normalized) == "" {
|
||||
return "", fmt.Errorf("驱动类型为空")
|
||||
}
|
||||
root, err := resolveExternalDriverRoot(downloadDir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if normalized != "mongodb" {
|
||||
return filepath.Join(root, normalized, optionalDriverAgentExecutableName(normalized)), nil
|
||||
}
|
||||
|
||||
baseName := optionalDriverAgentExecutableName(normalized)
|
||||
ext := filepath.Ext(baseName)
|
||||
stem := strings.TrimSuffix(baseName, ext)
|
||||
major := 2
|
||||
trimmed := strings.TrimSpace(version)
|
||||
trimmed = strings.TrimPrefix(trimmed, "v")
|
||||
if strings.HasPrefix(trimmed, "1.") || trimmed == "1" {
|
||||
major = 1
|
||||
}
|
||||
versionedName := fmt.Sprintf("%s-v%d%s", stem, major, ext)
|
||||
return filepath.Join(root, normalized, versionedName), nil
|
||||
}
|
||||
func ResolveMySQLAgentExecutablePath(downloadDir string) (string, error) {
|
||||
return ResolveOptionalDriverAgentExecutablePath(downloadDir, "mysql")
|
||||
}
|
||||
|
||||
@@ -184,9 +184,10 @@ func (m *MySQLDB) getDSN(config connection.ConnectionConfig) string {
|
||||
}
|
||||
|
||||
timeout := getConnectTimeoutSeconds(config)
|
||||
tlsMode := resolveMySQLTLSMode(config)
|
||||
|
||||
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds",
|
||||
config.User, config.Password, protocol, address, database, timeout)
|
||||
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s",
|
||||
config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode))
|
||||
}
|
||||
|
||||
func resolveMySQLCredential(config connection.ConnectionConfig, addressIndex int) (string, string) {
|
||||
|
||||
@@ -35,12 +35,23 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
|
||||
}
|
||||
u.User = url.UserPassword(config.User, config.Password)
|
||||
u.RawPath = "/" + url.PathEscape(database)
|
||||
q := url.Values{}
|
||||
switch normalizedSSLMode(config) {
|
||||
case sslModeRequired:
|
||||
q.Set("SSL", "TRUE")
|
||||
q.Set("SSL VERIFY", "TRUE")
|
||||
case sslModeSkipVerify, sslModePreferred:
|
||||
q.Set("SSL", "TRUE")
|
||||
q.Set("SSL VERIFY", "FALSE")
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
u.RawQuery = encoded
|
||||
}
|
||||
return u.String()
|
||||
}
|
||||
|
||||
func (o *OracleDB) Connect(config connection.ConnectionConfig) error {
|
||||
var dsn string
|
||||
var err error
|
||||
runConfig := config
|
||||
serviceName := strings.TrimSpace(config.Database)
|
||||
if serviceName == "" {
|
||||
return fmt.Errorf("Oracle 连接缺少服务名(Service Name),请在连接配置中填写,例如 ORCLPDB1")
|
||||
@@ -73,22 +84,37 @@ func (o *OracleDB) Connect(config connection.ConnectionConfig) error {
|
||||
localConfig.Port = port
|
||||
localConfig.UseSSH = false
|
||||
|
||||
dsn = o.getDSN(localConfig)
|
||||
runConfig = localConfig
|
||||
logger.Infof("Oracle 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
|
||||
} else {
|
||||
dsn = o.getDSN(config)
|
||||
}
|
||||
|
||||
db, err := sql.Open("oracle", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开数据库连接失败:%w", err)
|
||||
attempts := []connection.ConnectionConfig{runConfig}
|
||||
if shouldTrySSLPreferredFallback(runConfig) {
|
||||
attempts = append(attempts, withSSLDisabled(runConfig))
|
||||
}
|
||||
o.conn = db
|
||||
o.pingTimeout = getConnectTimeout(config)
|
||||
if err := o.Ping(); err != nil {
|
||||
return fmt.Errorf("连接建立后验证失败:%w", err)
|
||||
|
||||
var failures []string
|
||||
for idx, attempt := range attempts {
|
||||
dsn := o.getDSN(attempt)
|
||||
db, err := sql.Open("oracle", dsn)
|
||||
if err != nil {
|
||||
failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err))
|
||||
continue
|
||||
}
|
||||
o.conn = db
|
||||
o.pingTimeout = getConnectTimeout(attempt)
|
||||
if err := o.Ping(); err != nil {
|
||||
_ = db.Close()
|
||||
o.conn = nil
|
||||
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err))
|
||||
continue
|
||||
}
|
||||
if idx > 0 {
|
||||
logger.Warnf("Oracle SSL 优先连接失败,已回退至明文连接")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ";"))
|
||||
}
|
||||
|
||||
func (o *OracleDB) Close() error {
|
||||
|
||||
@@ -62,7 +62,7 @@ func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string {
|
||||
}
|
||||
u.User = url.UserPassword(config.User, config.Password)
|
||||
q := url.Values{}
|
||||
q.Set("sslmode", "disable")
|
||||
q.Set("sslmode", resolvePostgresSSLMode(config))
|
||||
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
@@ -126,34 +126,49 @@ func (p *PostgresDB) Connect(config connection.ConnectionConfig) error {
|
||||
logger.Infof("PostgreSQL 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
|
||||
}
|
||||
|
||||
attemptDBs := resolvePostgresConnectDatabases(runConfig)
|
||||
sslAttempts := []connection.ConnectionConfig{runConfig}
|
||||
if shouldTrySSLPreferredFallback(runConfig) {
|
||||
sslAttempts = append(sslAttempts, withSSLDisabled(runConfig))
|
||||
}
|
||||
|
||||
var failures []string
|
||||
for _, dbName := range attemptDBs {
|
||||
attemptConfig := runConfig
|
||||
attemptConfig.Database = dbName
|
||||
dsn := p.getDSN(attemptConfig)
|
||||
|
||||
dbConn, err := sql.Open("postgres", dsn)
|
||||
if err != nil {
|
||||
failures = append(failures, fmt.Sprintf("数据库=%s 打开连接失败: %v", dbName, err))
|
||||
continue
|
||||
}
|
||||
p.conn = dbConn
|
||||
|
||||
// Force verification
|
||||
if err := p.Ping(); err != nil {
|
||||
failures = append(failures, fmt.Sprintf("数据库=%s 验证失败: %v", dbName, err))
|
||||
_ = dbConn.Close()
|
||||
p.conn = nil
|
||||
continue
|
||||
for sslIndex, sslConfig := range sslAttempts {
|
||||
sslLabel := "SSL"
|
||||
if sslIndex > 0 {
|
||||
sslLabel = "明文回退"
|
||||
}
|
||||
|
||||
if strings.TrimSpace(config.Database) == "" && !strings.EqualFold(dbName, "postgres") {
|
||||
logger.Infof("PostgreSQL 自动选择连接数据库:%s", dbName)
|
||||
}
|
||||
attemptDBs := resolvePostgresConnectDatabases(sslConfig)
|
||||
for _, dbName := range attemptDBs {
|
||||
attemptConfig := sslConfig
|
||||
attemptConfig.Database = dbName
|
||||
dsn := p.getDSN(attemptConfig)
|
||||
|
||||
cleanupOnFailure = false
|
||||
return nil
|
||||
dbConn, err := sql.Open("postgres", dsn)
|
||||
if err != nil {
|
||||
failures = append(failures, fmt.Sprintf("%s 数据库=%s 打开连接失败: %v", sslLabel, dbName, err))
|
||||
continue
|
||||
}
|
||||
p.conn = dbConn
|
||||
|
||||
// Force verification
|
||||
if err := p.Ping(); err != nil {
|
||||
failures = append(failures, fmt.Sprintf("%s 数据库=%s 验证失败: %v", sslLabel, dbName, err))
|
||||
_ = dbConn.Close()
|
||||
p.conn = nil
|
||||
continue
|
||||
}
|
||||
|
||||
if sslIndex > 0 {
|
||||
logger.Warnf("PostgreSQL SSL 优先连接失败,已回退至明文连接")
|
||||
}
|
||||
if strings.TrimSpace(config.Database) == "" && !strings.EqualFold(dbName, "postgres") {
|
||||
logger.Infof("PostgreSQL 自动选择连接数据库:%s", dbName)
|
||||
}
|
||||
|
||||
cleanupOnFailure = false
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if len(failures) == 0 {
|
||||
|
||||
@@ -47,8 +47,9 @@ func (s *SqlServerDB) getDSN(config connection.ConnectionConfig) string {
|
||||
q := url.Values{}
|
||||
q.Set("database", dbname)
|
||||
q.Set("connection timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
|
||||
q.Set("encrypt", "disable")
|
||||
q.Set("TrustServerCertificate", "true")
|
||||
encrypt, trustServerCertificate := resolveSQLServerTLSSettings(config)
|
||||
q.Set("encrypt", encrypt)
|
||||
q.Set("TrustServerCertificate", trustServerCertificate)
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String()
|
||||
|
||||
122
internal/db/ssl_mode.go
Normal file
122
internal/db/ssl_mode.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"strings"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
const (
|
||||
sslModeDisable = "disable"
|
||||
sslModePreferred = "preferred"
|
||||
sslModeRequired = "required"
|
||||
sslModeSkipVerify = "skip-verify"
|
||||
)
|
||||
|
||||
func normalizeSSLModeValue(raw string) string {
|
||||
mode := strings.ToLower(strings.TrimSpace(raw))
|
||||
switch mode {
|
||||
case "", sslModePreferred, "prefer":
|
||||
return sslModePreferred
|
||||
case sslModeRequired, "require", "on", "true", "mandatory", "strict":
|
||||
return sslModeRequired
|
||||
case sslModeSkipVerify, "insecure", "skipverify", "skip_verify", "insecure-skip-verify":
|
||||
return sslModeSkipVerify
|
||||
case sslModeDisable, "disabled", "off", "false", "none":
|
||||
return sslModeDisable
|
||||
default:
|
||||
return sslModePreferred
|
||||
}
|
||||
}
|
||||
|
||||
func normalizedSSLMode(config connection.ConnectionConfig) string {
|
||||
if !config.UseSSL {
|
||||
return sslModeDisable
|
||||
}
|
||||
return normalizeSSLModeValue(config.SSLMode)
|
||||
}
|
||||
|
||||
func shouldTrySSLPreferredFallback(config connection.ConnectionConfig) bool {
|
||||
return config.UseSSL && normalizeSSLModeValue(config.SSLMode) == sslModePreferred
|
||||
}
|
||||
|
||||
func withSSLDisabled(config connection.ConnectionConfig) connection.ConnectionConfig {
|
||||
next := config
|
||||
next.UseSSL = false
|
||||
next.SSLMode = sslModeDisable
|
||||
return next
|
||||
}
|
||||
|
||||
func resolveMySQLTLSMode(config connection.ConnectionConfig) string {
|
||||
switch normalizedSSLMode(config) {
|
||||
case sslModeDisable:
|
||||
return "false"
|
||||
case sslModeRequired:
|
||||
return "true"
|
||||
case sslModeSkipVerify:
|
||||
return "skip-verify"
|
||||
default:
|
||||
return "preferred"
|
||||
}
|
||||
}
|
||||
|
||||
func resolvePostgresSSLMode(config connection.ConnectionConfig) string {
|
||||
switch normalizedSSLMode(config) {
|
||||
case sslModeDisable:
|
||||
return "disable"
|
||||
case sslModeRequired:
|
||||
return "require"
|
||||
case sslModeSkipVerify:
|
||||
return "require"
|
||||
default:
|
||||
return "require"
|
||||
}
|
||||
}
|
||||
|
||||
func resolveSQLServerTLSSettings(config connection.ConnectionConfig) (encrypt string, trustServerCertificate string) {
|
||||
switch normalizedSSLMode(config) {
|
||||
case sslModeDisable:
|
||||
return "disable", "true"
|
||||
case sslModeRequired:
|
||||
return "true", "false"
|
||||
case sslModeSkipVerify:
|
||||
return "true", "true"
|
||||
default:
|
||||
return "false", "true"
|
||||
}
|
||||
}
|
||||
|
||||
func resolveGenericTLSConfig(config connection.ConnectionConfig) *tls.Config {
|
||||
switch normalizedSSLMode(config) {
|
||||
case sslModeDisable:
|
||||
return nil
|
||||
case sslModeRequired:
|
||||
return &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
case sslModeSkipVerify:
|
||||
return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true}
|
||||
default:
|
||||
// Preferred: 先尝试 TLS(为提升兼容性默认跳过证书校验),失败时由调用方按需回退明文。
|
||||
return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true}
|
||||
}
|
||||
}
|
||||
|
||||
func resolveMongoTLSSettings(config connection.ConnectionConfig) (enabled bool, insecure bool) {
|
||||
switch normalizedSSLMode(config) {
|
||||
case sslModeDisable:
|
||||
return false, false
|
||||
case sslModeRequired:
|
||||
return true, false
|
||||
case sslModeSkipVerify:
|
||||
return true, true
|
||||
default:
|
||||
return true, true
|
||||
}
|
||||
}
|
||||
|
||||
func resolveTDengineNet(config connection.ConnectionConfig) string {
|
||||
if normalizedSSLMode(config) == sslModeDisable {
|
||||
return "ws"
|
||||
}
|
||||
return "wss"
|
||||
}
|
||||
@@ -40,11 +40,12 @@ func (t *TDengineDB) getDSN(config connection.ConnectionConfig) string {
|
||||
path = "/" + dbName
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s:%s@ws(%s)%s", user, pass, net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), path)
|
||||
netType := resolveTDengineNet(config)
|
||||
return fmt.Sprintf("%s:%s@%s(%s)%s", user, pass, netType, net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), path)
|
||||
}
|
||||
|
||||
func (t *TDengineDB) Connect(config connection.ConnectionConfig) error {
|
||||
var dsn string
|
||||
runConfig := config
|
||||
|
||||
if config.UseSSH {
|
||||
logger.Infof("TDengine 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
|
||||
@@ -68,23 +69,38 @@ func (t *TDengineDB) Connect(config connection.ConnectionConfig) error {
|
||||
localConfig.Host = host
|
||||
localConfig.Port = port
|
||||
localConfig.UseSSH = false
|
||||
dsn = t.getDSN(localConfig)
|
||||
runConfig = localConfig
|
||||
logger.Infof("TDengine 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
|
||||
} else {
|
||||
dsn = t.getDSN(config)
|
||||
}
|
||||
|
||||
db, err := sql.Open("taosWS", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开数据库连接失败:%w", err)
|
||||
attempts := []connection.ConnectionConfig{runConfig}
|
||||
if shouldTrySSLPreferredFallback(runConfig) {
|
||||
attempts = append(attempts, withSSLDisabled(runConfig))
|
||||
}
|
||||
t.conn = db
|
||||
t.pingTimeout = getConnectTimeout(config)
|
||||
|
||||
if err := t.Ping(); err != nil {
|
||||
return fmt.Errorf("连接建立后验证失败:%w", err)
|
||||
var failures []string
|
||||
for idx, attempt := range attempts {
|
||||
dsn := t.getDSN(attempt)
|
||||
db, err := sql.Open("taosWS", dsn)
|
||||
if err != nil {
|
||||
failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err))
|
||||
continue
|
||||
}
|
||||
t.conn = db
|
||||
t.pingTimeout = getConnectTimeout(attempt)
|
||||
|
||||
if err := t.Ping(); err != nil {
|
||||
_ = db.Close()
|
||||
t.conn = nil
|
||||
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err))
|
||||
continue
|
||||
}
|
||||
if idx > 0 {
|
||||
logger.Warnf("TDengine SSL 优先连接失败,已回退至明文连接")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ";"))
|
||||
}
|
||||
|
||||
func (t *TDengineDB) Close() error {
|
||||
|
||||
@@ -41,7 +41,7 @@ func (v *VastbaseDB) getDSN(config connection.ConnectionConfig) string {
|
||||
}
|
||||
u.User = url.UserPassword(config.User, config.Password)
|
||||
q := url.Values{}
|
||||
q.Set("sslmode", "disable")
|
||||
q.Set("sslmode", resolvePostgresSSLMode(config))
|
||||
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
@@ -49,7 +49,7 @@ func (v *VastbaseDB) getDSN(config connection.ConnectionConfig) string {
|
||||
}
|
||||
|
||||
func (v *VastbaseDB) Connect(config connection.ConnectionConfig) error {
|
||||
var dsn string
|
||||
runConfig := config
|
||||
|
||||
if config.UseSSH {
|
||||
logger.Infof("Vastbase 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
|
||||
@@ -75,23 +75,37 @@ func (v *VastbaseDB) Connect(config connection.ConnectionConfig) error {
|
||||
localConfig.Port = port
|
||||
localConfig.UseSSH = false
|
||||
|
||||
dsn = v.getDSN(localConfig)
|
||||
runConfig = localConfig
|
||||
logger.Infof("Vastbase 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
|
||||
} else {
|
||||
dsn = v.getDSN(config)
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开数据库连接失败:%w", err)
|
||||
attempts := []connection.ConnectionConfig{runConfig}
|
||||
if shouldTrySSLPreferredFallback(runConfig) {
|
||||
attempts = append(attempts, withSSLDisabled(runConfig))
|
||||
}
|
||||
v.conn = db
|
||||
v.pingTimeout = getConnectTimeout(config)
|
||||
|
||||
if err := v.Ping(); err != nil {
|
||||
return fmt.Errorf("连接建立后验证失败:%w", err)
|
||||
var failures []string
|
||||
for idx, attempt := range attempts {
|
||||
dsn := v.getDSN(attempt)
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
if err != nil {
|
||||
failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err))
|
||||
continue
|
||||
}
|
||||
v.conn = db
|
||||
v.pingTimeout = getConnectTimeout(attempt)
|
||||
if err := v.Ping(); err != nil {
|
||||
_ = db.Close()
|
||||
v.conn = nil
|
||||
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err))
|
||||
continue
|
||||
}
|
||||
if idx > 0 {
|
||||
logger.Warnf("Vastbase SSL 优先连接失败,已回退至明文连接")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ";"))
|
||||
}
|
||||
|
||||
func (v *VastbaseDB) Close() error {
|
||||
|
||||
@@ -2,6 +2,7 @@ package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
@@ -201,25 +202,48 @@ func (r *RedisClientImpl) Connect(config connection.ConnectionConfig) error {
|
||||
|
||||
timeout := normalizeRedisTimeout(config.Timeout)
|
||||
if r.isCluster {
|
||||
opts := &redis.ClusterOptions{
|
||||
Addrs: seedAddrs,
|
||||
Username: strings.TrimSpace(config.User),
|
||||
Password: config.Password,
|
||||
DialTimeout: timeout,
|
||||
ReadTimeout: timeout,
|
||||
WriteTimeout: timeout,
|
||||
attempts := []connection.ConnectionConfig{config}
|
||||
if shouldTryRedisSSLPreferredFallback(config) {
|
||||
attempts = append(attempts, withRedisSSLDisabled(config))
|
||||
}
|
||||
clusterClient := redis.NewClusterClient(opts)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
if err := clusterClient.Ping(ctx).Err(); err != nil {
|
||||
clusterClient.Close()
|
||||
return fmt.Errorf("Redis 集群连接失败: %w", err)
|
||||
|
||||
var failures []string
|
||||
for idx, attempt := range attempts {
|
||||
var tlsConfig *tls.Config
|
||||
if cfg := resolveRedisTLSConfig(attempt); cfg != nil {
|
||||
if host, _, err := net.SplitHostPort(seedAddrs[0]); err == nil && host != "" {
|
||||
cfg.ServerName = host
|
||||
}
|
||||
tlsConfig = cfg
|
||||
}
|
||||
opts := &redis.ClusterOptions{
|
||||
Addrs: seedAddrs,
|
||||
Username: strings.TrimSpace(attempt.User),
|
||||
Password: attempt.Password,
|
||||
DialTimeout: timeout,
|
||||
ReadTimeout: timeout,
|
||||
WriteTimeout: timeout,
|
||||
TLSConfig: tlsConfig,
|
||||
}
|
||||
clusterClient := redis.NewClusterClient(opts)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
pingErr := clusterClient.Ping(ctx).Err()
|
||||
cancel()
|
||||
if pingErr != nil {
|
||||
clusterClient.Close()
|
||||
failures = append(failures, fmt.Sprintf("第%d次连接失败: %v", idx+1, pingErr))
|
||||
continue
|
||||
}
|
||||
r.client = clusterClient
|
||||
r.clusterClient = clusterClient
|
||||
r.config = attempt
|
||||
if idx > 0 {
|
||||
logger.Warnf("Redis 集群 SSL 优先连接失败,已回退至明文连接")
|
||||
}
|
||||
logger.Infof("Redis 集群连接成功: seeds=%s 逻辑库=db%d", strings.Join(seedAddrs, ","), r.currentDB)
|
||||
return nil
|
||||
}
|
||||
r.client = clusterClient
|
||||
r.clusterClient = clusterClient
|
||||
logger.Infof("Redis 集群连接成功: seeds=%s 逻辑库=db%d", strings.Join(seedAddrs, ","), r.currentDB)
|
||||
return nil
|
||||
return fmt.Errorf("Redis 集群连接失败: %s", strings.Join(failures, ";"))
|
||||
}
|
||||
|
||||
addr := seedAddrs[0]
|
||||
@@ -233,29 +257,53 @@ func (r *RedisClientImpl) Connect(config connection.ConnectionConfig) error {
|
||||
logger.Infof("Redis 通过 SSH 隧道连接: %s -> %s:%d", addr, config.Host, config.Port)
|
||||
}
|
||||
|
||||
opts := &redis.Options{
|
||||
Addr: addr,
|
||||
Username: strings.TrimSpace(config.User),
|
||||
Password: config.Password,
|
||||
DB: r.currentDB,
|
||||
DialTimeout: timeout,
|
||||
ReadTimeout: timeout,
|
||||
WriteTimeout: timeout,
|
||||
attempts := []connection.ConnectionConfig{config}
|
||||
if shouldTryRedisSSLPreferredFallback(config) {
|
||||
attempts = append(attempts, withRedisSSLDisabled(config))
|
||||
}
|
||||
|
||||
singleClient := redis.NewClient(opts)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
var failures []string
|
||||
for idx, attempt := range attempts {
|
||||
var tlsConfig *tls.Config
|
||||
if cfg := resolveRedisTLSConfig(attempt); cfg != nil {
|
||||
if host, _, err := net.SplitHostPort(addr); err == nil && host != "" {
|
||||
cfg.ServerName = host
|
||||
}
|
||||
tlsConfig = cfg
|
||||
}
|
||||
|
||||
if err := singleClient.Ping(ctx).Err(); err != nil {
|
||||
singleClient.Close()
|
||||
return fmt.Errorf("Redis 连接失败: %w", err)
|
||||
opts := &redis.Options{
|
||||
Addr: addr,
|
||||
Username: strings.TrimSpace(attempt.User),
|
||||
Password: attempt.Password,
|
||||
DB: r.currentDB,
|
||||
DialTimeout: timeout,
|
||||
ReadTimeout: timeout,
|
||||
WriteTimeout: timeout,
|
||||
TLSConfig: tlsConfig,
|
||||
}
|
||||
|
||||
singleClient := redis.NewClient(opts)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
pingErr := singleClient.Ping(ctx).Err()
|
||||
cancel()
|
||||
if pingErr != nil {
|
||||
singleClient.Close()
|
||||
failures = append(failures, fmt.Sprintf("第%d次连接失败: %v", idx+1, pingErr))
|
||||
continue
|
||||
}
|
||||
|
||||
r.client = singleClient
|
||||
r.singleClient = singleClient
|
||||
r.config = attempt
|
||||
if idx > 0 {
|
||||
logger.Warnf("Redis SSL 优先连接失败,已回退至明文连接")
|
||||
}
|
||||
logger.Infof("Redis 连接成功: %s DB=%d", addr, r.currentDB)
|
||||
return nil
|
||||
}
|
||||
|
||||
r.client = singleClient
|
||||
r.singleClient = singleClient
|
||||
logger.Infof("Redis 连接成功: %s DB=%d", addr, r.currentDB)
|
||||
return nil
|
||||
return fmt.Errorf("Redis 连接失败: %s", strings.Join(failures, ";"))
|
||||
}
|
||||
|
||||
// Close closes the Redis connection
|
||||
|
||||
55
internal/redis/ssl_mode.go
Normal file
55
internal/redis/ssl_mode.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"strings"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
func normalizeRedisSSLMode(raw string) string {
|
||||
mode := strings.ToLower(strings.TrimSpace(raw))
|
||||
switch mode {
|
||||
case "", "preferred", "prefer":
|
||||
return "preferred"
|
||||
case "required", "require", "on", "true", "mandatory", "strict":
|
||||
return "required"
|
||||
case "skip-verify", "insecure", "skipverify", "skip_verify", "insecure-skip-verify":
|
||||
return "skip-verify"
|
||||
case "disable", "disabled", "off", "false", "none":
|
||||
return "disable"
|
||||
default:
|
||||
return "preferred"
|
||||
}
|
||||
}
|
||||
|
||||
func redisSSLMode(config connection.ConnectionConfig) string {
|
||||
if !config.UseSSL {
|
||||
return "disable"
|
||||
}
|
||||
return normalizeRedisSSLMode(config.SSLMode)
|
||||
}
|
||||
|
||||
func shouldTryRedisSSLPreferredFallback(config connection.ConnectionConfig) bool {
|
||||
return config.UseSSL && normalizeRedisSSLMode(config.SSLMode) == "preferred"
|
||||
}
|
||||
|
||||
func withRedisSSLDisabled(config connection.ConnectionConfig) connection.ConnectionConfig {
|
||||
next := config
|
||||
next.UseSSL = false
|
||||
next.SSLMode = "disable"
|
||||
return next
|
||||
}
|
||||
|
||||
func resolveRedisTLSConfig(config connection.ConnectionConfig) *tls.Config {
|
||||
switch redisSSLMode(config) {
|
||||
case "disable":
|
||||
return nil
|
||||
case "required":
|
||||
return &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
case "skip-verify":
|
||||
return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true}
|
||||
default:
|
||||
return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user