feat(mongodb,connection-tree,query-editor,sidebar,sqlserver,table-designer,ssl): 完成 MongoDB v1/v2 驱动切换与复制连接,增强快捷键/搜索/筛选与设计表体验,并修复 SQLServer、SSL 及连接稳定性问题 (#180)

* feat(mongodb-driver,connection-tree): 支持 MongoDB v1/v2 切换并新增复制连接

* fix(mongodb-query): 修复 MongoDB 筛选不生效并兼容 shell 语法执行

refs #153

* fix(query-editor): 修复 SQLServer 自动补全回车重复 dbo 前缀

refs #159

* fix(sqlserver-table-designer): 修复设计表读取列时错误使用 schema 作为数据库名

refs #156

* feat(shortcuts): 增加快捷键设置并支持 SQL 执行/侧边栏搜索

refs #158

* fix(sidebar-search): 优化范围搜索匹配与交互

refs #158

* fix(filter,connection-recovery): 保持筛选状态并修复连接失效卡死

refs #165

同步修复连接失效后侧栏持续转圈、断开后无法恢复的问题

* feat(table-designer): 统一设计表界面风格并优化字段新增交互

- 统一设计表页面与数据面板的视觉风格,覆盖工具栏、Tabs、表格与编辑区域

- 移除默认硬边框,改为透明背景与细分隔线,提升整体观感一致性

- 添加字段后自动滚动到新行并高亮,且自动聚焦输入框

- 新增" 在选中字段后添加\,支持按选中字段位置插入字段

* feat(data-grid-filter): 筛选字段支持快捷搜索

- 在筛选条件字段下拉启用可搜索(showSearch)

- 支持字段名大小写不敏感模糊匹配

- 表字段较多时可快速定位目标字段,减少下拉查找耗时

refs #171

* fix(db-ssl): 支持多数据源 SSL/TLS 连接并补齐达梦证书配置

refs #167

* fix(sidebar): 修复数据库加载时 null.map 导致表加载失败

* fix(query-editor): 合并运行按钮并保留 SQL 停止执行入口
This commit is contained in:
辣条
2026-03-05 16:52:06 +08:00
committed by GitHub
parent 69942bb77e
commit 71b41459e7
41 changed files with 5211 additions and 397 deletions

View File

@@ -0,0 +1,12 @@
//go:build gonavi_mongodb_driver_v1
package main
import "GoNavi-Wails/internal/db"
func init() {
agentDriverType = "mongodb"
agentDatabaseFactory = func() db.Database {
return &db.MongoDBV1{}
}
}

View File

@@ -2,7 +2,7 @@ import React, { useState, useEffect } from 'react';
import { Layout, Button, ConfigProvider, theme, Dropdown, MenuProps, message, Modal, Spin, Slider, Progress, Switch, Input, InputNumber, Select } from 'antd';
import zhCN from 'antd/locale/zh_CN';
import { PlusOutlined, ConsoleSqlOutlined, UploadOutlined, DownloadOutlined, CloudDownloadOutlined, BugOutlined, ToolOutlined, GlobalOutlined, InfoCircleOutlined, GithubOutlined, SkinOutlined, CheckOutlined, MinusOutlined, BorderOutlined, CloseOutlined, SettingOutlined, LinkOutlined } from '@ant-design/icons';
import { BrowserOpenURL, Environment, EventsOn, Quit, WindowFullscreen, WindowIsFullscreen, WindowIsMaximised, WindowMaximise, WindowMinimise, WindowToggleMaximise } from '../wailsjs/runtime';
import { BrowserOpenURL, Environment, EventsOn, Quit, WindowFullscreen, WindowGetSize, WindowIsFullscreen, WindowIsMaximised, WindowMaximise, WindowMinimise, WindowSetSize, WindowToggleMaximise } from '../wailsjs/runtime';
import Sidebar from './components/Sidebar';
import TabManager from './components/TabManager';
import ConnectionModal from './components/ConnectionModal';
@@ -12,6 +12,17 @@ import LogPanel from './components/LogPanel';
import { useStore } from './store';
import { SavedConnection } from './types';
import { blurToFilter, normalizeBlurForPlatform, normalizeOpacityForPlatform, isWindowsPlatform } from './utils/appearance';
import {
SHORTCUT_ACTION_META,
SHORTCUT_ACTION_ORDER,
ShortcutAction,
eventToShortcut,
getShortcutDisplay,
hasModifierKey,
isEditableElement,
isShortcutMatch,
normalizeShortcutCombo,
} from './utils/shortcuts';
import { ConfigureGlobalProxy, SetWindowTranslucency } from '../wailsjs/go/app/App';
import './App.css';
@@ -53,6 +64,9 @@ function App() {
const setStartupFullscreen = useStore(state => state.setStartupFullscreen);
const globalProxy = useStore(state => state.globalProxy);
const setGlobalProxy = useStore(state => state.setGlobalProxy);
const shortcutOptions = useStore(state => state.shortcutOptions);
const updateShortcut = useStore(state => state.updateShortcut);
const resetShortcutOptions = useStore(state => state.resetShortcutOptions);
const darkMode = themeMode === 'dark';
const effectiveUiScale = Math.min(MAX_UI_SCALE, Math.max(MIN_UI_SCALE, Number(uiScale) || DEFAULT_UI_SCALE));
const effectiveFontSize = Math.min(MAX_FONT_SIZE, Math.max(MIN_FONT_SIZE, Math.round(Number(fontSize) || DEFAULT_FONT_SIZE)));
@@ -260,6 +274,80 @@ function App() {
};
}, []);
useEffect(() => {
if (!isWindowsPlatform()) {
return;
}
let cancelled = false;
let inFlight = false;
let lastRatio = Number(window.devicePixelRatio) || 1;
let lastFixAt = 0;
const wait = (ms: number) => new Promise<void>((resolve) => window.setTimeout(resolve, ms));
const fixWindowScaleIfNeeded = async () => {
if (cancelled || inFlight) return;
const now = Date.now();
if (now - lastFixAt < 700) return;
inFlight = true;
try {
const [isFullscreen, isMaximised] = await Promise.all([
WindowIsFullscreen().catch(() => false),
WindowIsMaximised().catch(() => false),
]);
// 避免在全屏/最大化状态下强制改尺寸;这两种状态通常能自行保持 DPI 同步。
if (isFullscreen || isMaximised) {
window.dispatchEvent(new Event('resize'));
lastFixAt = Date.now();
return;
}
const size = await WindowGetSize().catch(() => null);
const width = Math.trunc(Number(size?.w || 0));
const height = Math.trunc(Number(size?.h || 0));
if (width <= 0 || height <= 0) {
window.dispatchEvent(new Event('resize'));
lastFixAt = Date.now();
return;
}
const nudgedWidth = width > 480 ? width - 1 : width + 1;
WindowSetSize(nudgedWidth, height);
await wait(28);
WindowSetSize(width, height);
window.dispatchEvent(new Event('resize'));
lastFixAt = Date.now();
} finally {
inFlight = false;
}
};
const checkDevicePixelRatio = () => {
if (cancelled) return;
const currentRatio = Number(window.devicePixelRatio) || 1;
if (Math.abs(currentRatio - lastRatio) < 0.02) {
return;
}
lastRatio = currentRatio;
void fixWindowScaleIfNeeded();
};
const pollTimer = window.setInterval(checkDevicePixelRatio, 900);
window.addEventListener('resize', checkDevicePixelRatio);
window.addEventListener('focus', checkDevicePixelRatio);
document.addEventListener('visibilitychange', checkDevicePixelRatio);
return () => {
cancelled = true;
window.clearInterval(pollTimer);
window.removeEventListener('resize', checkDevicePixelRatio);
window.removeEventListener('focus', checkDevicePixelRatio);
document.removeEventListener('visibilitychange', checkDevicePixelRatio);
};
}, []);
// Background Helper
const getBg = (darkHex: string) => {
if (!darkMode) return `rgba(255, 255, 255, ${effectiveOpacity})`; // Light mode usually white
@@ -720,10 +808,18 @@ function App() {
label: '外观设置...',
icon: <SettingOutlined />,
onClick: () => setIsAppearanceModalOpen(true)
},
{
key: 'shortcut-settings',
label: '快捷键管理...',
icon: <LinkOutlined />,
onClick: () => setIsShortcutModalOpen(true)
}
];
const [isAppearanceModalOpen, setIsAppearanceModalOpen] = useState(false);
const [isShortcutModalOpen, setIsShortcutModalOpen] = useState(false);
const [capturingShortcutAction, setCapturingShortcutAction] = useState<ShortcutAction | null>(null);
const [isProxyModalOpen, setIsProxyModalOpen] = useState(false);
@@ -935,6 +1031,113 @@ function App() {
};
}, []);
useEffect(() => {
const handleOpenShortcutSettingsEvent = () => {
setIsShortcutModalOpen(true);
};
window.addEventListener('gonavi:open-shortcut-settings', handleOpenShortcutSettingsEvent as EventListener);
return () => {
window.removeEventListener('gonavi:open-shortcut-settings', handleOpenShortcutSettingsEvent as EventListener);
};
}, []);
useEffect(() => {
const handleGlobalShortcut = (event: KeyboardEvent) => {
const matchedAction = SHORTCUT_ACTION_ORDER.find((action) => {
const binding = shortcutOptions[action];
if (!binding?.enabled) {
return false;
}
if (isEditableElement(event.target) && !SHORTCUT_ACTION_META[action].allowInEditable) {
return false;
}
return isShortcutMatch(event, binding.combo);
});
if (!matchedAction) {
return;
}
event.preventDefault();
event.stopPropagation();
switch (matchedAction) {
case 'runQuery':
window.dispatchEvent(new CustomEvent('gonavi:run-active-query'));
break;
case 'focusSidebarSearch':
window.dispatchEvent(new CustomEvent('gonavi:focus-sidebar-search'));
break;
case 'newQueryTab':
handleNewQuery();
break;
case 'toggleLogPanel':
setIsLogPanelOpen((prev) => !prev);
break;
case 'toggleTheme':
setTheme(themeMode === 'dark' ? 'light' : 'dark');
break;
case 'openShortcutManager':
setIsShortcutModalOpen(true);
break;
}
};
window.addEventListener('keydown', handleGlobalShortcut);
return () => {
window.removeEventListener('keydown', handleGlobalShortcut);
};
}, [handleNewQuery, shortcutOptions, themeMode, setTheme]);
useEffect(() => {
if (!capturingShortcutAction) {
return;
}
const handleShortcutCapture = (event: KeyboardEvent) => {
event.preventDefault();
event.stopPropagation();
if (event.key === 'Escape') {
setCapturingShortcutAction(null);
return;
}
const combo = eventToShortcut(event);
if (!combo) {
return;
}
if (!hasModifierKey(combo)) {
void message.warning('快捷键至少包含 Ctrl / Alt / Shift / Meta 之一');
return;
}
const normalizedCombo = normalizeShortcutCombo(combo);
const conflictAction = SHORTCUT_ACTION_ORDER.find((action) => {
if (action === capturingShortcutAction) {
return false;
}
const binding = shortcutOptions[action];
if (!binding?.enabled) {
return false;
}
return normalizeShortcutCombo(binding.combo) === normalizedCombo;
});
if (conflictAction) {
void message.warning(`与「${SHORTCUT_ACTION_META[conflictAction].label}」冲突,请换一个快捷键`);
return;
}
updateShortcut(capturingShortcutAction, { combo: normalizedCombo, enabled: true });
setCapturingShortcutAction(null);
};
window.addEventListener('keydown', handleShortcutCapture, true);
return () => {
window.removeEventListener('keydown', handleShortcutCapture, true);
};
}, [capturingShortcutAction, shortcutOptions, updateShortcut]);
const linuxResizeHandleStyleBase = {
position: 'fixed',
zIndex: 12000,
@@ -1354,6 +1557,84 @@ function App() {
</div>
</Modal>
<Modal
title="快捷键管理"
open={isShortcutModalOpen}
onCancel={() => {
setIsShortcutModalOpen(false);
setCapturingShortcutAction(null);
}}
width={720}
footer={[
<Button
key="reset"
onClick={() => {
resetShortcutOptions();
setCapturingShortcutAction(null);
void message.success('已恢复默认快捷键');
}}
>
</Button>,
<Button
key="close"
type="primary"
onClick={() => {
setIsShortcutModalOpen(false);
setCapturingShortcutAction(null);
}}
>
</Button>,
]}
>
<div style={{ display: 'flex', flexDirection: 'column', gap: 12, paddingTop: 8 }}>
<div style={{ fontSize: 12, color: '#8c8c8c' }}>
Esc Ctrl/Alt/Shift/Meta
</div>
{SHORTCUT_ACTION_ORDER.map((action) => {
const meta = SHORTCUT_ACTION_META[action];
const binding = shortcutOptions[action] ?? { combo: '', enabled: false };
const isCapturing = capturingShortcutAction === action;
return (
<div
key={action}
style={{
display: 'grid',
gridTemplateColumns: '1fr auto',
gap: 12,
alignItems: 'center',
padding: '10px 12px',
border: '1px solid rgba(128, 128, 128, 0.2)',
borderRadius: 8,
}}
>
<div>
<div style={{ fontWeight: 500 }}>{meta.label}</div>
<div style={{ fontSize: 12, color: '#8c8c8c' }}>{meta.description}</div>
</div>
<div style={{ display: 'flex', alignItems: 'center', gap: 8 }}>
<Input
readOnly
value={isCapturing ? '请按下快捷键...' : getShortcutDisplay(binding.combo)}
style={{ width: 180, fontFamily: 'Consolas, Menlo, Monaco, monospace' }}
/>
<Button
size="small"
onClick={() => setCapturingShortcutAction((prev) => (prev === action ? null : action))}
>
{isCapturing ? '取消' : '录制'}
</Button>
<Switch
checked={binding.enabled}
onChange={(checked) => updateShortcut(action, { enabled: checked })}
/>
</div>
</div>
);
})}
</div>
</Modal>
<Modal
title="全局代理设置"
open={isProxyModalOpen}

View File

@@ -54,6 +54,26 @@ const singleHostUriSchemesByType: Record<string, string[]> = {
vastbase: ['vastbase'],
};
const sslSupportedTypes = new Set([
'mysql',
'mariadb',
'diros',
'sphinx',
'dameng',
'clickhouse',
'postgres',
'sqlserver',
'oracle',
'kingbase',
'highgo',
'vastbase',
'mongodb',
'redis',
'tdengine',
]);
const supportsSSLForType = (type: string) => sslSupportedTypes.has(String(type || '').trim().toLowerCase());
const isFileDatabaseType = (type: string) => type === 'sqlite' || type === 'duckdb';
type DriverStatusSnapshot = {
@@ -78,6 +98,7 @@ const ConnectionModal: React.FC<{
}> = ({ open, onClose, initialValues, onOpenDriverManager }) => {
const [form] = Form.useForm();
const [loading, setLoading] = useState(false);
const [useSSL, setUseSSL] = useState(false);
const [useSSH, setUseSSH] = useState(false);
const [useProxy, setUseProxy] = useState(false);
const [dbType, setDbType] = useState('mysql');
@@ -107,6 +128,17 @@ const ConnectionModal: React.FC<{
const mongoTopology = Form.useWatch('mongoTopology', form) || 'single';
const mongoSrv = Form.useWatch('mongoSrv', form) || false;
const redisTopology = Form.useWatch('redisTopology', form) || 'single';
const isMySQLLike = dbType === 'mysql' || dbType === 'mariadb' || dbType === 'diros' || dbType === 'sphinx';
const isSSLType = supportsSSLForType(dbType);
const sslHintText = isMySQLLike
? '当 MySQL/MariaDB/Doris/Sphinx 开启安全传输策略时,请启用 SSL本地自签证书场景可先用 Preferred 或 Skip Verify。'
: dbType === 'dameng'
? '达梦驱动启用 SSL 需要客户端证书与私钥路径sslCertPath / sslKeyPath。'
: dbType === 'sqlserver'
? 'SQL Server 推荐在生产环境使用 Required并关闭 TrustServerCertificate。'
: dbType === 'mongodb'
? 'MongoDB 可通过 TLS 保护连接,证书校验异常时可先用 Skip Verify 验证连通性。'
: '建议优先使用 Required仅在测试环境或自签证书场景使用 Skip Verify。';
const getSectionBg = (darkHex: string) => {
if (!darkMode) {
@@ -364,7 +396,7 @@ const ConnectionModal: React.FC<{
uriText: string,
expectedSchemes: string[],
defaultPort: number,
): { host: string; port: number; username: string; password: string; database: string } | null => {
): { host: string; port: number; username: string; password: string; database: string; params: URLSearchParams } | null => {
let parsed: ReturnType<typeof parseMultiHostUri> | null = null;
for (const scheme of expectedSchemes) {
parsed = parseMultiHostUri(uriText, scheme);
@@ -392,6 +424,7 @@ const ConnectionModal: React.FC<{
username: parsed.username,
password: parsed.password,
database: parsed.database || '',
params: parsed.params,
};
};
@@ -425,12 +458,22 @@ const ConnectionModal: React.FC<{
const primary = parseHostPort(hostList[0] || `localhost:${mysqlDefaultPort}`, mysqlDefaultPort);
const timeoutValue = Number(parsed.params.get('timeout'));
const topology = String(parsed.params.get('topology') || '').toLowerCase();
const tlsValue = String(parsed.params.get('tls') || '').trim().toLowerCase();
const sslMode = tlsValue === 'true'
? 'required'
: tlsValue === 'skip-verify'
? 'skip-verify'
: tlsValue === 'preferred'
? 'preferred'
: 'disable';
return {
host: primary?.host || 'localhost',
port: primary?.port || mysqlDefaultPort,
user: parsed.username,
password: parsed.password,
database: parsed.database || '',
useSSL: sslMode !== 'disable',
sslMode,
mysqlTopology: hostList.length > 1 || topology === 'replica' ? 'replica' : 'single',
mysqlReplicaHosts: hostList.slice(1),
timeout: Number.isFinite(timeoutValue) && timeoutValue > 0
@@ -451,7 +494,7 @@ const ConnectionModal: React.FC<{
}
if (type === 'redis') {
const parsed = parseMultiHostUri(trimmedUri, 'redis');
const parsed = parseMultiHostUri(trimmedUri, 'redis') || parseMultiHostUri(trimmedUri, 'rediss');
if (!parsed) {
return null;
}
@@ -469,10 +512,15 @@ const ConnectionModal: React.FC<{
const topologyParam = String(parsed.params.get('topology') || '').toLowerCase();
const dbText = String(parsed.database || '').trim().replace(/^\//, '');
const dbIndex = Number(dbText);
const isRediss = trimmedUri.toLowerCase().startsWith('rediss://');
const skipVerifyText = String(parsed.params.get('skip_verify') || '').trim().toLowerCase();
const skipVerify = skipVerifyText === '1' || skipVerifyText === 'true' || skipVerifyText === 'yes' || skipVerifyText === 'on';
return {
host: primary?.host || 'localhost',
port: primary?.port || 6379,
password: parsed.password || '',
useSSL: isRediss,
sslMode: isRediss ? (skipVerify ? 'skip-verify' : 'required') : 'disable',
redisTopology: hostList.length > 1 || topologyParam === 'cluster' ? 'cluster' : 'single',
redisHosts: hostList.slice(1),
redisDB: Number.isFinite(dbIndex) && dbIndex >= 0 && dbIndex <= 15 ? Math.trunc(dbIndex) : 0,
@@ -501,12 +549,18 @@ const ConnectionModal: React.FC<{
? { host: hostList[0] || 'localhost', port: 27017 }
: parseHostPort(hostList[0] || 'localhost:27017', 27017);
const timeoutMs = Number(parsed.params.get('connectTimeoutMS') || parsed.params.get('serverSelectionTimeoutMS'));
const tlsText = String(parsed.params.get('tls') || parsed.params.get('ssl') || '').trim().toLowerCase();
const tlsInsecureText = String(parsed.params.get('tlsInsecure') || parsed.params.get('sslInsecure') || '').trim().toLowerCase();
const tlsEnabled = tlsText === '1' || tlsText === 'true' || tlsText === 'yes' || tlsText === 'on';
const tlsInsecure = tlsInsecureText === '1' || tlsInsecureText === 'true' || tlsInsecureText === 'yes' || tlsInsecureText === 'on';
return {
host: primary?.host || 'localhost',
port: primary?.port || 27017,
user: parsed.username,
password: parsed.password,
database: parsed.database || '',
useSSL: tlsEnabled,
sslMode: tlsEnabled ? (tlsInsecure ? 'skip-verify' : 'required') : 'disable',
mongoTopology: hostList.length > 1 || !!parsed.params.get('replicaSet') ? 'replica' : 'single',
mongoHosts: hostList.slice(1),
mongoSrv: isSrv,
@@ -531,13 +585,94 @@ const ConnectionModal: React.FC<{
// Oracle 需要显式 service name避免 URI 解析后放过必填校验。
return null;
}
return {
const parsedValues: Record<string, any> = {
host: parsed.host,
port: parsed.port,
user: parsed.username,
password: parsed.password,
database: parsed.database,
};
if (supportsSSLForType(type)) {
const normalizeBool = (raw: unknown) => {
const text = String(raw ?? '').trim().toLowerCase();
return text === '1' || text === 'true' || text === 'yes' || text === 'on';
};
if (type === 'postgres' || type === 'kingbase' || type === 'highgo' || type === 'vastbase') {
const sslMode = String(parsed.params.get('sslmode') || '').trim().toLowerCase();
if (sslMode) {
parsedValues.useSSL = sslMode !== 'disable' && sslMode !== 'false';
parsedValues.sslMode = sslMode === 'disable' || sslMode === 'false'
? 'disable'
: 'required';
}
} else if (type === 'sqlserver') {
const encrypt = String(parsed.params.get('encrypt') || '').trim().toLowerCase();
const trust = String(parsed.params.get('TrustServerCertificate') || parsed.params.get('trustservercertificate') || '').trim().toLowerCase();
const encrypted = encrypt === 'true' || encrypt === 'mandatory' || encrypt === 'yes' || encrypt === '1' || encrypt === 'strict';
if (encrypted) {
parsedValues.useSSL = true;
parsedValues.sslMode = trust === 'true' || trust === '1' || trust === 'yes' ? 'skip-verify' : 'required';
} else if (encrypt) {
parsedValues.useSSL = false;
parsedValues.sslMode = 'disable';
}
} else if (type === 'clickhouse') {
const secure = String(parsed.params.get('secure') || parsed.params.get('tls') || '').trim().toLowerCase();
const skipVerify = normalizeBool(parsed.params.get('skip_verify'));
if (secure) {
parsedValues.useSSL = normalizeBool(secure);
parsedValues.sslMode = skipVerify ? 'skip-verify' : (parsedValues.useSSL ? 'required' : 'disable');
}
} else if (type === 'dameng') {
const certPath = String(
parsed.params.get('SSL_CERT_PATH')
|| parsed.params.get('ssl_cert_path')
|| parsed.params.get('sslCertPath')
|| ''
).trim();
const keyPath = String(
parsed.params.get('SSL_KEY_PATH')
|| parsed.params.get('ssl_key_path')
|| parsed.params.get('sslKeyPath')
|| ''
).trim();
parsedValues.sslCertPath = certPath;
parsedValues.sslKeyPath = keyPath;
if (certPath || keyPath) {
parsedValues.useSSL = true;
parsedValues.sslMode = 'required';
}
} else if (type === 'oracle') {
const ssl = String(parsed.params.get('SSL') || parsed.params.get('ssl') || '').trim().toLowerCase();
const sslVerify = String(
parsed.params.get('SSL VERIFY')
|| parsed.params.get('ssl verify')
|| parsed.params.get('SSL_VERIFY')
|| parsed.params.get('ssl_verify')
|| ''
).trim().toLowerCase();
if (ssl) {
parsedValues.useSSL = normalizeBool(ssl);
if (!parsedValues.useSSL) {
parsedValues.sslMode = 'disable';
} else {
parsedValues.sslMode = normalizeBool(sslVerify || 'true') ? 'required' : 'skip-verify';
}
}
} else if (type === 'tdengine') {
const protocol = String(parsed.params.get('protocol') || '').trim().toLowerCase();
const skipVerify = normalizeBool(parsed.params.get('skip_verify'));
if (protocol === 'wss') {
parsedValues.useSSL = true;
parsedValues.sslMode = skipVerify ? 'skip-verify' : 'required';
} else if (protocol === 'ws') {
parsedValues.useSSL = false;
parsedValues.sslMode = 'disable';
}
}
};
return parsedValues;
}
return null;
@@ -609,6 +744,16 @@ const ConnectionModal: React.FC<{
if (hosts.length > 1 || values.mysqlTopology === 'replica') {
params.set('topology', 'replica');
}
if (values.useSSL) {
const mode = String(values.sslMode || 'preferred').trim().toLowerCase();
if (mode === 'required') {
params.set('tls', 'true');
} else if (mode === 'skip-verify') {
params.set('tls', 'skip-verify');
} else {
params.set('tls', 'preferred');
}
}
if (Number.isFinite(timeout) && timeout > 0) {
params.set('timeout', String(timeout));
}
@@ -634,8 +779,15 @@ const ConnectionModal: React.FC<{
? Math.max(0, Math.min(15, Math.trunc(Number(values.redisDB))))
: 0;
const dbPath = `/${redisDB}`;
if (values.useSSL) {
const mode = String(values.sslMode || 'preferred').trim().toLowerCase();
if (mode === 'skip-verify' || mode === 'preferred') {
params.set('skip_verify', 'true');
}
}
const query = params.toString();
return `redis://${redisAuth}${hosts.join(',')}${dbPath}${query ? `?${query}` : ''}`;
const scheme = values.useSSL ? 'rediss' : 'redis';
return `${scheme}://${redisAuth}${hosts.join(',')}${dbPath}${query ? `?${query}` : ''}`;
}
if (isFileDatabaseType(type)) {
@@ -675,6 +827,15 @@ const ConnectionModal: React.FC<{
if (authMechanism) {
params.set('authMechanism', authMechanism);
}
if (values.useSSL) {
const mode = String(values.sslMode || 'preferred').trim().toLowerCase();
params.set('tls', 'true');
if (mode === 'skip-verify' || mode === 'preferred') {
params.set('tlsInsecure', 'true');
} else {
params.delete('tlsInsecure');
}
}
if (Number.isFinite(timeout) && timeout > 0) {
params.set('connectTimeoutMS', String(timeout * 1000));
params.set('serverSelectionTimeoutMS', String(timeout * 1000));
@@ -686,7 +847,45 @@ const ConnectionModal: React.FC<{
const scheme = type === 'postgres' ? 'postgresql' : type;
const dbPath = database ? `/${encodeURIComponent(database)}` : '';
return `${scheme}://${encodedAuth}${toAddress(host, port, defaultPort)}${dbPath}`;
const params = new URLSearchParams();
if (supportsSSLForType(type) && values.useSSL) {
const mode = String(values.sslMode || 'preferred').trim().toLowerCase();
if (type === 'postgres' || type === 'kingbase' || type === 'highgo' || type === 'vastbase') {
params.set('sslmode', 'require');
} else if (type === 'sqlserver') {
params.set('encrypt', 'true');
params.set('TrustServerCertificate', mode === 'skip-verify' || mode === 'preferred' ? 'true' : 'false');
} else if (type === 'clickhouse') {
params.set('secure', 'true');
if (mode === 'skip-verify' || mode === 'preferred') {
params.set('skip_verify', 'true');
}
} else if (type === 'dameng') {
const certPath = String(values.sslCertPath || '').trim();
const keyPath = String(values.sslKeyPath || '').trim();
if (certPath) params.set('SSL_CERT_PATH', certPath);
if (keyPath) params.set('SSL_KEY_PATH', keyPath);
} else if (type === 'oracle') {
params.set('SSL', 'TRUE');
params.set('SSL VERIFY', mode === 'required' ? 'TRUE' : 'FALSE');
} else if (type === 'tdengine') {
params.set('protocol', 'wss');
if (mode === 'skip-verify' || mode === 'preferred') {
params.set('skip_verify', 'true');
}
}
} else if (supportsSSLForType(type)) {
if (type === 'postgres' || type === 'kingbase' || type === 'highgo' || type === 'vastbase') {
params.set('sslmode', 'disable');
} else if (type === 'sqlserver') {
params.set('encrypt', 'disable');
params.set('TrustServerCertificate', 'true');
} else if (type === 'tdengine') {
params.set('protocol', 'ws');
}
}
const query = params.toString();
return `${scheme}://${encodedAuth}${toAddress(host, port, defaultPort)}${dbPath}${query ? `?${query}` : ''}`;
};
const handleGenerateURI = () => {
@@ -838,6 +1037,10 @@ const ConnectionModal: React.FC<{
uri: config.uri || '',
includeDatabases: initialValues.includeDatabases,
includeRedisDatabases: initialValues.includeRedisDatabases,
useSSL: !!config.useSSL,
sslMode: config.sslMode || 'preferred',
sslCertPath: config.sslCertPath || '',
sslKeyPath: config.sslKeyPath || '',
useSSH: config.useSSH,
sshHost: config.ssh?.host,
sshPort: config.ssh?.port,
@@ -871,6 +1074,7 @@ const ConnectionModal: React.FC<{
mongoReplicaUser: config.mongoReplicaUser || '',
mongoReplicaPassword: config.mongoReplicaPassword || ''
});
setUseSSL(!!config.useSSL);
setUseSSH(config.useSSH || false);
setUseProxy(config.useProxy || false);
setDbType(configType);
@@ -882,6 +1086,7 @@ const ConnectionModal: React.FC<{
// Create mode: Start at step 1
setStep(1);
form.resetFields();
setUseSSL(false);
setUseSSH(false);
setUseProxy(false);
setDbType('mysql');
@@ -932,6 +1137,7 @@ const ConnectionModal: React.FC<{
setLoading(false);
form.resetFields();
setUseSSL(false);
setUseSSH(false);
setUseProxy(false);
setDbType('mysql');
@@ -1073,6 +1279,21 @@ const ConnectionModal: React.FC<{
const type = String(mergedValues.type || '').toLowerCase();
const defaultPort = getDefaultPortByType(type);
const isFileDbType = isFileDatabaseType(type);
const sslCapableType = supportsSSLForType(type);
const sslModeRaw = String(mergedValues.sslMode || 'preferred').trim().toLowerCase();
const sslMode: 'preferred' | 'required' | 'skip-verify' | 'disable' = sslModeRaw === 'required'
? 'required'
: sslModeRaw === 'skip-verify'
? 'skip-verify'
: sslModeRaw === 'disable'
? 'disable'
: 'preferred';
const effectiveUseSSL = sslCapableType && !!mergedValues.useSSL;
const sslCertPath = sslCapableType ? String(mergedValues.sslCertPath || '').trim() : '';
const sslKeyPath = sslCapableType ? String(mergedValues.sslKeyPath || '').trim() : '';
if (type === 'dameng' && effectiveUseSSL && (!sslCertPath || !sslKeyPath)) {
throw new Error('达梦启用 SSL 时必须填写证书路径与私钥路径');
}
let primaryHost = 'localhost';
let primaryPort = defaultPort;
@@ -1194,6 +1415,10 @@ const ConnectionModal: React.FC<{
password: keepPassword ? (mergedValues.password || "") : "",
savePassword: savePassword,
database: mergedValues.database || "",
useSSL: effectiveUseSSL,
sslMode: effectiveUseSSL ? sslMode : 'disable',
sslCertPath: sslCertPath,
sslKeyPath: sslKeyPath,
useSSH: !!mergedValues.useSSH,
ssh: sshConfig,
useProxy: effectiveUseProxy,
@@ -1233,6 +1458,7 @@ const ConnectionModal: React.FC<{
const defaultPort = getDefaultPortByType(type);
if (isFileDatabaseType(type)) {
setUseSSL(false);
setUseSSH(false);
setUseProxy(false);
form.setFieldsValue({
@@ -1241,6 +1467,10 @@ const ConnectionModal: React.FC<{
user: '',
password: '',
database: '',
useSSL: false,
sslMode: 'preferred',
sslCertPath: '',
sslKeyPath: '',
useSSH: false,
sshHost: '',
sshPort: 22,
@@ -1273,10 +1503,16 @@ const ConnectionModal: React.FC<{
});
} else if (type !== 'custom') {
const defaultUser = type === 'clickhouse' ? 'default' : 'root';
const sslCapableType = supportsSSLForType(type);
setUseSSL(false);
form.setFieldsValue({
user: defaultUser,
database: '',
port: defaultPort,
useSSL: sslCapableType ? false : undefined,
sslMode: sslCapableType ? 'preferred' : undefined,
sslCertPath: sslCapableType ? '' : undefined,
sslKeyPath: sslCapableType ? '' : undefined,
mysqlTopology: 'single',
redisTopology: 'single',
mongoTopology: 'single',
@@ -1420,6 +1656,10 @@ const ConnectionModal: React.FC<{
port: 3306,
database: '',
user: 'root',
useSSL: false,
sslMode: 'preferred',
sslCertPath: '',
sslKeyPath: '',
useSSH: false,
sshPort: 22,
useProxy: false,
@@ -1451,6 +1691,7 @@ const ConnectionModal: React.FC<{
if (changed.uri !== undefined || changed.type !== undefined) {
setUriFeedback(null);
}
if (changed.useSSL !== undefined) setUseSSL(changed.useSSL);
if (changed.useSSH !== undefined) setUseSSH(changed.useSSH);
if (changed.useProxy !== undefined) setUseProxy(changed.useProxy);
if (changed.proxyType !== undefined) {
@@ -1835,6 +2076,56 @@ const ConnectionModal: React.FC<{
{!isFileDb && (
<>
{isSSLType && (
<>
<Divider style={{ margin: '12px 0' }} />
<Form.Item name="useSSL" valuePropName="checked" style={{ marginBottom: 0 }}>
<Checkbox>使 SSL/TLS</Checkbox>
</Form.Item>
{useSSL && (
<div style={tunnelSectionStyle}>
<Form.Item
name="sslMode"
label="SSL 模式"
rules={[{ required: true, message: '请选择 SSL 模式' }]}
style={{ marginBottom: 8 }}
>
<Select
options={[
{ value: 'preferred', label: 'Preferred优先 SSL推荐' },
{ value: 'required', label: 'Required必须 SSL校验证书' },
{ value: 'skip-verify', label: 'Skip Verify必须 SSL跳过证书校验' },
]}
/>
</Form.Item>
{dbType === 'dameng' && (
<>
<Form.Item
name="sslCertPath"
label="客户端证书路径 (SSL_CERT_PATH)"
rules={[{ required: true, message: '达梦 SSL 需要证书路径' }]}
style={{ marginBottom: 8 }}
>
<Input placeholder="例如: C:\\certs\\client-cert.pem" />
</Form.Item>
<Form.Item
name="sslKeyPath"
label="客户端私钥路径 (SSL_KEY_PATH)"
rules={[{ required: true, message: '达梦 SSL 需要私钥路径' }]}
style={{ marginBottom: 8 }}
>
<Input placeholder="例如: C:\\certs\\client-key.pem" />
</Form.Item>
</>
)}
<Text type="secondary" style={{ fontSize: 12 }}>
{sslHintText}
</Text>
</div>
)}
</>
)}
<Divider style={{ margin: '12px 0' }} />
<Form.Item name="useSSH" valuePropName="checked" style={{ marginBottom: 0 }}>
<Checkbox>使 SSH (SSH Tunnel)</Checkbox>

View File

@@ -580,6 +580,7 @@ interface DataGridProps {
onToggleFilter?: () => void;
exportSqlWithFilter?: string;
onApplyFilter?: (conditions: GridFilterCondition[]) => void;
appliedFilterConditions?: FilterCondition[];
}
type GridFilterCondition = FilterCondition & {
@@ -599,7 +600,7 @@ type ColumnMeta = {
const DataGrid: React.FC<DataGridProps> = ({
data, columnNames, loading, tableName, exportScope = 'table', resultSql, dbName, connectionId, pkColumns = [], readOnly = false,
onReload, onSort, onPageChange, pagination, onRequestTotalCount, onCancelTotalCount, sortInfoExternal, showFilter, onToggleFilter, exportSqlWithFilter, onApplyFilter
onReload, onSort, onPageChange, pagination, onRequestTotalCount, onCancelTotalCount, sortInfoExternal, showFilter, onToggleFilter, exportSqlWithFilter, onApplyFilter, appliedFilterConditions
}) => {
const connections = useStore(state => state.connections);
const addSqlLog = useStore(state => state.addSqlLog);
@@ -1064,10 +1065,40 @@ const DataGrid: React.FC<DataGridProps> = ({
const [modifiedRows, setModifiedRows] = useState<Record<string, any>>({});
const [deletedRowKeys, setDeletedRowKeys] = useState<Set<string>>(new Set());
const normalizeFilterLogic = useCallback((logic: unknown): 'AND' | 'OR' => {
return String(logic || '').trim().toUpperCase() === 'OR' ? 'OR' : 'AND';
}, []);
const normalizeGridFilterConditions = useCallback((conditions?: FilterCondition[]): GridFilterCondition[] => {
if (!Array.isArray(conditions)) return [];
return conditions.map((cond, index) => {
const fallbackId = index + 1;
const nextId = Number.isFinite(Number(cond?.id)) ? Number(cond?.id) : fallbackId;
const op = String(cond?.op || '=');
const rawColumn = String(cond?.column || '');
return {
id: nextId,
enabled: cond?.enabled !== false,
logic: normalizeFilterLogic(cond?.logic),
column: rawColumn || (op === 'CUSTOM' ? '' : String(columnNames[0] || '')),
op,
value: String(cond?.value ?? ''),
value2: String(cond?.value2 ?? ''),
};
});
}, [columnNames, normalizeFilterLogic]);
// Filter State
const [filterConditions, setFilterConditions] = useState<GridFilterCondition[]>([]);
const [nextFilterId, setNextFilterId] = useState(1);
useEffect(() => {
const nextConditions = normalizeGridFilterConditions(appliedFilterConditions);
setFilterConditions(nextConditions);
const maxId = nextConditions.reduce((max, cond) => (cond.id > max ? cond.id : max), 0);
setNextFilterId(Math.max(1, maxId + 1));
}, [appliedFilterConditions, normalizeGridFilterConditions]);
const selectedRowKeysRef = useRef(selectedRowKeys);
const displayDataRef = useRef<any[]>([]);
@@ -2547,6 +2578,10 @@ const DataGrid: React.FC<DataGridProps> = ({
{ value: 'NOT_IN', label: '不在列表' },
{ value: 'CUSTOM', label: '[自定义]' },
]), []);
const filterLogicOptions = useMemo(() => ([
{ value: 'AND', label: '且 (AND)' },
{ value: 'OR', label: '或 (OR)' },
]), []);
const isNoValueOp = useCallback((op: string) => (
op === 'IS_NULL' || op === 'IS_NOT_NULL' || op === 'IS_EMPTY' || op === 'IS_NOT_EMPTY'
@@ -2555,7 +2590,18 @@ const DataGrid: React.FC<DataGridProps> = ({
const isListOp = useCallback((op: string) => op === 'IN' || op === 'NOT_IN', []);
const addFilter = () => {
setFilterConditions([...filterConditions, { id: nextFilterId, enabled: true, column: columnNames[0] || '', op: '=', value: '', value2: '' }]);
setFilterConditions([
...filterConditions,
{
id: nextFilterId,
enabled: true,
logic: 'AND',
column: columnNames[0] || '',
op: '=',
value: '',
value2: '',
}
]);
setNextFilterId(nextFilterId + 1);
};
const updateFilter = (id: number, field: keyof GridFilterCondition, val: string | boolean) => {
@@ -3066,7 +3112,7 @@ const DataGrid: React.FC<DataGridProps> = ({
background: 'transparent',
boxSizing: 'border-box',
}}>
{filterConditions.map(cond => (
{filterConditions.map((cond, condIndex) => (
<div key={cond.id} style={{ display: 'flex', gap: 8, marginBottom: 8, alignItems: 'flex-start', opacity: cond.enabled === false ? 0.58 : 1 }}>
<Checkbox
checked={cond.enabled !== false}
@@ -3075,13 +3121,33 @@ const DataGrid: React.FC<DataGridProps> = ({
>
</Checkbox>
<Select
style={{ width: 180 }}
value={cond.column}
onChange={v => updateFilter(cond.id, 'column', v)}
options={columnNames.map(c => ({ value: c, label: c }))}
disabled={cond.op === 'CUSTOM'}
/>
{condIndex === 0 ? (
<div style={{ width: 96, marginTop: 7, textAlign: 'center', fontSize: 12, color: '#8c8c8c' }}>
</div>
) : (
<Select
style={{ width: 96 }}
value={cond.logic === 'OR' ? 'OR' : 'AND'}
onChange={v => updateFilter(cond.id, 'logic', v)}
options={filterLogicOptions as any}
/>
)}
<Select
style={{ width: 180 }}
value={cond.column}
onChange={v => updateFilter(cond.id, 'column', v)}
options={columnNames.map(c => ({ value: c, label: c }))}
showSearch
optionFilterProp="label"
filterOption={(input, option) =>
String(option?.label ?? '')
.toLowerCase()
.includes(String(input || '').trim().toLowerCase())
}
placeholder="搜索字段名"
disabled={cond.op === 'CUSTOM'}
/>
<Select
style={{ width: 140 }}
value={cond.op}

View File

@@ -5,6 +5,7 @@ import { useStore } from '../store';
import { DBQuery, DBGetColumns } from '../../wailsjs/go/app/App';
import DataGrid, { GONAVI_ROW_KEY } from './DataGrid';
import { buildOrderBySQL, buildWhereSQL, quoteIdentPart, quoteQualifiedIdent, withSortBufferTuningSQL, type FilterCondition } from '../utils/sql';
import { buildMongoCountCommand, buildMongoFilter, buildMongoFindCommand, buildMongoSort } from '../utils/mongodb';
import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities';
type ViewerPaginationState = {
@@ -151,6 +152,37 @@ const reverseOrderBySQL = (orderBySQL: string): string => {
return ` ORDER BY ${parts.join(', ')}`;
};
type ViewerFilterSnapshot = {
showFilter: boolean;
conditions: FilterCondition[];
};
const viewerFilterSnapshotsByTab = new Map<string, ViewerFilterSnapshot>();
const normalizeViewerFilterConditions = (conditions: FilterCondition[] | undefined): FilterCondition[] => {
if (!Array.isArray(conditions)) return [];
return conditions.map((cond) => ({
id: Number.isFinite(Number(cond?.id)) ? Number(cond?.id) : undefined,
enabled: cond?.enabled !== false,
logic: String(cond?.logic || '').trim().toUpperCase() === 'OR' ? 'OR' : 'AND',
column: String(cond?.column || ''),
op: String(cond?.op || '='),
value: String(cond?.value ?? ''),
value2: String(cond?.value2 ?? ''),
}));
};
const getViewerFilterSnapshot = (tabId: string): ViewerFilterSnapshot => {
const cached = viewerFilterSnapshotsByTab.get(String(tabId || '').trim());
if (!cached) {
return { showFilter: false, conditions: [] };
}
return {
showFilter: cached.showFilter === true,
conditions: normalizeViewerFilterConditions(cached.conditions),
};
};
const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
const [data, setData] = useState<any[]>([]);
const [columnNames, setColumnNames] = useState<string[]>([]);
@@ -185,14 +217,27 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
const [sortInfo, setSortInfo] = useState<{ columnKey: string, order: string } | null>(null);
const [showFilter, setShowFilter] = useState(false);
const [filterConditions, setFilterConditions] = useState<FilterCondition[]>([]);
const [showFilter, setShowFilter] = useState<boolean>(() => getViewerFilterSnapshot(tab.id).showFilter);
const [filterConditions, setFilterConditions] = useState<FilterCondition[]>(() => getViewerFilterSnapshot(tab.id).conditions);
const duckdbSafeSelectCacheRef = useRef<Record<string, string>>({});
const currentConnConfig = connections.find(c => c.id === tab.connectionId)?.config;
const currentConnCaps = getDataSourceCapabilities(currentConnConfig);
const currentConnType = currentConnCaps.type;
const forceReadOnly = currentConnCaps.forceReadOnlyQueryResult;
useEffect(() => {
const snapshot = getViewerFilterSnapshot(tab.id);
setShowFilter(snapshot.showFilter);
setFilterConditions(snapshot.conditions);
}, [tab.id]);
useEffect(() => {
viewerFilterSnapshotsByTab.set(tab.id, {
showFilter,
conditions: normalizeViewerFilterConditions(filterConditions),
});
}, [tab.id, showFilter, filterConditions]);
useEffect(() => {
setPkColumns([]);
pkKeyRef.current = '';
@@ -315,42 +360,67 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
const dbName = tab.dbName || '';
const tableName = tab.tableName || '';
const isMongoDB = dbTypeLower === 'mongodb';
let mongoFilter: Record<string, unknown> | undefined;
if (isMongoDB) {
try {
mongoFilter = buildMongoFilter(filterConditions);
} catch (e: any) {
message.error(`Mongo 筛选条件无效:${String(e?.message || e || '解析失败')}`);
if (fetchSeqRef.current === seq) setLoading(false);
return;
}
}
const whereSQL = buildWhereSQL(dbType, filterConditions);
const countSql = `SELECT COUNT(*) as total FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
const baseSql = `SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
const orderBySQL = buildOrderBySQL(dbType, sortInfo, pkColumns);
let sql = `${baseSql}${orderBySQL}`;
const whereSQL = isMongoDB
? JSON.stringify(mongoFilter || {})
: buildWhereSQL(dbType, filterConditions);
const countSql = isMongoDB
? buildMongoCountCommand(tableName, mongoFilter || {})
: `SELECT COUNT(*) as total FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
const orderBySQL = isMongoDB ? '' : buildOrderBySQL(dbType, sortInfo, pkColumns);
const totalRows = Number(pagination.total);
const hasFiniteTotal = Number.isFinite(totalRows) && totalRows >= 0;
const totalKnown = pagination.totalKnown && hasFiniteTotal;
const totalPages = hasFiniteTotal ? Math.max(1, Math.ceil(totalRows / size)) : 0;
const currentPage = totalPages > 0 ? Math.min(Math.max(1, page), totalPages) : Math.max(1, page);
const offset = (currentPage - 1) * size;
const isClickHouse = dbTypeLower === 'clickhouse';
const isClickHouse = !isMongoDB && dbTypeLower === 'clickhouse';
const reverseOrderSQL = isClickHouse ? reverseOrderBySQL(orderBySQL) : '';
let useClickHouseReversePagination = false;
let clickHouseReverseLimit = 0;
let clickHouseReverseHasMore = false;
// ClickHouse 深分页在超大 OFFSET 下容易超时。对于总数已知且存在 ORDER BY 的场景,
// 当“尾部偏移”小于“头部偏移”时,改为反向 ORDER BY + 小 OFFSET并在前端翻转结果。
if (isClickHouse && totalKnown && offset > 0 && reverseOrderSQL) {
const pageRowCount = Math.max(0, Math.min(size, totalRows - offset));
if (pageRowCount > 0) {
const tailOffset = Math.max(0, totalRows - (offset + pageRowCount));
if (tailOffset < offset) {
sql = `${baseSql}${reverseOrderSQL} LIMIT ${pageRowCount} OFFSET ${tailOffset}`;
useClickHouseReversePagination = true;
clickHouseReverseLimit = pageRowCount;
clickHouseReverseHasMore = currentPage < totalPages;
let sql = '';
if (isMongoDB) {
const mongoSort = buildMongoSort(sortInfo, pkColumns);
sql = buildMongoFindCommand({
collection: tableName,
filter: mongoFilter || {},
sort: mongoSort,
limit: size + 1,
skip: offset,
});
} else {
const baseSql = `SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
sql = `${baseSql}${orderBySQL}`;
// ClickHouse 深分页在超大 OFFSET 下容易超时。对于总数已知且存在 ORDER BY 的场景,
// 当“尾部偏移”小于“头部偏移”时,改为反向 ORDER BY + 小 OFFSET并在前端翻转结果。
if (isClickHouse && totalKnown && offset > 0 && reverseOrderSQL) {
const pageRowCount = Math.max(0, Math.min(size, totalRows - offset));
if (pageRowCount > 0) {
const tailOffset = Math.max(0, totalRows - (offset + pageRowCount));
if (tailOffset < offset) {
sql = `${baseSql}${reverseOrderSQL} LIMIT ${pageRowCount} OFFSET ${tailOffset}`;
useClickHouseReversePagination = true;
clickHouseReverseLimit = pageRowCount;
clickHouseReverseHasMore = currentPage < totalPages;
}
}
}
}
if (!useClickHouseReversePagination) {
// 大表性能:打开表不阻塞在 COUNT(*),先通过多取 1 条判断是否还有下一页;总数在后台统计并异步回填。
sql += ` LIMIT ${size + 1} OFFSET ${offset}`;
if (!useClickHouseReversePagination) {
// 大表性能:打开表不阻塞在 COUNT(*),先通过多取 1 条判断是否还有下一页;总数在后台统计并异步回填。
sql += ` LIMIT ${size + 1} OFFSET ${offset}`;
}
}
const requestStartTime = Date.now();
@@ -718,6 +788,7 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
showFilter={showFilter}
onToggleFilter={handleToggleFilter}
onApplyFilter={handleApplyFilter}
appliedFilterConditions={filterConditions}
readOnly={forceReadOnly}
sortInfoExternal={sortInfo}
exportSqlWithFilter={exportSqlWithFilter || undefined}

View File

@@ -9,6 +9,8 @@ import { useStore } from '../store';
import { DBQuery, DBQueryWithCancel, DBGetTables, DBGetAllColumns, DBGetDatabases, DBGetColumns, CancelQuery, GenerateQueryID } from '../../wailsjs/go/app/App';
import DataGrid, { GONAVI_ROW_KEY } from './DataGrid';
import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities';
import { convertMongoShellToJsonCommand } from '../utils/mongodb';
import { getShortcutDisplay, isEditableElement, isShortcutMatch } from '../utils/shortcuts';
const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
const [query, setQuery] = useState(tab.query || 'SELECT * FROM ');
@@ -68,6 +70,8 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
const setSqlFormatOptions = useStore(state => state.setSqlFormatOptions);
const queryOptions = useStore(state => state.queryOptions);
const setQueryOptions = useStore(state => state.setQueryOptions);
const shortcutOptions = useStore(state => state.shortcutOptions);
const activeTabId = useStore(state => state.activeTabId);
useEffect(() => {
currentConnectionIdRef.current = currentConnectionId;
@@ -268,6 +272,19 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
return parts[parts.length - 1] || raw;
};
const splitSchemaAndTable = (qualified: string): { schema: string; table: string } => {
const raw = normalizeQualifiedName(qualified);
if (!raw) return { schema: '', table: '' };
const parts = raw.split('.').filter(Boolean);
if (parts.length >= 2) {
return {
schema: parts[parts.length - 2] || '',
table: parts[parts.length - 1] || '',
};
}
return { schema: '', table: parts[0] || '' };
};
const buildConnConfig = () => {
const connId = currentConnectionIdRef.current;
const conn = connectionsRef.current.find(c => c.id === connId);
@@ -340,13 +357,14 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
if (qualifierMatch) {
const qualifier = stripQuotes(qualifierMatch[1]);
const prefix = (qualifierMatch[2] || '').toLowerCase();
const qualifierLower = qualifier.toLowerCase();
// 首先检查 qualifier 是否是数据库名(跨库表提示)
const visibleDbs = visibleDbsRef.current;
if (visibleDbs.some(db => db.toLowerCase() === qualifier.toLowerCase())) {
if (visibleDbs.some(db => db.toLowerCase() === qualifierLower)) {
// qualifier 是数据库名,提示该库的表
const tables = tablesRef.current.filter(t =>
(t.dbName || '').toLowerCase() === qualifier.toLowerCase()
(t.dbName || '').toLowerCase() === qualifierLower
);
const filtered = prefix
? tables.filter(t => (t.tableName || '').toLowerCase().startsWith(prefix))
@@ -363,6 +381,34 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
return { suggestions };
}
// qualifier 是 schema如 dbo/public仅补全表名避免输入 dbo. 后再补成 dbo.dbo.table
const schemaTables = tablesRef.current
.map(t => {
const parsed = splitSchemaAndTable(t.tableName || '');
return {
dbName: t.dbName || '',
schema: parsed.schema,
table: parsed.table,
};
})
.filter(t => t.schema.toLowerCase() === qualifierLower && !!t.table);
if (schemaTables.length > 0) {
const filtered = prefix
? schemaTables.filter(t => t.table.toLowerCase().startsWith(prefix))
: schemaTables;
const suggestions = filtered.map(t => ({
label: t.table,
kind: monaco.languages.CompletionItemKind.Class,
insertText: t.table,
detail: `Table (${t.dbName}${t.schema ? '.' + t.schema : ''})`,
range,
sortText: '0' + t.table
}));
return { suggestions };
}
// 否则检查是否是表别名或表名,提示列
const reserved = new Set([
'where', 'on', 'group', 'order', 'limit', 'having',
@@ -531,6 +577,12 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
icon: sqlFormatOptions.keywordCase === 'lower' ? '✓' : undefined,
onClick: () => setSqlFormatOptions({ keywordCase: 'lower' })
},
{ type: 'divider' },
{
key: 'shortcut-settings',
label: '快捷键管理...',
onClick: () => window.dispatchEvent(new CustomEvent('gonavi:open-shortcut-settings')),
},
];
const splitSQLStatements = (sql: string): string[] => {
@@ -1035,7 +1087,15 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
try {
const rawSQL = getSelectedSQL() || query;
const statements = splitSQLStatements(rawSQL);
const dbType = String((config as any).type || 'mysql');
const normalizedDbType = dbType.trim().toLowerCase();
const normalizedRawSQL = String(rawSQL || '').replace(//g, ';');
const splitInput = normalizedDbType === 'mongodb'
? normalizedRawSQL
.replace(/^\s*\/\/.*$/gm, '')
.replace(/^\s*#.*$/gm, '')
: normalizedRawSQL;
const statements = splitSQLStatements(splitInput);
if (statements.length === 0) {
message.info('没有可执行的 SQL。');
setResultSets([]);
@@ -1045,7 +1105,6 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
const nextResultSets: ResultSet[] = [];
const maxRows = Number(queryOptions?.maxRows) || 0;
const dbType = String((config as any).type || 'mysql');
const forceReadOnlyResult = connCaps.forceReadOnlyQueryResult;
const wantsLimitProbe = Number.isFinite(maxRows) && maxRows > 0;
const probeLimit = wantsLimitProbe ? (maxRows + 1) : 0;
@@ -1059,9 +1118,24 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
const limitApplied = shouldAutoLimit && wantsLimitProbe;
const limited = limitApplied ? applyAutoLimit(rawStatement, dbType, probeLimit) : { sql: rawStatement, applied: false, maxRows: probeLimit };
const executedSql = limited.sql;
let executedSql = limited.sql;
if (String(dbType || '').trim().toLowerCase() === 'mongodb') {
const shellConvert = convertMongoShellToJsonCommand(executedSql);
if (shellConvert.recognized) {
if (shellConvert.error) {
const prefix = statements.length > 1 ? `${idx + 1} 条语句执行失败:` : '';
message.error(prefix + shellConvert.error);
setResultSets([]);
setActiveResultKey('');
return;
}
if (shellConvert.command) {
executedSql = shellConvert.command;
}
}
}
const startTime = Date.now();
// Generate query ID for cancellation using backend UUID with fallback
let queryId: string;
try {
@@ -1071,7 +1145,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
queryId = 'query-' + uuidv4();
}
setQueryId(queryId);
const res = await DBQueryWithCancel(config as any, currentDb, executedSql, queryId);
const duration = Date.now() - startTime;
@@ -1089,19 +1163,19 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
if (!res.success) {
// 检查是否为查询取消错误
const errorMsg = res.message.toLowerCase();
const isCancelledError = errorMsg.includes('context canceled') ||
const isCancelledError = errorMsg.includes('context canceled') ||
errorMsg.includes('查询已取消') ||
errorMsg.includes('canceled') ||
errorMsg.includes('cancelled') ||
errorMsg.includes('statement canceled') ||
errorMsg.includes('sql: statement canceled');
// 确保不是超时错误
const isTimeoutError = errorMsg.includes('context deadline exceeded') ||
errorMsg.includes('timeout') ||
errorMsg.includes('超时') ||
errorMsg.includes('deadline exceeded');
if (isCancelledError && !isTimeoutError) {
// 查询已被用户取消,不显示错误消息,清理状态
setResultSets([]);
@@ -1112,7 +1186,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
}
return;
}
const prefix = statements.length > 1 ? `${idx + 1} 条语句执行失败:` : '';
message.error(prefix + res.message);
setResultSets([]);
@@ -1245,6 +1319,48 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
}
};
useEffect(() => {
const binding = shortcutOptions.runQuery;
if (!binding?.enabled || !binding.combo) {
return;
}
const handleRunShortcut = (event: KeyboardEvent) => {
if (activeTabId !== tab.id) {
return;
}
if (!isShortcutMatch(event, binding.combo)) {
return;
}
const editorHasFocus = !!editorRef.current?.hasTextFocus?.();
if (!editorHasFocus && !isEditableElement(event.target)) {
return;
}
event.preventDefault();
event.stopPropagation();
void handleRun();
};
window.addEventListener('keydown', handleRunShortcut);
return () => {
window.removeEventListener('keydown', handleRunShortcut);
};
}, [activeTabId, tab.id, shortcutOptions.runQuery, handleRun]);
useEffect(() => {
const handleRunActiveQuery = () => {
if (activeTabId !== tab.id) {
return;
}
void handleRun();
};
window.addEventListener('gonavi:run-active-query', handleRunActiveQuery as EventListener);
return () => {
window.removeEventListener('gonavi:run-active-query', handleRunActiveQuery as EventListener);
};
}, [activeTabId, tab.id, handleRun]);
const handleSave = async () => {
try {
const values = await saveForm.validateFields();
@@ -1357,9 +1473,17 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
/>
</Tooltip>
<Button.Group>
<Button type="primary" icon={<PlayCircleOutlined />} onClick={handleRun} loading={loading}>
</Button>
<Tooltip
title={
shortcutOptions.runQuery?.enabled && shortcutOptions.runQuery?.combo
? `运行(${getShortcutDisplay(shortcutOptions.runQuery.combo)}`
: '运行'
}
>
<Button type="primary" icon={<PlayCircleOutlined />} onClick={handleRun} loading={loading}>
</Button>
</Tooltip>
{loading && (
<Button type="primary" danger icon={<StopOutlined />} onClick={handleCancel}>

View File

@@ -1,5 +1,5 @@
import React, { useEffect, useState, useMemo, useRef } from 'react';
import { Tree, message, Dropdown, MenuProps, Input, Button, Modal, Form, Badge, Checkbox, Space, Select } from 'antd';
import { Tree, message, Dropdown, MenuProps, Input, Button, Modal, Form, Badge, Checkbox, Space, Select, Popover, Tooltip } from 'antd';
import {
DatabaseOutlined,
TableOutlined,
@@ -50,6 +50,7 @@ type BatchTableExportMode = 'schema' | 'backup' | 'dataOnly';
type BatchObjectType = 'table' | 'view';
type BatchObjectFilterType = 'all' | BatchObjectType;
type BatchSelectionScope = 'filtered' | 'all';
type SearchScope = 'smart' | 'object' | 'database' | 'host' | 'tag';
interface BatchObjectItem {
title: string;
@@ -59,9 +60,23 @@ interface BatchObjectItem {
dataRef: any;
}
const SEARCH_SCOPE_OPTIONS: Array<{ value: SearchScope; label: string }> = [
{ value: 'smart', label: '智能' },
{ value: 'object', label: '表对象' },
{ value: 'database', label: '库' },
{ value: 'host', label: 'Host' },
{ value: 'tag', label: '标签' },
];
const SEARCH_SCOPE_LABEL_MAP: Record<SearchScope, string> = SEARCH_SCOPE_OPTIONS.reduce((acc, option) => {
acc[option.value] = option.label;
return acc;
}, {} as Record<SearchScope, string>);
const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> = ({ onEditConnection }) => {
const connections = useStore(state => state.connections);
const savedQueries = useStore(state => state.savedQueries);
const addConnection = useStore(state => state.addConnection);
const addTab = useStore(state => state.addTab);
const setActiveContext = useStore(state => state.setActiveContext);
const removeConnection = useStore(state => state.removeConnection);
@@ -94,6 +109,9 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
};
const bgMain = getBg('#141414');
const [searchValue, setSearchValue] = useState('');
const [searchScopes, setSearchScopes] = useState<SearchScope[]>(['smart']);
const [isSearchScopePopoverOpen, setIsSearchScopePopoverOpen] = useState(false);
const searchInputRef = useRef<any>(null);
const [expandedKeys, setExpandedKeys] = useState<React.Key[]>([]);
const [autoExpandParent, setAutoExpandParent] = useState(true);
const [loadedKeys, setLoadedKeys] = useState<React.Key[]>([]);
@@ -116,6 +134,21 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
resizeObserver.observe(treeContainerRef.current);
return () => resizeObserver.disconnect();
}, []);
useEffect(() => {
const handleFocusSidebarSearch = () => {
const inputEl = searchInputRef.current?.input as HTMLInputElement | undefined;
if (!inputEl) {
return;
}
inputEl.focus();
inputEl.select();
};
window.addEventListener('gonavi:focus-sidebar-search', handleFocusSidebarSearch as EventListener);
return () => {
window.removeEventListener('gonavi:focus-sidebar-search', handleFocusSidebarSearch as EventListener);
};
}, []);
// Connection Status State: key -> 'success' | 'error'
const [connectionStates, setConnectionStates] = useState<Record<string, 'success' | 'error'>>({});
@@ -219,7 +252,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
useEffect(() => {
setTreeData((prev) => {
const prevMap = new Map<string, TreeNode>();
// We need to recursively extract connections from old tag structures
// so if a user expands a connection that was tagged, the state remains
const recurseCollect = (nodes: TreeNode[]) => {
@@ -271,6 +304,118 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
});
}, [connections, connectionTags]);
const buildDuplicateConnectionName = (rawName: string): string => {
const baseName = String(rawName || '').trim() || '连接';
const suffix = ' - 副本';
const usedNames = new Set(connections.map(conn => String(conn.name || '').trim()));
let candidate = `${baseName}${suffix}`;
let counter = 2;
while (usedNames.has(candidate)) {
candidate = `${baseName}${suffix} ${counter}`;
counter += 1;
}
return candidate;
};
const cloneConnectionConfig = (config: SavedConnection['config']): SavedConnection['config'] => {
const raw: any = config || {};
let cloned: any = {};
try {
cloned = typeof structuredClone === 'function'
? structuredClone(raw)
: JSON.parse(JSON.stringify(raw));
} catch {
cloned = { ...raw };
}
const readString = (...values: unknown[]): string => {
for (const value of values) {
if (typeof value === 'string') {
return value;
}
}
return '';
};
const readBool = (fallback: boolean, ...values: unknown[]): boolean => {
for (const value of values) {
if (typeof value === 'boolean') {
return value;
}
}
return fallback;
};
const readNumber = (fallback: number, ...values: unknown[]): number => {
for (const value of values) {
const num = Number(value);
if (Number.isFinite(num)) {
return num;
}
}
return fallback;
};
const rawSSH = (cloned.ssh ?? cloned.SSH ?? {}) as Record<string, unknown>;
const normalizedSSH = {
host: readString(rawSSH.host, rawSSH.Host, cloned.sshHost, cloned.SSHHost),
port: readNumber(22, rawSSH.port, rawSSH.Port, cloned.sshPort, cloned.SSHPort),
user: readString(rawSSH.user, rawSSH.User, cloned.sshUser, cloned.SSHUser),
password: readString(rawSSH.password, rawSSH.Password, cloned.sshPassword, cloned.SSHPassword),
keyPath: readString(rawSSH.keyPath, rawSSH.KeyPath, cloned.sshKeyPath, cloned.SSHKeyPath),
};
const hasSSHDetail = Boolean(
normalizedSSH.host
|| normalizedSSH.user
|| normalizedSSH.password
|| normalizedSSH.keyPath
);
const rawProxy = (cloned.proxy ?? cloned.Proxy ?? {}) as Record<string, unknown>;
const proxyTypeRaw = readString(rawProxy.type, rawProxy.Type, cloned.proxyType, cloned.ProxyType).toLowerCase();
const proxyType: 'socks5' | 'http' = proxyTypeRaw === 'http' ? 'http' : 'socks5';
const normalizedProxy = {
type: proxyType,
host: readString(rawProxy.host, rawProxy.Host, cloned.proxyHost, cloned.ProxyHost),
port: readNumber(proxyType === 'http' ? 8080 : 1080, rawProxy.port, rawProxy.Port, cloned.proxyPort, cloned.ProxyPort),
user: readString(rawProxy.user, rawProxy.User, cloned.proxyUser, cloned.ProxyUser),
password: readString(rawProxy.password, rawProxy.Password, cloned.proxyPassword, cloned.ProxyPassword),
};
const hasProxyDetail = Boolean(normalizedProxy.host || normalizedProxy.user || normalizedProxy.password);
const rawHosts = Array.isArray(cloned.hosts)
? cloned.hosts
: (Array.isArray(cloned.Hosts) ? cloned.Hosts : []);
const normalizedHosts = rawHosts
.map((entry: unknown) => String(entry || '').trim())
.filter((entry: string) => !!entry);
return {
...(cloned as SavedConnection['config']),
useSSH: readBool(hasSSHDetail, cloned.useSSH, cloned.UseSSH),
ssh: normalizedSSH,
useProxy: readBool(hasProxyDetail, cloned.useProxy, cloned.UseProxy),
proxy: normalizedProxy,
hosts: normalizedHosts,
timeout: readNumber(30, cloned.timeout, cloned.Timeout),
};
};
const handleDuplicateConnection = (conn: SavedConnection) => {
if (!conn) return;
const duplicatedConnection: SavedConnection = {
...conn,
id: `${Date.now()}-${Math.random().toString(36).slice(2, 8)}`,
name: buildDuplicateConnectionName(conn.name),
config: cloneConnectionConfig(conn.config),
includeDatabases: conn.includeDatabases ? [...conn.includeDatabases] : undefined,
includeRedisDatabases: conn.includeRedisDatabases ? [...conn.includeRedisDatabases] : undefined,
};
addConnection(duplicatedConnection);
message.success(`已复制连接: ${duplicatedConnection.name}`);
};
const updateTreeData = (list: TreeNode[], key: React.Key, children: TreeNode[] | undefined): TreeNode[] => {
return list.map(node => {
if (node.key === key) {
@@ -749,7 +894,8 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
const res = await (window as any).go.app.App.RedisGetDatabases(config);
if (res.success) {
setConnectionStates(prev => ({ ...prev, [conn.id]: 'success' }));
let dbs = (res.data as any[]).map((db: any) => ({
const redisRows: any[] = Array.isArray(res.data) ? res.data : [];
let dbs = redisRows.map((db: any) => ({
title: `db${db.index}${db.keys > 0 ? ` (${db.keys})` : ''}`,
key: `${conn.id}-db${db.index}`,
icon: <DatabaseOutlined style={{ color: '#DC382D' }} />,
@@ -780,7 +926,8 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
const res = await DBGetDatabases(config as any);
if (res.success) {
setConnectionStates(prev => ({ ...prev, [conn.id]: 'success' }));
let dbs = (res.data as any[]).map((row: any) => ({
const dbRows: any[] = Array.isArray(res.data) ? res.data : [];
let dbs = dbRows.map((row: any) => ({
title: row.Database || row.database,
key: `${conn.id}-${row.Database || row.database}`,
icon: <DatabaseOutlined />,
@@ -795,10 +942,13 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
}
setTreeData(origin => updateTreeData(origin, node.key, dbs));
} else {
setConnectionStates(prev => ({ ...prev, [conn.id]: 'error' }));
message.error({ content: res.message, key: `conn-${conn.id}-dbs` });
}
} else {
setConnectionStates(prev => ({ ...prev, [conn.id]: 'error' }));
message.error({ content: res.message, key: `conn-${conn.id}-dbs` });
}
} catch (e: any) {
setConnectionStates(prev => ({ ...prev, [conn.id]: 'error' }));
message.error({ content: '连接失败: ' + (e?.message || String(e)), key: `conn-${conn.id}-dbs` });
} finally {
loadingNodesRef.current.delete(loadKey);
}
@@ -843,7 +993,8 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
if (res.success) {
setConnectionStates(prev => ({ ...prev, [key as string]: 'success' }));
const tableEntries = (res.data as any[]).map((row: any) => {
const tableRows: any[] = Array.isArray(res.data) ? res.data : [];
const tableEntries = tableRows.map((row: any) => {
const tableName = Object.values(row)[0] as string;
const parsed = splitQualifiedName(tableName);
return {
@@ -859,7 +1010,11 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
loadFunctions(conn, conn.dbName),
]);
const viewEntries = viewsResult.views.map((viewName) => {
const viewRows: string[] = Array.isArray(viewsResult.views) ? viewsResult.views : [];
const triggerRows: any[] = Array.isArray(triggersResult.triggers) ? triggersResult.triggers : [];
const routineRows: any[] = Array.isArray(routinesResult.routines) ? routinesResult.routines : [];
const viewEntries = viewRows.map((viewName: string) => {
const parsed = splitQualifiedName(viewName);
return {
viewName,
@@ -873,7 +1028,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
const triggerSeen = new Set<string>();
const metadataDialect = getMetadataDialect(conn as SavedConnection);
triggersResult.triggers.forEach((trigger) => {
triggerRows.forEach((trigger: any) => {
const triggerParsed = splitQualifiedName(trigger.triggerName);
const tableParsed = splitQualifiedName(trigger.tableName);
const schemaName = tableParsed.schemaName || triggerParsed.schemaName || String(conn.dbName || '').trim();
@@ -898,7 +1053,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
return deduped;
})();
const routineEntries = routinesResult.routines.map((routine) => {
const routineEntries = routineRows.map((routine: any) => {
const parsed = splitQualifiedName(routine.routineName);
const typeLabel = routine.routineType === 'PROCEDURE' ? 'P' : 'F';
return {
@@ -1080,6 +1235,9 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
setConnectionStates(prev => ({ ...prev, [key as string]: 'error' }));
message.error({ content: res.message, key: `db-${key}-tables` });
}
} catch (e: any) {
setConnectionStates(prev => ({ ...prev, [key as string]: 'error' }));
message.error({ content: '加载表失败: ' + (e?.message || String(e)), key: `db-${key}-tables` });
} finally {
loadingNodesRef.current.delete(loadKey);
}
@@ -1425,7 +1583,8 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
const res = await DBGetDatabases(config as any);
if (res.success) {
let dbs = (res.data as any[]).map((row: any) => {
const dbRows: any[] = Array.isArray(res.data) ? res.data : [];
let dbs = dbRows.map((row: any) => {
const dbName = row.Database || row.database;
return {
title: dbName,
@@ -1466,9 +1625,11 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
return;
}
const viewSet = new Set(viewResult.views.map(view => view.toLowerCase()));
const tableRows: any[] = Array.isArray(res.data) ? res.data : [];
const viewRows: string[] = Array.isArray(viewResult.views) ? viewResult.views : [];
const viewSet = new Set(viewRows.map((view: string) => view.toLowerCase()));
const tableObjects: BatchObjectItem[] = (res.data as any[])
const tableObjects: BatchObjectItem[] = tableRows
.map((row: any) => Object.values(row)[0] as string)
.filter((tableName: string) => !viewSet.has(tableName.toLowerCase()))
.map((tableName: string) => ({
@@ -1479,7 +1640,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
dataRef: { ...conn, tableName, dbName, objectType: 'table' },
}));
const viewObjects: BatchObjectItem[] = viewResult.views.map((viewName: string) => ({
const viewObjects: BatchObjectItem[] = viewRows.map((viewName: string) => ({
title: getSidebarTableDisplayName(conn, viewName),
key: `${conn.id}-${dbName}-view-${viewName}`,
objectName: viewName,
@@ -1645,7 +1806,8 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
const res = await DBGetDatabases(config as any);
if (res.success) {
let dbs = (res.data as any[]).map((row: any) => {
const dbRows: any[] = Array.isArray(res.data) ? res.data : [];
let dbs = dbRows.map((row: any) => {
const dbName = row.Database || row.database;
return {
title: dbName,
@@ -2232,28 +2394,205 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
setSearchValue(value);
};
const loop = (data: TreeNode[]): TreeNode[] => {
const result: TreeNode[] = [];
data.forEach(item => {
const match = item.title.toLowerCase().indexOf(searchValue.toLowerCase()) > -1;
if (item.children) {
const filteredChildren = loop(item.children);
if (filteredChildren.length > 0 || match) {
result.push({ ...item, children: filteredChildren });
}
const toggleSearchScope = (scope: SearchScope) => {
setSearchScopes((prev) => {
if (scope === 'smart') {
return ['smart'];
}
const withoutSmart = prev.filter((item) => item !== 'smart');
if (withoutSmart.includes(scope)) {
const next = withoutSmart.filter((item) => item !== scope);
return next.length > 0 ? next : ['smart'];
}
return [...withoutSmart, scope];
});
};
const setSearchScopeChecked = (scope: SearchScope, checked: boolean) => {
if (scope === 'smart') {
if (checked) {
setSearchScopes(['smart']);
} else if (searchScopes.length === 1 && searchScopes[0] === 'smart') {
setSearchScopes(['smart']);
} else {
if (match) {
setSearchScopes((prev) => {
const next = prev.filter((item) => item !== 'smart');
return next.length > 0 ? next : ['smart'];
});
}
return;
}
if (checked) {
setSearchScopes((prev) => {
const withoutSmart = prev.filter((item) => item !== 'smart');
if (withoutSmart.includes(scope)) {
return withoutSmart;
}
return [...withoutSmart, scope];
});
} else {
setSearchScopes((prev) => {
const next = prev.filter((item) => item !== scope && item !== 'smart');
return next.length > 0 ? next : ['smart'];
});
}
};
const searchScopeSummary = useMemo(() => {
if (searchScopes.includes('smart')) {
return '智能';
}
return searchScopes.map((scope) => SEARCH_SCOPE_LABEL_MAP[scope]).join(' + ');
}, [searchScopes]);
const searchScopePopoverContent = useMemo(() => {
const smartSelected = searchScopes.includes('smart');
const scopedOptions = SEARCH_SCOPE_OPTIONS.filter((option) => option.value !== 'smart');
return (
<div style={{ minWidth: 220, display: 'flex', flexDirection: 'column', gap: 8 }}>
<div style={{ fontSize: 12, color: '#8c8c8c' }}></div>
<Checkbox
checked={smartSelected}
onChange={(e) => setSearchScopeChecked('smart', e.target.checked)}
>
</Checkbox>
<div style={{ paddingLeft: 12, display: 'grid', gap: 6 }}>
{scopedOptions.map((option) => (
<Checkbox
key={option.value}
checked={searchScopes.includes(option.value)}
onChange={(e) => setSearchScopeChecked(option.value, e.target.checked)}
>
{option.label}
</Checkbox>
))}
</div>
<div style={{ fontSize: 12, color: '#8c8c8c' }}>
</div>
</div>
);
}, [searchScopes]);
const parseHostOnlyToken = (value: unknown): string[] => {
const raw = String(value || '').trim();
if (!raw) {
return [];
}
let text = raw.replace(/^[a-z][a-z0-9+.-]*:\/\//i, '');
if (text.includes('/')) {
text = text.split('/')[0];
}
if (text.includes('?')) {
text = text.split('?')[0];
}
if (text.includes('@')) {
text = text.split('@').pop() || '';
}
return text
.split(',')
.map((entry) => {
const token = entry.trim();
if (!token) return '';
if (token.startsWith('[')) {
const rightBracketIndex = token.indexOf(']');
if (rightBracketIndex > 0) {
return token.slice(0, rightBracketIndex + 1).toLowerCase();
}
}
const colonIndex = token.lastIndexOf(':');
if (colonIndex > 0) {
return token.slice(0, colonIndex).toLowerCase();
}
return token.toLowerCase();
})
.filter(Boolean);
};
const getConnectionHostSearchText = (node: TreeNode): string => {
if (node.type !== 'connection') return '';
const config = node.dataRef?.config || {};
const hostTokens = [
...parseHostOnlyToken(config.host),
...(Array.isArray(config.hosts) ? config.hosts.flatMap((entry: string) => parseHostOnlyToken(entry)) : []),
...parseHostOnlyToken(config.uri),
];
const uniqueHosts = Array.from(new Set(hostTokens));
return uniqueHosts.join(' ');
};
const getConnectionNameSearchText = (node: TreeNode): string => {
if (node.type !== 'connection') return '';
const name = node.dataRef?.name ?? node.title;
return String(name || '').toLowerCase();
};
const isObjectNode = (node: TreeNode): boolean => {
return node.type === 'table'
|| node.type === 'view'
|| node.type === 'db-trigger'
|| node.type === 'routine'
|| node.type === 'object-group';
};
const matchByScopes = (node: TreeNode, keyword: string, scopes: SearchScope[]): boolean => {
const title = String(node.title || '').toLowerCase();
if (scopes.includes('database') && node.type === 'database' && title.includes(keyword)) {
return true;
}
if (scopes.includes('tag') && node.type === 'tag' && title.includes(keyword)) {
return true;
}
if (scopes.includes('host') && node.type === 'connection' && getConnectionHostSearchText(node).includes(keyword)) {
return true;
}
if (scopes.includes('object') && isObjectNode(node) && title.includes(keyword)) {
return true;
}
return false;
};
const loop = (data: TreeNode[], keyword: string): TreeNode[] => {
const isSmartMode = searchScopes.includes('smart');
const result: TreeNode[] = [];
data.forEach((item) => {
const titleMatch = String(item.title || '').toLowerCase().includes(keyword);
const smartMatch = item.type === 'connection'
? getConnectionNameSearchText(item).includes(keyword) || getConnectionHostSearchText(item).includes(keyword)
: titleMatch;
const scopedMatch = matchByScopes(item, keyword, searchScopes);
const selfMatch = isSmartMode ? smartMatch : scopedMatch;
const filteredChildren = item.children ? loop(item.children, keyword) : [];
if (selfMatch) {
const shouldKeepFullSubtree = isSmartMode
|| item.type === 'connection'
|| item.type === 'database'
|| item.type === 'tag';
if (item.children && shouldKeepFullSubtree) {
result.push(item);
} else if (item.children && filteredChildren.length > 0) {
result.push({ ...item, children: filteredChildren });
} else {
result.push(item);
}
return;
}
if (filteredChildren.length > 0) {
result.push({ ...item, children: filteredChildren });
}
});
return result;
};
const displayTreeData = useMemo(() => {
if (!searchValue) return treeData;
return loop(treeData);
}, [searchValue, treeData]);
const keyword = searchValue.trim().toLowerCase();
if (!keyword) return treeData;
return loop(treeData, keyword);
}, [searchValue, searchScopes, treeData]);
const getNodeMenuItems = (node: any): MenuProps['items'] => {
const conn = node.dataRef as SavedConnection;
@@ -2395,6 +2734,12 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
if (onEditConnection) onEditConnection(node.dataRef);
}
},
{
key: 'copy-connection',
label: '复制连接',
icon: <CopyOutlined />,
onClick: () => handleDuplicateConnection(node.dataRef as SavedConnection)
},
{
key: 'disconnect',
label: '断开连接',
@@ -2493,6 +2838,12 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
if (onEditConnection) onEditConnection(node.dataRef);
}
},
{
key: 'copy-connection',
label: '复制连接',
icon: <CopyOutlined />,
onClick: () => handleDuplicateConnection(node.dataRef as SavedConnection)
},
{
key: 'move-to-tag',
label: '移至标签',
@@ -2504,6 +2855,13 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
label: '断开连接',
icon: <DisconnectOutlined />,
onClick: () => {
const connId = String(node.key || '');
// 强制清理该连接相关的 loading 标记,避免网络卡住后重连仍被短路。
Array.from(loadingNodesRef.current).forEach((loadingKey) => {
if (loadingKey === `dbs-${connId}` || loadingKey.startsWith(`tables-${connId}-`)) {
loadingNodesRef.current.delete(loadingKey);
}
});
// Reset status recursively
setConnectionStates(prev => {
const next = { ...prev };
@@ -2625,6 +2983,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
onClick: () => {
const dbConnId = String(node.dataRef?.id || '');
const dbName = String(node.dataRef?.dbName || node.title || '').trim();
loadingNodesRef.current.delete(`tables-${dbConnId}-${dbName}`);
setConnectionStates(prev => {
const next = { ...prev };
delete next[node.key];
@@ -2857,15 +3216,15 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
// Get current order
const currentTagOrder = connectionTags.map(t => t.id);
const dragTagId = dragNode.dataRef.id;
// Filter out the dragging tag
const newOrder = currentTagOrder.filter(id => id !== dragTagId);
let insertIndex = newOrder.length;
if (dropNode.type === 'tag') {
const dropTagId = dropNode.dataRef.id;
const dropIndex = newOrder.indexOf(dropTagId);
if (dropPosition === -1) {
insertIndex = dropIndex;
} else {
@@ -2876,7 +3235,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
// Since tags are always displayed before ungrouped connections, just put it at the end
insertIndex = newOrder.length;
}
newOrder.splice(insertIndex, 0, dragTagId);
reorderTags(newOrder);
}
@@ -2897,7 +3256,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
moveConnectionToTag(dragNode.key, targetTag.id);
return;
}
// Drop target is NOT under a tag (ungrouped) -> move OUT of tag
const sourceTag = connectionTags.find(t => t.connectionIds.includes(dragNode.key));
if (sourceTag) {
@@ -2921,7 +3280,28 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
return (
<div style={{ display: 'flex', flexDirection: 'column', height: '100%' }}>
<div style={{ padding: '4px 8px' }}>
<Search placeholder="搜索..." onChange={onSearch} size="small" />
<Space.Compact block size="small">
<Search
ref={searchInputRef}
placeholder="搜索..."
onChange={onSearch}
size="small"
style={{ width: '100%' }}
/>
<Popover
content={searchScopePopoverContent}
trigger="click"
placement="bottomRight"
open={isSearchScopePopoverOpen}
onOpenChange={setIsSearchScopePopoverOpen}
>
<Tooltip title={`搜索范围:${searchScopeSummary}`}>
<Button size="small" icon={<DownOutlined />} style={{ width: 86 }}>
{searchScopes.includes('smart') ? '(智)' : `(${searchScopes.length})`}
</Button>
</Tooltip>
</Popover>
</Space.Compact>
</div>
{/* Toolbar */}

View File

@@ -261,9 +261,18 @@ const TableDesigner: React.FC<{ tab: TabData }> = ({ tab }) => {
const darkMode = theme === 'dark';
const resizeGuideColor = darkMode ? '#f6c453' : '#1890ff';
const readOnly = !!tab.readOnly;
const panelRadius = 10;
const panelFrameColor = darkMode ? 'rgba(0, 0, 0, 0.18)' : 'rgba(0, 0, 0, 0.12)';
const panelToolbarBorder = darkMode ? 'rgba(255, 255, 255, 0.12)' : 'rgba(0, 0, 0, 0.10)';
const panelToolbarBg = darkMode ? 'rgba(20, 20, 20, 0.35)' : 'rgba(255, 255, 255, 0.72)';
const panelBodyBg = darkMode ? 'rgba(0, 0, 0, 0.24)' : 'rgba(255, 255, 255, 0.82)';
const focusRowBg = darkMode ? 'rgba(246, 196, 83, 0.22)' : 'rgba(24, 144, 255, 0.12)';
const [tableHeight, setTableHeight] = useState(500);
const containerRef = useRef<HTMLDivElement>(null);
const pendingFocusColumnKeyRef = useRef<string | null>(null);
const focusHighlightTimerRef = useRef<number | null>(null);
const [focusColumnKey, setFocusColumnKey] = useState('');
const openCommentEditor = useCallback((record: EditableColumn) => {
if (!record?._key) return;
@@ -346,6 +355,61 @@ const TableDesigner: React.FC<{ tab: TabData }> = ({ tab }) => {
setSelectedColumnRowKeys(prev => prev.filter(key => columns.some(c => c._key === key)));
}, [columns]);
useEffect(() => {
return () => {
if (focusHighlightTimerRef.current !== null) {
window.clearTimeout(focusHighlightTimerRef.current);
}
};
}, []);
const focusColumnRow = useCallback((targetKey: string): boolean => {
if (activeKey !== 'columns') return false;
const tableBody = containerRef.current?.querySelector('.ant-table-body') as HTMLElement | null;
if (!tableBody) return false;
const row = tableBody.querySelector(`tr[data-row-key="${targetKey}"]`) as HTMLTableRowElement | null;
if (!row) return false;
row.scrollIntoView({ behavior: 'smooth', block: 'nearest' });
setFocusColumnKey(targetKey);
if (focusHighlightTimerRef.current !== null) {
window.clearTimeout(focusHighlightTimerRef.current);
}
focusHighlightTimerRef.current = window.setTimeout(() => {
setFocusColumnKey(prev => (prev === targetKey ? '' : prev));
}, 1600);
if (!readOnly) {
const firstInput = row.querySelector('input') as HTMLInputElement | null;
if (firstInput) {
firstInput.focus();
firstInput.select();
}
}
return true;
}, [activeKey, readOnly]);
useEffect(() => {
const pendingKey = pendingFocusColumnKeyRef.current;
if (!pendingKey || activeKey !== 'columns') return;
let cancelled = false;
const tryFocus = () => {
if (cancelled) return;
if (focusColumnRow(pendingKey)) {
pendingFocusColumnKeyRef.current = null;
}
};
const timerA = window.setTimeout(tryFocus, 0);
const timerB = window.setTimeout(tryFocus, 96);
return () => {
cancelled = true;
window.clearTimeout(timerA);
window.clearTimeout(timerB);
};
}, [activeKey, columns, focusColumnRow]);
// Initial Columns Definition
useEffect(() => {
const initialCols = [
@@ -886,21 +950,46 @@ ${selectedTrigger.statement}`;
}));
};
const handleAddColumn = () => {
const newCol: EditableColumn = {
name: isNewTable ? 'new_column' : `new_col_${columns.length + 1}`,
type: 'varchar(255)',
nullable: 'YES',
key: '',
extra: '',
comment: '',
default: '',
_key: `new-${Date.now()}`,
isNew: true,
isAutoIncrement: false
};
setColumns([...columns, newCol]);
};
const createNewColumn = useCallback((indexHint: number): EditableColumn => ({
name: isNewTable ? 'new_column' : `new_col_${indexHint}`,
type: 'varchar(255)',
nullable: 'YES',
key: '',
extra: '',
comment: '',
default: '',
_key: `new-${Date.now()}-${Math.random().toString(36).slice(2, 8)}`,
isNew: true,
isAutoIncrement: false
}), [isNewTable]);
const handleAddColumn = useCallback((insertAfterKey?: string) => {
const newCol = createNewColumn(columns.length + 1);
setColumns(prev => {
const next = [...prev];
if (insertAfterKey) {
const insertIndex = next.findIndex(col => col._key === insertAfterKey);
if (insertIndex >= 0) {
next.splice(insertIndex + 1, 0, newCol);
return next;
}
}
next.push(newCol);
return next;
});
setSelectedColumnRowKeys([newCol._key]);
pendingFocusColumnKeyRef.current = newCol._key;
}, [columns.length, createNewColumn]);
const handleAddColumnAfterSelected = useCallback(() => {
const selectedSet = new Set(selectedColumnRowKeys);
const anchor = columns.find(col => selectedSet.has(col._key));
if (!anchor) {
message.warning('请先选择一个字段,再执行插入。');
return;
}
handleAddColumn(anchor._key);
}, [columns, handleAddColumn, selectedColumnRowKeys]);
const handleDeleteColumn = (key: string) => {
setColumns(prev => prev.filter(c => c._key !== key));
@@ -1920,22 +2009,35 @@ END;`;
}));
const columnsTabContent = (
<div ref={containerRef} className="table-designer-wrapper" style={{ height: '100%', overflow: 'hidden', position: 'relative' }}>
<div
ref={containerRef}
className="table-designer-wrapper"
style={{
height: '100%',
overflow: 'hidden',
position: 'relative',
background: panelBodyBg
}}
>
<style>{`
.table-designer-wrapper .ant-table-body {
max-height: ${tableHeight}px !important;
}
}
.table-designer-wrapper .table-designer-focus-row > .ant-table-cell {
background: ${focusRowBg} !important;
}
`}</style>
{readOnly ? (
<Table
dataSource={columns}
columns={resizableColumns}
rowKey="_key"
rowClassName={(record: EditableColumn) => record._key === focusColumnKey ? 'table-designer-focus-row' : ''}
size="small"
pagination={false}
loading={loading}
scroll={{ y: tableHeight }}
bordered
bordered={false}
components={{
header: {
cell: ResizableTitle,
@@ -1953,11 +2055,12 @@ END;`;
onChange: (nextSelectedRowKeys) => setSelectedColumnRowKeys(nextSelectedRowKeys as string[]),
}}
rowKey="_key"
rowClassName={(record: EditableColumn) => record._key === focusColumnKey ? 'table-designer-focus-row' : ''}
size="small"
pagination={false}
loading={loading}
scroll={{ y: tableHeight }}
bordered
bordered={false}
components={{
body: { row: SortableRow },
header: { cell: ResizableTitle }
@@ -1985,8 +2088,63 @@ END;`;
);
return (
<div style={{ display: 'flex', flexDirection: 'column', height: '100%' }}>
<div style={{ padding: '8px', borderBottom: '1px solid #eee', display: 'flex', gap: '8px', alignItems: 'center' }}>
<div className="table-designer-shell" style={{ display: 'flex', flexDirection: 'column', height: '100%', minHeight: 0, padding: '6px 0' }}>
<style>{`
.table-designer-shell .ant-table,
.table-designer-shell .ant-table-wrapper,
.table-designer-shell .ant-table-container {
background: transparent !important;
}
.table-designer-shell .ant-table-wrapper,
.table-designer-shell .ant-table-container {
border: none !important;
overflow: hidden !important;
}
.table-designer-shell .ant-table-thead > tr > th {
background: transparent !important;
border-bottom: 1px solid ${darkMode ? 'rgba(255,255,255,0.06)' : 'rgba(0,0,0,0.06)'} !important;
border-inline-end: 1px solid transparent !important;
}
.table-designer-shell .ant-table-tbody > tr > td,
.table-designer-shell .ant-table-tbody .ant-table-row > .ant-table-cell {
background: transparent !important;
border-bottom: 1px solid ${darkMode ? 'rgba(255,255,255,0.05)' : 'rgba(0,0,0,0.05)'} !important;
border-inline-end: 1px solid transparent !important;
}
.table-designer-shell .ant-table-thead > tr > th::before {
display: none !important;
}
.table-designer-shell .ant-table-tbody > tr:hover > td,
.table-designer-shell .ant-table-tbody .ant-table-row:hover > .ant-table-cell {
background: ${darkMode ? 'rgba(255,255,255,0.06)' : 'rgba(0,0,0,0.02)'} !important;
}
.table-designer-shell .ant-tabs-nav {
margin-bottom: 8px !important;
}
.table-designer-shell .ant-tabs-nav::before {
border-bottom-color: ${darkMode ? 'rgba(255,255,255,0.08)' : 'rgba(0,0,0,0.08)'} !important;
}
.table-designer-shell .ant-tabs-content-holder,
.table-designer-shell .ant-tabs-content,
.table-designer-shell .ant-tabs-tabpane {
height: 100%;
}
`}</style>
<div
style={{
padding: '10px 12px 8px 12px',
borderBottom: `1px solid ${panelToolbarBorder}`,
borderTopLeftRadius: panelRadius,
borderTopRightRadius: panelRadius,
borderLeft: `1px solid ${panelFrameColor}`,
borderRight: `1px solid ${panelFrameColor}`,
borderTop: `1px solid ${panelFrameColor}`,
background: panelToolbarBg,
display: 'flex',
gap: '8px',
alignItems: 'center'
}}
>
{isNewTable && (
<>
<Input
@@ -2014,14 +2172,25 @@ END;`;
/>
</>
)}
{!readOnly && <Button icon={<SaveOutlined />} type="primary" onClick={generateDDL}></Button>}
{!isNewTable && <Button icon={<ReloadOutlined />} onClick={fetchData}></Button>}
{!readOnly && <Button size="small" icon={<SaveOutlined />} type="primary" onClick={generateDDL}></Button>}
{!isNewTable && <Button size="small" icon={<ReloadOutlined />} onClick={fetchData}></Button>}
{!isNewTable && !readOnly && supportsTableCommentOps() && (
<Button icon={<EditOutlined />} onClick={openTableCommentModal}></Button>
<Button size="small" icon={<EditOutlined />} onClick={openTableCommentModal}></Button>
)}
{!readOnly && <Button icon={<PlusOutlined />} onClick={handleAddColumn}></Button>}
{!readOnly && <Button size="small" icon={<PlusOutlined />} onClick={() => handleAddColumn()}></Button>}
{!readOnly && (
<Button
size="small"
icon={<PlusOutlined />}
onClick={handleAddColumnAfterSelected}
disabled={selectedColumnRowKeys.length === 0}
>
</Button>
)}
{!readOnly && (
<Button
size="small"
icon={<CopyOutlined />}
onClick={openCopySelectedColumnsModal}
disabled={selectedColumns.length === 0}
@@ -2034,7 +2203,17 @@ END;`;
<Tabs
activeKey={activeKey}
onChange={setActiveKey}
style={{ flex: 1, padding: '0 10px' }}
style={{
flex: 1,
minHeight: 0,
padding: '8px 10px 10px 10px',
borderBottomLeftRadius: panelRadius,
borderBottomRightRadius: panelRadius,
borderLeft: `1px solid ${panelFrameColor}`,
borderRight: `1px solid ${panelFrameColor}`,
borderBottom: `1px solid ${panelFrameColor}`,
background: panelBodyBg
}}
items={[
{
key: 'columns',
@@ -2276,7 +2455,7 @@ END;`;
label: 'DDL',
icon: <FileTextOutlined />,
children: (
<div style={{ height: 'calc(100vh - 200px)', border: darkMode ? '1px solid #303030' : '1px solid #d9d9d9', borderRadius: 4 }}>
<div style={{ height: 'calc(100vh - 200px)', border: `1px solid ${panelFrameColor}`, borderRadius: panelRadius, background: panelBodyBg }}>
<Editor
height="100%"
language="sql"
@@ -2517,7 +2696,7 @@ END;`;
<span><strong>:</strong> {selectedTrigger.timing}</span>
<span><strong>:</strong> {selectedTrigger.event}</span>
</div>
<div style={{ border: darkMode ? '1px solid #303030' : '1px solid #d9d9d9', borderRadius: 4 }}>
<div style={{ border: `1px solid ${panelFrameColor}`, borderRadius: panelRadius, background: panelBodyBg }}>
<Editor
height="350px"
language="sql"
@@ -2553,7 +2732,7 @@ END;`;
<span></span>
)}
</div>
<div style={{ border: darkMode ? '1px solid #303030' : '1px solid #d9d9d9', borderRadius: 4 }}>
<div style={{ border: `1px solid ${panelFrameColor}`, borderRadius: panelRadius, background: panelBodyBg }}>
<Editor
height="350px"
language="sql"

View File

@@ -1,6 +1,14 @@
import { create } from 'zustand';
import { persist } from 'zustand/middleware';
import { ConnectionConfig, ProxyConfig, SavedConnection, TabData, SavedQuery, ConnectionTag } from './types';
import {
ShortcutAction,
ShortcutBinding,
ShortcutOptions,
DEFAULT_SHORTCUT_OPTIONS,
cloneShortcutOptions,
sanitizeShortcutOptions,
} from './utils/shortcuts';
const DEFAULT_APPEARANCE = { opacity: 1.0, blur: 0 };
const DEFAULT_UI_SCALE = 1.0;
@@ -48,6 +56,23 @@ const SUPPORTED_CONNECTION_TYPES = new Set([
'duckdb',
'custom',
]);
const SSL_SUPPORTED_CONNECTION_TYPES = new Set([
'mysql',
'mariadb',
'diros',
'sphinx',
'dameng',
'clickhouse',
'postgres',
'sqlserver',
'oracle',
'kingbase',
'highgo',
'vastbase',
'mongodb',
'redis',
'tdengine',
]);
const getDefaultPortByType = (type: string): number => {
switch (type) {
@@ -177,6 +202,16 @@ const sanitizeConnectionConfig = (value: unknown): ConnectionConfig => {
const defaultPort = getDefaultPortByType(type);
const savePassword = typeof raw.savePassword === 'boolean' ? raw.savePassword : true;
const mongoSrv = !!raw.mongoSrv;
const sslCapable = SSL_SUPPORTED_CONNECTION_TYPES.has(type);
const sslModeRaw = toTrimmedString(raw.sslMode, 'preferred').toLowerCase();
const sslMode: 'preferred' | 'required' | 'skip-verify' | 'disable' =
sslModeRaw === 'required'
? 'required'
: sslModeRaw === 'skip-verify'
? 'skip-verify'
: sslModeRaw === 'disable'
? 'disable'
: 'preferred';
const sshRaw = (raw.ssh && typeof raw.ssh === 'object') ? raw.ssh as Record<string, unknown> : {};
const ssh = {
@@ -206,6 +241,10 @@ const sanitizeConnectionConfig = (value: unknown): ConnectionConfig => {
password: savePassword ? toTrimmedString(raw.password) : '',
savePassword,
database: toTrimmedString(raw.database),
useSSL: sslCapable ? !!raw.useSSL : false,
sslMode: sslCapable ? sslMode : 'disable',
sslCertPath: sslCapable ? toTrimmedString(raw.sslCertPath) : '',
sslKeyPath: sslCapable ? toTrimmedString(raw.sslKeyPath) : '',
useSSH: !!raw.useSSH,
ssh,
useProxy: !!raw.useProxy,
@@ -359,6 +398,7 @@ interface AppState {
globalProxy: GlobalProxyConfig;
sqlFormatOptions: { keywordCase: 'upper' | 'lower' };
queryOptions: QueryOptions;
shortcutOptions: ShortcutOptions;
sqlLogs: SqlLog[];
tableAccessCount: Record<string, number>;
tableSortPreference: Record<string, 'name' | 'frequency'>;
@@ -396,6 +436,8 @@ interface AppState {
setGlobalProxy: (proxy: Partial<GlobalProxyConfig>) => void;
setSqlFormatOptions: (options: { keywordCase: 'upper' | 'lower' }) => void;
setQueryOptions: (options: Partial<QueryOptions>) => void;
updateShortcut: (action: ShortcutAction, binding: Partial<ShortcutBinding>) => void;
resetShortcutOptions: () => void;
addSqlLog: (log: SqlLog) => void;
clearSqlLogs: () => void;
@@ -537,6 +579,7 @@ export const useStore = create<AppState>()(
globalProxy: { ...DEFAULT_GLOBAL_PROXY },
sqlFormatOptions: { keywordCase: 'upper' },
queryOptions: { maxRows: 5000, showColumnComment: true, showColumnType: true },
shortcutOptions: cloneShortcutOptions(DEFAULT_SHORTCUT_OPTIONS),
sqlLogs: [],
tableAccessCount: {},
tableSortPreference: {},
@@ -708,6 +751,16 @@ export const useStore = create<AppState>()(
setGlobalProxy: (proxy) => set((state) => ({ globalProxy: sanitizeGlobalProxy({ ...state.globalProxy, ...proxy }) })),
setSqlFormatOptions: (options) => set({ sqlFormatOptions: options }),
setQueryOptions: (options) => set((state) => ({ queryOptions: { ...state.queryOptions, ...options } })),
updateShortcut: (action, binding) => set((state) => ({
shortcutOptions: {
...state.shortcutOptions,
[action]: {
...state.shortcutOptions[action],
...binding,
},
},
})),
resetShortcutOptions: () => set({ shortcutOptions: cloneShortcutOptions(DEFAULT_SHORTCUT_OPTIONS) }),
addSqlLog: (log) => set((state) => ({ sqlLogs: [log, ...state.sqlLogs].slice(0, 1000) })), // Keep last 1000 logs
clearSqlLogs: () => set({ sqlLogs: [] }),
@@ -754,6 +807,7 @@ export const useStore = create<AppState>()(
nextState.globalProxy = sanitizeGlobalProxy(state.globalProxy);
nextState.sqlFormatOptions = sanitizeSqlFormatOptions(state.sqlFormatOptions);
nextState.queryOptions = sanitizeQueryOptions(state.queryOptions);
nextState.shortcutOptions = sanitizeShortcutOptions(state.shortcutOptions);
nextState.tableAccessCount = sanitizeTableAccessCount(state.tableAccessCount);
nextState.tableSortPreference = sanitizeTableSortPreference(state.tableSortPreference);
return nextState as AppState;
@@ -774,6 +828,7 @@ export const useStore = create<AppState>()(
globalProxy: sanitizeGlobalProxy(state.globalProxy),
sqlFormatOptions: sanitizeSqlFormatOptions(state.sqlFormatOptions),
queryOptions: sanitizeQueryOptions(state.queryOptions),
shortcutOptions: sanitizeShortcutOptions(state.shortcutOptions),
tableAccessCount: sanitizeTableAccessCount(state.tableAccessCount),
tableSortPreference: sanitizeTableSortPreference(state.tableSortPreference),
};
@@ -790,6 +845,7 @@ export const useStore = create<AppState>()(
globalProxy: state.globalProxy,
sqlFormatOptions: state.sqlFormatOptions,
queryOptions: state.queryOptions,
shortcutOptions: state.shortcutOptions,
tableAccessCount: state.tableAccessCount,
tableSortPreference: state.tableSortPreference
}), // Don't persist logs

View File

@@ -22,6 +22,10 @@ export interface ConnectionConfig {
password?: string;
savePassword?: boolean;
database?: string;
useSSL?: boolean;
sslMode?: 'preferred' | 'required' | 'skip-verify' | 'disable';
sslCertPath?: string;
sslKeyPath?: string;
useSSH?: boolean;
ssh?: SSHConfig;
useProxy?: boolean;

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,258 @@
import type { KeyboardEvent as ReactKeyboardEvent } from 'react';
export type ShortcutAction =
| 'runQuery'
| 'focusSidebarSearch'
| 'newQueryTab'
| 'toggleLogPanel'
| 'toggleTheme'
| 'openShortcutManager';
export interface ShortcutBinding {
combo: string;
enabled: boolean;
}
export type ShortcutOptions = Record<ShortcutAction, ShortcutBinding>;
export interface ShortcutActionMeta {
label: string;
description: string;
allowInEditable?: boolean;
}
const MODIFIER_ORDER = ['Ctrl', 'Meta', 'Alt', 'Shift'] as const;
const MODIFIER_SET = new Set(MODIFIER_ORDER);
const KEY_ALIASES: Record<string, string> = {
control: 'Ctrl',
ctrl: 'Ctrl',
command: 'Meta',
cmd: 'Meta',
meta: 'Meta',
option: 'Alt',
alt: 'Alt',
shift: 'Shift',
escape: 'Esc',
esc: 'Esc',
return: 'Enter',
enter: 'Enter',
tab: 'Tab',
space: 'Space',
' ': 'Space',
backspace: 'Backspace',
delete: 'Delete',
del: 'Delete',
arrowup: 'Up',
up: 'Up',
arrowdown: 'Down',
down: 'Down',
arrowleft: 'Left',
left: 'Left',
arrowright: 'Right',
right: 'Right',
pagedown: 'PageDown',
pageup: 'PageUp',
home: 'Home',
end: 'End',
insert: 'Insert',
',': ',',
'.': '.',
'/': '/',
';': ';',
"'": "'",
'[': '[',
']': ']',
'\\': '\\',
'-': '-',
'=': '=',
'`': '`',
};
export const SHORTCUT_ACTION_ORDER: ShortcutAction[] = [
'runQuery',
'focusSidebarSearch',
'newQueryTab',
'toggleLogPanel',
'toggleTheme',
'openShortcutManager',
];
export const SHORTCUT_ACTION_META: Record<ShortcutAction, ShortcutActionMeta> = {
runQuery: {
label: '执行 SQL',
description: '在当前查询页执行 SQL',
},
focusSidebarSearch: {
label: '聚焦侧边栏搜索',
description: '定位到左侧连接树搜索框',
allowInEditable: true,
},
newQueryTab: {
label: '新建查询页',
description: '创建一个新的 SQL 查询标签页',
},
toggleLogPanel: {
label: '切换日志面板',
description: '打开或关闭 SQL 执行日志面板',
},
toggleTheme: {
label: '切换主题',
description: '在亮色和暗色主题之间切换',
},
openShortcutManager: {
label: '打开快捷键管理',
description: '打开快捷键设置面板',
allowInEditable: true,
},
};
export const DEFAULT_SHORTCUT_OPTIONS: ShortcutOptions = {
runQuery: { combo: 'Ctrl+Shift+R', enabled: true },
focusSidebarSearch: { combo: 'Ctrl+F', enabled: true },
newQueryTab: { combo: 'Ctrl+Shift+N', enabled: true },
toggleLogPanel: { combo: 'Ctrl+Shift+L', enabled: true },
toggleTheme: { combo: 'Ctrl+Shift+D', enabled: true },
openShortcutManager: { combo: 'Ctrl+,', enabled: true },
};
const normalizeKeyToken = (value: string): string => {
const token = String(value || '').trim();
if (!token) return '';
const alias = KEY_ALIASES[token.toLowerCase()];
if (alias) return alias;
if (/^f([1-9]|1[0-2])$/i.test(token)) {
return token.toUpperCase();
}
if (token.length === 1) {
return token === '+' ? '+' : token.toUpperCase();
}
return token.length > 1 ? token[0].toUpperCase() + token.slice(1).toLowerCase() : token;
};
export const normalizeShortcutCombo = (combo: string): string => {
const raw = String(combo || '').trim();
if (!raw) return '';
const pieces = raw
.split('+')
.map(part => part.trim())
.filter(Boolean);
const modifiers: string[] = [];
let key = '';
pieces.forEach((part) => {
const normalized = normalizeKeyToken(part);
if (!normalized) return;
if (MODIFIER_SET.has(normalized as typeof MODIFIER_ORDER[number])) {
if (!modifiers.includes(normalized)) {
modifiers.push(normalized);
}
return;
}
key = normalized;
});
modifiers.sort((a, b) => MODIFIER_ORDER.indexOf(a as typeof MODIFIER_ORDER[number]) - MODIFIER_ORDER.indexOf(b as typeof MODIFIER_ORDER[number]));
if (!key) {
return modifiers.join('+');
}
return [...modifiers, key].join('+');
};
const normalizeKeyboardKey = (key: string): string => {
const token = String(key || '').trim();
if (!token) return '';
const alias = KEY_ALIASES[token.toLowerCase()];
if (alias) return alias;
if (token.length === 1) {
if (token === ' ') return 'Space';
return token.toUpperCase();
}
if (/^f([1-9]|1[0-2])$/i.test(token)) {
return token.toUpperCase();
}
return token.length > 1 ? token[0].toUpperCase() + token.slice(1) : token;
};
export const eventToShortcut = (event: KeyboardEvent | ReactKeyboardEvent): string => {
const key = normalizeKeyboardKey(event.key);
if (!key || MODIFIER_SET.has(key as typeof MODIFIER_ORDER[number])) {
return '';
}
const modifiers: string[] = [];
if (event.ctrlKey) modifiers.push('Ctrl');
if (event.metaKey) modifiers.push('Meta');
if (event.altKey) modifiers.push('Alt');
if (event.shiftKey) modifiers.push('Shift');
return normalizeShortcutCombo([...modifiers, key].join('+'));
};
export const isShortcutMatch = (event: KeyboardEvent | ReactKeyboardEvent, combo: string): boolean => {
const expected = normalizeShortcutCombo(combo);
if (!expected) return false;
const actual = eventToShortcut(event);
return actual === expected;
};
export const hasModifierKey = (combo: string): boolean => {
const normalized = normalizeShortcutCombo(combo);
if (!normalized) return false;
return normalized.split('+').some(part => MODIFIER_SET.has(part as typeof MODIFIER_ORDER[number]));
};
export const cloneShortcutOptions = (value: ShortcutOptions): ShortcutOptions => {
return SHORTCUT_ACTION_ORDER.reduce((acc, action) => {
acc[action] = {
combo: normalizeShortcutCombo(value[action]?.combo || DEFAULT_SHORTCUT_OPTIONS[action].combo),
enabled: value[action]?.enabled !== false,
};
return acc;
}, {} as ShortcutOptions);
};
export const sanitizeShortcutOptions = (value: unknown): ShortcutOptions => {
const raw = (value && typeof value === 'object') ? value as Record<string, unknown> : {};
const defaults = cloneShortcutOptions(DEFAULT_SHORTCUT_OPTIONS);
SHORTCUT_ACTION_ORDER.forEach((action) => {
const actionRaw = raw[action];
if (!actionRaw || typeof actionRaw !== 'object') {
return;
}
const binding = actionRaw as Record<string, unknown>;
const combo = normalizeShortcutCombo(String(binding.combo || defaults[action].combo));
defaults[action] = {
combo: combo || defaults[action].combo,
enabled: binding.enabled === false ? false : true,
};
});
return defaults;
};
export const isEditableElement = (target: EventTarget | null): boolean => {
if (!(target instanceof HTMLElement)) {
return false;
}
const tag = target.tagName.toLowerCase();
if (target.isContentEditable) {
return true;
}
if (tag === 'input' || tag === 'textarea' || tag === 'select') {
return true;
}
if (target.closest('.monaco-editor, .monaco-inputbox, .ant-select, .ant-picker, .ant-input')) {
return true;
}
return false;
};
export const getShortcutDisplay = (combo: string): string => {
const normalized = normalizeShortcutCombo(combo);
return normalized || '-';
};

View File

@@ -1,6 +1,7 @@
export type FilterCondition = {
id?: number;
enabled?: boolean;
logic?: 'AND' | 'OR';
column?: string;
op?: string;
value?: string;
@@ -142,8 +143,12 @@ export const parseListValues = (val: string) => {
.filter(Boolean);
};
const normalizeConditionLogic = (logic: unknown): 'AND' | 'OR' => {
return String(logic || '').trim().toUpperCase() === 'OR' ? 'OR' : 'AND';
};
export const buildWhereSQL = (dbType: string, conditions: FilterCondition[]) => {
const whereParts: string[] = [];
const whereParts: Array<{ expr: string; logic: 'AND' | 'OR' }> = [];
(conditions || []).forEach((cond) => {
if (cond?.enabled === false) return;
@@ -152,10 +157,17 @@ export const buildWhereSQL = (dbType: string, conditions: FilterCondition[]) =>
const column = (cond?.column || '').trim();
const value = (cond?.value ?? '').toString();
const value2 = (cond?.value2 ?? '').toString();
const logic = normalizeConditionLogic(cond?.logic);
const appendWherePart = (expr: string) => {
const normalizedExpr = String(expr || '').trim();
if (!normalizedExpr) return;
whereParts.push({ expr: normalizedExpr, logic });
};
if (op === 'CUSTOM') {
const expr = value.trim();
if (expr) whereParts.push(`(${expr})`);
if (expr) appendWherePart(`(${expr})`);
return;
}
@@ -165,80 +177,80 @@ export const buildWhereSQL = (dbType: string, conditions: FilterCondition[]) =>
switch (op) {
case 'IS_NULL':
whereParts.push(`${col} IS NULL`);
appendWherePart(`${col} IS NULL`);
return;
case 'IS_NOT_NULL':
whereParts.push(`${col} IS NOT NULL`);
appendWherePart(`${col} IS NOT NULL`);
return;
case 'IS_EMPTY':
// 兼容:空值通常理解为 NULL 或空字符串
whereParts.push(`(${col} IS NULL OR ${col} = '')`);
appendWherePart(`(${col} IS NULL OR ${col} = '')`);
return;
case 'IS_NOT_EMPTY':
whereParts.push(`(${col} IS NOT NULL AND ${col} <> '')`);
appendWherePart(`(${col} IS NOT NULL AND ${col} <> '')`);
return;
case 'BETWEEN': {
const v1 = value.trim();
const v2 = value2.trim();
if (!v1 || !v2) return;
whereParts.push(`${col} BETWEEN '${escapeLiteral(v1)}' AND '${escapeLiteral(v2)}'`);
appendWherePart(`${col} BETWEEN '${escapeLiteral(v1)}' AND '${escapeLiteral(v2)}'`);
return;
}
case 'NOT_BETWEEN': {
const v1 = value.trim();
const v2 = value2.trim();
if (!v1 || !v2) return;
whereParts.push(`${col} NOT BETWEEN '${escapeLiteral(v1)}' AND '${escapeLiteral(v2)}'`);
appendWherePart(`${col} NOT BETWEEN '${escapeLiteral(v1)}' AND '${escapeLiteral(v2)}'`);
return;
}
case 'IN': {
const items = parseListValues(value);
if (items.length === 0) return;
const list = items.map(v => `'${escapeLiteral(v)}'`).join(', ');
whereParts.push(`${col} IN (${list})`);
appendWherePart(`${col} IN (${list})`);
return;
}
case 'NOT_IN': {
const items = parseListValues(value);
if (items.length === 0) return;
const list = items.map(v => `'${escapeLiteral(v)}'`).join(', ');
whereParts.push(`${col} NOT IN (${list})`);
appendWherePart(`${col} NOT IN (${list})`);
return;
}
case 'CONTAINS': {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} LIKE '%${escapeLiteral(v)}%'`);
appendWherePart(`${col} LIKE '%${escapeLiteral(v)}%'`);
return;
}
case 'NOT_CONTAINS': {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} NOT LIKE '%${escapeLiteral(v)}%'`);
appendWherePart(`${col} NOT LIKE '%${escapeLiteral(v)}%'`);
return;
}
case 'STARTS_WITH': {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} LIKE '${escapeLiteral(v)}%'`);
appendWherePart(`${col} LIKE '${escapeLiteral(v)}%'`);
return;
}
case 'NOT_STARTS_WITH': {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} NOT LIKE '${escapeLiteral(v)}%'`);
appendWherePart(`${col} NOT LIKE '${escapeLiteral(v)}%'`);
return;
}
case 'ENDS_WITH': {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} LIKE '%${escapeLiteral(v)}'`);
appendWherePart(`${col} LIKE '%${escapeLiteral(v)}'`);
return;
}
case 'NOT_ENDS_WITH': {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} NOT LIKE '%${escapeLiteral(v)}'`);
appendWherePart(`${col} NOT LIKE '%${escapeLiteral(v)}'`);
return;
}
case '=':
@@ -249,7 +261,7 @@ export const buildWhereSQL = (dbType: string, conditions: FilterCondition[]) =>
case '>=': {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} ${op} '${escapeLiteral(v)}'`);
appendWherePart(`${col} ${op} '${escapeLiteral(v)}'`);
return;
}
default: {
@@ -257,16 +269,23 @@ export const buildWhereSQL = (dbType: string, conditions: FilterCondition[]) =>
if (op.toUpperCase() === 'LIKE') {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} LIKE '%${escapeLiteral(v)}%'`);
appendWherePart(`${col} LIKE '%${escapeLiteral(v)}%'`);
return;
}
const v = value.trim();
if (!v) return;
whereParts.push(`${col} ${op} '${escapeLiteral(v)}'`);
appendWherePart(`${col} ${op} '${escapeLiteral(v)}'`);
}
}
});
return whereParts.length > 0 ? `WHERE ${whereParts.join(' AND ')}` : '';
if (whereParts.length === 0) return '';
let whereExpr = `(${whereParts[0].expr})`;
for (let i = 1; i < whereParts.length; i++) {
const part = whereParts[i];
whereExpr = `(${whereExpr} ${part.logic} (${part.expr}))`;
}
return `WHERE ${whereExpr}`;
};

View File

@@ -96,6 +96,10 @@ export namespace connection {
password: string;
savePassword?: boolean;
database: string;
useSSL?: boolean;
sslMode?: string;
sslCertPath?: string;
sslKeyPath?: string;
useSSH: boolean;
ssh: SSHConfig;
useProxy?: boolean;
@@ -130,6 +134,10 @@ export namespace connection {
this.password = source["password"];
this.savePassword = source["savePassword"];
this.database = source["database"];
this.useSSL = source["useSSL"];
this.sslMode = source["sslMode"];
this.sslCertPath = source["sslCertPath"];
this.sslKeyPath = source["sslKeyPath"];
this.useSSH = source["useSSH"];
this.ssh = this.convertValues(source["ssh"], SSHConfig);
this.useProxy = source["useProxy"];

2
go.mod
View File

@@ -17,6 +17,7 @@ require (
github.com/taosdata/driver-go/v3 v3.7.8
github.com/wailsapp/wails/v2 v2.11.0
github.com/xuri/excelize/v2 v2.10.0
go.mongodb.org/mongo-driver v1.17.9
go.mongodb.org/mongo-driver/v2 v2.5.0
golang.org/x/crypto v0.47.0
golang.org/x/mod v0.32.0
@@ -66,6 +67,7 @@ require (
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/montanaflynn/stats v0.7.1 // indirect
github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/paulmach/orb v0.12.0 // indirect
github.com/pierrec/lz4/v4 v4.1.25 // indirect

4
go.sum
View File

@@ -156,6 +156,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc=
github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE=
github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
@@ -248,6 +250,8 @@ github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN
github.com/zeebo/xxh3 v1.1.0 h1:s7DLGDK45Dyfg7++yxI0khrfwq9661w9EN78eP/UZVs=
github.com/zeebo/xxh3 v1.1.0/go.mod h1:IisAie1LELR4xhVinxWS5+zf1lA4p0MW4T+w+W07F5s=
go.mongodb.org/mongo-driver v1.11.4/go.mod h1:PTSz5yu21bkT/wXpkS7WR5f0ddqw5quethTUn9WM+2g=
go.mongodb.org/mongo-driver v1.17.9 h1:IexDdCuuNJ3BHrELgBlyaH9p60JXAvdzWR128q+U5tU=
go.mongodb.org/mongo-driver v1.17.9/go.mod h1:LlOhpH5NUEfhxcAwG0UEkMqwYcc4JU18gtCdGudk/tQ=
go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE=
go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0=
go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48=

View File

@@ -148,6 +148,67 @@ func getCacheKey(config connection.ConnectionConfig) string {
return hex.EncodeToString(sum[:])
}
func shortCacheKey(cacheKey string) string {
shortKey := cacheKey
if len(shortKey) > 12 {
shortKey = shortKey[:12]
}
return shortKey
}
func shouldRefreshCachedConnection(err error) bool {
if err == nil {
return false
}
normalized := strings.ToLower(normalizeErrorMessage(err))
if normalized == "" {
return false
}
patterns := []string{
"invalid connection",
"bad connection",
"database is closed",
"connection is already closed",
"use of closed network connection",
"broken pipe",
"connection reset by peer",
"server has gone away",
"eof",
}
for _, pattern := range patterns {
if strings.Contains(normalized, pattern) {
return true
}
}
return false
}
func (a *App) invalidateCachedDatabase(config connection.ConnectionConfig, reason error) bool {
effectiveConfig := applyGlobalProxyToConnection(config)
key := getCacheKey(effectiveConfig)
shortKey := shortCacheKey(key)
a.mu.Lock()
defer a.mu.Unlock()
entry, exists := a.dbCache[key]
if !exists || entry.inst == nil {
return false
}
if closeErr := entry.inst.Close(); closeErr != nil {
logger.Error(closeErr, "关闭失效缓存连接失败缓存Key=%s", shortKey)
}
delete(a.dbCache, key)
if reason != nil {
logger.Errorf("检测到连接失效,已清理缓存连接:%s 缓存Key=%s 原因=%s", formatConnSummary(effectiveConfig), shortKey, normalizeErrorMessage(reason))
} else {
logger.Infof("已清理缓存连接:%s 缓存Key=%s", formatConnSummary(effectiveConfig), shortKey)
}
return true
}
func wrapConnectError(config connection.ConnectionConfig, err error) error {
if err == nil {
return nil

View File

@@ -36,6 +36,17 @@ func normalizeSchemaAndTable(config connection.ConnectionConfig, dbName string,
return rawDB, rawTable
}
dbType := strings.ToLower(strings.TrimSpace(config.Type))
if dbType == "sqlserver" {
// SQL Server 的 DB 接口约定第一个参数是数据库名schema 由 tableName(如 dbo.users) 自行解析。
// 不能把 schema(dbo) 传到第一个参数,否则会拼出 dbo.sys.columns 等无效对象名。
targetDB := rawDB
if targetDB == "" {
targetDB = strings.TrimSpace(config.Database)
}
return targetDB, rawTable
}
if parts := strings.SplitN(rawTable, ".", 2); len(parts) == 2 {
schema := strings.TrimSpace(parts[0])
table := strings.TrimSpace(parts[1])
@@ -44,13 +55,10 @@ func normalizeSchemaAndTable(config connection.ConnectionConfig, dbName string,
}
}
switch strings.ToLower(strings.TrimSpace(config.Type)) {
switch dbType {
case "postgres", "kingbase", "highgo", "vastbase":
// PG/金仓/瀚高/海量dbName 在 UI 里是"数据库"schema 需从 tableName 或使用默认 public。
return "public", rawTable
case "sqlserver":
// SQL ServerdbName 表示数据库schema 默认 dbo
return "dbo", rawTable
default:
// MySQLdbName 表示数据库Oracle/达梦dbName 表示 schema/owner。
return rawDB, rawTable

View File

@@ -0,0 +1,51 @@
package app
import (
"testing"
"GoNavi-Wails/internal/connection"
)
func TestNormalizeSchemaAndTable_SQLServerKeepsDatabaseAndQualifiedTable(t *testing.T) {
t.Parallel()
schemaOrDb, table := normalizeSchemaAndTable(connection.ConnectionConfig{
Type: "sqlserver",
Database: "master",
}, "biz_db", "dbo.users")
if schemaOrDb != "biz_db" {
t.Fatalf("expected sqlserver first return value as database name, got %q", schemaOrDb)
}
if table != "dbo.users" {
t.Fatalf("expected sqlserver table name keep qualified form, got %q", table)
}
}
func TestNormalizeSchemaAndTable_SQLServerFallbackToConfigDatabase(t *testing.T) {
t.Parallel()
schemaOrDb, table := normalizeSchemaAndTable(connection.ConnectionConfig{
Type: "sqlserver",
Database: "biz_db",
}, "", "dbo.users")
if schemaOrDb != "biz_db" {
t.Fatalf("expected sqlserver fallback database from config, got %q", schemaOrDb)
}
if table != "dbo.users" {
t.Fatalf("expected sqlserver table name keep qualified form, got %q", table)
}
}
func TestNormalizeSchemaAndTable_PostgresStillSplitsQualifiedName(t *testing.T) {
t.Parallel()
schema, table := normalizeSchemaAndTable(connection.ConnectionConfig{
Type: "postgres",
}, "demo_db", "public.orders")
if schema != "public" || table != "orders" {
t.Fatalf("expected postgres qualified split to public.orders, got %q.%q", schema, table)
}
}

View File

@@ -422,15 +422,36 @@ func (a *App) DBQueryWithCancel(config connection.ConnectionConfig, dbName strin
if !isReadQuery && strings.ToLower(strings.TrimSpace(runConfig.Type)) == "mongodb" && strings.HasPrefix(strings.TrimSpace(query), "{") {
isReadQuery = true
}
if isReadQuery {
var data []map[string]interface{}
var columns []string
if q, ok := dbInst.(interface {
runReadQuery := func(inst db.Database) ([]map[string]interface{}, []string, error) {
if q, ok := inst.(interface {
QueryContext(context.Context, string) ([]map[string]interface{}, []string, error)
}); ok {
data, columns, err = q.QueryContext(ctx, query)
} else {
data, columns, err = dbInst.Query(query)
return q.QueryContext(ctx, query)
}
return inst.Query(query)
}
runExecQuery := func(inst db.Database) (int64, error) {
if e, ok := inst.(interface {
ExecContext(context.Context, string) (int64, error)
}); ok {
return e.ExecContext(ctx, query)
}
return inst.Exec(query)
}
if isReadQuery {
data, columns, err := runReadQuery(dbInst)
if err != nil && shouldRefreshCachedConnection(err) {
if a.invalidateCachedDatabase(runConfig, err) {
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
if retryErr != nil {
logger.Error(retryErr, "DBQuery 重建连接失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: retryErr.Error()}
}
data, columns, err = runReadQuery(retryInst)
}
}
if err != nil {
logger.Error(err, "DBQuery 查询失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
@@ -438,13 +459,16 @@ func (a *App) DBQueryWithCancel(config connection.ConnectionConfig, dbName strin
}
return connection.QueryResult{Success: true, Data: data, Fields: columns, QueryID: queryID}
} else {
var affected int64
if e, ok := dbInst.(interface {
ExecContext(context.Context, string) (int64, error)
}); ok {
affected, err = e.ExecContext(ctx, query)
} else {
affected, err = dbInst.Exec(query)
affected, err := runExecQuery(dbInst)
if err != nil && shouldRefreshCachedConnection(err) {
if a.invalidateCachedDatabase(runConfig, err) {
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
if retryErr != nil {
logger.Error(retryErr, "DBQuery 重建连接失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: retryErr.Error()}
}
affected, err = runExecQuery(retryInst)
}
}
if err != nil {
logger.Error(err, "DBQuery 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
@@ -524,15 +548,26 @@ func sqlSnippet(query string) string {
}
func (a *App) DBGetDatabases(config connection.ConnectionConfig) connection.QueryResult {
dbInst, err := a.getDatabase(config)
runConfig := normalizeRunConfig(config, "")
dbInst, err := a.getDatabase(runConfig)
if err != nil {
logger.Error(err, "DBGetDatabases 获取连接失败:%s", formatConnSummary(config))
logger.Error(err, "DBGetDatabases 获取连接失败:%s", formatConnSummary(runConfig))
return connection.QueryResult{Success: false, Message: err.Error()}
}
dbs, err := dbInst.GetDatabases()
if err != nil && shouldRefreshCachedConnection(err) {
if a.invalidateCachedDatabase(runConfig, err) {
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
if retryErr != nil {
logger.Error(retryErr, "DBGetDatabases 重建连接失败:%s", formatConnSummary(runConfig))
return connection.QueryResult{Success: false, Message: retryErr.Error()}
}
dbs, err = retryInst.GetDatabases()
}
}
if err != nil {
logger.Error(err, "DBGetDatabases 获取数据库列表失败:%s", formatConnSummary(config))
logger.Error(err, "DBGetDatabases 获取数据库列表失败:%s", formatConnSummary(runConfig))
return connection.QueryResult{Success: false, Message: err.Error()}
}
@@ -554,6 +589,16 @@ func (a *App) DBGetTables(config connection.ConnectionConfig, dbName string) con
}
tables, err := dbInst.GetTables(dbName)
if err != nil && shouldRefreshCachedConnection(err) {
if a.invalidateCachedDatabase(runConfig, err) {
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
if retryErr != nil {
logger.Error(retryErr, "DBGetTables 重建连接失败:%s", formatConnSummary(runConfig))
return connection.QueryResult{Success: false, Message: retryErr.Error()}
}
tables, err = retryInst.GetTables(dbName)
}
}
if err != nil {
logger.Error(err, "DBGetTables 获取表列表失败:%s", formatConnSummary(runConfig))
return connection.QueryResult{Success: false, Message: err.Error()}

View File

@@ -304,13 +304,33 @@ var driverGoModulePathMap = map[string]string{
"clickhouse": "github.com/ClickHouse/clickhouse-go/v2",
}
var driverGoModuleAliasPathMap = map[string][]string{
"mongodb": {
"go.mongodb.org/mongo-driver",
},
}
var driverExtraHistoryLimitMap = map[string]int{
"mongodb": 10,
}
var fallbackRecentDriverVersionsMap = map[string][]goModuleVersionMeta{
"mongodb": {
{Version: "2.5.0"},
{Version: "2.4.2"},
{Version: "2.4.1"},
{Version: "2.4.0"},
{Version: "2.3.1"},
{Version: "2.3.0"},
{Version: "2.2.3"},
{Version: "1.17.9"},
{Version: "1.17.8"},
{Version: "1.17.7"},
{Version: "1.17.6"},
{Version: "1.17.4"},
{Version: "1.17.3"},
{Version: "1.17.2"},
{Version: "1.17.1"},
{Version: "1.17.0"},
{Version: "1.16.1"},
},
}
@@ -1600,17 +1620,57 @@ func resolveRecentDriverVersionMetas(driverType string, limit int) []goModuleVer
if normalized == "" {
return nil
}
if modulePath := strings.TrimSpace(driverGoModulePathMap[normalized]); modulePath != "" {
if metas := fetchGoModuleVersionMetasCached(modulePath); len(metas) > 0 {
if len(metas) > limit {
return append([]goModuleVersionMeta(nil), metas[:limit]...)
modulePaths := resolveDriverGoModulePaths(normalized)
if len(modulePaths) > 0 {
result := make([]goModuleVersionMeta, 0, limit)
seen := make(map[string]struct{}, limit)
appendUnique := func(values []goModuleVersionMeta, maxAppend int) {
if maxAppend <= 0 {
return
}
return append([]goModuleVersionMeta(nil), metas...)
appended := 0
for _, meta := range values {
version := normalizeVersion(strings.TrimSpace(meta.Version))
if version == "" {
continue
}
key := strings.ToLower(version)
if _, ok := seen[key]; ok {
continue
}
meta.Version = version
result = append(result, meta)
seen[key] = struct{}{}
appended++
if appended >= maxAppend {
return
}
}
}
appendUnique(fetchGoModuleVersionMetasCached(modulePaths[0]), limit)
extraLimit := resolveDriverExtraHistoryLimit(normalized)
for _, modulePath := range modulePaths[1:] {
if extraLimit <= 0 {
break
}
before := len(result)
appendUnique(fetchGoModuleVersionMetasCached(modulePath), extraLimit)
extraLimit -= len(result) - before
}
if len(result) > 0 {
return result
}
}
fallbackLimit := limit + resolveDriverExtraHistoryLimit(normalized)
if fallbackLimit <= 0 {
fallbackLimit = limit
}
if fallback := fallbackRecentDriverVersionsMap[normalized]; len(fallback) > 0 {
if len(fallback) > limit {
return append([]goModuleVersionMeta(nil), fallback[:limit]...)
if len(fallback) > fallbackLimit {
return append([]goModuleVersionMeta(nil), fallback[:fallbackLimit]...)
}
return append([]goModuleVersionMeta(nil), fallback...)
}
@@ -1635,15 +1695,13 @@ func triggerDriverVersionMetadataWarmup(definitions []driverDefinition) {
if driverType == "" || !db.IsOptionalGoDriver(driverType) {
continue
}
modulePath := strings.TrimSpace(driverGoModulePathMap[driverType])
if modulePath == "" {
continue
for _, modulePath := range resolveDriverGoModulePaths(driverType) {
if _, ok := seenModule[modulePath]; ok {
continue
}
seenModule[modulePath] = struct{}{}
modulePaths = append(modulePaths, modulePath)
}
if _, ok := seenModule[modulePath]; ok {
continue
}
seenModule[modulePath] = struct{}{}
modulePaths = append(modulePaths, modulePath)
}
if len(modulePaths) == 0 {
@@ -1663,6 +1721,40 @@ func triggerDriverVersionMetadataWarmup(definitions []driverDefinition) {
}(append([]string(nil), modulePaths...))
}
func resolveDriverGoModulePaths(driverType string) []string {
normalized := normalizeDriverType(driverType)
if normalized == "" {
return nil
}
paths := make([]string, 0, 3)
seen := make(map[string]struct{}, 3)
appendPath := func(path string) {
trimmed := strings.TrimSpace(path)
if trimmed == "" {
return
}
if _, ok := seen[trimmed]; ok {
return
}
seen[trimmed] = struct{}{}
paths = append(paths, trimmed)
}
appendPath(driverGoModulePathMap[normalized])
for _, alias := range driverGoModuleAliasPathMap[normalized] {
appendPath(alias)
}
return paths
}
func resolveDriverExtraHistoryLimit(driverType string) int {
limit := driverExtraHistoryLimitMap[normalizeDriverType(driverType)]
if limit < 0 {
return 0
}
return limit
}
func tryStartDriverVersionMetadataWarmup(now time.Time) bool {
driverVersionWarmupMu.Lock()
defer driverVersionWarmupMu.Unlock()
@@ -2356,16 +2448,23 @@ func hashFileSHA256(filePath string) (string, error) {
func installOptionalDriverAgentPackage(a *App, definition driverDefinition, selectedVersion string, resolvedDir string, downloadURL string) (installedDriverPackage, error) {
driverType := normalizeDriverType(definition.Type)
executablePath, err := db.ResolveOptionalDriverAgentExecutablePath(resolvedDir, driverType)
installPath, err := db.ResolveOptionalDriverAgentExecutablePathForVersion(resolvedDir, driverType, selectedVersion)
if err != nil {
return installedDriverPackage{}, err
}
downloadSource, hash, err := ensureOptionalDriverAgentBinary(a, definition, executablePath, downloadURL)
runtimePath, err := db.ResolveOptionalDriverAgentExecutablePath(resolvedDir, driverType)
if err != nil {
return installedDriverPackage{}, err
}
downloadSource, hash, err := ensureOptionalDriverAgentBinary(a, definition, installPath, downloadURL, selectedVersion)
if err != nil {
return installedDriverPackage{}, err
}
if activateErr := activateOptionalDriverAgentBinary(installPath, runtimePath); activateErr != nil {
return installedDriverPackage{}, fmt.Errorf("activate %s driver agent failed: %w", resolveDriverDisplayName(definition), activateErr)
}
if strings.TrimSpace(hash) == "" {
hash, err = hashFileSHA256(executablePath)
hash, err = hashFileSHA256(installPath)
if err != nil {
return installedDriverPackage{}, fmt.Errorf("计算 %s 驱动代理摘要失败:%w", resolveDriverDisplayName(definition), err)
}
@@ -2376,9 +2475,9 @@ func installOptionalDriverAgentPackage(a *App, definition driverDefinition, sele
return installedDriverPackage{
DriverType: driverType,
Version: strings.TrimSpace(selectedVersion),
FilePath: executablePath,
FileName: filepath.Base(executablePath),
ExecutablePath: executablePath,
FilePath: installPath,
FileName: filepath.Base(installPath),
ExecutablePath: runtimePath,
DownloadURL: strings.TrimSpace(downloadSource),
SHA256: hash,
DownloadedAt: time.Now().Format(time.RFC3339),
@@ -2686,9 +2785,11 @@ func installOptionalDriverAgentFromLocalZip(zipPath string, definition driverDef
return filepath.ToSlash(strings.TrimPrefix(strings.TrimSpace(entry.Name), "./")), nil
}
func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, executablePath string, downloadURL string) (string, string, error) {
func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, executablePath string, downloadURL string, selectedVersion string) (string, string, error) {
driverType := normalizeDriverType(definition.Type)
displayName := resolveDriverDisplayName(definition)
forceSourceBuild := shouldForceSourceBuildForVersion(driverType, selectedVersion)
skipReuseCandidate := shouldSkipReusableAgentCandidate(driverType, selectedVersion)
info, err := os.Stat(executablePath)
if err == nil && !info.IsDir() {
@@ -2708,49 +2809,53 @@ func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, execut
if a != nil {
a.emitDriverDownloadProgress(driverType, "downloading", 10, 100, "检查本地驱动代理缓存")
}
if sourcePath, ok := findExistingOptionalDriverAgentCandidate(definition, executablePath); ok {
if copyErr := copyAgentBinary(sourcePath, executablePath); copyErr != nil {
return "", "", fmt.Errorf("复制预置 %s 驱动代理失败:%w", displayName, copyErr)
if !skipReuseCandidate {
if sourcePath, ok := findExistingOptionalDriverAgentCandidate(definition, executablePath); ok {
if copyErr := copyAgentBinary(sourcePath, executablePath); copyErr != nil {
return "", "", fmt.Errorf("复制预置 %s 驱动代理失败:%w", displayName, copyErr)
}
hash, hashErr := hashFileSHA256(executablePath)
if hashErr != nil {
return "", "", fmt.Errorf("计算预置 %s 驱动代理摘要失败:%w", displayName, hashErr)
}
return "file://" + sourcePath, hash, nil
}
hash, hashErr := hashFileSHA256(executablePath)
if hashErr != nil {
return "", "", fmt.Errorf("计算预置 %s 驱动代理摘要失败:%w", displayName, hashErr)
}
return "file://" + sourcePath, hash, nil
}
downloadURLs := resolveOptionalDriverAgentDownloadURLs(definition, downloadURL)
var downloadErrs []string
if len(downloadURLs) > 0 {
for _, candidateURL := range downloadURLs {
if a != nil {
a.emitDriverDownloadProgress(driverType, "downloading", 20, 100, fmt.Sprintf("下载预编译 %s 驱动代理", displayName))
if !forceSourceBuild {
downloadURLs := resolveOptionalDriverAgentDownloadURLs(definition, downloadURL)
if len(downloadURLs) > 0 {
for _, candidateURL := range downloadURLs {
if a != nil {
a.emitDriverDownloadProgress(driverType, "downloading", 20, 100, fmt.Sprintf("下载预编译 %s 驱动代理", displayName))
}
hash, dlErr := downloadOptionalDriverAgentBinary(a, definition, candidateURL, executablePath)
if dlErr == nil {
return candidateURL, hash, nil
}
downloadErrs = append(downloadErrs, fmt.Sprintf("%s: %s", candidateURL, strings.TrimSpace(dlErr.Error())))
}
hash, dlErr := downloadOptionalDriverAgentBinary(a, definition, candidateURL, executablePath)
if dlErr == nil {
return candidateURL, hash, nil
}
downloadErrs = append(downloadErrs, fmt.Sprintf("%s: %s", candidateURL, strings.TrimSpace(dlErr.Error())))
}
}
bundleURLs := resolveOptionalDriverBundleDownloadURLs()
if len(bundleURLs) > 0 {
for _, bundleURL := range bundleURLs {
if a != nil {
a.emitDriverDownloadProgress(driverType, "downloading", 20, 100, fmt.Sprintf("从驱动总包提取 %s 代理", displayName))
bundleURLs := resolveOptionalDriverBundleDownloadURLs()
if len(bundleURLs) > 0 {
for _, bundleURL := range bundleURLs {
if a != nil {
a.emitDriverDownloadProgress(driverType, "downloading", 20, 100, fmt.Sprintf("从驱动总包提取 %s 代理", displayName))
}
source, hash, bundleErr := downloadOptionalDriverAgentFromBundle(a, definition, bundleURL, executablePath)
if bundleErr == nil {
return source, hash, nil
}
downloadErrs = append(downloadErrs, fmt.Sprintf("%s: %s", bundleURL, strings.TrimSpace(bundleErr.Error())))
}
source, hash, bundleErr := downloadOptionalDriverAgentFromBundle(a, definition, bundleURL, executablePath)
if bundleErr == nil {
return source, hash, nil
}
downloadErrs = append(downloadErrs, fmt.Sprintf("%s: %s", bundleURL, strings.TrimSpace(bundleErr.Error())))
}
}
if a != nil {
a.emitDriverDownloadProgress(driverType, "downloading", 92, 100, "未命中预编译包,尝试开发态本地构建")
}
hash, buildErr := buildOptionalDriverAgentFromSource(definition, executablePath)
hash, buildErr := buildOptionalDriverAgentFromSource(definition, executablePath, selectedVersion)
if buildErr == nil {
return fmt.Sprintf("local://go-build/%s-driver-agent", driverType), hash, nil
}
@@ -2912,7 +3017,7 @@ func downloadOptionalDriverAgentFromBundle(a *App, definition driverDefinition,
return source, hash, nil
}
func buildOptionalDriverAgentFromSource(definition driverDefinition, executablePath string) (string, error) {
func buildOptionalDriverAgentFromSource(definition driverDefinition, executablePath string, selectedVersion string) (string, error) {
driverType := normalizeDriverType(definition.Type)
displayName := resolveDriverDisplayName(definition)
goPath, lookErr := exec.LookPath("go")
@@ -2920,7 +3025,7 @@ func buildOptionalDriverAgentFromSource(definition driverDefinition, executableP
return "", fmt.Errorf("当前环境未安装 Go且未找到可用的 %s 预编译代理包", displayName)
}
tagName, tagErr := optionalDriverBuildTag(driverType)
tagName, tagErr := optionalDriverBuildTag(driverType, selectedVersion)
if tagErr != nil {
return "", tagErr
}
@@ -2931,6 +3036,7 @@ func buildOptionalDriverAgentFromSource(definition driverDefinition, executableP
}
cmd := exec.Command(goPath, "build", "-tags", tagName, "-trimpath", "-ldflags", "-s -w", "-o", executablePath, "./cmd/optional-driver-agent")
cmd.Dir = projectRoot
cmd.Env = append(os.Environ(), "GOTOOLCHAIN=auto")
output, buildErr := cmd.CombinedOutput()
if buildErr != nil {
return "", fmt.Errorf("构建 %s 驱动代理失败:%v输出%s", displayName, buildErr, strings.TrimSpace(string(output)))
@@ -2945,7 +3051,31 @@ func buildOptionalDriverAgentFromSource(definition driverDefinition, executableP
return hash, nil
}
func optionalDriverBuildTag(driverType string) (string, error) {
func resolveMongoDriverMajorFromVersion(version string) int {
trimmed := strings.TrimSpace(version)
trimmed = strings.TrimPrefix(trimmed, "v")
if strings.HasPrefix(trimmed, "1.") || trimmed == "1" {
return 1
}
return 2
}
func shouldForceSourceBuildForVersion(driverType string, selectedVersion string) bool {
if normalizeDriverType(driverType) != "mongodb" {
return false
}
return resolveMongoDriverMajorFromVersion(selectedVersion) == 1
}
func shouldSkipReusableAgentCandidate(driverType string, selectedVersion string) bool {
if normalizeDriverType(driverType) != "mongodb" {
return false
}
_ = selectedVersion
return true
}
func optionalDriverBuildTag(driverType string, selectedVersion string) (string, error) {
switch normalizeDriverType(driverType) {
case "mysql":
return "gonavi_mysql_driver", nil
@@ -2970,6 +3100,9 @@ func optionalDriverBuildTag(driverType string) (string, error) {
case "vastbase":
return "gonavi_vastbase_driver", nil
case "mongodb":
if resolveMongoDriverMajorFromVersion(selectedVersion) == 1 {
return "gonavi_mongodb_driver_v1", nil
}
return "gonavi_mongodb_driver", nil
case "tdengine":
return "gonavi_tdengine_driver", nil
@@ -3310,6 +3443,30 @@ func resolveDriverDisplayName(definition driverDefinition) string {
return "未知"
}
func activateOptionalDriverAgentBinary(installPath string, runtimePath string) error {
source := strings.TrimSpace(installPath)
target := strings.TrimSpace(runtimePath)
if source == "" || target == "" {
return fmt.Errorf("agent path is empty")
}
if source == target {
return nil
}
absSource := source
absTarget := target
if value, err := filepath.Abs(source); err == nil && strings.TrimSpace(value) != "" {
absSource = value
}
if value, err := filepath.Abs(target); err == nil && strings.TrimSpace(value) != "" {
absTarget = value
}
if strings.EqualFold(absSource, absTarget) {
return nil
}
return copyAgentBinary(source, target)
}
func copyAgentBinary(sourcePath, targetPath string) error {
src, err := os.Open(sourcePath)
if err != nil {

View File

@@ -27,6 +27,10 @@ type ConnectionConfig struct {
Password string `json:"password"`
SavePassword bool `json:"savePassword,omitempty"` // Persist password in saved connection
Database string `json:"database"`
UseSSL bool `json:"useSSL,omitempty"` // MySQL-like SSL/TLS switch
SSLMode string `json:"sslMode,omitempty"` // preferred | required | skip-verify | disable
SSLCertPath string `json:"sslCertPath,omitempty"` // TLS client certificate path (e.g., Dameng)
SSLKeyPath string `json:"sslKeyPath,omitempty"` // TLS client private key path (e.g., Dameng)
UseSSH bool `json:"useSSH"`
SSH SSHConfig `json:"ssh"`
UseProxy bool `json:"useProxy,omitempty"`

View File

@@ -107,7 +107,7 @@ func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig
if readTimeout < minClickHouseReadTimeout {
readTimeout = minClickHouseReadTimeout
}
return &clickhouse.Options{
opts := &clickhouse.Options{
Addr: []string{
net.JoinHostPort(config.Host, strconv.Itoa(config.Port)),
},
@@ -119,6 +119,10 @@ func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig
DialTimeout: connectTimeout,
ReadTimeout: readTimeout,
}
if tlsConfig := resolveGenericTLSConfig(config); tlsConfig != nil {
opts.TLS = tlsConfig
}
return opts
}
func (c *ClickHouseDB) Connect(config connection.ConnectionConfig) error {
@@ -165,13 +169,30 @@ func (c *ClickHouseDB) Connect(config connection.ConnectionConfig) error {
logger.Infof("ClickHouse 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
}
c.conn = clickhouse.OpenDB(c.buildClickHouseOptions(runConfig))
if err := c.Ping(); err != nil {
_ = c.Close()
return fmt.Errorf("连接建立后验证失败:%w", err)
attempts := []connection.ConnectionConfig{runConfig}
if shouldTrySSLPreferredFallback(runConfig) {
attempts = append(attempts, withSSLDisabled(runConfig))
}
return nil
var failures []string
for idx, attempt := range attempts {
c.conn = clickhouse.OpenDB(c.buildClickHouseOptions(attempt))
if err := c.Ping(); err != nil {
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err))
if c.conn != nil {
_ = c.conn.Close()
c.conn = nil
}
continue
}
if idx > 0 {
logger.Warnf("ClickHouse SSL 优先连接失败,已回退至明文连接")
}
return nil
}
_ = c.Close()
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ""))
}
func (c *ClickHouseDB) Close() error {

View File

@@ -36,6 +36,14 @@ func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
if config.Database != "" {
q.Set("schema", config.Database)
}
if config.UseSSL {
if certPath := strings.TrimSpace(config.SSLCertPath); certPath != "" {
q.Set("SSL_CERT_PATH", certPath)
}
if keyPath := strings.TrimSpace(config.SSLKeyPath); keyPath != "" {
q.Set("SSL_KEY_PATH", keyPath)
}
}
if escapedPassword != config.Password {
// 达梦驱动要求密码包含特殊字符时password 需 PathEscape并添加 escapeProcess=true 让驱动解码。
q.Set("escapeProcess", "true")
@@ -50,8 +58,12 @@ func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
}
func (d *DamengDB) Connect(config connection.ConnectionConfig) error {
var dsn string
var err error
runConfig := config
if runConfig.UseSSL {
if strings.TrimSpace(runConfig.SSLCertPath) == "" || strings.TrimSpace(runConfig.SSLKeyPath) == "" {
return fmt.Errorf("达梦启用 SSL 需要同时配置证书路径(sslCertPath)与私钥路径(sslKeyPath)")
}
}
if config.UseSSH {
// Create SSH tunnel with local port forwarding
@@ -80,22 +92,37 @@ func (d *DamengDB) Connect(config connection.ConnectionConfig) error {
localConfig.Port = port
localConfig.UseSSH = false
dsn = d.getDSN(localConfig)
runConfig = localConfig
logger.Infof("达梦数据库通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
} else {
dsn = d.getDSN(config)
}
db, err := sql.Open("dm", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
attempts := []connection.ConnectionConfig{runConfig}
if shouldTrySSLPreferredFallback(runConfig) {
attempts = append(attempts, withSSLDisabled(runConfig))
}
d.conn = db
d.pingTimeout = getConnectTimeout(config)
if err := d.Ping(); err != nil {
return fmt.Errorf("连接建立后验证失败:%w", err)
var failures []string
for idx, attempt := range attempts {
dsn := d.getDSN(attempt)
db, err := sql.Open("dm", dsn)
if err != nil {
failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err))
continue
}
d.conn = db
d.pingTimeout = getConnectTimeout(attempt)
if err := d.Ping(); err != nil {
_ = db.Close()
d.conn = nil
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err))
continue
}
if idx > 0 {
logger.Warnf("达梦 SSL 优先连接失败,已回退至明文连接")
}
return nil
}
return nil
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ""))
}
func (d *DamengDB) Close() error {

View File

@@ -151,9 +151,10 @@ func (d *DirosDB) getDSN(config connection.ConnectionConfig) string {
}
timeout := getConnectTimeoutSeconds(config)
tlsMode := resolveMySQLTLSMode(config)
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds",
config.User, config.Password, protocol, address, database, timeout)
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s",
config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode))
}
func resolveDirosCredential(config connection.ConnectionConfig, addressIndex int) (string, string) {

View File

@@ -33,6 +33,44 @@ func TestPostgresDSN_EscapesPassword(t *testing.T) {
}
}
func TestPostgresDSN_SSLModeRequireWhenEnabled(t *testing.T) {
p := &PostgresDB{}
cfg := connection.ConnectionConfig{
Type: "postgres",
Host: "127.0.0.1",
Port: 5432,
User: "user",
Password: "pass",
Database: "db",
UseSSL: true,
SSLMode: "required",
}
dsn := p.getDSN(cfg)
if !strings.Contains(dsn, "sslmode=require") {
t.Fatalf("dsn 缺少 sslmode=require 参数:%s", dsn)
}
}
func TestMySQLDSN_UsesTLSParamWhenSSLEnabled(t *testing.T) {
m := &MySQLDB{}
cfg := connection.ConnectionConfig{
Type: "mysql",
Host: "127.0.0.1",
Port: 3306,
User: "root",
Password: "pass",
Database: "db",
UseSSL: true,
SSLMode: "required",
}
dsn := m.getDSN(cfg)
if !strings.Contains(dsn, "tls=true") {
t.Fatalf("dsn 缺少 tls=true 参数:%s", dsn)
}
}
func TestOracleDSN_EscapesUserAndPassword(t *testing.T) {
o := &OracleDB{}
cfg := connection.ConnectionConfig{
@@ -82,6 +120,30 @@ func TestDamengDSN_EscapesPasswordAndEnablesEscapeProcess(t *testing.T) {
}
}
func TestDamengDSN_AppendsSSLCertAndKeyParams(t *testing.T) {
d := &DamengDB{}
cfg := connection.ConnectionConfig{
Type: "dameng",
Host: "127.0.0.1",
Port: 5236,
User: "SYSDBA",
Password: "pass",
Database: "DBName",
UseSSL: true,
SSLMode: "required",
SSLCertPath: "C:\\certs\\client-cert.pem",
SSLKeyPath: "C:\\certs\\client-key.pem",
}
dsn := d.getDSN(cfg)
if !strings.Contains(dsn, "SSL_CERT_PATH=") {
t.Fatalf("dsn 缺少 SSL_CERT_PATH 参数:%s", dsn)
}
if !strings.Contains(dsn, "SSL_KEY_PATH=") {
t.Fatalf("dsn 缺少 SSL_KEY_PATH 参数:%s", dsn)
}
}
func TestKingbaseDSN_QuotesPasswordWithSpaces(t *testing.T) {
k := &KingbaseDB{}
cfg := connection.ConnectionConfig{
@@ -116,6 +178,47 @@ func TestTDengineDSN_UsesWebSocketFormat(t *testing.T) {
}
}
func TestTDengineDSN_UsesSecureWebSocketWhenSSLEnabled(t *testing.T) {
td := &TDengineDB{}
cfg := connection.ConnectionConfig{
Type: "tdengine",
Host: "127.0.0.1",
Port: 6041,
User: "root",
Password: "taosdata",
Database: "power",
UseSSL: true,
SSLMode: "required",
}
dsn := td.getDSN(cfg)
if !strings.HasPrefix(dsn, "root:taosdata@wss(127.0.0.1:6041)/power") {
t.Fatalf("tdengine ssl dsn 格式不正确:%s", dsn)
}
}
func TestSQLServerDSN_EncryptMapping(t *testing.T) {
s := &SqlServerDB{}
cfg := connection.ConnectionConfig{
Type: "sqlserver",
Host: "127.0.0.1",
Port: 1433,
User: "sa",
Password: "pass",
Database: "master",
UseSSL: true,
SSLMode: "required",
}
dsn := s.getDSN(cfg)
if !strings.Contains(strings.ToLower(dsn), "encrypt=true") {
t.Fatalf("sqlserver dsn 缺少 encrypt=true%s", dsn)
}
if !strings.Contains(strings.ToLower(dsn), "trustservercertificate=false") {
t.Fatalf("sqlserver dsn 缺少 TrustServerCertificate=false%s", dsn)
}
}
func TestClickHouseOptions_UsesStructuredTimeoutAndAuth(t *testing.T) {
c := &ClickHouseDB{}
cfg := normalizeClickHouseConfig(connection.ConnectionConfig{

View File

@@ -42,7 +42,7 @@ func (h *HighGoDB) getDSN(config connection.ConnectionConfig) string {
}
u.User = url.UserPassword(config.User, config.Password)
q := url.Values{}
q.Set("sslmode", "disable")
q.Set("sslmode", resolvePostgresSSLMode(config))
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
u.RawQuery = q.Encode()
@@ -50,7 +50,7 @@ func (h *HighGoDB) getDSN(config connection.ConnectionConfig) string {
}
func (h *HighGoDB) Connect(config connection.ConnectionConfig) error {
var dsn string
runConfig := config
if config.UseSSH {
logger.Infof("HighGo 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
@@ -76,23 +76,37 @@ func (h *HighGoDB) Connect(config connection.ConnectionConfig) error {
localConfig.Port = port
localConfig.UseSSH = false
dsn = h.getDSN(localConfig)
runConfig = localConfig
logger.Infof("HighGo 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
} else {
dsn = h.getDSN(config)
}
db, err := sql.Open("highgo", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
attempts := []connection.ConnectionConfig{runConfig}
if shouldTrySSLPreferredFallback(runConfig) {
attempts = append(attempts, withSSLDisabled(runConfig))
}
h.conn = db
h.pingTimeout = getConnectTimeout(config)
if err := h.Ping(); err != nil {
return fmt.Errorf("连接建立后验证失败:%w", err)
var failures []string
for idx, attempt := range attempts {
dsn := h.getDSN(attempt)
db, err := sql.Open("highgo", dsn)
if err != nil {
failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err))
continue
}
h.conn = db
h.pingTimeout = getConnectTimeout(attempt)
if err := h.Ping(); err != nil {
_ = db.Close()
h.conn = nil
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err))
continue
}
if idx > 0 {
logger.Warnf("HighGo SSL 优先连接失败,已回退至明文连接")
}
return nil
}
return nil
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ""))
}
func (h *HighGoDB) Close() error {

View File

@@ -65,12 +65,13 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string {
port := config.Port
// Construct DSN
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable connect_timeout=%d",
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s connect_timeout=%d",
quoteConnValue(address),
port,
quoteConnValue(config.User),
quoteConnValue(config.Password),
quoteConnValue(config.Database),
quoteConnValue(resolvePostgresSSLMode(config)),
getConnectTimeoutSeconds(config),
)
@@ -78,8 +79,7 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string {
}
func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error {
var dsn string
var err error
runConfig := config
if config.UseSSH {
// Create SSH tunnel with local port forwarding
@@ -108,23 +108,37 @@ func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error {
localConfig.Port = port
localConfig.UseSSH = false
dsn = k.getDSN(localConfig)
runConfig = localConfig
logger.Infof("人大金仓通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
} else {
dsn = k.getDSN(config)
}
// Open using "kingbase" driver
db, err := sql.Open("kingbase", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
attempts := []connection.ConnectionConfig{runConfig}
if shouldTrySSLPreferredFallback(runConfig) {
attempts = append(attempts, withSSLDisabled(runConfig))
}
k.conn = db
k.pingTimeout = getConnectTimeout(config)
if err := k.Ping(); err != nil {
return fmt.Errorf("连接建立后验证失败:%w", err)
var failures []string
for idx, attempt := range attempts {
dsn := k.getDSN(attempt)
db, err := sql.Open("kingbase", dsn)
if err != nil {
failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err))
continue
}
k.conn = db
k.pingTimeout = getConnectTimeout(attempt)
if err := k.Ping(); err != nil {
_ = db.Close()
k.conn = nil
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err))
continue
}
if idx > 0 {
logger.Warnf("人大金仓 SSL 优先连接失败,已回退至明文连接")
}
return nil
}
return nil
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ""))
}
func (k *KingbaseDB) Close() error {

View File

@@ -6,6 +6,7 @@ import (
"context"
"database/sql"
"fmt"
"net/url"
"strings"
"time"
@@ -40,9 +41,10 @@ func (m *MariaDB) getDSN(config connection.ConnectionConfig) string {
}
timeout := getConnectTimeoutSeconds(config)
tlsMode := resolveMySQLTLSMode(config)
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds",
config.User, config.Password, protocol, address, database, timeout)
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s",
config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode))
}
func (m *MariaDB) Connect(config connection.ConnectionConfig) error {

View File

@@ -4,6 +4,7 @@ package db
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/url"
@@ -327,35 +328,60 @@ func (m *MongoDB) Connect(config connection.ConnectionConfig) error {
m.database = "admin"
}
attemptConfigs := buildMongoAuthAttempts(connectConfig)
sslAttempts := []connection.ConnectionConfig{connectConfig}
if shouldTrySSLPreferredFallback(connectConfig) {
sslAttempts = append(sslAttempts, withSSLDisabled(connectConfig))
}
var errorDetails []string
for index, attemptConfig := range attemptConfigs {
authLabel := "主库凭据"
if index > 0 {
authLabel = "从库凭据"
for sslIndex, sslConfig := range sslAttempts {
sslLabel := "SSL"
if sslIndex > 0 {
sslLabel = "明文回退"
}
uri := m.getURI(attemptConfig)
clientOpts := options.Client().ApplyURI(uri)
if attemptConfig.UseProxy {
clientOpts.SetDialer(&mongoProxyDialer{proxyConfig: attemptConfig.Proxy})
}
client, err := mongo.Connect(clientOpts)
if err != nil {
errorDetails = append(errorDetails, fmt.Sprintf("%s连接失败: %v", authLabel, err))
continue
}
attemptConfigs := buildMongoAuthAttempts(sslConfig)
for index, attemptConfig := range attemptConfigs {
authLabel := "主库凭据"
if index > 0 {
authLabel = "从库凭据"
}
m.client = client
if err := m.Ping(); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
_ = client.Disconnect(ctx)
cancel()
m.client = nil
errorDetails = append(errorDetails, fmt.Sprintf("%s验证失败: %v", authLabel, err))
continue
if sslIndex > 0 {
attemptConfig.URI = ""
}
uri := m.getURI(attemptConfig)
clientOpts := options.Client().ApplyURI(uri)
tlsEnabled, tlsInsecure := resolveMongoTLSSettings(attemptConfig)
if tlsEnabled {
clientOpts.SetTLSConfig(&tls.Config{
MinVersion: tls.VersionTLS12,
InsecureSkipVerify: tlsInsecure,
})
}
if attemptConfig.UseProxy {
clientOpts.SetDialer(&mongoProxyDialer{proxyConfig: attemptConfig.Proxy})
}
client, err := mongo.Connect(clientOpts)
if err != nil {
errorDetails = append(errorDetails, fmt.Sprintf("%s %s连接失败: %v", sslLabel, authLabel, err))
continue
}
m.client = client
if err := m.Ping(); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
_ = client.Disconnect(ctx)
cancel()
m.client = nil
errorDetails = append(errorDetails, fmt.Sprintf("%s %s验证失败: %v", sslLabel, authLabel, err))
continue
}
if sslIndex > 0 {
logger.Warnf("MongoDB SSL 优先连接失败,已回退至明文连接")
}
return nil
}
return nil
}
if len(errorDetails) > 0 {

File diff suppressed because it is too large Load Diff

View File

@@ -35,6 +35,32 @@ func ResolveOptionalDriverAgentExecutablePath(downloadDir string, driverType str
return filepath.Join(root, normalized, optionalDriverAgentExecutableName(normalized)), nil
}
func ResolveOptionalDriverAgentExecutablePathForVersion(downloadDir string, driverType string, version string) (string, error) {
normalized := normalizeRuntimeDriverType(driverType)
if strings.TrimSpace(normalized) == "" {
return "", fmt.Errorf("驱动类型为空")
}
root, err := resolveExternalDriverRoot(downloadDir)
if err != nil {
return "", err
}
if normalized != "mongodb" {
return filepath.Join(root, normalized, optionalDriverAgentExecutableName(normalized)), nil
}
baseName := optionalDriverAgentExecutableName(normalized)
ext := filepath.Ext(baseName)
stem := strings.TrimSuffix(baseName, ext)
major := 2
trimmed := strings.TrimSpace(version)
trimmed = strings.TrimPrefix(trimmed, "v")
if strings.HasPrefix(trimmed, "1.") || trimmed == "1" {
major = 1
}
versionedName := fmt.Sprintf("%s-v%d%s", stem, major, ext)
return filepath.Join(root, normalized, versionedName), nil
}
func ResolveMySQLAgentExecutablePath(downloadDir string) (string, error) {
return ResolveOptionalDriverAgentExecutablePath(downloadDir, "mysql")
}

View File

@@ -184,9 +184,10 @@ func (m *MySQLDB) getDSN(config connection.ConnectionConfig) string {
}
timeout := getConnectTimeoutSeconds(config)
tlsMode := resolveMySQLTLSMode(config)
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds",
config.User, config.Password, protocol, address, database, timeout)
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s",
config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode))
}
func resolveMySQLCredential(config connection.ConnectionConfig, addressIndex int) (string, string) {

View File

@@ -35,12 +35,23 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
}
u.User = url.UserPassword(config.User, config.Password)
u.RawPath = "/" + url.PathEscape(database)
q := url.Values{}
switch normalizedSSLMode(config) {
case sslModeRequired:
q.Set("SSL", "TRUE")
q.Set("SSL VERIFY", "TRUE")
case sslModeSkipVerify, sslModePreferred:
q.Set("SSL", "TRUE")
q.Set("SSL VERIFY", "FALSE")
}
if encoded := q.Encode(); encoded != "" {
u.RawQuery = encoded
}
return u.String()
}
func (o *OracleDB) Connect(config connection.ConnectionConfig) error {
var dsn string
var err error
runConfig := config
serviceName := strings.TrimSpace(config.Database)
if serviceName == "" {
return fmt.Errorf("Oracle 连接缺少服务名Service Name请在连接配置中填写例如 ORCLPDB1")
@@ -73,22 +84,37 @@ func (o *OracleDB) Connect(config connection.ConnectionConfig) error {
localConfig.Port = port
localConfig.UseSSH = false
dsn = o.getDSN(localConfig)
runConfig = localConfig
logger.Infof("Oracle 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
} else {
dsn = o.getDSN(config)
}
db, err := sql.Open("oracle", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
attempts := []connection.ConnectionConfig{runConfig}
if shouldTrySSLPreferredFallback(runConfig) {
attempts = append(attempts, withSSLDisabled(runConfig))
}
o.conn = db
o.pingTimeout = getConnectTimeout(config)
if err := o.Ping(); err != nil {
return fmt.Errorf("连接建立后验证失败:%w", err)
var failures []string
for idx, attempt := range attempts {
dsn := o.getDSN(attempt)
db, err := sql.Open("oracle", dsn)
if err != nil {
failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err))
continue
}
o.conn = db
o.pingTimeout = getConnectTimeout(attempt)
if err := o.Ping(); err != nil {
_ = db.Close()
o.conn = nil
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err))
continue
}
if idx > 0 {
logger.Warnf("Oracle SSL 优先连接失败,已回退至明文连接")
}
return nil
}
return nil
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ""))
}
func (o *OracleDB) Close() error {

View File

@@ -62,7 +62,7 @@ func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string {
}
u.User = url.UserPassword(config.User, config.Password)
q := url.Values{}
q.Set("sslmode", "disable")
q.Set("sslmode", resolvePostgresSSLMode(config))
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
u.RawQuery = q.Encode()
@@ -126,34 +126,49 @@ func (p *PostgresDB) Connect(config connection.ConnectionConfig) error {
logger.Infof("PostgreSQL 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
}
attemptDBs := resolvePostgresConnectDatabases(runConfig)
sslAttempts := []connection.ConnectionConfig{runConfig}
if shouldTrySSLPreferredFallback(runConfig) {
sslAttempts = append(sslAttempts, withSSLDisabled(runConfig))
}
var failures []string
for _, dbName := range attemptDBs {
attemptConfig := runConfig
attemptConfig.Database = dbName
dsn := p.getDSN(attemptConfig)
dbConn, err := sql.Open("postgres", dsn)
if err != nil {
failures = append(failures, fmt.Sprintf("数据库=%s 打开连接失败: %v", dbName, err))
continue
}
p.conn = dbConn
// Force verification
if err := p.Ping(); err != nil {
failures = append(failures, fmt.Sprintf("数据库=%s 验证失败: %v", dbName, err))
_ = dbConn.Close()
p.conn = nil
continue
for sslIndex, sslConfig := range sslAttempts {
sslLabel := "SSL"
if sslIndex > 0 {
sslLabel = "明文回退"
}
if strings.TrimSpace(config.Database) == "" && !strings.EqualFold(dbName, "postgres") {
logger.Infof("PostgreSQL 自动选择连接数据库:%s", dbName)
}
attemptDBs := resolvePostgresConnectDatabases(sslConfig)
for _, dbName := range attemptDBs {
attemptConfig := sslConfig
attemptConfig.Database = dbName
dsn := p.getDSN(attemptConfig)
cleanupOnFailure = false
return nil
dbConn, err := sql.Open("postgres", dsn)
if err != nil {
failures = append(failures, fmt.Sprintf("%s 数据库=%s 打开连接失败: %v", sslLabel, dbName, err))
continue
}
p.conn = dbConn
// Force verification
if err := p.Ping(); err != nil {
failures = append(failures, fmt.Sprintf("%s 数据库=%s 验证失败: %v", sslLabel, dbName, err))
_ = dbConn.Close()
p.conn = nil
continue
}
if sslIndex > 0 {
logger.Warnf("PostgreSQL SSL 优先连接失败,已回退至明文连接")
}
if strings.TrimSpace(config.Database) == "" && !strings.EqualFold(dbName, "postgres") {
logger.Infof("PostgreSQL 自动选择连接数据库:%s", dbName)
}
cleanupOnFailure = false
return nil
}
}
if len(failures) == 0 {

View File

@@ -47,8 +47,9 @@ func (s *SqlServerDB) getDSN(config connection.ConnectionConfig) string {
q := url.Values{}
q.Set("database", dbname)
q.Set("connection timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
q.Set("encrypt", "disable")
q.Set("TrustServerCertificate", "true")
encrypt, trustServerCertificate := resolveSQLServerTLSSettings(config)
q.Set("encrypt", encrypt)
q.Set("TrustServerCertificate", trustServerCertificate)
u.RawQuery = q.Encode()
return u.String()

122
internal/db/ssl_mode.go Normal file
View File

@@ -0,0 +1,122 @@
package db
import (
"crypto/tls"
"strings"
"GoNavi-Wails/internal/connection"
)
const (
sslModeDisable = "disable"
sslModePreferred = "preferred"
sslModeRequired = "required"
sslModeSkipVerify = "skip-verify"
)
func normalizeSSLModeValue(raw string) string {
mode := strings.ToLower(strings.TrimSpace(raw))
switch mode {
case "", sslModePreferred, "prefer":
return sslModePreferred
case sslModeRequired, "require", "on", "true", "mandatory", "strict":
return sslModeRequired
case sslModeSkipVerify, "insecure", "skipverify", "skip_verify", "insecure-skip-verify":
return sslModeSkipVerify
case sslModeDisable, "disabled", "off", "false", "none":
return sslModeDisable
default:
return sslModePreferred
}
}
func normalizedSSLMode(config connection.ConnectionConfig) string {
if !config.UseSSL {
return sslModeDisable
}
return normalizeSSLModeValue(config.SSLMode)
}
func shouldTrySSLPreferredFallback(config connection.ConnectionConfig) bool {
return config.UseSSL && normalizeSSLModeValue(config.SSLMode) == sslModePreferred
}
func withSSLDisabled(config connection.ConnectionConfig) connection.ConnectionConfig {
next := config
next.UseSSL = false
next.SSLMode = sslModeDisable
return next
}
func resolveMySQLTLSMode(config connection.ConnectionConfig) string {
switch normalizedSSLMode(config) {
case sslModeDisable:
return "false"
case sslModeRequired:
return "true"
case sslModeSkipVerify:
return "skip-verify"
default:
return "preferred"
}
}
func resolvePostgresSSLMode(config connection.ConnectionConfig) string {
switch normalizedSSLMode(config) {
case sslModeDisable:
return "disable"
case sslModeRequired:
return "require"
case sslModeSkipVerify:
return "require"
default:
return "require"
}
}
func resolveSQLServerTLSSettings(config connection.ConnectionConfig) (encrypt string, trustServerCertificate string) {
switch normalizedSSLMode(config) {
case sslModeDisable:
return "disable", "true"
case sslModeRequired:
return "true", "false"
case sslModeSkipVerify:
return "true", "true"
default:
return "false", "true"
}
}
func resolveGenericTLSConfig(config connection.ConnectionConfig) *tls.Config {
switch normalizedSSLMode(config) {
case sslModeDisable:
return nil
case sslModeRequired:
return &tls.Config{MinVersion: tls.VersionTLS12}
case sslModeSkipVerify:
return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true}
default:
// Preferred: 先尝试 TLS为提升兼容性默认跳过证书校验失败时由调用方按需回退明文。
return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true}
}
}
func resolveMongoTLSSettings(config connection.ConnectionConfig) (enabled bool, insecure bool) {
switch normalizedSSLMode(config) {
case sslModeDisable:
return false, false
case sslModeRequired:
return true, false
case sslModeSkipVerify:
return true, true
default:
return true, true
}
}
func resolveTDengineNet(config connection.ConnectionConfig) string {
if normalizedSSLMode(config) == sslModeDisable {
return "ws"
}
return "wss"
}

View File

@@ -40,11 +40,12 @@ func (t *TDengineDB) getDSN(config connection.ConnectionConfig) string {
path = "/" + dbName
}
return fmt.Sprintf("%s:%s@ws(%s)%s", user, pass, net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), path)
netType := resolveTDengineNet(config)
return fmt.Sprintf("%s:%s@%s(%s)%s", user, pass, netType, net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), path)
}
func (t *TDengineDB) Connect(config connection.ConnectionConfig) error {
var dsn string
runConfig := config
if config.UseSSH {
logger.Infof("TDengine 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
@@ -68,23 +69,38 @@ func (t *TDengineDB) Connect(config connection.ConnectionConfig) error {
localConfig.Host = host
localConfig.Port = port
localConfig.UseSSH = false
dsn = t.getDSN(localConfig)
runConfig = localConfig
logger.Infof("TDengine 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
} else {
dsn = t.getDSN(config)
}
db, err := sql.Open("taosWS", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
attempts := []connection.ConnectionConfig{runConfig}
if shouldTrySSLPreferredFallback(runConfig) {
attempts = append(attempts, withSSLDisabled(runConfig))
}
t.conn = db
t.pingTimeout = getConnectTimeout(config)
if err := t.Ping(); err != nil {
return fmt.Errorf("连接建立后验证失败:%w", err)
var failures []string
for idx, attempt := range attempts {
dsn := t.getDSN(attempt)
db, err := sql.Open("taosWS", dsn)
if err != nil {
failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err))
continue
}
t.conn = db
t.pingTimeout = getConnectTimeout(attempt)
if err := t.Ping(); err != nil {
_ = db.Close()
t.conn = nil
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err))
continue
}
if idx > 0 {
logger.Warnf("TDengine SSL 优先连接失败,已回退至明文连接")
}
return nil
}
return nil
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ""))
}
func (t *TDengineDB) Close() error {

View File

@@ -41,7 +41,7 @@ func (v *VastbaseDB) getDSN(config connection.ConnectionConfig) string {
}
u.User = url.UserPassword(config.User, config.Password)
q := url.Values{}
q.Set("sslmode", "disable")
q.Set("sslmode", resolvePostgresSSLMode(config))
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
u.RawQuery = q.Encode()
@@ -49,7 +49,7 @@ func (v *VastbaseDB) getDSN(config connection.ConnectionConfig) string {
}
func (v *VastbaseDB) Connect(config connection.ConnectionConfig) error {
var dsn string
runConfig := config
if config.UseSSH {
logger.Infof("Vastbase 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
@@ -75,23 +75,37 @@ func (v *VastbaseDB) Connect(config connection.ConnectionConfig) error {
localConfig.Port = port
localConfig.UseSSH = false
dsn = v.getDSN(localConfig)
runConfig = localConfig
logger.Infof("Vastbase 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
} else {
dsn = v.getDSN(config)
}
db, err := sql.Open("postgres", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
attempts := []connection.ConnectionConfig{runConfig}
if shouldTrySSLPreferredFallback(runConfig) {
attempts = append(attempts, withSSLDisabled(runConfig))
}
v.conn = db
v.pingTimeout = getConnectTimeout(config)
if err := v.Ping(); err != nil {
return fmt.Errorf("连接建立后验证失败:%w", err)
var failures []string
for idx, attempt := range attempts {
dsn := v.getDSN(attempt)
db, err := sql.Open("postgres", dsn)
if err != nil {
failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err))
continue
}
v.conn = db
v.pingTimeout = getConnectTimeout(attempt)
if err := v.Ping(); err != nil {
_ = db.Close()
v.conn = nil
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err))
continue
}
if idx > 0 {
logger.Warnf("Vastbase SSL 优先连接失败,已回退至明文连接")
}
return nil
}
return nil
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ""))
}
func (v *VastbaseDB) Close() error {

View File

@@ -2,6 +2,7 @@ package redis
import (
"context"
"crypto/tls"
"fmt"
"net"
"strconv"
@@ -201,25 +202,48 @@ func (r *RedisClientImpl) Connect(config connection.ConnectionConfig) error {
timeout := normalizeRedisTimeout(config.Timeout)
if r.isCluster {
opts := &redis.ClusterOptions{
Addrs: seedAddrs,
Username: strings.TrimSpace(config.User),
Password: config.Password,
DialTimeout: timeout,
ReadTimeout: timeout,
WriteTimeout: timeout,
attempts := []connection.ConnectionConfig{config}
if shouldTryRedisSSLPreferredFallback(config) {
attempts = append(attempts, withRedisSSLDisabled(config))
}
clusterClient := redis.NewClusterClient(opts)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if err := clusterClient.Ping(ctx).Err(); err != nil {
clusterClient.Close()
return fmt.Errorf("Redis 集群连接失败: %w", err)
var failures []string
for idx, attempt := range attempts {
var tlsConfig *tls.Config
if cfg := resolveRedisTLSConfig(attempt); cfg != nil {
if host, _, err := net.SplitHostPort(seedAddrs[0]); err == nil && host != "" {
cfg.ServerName = host
}
tlsConfig = cfg
}
opts := &redis.ClusterOptions{
Addrs: seedAddrs,
Username: strings.TrimSpace(attempt.User),
Password: attempt.Password,
DialTimeout: timeout,
ReadTimeout: timeout,
WriteTimeout: timeout,
TLSConfig: tlsConfig,
}
clusterClient := redis.NewClusterClient(opts)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
pingErr := clusterClient.Ping(ctx).Err()
cancel()
if pingErr != nil {
clusterClient.Close()
failures = append(failures, fmt.Sprintf("第%d次连接失败: %v", idx+1, pingErr))
continue
}
r.client = clusterClient
r.clusterClient = clusterClient
r.config = attempt
if idx > 0 {
logger.Warnf("Redis 集群 SSL 优先连接失败,已回退至明文连接")
}
logger.Infof("Redis 集群连接成功: seeds=%s 逻辑库=db%d", strings.Join(seedAddrs, ","), r.currentDB)
return nil
}
r.client = clusterClient
r.clusterClient = clusterClient
logger.Infof("Redis 集群连接成功: seeds=%s 逻辑库=db%d", strings.Join(seedAddrs, ","), r.currentDB)
return nil
return fmt.Errorf("Redis 集群连接失败: %s", strings.Join(failures, ""))
}
addr := seedAddrs[0]
@@ -233,29 +257,53 @@ func (r *RedisClientImpl) Connect(config connection.ConnectionConfig) error {
logger.Infof("Redis 通过 SSH 隧道连接: %s -> %s:%d", addr, config.Host, config.Port)
}
opts := &redis.Options{
Addr: addr,
Username: strings.TrimSpace(config.User),
Password: config.Password,
DB: r.currentDB,
DialTimeout: timeout,
ReadTimeout: timeout,
WriteTimeout: timeout,
attempts := []connection.ConnectionConfig{config}
if shouldTryRedisSSLPreferredFallback(config) {
attempts = append(attempts, withRedisSSLDisabled(config))
}
singleClient := redis.NewClient(opts)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
var failures []string
for idx, attempt := range attempts {
var tlsConfig *tls.Config
if cfg := resolveRedisTLSConfig(attempt); cfg != nil {
if host, _, err := net.SplitHostPort(addr); err == nil && host != "" {
cfg.ServerName = host
}
tlsConfig = cfg
}
if err := singleClient.Ping(ctx).Err(); err != nil {
singleClient.Close()
return fmt.Errorf("Redis 连接失败: %w", err)
opts := &redis.Options{
Addr: addr,
Username: strings.TrimSpace(attempt.User),
Password: attempt.Password,
DB: r.currentDB,
DialTimeout: timeout,
ReadTimeout: timeout,
WriteTimeout: timeout,
TLSConfig: tlsConfig,
}
singleClient := redis.NewClient(opts)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
pingErr := singleClient.Ping(ctx).Err()
cancel()
if pingErr != nil {
singleClient.Close()
failures = append(failures, fmt.Sprintf("第%d次连接失败: %v", idx+1, pingErr))
continue
}
r.client = singleClient
r.singleClient = singleClient
r.config = attempt
if idx > 0 {
logger.Warnf("Redis SSL 优先连接失败,已回退至明文连接")
}
logger.Infof("Redis 连接成功: %s DB=%d", addr, r.currentDB)
return nil
}
r.client = singleClient
r.singleClient = singleClient
logger.Infof("Redis 连接成功: %s DB=%d", addr, r.currentDB)
return nil
return fmt.Errorf("Redis 连接失败: %s", strings.Join(failures, ""))
}
// Close closes the Redis connection

View File

@@ -0,0 +1,55 @@
package redis
import (
"crypto/tls"
"strings"
"GoNavi-Wails/internal/connection"
)
func normalizeRedisSSLMode(raw string) string {
mode := strings.ToLower(strings.TrimSpace(raw))
switch mode {
case "", "preferred", "prefer":
return "preferred"
case "required", "require", "on", "true", "mandatory", "strict":
return "required"
case "skip-verify", "insecure", "skipverify", "skip_verify", "insecure-skip-verify":
return "skip-verify"
case "disable", "disabled", "off", "false", "none":
return "disable"
default:
return "preferred"
}
}
func redisSSLMode(config connection.ConnectionConfig) string {
if !config.UseSSL {
return "disable"
}
return normalizeRedisSSLMode(config.SSLMode)
}
func shouldTryRedisSSLPreferredFallback(config connection.ConnectionConfig) bool {
return config.UseSSL && normalizeRedisSSLMode(config.SSLMode) == "preferred"
}
func withRedisSSLDisabled(config connection.ConnectionConfig) connection.ConnectionConfig {
next := config
next.UseSSL = false
next.SSLMode = "disable"
return next
}
func resolveRedisTLSConfig(config connection.ConnectionConfig) *tls.Config {
switch redisSSLMode(config) {
case "disable":
return nil
case "required":
return &tls.Config{MinVersion: tls.VersionTLS12}
case "skip-verify":
return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true}
default:
return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true}
}
}