Merge branch 'main' into dev

# Conflicts:
#	frontend/src/App.tsx
#	frontend/src/components/ConnectionModal.tsx
#	frontend/src/components/DataGrid.tsx
#	frontend/src/components/DataViewer.tsx
#	frontend/src/components/QueryEditor.tsx
#	internal/app/methods_driver.go
#	internal/app/methods_file_export_test.go
#	internal/db/clickhouse_impl.go
#	internal/db/oracle_impl.go
#	internal/redis/redis_impl.go
This commit is contained in:
Syngnat
2026-03-05 17:11:41 +08:00
48 changed files with 6308 additions and 495 deletions

3
.gitignore vendored
View File

@@ -19,3 +19,6 @@ GoNavi-Wails.exe
.ace-tool/
.claude/
tmpclaude-*
CLAUDE.md
**/CLAUDE.md

View File

@@ -0,0 +1,12 @@
//go:build gonavi_mongodb_driver_v1
package main
import "GoNavi-Wails/internal/db"
func init() {
agentDriverType = "mongodb"
agentDatabaseFactory = func() db.Database {
return &db.MongoDBV1{}
}
}

View File

@@ -1,8 +1,8 @@
import React, { useState, useEffect } from 'react';
import { Layout, Button, ConfigProvider, theme, Dropdown, MenuProps, message, Modal, Spin, Slider, Progress, Switch, Input, InputNumber, Select } from 'antd';
import zhCN from 'antd/locale/zh_CN';
import { PlusOutlined, ConsoleSqlOutlined, UploadOutlined, DownloadOutlined, CloudDownloadOutlined, BugOutlined, ToolOutlined, GlobalOutlined, InfoCircleOutlined, GithubOutlined, SkinOutlined, CheckOutlined, MinusOutlined, BorderOutlined, CloseOutlined, SettingOutlined } from '@ant-design/icons';
import { BrowserOpenURL, Environment, EventsOn, Quit, WindowFullscreen, WindowIsFullscreen, WindowIsMaximised, WindowMaximise, WindowMinimise, WindowToggleMaximise } from '../wailsjs/runtime';
import { PlusOutlined, ConsoleSqlOutlined, UploadOutlined, DownloadOutlined, CloudDownloadOutlined, BugOutlined, ToolOutlined, GlobalOutlined, InfoCircleOutlined, GithubOutlined, SkinOutlined, CheckOutlined, MinusOutlined, BorderOutlined, CloseOutlined, SettingOutlined, LinkOutlined } from '@ant-design/icons';
import { BrowserOpenURL, Environment, EventsOn, Quit, WindowFullscreen, WindowGetSize, WindowIsFullscreen, WindowIsMaximised, WindowMaximise, WindowMinimise, WindowSetSize, WindowToggleMaximise } from '../wailsjs/runtime';
import Sidebar from './components/Sidebar';
import TabManager from './components/TabManager';
import ConnectionModal from './components/ConnectionModal';
@@ -12,6 +12,17 @@ import LogPanel from './components/LogPanel';
import { useStore } from './store';
import { SavedConnection } from './types';
import { blurToFilter, normalizeBlurForPlatform, normalizeOpacityForPlatform, isWindowsPlatform } from './utils/appearance';
import {
SHORTCUT_ACTION_META,
SHORTCUT_ACTION_ORDER,
ShortcutAction,
eventToShortcut,
getShortcutDisplay,
hasModifierKey,
isEditableElement,
isShortcutMatch,
normalizeShortcutCombo,
} from './utils/shortcuts';
import { ConfigureGlobalProxy, SetWindowTranslucency } from '../wailsjs/go/app/App';
import './App.css';
@@ -53,6 +64,9 @@ function App() {
const setStartupFullscreen = useStore(state => state.setStartupFullscreen);
const globalProxy = useStore(state => state.globalProxy);
const setGlobalProxy = useStore(state => state.setGlobalProxy);
const shortcutOptions = useStore(state => state.shortcutOptions);
const updateShortcut = useStore(state => state.updateShortcut);
const resetShortcutOptions = useStore(state => state.resetShortcutOptions);
const darkMode = themeMode === 'dark';
const effectiveUiScale = Math.min(MAX_UI_SCALE, Math.max(MIN_UI_SCALE, Number(uiScale) || DEFAULT_UI_SCALE));
const effectiveFontSize = Math.min(MAX_FONT_SIZE, Math.max(MIN_FONT_SIZE, Math.round(Number(fontSize) || DEFAULT_FONT_SIZE)));
@@ -260,6 +274,80 @@ function App() {
};
}, []);
useEffect(() => {
if (!isWindowsPlatform()) {
return;
}
let cancelled = false;
let inFlight = false;
let lastRatio = Number(window.devicePixelRatio) || 1;
let lastFixAt = 0;
const wait = (ms: number) => new Promise<void>((resolve) => window.setTimeout(resolve, ms));
const fixWindowScaleIfNeeded = async () => {
if (cancelled || inFlight) return;
const now = Date.now();
if (now - lastFixAt < 700) return;
inFlight = true;
try {
const [isFullscreen, isMaximised] = await Promise.all([
WindowIsFullscreen().catch(() => false),
WindowIsMaximised().catch(() => false),
]);
// 避免在全屏/最大化状态下强制改尺寸;这两种状态通常能自行保持 DPI 同步。
if (isFullscreen || isMaximised) {
window.dispatchEvent(new Event('resize'));
lastFixAt = Date.now();
return;
}
const size = await WindowGetSize().catch(() => null);
const width = Math.trunc(Number(size?.w || 0));
const height = Math.trunc(Number(size?.h || 0));
if (width <= 0 || height <= 0) {
window.dispatchEvent(new Event('resize'));
lastFixAt = Date.now();
return;
}
const nudgedWidth = width > 480 ? width - 1 : width + 1;
WindowSetSize(nudgedWidth, height);
await wait(28);
WindowSetSize(width, height);
window.dispatchEvent(new Event('resize'));
lastFixAt = Date.now();
} finally {
inFlight = false;
}
};
const checkDevicePixelRatio = () => {
if (cancelled) return;
const currentRatio = Number(window.devicePixelRatio) || 1;
if (Math.abs(currentRatio - lastRatio) < 0.02) {
return;
}
lastRatio = currentRatio;
void fixWindowScaleIfNeeded();
};
const pollTimer = window.setInterval(checkDevicePixelRatio, 900);
window.addEventListener('resize', checkDevicePixelRatio);
window.addEventListener('focus', checkDevicePixelRatio);
document.addEventListener('visibilitychange', checkDevicePixelRatio);
return () => {
cancelled = true;
window.clearInterval(pollTimer);
window.removeEventListener('resize', checkDevicePixelRatio);
window.removeEventListener('focus', checkDevicePixelRatio);
document.removeEventListener('visibilitychange', checkDevicePixelRatio);
};
}, []);
// Background Helper
const getBg = (darkHex: string) => {
if (!darkMode) return `rgba(255, 255, 255, ${effectiveOpacity})`; // Light mode usually white
@@ -299,7 +387,7 @@ function App() {
const [isAboutOpen, setIsAboutOpen] = useState(false);
const isAboutOpenRef = React.useRef(false);
const [aboutLoading, setAboutLoading] = useState(false);
const [aboutInfo, setAboutInfo] = useState<{ version: string; author: string; buildTime?: string; repoUrl?: string; issueUrl?: string; releaseUrl?: string } | null>(null);
const [aboutInfo, setAboutInfo] = useState<{ version: string; author: string; buildTime?: string; repoUrl?: string; issueUrl?: string; releaseUrl?: string; communityUrl?: string } | null>(null);
const [aboutUpdateStatus, setAboutUpdateStatus] = useState<string>('');
const [lastUpdateInfo, setLastUpdateInfo] = useState<UpdateInfo | null>(null);
const [updateDownloadProgress, setUpdateDownloadProgress] = useState<{
@@ -666,7 +754,7 @@ function App() {
void message.warning("没有连接可导出");
return;
}
const res = await (window as any).go.app.App.ExportData(connections, [], "connections", "json");
const res = await (window as any).go.app.App.ExportData(connections, ['id','name','config','includeDatabases','includeRedisDatabases'], "connections", "json");
if (res.success) {
void message.success("导出成功");
} else if (res.message !== "Cancelled") {
@@ -720,10 +808,18 @@ function App() {
label: '外观设置...',
icon: <SettingOutlined />,
onClick: () => setIsAppearanceModalOpen(true)
},
{
key: 'shortcut-settings',
label: '快捷键管理...',
icon: <LinkOutlined />,
onClick: () => setIsShortcutModalOpen(true)
}
];
const [isAppearanceModalOpen, setIsAppearanceModalOpen] = useState(false);
const [isShortcutModalOpen, setIsShortcutModalOpen] = useState(false);
const [capturingShortcutAction, setCapturingShortcutAction] = useState<ShortcutAction | null>(null);
const [isProxyModalOpen, setIsProxyModalOpen] = useState(false);
@@ -803,7 +899,7 @@ function App() {
};
// Sidebar Resizing
const [sidebarWidth, setSidebarWidth] = useState(300);
const [sidebarWidth, setSidebarWidth] = useState(330);
const sidebarDragRef = React.useRef<{ startX: number, startWidth: number } | null>(null);
const rafRef = React.useRef<number | null>(null);
const ghostRef = React.useRef<HTMLDivElement>(null);
@@ -935,6 +1031,113 @@ function App() {
};
}, []);
useEffect(() => {
const handleOpenShortcutSettingsEvent = () => {
setIsShortcutModalOpen(true);
};
window.addEventListener('gonavi:open-shortcut-settings', handleOpenShortcutSettingsEvent as EventListener);
return () => {
window.removeEventListener('gonavi:open-shortcut-settings', handleOpenShortcutSettingsEvent as EventListener);
};
}, []);
useEffect(() => {
const handleGlobalShortcut = (event: KeyboardEvent) => {
const matchedAction = SHORTCUT_ACTION_ORDER.find((action) => {
const binding = shortcutOptions[action];
if (!binding?.enabled) {
return false;
}
if (isEditableElement(event.target) && !SHORTCUT_ACTION_META[action].allowInEditable) {
return false;
}
return isShortcutMatch(event, binding.combo);
});
if (!matchedAction) {
return;
}
event.preventDefault();
event.stopPropagation();
switch (matchedAction) {
case 'runQuery':
window.dispatchEvent(new CustomEvent('gonavi:run-active-query'));
break;
case 'focusSidebarSearch':
window.dispatchEvent(new CustomEvent('gonavi:focus-sidebar-search'));
break;
case 'newQueryTab':
handleNewQuery();
break;
case 'toggleLogPanel':
setIsLogPanelOpen((prev) => !prev);
break;
case 'toggleTheme':
setTheme(themeMode === 'dark' ? 'light' : 'dark');
break;
case 'openShortcutManager':
setIsShortcutModalOpen(true);
break;
}
};
window.addEventListener('keydown', handleGlobalShortcut);
return () => {
window.removeEventListener('keydown', handleGlobalShortcut);
};
}, [handleNewQuery, shortcutOptions, themeMode, setTheme]);
useEffect(() => {
if (!capturingShortcutAction) {
return;
}
const handleShortcutCapture = (event: KeyboardEvent) => {
event.preventDefault();
event.stopPropagation();
if (event.key === 'Escape') {
setCapturingShortcutAction(null);
return;
}
const combo = eventToShortcut(event);
if (!combo) {
return;
}
if (!hasModifierKey(combo)) {
void message.warning('快捷键至少包含 Ctrl / Alt / Shift / Meta 之一');
return;
}
const normalizedCombo = normalizeShortcutCombo(combo);
const conflictAction = SHORTCUT_ACTION_ORDER.find((action) => {
if (action === capturingShortcutAction) {
return false;
}
const binding = shortcutOptions[action];
if (!binding?.enabled) {
return false;
}
return normalizeShortcutCombo(binding.combo) === normalizedCombo;
});
if (conflictAction) {
void message.warning(`与「${SHORTCUT_ACTION_META[conflictAction].label}」冲突,请换一个快捷键`);
return;
}
updateShortcut(capturingShortcutAction, { combo: normalizedCombo, enabled: true });
setCapturingShortcutAction(null);
};
window.addEventListener('keydown', handleShortcutCapture, true);
return () => {
window.removeEventListener('keydown', handleShortcutCapture, true);
};
}, [capturingShortcutAction, shortcutOptions, updateShortcut]);
const linuxResizeHandleStyleBase = {
position: 'fixed',
zIndex: 12000,
@@ -1221,6 +1424,9 @@ function App() {
<div style={{ display: 'flex', flexDirection: 'column', gap: 8 }}>
<div>{aboutInfo?.version || '未知'}</div>
<div>{aboutInfo?.author || '未知'}</div>
{aboutInfo?.communityUrl ? (
<div><a onClick={(e) => { e.preventDefault(); if (aboutInfo?.communityUrl) BrowserOpenURL(aboutInfo.communityUrl); }} href={aboutInfo.communityUrl}>AI全书</a></div>
) : null}
<div>{aboutUpdateStatus || '未检查'}</div>
<div style={{ display: 'flex', alignItems: 'center', gap: 6 }}>
<GithubOutlined />
@@ -1351,6 +1557,84 @@ function App() {
</div>
</Modal>
<Modal
title="快捷键管理"
open={isShortcutModalOpen}
onCancel={() => {
setIsShortcutModalOpen(false);
setCapturingShortcutAction(null);
}}
width={720}
footer={[
<Button
key="reset"
onClick={() => {
resetShortcutOptions();
setCapturingShortcutAction(null);
void message.success('已恢复默认快捷键');
}}
>
</Button>,
<Button
key="close"
type="primary"
onClick={() => {
setIsShortcutModalOpen(false);
setCapturingShortcutAction(null);
}}
>
</Button>,
]}
>
<div style={{ display: 'flex', flexDirection: 'column', gap: 12, paddingTop: 8 }}>
<div style={{ fontSize: 12, color: '#8c8c8c' }}>
Esc Ctrl/Alt/Shift/Meta
</div>
{SHORTCUT_ACTION_ORDER.map((action) => {
const meta = SHORTCUT_ACTION_META[action];
const binding = shortcutOptions[action] ?? { combo: '', enabled: false };
const isCapturing = capturingShortcutAction === action;
return (
<div
key={action}
style={{
display: 'grid',
gridTemplateColumns: '1fr auto',
gap: 12,
alignItems: 'center',
padding: '10px 12px',
border: '1px solid rgba(128, 128, 128, 0.2)',
borderRadius: 8,
}}
>
<div>
<div style={{ fontWeight: 500 }}>{meta.label}</div>
<div style={{ fontSize: 12, color: '#8c8c8c' }}>{meta.description}</div>
</div>
<div style={{ display: 'flex', alignItems: 'center', gap: 8 }}>
<Input
readOnly
value={isCapturing ? '请按下快捷键...' : getShortcutDisplay(binding.combo)}
style={{ width: 180, fontFamily: 'Consolas, Menlo, Monaco, monospace' }}
/>
<Button
size="small"
onClick={() => setCapturingShortcutAction((prev) => (prev === action ? null : action))}
>
{isCapturing ? '取消' : '录制'}
</Button>
<Switch
checked={binding.enabled}
onChange={(checked) => updateShortcut(action, { enabled: checked })}
/>
</div>
</div>
);
})}
</div>
</Modal>
<Modal
title="全局代理设置"
open={isProxyModalOpen}

View File

@@ -54,6 +54,26 @@ const singleHostUriSchemesByType: Record<string, string[]> = {
vastbase: ['vastbase'],
};
const sslSupportedTypes = new Set([
'mysql',
'mariadb',
'diros',
'sphinx',
'dameng',
'clickhouse',
'postgres',
'sqlserver',
'oracle',
'kingbase',
'highgo',
'vastbase',
'mongodb',
'redis',
'tdengine',
]);
const supportsSSLForType = (type: string) => sslSupportedTypes.has(String(type || '').trim().toLowerCase());
const isFileDatabaseType = (type: string) => type === 'sqlite' || type === 'duckdb';
type DriverStatusSnapshot = {
@@ -78,6 +98,7 @@ const ConnectionModal: React.FC<{
}> = ({ open, onClose, initialValues, onOpenDriverManager }) => {
const [form] = Form.useForm();
const [loading, setLoading] = useState(false);
const [useSSL, setUseSSL] = useState(false);
const [useSSH, setUseSSH] = useState(false);
const [useProxy, setUseProxy] = useState(false);
const [dbType, setDbType] = useState('mysql');
@@ -107,6 +128,17 @@ const ConnectionModal: React.FC<{
const mongoTopology = Form.useWatch('mongoTopology', form) || 'single';
const mongoSrv = Form.useWatch('mongoSrv', form) || false;
const redisTopology = Form.useWatch('redisTopology', form) || 'single';
const isMySQLLike = dbType === 'mysql' || dbType === 'mariadb' || dbType === 'diros' || dbType === 'sphinx';
const isSSLType = supportsSSLForType(dbType);
const sslHintText = isMySQLLike
? '当 MySQL/MariaDB/Doris/Sphinx 开启安全传输策略时,请启用 SSL本地自签证书场景可先用 Preferred 或 Skip Verify。'
: dbType === 'dameng'
? '达梦驱动启用 SSL 需要客户端证书与私钥路径sslCertPath / sslKeyPath。'
: dbType === 'sqlserver'
? 'SQL Server 推荐在生产环境使用 Required并关闭 TrustServerCertificate。'
: dbType === 'mongodb'
? 'MongoDB 可通过 TLS 保护连接,证书校验异常时可先用 Skip Verify 验证连通性。'
: '建议优先使用 Required仅在测试环境或自签证书场景使用 Skip Verify。';
const getSectionBg = (darkHex: string) => {
if (!darkMode) {
@@ -364,7 +396,7 @@ const ConnectionModal: React.FC<{
uriText: string,
expectedSchemes: string[],
defaultPort: number,
): { host: string; port: number; username: string; password: string; database: string } | null => {
): { host: string; port: number; username: string; password: string; database: string; params: URLSearchParams } | null => {
let parsed: ReturnType<typeof parseMultiHostUri> | null = null;
for (const scheme of expectedSchemes) {
parsed = parseMultiHostUri(uriText, scheme);
@@ -392,6 +424,7 @@ const ConnectionModal: React.FC<{
username: parsed.username,
password: parsed.password,
database: parsed.database || '',
params: parsed.params,
};
};
@@ -425,12 +458,22 @@ const ConnectionModal: React.FC<{
const primary = parseHostPort(hostList[0] || `localhost:${mysqlDefaultPort}`, mysqlDefaultPort);
const timeoutValue = Number(parsed.params.get('timeout'));
const topology = String(parsed.params.get('topology') || '').toLowerCase();
const tlsValue = String(parsed.params.get('tls') || '').trim().toLowerCase();
const sslMode = tlsValue === 'true'
? 'required'
: tlsValue === 'skip-verify'
? 'skip-verify'
: tlsValue === 'preferred'
? 'preferred'
: 'disable';
return {
host: primary?.host || 'localhost',
port: primary?.port || mysqlDefaultPort,
user: parsed.username,
password: parsed.password,
database: parsed.database || '',
useSSL: sslMode !== 'disable',
sslMode,
mysqlTopology: hostList.length > 1 || topology === 'replica' ? 'replica' : 'single',
mysqlReplicaHosts: hostList.slice(1),
timeout: Number.isFinite(timeoutValue) && timeoutValue > 0
@@ -451,7 +494,7 @@ const ConnectionModal: React.FC<{
}
if (type === 'redis') {
const parsed = parseMultiHostUri(trimmedUri, 'redis');
const parsed = parseMultiHostUri(trimmedUri, 'redis') || parseMultiHostUri(trimmedUri, 'rediss');
if (!parsed) {
return null;
}
@@ -469,10 +512,15 @@ const ConnectionModal: React.FC<{
const topologyParam = String(parsed.params.get('topology') || '').toLowerCase();
const dbText = String(parsed.database || '').trim().replace(/^\//, '');
const dbIndex = Number(dbText);
const isRediss = trimmedUri.toLowerCase().startsWith('rediss://');
const skipVerifyText = String(parsed.params.get('skip_verify') || '').trim().toLowerCase();
const skipVerify = skipVerifyText === '1' || skipVerifyText === 'true' || skipVerifyText === 'yes' || skipVerifyText === 'on';
return {
host: primary?.host || 'localhost',
port: primary?.port || 6379,
password: parsed.password || '',
useSSL: isRediss,
sslMode: isRediss ? (skipVerify ? 'skip-verify' : 'required') : 'disable',
redisTopology: hostList.length > 1 || topologyParam === 'cluster' ? 'cluster' : 'single',
redisHosts: hostList.slice(1),
redisDB: Number.isFinite(dbIndex) && dbIndex >= 0 && dbIndex <= 15 ? Math.trunc(dbIndex) : 0,
@@ -501,12 +549,18 @@ const ConnectionModal: React.FC<{
? { host: hostList[0] || 'localhost', port: 27017 }
: parseHostPort(hostList[0] || 'localhost:27017', 27017);
const timeoutMs = Number(parsed.params.get('connectTimeoutMS') || parsed.params.get('serverSelectionTimeoutMS'));
const tlsText = String(parsed.params.get('tls') || parsed.params.get('ssl') || '').trim().toLowerCase();
const tlsInsecureText = String(parsed.params.get('tlsInsecure') || parsed.params.get('sslInsecure') || '').trim().toLowerCase();
const tlsEnabled = tlsText === '1' || tlsText === 'true' || tlsText === 'yes' || tlsText === 'on';
const tlsInsecure = tlsInsecureText === '1' || tlsInsecureText === 'true' || tlsInsecureText === 'yes' || tlsInsecureText === 'on';
return {
host: primary?.host || 'localhost',
port: primary?.port || 27017,
user: parsed.username,
password: parsed.password,
database: parsed.database || '',
useSSL: tlsEnabled,
sslMode: tlsEnabled ? (tlsInsecure ? 'skip-verify' : 'required') : 'disable',
mongoTopology: hostList.length > 1 || !!parsed.params.get('replicaSet') ? 'replica' : 'single',
mongoHosts: hostList.slice(1),
mongoSrv: isSrv,
@@ -531,13 +585,94 @@ const ConnectionModal: React.FC<{
// Oracle 需要显式 service name避免 URI 解析后放过必填校验。
return null;
}
return {
const parsedValues: Record<string, any> = {
host: parsed.host,
port: parsed.port,
user: parsed.username,
password: parsed.password,
database: parsed.database,
};
if (supportsSSLForType(type)) {
const normalizeBool = (raw: unknown) => {
const text = String(raw ?? '').trim().toLowerCase();
return text === '1' || text === 'true' || text === 'yes' || text === 'on';
};
if (type === 'postgres' || type === 'kingbase' || type === 'highgo' || type === 'vastbase') {
const sslMode = String(parsed.params.get('sslmode') || '').trim().toLowerCase();
if (sslMode) {
parsedValues.useSSL = sslMode !== 'disable' && sslMode !== 'false';
parsedValues.sslMode = sslMode === 'disable' || sslMode === 'false'
? 'disable'
: 'required';
}
} else if (type === 'sqlserver') {
const encrypt = String(parsed.params.get('encrypt') || '').trim().toLowerCase();
const trust = String(parsed.params.get('TrustServerCertificate') || parsed.params.get('trustservercertificate') || '').trim().toLowerCase();
const encrypted = encrypt === 'true' || encrypt === 'mandatory' || encrypt === 'yes' || encrypt === '1' || encrypt === 'strict';
if (encrypted) {
parsedValues.useSSL = true;
parsedValues.sslMode = trust === 'true' || trust === '1' || trust === 'yes' ? 'skip-verify' : 'required';
} else if (encrypt) {
parsedValues.useSSL = false;
parsedValues.sslMode = 'disable';
}
} else if (type === 'clickhouse') {
const secure = String(parsed.params.get('secure') || parsed.params.get('tls') || '').trim().toLowerCase();
const skipVerify = normalizeBool(parsed.params.get('skip_verify'));
if (secure) {
parsedValues.useSSL = normalizeBool(secure);
parsedValues.sslMode = skipVerify ? 'skip-verify' : (parsedValues.useSSL ? 'required' : 'disable');
}
} else if (type === 'dameng') {
const certPath = String(
parsed.params.get('SSL_CERT_PATH')
|| parsed.params.get('ssl_cert_path')
|| parsed.params.get('sslCertPath')
|| ''
).trim();
const keyPath = String(
parsed.params.get('SSL_KEY_PATH')
|| parsed.params.get('ssl_key_path')
|| parsed.params.get('sslKeyPath')
|| ''
).trim();
parsedValues.sslCertPath = certPath;
parsedValues.sslKeyPath = keyPath;
if (certPath || keyPath) {
parsedValues.useSSL = true;
parsedValues.sslMode = 'required';
}
} else if (type === 'oracle') {
const ssl = String(parsed.params.get('SSL') || parsed.params.get('ssl') || '').trim().toLowerCase();
const sslVerify = String(
parsed.params.get('SSL VERIFY')
|| parsed.params.get('ssl verify')
|| parsed.params.get('SSL_VERIFY')
|| parsed.params.get('ssl_verify')
|| ''
).trim().toLowerCase();
if (ssl) {
parsedValues.useSSL = normalizeBool(ssl);
if (!parsedValues.useSSL) {
parsedValues.sslMode = 'disable';
} else {
parsedValues.sslMode = normalizeBool(sslVerify || 'true') ? 'required' : 'skip-verify';
}
}
} else if (type === 'tdengine') {
const protocol = String(parsed.params.get('protocol') || '').trim().toLowerCase();
const skipVerify = normalizeBool(parsed.params.get('skip_verify'));
if (protocol === 'wss') {
parsedValues.useSSL = true;
parsedValues.sslMode = skipVerify ? 'skip-verify' : 'required';
} else if (protocol === 'ws') {
parsedValues.useSSL = false;
parsedValues.sslMode = 'disable';
}
}
};
return parsedValues;
}
return null;
@@ -609,6 +744,16 @@ const ConnectionModal: React.FC<{
if (hosts.length > 1 || values.mysqlTopology === 'replica') {
params.set('topology', 'replica');
}
if (values.useSSL) {
const mode = String(values.sslMode || 'preferred').trim().toLowerCase();
if (mode === 'required') {
params.set('tls', 'true');
} else if (mode === 'skip-verify') {
params.set('tls', 'skip-verify');
} else {
params.set('tls', 'preferred');
}
}
if (Number.isFinite(timeout) && timeout > 0) {
params.set('timeout', String(timeout));
}
@@ -634,8 +779,15 @@ const ConnectionModal: React.FC<{
? Math.max(0, Math.min(15, Math.trunc(Number(values.redisDB))))
: 0;
const dbPath = `/${redisDB}`;
if (values.useSSL) {
const mode = String(values.sslMode || 'preferred').trim().toLowerCase();
if (mode === 'skip-verify' || mode === 'preferred') {
params.set('skip_verify', 'true');
}
}
const query = params.toString();
return `redis://${redisAuth}${hosts.join(',')}${dbPath}${query ? `?${query}` : ''}`;
const scheme = values.useSSL ? 'rediss' : 'redis';
return `${scheme}://${redisAuth}${hosts.join(',')}${dbPath}${query ? `?${query}` : ''}`;
}
if (isFileDatabaseType(type)) {
@@ -675,6 +827,15 @@ const ConnectionModal: React.FC<{
if (authMechanism) {
params.set('authMechanism', authMechanism);
}
if (values.useSSL) {
const mode = String(values.sslMode || 'preferred').trim().toLowerCase();
params.set('tls', 'true');
if (mode === 'skip-verify' || mode === 'preferred') {
params.set('tlsInsecure', 'true');
} else {
params.delete('tlsInsecure');
}
}
if (Number.isFinite(timeout) && timeout > 0) {
params.set('connectTimeoutMS', String(timeout * 1000));
params.set('serverSelectionTimeoutMS', String(timeout * 1000));
@@ -686,7 +847,45 @@ const ConnectionModal: React.FC<{
const scheme = type === 'postgres' ? 'postgresql' : type;
const dbPath = database ? `/${encodeURIComponent(database)}` : '';
return `${scheme}://${encodedAuth}${toAddress(host, port, defaultPort)}${dbPath}`;
const params = new URLSearchParams();
if (supportsSSLForType(type) && values.useSSL) {
const mode = String(values.sslMode || 'preferred').trim().toLowerCase();
if (type === 'postgres' || type === 'kingbase' || type === 'highgo' || type === 'vastbase') {
params.set('sslmode', 'require');
} else if (type === 'sqlserver') {
params.set('encrypt', 'true');
params.set('TrustServerCertificate', mode === 'skip-verify' || mode === 'preferred' ? 'true' : 'false');
} else if (type === 'clickhouse') {
params.set('secure', 'true');
if (mode === 'skip-verify' || mode === 'preferred') {
params.set('skip_verify', 'true');
}
} else if (type === 'dameng') {
const certPath = String(values.sslCertPath || '').trim();
const keyPath = String(values.sslKeyPath || '').trim();
if (certPath) params.set('SSL_CERT_PATH', certPath);
if (keyPath) params.set('SSL_KEY_PATH', keyPath);
} else if (type === 'oracle') {
params.set('SSL', 'TRUE');
params.set('SSL VERIFY', mode === 'required' ? 'TRUE' : 'FALSE');
} else if (type === 'tdengine') {
params.set('protocol', 'wss');
if (mode === 'skip-verify' || mode === 'preferred') {
params.set('skip_verify', 'true');
}
}
} else if (supportsSSLForType(type)) {
if (type === 'postgres' || type === 'kingbase' || type === 'highgo' || type === 'vastbase') {
params.set('sslmode', 'disable');
} else if (type === 'sqlserver') {
params.set('encrypt', 'disable');
params.set('TrustServerCertificate', 'true');
} else if (type === 'tdengine') {
params.set('protocol', 'ws');
}
}
const query = params.toString();
return `${scheme}://${encodedAuth}${toAddress(host, port, defaultPort)}${dbPath}${query ? `?${query}` : ''}`;
};
const handleGenerateURI = () => {
@@ -838,6 +1037,10 @@ const ConnectionModal: React.FC<{
uri: config.uri || '',
includeDatabases: initialValues.includeDatabases,
includeRedisDatabases: initialValues.includeRedisDatabases,
useSSL: !!config.useSSL,
sslMode: config.sslMode || 'preferred',
sslCertPath: config.sslCertPath || '',
sslKeyPath: config.sslKeyPath || '',
useSSH: config.useSSH,
sshHost: config.ssh?.host,
sshPort: config.ssh?.port,
@@ -871,6 +1074,7 @@ const ConnectionModal: React.FC<{
mongoReplicaUser: config.mongoReplicaUser || '',
mongoReplicaPassword: config.mongoReplicaPassword || ''
});
setUseSSL(!!config.useSSL);
setUseSSH(config.useSSH || false);
setUseProxy(config.useProxy || false);
setDbType(configType);
@@ -882,6 +1086,7 @@ const ConnectionModal: React.FC<{
// Create mode: Start at step 1
setStep(1);
form.resetFields();
setUseSSL(false);
setUseSSH(false);
setUseProxy(false);
setDbType('mysql');
@@ -932,6 +1137,7 @@ const ConnectionModal: React.FC<{
setLoading(false);
form.resetFields();
setUseSSL(false);
setUseSSH(false);
setUseProxy(false);
setDbType('mysql');
@@ -1073,6 +1279,21 @@ const ConnectionModal: React.FC<{
const type = String(mergedValues.type || '').toLowerCase();
const defaultPort = getDefaultPortByType(type);
const isFileDbType = isFileDatabaseType(type);
const sslCapableType = supportsSSLForType(type);
const sslModeRaw = String(mergedValues.sslMode || 'preferred').trim().toLowerCase();
const sslMode: 'preferred' | 'required' | 'skip-verify' | 'disable' = sslModeRaw === 'required'
? 'required'
: sslModeRaw === 'skip-verify'
? 'skip-verify'
: sslModeRaw === 'disable'
? 'disable'
: 'preferred';
const effectiveUseSSL = sslCapableType && !!mergedValues.useSSL;
const sslCertPath = sslCapableType ? String(mergedValues.sslCertPath || '').trim() : '';
const sslKeyPath = sslCapableType ? String(mergedValues.sslKeyPath || '').trim() : '';
if (type === 'dameng' && effectiveUseSSL && (!sslCertPath || !sslKeyPath)) {
throw new Error('达梦启用 SSL 时必须填写证书路径与私钥路径');
}
let primaryHost = 'localhost';
let primaryPort = defaultPort;
@@ -1194,6 +1415,10 @@ const ConnectionModal: React.FC<{
password: keepPassword ? (mergedValues.password || "") : "",
savePassword: savePassword,
database: mergedValues.database || "",
useSSL: effectiveUseSSL,
sslMode: effectiveUseSSL ? sslMode : 'disable',
sslCertPath: sslCertPath,
sslKeyPath: sslKeyPath,
useSSH: !!mergedValues.useSSH,
ssh: sshConfig,
useProxy: effectiveUseProxy,
@@ -1233,6 +1458,7 @@ const ConnectionModal: React.FC<{
const defaultPort = getDefaultPortByType(type);
if (isFileDatabaseType(type)) {
setUseSSL(false);
setUseSSH(false);
setUseProxy(false);
form.setFieldsValue({
@@ -1241,6 +1467,10 @@ const ConnectionModal: React.FC<{
user: '',
password: '',
database: '',
useSSL: false,
sslMode: 'preferred',
sslCertPath: '',
sslKeyPath: '',
useSSH: false,
sshHost: '',
sshPort: 22,
@@ -1273,10 +1503,16 @@ const ConnectionModal: React.FC<{
});
} else if (type !== 'custom') {
const defaultUser = type === 'clickhouse' ? 'default' : 'root';
const sslCapableType = supportsSSLForType(type);
setUseSSL(false);
form.setFieldsValue({
user: defaultUser,
database: '',
port: defaultPort,
useSSL: sslCapableType ? false : undefined,
sslMode: sslCapableType ? 'preferred' : undefined,
sslCertPath: sslCapableType ? '' : undefined,
sslKeyPath: sslCapableType ? '' : undefined,
mysqlTopology: 'single',
redisTopology: 'single',
mongoTopology: 'single',
@@ -1420,6 +1656,10 @@ const ConnectionModal: React.FC<{
port: 3306,
database: '',
user: 'root',
useSSL: false,
sslMode: 'preferred',
sslCertPath: '',
sslKeyPath: '',
useSSH: false,
sshPort: 22,
useProxy: false,
@@ -1451,6 +1691,7 @@ const ConnectionModal: React.FC<{
if (changed.uri !== undefined || changed.type !== undefined) {
setUriFeedback(null);
}
if (changed.useSSL !== undefined) setUseSSL(changed.useSSL);
if (changed.useSSH !== undefined) setUseSSH(changed.useSSH);
if (changed.useProxy !== undefined) setUseProxy(changed.useProxy);
if (changed.proxyType !== undefined) {
@@ -1835,6 +2076,56 @@ const ConnectionModal: React.FC<{
{!isFileDb && (
<>
{isSSLType && (
<>
<Divider style={{ margin: '12px 0' }} />
<Form.Item name="useSSL" valuePropName="checked" style={{ marginBottom: 0 }}>
<Checkbox>使 SSL/TLS</Checkbox>
</Form.Item>
{useSSL && (
<div style={tunnelSectionStyle}>
<Form.Item
name="sslMode"
label="SSL 模式"
rules={[{ required: true, message: '请选择 SSL 模式' }]}
style={{ marginBottom: 8 }}
>
<Select
options={[
{ value: 'preferred', label: 'Preferred优先 SSL推荐' },
{ value: 'required', label: 'Required必须 SSL校验证书' },
{ value: 'skip-verify', label: 'Skip Verify必须 SSL跳过证书校验' },
]}
/>
</Form.Item>
{dbType === 'dameng' && (
<>
<Form.Item
name="sslCertPath"
label="客户端证书路径 (SSL_CERT_PATH)"
rules={[{ required: true, message: '达梦 SSL 需要证书路径' }]}
style={{ marginBottom: 8 }}
>
<Input placeholder="例如: C:\\certs\\client-cert.pem" />
</Form.Item>
<Form.Item
name="sslKeyPath"
label="客户端私钥路径 (SSL_KEY_PATH)"
rules={[{ required: true, message: '达梦 SSL 需要私钥路径' }]}
style={{ marginBottom: 8 }}
>
<Input placeholder="例如: C:\\certs\\client-key.pem" />
</Form.Item>
</>
)}
<Text type="secondary" style={{ fontSize: 12 }}>
{sslHintText}
</Text>
</div>
)}
</>
)}
<Divider style={{ margin: '12px 0' }} />
<Form.Item name="useSSH" valuePropName="checked" style={{ marginBottom: 0 }}>
<Checkbox>使 SSH (SSH Tunnel)</Checkbox>

View File

@@ -537,6 +537,7 @@ const ContextMenuRow = React.memo(({ children, record, ...props }: any) => {
{ key: 'exp-xlsx', label: 'Excel', onClick: () => handleExportSelected('xlsx', record) },
{ key: 'exp-json', label: 'JSON', onClick: () => handleExportSelected('json', record) },
{ key: 'exp-md', label: 'Markdown', onClick: () => handleExportSelected('md', record) },
{ key: 'exp-html', label: 'HTML', onClick: () => handleExportSelected('html', record) },
]
}
];
@@ -577,7 +578,9 @@ interface DataGridProps {
// Filtering
showFilter?: boolean;
onToggleFilter?: () => void;
exportSqlWithFilter?: string;
onApplyFilter?: (conditions: GridFilterCondition[]) => void;
appliedFilterConditions?: FilterCondition[];
}
type GridFilterCondition = FilterCondition & {
@@ -595,9 +598,9 @@ type ColumnMeta = {
comment: string;
};
const DataGrid: React.FC<DataGridProps> = ({
const DataGrid: React.FC<DataGridProps> = ({
data, columnNames, loading, tableName, exportScope = 'table', resultSql, dbName, connectionId, pkColumns = [], readOnly = false,
onReload, onSort, onPageChange, pagination, onRequestTotalCount, onCancelTotalCount, sortInfoExternal, showFilter, onToggleFilter, onApplyFilter
onReload, onSort, onPageChange, pagination, onRequestTotalCount, onCancelTotalCount, sortInfoExternal, showFilter, onToggleFilter, exportSqlWithFilter, onApplyFilter, appliedFilterConditions
}) => {
const connections = useStore(state => state.connections);
const addSqlLog = useStore(state => state.addSqlLog);
@@ -620,6 +623,8 @@ const DataGrid: React.FC<DataGridProps> = ({
const isQueryResultExport = exportScope === 'queryResult';
const canImport = exportScope === 'table' && !!tableName;
const canExport = !!connectionId && (isQueryResultExport || !!tableName);
const filteredExportSql = useMemo(() => String(exportSqlWithFilter || '').trim(), [exportSqlWithFilter]);
const hasFilteredExportSql = exportScope === 'table' && filteredExportSql.length > 0;
// Background Helper
const getBg = (darkHex: string) => {
@@ -1060,10 +1065,40 @@ const DataGrid: React.FC<DataGridProps> = ({
const [modifiedRows, setModifiedRows] = useState<Record<string, any>>({});
const [deletedRowKeys, setDeletedRowKeys] = useState<Set<string>>(new Set());
const normalizeFilterLogic = useCallback((logic: unknown): 'AND' | 'OR' => {
return String(logic || '').trim().toUpperCase() === 'OR' ? 'OR' : 'AND';
}, []);
const normalizeGridFilterConditions = useCallback((conditions?: FilterCondition[]): GridFilterCondition[] => {
if (!Array.isArray(conditions)) return [];
return conditions.map((cond, index) => {
const fallbackId = index + 1;
const nextId = Number.isFinite(Number(cond?.id)) ? Number(cond?.id) : fallbackId;
const op = String(cond?.op || '=');
const rawColumn = String(cond?.column || '');
return {
id: nextId,
enabled: cond?.enabled !== false,
logic: normalizeFilterLogic(cond?.logic),
column: rawColumn || (op === 'CUSTOM' ? '' : String(columnNames[0] || '')),
op,
value: String(cond?.value ?? ''),
value2: String(cond?.value2 ?? ''),
};
});
}, [columnNames, normalizeFilterLogic]);
// Filter State
const [filterConditions, setFilterConditions] = useState<GridFilterCondition[]>([]);
const [nextFilterId, setNextFilterId] = useState(1);
useEffect(() => {
const nextConditions = normalizeGridFilterConditions(appliedFilterConditions);
setFilterConditions(nextConditions);
const maxId = nextConditions.reduce((max, cond) => (cond.id > max ? cond.id : max), 0);
setNextFilterId(Math.max(1, maxId + 1));
}, [appliedFilterConditions, normalizeGridFilterConditions]);
const selectedRowKeysRef = useRef(selectedRowKeys);
const displayDataRef = useRef<any[]>([]);
@@ -2481,6 +2516,23 @@ const DataGrid: React.FC<DataGridProps> = ({
});
};
const handleExportFilteredAll = async (format: string) => {
if (!connectionId || !tableName) return;
if (!filteredExportSql) {
message.warning('当前未应用筛选条件');
return;
}
if (!supportsSqlQueryExport) {
message.error('当前数据源不支持按筛选结果导出');
return;
}
if (hasChanges) {
message.warning("当前存在未提交修改,筛选结果导出基于数据库已提交数据。");
}
await exportByQuery(filteredExportSql, format, `${tableName || 'export'}_filtered`);
};
const handleImport = async () => {
if (!connectionId || !tableName) return;
const config = buildConnConfig();
@@ -2526,6 +2578,10 @@ const DataGrid: React.FC<DataGridProps> = ({
{ value: 'NOT_IN', label: '不在列表' },
{ value: 'CUSTOM', label: '[自定义]' },
]), []);
const filterLogicOptions = useMemo(() => ([
{ value: 'AND', label: '且 (AND)' },
{ value: 'OR', label: '或 (OR)' },
]), []);
const isNoValueOp = useCallback((op: string) => (
op === 'IS_NULL' || op === 'IS_NOT_NULL' || op === 'IS_EMPTY' || op === 'IS_NOT_EMPTY'
@@ -2534,7 +2590,18 @@ const DataGrid: React.FC<DataGridProps> = ({
const isListOp = useCallback((op: string) => op === 'IN' || op === 'NOT_IN', []);
const addFilter = () => {
setFilterConditions([...filterConditions, { id: nextFilterId, enabled: true, column: columnNames[0] || '', op: '=', value: '', value2: '' }]);
setFilterConditions([
...filterConditions,
{
id: nextFilterId,
enabled: true,
logic: 'AND',
column: columnNames[0] || '',
op: '=',
value: '',
value2: '',
}
]);
setNextFilterId(nextFilterId + 1);
};
const updateFilter = (id: number, field: keyof GridFilterCondition, val: string | boolean) => {
@@ -2562,11 +2629,28 @@ const DataGrid: React.FC<DataGridProps> = ({
if (onApplyFilter) onApplyFilter(filterConditions);
};
const exportMenu: MenuProps['items'] = [
const exportMenu: MenuProps['items'] = hasFilteredExportSql ? [
{ type: 'group', label: '筛选结果', children: [
{ key: 'filtered-csv', label: 'CSV', onClick: () => handleExportFilteredAll('csv') },
{ key: 'filtered-xlsx', label: 'Excel (XLSX)', onClick: () => handleExportFilteredAll('xlsx') },
{ key: 'filtered-json', label: 'JSON', onClick: () => handleExportFilteredAll('json') },
{ key: 'filtered-md', label: 'Markdown', onClick: () => handleExportFilteredAll('md') },
{ key: 'filtered-html', label: 'HTML', onClick: () => handleExportFilteredAll('html') },
]},
{ type: 'divider' },
{ type: 'group', label: '全表', children: [
{ key: 'table-csv', label: 'CSV', onClick: () => handleExport('csv') },
{ key: 'table-xlsx', label: 'Excel (XLSX)', onClick: () => handleExport('xlsx') },
{ key: 'table-json', label: 'JSON', onClick: () => handleExport('json') },
{ key: 'table-md', label: 'Markdown', onClick: () => handleExport('md') },
{ key: 'table-html', label: 'HTML', onClick: () => handleExport('html') },
]},
] : [
{ key: 'csv', label: 'CSV', onClick: () => handleExport('csv') },
{ key: 'xlsx', label: 'Excel (XLSX)', onClick: () => handleExport('xlsx') },
{ key: 'json', label: 'JSON', onClick: () => handleExport('json') },
{ key: 'md', label: 'Markdown', onClick: () => handleExport('md') },
{ key: 'html', label: 'HTML', onClick: () => handleExport('html') },
];
const columnInfoSettingContent = (
@@ -2705,29 +2789,31 @@ const DataGrid: React.FC<DataGridProps> = ({
horizontalSyncSourceRef.current = '';
}, []);
const handleExternalHorizontalWheel = useCallback((event: React.WheelEvent<HTMLDivElement>) => {
// 非虚拟模式:外部水平滚动条的 wheel 处理(通过原生事件绑定,确保 preventDefault 生效)
useEffect(() => {
const externalScroll = externalHScrollRef.current;
if (!(externalScroll instanceof HTMLDivElement)) {
return;
}
const dominantDelta = Math.abs(event.deltaX) > Math.abs(event.deltaY) ? event.deltaX : event.deltaY;
if (!Number.isFinite(dominantDelta) || Math.abs(dominantDelta) < 0.5) {
return;
}
if (!externalScroll || !horizontalScrollVisible) return;
const maxScrollLeft = Math.max(0, externalScroll.scrollWidth - externalScroll.clientWidth);
if (maxScrollLeft <= 0) {
return;
}
const handleExternalWheel = (e: WheelEvent) => {
// 鼠标在水平滚动条区域时,始终阻止垂直滚动冒泡
e.preventDefault();
e.stopPropagation();
const nextScrollLeft = Math.max(0, Math.min(maxScrollLeft, externalScroll.scrollLeft + dominantDelta));
if (Math.abs(nextScrollLeft - externalScroll.scrollLeft) < 0.5) {
return;
}
const dominantDelta = Math.abs(e.deltaX) > Math.abs(e.deltaY) ? e.deltaX : e.deltaY;
if (!Number.isFinite(dominantDelta) || Math.abs(dominantDelta) < 0.5) return;
event.preventDefault();
externalScroll.scrollLeft = nextScrollLeft;
}, []);
const maxScrollLeft = Math.max(0, externalScroll.scrollWidth - externalScroll.clientWidth);
if (maxScrollLeft <= 0) return;
const nextScrollLeft = Math.max(0, Math.min(maxScrollLeft, externalScroll.scrollLeft + dominantDelta));
externalScroll.scrollLeft = nextScrollLeft;
};
externalScroll.addEventListener('wheel', handleExternalWheel, { passive: false, capture: true });
return () => {
externalScroll.removeEventListener('wheel', handleExternalWheel, { capture: true } as EventListenerOptions);
};
}, [horizontalScrollVisible]);
useEffect(() => {
if (viewMode !== 'table') return;
@@ -2735,19 +2821,24 @@ const DataGrid: React.FC<DataGridProps> = ({
return () => cancelAnimationFrame(rafId);
}, [viewMode, totalWidth, mergedDisplayData.length, recalculateTableMetrics]);
// 虚拟模式下,为 rc-virtual-list 的内置水平滚动条添加鼠标滚轮支持
// rc-virtual-list 的 ScrollBar 组件原生只支持拖拽,不支持 wheel 事件
// 方案:使用 MutationObserver 发现滚动条元素后直接绑定 wheel 事件
// 虚拟模式下,在容器级别监听 wheel 事件,当鼠标在底部水平滚动条区域时拦截并转为水平滚动
useEffect(() => {
if (viewMode !== 'table' || !enableVirtual) return;
const container = tableContainerRef.current;
if (!container) return;
let currentScrollbarEl: HTMLElement | null = null;
// 滚动条区域高度:滚动条高度 + 间距 + 容错
const scrollbarZoneHeight = floatingScrollbarHeight + floatingScrollbarGap + 8;
const handleScrollbarWheel = (e: WheelEvent) => {
const innerEl = container.querySelector('.rc-virtual-list-holder-inner') as HTMLElement | null;
const holderEl = container.querySelector('.rc-virtual-list-holder') as HTMLElement | null;
const handleContainerWheel = (e: WheelEvent) => {
// 判断鼠标是否在底部滚动条区域
const containerRect = container.getBoundingClientRect();
if (e.clientY < containerRect.bottom - scrollbarZoneHeight) return;
// 适配 antd 的虚拟列表类名
const holderEl = container.querySelector('.ant-table-tbody-virtual-holder') as HTMLElement | null;
const innerEl = holderEl?.querySelector('.ant-table-tbody-virtual-holder-inner') as HTMLElement | null;
if (!innerEl || !holderEl) return;
const dominantDelta = Math.abs(e.deltaX) > Math.abs(e.deltaY) ? e.deltaX : e.deltaY;
@@ -2769,12 +2860,13 @@ const DataGrid: React.FC<DataGridProps> = ({
innerEl.style.marginLeft = `${-newOffset}px`;
// 同步 scrollbar thumb 位置
if (currentScrollbarEl && maxScroll > 0) {
const thumbEl = currentScrollbarEl.querySelector('[class*="scrollbar-thumb"]') as HTMLElement | null;
const scrollbarEl = container.querySelector('.ant-table-tbody-virtual-scrollbar-horizontal') as HTMLElement | null;
if (scrollbarEl && maxScroll > 0) {
const thumbEl = scrollbarEl.querySelector('[class*="scrollbar-thumb"]') as HTMLElement | null;
if (thumbEl) {
const ratio = newOffset / maxScroll;
const thumbWidth = parseFloat(thumbEl.style.width) || thumbEl.offsetWidth;
const trackWidth = currentScrollbarEl.clientWidth;
const trackWidth = scrollbarEl.clientWidth;
const thumbMaxOffset = trackWidth - thumbWidth;
thumbEl.style.left = `${ratio * thumbMaxOffset}px`;
}
@@ -2787,33 +2879,12 @@ const DataGrid: React.FC<DataGridProps> = ({
}
};
const bindScrollbar = () => {
const el = container.querySelector('.ant-table-tbody-virtual-scrollbar-horizontal') as HTMLElement | null;
if (el && el !== currentScrollbarEl) {
if (currentScrollbarEl) {
currentScrollbarEl.removeEventListener('wheel', handleScrollbarWheel);
}
currentScrollbarEl = el;
el.addEventListener('wheel', handleScrollbarWheel, { passive: false });
}
};
// 初次尝试绑定
bindScrollbar();
// 使用 MutationObserver 监听 DOM 变化,确保即使元素延迟渲染也能绑定
const observer = new MutationObserver(() => {
bindScrollbar();
});
observer.observe(container, { childList: true, subtree: true });
container.addEventListener('wheel', handleContainerWheel, { passive: false, capture: true });
return () => {
observer.disconnect();
if (currentScrollbarEl) {
currentScrollbarEl.removeEventListener('wheel', handleScrollbarWheel);
}
container.removeEventListener('wheel', handleContainerWheel, { capture: true } as EventListenerOptions);
};
}, [viewMode, enableVirtual, tableScrollX, mergedDisplayData.length]);
}, [viewMode, enableVirtual, tableScrollX, floatingScrollbarHeight, floatingScrollbarGap]);
useEffect(() => {
if (viewMode !== 'table') return;
@@ -3041,7 +3112,7 @@ const DataGrid: React.FC<DataGridProps> = ({
background: 'transparent',
boxSizing: 'border-box',
}}>
{filterConditions.map(cond => (
{filterConditions.map((cond, condIndex) => (
<div key={cond.id} style={{ display: 'flex', gap: 8, marginBottom: 8, alignItems: 'flex-start', opacity: cond.enabled === false ? 0.58 : 1 }}>
<Checkbox
checked={cond.enabled !== false}
@@ -3050,13 +3121,33 @@ const DataGrid: React.FC<DataGridProps> = ({
>
</Checkbox>
<Select
style={{ width: 180 }}
value={cond.column}
onChange={v => updateFilter(cond.id, 'column', v)}
options={columnNames.map(c => ({ value: c, label: c }))}
disabled={cond.op === 'CUSTOM'}
/>
{condIndex === 0 ? (
<div style={{ width: 96, marginTop: 7, textAlign: 'center', fontSize: 12, color: '#8c8c8c' }}>
</div>
) : (
<Select
style={{ width: 96 }}
value={cond.logic === 'OR' ? 'OR' : 'AND'}
onChange={v => updateFilter(cond.id, 'logic', v)}
options={filterLogicOptions as any}
/>
)}
<Select
style={{ width: 180 }}
value={cond.column}
onChange={v => updateFilter(cond.id, 'column', v)}
options={columnNames.map(c => ({ value: c, label: c }))}
showSearch
optionFilterProp="label"
filterOption={(input, option) =>
String(option?.label ?? '')
.toLowerCase()
.includes(String(input || '').trim().toLowerCase())
}
placeholder="搜索字段名"
disabled={cond.op === 'CUSTOM'}
/>
<Select
style={{ width: 140 }}
value={cond.op}
@@ -3307,7 +3398,6 @@ const DataGrid: React.FC<DataGridProps> = ({
className="data-grid-external-hscroll"
aria-hidden={!horizontalScrollVisible}
onScroll={applyExternalScrollToTableTargets}
onWheel={handleExternalHorizontalWheel}
style={{
opacity: horizontalScrollVisible ? 1 : 0,
pointerEvents: horizontalScrollVisible ? 'auto' : 'none',
@@ -3552,6 +3642,21 @@ const DataGrid: React.FC<DataGridProps> = ({
>
JSON
</div>
<div
style={{
padding: '8px 12px',
cursor: 'pointer',
transition: 'background 0.2s',
}}
onMouseEnter={(e) => e.currentTarget.style.background = darkMode ? '#303030' : '#f5f5f5'}
onMouseLeave={(e) => e.currentTarget.style.background = 'transparent'}
onClick={() => {
if (cellContextMenu.record) handleExportSelected('html', cellContextMenu.record);
setCellContextMenu(prev => ({ ...prev, visible: false }));
}}
>
HTML
</div>
</div>,
document.body
)}

View File

@@ -1,10 +1,11 @@
import React, { useEffect, useState, useCallback, useRef } from 'react';
import React, { useEffect, useState, useCallback, useRef, useMemo } from 'react';
import { message } from 'antd';
import { TabData, ColumnDefinition } from '../types';
import { useStore } from '../store';
import { DBQuery, DBGetColumns } from '../../wailsjs/go/app/App';
import DataGrid, { GONAVI_ROW_KEY } from './DataGrid';
import { buildOrderBySQL, buildWhereSQL, quoteIdentPart, quoteQualifiedIdent, withSortBufferTuningSQL, type FilterCondition } from '../utils/sql';
import { buildMongoCountCommand, buildMongoFilter, buildMongoFindCommand, buildMongoSort } from '../utils/mongodb';
import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities';
type ViewerPaginationState = {
@@ -151,6 +152,37 @@ const reverseOrderBySQL = (orderBySQL: string): string => {
return ` ORDER BY ${parts.join(', ')}`;
};
type ViewerFilterSnapshot = {
showFilter: boolean;
conditions: FilterCondition[];
};
const viewerFilterSnapshotsByTab = new Map<string, ViewerFilterSnapshot>();
const normalizeViewerFilterConditions = (conditions: FilterCondition[] | undefined): FilterCondition[] => {
if (!Array.isArray(conditions)) return [];
return conditions.map((cond) => ({
id: Number.isFinite(Number(cond?.id)) ? Number(cond?.id) : undefined,
enabled: cond?.enabled !== false,
logic: String(cond?.logic || '').trim().toUpperCase() === 'OR' ? 'OR' : 'AND',
column: String(cond?.column || ''),
op: String(cond?.op || '='),
value: String(cond?.value ?? ''),
value2: String(cond?.value2 ?? ''),
}));
};
const getViewerFilterSnapshot = (tabId: string): ViewerFilterSnapshot => {
const cached = viewerFilterSnapshotsByTab.get(String(tabId || '').trim());
if (!cached) {
return { showFilter: false, conditions: [] };
}
return {
showFilter: cached.showFilter === true,
conditions: normalizeViewerFilterConditions(cached.conditions),
};
};
const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
const [data, setData] = useState<any[]>([]);
const [columnNames, setColumnNames] = useState<string[]>([]);
@@ -185,14 +217,27 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
const [sortInfo, setSortInfo] = useState<{ columnKey: string, order: string } | null>(null);
const [showFilter, setShowFilter] = useState(false);
const [filterConditions, setFilterConditions] = useState<FilterCondition[]>([]);
const [showFilter, setShowFilter] = useState<boolean>(() => getViewerFilterSnapshot(tab.id).showFilter);
const [filterConditions, setFilterConditions] = useState<FilterCondition[]>(() => getViewerFilterSnapshot(tab.id).conditions);
const duckdbSafeSelectCacheRef = useRef<Record<string, string>>({});
const currentConnConfig = connections.find(c => c.id === tab.connectionId)?.config;
const currentConnCaps = getDataSourceCapabilities(currentConnConfig);
const currentConnType = currentConnCaps.type;
const forceReadOnly = currentConnCaps.forceReadOnlyQueryResult;
useEffect(() => {
const snapshot = getViewerFilterSnapshot(tab.id);
setShowFilter(snapshot.showFilter);
setFilterConditions(snapshot.conditions);
}, [tab.id]);
useEffect(() => {
viewerFilterSnapshotsByTab.set(tab.id, {
showFilter,
conditions: normalizeViewerFilterConditions(filterConditions),
});
}, [tab.id, showFilter, filterConditions]);
useEffect(() => {
setPkColumns([]);
pkKeyRef.current = '';
@@ -315,42 +360,67 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
const dbName = tab.dbName || '';
const tableName = tab.tableName || '';
const isMongoDB = dbTypeLower === 'mongodb';
let mongoFilter: Record<string, unknown> | undefined;
if (isMongoDB) {
try {
mongoFilter = buildMongoFilter(filterConditions);
} catch (e: any) {
message.error(`Mongo 筛选条件无效:${String(e?.message || e || '解析失败')}`);
if (fetchSeqRef.current === seq) setLoading(false);
return;
}
}
const whereSQL = buildWhereSQL(dbType, filterConditions);
const countSql = `SELECT COUNT(*) as total FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
const baseSql = `SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
const orderBySQL = buildOrderBySQL(dbType, sortInfo, pkColumns);
let sql = `${baseSql}${orderBySQL}`;
const whereSQL = isMongoDB
? JSON.stringify(mongoFilter || {})
: buildWhereSQL(dbType, filterConditions);
const countSql = isMongoDB
? buildMongoCountCommand(tableName, mongoFilter || {})
: `SELECT COUNT(*) as total FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
const orderBySQL = isMongoDB ? '' : buildOrderBySQL(dbType, sortInfo, pkColumns);
const totalRows = Number(pagination.total);
const hasFiniteTotal = Number.isFinite(totalRows) && totalRows >= 0;
const totalKnown = pagination.totalKnown && hasFiniteTotal;
const totalPages = hasFiniteTotal ? Math.max(1, Math.ceil(totalRows / size)) : 0;
const currentPage = totalPages > 0 ? Math.min(Math.max(1, page), totalPages) : Math.max(1, page);
const offset = (currentPage - 1) * size;
const isClickHouse = dbTypeLower === 'clickhouse';
const isClickHouse = !isMongoDB && dbTypeLower === 'clickhouse';
const reverseOrderSQL = isClickHouse ? reverseOrderBySQL(orderBySQL) : '';
let useClickHouseReversePagination = false;
let clickHouseReverseLimit = 0;
let clickHouseReverseHasMore = false;
// ClickHouse 深分页在超大 OFFSET 下容易超时。对于总数已知且存在 ORDER BY 的场景,
// 当“尾部偏移”小于“头部偏移”时,改为反向 ORDER BY + 小 OFFSET并在前端翻转结果。
if (isClickHouse && totalKnown && offset > 0 && reverseOrderSQL) {
const pageRowCount = Math.max(0, Math.min(size, totalRows - offset));
if (pageRowCount > 0) {
const tailOffset = Math.max(0, totalRows - (offset + pageRowCount));
if (tailOffset < offset) {
sql = `${baseSql}${reverseOrderSQL} LIMIT ${pageRowCount} OFFSET ${tailOffset}`;
useClickHouseReversePagination = true;
clickHouseReverseLimit = pageRowCount;
clickHouseReverseHasMore = currentPage < totalPages;
let sql = '';
if (isMongoDB) {
const mongoSort = buildMongoSort(sortInfo, pkColumns);
sql = buildMongoFindCommand({
collection: tableName,
filter: mongoFilter || {},
sort: mongoSort,
limit: size + 1,
skip: offset,
});
} else {
const baseSql = `SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
sql = `${baseSql}${orderBySQL}`;
// ClickHouse 深分页在超大 OFFSET 下容易超时。对于总数已知且存在 ORDER BY 的场景,
// 当“尾部偏移”小于“头部偏移”时,改为反向 ORDER BY + 小 OFFSET并在前端翻转结果。
if (isClickHouse && totalKnown && offset > 0 && reverseOrderSQL) {
const pageRowCount = Math.max(0, Math.min(size, totalRows - offset));
if (pageRowCount > 0) {
const tailOffset = Math.max(0, totalRows - (offset + pageRowCount));
if (tailOffset < offset) {
sql = `${baseSql}${reverseOrderSQL} LIMIT ${pageRowCount} OFFSET ${tailOffset}`;
useClickHouseReversePagination = true;
clickHouseReverseLimit = pageRowCount;
clickHouseReverseHasMore = currentPage < totalPages;
}
}
}
}
if (!useClickHouseReversePagination) {
// 大表性能:打开表不阻塞在 COUNT(*),先通过多取 1 条判断是否还有下一页;总数在后台统计并异步回填。
sql += ` LIMIT ${size + 1} OFFSET ${offset}`;
if (!useClickHouseReversePagination) {
// 大表性能:打开表不阻塞在 COUNT(*),先通过多取 1 条判断是否还有下一页;总数在后台统计并异步回填。
sql += ` LIMIT ${size + 1} OFFSET ${offset}`;
}
}
const requestStartTime = Date.now();
@@ -676,6 +746,24 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
const handleToggleFilter = useCallback(() => setShowFilter(prev => !prev), []);
const handleApplyFilter = useCallback((conditions: FilterCondition[]) => setFilterConditions(conditions), []);
const exportSqlWithFilter = useMemo(() => {
const tableName = String(tab.tableName || '').trim();
const dbType = String(currentConnConfig?.type || '').trim();
if (!tableName || !dbType) return '';
const whereSQL = buildWhereSQL(dbType, filterConditions);
if (!whereSQL) return '';
let sql = `SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`;
sql += buildOrderBySQL(dbType, sortInfo, pkColumns);
const normalizedType = dbType.toLowerCase();
const hasExplicitSort = !!sortInfo?.columnKey && (sortInfo?.order === 'ascend' || sortInfo?.order === 'descend');
if (hasExplicitSort && (normalizedType === 'mysql' || normalizedType === 'mariadb')) {
sql = withSortBufferTuningSQL(normalizedType, sql, 32 * 1024 * 1024);
}
return sql;
}, [tab.tableName, currentConnConfig?.type, filterConditions, sortInfo, pkColumns]);
useEffect(() => {
fetchData(1, pagination.pageSize);
}, [tab, sortInfo, filterConditions]); // Initial load and re-load on sort/filter
@@ -700,8 +788,10 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
showFilter={showFilter}
onToggleFilter={handleToggleFilter}
onApplyFilter={handleApplyFilter}
appliedFilterConditions={filterConditions}
readOnly={forceReadOnly}
sortInfoExternal={sortInfo}
exportSqlWithFilter={exportSqlWithFilter || undefined}
/>
</div>
);

View File

@@ -1,13 +1,16 @@
import React, { useState, useEffect, useRef, useMemo } from 'react';
import Editor, { OnMount } from '@monaco-editor/react';
import { Button, message, Modal, Input, Form, Dropdown, MenuProps, Tooltip, Select, Tabs } from 'antd';
import { PlayCircleOutlined, SaveOutlined, FormatPainterOutlined, SettingOutlined, CloseOutlined } from '@ant-design/icons';
import { PlayCircleOutlined, SaveOutlined, FormatPainterOutlined, SettingOutlined, CloseOutlined, StopOutlined } from '@ant-design/icons';
import { format } from 'sql-formatter';
import { v4 as uuidv4 } from 'uuid';
import { TabData, ColumnDefinition } from '../types';
import { useStore } from '../store';
import { DBQuery, DBGetTables, DBGetAllColumns, DBGetDatabases, DBGetColumns } from '../../wailsjs/go/app/App';
import { DBQuery, DBQueryWithCancel, DBGetTables, DBGetAllColumns, DBGetDatabases, DBGetColumns, CancelQuery, GenerateQueryID } from '../../wailsjs/go/app/App';
import DataGrid, { GONAVI_ROW_KEY } from './DataGrid';
import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities';
import { convertMongoShellToJsonCommand } from '../utils/mongodb';
import { getShortcutDisplay, isEditableElement, isShortcutMatch } from '../utils/shortcuts';
const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
const [query, setQuery] = useState(tab.query || 'SELECT * FROM ');
@@ -30,7 +33,9 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
const [activeResultKey, setActiveResultKey] = useState<string>('');
const [loading, setLoading] = useState(false);
const [currentQueryId, setCurrentQueryId] = useState<string>('');
const runSeqRef = useRef(0);
const currentQueryIdRef = useRef('');
const [isSaveModalOpen, setIsSaveModalOpen] = useState(false);
const [saveForm] = Form.useForm();
@@ -65,6 +70,8 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
const setSqlFormatOptions = useStore(state => state.setSqlFormatOptions);
const queryOptions = useStore(state => state.queryOptions);
const setQueryOptions = useStore(state => state.setQueryOptions);
const shortcutOptions = useStore(state => state.shortcutOptions);
const activeTabId = useStore(state => state.activeTabId);
useEffect(() => {
currentConnectionIdRef.current = currentConnectionId;
@@ -186,6 +193,17 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
fetchMetadata();
}, [currentConnectionId, connections, dbList]); // dbList 变化时触发重新加载
// Query ID management helpers
const setQueryId = (id: string) => {
currentQueryIdRef.current = id;
setCurrentQueryId(id);
};
const clearQueryId = () => {
currentQueryIdRef.current = '';
setCurrentQueryId('');
};
// Handle Resizing
const handleMouseDown = (e: React.MouseEvent) => {
e.preventDefault();
@@ -254,6 +272,19 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
return parts[parts.length - 1] || raw;
};
const splitSchemaAndTable = (qualified: string): { schema: string; table: string } => {
const raw = normalizeQualifiedName(qualified);
if (!raw) return { schema: '', table: '' };
const parts = raw.split('.').filter(Boolean);
if (parts.length >= 2) {
return {
schema: parts[parts.length - 2] || '',
table: parts[parts.length - 1] || '',
};
}
return { schema: '', table: parts[0] || '' };
};
const buildConnConfig = () => {
const connId = currentConnectionIdRef.current;
const conn = connectionsRef.current.find(c => c.id === connId);
@@ -326,13 +357,14 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
if (qualifierMatch) {
const qualifier = stripQuotes(qualifierMatch[1]);
const prefix = (qualifierMatch[2] || '').toLowerCase();
const qualifierLower = qualifier.toLowerCase();
// 首先检查 qualifier 是否是数据库名(跨库表提示)
const visibleDbs = visibleDbsRef.current;
if (visibleDbs.some(db => db.toLowerCase() === qualifier.toLowerCase())) {
if (visibleDbs.some(db => db.toLowerCase() === qualifierLower)) {
// qualifier 是数据库名,提示该库的表
const tables = tablesRef.current.filter(t =>
(t.dbName || '').toLowerCase() === qualifier.toLowerCase()
(t.dbName || '').toLowerCase() === qualifierLower
);
const filtered = prefix
? tables.filter(t => (t.tableName || '').toLowerCase().startsWith(prefix))
@@ -349,6 +381,34 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
return { suggestions };
}
// qualifier 是 schema如 dbo/public仅补全表名避免输入 dbo. 后再补成 dbo.dbo.table
const schemaTables = tablesRef.current
.map(t => {
const parsed = splitSchemaAndTable(t.tableName || '');
return {
dbName: t.dbName || '',
schema: parsed.schema,
table: parsed.table,
};
})
.filter(t => t.schema.toLowerCase() === qualifierLower && !!t.table);
if (schemaTables.length > 0) {
const filtered = prefix
? schemaTables.filter(t => t.table.toLowerCase().startsWith(prefix))
: schemaTables;
const suggestions = filtered.map(t => ({
label: t.table,
kind: monaco.languages.CompletionItemKind.Class,
insertText: t.table,
detail: `Table (${t.dbName}${t.schema ? '.' + t.schema : ''})`,
range,
sortText: '0' + t.table
}));
return { suggestions };
}
// 否则检查是否是表别名或表名,提示列
const reserved = new Set([
'where', 'on', 'group', 'order', 'limit', 'having',
@@ -517,6 +577,12 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
icon: sqlFormatOptions.keywordCase === 'lower' ? '✓' : undefined,
onClick: () => setSqlFormatOptions({ keywordCase: 'lower' })
},
{ type: 'divider' },
{
key: 'shortcut-settings',
label: '快捷键管理...',
onClick: () => window.dispatchEvent(new CustomEvent('gonavi:open-shortcut-settings')),
},
];
const splitSQLStatements = (sql: string): string[] => {
@@ -984,6 +1050,16 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
message.error("请先选择数据库");
return;
}
// 如果已有查询在运行,先取消它
if (currentQueryIdRef.current) {
try {
await CancelQuery(currentQueryIdRef.current);
} catch (error) {
// 忽略取消错误,可能查询已完成
}
// 清除旧查询ID
clearQueryId();
}
const runSeq = ++runSeqRef.current;
setLoading(true);
const runStartTime = Date.now();
@@ -1011,7 +1087,15 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
try {
const rawSQL = getSelectedSQL() || query;
const statements = splitSQLStatements(rawSQL);
const dbType = String((config as any).type || 'mysql');
const normalizedDbType = dbType.trim().toLowerCase();
const normalizedRawSQL = String(rawSQL || '').replace(//g, ';');
const splitInput = normalizedDbType === 'mongodb'
? normalizedRawSQL
.replace(/^\s*\/\/.*$/gm, '')
.replace(/^\s*#.*$/gm, '')
: normalizedRawSQL;
const statements = splitSQLStatements(splitInput);
if (statements.length === 0) {
message.info('没有可执行的 SQL。');
setResultSets([]);
@@ -1021,7 +1105,6 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
const nextResultSets: ResultSet[] = [];
const maxRows = Number(queryOptions?.maxRows) || 0;
const dbType = String((config as any).type || 'mysql');
const forceReadOnlyResult = connCaps.forceReadOnlyQueryResult;
const wantsLimitProbe = Number.isFinite(maxRows) && maxRows > 0;
const probeLimit = wantsLimitProbe ? (maxRows + 1) : 0;
@@ -1035,9 +1118,35 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
const limitApplied = shouldAutoLimit && wantsLimitProbe;
const limited = limitApplied ? applyAutoLimit(rawStatement, dbType, probeLimit) : { sql: rawStatement, applied: false, maxRows: probeLimit };
const executedSql = limited.sql;
let executedSql = limited.sql;
if (String(dbType || '').trim().toLowerCase() === 'mongodb') {
const shellConvert = convertMongoShellToJsonCommand(executedSql);
if (shellConvert.recognized) {
if (shellConvert.error) {
const prefix = statements.length > 1 ? `${idx + 1} 条语句执行失败:` : '';
message.error(prefix + shellConvert.error);
setResultSets([]);
setActiveResultKey('');
return;
}
if (shellConvert.command) {
executedSql = shellConvert.command;
}
}
}
const startTime = Date.now();
const res = await DBQuery(config as any, currentDb, executedSql);
// Generate query ID for cancellation using backend UUID with fallback
let queryId: string;
try {
queryId = await GenerateQueryID();
} catch (error) {
console.warn('GenerateQueryID failed, using local UUID fallback:', error);
queryId = 'query-' + uuidv4();
}
setQueryId(queryId);
const res = await DBQueryWithCancel(config as any, currentDb, executedSql, queryId);
const duration = Date.now() - startTime;
addSqlLog({
@@ -1052,6 +1161,32 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
});
if (!res.success) {
// 检查是否为查询取消错误
const errorMsg = res.message.toLowerCase();
const isCancelledError = errorMsg.includes('context canceled') ||
errorMsg.includes('查询已取消') ||
errorMsg.includes('canceled') ||
errorMsg.includes('cancelled') ||
errorMsg.includes('statement canceled') ||
errorMsg.includes('sql: statement canceled');
// 确保不是超时错误
const isTimeoutError = errorMsg.includes('context deadline exceeded') ||
errorMsg.includes('timeout') ||
errorMsg.includes('超时') ||
errorMsg.includes('deadline exceeded');
if (isCancelledError && !isTimeoutError) {
// 查询已被用户取消,不显示错误消息,清理状态
setResultSets([]);
setActiveResultKey('');
// 清除查询ID与handleCancel保持一致
if (currentQueryIdRef.current) {
clearQueryId();
}
return;
}
const prefix = statements.length > 1 ? `${idx + 1} 条语句执行失败:` : '';
message.error(prefix + res.message);
setResultSets([]);
@@ -1157,9 +1292,75 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
setActiveResultKey('');
} finally {
if (runSeqRef.current === runSeq) setLoading(false);
// Clear query ID after execution completes
clearQueryId();
}
};
const handleCancel = async () => {
if (!currentQueryIdRef.current) {
message.warning('没有正在运行的查询可取消');
return;
}
const queryIdToCancel = currentQueryIdRef.current;
try {
const res = await CancelQuery(queryIdToCancel);
if (res.success) {
message.success('查询已取消');
// Clear query ID after successful cancellation
if (currentQueryIdRef.current === queryIdToCancel) {
clearQueryId()
}
} else {
message.warning(res.message);
}
} catch (error: any) {
message.error('取消查询失败: ' + error.message);
}
};
useEffect(() => {
const binding = shortcutOptions.runQuery;
if (!binding?.enabled || !binding.combo) {
return;
}
const handleRunShortcut = (event: KeyboardEvent) => {
if (activeTabId !== tab.id) {
return;
}
if (!isShortcutMatch(event, binding.combo)) {
return;
}
const editorHasFocus = !!editorRef.current?.hasTextFocus?.();
if (!editorHasFocus && !isEditableElement(event.target)) {
return;
}
event.preventDefault();
event.stopPropagation();
void handleRun();
};
window.addEventListener('keydown', handleRunShortcut);
return () => {
window.removeEventListener('keydown', handleRunShortcut);
};
}, [activeTabId, tab.id, shortcutOptions.runQuery, handleRun]);
useEffect(() => {
const handleRunActiveQuery = () => {
if (activeTabId !== tab.id) {
return;
}
void handleRun();
};
window.addEventListener('gonavi:run-active-query', handleRunActiveQuery as EventListener);
return () => {
window.removeEventListener('gonavi:run-active-query', handleRunActiveQuery as EventListener);
};
}, [activeTabId, tab.id, handleRun]);
const handleSave = async () => {
try {
const values = await saveForm.validateFields();
@@ -1271,9 +1472,24 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
]}
/>
</Tooltip>
<Button type="primary" icon={<PlayCircleOutlined />} onClick={handleRun} loading={loading}>
</Button>
<Button.Group>
<Tooltip
title={
shortcutOptions.runQuery?.enabled && shortcutOptions.runQuery?.combo
? `运行(${getShortcutDisplay(shortcutOptions.runQuery.combo)}`
: '运行'
}
>
<Button type="primary" icon={<PlayCircleOutlined />} onClick={handleRun} loading={loading}>
</Button>
</Tooltip>
{loading && (
<Button type="primary" danger icon={<StopOutlined />} onClick={handleCancel}>
</Button>
)}
</Button.Group>
<Button icon={<SaveOutlined />} onClick={() => {
saveForm.setFieldsValue({ name: tab.title.replace('Query (', '').replace(')', '') });
setIsSaveModalOpen(true);

View File

@@ -1,11 +1,12 @@
import React, { useEffect, useState, useMemo, useRef } from 'react';
import { Tree, message, Dropdown, MenuProps, Input, Button, Modal, Form, Badge, Checkbox, Space, Select } from 'antd';
import { Tree, message, Dropdown, MenuProps, Input, Button, Modal, Form, Badge, Checkbox, Space, Select, Popover, Tooltip } from 'antd';
import {
DatabaseOutlined,
TableOutlined,
EyeOutlined,
ConsoleSqlOutlined,
HddOutlined,
FolderOutlined,
FolderOpenOutlined,
FileTextOutlined,
CopyOutlined,
@@ -42,13 +43,14 @@ interface TreeNode {
children?: TreeNode[];
icon?: React.ReactNode;
dataRef?: any;
type?: 'connection' | 'database' | 'table' | 'view' | 'db-trigger' | 'routine' | 'object-group' | 'queries-folder' | 'saved-query' | 'folder-columns' | 'folder-indexes' | 'folder-fks' | 'folder-triggers' | 'redis-db';
type?: 'connection' | 'database' | 'table' | 'view' | 'db-trigger' | 'routine' | 'object-group' | 'queries-folder' | 'saved-query' | 'folder-columns' | 'folder-indexes' | 'folder-fks' | 'folder-triggers' | 'redis-db' | 'tag';
}
type BatchTableExportMode = 'schema' | 'backup' | 'dataOnly';
type BatchObjectType = 'table' | 'view';
type BatchObjectFilterType = 'all' | BatchObjectType;
type BatchSelectionScope = 'filtered' | 'all';
type SearchScope = 'smart' | 'object' | 'database' | 'host' | 'tag';
interface BatchObjectItem {
title: string;
@@ -58,12 +60,32 @@ interface BatchObjectItem {
dataRef: any;
}
const SEARCH_SCOPE_OPTIONS: Array<{ value: SearchScope; label: string }> = [
{ value: 'smart', label: '智能' },
{ value: 'object', label: '表对象' },
{ value: 'database', label: '库' },
{ value: 'host', label: 'Host' },
{ value: 'tag', label: '标签' },
];
const SEARCH_SCOPE_LABEL_MAP: Record<SearchScope, string> = SEARCH_SCOPE_OPTIONS.reduce((acc, option) => {
acc[option.value] = option.label;
return acc;
}, {} as Record<SearchScope, string>);
const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> = ({ onEditConnection }) => {
const connections = useStore(state => state.connections);
const savedQueries = useStore(state => state.savedQueries);
const addConnection = useStore(state => state.addConnection);
const addTab = useStore(state => state.addTab);
const setActiveContext = useStore(state => state.setActiveContext);
const removeConnection = useStore(state => state.removeConnection);
const connectionTags = useStore(state => state.connectionTags);
const addConnectionTag = useStore(state => state.addConnectionTag);
const updateConnectionTag = useStore(state => state.updateConnectionTag);
const removeConnectionTag = useStore(state => state.removeConnectionTag);
const moveConnectionToTag = useStore(state => state.moveConnectionToTag);
const reorderTags = useStore(state => state.reorderTags);
const closeTabsByConnection = useStore(state => state.closeTabsByConnection);
const closeTabsByDatabase = useStore(state => state.closeTabsByDatabase);
const theme = useStore(state => state.theme);
@@ -87,6 +109,9 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
};
const bgMain = getBg('#141414');
const [searchValue, setSearchValue] = useState('');
const [searchScopes, setSearchScopes] = useState<SearchScope[]>(['smart']);
const [isSearchScopePopoverOpen, setIsSearchScopePopoverOpen] = useState(false);
const searchInputRef = useRef<any>(null);
const [expandedKeys, setExpandedKeys] = useState<React.Key[]>([]);
const [autoExpandParent, setAutoExpandParent] = useState(true);
const [loadedKeys, setLoadedKeys] = useState<React.Key[]>([]);
@@ -109,6 +134,21 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
resizeObserver.observe(treeContainerRef.current);
return () => resizeObserver.disconnect();
}, []);
useEffect(() => {
const handleFocusSidebarSearch = () => {
const inputEl = searchInputRef.current?.input as HTMLInputElement | undefined;
if (!inputEl) {
return;
}
inputEl.focus();
inputEl.select();
};
window.addEventListener('gonavi:focus-sidebar-search', handleFocusSidebarSearch as EventListener);
return () => {
window.removeEventListener('gonavi:focus-sidebar-search', handleFocusSidebarSearch as EventListener);
};
}, []);
// Connection Status State: key -> 'success' | 'error'
const [connectionStates, setConnectionStates] = useState<Record<string, 'success' | 'error'>>({});
@@ -127,6 +167,10 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
const [renameViewForm] = Form.useForm();
const [renameViewTarget, setRenameViewTarget] = useState<any>(null);
// Connection Tag Modals
const [isCreateTagModalOpen, setIsCreateTagModalOpen] = useState(false);
const [createTagForm] = Form.useForm();
// Batch Operations Modal
const [isBatchModalOpen, setIsBatchModalOpen] = useState(false);
const [batchTables, setBatchTables] = useState<BatchObjectItem[]>([]);
@@ -208,11 +252,21 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
useEffect(() => {
setTreeData((prev) => {
const prevMap = new Map<string, TreeNode>();
prev.forEach((node) => {
prevMap.set(String(node.key), node);
});
return connections.map((conn) => {
// We need to recursively extract connections from old tag structures
// so if a user expands a connection that was tagged, the state remains
const recurseCollect = (nodes: TreeNode[]) => {
nodes.forEach((node) => {
if (node.type === 'tag') {
if (node.children) recurseCollect(node.children);
} else if (node.type === 'connection') {
prevMap.set(String(node.key), node);
}
});
};
recurseCollect(prev);
const buildConnectionNode = (conn: SavedConnection): TreeNode => {
const existing = prevMap.get(conn.id);
return {
title: conn.name,
@@ -223,10 +277,145 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
isLeaf: false,
children: existing?.children,
} as TreeNode;
});
});
}, [connections]);
};
const taggedConnIds = new Set<string>();
const tagNodes: TreeNode[] = connectionTags.map((tag) => {
tag.connectionIds.forEach(id => taggedConnIds.add(id));
return {
title: tag.name,
key: `tag-${tag.id}`,
icon: <FolderOutlined style={{ color: '#faad14' }} />,
type: 'tag',
dataRef: tag,
isLeaf: false,
children: tag.connectionIds
.map(cid => connections.find(c => c.id === cid))
.filter(Boolean)
.map(conn => buildConnectionNode(conn!)),
} as TreeNode;
});
const ungroupedNodes: TreeNode[] = connections
.filter(c => !taggedConnIds.has(c.id))
.map(conn => buildConnectionNode(conn));
return [...tagNodes, ...ungroupedNodes];
});
}, [connections, connectionTags]);
const buildDuplicateConnectionName = (rawName: string): string => {
const baseName = String(rawName || '').trim() || '连接';
const suffix = ' - 副本';
const usedNames = new Set(connections.map(conn => String(conn.name || '').trim()));
let candidate = `${baseName}${suffix}`;
let counter = 2;
while (usedNames.has(candidate)) {
candidate = `${baseName}${suffix} ${counter}`;
counter += 1;
}
return candidate;
};
const cloneConnectionConfig = (config: SavedConnection['config']): SavedConnection['config'] => {
const raw: any = config || {};
let cloned: any = {};
try {
cloned = typeof structuredClone === 'function'
? structuredClone(raw)
: JSON.parse(JSON.stringify(raw));
} catch {
cloned = { ...raw };
}
const readString = (...values: unknown[]): string => {
for (const value of values) {
if (typeof value === 'string') {
return value;
}
}
return '';
};
const readBool = (fallback: boolean, ...values: unknown[]): boolean => {
for (const value of values) {
if (typeof value === 'boolean') {
return value;
}
}
return fallback;
};
const readNumber = (fallback: number, ...values: unknown[]): number => {
for (const value of values) {
const num = Number(value);
if (Number.isFinite(num)) {
return num;
}
}
return fallback;
};
const rawSSH = (cloned.ssh ?? cloned.SSH ?? {}) as Record<string, unknown>;
const normalizedSSH = {
host: readString(rawSSH.host, rawSSH.Host, cloned.sshHost, cloned.SSHHost),
port: readNumber(22, rawSSH.port, rawSSH.Port, cloned.sshPort, cloned.SSHPort),
user: readString(rawSSH.user, rawSSH.User, cloned.sshUser, cloned.SSHUser),
password: readString(rawSSH.password, rawSSH.Password, cloned.sshPassword, cloned.SSHPassword),
keyPath: readString(rawSSH.keyPath, rawSSH.KeyPath, cloned.sshKeyPath, cloned.SSHKeyPath),
};
const hasSSHDetail = Boolean(
normalizedSSH.host
|| normalizedSSH.user
|| normalizedSSH.password
|| normalizedSSH.keyPath
);
const rawProxy = (cloned.proxy ?? cloned.Proxy ?? {}) as Record<string, unknown>;
const proxyTypeRaw = readString(rawProxy.type, rawProxy.Type, cloned.proxyType, cloned.ProxyType).toLowerCase();
const proxyType: 'socks5' | 'http' = proxyTypeRaw === 'http' ? 'http' : 'socks5';
const normalizedProxy = {
type: proxyType,
host: readString(rawProxy.host, rawProxy.Host, cloned.proxyHost, cloned.ProxyHost),
port: readNumber(proxyType === 'http' ? 8080 : 1080, rawProxy.port, rawProxy.Port, cloned.proxyPort, cloned.ProxyPort),
user: readString(rawProxy.user, rawProxy.User, cloned.proxyUser, cloned.ProxyUser),
password: readString(rawProxy.password, rawProxy.Password, cloned.proxyPassword, cloned.ProxyPassword),
};
const hasProxyDetail = Boolean(normalizedProxy.host || normalizedProxy.user || normalizedProxy.password);
const rawHosts = Array.isArray(cloned.hosts)
? cloned.hosts
: (Array.isArray(cloned.Hosts) ? cloned.Hosts : []);
const normalizedHosts = rawHosts
.map((entry: unknown) => String(entry || '').trim())
.filter((entry: string) => !!entry);
return {
...(cloned as SavedConnection['config']),
useSSH: readBool(hasSSHDetail, cloned.useSSH, cloned.UseSSH),
ssh: normalizedSSH,
useProxy: readBool(hasProxyDetail, cloned.useProxy, cloned.UseProxy),
proxy: normalizedProxy,
hosts: normalizedHosts,
timeout: readNumber(30, cloned.timeout, cloned.Timeout),
};
};
const handleDuplicateConnection = (conn: SavedConnection) => {
if (!conn) return;
const duplicatedConnection: SavedConnection = {
...conn,
id: `${Date.now()}-${Math.random().toString(36).slice(2, 8)}`,
name: buildDuplicateConnectionName(conn.name),
config: cloneConnectionConfig(conn.config),
includeDatabases: conn.includeDatabases ? [...conn.includeDatabases] : undefined,
includeRedisDatabases: conn.includeRedisDatabases ? [...conn.includeRedisDatabases] : undefined,
};
addConnection(duplicatedConnection);
message.success(`已复制连接: ${duplicatedConnection.name}`);
};
const updateTreeData = (list: TreeNode[], key: React.Key, children: TreeNode[] | undefined): TreeNode[] => {
return list.map(node => {
if (node.key === key) {
@@ -705,7 +894,8 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
const res = await (window as any).go.app.App.RedisGetDatabases(config);
if (res.success) {
setConnectionStates(prev => ({ ...prev, [conn.id]: 'success' }));
let dbs = (res.data as any[]).map((db: any) => ({
const redisRows: any[] = Array.isArray(res.data) ? res.data : [];
let dbs = redisRows.map((db: any) => ({
title: `db${db.index}${db.keys > 0 ? ` (${db.keys})` : ''}`,
key: `${conn.id}-db${db.index}`,
icon: <DatabaseOutlined style={{ color: '#DC382D' }} />,
@@ -736,7 +926,8 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
const res = await DBGetDatabases(config as any);
if (res.success) {
setConnectionStates(prev => ({ ...prev, [conn.id]: 'success' }));
let dbs = (res.data as any[]).map((row: any) => ({
const dbRows: any[] = Array.isArray(res.data) ? res.data : [];
let dbs = dbRows.map((row: any) => ({
title: row.Database || row.database,
key: `${conn.id}-${row.Database || row.database}`,
icon: <DatabaseOutlined />,
@@ -751,10 +942,13 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
}
setTreeData(origin => updateTreeData(origin, node.key, dbs));
} else {
setConnectionStates(prev => ({ ...prev, [conn.id]: 'error' }));
message.error({ content: res.message, key: `conn-${conn.id}-dbs` });
}
} else {
setConnectionStates(prev => ({ ...prev, [conn.id]: 'error' }));
message.error({ content: res.message, key: `conn-${conn.id}-dbs` });
}
} catch (e: any) {
setConnectionStates(prev => ({ ...prev, [conn.id]: 'error' }));
message.error({ content: '连接失败: ' + (e?.message || String(e)), key: `conn-${conn.id}-dbs` });
} finally {
loadingNodesRef.current.delete(loadKey);
}
@@ -799,7 +993,8 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
if (res.success) {
setConnectionStates(prev => ({ ...prev, [key as string]: 'success' }));
const tableEntries = (res.data as any[]).map((row: any) => {
const tableRows: any[] = Array.isArray(res.data) ? res.data : [];
const tableEntries = tableRows.map((row: any) => {
const tableName = Object.values(row)[0] as string;
const parsed = splitQualifiedName(tableName);
return {
@@ -815,7 +1010,11 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
loadFunctions(conn, conn.dbName),
]);
const viewEntries = viewsResult.views.map((viewName) => {
const viewRows: string[] = Array.isArray(viewsResult.views) ? viewsResult.views : [];
const triggerRows: any[] = Array.isArray(triggersResult.triggers) ? triggersResult.triggers : [];
const routineRows: any[] = Array.isArray(routinesResult.routines) ? routinesResult.routines : [];
const viewEntries = viewRows.map((viewName: string) => {
const parsed = splitQualifiedName(viewName);
return {
viewName,
@@ -829,7 +1028,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
const triggerSeen = new Set<string>();
const metadataDialect = getMetadataDialect(conn as SavedConnection);
triggersResult.triggers.forEach((trigger) => {
triggerRows.forEach((trigger: any) => {
const triggerParsed = splitQualifiedName(trigger.triggerName);
const tableParsed = splitQualifiedName(trigger.tableName);
const schemaName = tableParsed.schemaName || triggerParsed.schemaName || String(conn.dbName || '').trim();
@@ -854,7 +1053,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
return deduped;
})();
const routineEntries = routinesResult.routines.map((routine) => {
const routineEntries = routineRows.map((routine: any) => {
const parsed = splitQualifiedName(routine.routineName);
const typeLabel = routine.routineType === 'PROCEDURE' ? 'P' : 'F';
return {
@@ -1036,12 +1235,16 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
setConnectionStates(prev => ({ ...prev, [key as string]: 'error' }));
message.error({ content: res.message, key: `db-${key}-tables` });
}
} catch (e: any) {
setConnectionStates(prev => ({ ...prev, [key as string]: 'error' }));
message.error({ content: '加载表失败: ' + (e?.message || String(e)), key: `db-${key}-tables` });
} finally {
loadingNodesRef.current.delete(loadKey);
}
};
const onLoadData = async ({ key, children, dataRef, type }: any) => {
if (type === 'tag') return;
if (children) return;
if (type === 'connection') {
@@ -1380,7 +1583,8 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
const res = await DBGetDatabases(config as any);
if (res.success) {
let dbs = (res.data as any[]).map((row: any) => {
const dbRows: any[] = Array.isArray(res.data) ? res.data : [];
let dbs = dbRows.map((row: any) => {
const dbName = row.Database || row.database;
return {
title: dbName,
@@ -1421,9 +1625,11 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
return;
}
const viewSet = new Set(viewResult.views.map(view => view.toLowerCase()));
const tableRows: any[] = Array.isArray(res.data) ? res.data : [];
const viewRows: string[] = Array.isArray(viewResult.views) ? viewResult.views : [];
const viewSet = new Set(viewRows.map((view: string) => view.toLowerCase()));
const tableObjects: BatchObjectItem[] = (res.data as any[])
const tableObjects: BatchObjectItem[] = tableRows
.map((row: any) => Object.values(row)[0] as string)
.filter((tableName: string) => !viewSet.has(tableName.toLowerCase()))
.map((tableName: string) => ({
@@ -1434,7 +1640,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
dataRef: { ...conn, tableName, dbName, objectType: 'table' },
}));
const viewObjects: BatchObjectItem[] = viewResult.views.map((viewName: string) => ({
const viewObjects: BatchObjectItem[] = viewRows.map((viewName: string) => ({
title: getSidebarTableDisplayName(conn, viewName),
key: `${conn.id}-${dbName}-view-${viewName}`,
objectName: viewName,
@@ -1600,7 +1806,8 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
const res = await DBGetDatabases(config as any);
if (res.success) {
let dbs = (res.data as any[]).map((row: any) => {
const dbRows: any[] = Array.isArray(res.data) ? res.data : [];
let dbs = dbRows.map((row: any) => {
const dbName = row.Database || row.database;
return {
title: dbName,
@@ -2187,28 +2394,205 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
setSearchValue(value);
};
const loop = (data: TreeNode[]): TreeNode[] => {
const result: TreeNode[] = [];
data.forEach(item => {
const match = item.title.toLowerCase().indexOf(searchValue.toLowerCase()) > -1;
if (item.children) {
const filteredChildren = loop(item.children);
if (filteredChildren.length > 0 || match) {
result.push({ ...item, children: filteredChildren });
}
const toggleSearchScope = (scope: SearchScope) => {
setSearchScopes((prev) => {
if (scope === 'smart') {
return ['smart'];
}
const withoutSmart = prev.filter((item) => item !== 'smart');
if (withoutSmart.includes(scope)) {
const next = withoutSmart.filter((item) => item !== scope);
return next.length > 0 ? next : ['smart'];
}
return [...withoutSmart, scope];
});
};
const setSearchScopeChecked = (scope: SearchScope, checked: boolean) => {
if (scope === 'smart') {
if (checked) {
setSearchScopes(['smart']);
} else if (searchScopes.length === 1 && searchScopes[0] === 'smart') {
setSearchScopes(['smart']);
} else {
if (match) {
setSearchScopes((prev) => {
const next = prev.filter((item) => item !== 'smart');
return next.length > 0 ? next : ['smart'];
});
}
return;
}
if (checked) {
setSearchScopes((prev) => {
const withoutSmart = prev.filter((item) => item !== 'smart');
if (withoutSmart.includes(scope)) {
return withoutSmart;
}
return [...withoutSmart, scope];
});
} else {
setSearchScopes((prev) => {
const next = prev.filter((item) => item !== scope && item !== 'smart');
return next.length > 0 ? next : ['smart'];
});
}
};
const searchScopeSummary = useMemo(() => {
if (searchScopes.includes('smart')) {
return '智能';
}
return searchScopes.map((scope) => SEARCH_SCOPE_LABEL_MAP[scope]).join(' + ');
}, [searchScopes]);
const searchScopePopoverContent = useMemo(() => {
const smartSelected = searchScopes.includes('smart');
const scopedOptions = SEARCH_SCOPE_OPTIONS.filter((option) => option.value !== 'smart');
return (
<div style={{ minWidth: 220, display: 'flex', flexDirection: 'column', gap: 8 }}>
<div style={{ fontSize: 12, color: '#8c8c8c' }}></div>
<Checkbox
checked={smartSelected}
onChange={(e) => setSearchScopeChecked('smart', e.target.checked)}
>
</Checkbox>
<div style={{ paddingLeft: 12, display: 'grid', gap: 6 }}>
{scopedOptions.map((option) => (
<Checkbox
key={option.value}
checked={searchScopes.includes(option.value)}
onChange={(e) => setSearchScopeChecked(option.value, e.target.checked)}
>
{option.label}
</Checkbox>
))}
</div>
<div style={{ fontSize: 12, color: '#8c8c8c' }}>
</div>
</div>
);
}, [searchScopes]);
const parseHostOnlyToken = (value: unknown): string[] => {
const raw = String(value || '').trim();
if (!raw) {
return [];
}
let text = raw.replace(/^[a-z][a-z0-9+.-]*:\/\//i, '');
if (text.includes('/')) {
text = text.split('/')[0];
}
if (text.includes('?')) {
text = text.split('?')[0];
}
if (text.includes('@')) {
text = text.split('@').pop() || '';
}
return text
.split(',')
.map((entry) => {
const token = entry.trim();
if (!token) return '';
if (token.startsWith('[')) {
const rightBracketIndex = token.indexOf(']');
if (rightBracketIndex > 0) {
return token.slice(0, rightBracketIndex + 1).toLowerCase();
}
}
const colonIndex = token.lastIndexOf(':');
if (colonIndex > 0) {
return token.slice(0, colonIndex).toLowerCase();
}
return token.toLowerCase();
})
.filter(Boolean);
};
const getConnectionHostSearchText = (node: TreeNode): string => {
if (node.type !== 'connection') return '';
const config = node.dataRef?.config || {};
const hostTokens = [
...parseHostOnlyToken(config.host),
...(Array.isArray(config.hosts) ? config.hosts.flatMap((entry: string) => parseHostOnlyToken(entry)) : []),
...parseHostOnlyToken(config.uri),
];
const uniqueHosts = Array.from(new Set(hostTokens));
return uniqueHosts.join(' ');
};
const getConnectionNameSearchText = (node: TreeNode): string => {
if (node.type !== 'connection') return '';
const name = node.dataRef?.name ?? node.title;
return String(name || '').toLowerCase();
};
const isObjectNode = (node: TreeNode): boolean => {
return node.type === 'table'
|| node.type === 'view'
|| node.type === 'db-trigger'
|| node.type === 'routine'
|| node.type === 'object-group';
};
const matchByScopes = (node: TreeNode, keyword: string, scopes: SearchScope[]): boolean => {
const title = String(node.title || '').toLowerCase();
if (scopes.includes('database') && node.type === 'database' && title.includes(keyword)) {
return true;
}
if (scopes.includes('tag') && node.type === 'tag' && title.includes(keyword)) {
return true;
}
if (scopes.includes('host') && node.type === 'connection' && getConnectionHostSearchText(node).includes(keyword)) {
return true;
}
if (scopes.includes('object') && isObjectNode(node) && title.includes(keyword)) {
return true;
}
return false;
};
const loop = (data: TreeNode[], keyword: string): TreeNode[] => {
const isSmartMode = searchScopes.includes('smart');
const result: TreeNode[] = [];
data.forEach((item) => {
const titleMatch = String(item.title || '').toLowerCase().includes(keyword);
const smartMatch = item.type === 'connection'
? getConnectionNameSearchText(item).includes(keyword) || getConnectionHostSearchText(item).includes(keyword)
: titleMatch;
const scopedMatch = matchByScopes(item, keyword, searchScopes);
const selfMatch = isSmartMode ? smartMatch : scopedMatch;
const filteredChildren = item.children ? loop(item.children, keyword) : [];
if (selfMatch) {
const shouldKeepFullSubtree = isSmartMode
|| item.type === 'connection'
|| item.type === 'database'
|| item.type === 'tag';
if (item.children && shouldKeepFullSubtree) {
result.push(item);
} else if (item.children && filteredChildren.length > 0) {
result.push({ ...item, children: filteredChildren });
} else {
result.push(item);
}
return;
}
if (filteredChildren.length > 0) {
result.push({ ...item, children: filteredChildren });
}
});
return result;
};
const displayTreeData = useMemo(() => {
if (!searchValue) return treeData;
return loop(treeData);
}, [searchValue, treeData]);
const keyword = searchValue.trim().toLowerCase();
if (!keyword) return treeData;
return loop(treeData, keyword);
}, [searchValue, searchScopes, treeData]);
const getNodeMenuItems = (node: any): MenuProps['items'] => {
const conn = node.dataRef as SavedConnection;
@@ -2284,6 +2668,38 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
return routineMenu;
}
// Connection Tag Menu — must be BEFORE the connection check
if (node.type === 'tag') {
return [
{
key: 'edit-tag',
label: '编辑标签',
icon: <EditOutlined />,
onClick: () => {
createTagForm.setFieldsValue({ name: node.title, connectionIds: node.dataRef.connectionIds });
setRenameViewTarget(node);
setIsCreateTagModalOpen(true);
}
},
{ type: 'divider' },
{
key: 'delete-tag',
label: '删除标签',
icon: <DeleteOutlined />,
danger: true,
onClick: () => {
Modal.confirm({
title: '确认删除',
content: `确定要删除标签 "${node.title}" 吗?这不会删除里面的连接。`,
onOk: () => {
removeConnectionTag(node.dataRef.id);
}
});
}
}
];
}
if (node.type === 'connection') {
// Redis connection menu
if (isRedis) {
@@ -2318,6 +2734,12 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
if (onEditConnection) onEditConnection(node.dataRef);
}
},
{
key: 'copy-connection',
label: '复制连接',
icon: <CopyOutlined />,
onClick: () => handleDuplicateConnection(node.dataRef as SavedConnection)
},
{
key: 'disconnect',
label: '断开连接',
@@ -2358,6 +2780,22 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
];
}
// Tag submenu for connection
const tagSubMenuItems: MenuProps['items'] = connectionTags.map(tag => ({
key: `move-to-tag-${tag.id}`,
label: tag.name,
icon: <FolderOutlined />,
onClick: () => moveConnectionToTag(node.key, tag.id)
}));
if (connectionTags.length > 0) {
tagSubMenuItems.push({ type: 'divider' });
}
tagSubMenuItems.push({
key: 'move-to-ungrouped',
label: '移出标签',
onClick: () => moveConnectionToTag(node.key, null)
});
// Regular database connection menu
return [
{
@@ -2400,11 +2838,30 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
if (onEditConnection) onEditConnection(node.dataRef);
}
},
{
key: 'copy-connection',
label: '复制连接',
icon: <CopyOutlined />,
onClick: () => handleDuplicateConnection(node.dataRef as SavedConnection)
},
{
key: 'move-to-tag',
label: '移至标签',
icon: <FolderOpenOutlined />,
children: tagSubMenuItems
},
{
key: 'disconnect',
label: '断开连接',
icon: <DisconnectOutlined />,
onClick: () => {
const connId = String(node.key || '');
// 强制清理该连接相关的 loading 标记,避免网络卡住后重连仍被短路。
Array.from(loadingNodesRef.current).forEach((loadingKey) => {
if (loadingKey === `dbs-${connId}` || loadingKey.startsWith(`tables-${connId}-`)) {
loadingNodesRef.current.delete(loadingKey);
}
});
// Reset status recursively
setConnectionStates(prev => {
const next = { ...prev };
@@ -2526,6 +2983,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
onClick: () => {
const dbConnId = String(node.dataRef?.id || '');
const dbName = String(node.dataRef?.dbName || node.title || '').trim();
loadingNodesRef.current.delete(`tables-${dbConnId}-${dbName}`);
setConnectionStates(prev => {
const next = { ...prev };
delete next[node.key];
@@ -2707,6 +3165,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
{ key: 'export-xlsx', label: '导出 Excel (XLSX)', onClick: () => handleExport(node, 'xlsx') },
{ key: 'export-json', label: '导出 JSON', onClick: () => handleExport(node, 'json') },
{ key: 'export-md', label: '导出 Markdown', onClick: () => handleExport(node, 'md') },
{ key: 'export-html', label: '导出 HTML', onClick: () => handleExport(node, 'html') },
]
}
];
@@ -2741,6 +3200,72 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
return <span title={hoverTitle}>{statusBadge}{displayTitle}</span>;
};
const handleDrop = (info: any) => {
const dropKey = info.node.key;
const dragKey = info.dragNode.key;
const dropPos = info.node.pos.split('-');
const dropPosition = info.dropPosition - Number(dropPos[dropPos.length - 1]);
const dragNode = info.dragNode;
const dropNode = info.node;
// Tag to Tag reordering
if (dragNode.type === 'tag') {
// You can only drop tags onto the root level (before/after other tags or connections at root)
if (dropNode.type === 'tag' || dropNode.type === 'connection') {
// Get current order
const currentTagOrder = connectionTags.map(t => t.id);
const dragTagId = dragNode.dataRef.id;
// Filter out the dragging tag
const newOrder = currentTagOrder.filter(id => id !== dragTagId);
let insertIndex = newOrder.length;
if (dropNode.type === 'tag') {
const dropTagId = dropNode.dataRef.id;
const dropIndex = newOrder.indexOf(dropTagId);
if (dropPosition === -1) {
insertIndex = dropIndex;
} else {
insertIndex = dropIndex + 1;
}
} else {
// Dropped onto a root connection, usually meaning moving to the end of tags
// Since tags are always displayed before ungrouped connections, just put it at the end
insertIndex = newOrder.length;
}
newOrder.splice(insertIndex, 0, dragTagId);
reorderTags(newOrder);
}
return;
}
// Connection moving to tag (any drop position on a tag node counts as "into")
if (dragNode.type === 'connection' && dropNode.type === 'tag') {
moveConnectionToTag(dragNode.key, dropNode.dataRef.id);
return;
}
// Connection moving to another connection inside a tag
if (dragNode.type === 'connection' && dropNode.type === 'connection') {
// Find if drop target is under a tag
const targetTag = connectionTags.find(t => t.connectionIds.includes(dropNode.key));
if (targetTag) {
moveConnectionToTag(dragNode.key, targetTag.id);
return;
}
// Drop target is NOT under a tag (ungrouped) -> move OUT of tag
const sourceTag = connectionTags.find(t => t.connectionIds.includes(dragNode.key));
if (sourceTag) {
moveConnectionToTag(dragNode.key, null);
return;
}
}
};
const onRightClick = ({ event, node }: any) => {
const items = getNodeMenuItems(node);
if (items && items.length > 0) {
@@ -2755,16 +3280,49 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
return (
<div style={{ display: 'flex', flexDirection: 'column', height: '100%' }}>
<div style={{ padding: '4px 8px' }}>
<Search placeholder="搜索..." onChange={onSearch} size="small" />
<Space.Compact block size="small">
<Search
ref={searchInputRef}
placeholder="搜索..."
onChange={onSearch}
size="small"
style={{ width: '100%' }}
/>
<Popover
content={searchScopePopoverContent}
trigger="click"
placement="bottomRight"
open={isSearchScopePopoverOpen}
onOpenChange={setIsSearchScopePopoverOpen}
>
<Tooltip title={`搜索范围:${searchScopeSummary}`}>
<Button size="small" icon={<DownOutlined />} style={{ width: 86 }}>
{searchScopes.includes('smart') ? '(智)' : `(${searchScopes.length})`}
</Button>
</Tooltip>
</Popover>
</Space.Compact>
</div>
{/* Toolbar for batch operations - always visible */}
<div style={{ padding: '4px 8px', borderBottom: 'none', display: 'flex', gap: 4 }}>
{/* Toolbar */}
<div style={{ padding: '4px 8px', borderBottom: 'none', display: 'flex', flexWrap: 'wrap', gap: 4 }}>
<Button
size="small"
icon={<FolderOpenOutlined />}
onClick={() => {
setRenameViewTarget(null); // Create mode
createTagForm.resetFields();
setIsCreateTagModalOpen(true);
}}
style={{ flex: '1 1 auto' }}
>
</Button>
<Button
size="small"
icon={<CheckSquareOutlined />}
onClick={() => openBatchOperationModal()}
style={{ flex: 1 }}
style={{ flex: '1 1 auto' }}
>
</Button>
@@ -2772,7 +3330,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
size="small"
icon={<CheckSquareOutlined />}
onClick={() => openBatchDatabaseModal()}
style={{ flex: 1 }}
style={{ flex: '1 1 auto' }}
>
</Button>
@@ -2781,6 +3339,11 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
<div ref={treeContainerRef} style={{ flex: 1, overflow: 'hidden', minHeight: 0 }}>
<Tree
showIcon
draggable={{
icon: false,
nodeDraggable: (node: any) => node.type === 'connection' || node.type === 'tag'
}}
onDrop={handleDrop}
loadData={onLoadData}
treeData={displayTreeData}
onDoubleClick={onDoubleClick}
@@ -2809,6 +3372,60 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
</Dropdown>
)}
<Modal
title={renameViewTarget?.type === 'tag' ? "编辑标签" : "新建组"}
open={isCreateTagModalOpen}
onOk={() => {
createTagForm.validateFields().then(values => {
if (renameViewTarget?.type === 'tag') {
// Rename
updateConnectionTag({
...renameViewTarget.dataRef,
name: values.name,
connectionIds: values.connectionIds || []
});
// update cross-connections
const allOtherTagsIds = connectionTags.filter(t => t.id !== renameViewTarget.dataRef.id).flatMap(t => t.connectionIds);
(values.connectionIds || []).forEach((cid: string) => {
if (allOtherTagsIds.includes(cid)) {
moveConnectionToTag(cid, renameViewTarget.dataRef.id);
}
});
} else {
// Create
const tagId = Date.now().toString();
addConnectionTag({
id: tagId,
name: values.name,
connectionIds: values.connectionIds || []
});
(values.connectionIds || []).forEach((cid: string) => {
moveConnectionToTag(cid, tagId);
});
}
setIsCreateTagModalOpen(false);
});
}}
onCancel={() => setIsCreateTagModalOpen(false)}
>
<Form form={createTagForm} layout="vertical">
<Form.Item name="name" label="标签名称" rules={[{ required: true, message: '请输入标签名称' }]}>
<Input />
</Form.Item>
<Form.Item name="connectionIds" label="选择连接">
<Checkbox.Group style={{ width: '100%' }}>
<Space direction="vertical" style={{ width: '100%', maxHeight: '400px', overflowY: 'auto' }}>
{connections.map(conn => (
<Checkbox key={conn.id} value={conn.id}>
{conn.name} {conn.config.host ? `(${conn.config.host})` : ''}
</Checkbox>
))}
</Space>
</Checkbox.Group>
</Form.Item>
</Form>
</Modal>
<Modal
title="新建数据库"
open={isCreateDbModalOpen}

View File

@@ -261,9 +261,18 @@ const TableDesigner: React.FC<{ tab: TabData }> = ({ tab }) => {
const darkMode = theme === 'dark';
const resizeGuideColor = darkMode ? '#f6c453' : '#1890ff';
const readOnly = !!tab.readOnly;
const panelRadius = 10;
const panelFrameColor = darkMode ? 'rgba(0, 0, 0, 0.18)' : 'rgba(0, 0, 0, 0.12)';
const panelToolbarBorder = darkMode ? 'rgba(255, 255, 255, 0.12)' : 'rgba(0, 0, 0, 0.10)';
const panelToolbarBg = darkMode ? 'rgba(20, 20, 20, 0.35)' : 'rgba(255, 255, 255, 0.72)';
const panelBodyBg = darkMode ? 'rgba(0, 0, 0, 0.24)' : 'rgba(255, 255, 255, 0.82)';
const focusRowBg = darkMode ? 'rgba(246, 196, 83, 0.22)' : 'rgba(24, 144, 255, 0.12)';
const [tableHeight, setTableHeight] = useState(500);
const containerRef = useRef<HTMLDivElement>(null);
const pendingFocusColumnKeyRef = useRef<string | null>(null);
const focusHighlightTimerRef = useRef<number | null>(null);
const [focusColumnKey, setFocusColumnKey] = useState('');
const openCommentEditor = useCallback((record: EditableColumn) => {
if (!record?._key) return;
@@ -346,6 +355,61 @@ const TableDesigner: React.FC<{ tab: TabData }> = ({ tab }) => {
setSelectedColumnRowKeys(prev => prev.filter(key => columns.some(c => c._key === key)));
}, [columns]);
useEffect(() => {
return () => {
if (focusHighlightTimerRef.current !== null) {
window.clearTimeout(focusHighlightTimerRef.current);
}
};
}, []);
const focusColumnRow = useCallback((targetKey: string): boolean => {
if (activeKey !== 'columns') return false;
const tableBody = containerRef.current?.querySelector('.ant-table-body') as HTMLElement | null;
if (!tableBody) return false;
const row = tableBody.querySelector(`tr[data-row-key="${targetKey}"]`) as HTMLTableRowElement | null;
if (!row) return false;
row.scrollIntoView({ behavior: 'smooth', block: 'nearest' });
setFocusColumnKey(targetKey);
if (focusHighlightTimerRef.current !== null) {
window.clearTimeout(focusHighlightTimerRef.current);
}
focusHighlightTimerRef.current = window.setTimeout(() => {
setFocusColumnKey(prev => (prev === targetKey ? '' : prev));
}, 1600);
if (!readOnly) {
const firstInput = row.querySelector('input') as HTMLInputElement | null;
if (firstInput) {
firstInput.focus();
firstInput.select();
}
}
return true;
}, [activeKey, readOnly]);
useEffect(() => {
const pendingKey = pendingFocusColumnKeyRef.current;
if (!pendingKey || activeKey !== 'columns') return;
let cancelled = false;
const tryFocus = () => {
if (cancelled) return;
if (focusColumnRow(pendingKey)) {
pendingFocusColumnKeyRef.current = null;
}
};
const timerA = window.setTimeout(tryFocus, 0);
const timerB = window.setTimeout(tryFocus, 96);
return () => {
cancelled = true;
window.clearTimeout(timerA);
window.clearTimeout(timerB);
};
}, [activeKey, columns, focusColumnRow]);
// Initial Columns Definition
useEffect(() => {
const initialCols = [
@@ -886,21 +950,46 @@ ${selectedTrigger.statement}`;
}));
};
const handleAddColumn = () => {
const newCol: EditableColumn = {
name: isNewTable ? 'new_column' : `new_col_${columns.length + 1}`,
type: 'varchar(255)',
nullable: 'YES',
key: '',
extra: '',
comment: '',
default: '',
_key: `new-${Date.now()}`,
isNew: true,
isAutoIncrement: false
};
setColumns([...columns, newCol]);
};
const createNewColumn = useCallback((indexHint: number): EditableColumn => ({
name: isNewTable ? 'new_column' : `new_col_${indexHint}`,
type: 'varchar(255)',
nullable: 'YES',
key: '',
extra: '',
comment: '',
default: '',
_key: `new-${Date.now()}-${Math.random().toString(36).slice(2, 8)}`,
isNew: true,
isAutoIncrement: false
}), [isNewTable]);
const handleAddColumn = useCallback((insertAfterKey?: string) => {
const newCol = createNewColumn(columns.length + 1);
setColumns(prev => {
const next = [...prev];
if (insertAfterKey) {
const insertIndex = next.findIndex(col => col._key === insertAfterKey);
if (insertIndex >= 0) {
next.splice(insertIndex + 1, 0, newCol);
return next;
}
}
next.push(newCol);
return next;
});
setSelectedColumnRowKeys([newCol._key]);
pendingFocusColumnKeyRef.current = newCol._key;
}, [columns.length, createNewColumn]);
const handleAddColumnAfterSelected = useCallback(() => {
const selectedSet = new Set(selectedColumnRowKeys);
const anchor = columns.find(col => selectedSet.has(col._key));
if (!anchor) {
message.warning('请先选择一个字段,再执行插入。');
return;
}
handleAddColumn(anchor._key);
}, [columns, handleAddColumn, selectedColumnRowKeys]);
const handleDeleteColumn = (key: string) => {
setColumns(prev => prev.filter(c => c._key !== key));
@@ -1920,22 +2009,35 @@ END;`;
}));
const columnsTabContent = (
<div ref={containerRef} className="table-designer-wrapper" style={{ height: '100%', overflow: 'hidden', position: 'relative' }}>
<div
ref={containerRef}
className="table-designer-wrapper"
style={{
height: '100%',
overflow: 'hidden',
position: 'relative',
background: panelBodyBg
}}
>
<style>{`
.table-designer-wrapper .ant-table-body {
max-height: ${tableHeight}px !important;
}
}
.table-designer-wrapper .table-designer-focus-row > .ant-table-cell {
background: ${focusRowBg} !important;
}
`}</style>
{readOnly ? (
<Table
dataSource={columns}
columns={resizableColumns}
rowKey="_key"
rowClassName={(record: EditableColumn) => record._key === focusColumnKey ? 'table-designer-focus-row' : ''}
size="small"
pagination={false}
loading={loading}
scroll={{ y: tableHeight }}
bordered
bordered={false}
components={{
header: {
cell: ResizableTitle,
@@ -1953,11 +2055,12 @@ END;`;
onChange: (nextSelectedRowKeys) => setSelectedColumnRowKeys(nextSelectedRowKeys as string[]),
}}
rowKey="_key"
rowClassName={(record: EditableColumn) => record._key === focusColumnKey ? 'table-designer-focus-row' : ''}
size="small"
pagination={false}
loading={loading}
scroll={{ y: tableHeight }}
bordered
bordered={false}
components={{
body: { row: SortableRow },
header: { cell: ResizableTitle }
@@ -1985,8 +2088,63 @@ END;`;
);
return (
<div style={{ display: 'flex', flexDirection: 'column', height: '100%' }}>
<div style={{ padding: '8px', borderBottom: '1px solid #eee', display: 'flex', gap: '8px', alignItems: 'center' }}>
<div className="table-designer-shell" style={{ display: 'flex', flexDirection: 'column', height: '100%', minHeight: 0, padding: '6px 0' }}>
<style>{`
.table-designer-shell .ant-table,
.table-designer-shell .ant-table-wrapper,
.table-designer-shell .ant-table-container {
background: transparent !important;
}
.table-designer-shell .ant-table-wrapper,
.table-designer-shell .ant-table-container {
border: none !important;
overflow: hidden !important;
}
.table-designer-shell .ant-table-thead > tr > th {
background: transparent !important;
border-bottom: 1px solid ${darkMode ? 'rgba(255,255,255,0.06)' : 'rgba(0,0,0,0.06)'} !important;
border-inline-end: 1px solid transparent !important;
}
.table-designer-shell .ant-table-tbody > tr > td,
.table-designer-shell .ant-table-tbody .ant-table-row > .ant-table-cell {
background: transparent !important;
border-bottom: 1px solid ${darkMode ? 'rgba(255,255,255,0.05)' : 'rgba(0,0,0,0.05)'} !important;
border-inline-end: 1px solid transparent !important;
}
.table-designer-shell .ant-table-thead > tr > th::before {
display: none !important;
}
.table-designer-shell .ant-table-tbody > tr:hover > td,
.table-designer-shell .ant-table-tbody .ant-table-row:hover > .ant-table-cell {
background: ${darkMode ? 'rgba(255,255,255,0.06)' : 'rgba(0,0,0,0.02)'} !important;
}
.table-designer-shell .ant-tabs-nav {
margin-bottom: 8px !important;
}
.table-designer-shell .ant-tabs-nav::before {
border-bottom-color: ${darkMode ? 'rgba(255,255,255,0.08)' : 'rgba(0,0,0,0.08)'} !important;
}
.table-designer-shell .ant-tabs-content-holder,
.table-designer-shell .ant-tabs-content,
.table-designer-shell .ant-tabs-tabpane {
height: 100%;
}
`}</style>
<div
style={{
padding: '10px 12px 8px 12px',
borderBottom: `1px solid ${panelToolbarBorder}`,
borderTopLeftRadius: panelRadius,
borderTopRightRadius: panelRadius,
borderLeft: `1px solid ${panelFrameColor}`,
borderRight: `1px solid ${panelFrameColor}`,
borderTop: `1px solid ${panelFrameColor}`,
background: panelToolbarBg,
display: 'flex',
gap: '8px',
alignItems: 'center'
}}
>
{isNewTable && (
<>
<Input
@@ -2014,14 +2172,25 @@ END;`;
/>
</>
)}
{!readOnly && <Button icon={<SaveOutlined />} type="primary" onClick={generateDDL}></Button>}
{!isNewTable && <Button icon={<ReloadOutlined />} onClick={fetchData}></Button>}
{!readOnly && <Button size="small" icon={<SaveOutlined />} type="primary" onClick={generateDDL}></Button>}
{!isNewTable && <Button size="small" icon={<ReloadOutlined />} onClick={fetchData}></Button>}
{!isNewTable && !readOnly && supportsTableCommentOps() && (
<Button icon={<EditOutlined />} onClick={openTableCommentModal}></Button>
<Button size="small" icon={<EditOutlined />} onClick={openTableCommentModal}></Button>
)}
{!readOnly && <Button icon={<PlusOutlined />} onClick={handleAddColumn}></Button>}
{!readOnly && <Button size="small" icon={<PlusOutlined />} onClick={() => handleAddColumn()}></Button>}
{!readOnly && (
<Button
size="small"
icon={<PlusOutlined />}
onClick={handleAddColumnAfterSelected}
disabled={selectedColumnRowKeys.length === 0}
>
</Button>
)}
{!readOnly && (
<Button
size="small"
icon={<CopyOutlined />}
onClick={openCopySelectedColumnsModal}
disabled={selectedColumns.length === 0}
@@ -2034,7 +2203,17 @@ END;`;
<Tabs
activeKey={activeKey}
onChange={setActiveKey}
style={{ flex: 1, padding: '0 10px' }}
style={{
flex: 1,
minHeight: 0,
padding: '8px 10px 10px 10px',
borderBottomLeftRadius: panelRadius,
borderBottomRightRadius: panelRadius,
borderLeft: `1px solid ${panelFrameColor}`,
borderRight: `1px solid ${panelFrameColor}`,
borderBottom: `1px solid ${panelFrameColor}`,
background: panelBodyBg
}}
items={[
{
key: 'columns',
@@ -2276,7 +2455,7 @@ END;`;
label: 'DDL',
icon: <FileTextOutlined />,
children: (
<div style={{ height: 'calc(100vh - 200px)', border: darkMode ? '1px solid #303030' : '1px solid #d9d9d9', borderRadius: 4 }}>
<div style={{ height: 'calc(100vh - 200px)', border: `1px solid ${panelFrameColor}`, borderRadius: panelRadius, background: panelBodyBg }}>
<Editor
height="100%"
language="sql"
@@ -2517,7 +2696,7 @@ END;`;
<span><strong>:</strong> {selectedTrigger.timing}</span>
<span><strong>:</strong> {selectedTrigger.event}</span>
</div>
<div style={{ border: darkMode ? '1px solid #303030' : '1px solid #d9d9d9', borderRadius: 4 }}>
<div style={{ border: `1px solid ${panelFrameColor}`, borderRadius: panelRadius, background: panelBodyBg }}>
<Editor
height="350px"
language="sql"
@@ -2553,7 +2732,7 @@ END;`;
<span></span>
)}
</div>
<div style={{ border: darkMode ? '1px solid #303030' : '1px solid #d9d9d9', borderRadius: 4 }}>
<div style={{ border: `1px solid ${panelFrameColor}`, borderRadius: panelRadius, background: panelBodyBg }}>
<Editor
height="350px"
language="sql"

View File

@@ -1,6 +1,14 @@
import { create } from 'zustand';
import { persist } from 'zustand/middleware';
import { ConnectionConfig, ProxyConfig, SavedConnection, TabData, SavedQuery } from './types';
import { ConnectionConfig, ProxyConfig, SavedConnection, TabData, SavedQuery, ConnectionTag } from './types';
import {
ShortcutAction,
ShortcutBinding,
ShortcutOptions,
DEFAULT_SHORTCUT_OPTIONS,
cloneShortcutOptions,
sanitizeShortcutOptions,
} from './utils/shortcuts';
const DEFAULT_APPEARANCE = { opacity: 1.0, blur: 0 };
const DEFAULT_UI_SCALE = 1.0;
@@ -17,7 +25,7 @@ const MAX_HOST_ENTRY_LENGTH = 512;
const MAX_HOST_ENTRIES = 64;
const DEFAULT_TIMEOUT_SECONDS = 30;
const MAX_TIMEOUT_SECONDS = 3600;
const PERSIST_VERSION = 4;
const PERSIST_VERSION = 5;
const DEFAULT_CONNECTION_TYPE = 'mysql';
const DEFAULT_GLOBAL_PROXY: GlobalProxyConfig = {
enabled: false,
@@ -48,6 +56,23 @@ const SUPPORTED_CONNECTION_TYPES = new Set([
'duckdb',
'custom',
]);
const SSL_SUPPORTED_CONNECTION_TYPES = new Set([
'mysql',
'mariadb',
'diros',
'sphinx',
'dameng',
'clickhouse',
'postgres',
'sqlserver',
'oracle',
'kingbase',
'highgo',
'vastbase',
'mongodb',
'redis',
'tdengine',
]);
const getDefaultPortByType = (type: string): number => {
switch (type) {
@@ -177,6 +202,16 @@ const sanitizeConnectionConfig = (value: unknown): ConnectionConfig => {
const defaultPort = getDefaultPortByType(type);
const savePassword = typeof raw.savePassword === 'boolean' ? raw.savePassword : true;
const mongoSrv = !!raw.mongoSrv;
const sslCapable = SSL_SUPPORTED_CONNECTION_TYPES.has(type);
const sslModeRaw = toTrimmedString(raw.sslMode, 'preferred').toLowerCase();
const sslMode: 'preferred' | 'required' | 'skip-verify' | 'disable' =
sslModeRaw === 'required'
? 'required'
: sslModeRaw === 'skip-verify'
? 'skip-verify'
: sslModeRaw === 'disable'
? 'disable'
: 'preferred';
const sshRaw = (raw.ssh && typeof raw.ssh === 'object') ? raw.ssh as Record<string, unknown> : {};
const ssh = {
@@ -206,6 +241,10 @@ const sanitizeConnectionConfig = (value: unknown): ConnectionConfig => {
password: savePassword ? toTrimmedString(raw.password) : '',
savePassword,
database: toTrimmedString(raw.database),
useSSL: sslCapable ? !!raw.useSSL : false,
sslMode: sslCapable ? sslMode : 'disable',
sslCertPath: sslCapable ? toTrimmedString(raw.sslCertPath) : '',
sslKeyPath: sslCapable ? toTrimmedString(raw.sslKeyPath) : '',
useSSH: !!raw.useSSH,
ssh,
useProxy: !!raw.useProxy,
@@ -293,6 +332,27 @@ const sanitizeConnections = (value: unknown): SavedConnection[] => {
return result;
};
const sanitizeConnectionTags = (value: unknown): ConnectionTag[] => {
if (!Array.isArray(value)) return [];
const result: ConnectionTag[] = [];
const idSet = new Set<string>();
value.forEach((entry, index) => {
if (!entry || typeof entry !== 'object') return;
const raw = entry as Record<string, unknown>;
const id = toTrimmedString(raw.id, `tag-${index + 1}`) || `tag-${index + 1}`;
if (idSet.has(id)) return;
idSet.add(id);
const name = toTrimmedString(raw.name, `标签-${index + 1}`) || `标签-${index + 1}`;
const connectionIds = sanitizeStringArray(raw.connectionIds, 256);
result.push({ id, name, connectionIds });
});
return result;
};
const isLegacyDefaultAppearance = (appearance: Partial<{ opacity: number; blur: number }> | undefined): boolean => {
if (!appearance) {
return true;
@@ -325,6 +385,7 @@ export interface GlobalProxyConfig extends ProxyConfig {
interface AppState {
connections: SavedConnection[];
connectionTags: ConnectionTag[];
tabs: TabData[];
activeTabId: string | null;
activeContext: { connectionId: string; dbName: string } | null;
@@ -337,6 +398,7 @@ interface AppState {
globalProxy: GlobalProxyConfig;
sqlFormatOptions: { keywordCase: 'upper' | 'lower' };
queryOptions: QueryOptions;
shortcutOptions: ShortcutOptions;
sqlLogs: SqlLog[];
tableAccessCount: Record<string, number>;
tableSortPreference: Record<string, 'name' | 'frequency'>;
@@ -345,6 +407,12 @@ interface AppState {
updateConnection: (conn: SavedConnection) => void;
removeConnection: (id: string) => void;
addConnectionTag: (tag: ConnectionTag) => void;
updateConnectionTag: (tag: ConnectionTag) => void;
removeConnectionTag: (id: string) => void;
moveConnectionToTag: (connectionId: string, targetTagId: string | null) => void;
reorderTags: (tagIds: string[]) => void;
addTab: (tab: TabData) => void;
closeTab: (id: string) => void;
closeOtherTabs: (id: string) => void;
@@ -368,6 +436,8 @@ interface AppState {
setGlobalProxy: (proxy: Partial<GlobalProxyConfig>) => void;
setSqlFormatOptions: (options: { keywordCase: 'upper' | 'lower' }) => void;
setQueryOptions: (options: Partial<QueryOptions>) => void;
updateShortcut: (action: ShortcutAction, binding: Partial<ShortcutBinding>) => void;
resetShortcutOptions: () => void;
addSqlLog: (log: SqlLog) => void;
clearSqlLogs: () => void;
@@ -496,6 +566,7 @@ export const useStore = create<AppState>()(
persist(
(set) => ({
connections: [],
connectionTags: [],
tabs: [],
activeTabId: null,
activeContext: null,
@@ -508,6 +579,7 @@ export const useStore = create<AppState>()(
globalProxy: { ...DEFAULT_GLOBAL_PROXY },
sqlFormatOptions: { keywordCase: 'upper' },
queryOptions: { maxRows: 5000, showColumnComment: true, showColumnType: true },
shortcutOptions: cloneShortcutOptions(DEFAULT_SHORTCUT_OPTIONS),
sqlLogs: [],
tableAccessCount: {},
tableSortPreference: {},
@@ -516,7 +588,46 @@ export const useStore = create<AppState>()(
updateConnection: (conn) => set((state) => ({
connections: state.connections.map(c => c.id === conn.id ? conn : c)
})),
removeConnection: (id) => set((state) => ({ connections: state.connections.filter(c => c.id !== id) })),
removeConnection: (id) => set((state) => ({
connections: state.connections.filter(c => c.id !== id),
connectionTags: state.connectionTags.map(tag => ({
...tag,
connectionIds: tag.connectionIds.filter(cid => cid !== id)
}))
})),
addConnectionTag: (tag) => set((state) => ({ connectionTags: [...state.connectionTags, tag] })),
updateConnectionTag: (tag) => set((state) => ({
connectionTags: state.connectionTags.map(t => t.id === tag.id ? tag : t)
})),
removeConnectionTag: (id) => set((state) => ({
connectionTags: state.connectionTags.filter(t => t.id !== id)
})),
moveConnectionToTag: (connectionId, targetTagId) => set((state) => {
const newTags = state.connectionTags.map(tag => {
//先从所有tag中移除该connection
const filteredIds = tag.connectionIds.filter(id => id !== connectionId);
if (tag.id === targetTagId) {
return { ...tag, connectionIds: [...filteredIds, connectionId] };
}
return { ...tag, connectionIds: filteredIds };
});
return { connectionTags: newTags };
}),
reorderTags: (tagIds) => set((state) => {
const tagMap = new Map(state.connectionTags.map(t => [t.id, t]));
const newTags: ConnectionTag[] = [];
tagIds.forEach(id => {
const tag = tagMap.get(id);
if (tag) {
newTags.push(tag);
tagMap.delete(id);
}
});
// 追加未指定的tag如果有的话
newTags.push(...Array.from(tagMap.values()));
return { connectionTags: newTags };
}),
addTab: (tab) => set((state) => {
const index = state.tabs.findIndex(t => t.id === tab.id);
@@ -640,6 +751,16 @@ export const useStore = create<AppState>()(
setGlobalProxy: (proxy) => set((state) => ({ globalProxy: sanitizeGlobalProxy({ ...state.globalProxy, ...proxy }) })),
setSqlFormatOptions: (options) => set({ sqlFormatOptions: options }),
setQueryOptions: (options) => set((state) => ({ queryOptions: { ...state.queryOptions, ...options } })),
updateShortcut: (action, binding) => set((state) => ({
shortcutOptions: {
...state.shortcutOptions,
[action]: {
...state.shortcutOptions[action],
...binding,
},
},
})),
resetShortcutOptions: () => set({ shortcutOptions: cloneShortcutOptions(DEFAULT_SHORTCUT_OPTIONS) }),
addSqlLog: (log) => set((state) => ({ sqlLogs: [log, ...state.sqlLogs].slice(0, 1000) })), // Keep last 1000 logs
clearSqlLogs: () => set({ sqlLogs: [] }),
@@ -672,6 +793,11 @@ export const useStore = create<AppState>()(
const state = unwrapPersistedAppState(persistedState) as Partial<AppState>;
const nextState: Partial<AppState> = { ...state };
nextState.connections = sanitizeConnections(state.connections);
if (version < 5) {
nextState.connectionTags = sanitizeConnectionTags(state.connectionTags);
} else {
nextState.connectionTags = sanitizeConnectionTags(state.connectionTags);
}
nextState.savedQueries = sanitizeSavedQueries(state.savedQueries);
nextState.theme = sanitizeTheme(state.theme);
nextState.appearance = sanitizeAppearance(state.appearance, version);
@@ -681,6 +807,7 @@ export const useStore = create<AppState>()(
nextState.globalProxy = sanitizeGlobalProxy(state.globalProxy);
nextState.sqlFormatOptions = sanitizeSqlFormatOptions(state.sqlFormatOptions);
nextState.queryOptions = sanitizeQueryOptions(state.queryOptions);
nextState.shortcutOptions = sanitizeShortcutOptions(state.shortcutOptions);
nextState.tableAccessCount = sanitizeTableAccessCount(state.tableAccessCount);
nextState.tableSortPreference = sanitizeTableSortPreference(state.tableSortPreference);
return nextState as AppState;
@@ -691,6 +818,7 @@ export const useStore = create<AppState>()(
...currentState,
...state,
connections: sanitizeConnections(state.connections),
connectionTags: sanitizeConnectionTags(state.connectionTags),
savedQueries: sanitizeSavedQueries(state.savedQueries),
theme: sanitizeTheme(state.theme),
appearance: sanitizeAppearance(state.appearance, PERSIST_VERSION),
@@ -700,12 +828,14 @@ export const useStore = create<AppState>()(
globalProxy: sanitizeGlobalProxy(state.globalProxy),
sqlFormatOptions: sanitizeSqlFormatOptions(state.sqlFormatOptions),
queryOptions: sanitizeQueryOptions(state.queryOptions),
shortcutOptions: sanitizeShortcutOptions(state.shortcutOptions),
tableAccessCount: sanitizeTableAccessCount(state.tableAccessCount),
tableSortPreference: sanitizeTableSortPreference(state.tableSortPreference),
};
},
partialize: (state) => ({
connections: state.connections,
connectionTags: state.connectionTags,
savedQueries: state.savedQueries,
theme: state.theme,
appearance: state.appearance,
@@ -715,6 +845,7 @@ export const useStore = create<AppState>()(
globalProxy: state.globalProxy,
sqlFormatOptions: state.sqlFormatOptions,
queryOptions: state.queryOptions,
shortcutOptions: state.shortcutOptions,
tableAccessCount: state.tableAccessCount,
tableSortPreference: state.tableSortPreference
}), // Don't persist logs

View File

@@ -22,6 +22,10 @@ export interface ConnectionConfig {
password?: string;
savePassword?: boolean;
database?: string;
useSSL?: boolean;
sslMode?: 'preferred' | 'required' | 'skip-verify' | 'disable';
sslCertPath?: string;
sslKeyPath?: string;
useSSH?: boolean;
ssh?: SSHConfig;
useProxy?: boolean;
@@ -61,6 +65,12 @@ export interface SavedConnection {
includeRedisDatabases?: number[]; // Redis databases to show (0-15)
}
export interface ConnectionTag {
id: string;
name: string;
connectionIds: string[];
}
export interface ColumnDefinition {
name: string;
type: string;

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,258 @@
import type { KeyboardEvent as ReactKeyboardEvent } from 'react';
export type ShortcutAction =
| 'runQuery'
| 'focusSidebarSearch'
| 'newQueryTab'
| 'toggleLogPanel'
| 'toggleTheme'
| 'openShortcutManager';
export interface ShortcutBinding {
combo: string;
enabled: boolean;
}
export type ShortcutOptions = Record<ShortcutAction, ShortcutBinding>;
export interface ShortcutActionMeta {
label: string;
description: string;
allowInEditable?: boolean;
}
const MODIFIER_ORDER = ['Ctrl', 'Meta', 'Alt', 'Shift'] as const;
const MODIFIER_SET = new Set(MODIFIER_ORDER);
const KEY_ALIASES: Record<string, string> = {
control: 'Ctrl',
ctrl: 'Ctrl',
command: 'Meta',
cmd: 'Meta',
meta: 'Meta',
option: 'Alt',
alt: 'Alt',
shift: 'Shift',
escape: 'Esc',
esc: 'Esc',
return: 'Enter',
enter: 'Enter',
tab: 'Tab',
space: 'Space',
' ': 'Space',
backspace: 'Backspace',
delete: 'Delete',
del: 'Delete',
arrowup: 'Up',
up: 'Up',
arrowdown: 'Down',
down: 'Down',
arrowleft: 'Left',
left: 'Left',
arrowright: 'Right',
right: 'Right',
pagedown: 'PageDown',
pageup: 'PageUp',
home: 'Home',
end: 'End',
insert: 'Insert',
',': ',',
'.': '.',
'/': '/',
';': ';',
"'": "'",
'[': '[',
']': ']',
'\\': '\\',
'-': '-',
'=': '=',
'`': '`',
};
export const SHORTCUT_ACTION_ORDER: ShortcutAction[] = [
'runQuery',
'focusSidebarSearch',
'newQueryTab',
'toggleLogPanel',
'toggleTheme',
'openShortcutManager',
];
export const SHORTCUT_ACTION_META: Record<ShortcutAction, ShortcutActionMeta> = {
runQuery: {
label: '执行 SQL',
description: '在当前查询页执行 SQL',
},
focusSidebarSearch: {
label: '聚焦侧边栏搜索',
description: '定位到左侧连接树搜索框',
allowInEditable: true,
},
newQueryTab: {
label: '新建查询页',
description: '创建一个新的 SQL 查询标签页',
},
toggleLogPanel: {
label: '切换日志面板',
description: '打开或关闭 SQL 执行日志面板',
},
toggleTheme: {
label: '切换主题',
description: '在亮色和暗色主题之间切换',
},
openShortcutManager: {
label: '打开快捷键管理',
description: '打开快捷键设置面板',
allowInEditable: true,
},
};
export const DEFAULT_SHORTCUT_OPTIONS: ShortcutOptions = {
runQuery: { combo: 'Ctrl+Shift+R', enabled: true },
focusSidebarSearch: { combo: 'Ctrl+F', enabled: true },
newQueryTab: { combo: 'Ctrl+Shift+N', enabled: true },
toggleLogPanel: { combo: 'Ctrl+Shift+L', enabled: true },
toggleTheme: { combo: 'Ctrl+Shift+D', enabled: true },
openShortcutManager: { combo: 'Ctrl+,', enabled: true },
};
const normalizeKeyToken = (value: string): string => {
const token = String(value || '').trim();
if (!token) return '';
const alias = KEY_ALIASES[token.toLowerCase()];
if (alias) return alias;
if (/^f([1-9]|1[0-2])$/i.test(token)) {
return token.toUpperCase();
}
if (token.length === 1) {
return token === '+' ? '+' : token.toUpperCase();
}
return token.length > 1 ? token[0].toUpperCase() + token.slice(1).toLowerCase() : token;
};
export const normalizeShortcutCombo = (combo: string): string => {
const raw = String(combo || '').trim();
if (!raw) return '';
const pieces = raw
.split('+')
.map(part => part.trim())
.filter(Boolean);
const modifiers: string[] = [];
let key = '';
pieces.forEach((part) => {
const normalized = normalizeKeyToken(part);
if (!normalized) return;
if (MODIFIER_SET.has(normalized as typeof MODIFIER_ORDER[number])) {
if (!modifiers.includes(normalized)) {
modifiers.push(normalized);
}
return;
}
key = normalized;
});
modifiers.sort((a, b) => MODIFIER_ORDER.indexOf(a as typeof MODIFIER_ORDER[number]) - MODIFIER_ORDER.indexOf(b as typeof MODIFIER_ORDER[number]));
if (!key) {
return modifiers.join('+');
}
return [...modifiers, key].join('+');
};
const normalizeKeyboardKey = (key: string): string => {
const token = String(key || '').trim();
if (!token) return '';
const alias = KEY_ALIASES[token.toLowerCase()];
if (alias) return alias;
if (token.length === 1) {
if (token === ' ') return 'Space';
return token.toUpperCase();
}
if (/^f([1-9]|1[0-2])$/i.test(token)) {
return token.toUpperCase();
}
return token.length > 1 ? token[0].toUpperCase() + token.slice(1) : token;
};
export const eventToShortcut = (event: KeyboardEvent | ReactKeyboardEvent): string => {
const key = normalizeKeyboardKey(event.key);
if (!key || MODIFIER_SET.has(key as typeof MODIFIER_ORDER[number])) {
return '';
}
const modifiers: string[] = [];
if (event.ctrlKey) modifiers.push('Ctrl');
if (event.metaKey) modifiers.push('Meta');
if (event.altKey) modifiers.push('Alt');
if (event.shiftKey) modifiers.push('Shift');
return normalizeShortcutCombo([...modifiers, key].join('+'));
};
export const isShortcutMatch = (event: KeyboardEvent | ReactKeyboardEvent, combo: string): boolean => {
const expected = normalizeShortcutCombo(combo);
if (!expected) return false;
const actual = eventToShortcut(event);
return actual === expected;
};
export const hasModifierKey = (combo: string): boolean => {
const normalized = normalizeShortcutCombo(combo);
if (!normalized) return false;
return normalized.split('+').some(part => MODIFIER_SET.has(part as typeof MODIFIER_ORDER[number]));
};
export const cloneShortcutOptions = (value: ShortcutOptions): ShortcutOptions => {
return SHORTCUT_ACTION_ORDER.reduce((acc, action) => {
acc[action] = {
combo: normalizeShortcutCombo(value[action]?.combo || DEFAULT_SHORTCUT_OPTIONS[action].combo),
enabled: value[action]?.enabled !== false,
};
return acc;
}, {} as ShortcutOptions);
};
export const sanitizeShortcutOptions = (value: unknown): ShortcutOptions => {
const raw = (value && typeof value === 'object') ? value as Record<string, unknown> : {};
const defaults = cloneShortcutOptions(DEFAULT_SHORTCUT_OPTIONS);
SHORTCUT_ACTION_ORDER.forEach((action) => {
const actionRaw = raw[action];
if (!actionRaw || typeof actionRaw !== 'object') {
return;
}
const binding = actionRaw as Record<string, unknown>;
const combo = normalizeShortcutCombo(String(binding.combo || defaults[action].combo));
defaults[action] = {
combo: combo || defaults[action].combo,
enabled: binding.enabled === false ? false : true,
};
});
return defaults;
};
export const isEditableElement = (target: EventTarget | null): boolean => {
if (!(target instanceof HTMLElement)) {
return false;
}
const tag = target.tagName.toLowerCase();
if (target.isContentEditable) {
return true;
}
if (tag === 'input' || tag === 'textarea' || tag === 'select') {
return true;
}
if (target.closest('.monaco-editor, .monaco-inputbox, .ant-select, .ant-picker, .ant-input')) {
return true;
}
return false;
};
export const getShortcutDisplay = (combo: string): string => {
const normalized = normalizeShortcutCombo(combo);
return normalized || '-';
};

View File

@@ -1,6 +1,7 @@
export type FilterCondition = {
id?: number;
enabled?: boolean;
logic?: 'AND' | 'OR';
column?: string;
op?: string;
value?: string;
@@ -142,8 +143,12 @@ export const parseListValues = (val: string) => {
.filter(Boolean);
};
const normalizeConditionLogic = (logic: unknown): 'AND' | 'OR' => {
return String(logic || '').trim().toUpperCase() === 'OR' ? 'OR' : 'AND';
};
export const buildWhereSQL = (dbType: string, conditions: FilterCondition[]) => {
const whereParts: string[] = [];
const whereParts: Array<{ expr: string; logic: 'AND' | 'OR' }> = [];
(conditions || []).forEach((cond) => {
if (cond?.enabled === false) return;
@@ -152,10 +157,17 @@ export const buildWhereSQL = (dbType: string, conditions: FilterCondition[]) =>
const column = (cond?.column || '').trim();
const value = (cond?.value ?? '').toString();
const value2 = (cond?.value2 ?? '').toString();
const logic = normalizeConditionLogic(cond?.logic);
const appendWherePart = (expr: string) => {
const normalizedExpr = String(expr || '').trim();
if (!normalizedExpr) return;
whereParts.push({ expr: normalizedExpr, logic });
};
if (op === 'CUSTOM') {
const expr = value.trim();
if (expr) whereParts.push(`(${expr})`);
if (expr) appendWherePart(`(${expr})`);
return;
}
@@ -165,80 +177,80 @@ export const buildWhereSQL = (dbType: string, conditions: FilterCondition[]) =>
switch (op) {
case 'IS_NULL':
whereParts.push(`${col} IS NULL`);
appendWherePart(`${col} IS NULL`);
return;
case 'IS_NOT_NULL':
whereParts.push(`${col} IS NOT NULL`);
appendWherePart(`${col} IS NOT NULL`);
return;
case 'IS_EMPTY':
// 兼容:空值通常理解为 NULL 或空字符串
whereParts.push(`(${col} IS NULL OR ${col} = '')`);
appendWherePart(`(${col} IS NULL OR ${col} = '')`);
return;
case 'IS_NOT_EMPTY':
whereParts.push(`(${col} IS NOT NULL AND ${col} <> '')`);
appendWherePart(`(${col} IS NOT NULL AND ${col} <> '')`);
return;
case 'BETWEEN': {
const v1 = value.trim();
const v2 = value2.trim();
if (!v1 || !v2) return;
whereParts.push(`${col} BETWEEN '${escapeLiteral(v1)}' AND '${escapeLiteral(v2)}'`);
appendWherePart(`${col} BETWEEN '${escapeLiteral(v1)}' AND '${escapeLiteral(v2)}'`);
return;
}
case 'NOT_BETWEEN': {
const v1 = value.trim();
const v2 = value2.trim();
if (!v1 || !v2) return;
whereParts.push(`${col} NOT BETWEEN '${escapeLiteral(v1)}' AND '${escapeLiteral(v2)}'`);
appendWherePart(`${col} NOT BETWEEN '${escapeLiteral(v1)}' AND '${escapeLiteral(v2)}'`);
return;
}
case 'IN': {
const items = parseListValues(value);
if (items.length === 0) return;
const list = items.map(v => `'${escapeLiteral(v)}'`).join(', ');
whereParts.push(`${col} IN (${list})`);
appendWherePart(`${col} IN (${list})`);
return;
}
case 'NOT_IN': {
const items = parseListValues(value);
if (items.length === 0) return;
const list = items.map(v => `'${escapeLiteral(v)}'`).join(', ');
whereParts.push(`${col} NOT IN (${list})`);
appendWherePart(`${col} NOT IN (${list})`);
return;
}
case 'CONTAINS': {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} LIKE '%${escapeLiteral(v)}%'`);
appendWherePart(`${col} LIKE '%${escapeLiteral(v)}%'`);
return;
}
case 'NOT_CONTAINS': {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} NOT LIKE '%${escapeLiteral(v)}%'`);
appendWherePart(`${col} NOT LIKE '%${escapeLiteral(v)}%'`);
return;
}
case 'STARTS_WITH': {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} LIKE '${escapeLiteral(v)}%'`);
appendWherePart(`${col} LIKE '${escapeLiteral(v)}%'`);
return;
}
case 'NOT_STARTS_WITH': {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} NOT LIKE '${escapeLiteral(v)}%'`);
appendWherePart(`${col} NOT LIKE '${escapeLiteral(v)}%'`);
return;
}
case 'ENDS_WITH': {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} LIKE '%${escapeLiteral(v)}'`);
appendWherePart(`${col} LIKE '%${escapeLiteral(v)}'`);
return;
}
case 'NOT_ENDS_WITH': {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} NOT LIKE '%${escapeLiteral(v)}'`);
appendWherePart(`${col} NOT LIKE '%${escapeLiteral(v)}'`);
return;
}
case '=':
@@ -249,7 +261,7 @@ export const buildWhereSQL = (dbType: string, conditions: FilterCondition[]) =>
case '>=': {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} ${op} '${escapeLiteral(v)}'`);
appendWherePart(`${col} ${op} '${escapeLiteral(v)}'`);
return;
}
default: {
@@ -257,16 +269,23 @@ export const buildWhereSQL = (dbType: string, conditions: FilterCondition[]) =>
if (op.toUpperCase() === 'LIKE') {
const v = value.trim();
if (!v) return;
whereParts.push(`${col} LIKE '%${escapeLiteral(v)}%'`);
appendWherePart(`${col} LIKE '%${escapeLiteral(v)}%'`);
return;
}
const v = value.trim();
if (!v) return;
whereParts.push(`${col} ${op} '${escapeLiteral(v)}'`);
appendWherePart(`${col} ${op} '${escapeLiteral(v)}'`);
}
}
});
return whereParts.length > 0 ? `WHERE ${whereParts.join(' AND ')}` : '';
if (whereParts.length === 0) return '';
let whereExpr = `(${whereParts[0].expr})`;
for (let i = 1; i < whereParts.length; i++) {
const part = whereParts[i];
whereExpr = `(${whereExpr} ${part.logic} (${part.expr}))`;
}
return `WHERE ${whereExpr}`;
};

View File

@@ -1,15 +1,20 @@
// Cynhyrchwyd y ffeil hon yn awtomatig. PEIDIWCH Â MODIWL
// This file is automatically generated. DO NOT EDIT
import {connection} from '../models';
import {time} from '../models';
import {sync} from '../models';
import {redis} from '../models';
export function ApplyChanges(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:connection.ChangeSet):Promise<connection.QueryResult>;
export function CancelQuery(arg1:string):Promise<connection.QueryResult>;
export function CheckDriverNetworkStatus():Promise<connection.QueryResult>;
export function CheckForUpdates():Promise<connection.QueryResult>;
export function CleanupStaleQueries(arg1:time.Duration):Promise<void>;
export function ConfigureDriverRuntimeDirectory(arg1:string):Promise<connection.QueryResult>;
export function ConfigureGlobalProxy(arg1:boolean,arg2:connection.ProxyConfig):Promise<connection.QueryResult>;
@@ -36,6 +41,8 @@ export function DBQuery(arg1:connection.ConnectionConfig,arg2:string,arg3:string
export function DBQueryIsolated(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise<connection.QueryResult>;
export function DBQueryWithCancel(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise<connection.QueryResult>;
export function DBShowCreateTable(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise<connection.QueryResult>;
export function DataSync(arg1:sync.SyncConfig):Promise<sync.SyncResult>;
@@ -68,6 +75,8 @@ export function ExportTablesDataSQL(arg1:connection.ConnectionConfig,arg2:string
export function ExportTablesSQL(arg1:connection.ConnectionConfig,arg2:string,arg3:Array<string>,arg4:boolean):Promise<connection.QueryResult>;
export function GenerateQueryID():Promise<string>;
export function GetAppInfo():Promise<connection.QueryResult>;
export function GetDriverStatusList(arg1:string,arg2:string):Promise<connection.QueryResult>;

View File

@@ -6,6 +6,10 @@ export function ApplyChanges(arg1, arg2, arg3, arg4) {
return window['go']['app']['App']['ApplyChanges'](arg1, arg2, arg3, arg4);
}
export function CancelQuery(arg1) {
return window['go']['app']['App']['CancelQuery'](arg1);
}
export function CheckDriverNetworkStatus() {
return window['go']['app']['App']['CheckDriverNetworkStatus']();
}
@@ -14,6 +18,10 @@ export function CheckForUpdates() {
return window['go']['app']['App']['CheckForUpdates']();
}
export function CleanupStaleQueries(arg1) {
return window['go']['app']['App']['CleanupStaleQueries'](arg1);
}
export function ConfigureDriverRuntimeDirectory(arg1) {
return window['go']['app']['App']['ConfigureDriverRuntimeDirectory'](arg1);
}
@@ -66,6 +74,10 @@ export function DBQueryIsolated(arg1, arg2, arg3) {
return window['go']['app']['App']['DBQueryIsolated'](arg1, arg2, arg3);
}
export function DBQueryWithCancel(arg1, arg2, arg3, arg4) {
return window['go']['app']['App']['DBQueryWithCancel'](arg1, arg2, arg3, arg4);
}
export function DBShowCreateTable(arg1, arg2, arg3) {
return window['go']['app']['App']['DBShowCreateTable'](arg1, arg2, arg3);
}
@@ -130,6 +142,10 @@ export function ExportTablesSQL(arg1, arg2, arg3, arg4) {
return window['go']['app']['App']['ExportTablesSQL'](arg1, arg2, arg3, arg4);
}
export function GenerateQueryID() {
return window['go']['app']['App']['GenerateQueryID']();
}
export function GetAppInfo() {
return window['go']['app']['App']['GetAppInfo']();
}

View File

@@ -96,6 +96,10 @@ export namespace connection {
password: string;
savePassword?: boolean;
database: string;
useSSL?: boolean;
sslMode?: string;
sslCertPath?: string;
sslKeyPath?: string;
useSSH: boolean;
ssh: SSHConfig;
useProxy?: boolean;
@@ -130,6 +134,10 @@ export namespace connection {
this.password = source["password"];
this.savePassword = source["savePassword"];
this.database = source["database"];
this.useSSL = source["useSSL"];
this.sslMode = source["sslMode"];
this.sslCertPath = source["sslCertPath"];
this.sslKeyPath = source["sslKeyPath"];
this.useSSH = source["useSSH"];
this.ssh = this.convertValues(source["ssh"], SSHConfig);
this.useProxy = source["useProxy"];
@@ -176,6 +184,7 @@ export namespace connection {
message: string;
data: any;
fields?: string[];
queryId?: string;
static createFrom(source: any = {}) {
return new QueryResult(source);
@@ -187,6 +196,7 @@ export namespace connection {
this.message = source["message"];
this.data = source["data"];
this.fields = source["fields"];
this.queryId = source["queryId"];
}
}

4
go.mod
View File

@@ -8,6 +8,7 @@ require (
github.com/ClickHouse/clickhouse-go/v2 v2.43.0
github.com/duckdb/duckdb-go/v2 v2.5.5
github.com/go-sql-driver/mysql v1.9.3
github.com/google/uuid v1.6.0
github.com/highgo/pq-sm3 v0.0.0
github.com/lib/pq v1.11.1
github.com/microsoft/go-mssqldb v1.9.6
@@ -16,6 +17,7 @@ require (
github.com/taosdata/driver-go/v3 v3.7.8
github.com/wailsapp/wails/v2 v2.11.0
github.com/xuri/excelize/v2 v2.10.0
go.mongodb.org/mongo-driver v1.17.9
go.mongodb.org/mongo-driver/v2 v2.5.0
golang.org/x/crypto v0.47.0
golang.org/x/mod v0.32.0
@@ -49,7 +51,6 @@ require (
github.com/golang-sql/sqlexp v0.1.0 // indirect
github.com/golang/snappy v1.0.0 // indirect
github.com/google/flatbuffers v25.12.19+incompatible // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/gorilla/websocket v1.5.3 // indirect
github.com/hashicorp/go-version v1.8.0 // indirect
github.com/jchv/go-winloader v0.0.0-20210711035445-715c2860da7e // indirect
@@ -66,6 +67,7 @@ require (
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/montanaflynn/stats v0.7.1 // indirect
github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/paulmach/orb v0.12.0 // indirect
github.com/pierrec/lz4/v4 v4.1.25 // indirect

4
go.sum
View File

@@ -156,6 +156,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc=
github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE=
github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
@@ -248,6 +250,8 @@ github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN
github.com/zeebo/xxh3 v1.1.0 h1:s7DLGDK45Dyfg7++yxI0khrfwq9661w9EN78eP/UZVs=
github.com/zeebo/xxh3 v1.1.0/go.mod h1:IisAie1LELR4xhVinxWS5+zf1lA4p0MW4T+w+W07F5s=
go.mongodb.org/mongo-driver v1.11.4/go.mod h1:PTSz5yu21bkT/wXpkS7WR5f0ddqw5quethTUn9WM+2g=
go.mongodb.org/mongo-driver v1.17.9 h1:IexDdCuuNJ3BHrELgBlyaH9p60JXAvdzWR128q+U5tU=
go.mongodb.org/mongo-driver v1.17.9/go.mod h1:LlOhpH5NUEfhxcAwG0UEkMqwYcc4JU18gtCdGudk/tQ=
go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE=
go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0=
go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48=

View File

@@ -16,6 +16,7 @@ import (
"GoNavi-Wails/internal/db"
"GoNavi-Wails/internal/logger"
proxytunnel "GoNavi-Wails/internal/proxy"
"github.com/google/uuid"
)
const dbCachePingInterval = 30 * time.Second
@@ -25,19 +26,27 @@ type cachedDatabase struct {
lastPing time.Time
}
type queryContext struct {
cancel context.CancelFunc
started time.Time
}
// App struct
type App struct {
ctx context.Context
dbCache map[string]cachedDatabase // Cache for DB connections
mu sync.RWMutex // Mutex for cache access
updateMu sync.Mutex
updateState updateState
ctx context.Context
dbCache map[string]cachedDatabase // Cache for DB connections
mu sync.RWMutex // Mutex for cache access
updateMu sync.Mutex
updateState updateState
queryMu sync.RWMutex
runningQueries map[string]queryContext // queryID -> cancelFunc and start time
}
// NewApp creates a new App application struct
func NewApp() *App {
return &App{
dbCache: make(map[string]cachedDatabase),
dbCache: make(map[string]cachedDatabase),
runningQueries: make(map[string]queryContext),
}
}
@@ -139,6 +148,67 @@ func getCacheKey(config connection.ConnectionConfig) string {
return hex.EncodeToString(sum[:])
}
func shortCacheKey(cacheKey string) string {
shortKey := cacheKey
if len(shortKey) > 12 {
shortKey = shortKey[:12]
}
return shortKey
}
func shouldRefreshCachedConnection(err error) bool {
if err == nil {
return false
}
normalized := strings.ToLower(normalizeErrorMessage(err))
if normalized == "" {
return false
}
patterns := []string{
"invalid connection",
"bad connection",
"database is closed",
"connection is already closed",
"use of closed network connection",
"broken pipe",
"connection reset by peer",
"server has gone away",
"eof",
}
for _, pattern := range patterns {
if strings.Contains(normalized, pattern) {
return true
}
}
return false
}
func (a *App) invalidateCachedDatabase(config connection.ConnectionConfig, reason error) bool {
effectiveConfig := applyGlobalProxyToConnection(config)
key := getCacheKey(effectiveConfig)
shortKey := shortCacheKey(key)
a.mu.Lock()
defer a.mu.Unlock()
entry, exists := a.dbCache[key]
if !exists || entry.inst == nil {
return false
}
if closeErr := entry.inst.Close(); closeErr != nil {
logger.Error(closeErr, "关闭失效缓存连接失败缓存Key=%s", shortKey)
}
delete(a.dbCache, key)
if reason != nil {
logger.Errorf("检测到连接失效,已清理缓存连接:%s 缓存Key=%s 原因=%s", formatConnSummary(effectiveConfig), shortKey, normalizeErrorMessage(reason))
} else {
logger.Infof("已清理缓存连接:%s 缓存Key=%s", formatConnSummary(effectiveConfig), shortKey)
}
return true
}
func wrapConnectError(config connection.ConnectionConfig, err error) error {
if err == nil {
return nil
@@ -408,3 +478,43 @@ func (a *App) getDatabaseWithPing(config connection.ConnectionConfig, forcePing
logger.Infof("数据库连接成功并写入缓存:%s 缓存Key=%s", formatConnSummary(effectiveConfig), shortKey)
return dbInst, nil
}
// generateQueryID generates a unique ID for a query using UUID v4
func generateQueryID() string {
return "query-" + uuid.New().String()
}
// CancelQuery cancels a running query by its ID
func (a *App) CancelQuery(queryID string) connection.QueryResult {
a.queryMu.Lock()
defer a.queryMu.Unlock()
if ctx, exists := a.runningQueries[queryID]; exists {
ctx.cancel()
delete(a.runningQueries, queryID)
logger.Infof("查询已取消queryID=%s", queryID)
return connection.QueryResult{Success: true, Message: "查询已取消"}
}
logger.Warnf("取消查询失败queryID=%s 不存在或已完成", queryID)
return connection.QueryResult{Success: false, Message: "查询不存在或已完成"}
}
// CleanupStaleQueries removes queries older than maxAge
func (a *App) CleanupStaleQueries(maxAge time.Duration) {
a.queryMu.Lock()
defer a.queryMu.Unlock()
now := time.Now()
for id, ctx := range a.runningQueries {
if now.Sub(ctx.started) > maxAge {
// Query likely finished or stuck, remove from tracking
delete(a.runningQueries, id)
// Query expired, silently remove
}
}
}
// GenerateQueryID generates a unique query ID for cancellation tracking
func (a *App) GenerateQueryID() string {
return generateQueryID()
}

View File

@@ -36,6 +36,17 @@ func normalizeSchemaAndTable(config connection.ConnectionConfig, dbName string,
return rawDB, rawTable
}
dbType := strings.ToLower(strings.TrimSpace(config.Type))
if dbType == "sqlserver" {
// SQL Server 的 DB 接口约定第一个参数是数据库名schema 由 tableName(如 dbo.users) 自行解析。
// 不能把 schema(dbo) 传到第一个参数,否则会拼出 dbo.sys.columns 等无效对象名。
targetDB := rawDB
if targetDB == "" {
targetDB = strings.TrimSpace(config.Database)
}
return targetDB, rawTable
}
if parts := strings.SplitN(rawTable, ".", 2); len(parts) == 2 {
schema := strings.TrimSpace(parts[0])
table := strings.TrimSpace(parts[1])
@@ -44,13 +55,10 @@ func normalizeSchemaAndTable(config connection.ConnectionConfig, dbName string,
}
}
switch strings.ToLower(strings.TrimSpace(config.Type)) {
switch dbType {
case "postgres", "kingbase", "highgo", "vastbase":
// PG/金仓/瀚高/海量dbName 在 UI 里是"数据库"schema 需从 tableName 或使用默认 public。
return "public", rawTable
case "sqlserver":
// SQL ServerdbName 表示数据库schema 默认 dbo
return "dbo", rawTable
default:
// MySQLdbName 表示数据库Oracle/达梦dbName 表示 schema/owner。
return rawDB, rawTable

View File

@@ -0,0 +1,51 @@
package app
import (
"testing"
"GoNavi-Wails/internal/connection"
)
func TestNormalizeSchemaAndTable_SQLServerKeepsDatabaseAndQualifiedTable(t *testing.T) {
t.Parallel()
schemaOrDb, table := normalizeSchemaAndTable(connection.ConnectionConfig{
Type: "sqlserver",
Database: "master",
}, "biz_db", "dbo.users")
if schemaOrDb != "biz_db" {
t.Fatalf("expected sqlserver first return value as database name, got %q", schemaOrDb)
}
if table != "dbo.users" {
t.Fatalf("expected sqlserver table name keep qualified form, got %q", table)
}
}
func TestNormalizeSchemaAndTable_SQLServerFallbackToConfigDatabase(t *testing.T) {
t.Parallel()
schemaOrDb, table := normalizeSchemaAndTable(connection.ConnectionConfig{
Type: "sqlserver",
Database: "biz_db",
}, "", "dbo.users")
if schemaOrDb != "biz_db" {
t.Fatalf("expected sqlserver fallback database from config, got %q", schemaOrDb)
}
if table != "dbo.users" {
t.Fatalf("expected sqlserver table name keep qualified form, got %q", table)
}
}
func TestNormalizeSchemaAndTable_PostgresStillSplitsQualifiedName(t *testing.T) {
t.Parallel()
schema, table := normalizeSchemaAndTable(connection.ConnectionConfig{
Type: "postgres",
}, "demo_db", "public.orders")
if schema != "public" || table != "orders" {
t.Fatalf("expected postgres qualified split to public.orders, got %q.%q", schema, table)
}
}

View File

@@ -376,12 +376,21 @@ func (a *App) MySQLShowCreateTable(config connection.ConnectionConfig, dbName st
}
func (a *App) DBQuery(config connection.ConnectionConfig, dbName string, query string) connection.QueryResult {
return a.DBQueryWithCancel(config, dbName, query, "")
}
func (a *App) DBQueryWithCancel(config connection.ConnectionConfig, dbName string, query string, queryID string) connection.QueryResult {
runConfig := normalizeRunConfig(config, dbName)
// Generate query ID if not provided
if queryID == "" {
queryID = generateQueryID()
}
dbInst, err := a.getDatabase(runConfig)
if err != nil {
logger.Error(err, "DBQuery 获取连接失败:%s", formatConnSummary(runConfig))
return connection.QueryResult{Success: false, Message: err.Error()}
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
query = sanitizeSQLForPgLike(runConfig.Type, query)
@@ -392,41 +401,80 @@ func (a *App) DBQuery(config connection.ConnectionConfig, dbName string, query s
ctx, cancel := utils.ContextWithTimeout(time.Duration(timeoutSeconds) * time.Second)
defer cancel()
// Store cancel function for potential manual cancellation
a.queryMu.Lock()
a.runningQueries[queryID] = queryContext{
cancel: cancel,
started: time.Now(),
}
a.queryMu.Unlock()
// Ensure query is removed from tracking when done
defer func() {
a.queryMu.Lock()
delete(a.runningQueries, queryID)
a.queryMu.Unlock()
}()
lowerQuery := strings.TrimSpace(strings.ToLower(query))
isReadQuery := strings.HasPrefix(lowerQuery, "select") || strings.HasPrefix(lowerQuery, "show") || strings.HasPrefix(lowerQuery, "describe") || strings.HasPrefix(lowerQuery, "explain")
// MongoDB JSON 命令中的 find/count/aggregate 也属于读查询
if !isReadQuery && strings.ToLower(strings.TrimSpace(runConfig.Type)) == "mongodb" && strings.HasPrefix(strings.TrimSpace(query), "{") {
isReadQuery = true
}
if isReadQuery {
var data []map[string]interface{}
var columns []string
if q, ok := dbInst.(interface {
runReadQuery := func(inst db.Database) ([]map[string]interface{}, []string, error) {
if q, ok := inst.(interface {
QueryContext(context.Context, string) ([]map[string]interface{}, []string, error)
}); ok {
data, columns, err = q.QueryContext(ctx, query)
} else {
data, columns, err = dbInst.Query(query)
return q.QueryContext(ctx, query)
}
return inst.Query(query)
}
runExecQuery := func(inst db.Database) (int64, error) {
if e, ok := inst.(interface {
ExecContext(context.Context, string) (int64, error)
}); ok {
return e.ExecContext(ctx, query)
}
return inst.Exec(query)
}
if isReadQuery {
data, columns, err := runReadQuery(dbInst)
if err != nil && shouldRefreshCachedConnection(err) {
if a.invalidateCachedDatabase(runConfig, err) {
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
if retryErr != nil {
logger.Error(retryErr, "DBQuery 重建连接失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: retryErr.Error()}
}
data, columns, err = runReadQuery(retryInst)
}
}
if err != nil {
logger.Error(err, "DBQuery 查询失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error()}
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
return connection.QueryResult{Success: true, Data: data, Fields: columns}
return connection.QueryResult{Success: true, Data: data, Fields: columns, QueryID: queryID}
} else {
var affected int64
if e, ok := dbInst.(interface {
ExecContext(context.Context, string) (int64, error)
}); ok {
affected, err = e.ExecContext(ctx, query)
} else {
affected, err = dbInst.Exec(query)
affected, err := runExecQuery(dbInst)
if err != nil && shouldRefreshCachedConnection(err) {
if a.invalidateCachedDatabase(runConfig, err) {
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
if retryErr != nil {
logger.Error(retryErr, "DBQuery 重建连接失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: retryErr.Error()}
}
affected, err = runExecQuery(retryInst)
}
}
if err != nil {
logger.Error(err, "DBQuery 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error()}
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
return connection.QueryResult{Success: true, Data: map[string]int64{"affectedRows": affected}}
return connection.QueryResult{Success: true, Data: map[string]int64{"affectedRows": affected}, QueryID: queryID}
}
}
@@ -500,15 +548,26 @@ func sqlSnippet(query string) string {
}
func (a *App) DBGetDatabases(config connection.ConnectionConfig) connection.QueryResult {
dbInst, err := a.getDatabase(config)
runConfig := normalizeRunConfig(config, "")
dbInst, err := a.getDatabase(runConfig)
if err != nil {
logger.Error(err, "DBGetDatabases 获取连接失败:%s", formatConnSummary(config))
logger.Error(err, "DBGetDatabases 获取连接失败:%s", formatConnSummary(runConfig))
return connection.QueryResult{Success: false, Message: err.Error()}
}
dbs, err := dbInst.GetDatabases()
if err != nil && shouldRefreshCachedConnection(err) {
if a.invalidateCachedDatabase(runConfig, err) {
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
if retryErr != nil {
logger.Error(retryErr, "DBGetDatabases 重建连接失败:%s", formatConnSummary(runConfig))
return connection.QueryResult{Success: false, Message: retryErr.Error()}
}
dbs, err = retryInst.GetDatabases()
}
}
if err != nil {
logger.Error(err, "DBGetDatabases 获取数据库列表失败:%s", formatConnSummary(config))
logger.Error(err, "DBGetDatabases 获取数据库列表失败:%s", formatConnSummary(runConfig))
return connection.QueryResult{Success: false, Message: err.Error()}
}
@@ -530,6 +589,16 @@ func (a *App) DBGetTables(config connection.ConnectionConfig, dbName string) con
}
tables, err := dbInst.GetTables(dbName)
if err != nil && shouldRefreshCachedConnection(err) {
if a.invalidateCachedDatabase(runConfig, err) {
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
if retryErr != nil {
logger.Error(retryErr, "DBGetTables 重建连接失败:%s", formatConnSummary(runConfig))
return connection.QueryResult{Success: false, Message: retryErr.Error()}
}
tables, err = retryInst.GetTables(dbName)
}
}
if err != nil {
logger.Error(err, "DBGetTables 获取表列表失败:%s", formatConnSummary(runConfig))
return connection.QueryResult{Success: false, Message: err.Error()}

View File

@@ -0,0 +1,149 @@
package app
import (
"context"
"strings"
"testing"
"time"
"GoNavi-Wails/internal/connection"
)
func TestGenerateQueryID(t *testing.T) {
app := NewApp()
id := app.GenerateQueryID()
if id == "" {
t.Fatal("GenerateQueryID returned empty string")
}
// Should start with "query-"
if !strings.HasPrefix(id, "query-") {
t.Fatalf("Expected query ID to start with 'query-', got: %s", id)
}
// Should be reasonably unique (not equal to another generated ID)
id2 := app.GenerateQueryID()
if id == id2 {
t.Fatal("Two consecutive GenerateQueryID calls returned identical IDs")
}
}
func TestCancelQuery_NonExistent(t *testing.T) {
app := NewApp()
res := app.CancelQuery("non-existent-query-id")
if res.Success {
t.Fatal("CancelQuery should fail for non-existent query ID")
}
if !strings.Contains(res.Message, "不存在") && !strings.Contains(res.Message, "not exist") {
t.Fatalf("Expected error message about query not existing, got: %s", res.Message)
}
}
func TestCancelQuery_ValidQuery(t *testing.T) {
app := NewApp()
// First, generate a query ID and simulate a running query
queryID := app.GenerateQueryID()
// Store a cancel function in runningQueries map
_, cancel := context.WithCancel(context.Background())
app.queryMu.Lock()
app.runningQueries[queryID] = queryContext{
cancel: cancel,
started: time.Now(),
}
app.queryMu.Unlock()
// Ensure cleanup after test
defer func() {
app.queryMu.Lock()
delete(app.runningQueries, queryID)
app.queryMu.Unlock()
}()
// Cancel the query
res := app.CancelQuery(queryID)
if !res.Success {
t.Fatalf("CancelQuery should succeed for valid query ID, got: %s", res.Message)
}
// Verify query removed from map
app.queryMu.Lock()
_, exists := app.runningQueries[queryID]
app.queryMu.Unlock()
if exists {
t.Fatal("Query should be removed from runningQueries after cancellation")
}
}
func TestCleanupStaleQueries(t *testing.T) {
app := NewApp()
// Add a stale query (started 2 hours ago)
queryID := app.GenerateQueryID()
_, cancel := context.WithCancel(context.Background())
app.queryMu.Lock()
app.runningQueries[queryID] = queryContext{
cancel: cancel,
started: time.Now().Add(-2 * time.Hour),
}
app.queryMu.Unlock()
// Cleanup queries older than 1 hour
app.CleanupStaleQueries(1 * time.Hour)
// Verify stale query was removed
app.queryMu.Lock()
_, exists := app.runningQueries[queryID]
app.queryMu.Unlock()
if exists {
t.Fatal("Stale query should be removed by CleanupStaleQueries")
}
// Add a fresh query (started 30 minutes ago)
freshID := app.GenerateQueryID()
_, cancel2 := context.WithCancel(context.Background())
app.queryMu.Lock()
app.runningQueries[freshID] = queryContext{
cancel: cancel2,
started: time.Now().Add(-30 * time.Minute),
}
app.queryMu.Unlock()
defer cancel2()
// Cleanup queries older than 1 hour
app.CleanupStaleQueries(1 * time.Hour)
// Verify fresh query still exists
app.queryMu.Lock()
_, exists = app.runningQueries[freshID]
app.queryMu.Unlock()
if !exists {
t.Fatal("Fresh query should not be removed by CleanupStaleQueries")
}
// Clean up
app.queryMu.Lock()
delete(app.runningQueries, freshID)
app.queryMu.Unlock()
}
func TestDBQueryWithCancel_QueryIDPropagation(t *testing.T) {
// This test verifies that query ID is properly propagated in QueryResult
// Since we can't easily mock database connections, we'll test the integration
// by checking that DBQueryWithCancel returns a QueryResult with QueryID field
app := NewApp()
// Create a minimal config for a database type that doesn't require actual connection
config := connection.ConnectionConfig{
Type: "duckdb",
Host: ":memory:", // In-memory duckdb for testing
}
// This will fail because we can't actually connect, but we can test the error path
result := app.DBQueryWithCancel(config, "", "SELECT 1", "test-query-id")
// The query should fail (no actual database), but QueryID should be present
if result.QueryID != "test-query-id" {
t.Fatalf("Expected QueryID 'test-query-id' in result, got: %s", result.QueryID)
}
}

View File

@@ -304,13 +304,33 @@ var driverGoModulePathMap = map[string]string{
"clickhouse": "github.com/ClickHouse/clickhouse-go/v2",
}
var driverGoModuleAliasPathMap = map[string][]string{
"mongodb": {
"go.mongodb.org/mongo-driver",
},
}
var driverExtraHistoryLimitMap = map[string]int{
"mongodb": 10,
}
var fallbackRecentDriverVersionsMap = map[string][]goModuleVersionMeta{
"mongodb": {
{Version: "2.5.0"},
{Version: "2.4.2"},
{Version: "2.4.1"},
{Version: "2.4.0"},
{Version: "2.3.1"},
{Version: "2.3.0"},
{Version: "2.2.3"},
{Version: "1.17.9"},
{Version: "1.17.8"},
{Version: "1.17.7"},
{Version: "1.17.6"},
{Version: "1.17.4"},
{Version: "1.17.3"},
{Version: "1.17.2"},
{Version: "1.17.1"},
{Version: "1.17.0"},
{Version: "1.16.1"},
},
}
@@ -696,11 +716,11 @@ func (a *App) CheckDriverNetworkStatus() connection.QueryResult {
}
data := map[string]interface{}{
"reachable": allReachable,
"summary": summary,
"recommendedProxy": !allReachable,
"proxyConfigured": proxyConfigured,
"proxyEnv": proxyEnv,
"reachable": allReachable,
"summary": summary,
"recommendedProxy": !allReachable,
"proxyConfigured": proxyConfigured,
"proxyEnv": proxyEnv,
"downloadChainReachable": downloadChainReachable,
"downloadRequiredHosts": []string{
"github.com",
@@ -709,8 +729,8 @@ func (a *App) CheckDriverNetworkStatus() connection.QueryResult {
"objects.githubusercontent.com",
"raw.githubusercontent.com",
},
"checkedAt": time.Now().Format(time.RFC3339),
"checks": checks,
"checkedAt": time.Now().Format(time.RFC3339),
"checks": checks,
}
if logPath := strings.TrimSpace(logger.Path()); logPath != "" {
data["logPath"] = logPath
@@ -1600,17 +1620,57 @@ func resolveRecentDriverVersionMetas(driverType string, limit int) []goModuleVer
if normalized == "" {
return nil
}
if modulePath := strings.TrimSpace(driverGoModulePathMap[normalized]); modulePath != "" {
if metas := fetchGoModuleVersionMetasCached(modulePath); len(metas) > 0 {
if len(metas) > limit {
return append([]goModuleVersionMeta(nil), metas[:limit]...)
modulePaths := resolveDriverGoModulePaths(normalized)
if len(modulePaths) > 0 {
result := make([]goModuleVersionMeta, 0, limit)
seen := make(map[string]struct{}, limit)
appendUnique := func(values []goModuleVersionMeta, maxAppend int) {
if maxAppend <= 0 {
return
}
return append([]goModuleVersionMeta(nil), metas...)
appended := 0
for _, meta := range values {
version := normalizeVersion(strings.TrimSpace(meta.Version))
if version == "" {
continue
}
key := strings.ToLower(version)
if _, ok := seen[key]; ok {
continue
}
meta.Version = version
result = append(result, meta)
seen[key] = struct{}{}
appended++
if appended >= maxAppend {
return
}
}
}
appendUnique(fetchGoModuleVersionMetasCached(modulePaths[0]), limit)
extraLimit := resolveDriverExtraHistoryLimit(normalized)
for _, modulePath := range modulePaths[1:] {
if extraLimit <= 0 {
break
}
before := len(result)
appendUnique(fetchGoModuleVersionMetasCached(modulePath), extraLimit)
extraLimit -= len(result) - before
}
if len(result) > 0 {
return result
}
}
fallbackLimit := limit + resolveDriverExtraHistoryLimit(normalized)
if fallbackLimit <= 0 {
fallbackLimit = limit
}
if fallback := fallbackRecentDriverVersionsMap[normalized]; len(fallback) > 0 {
if len(fallback) > limit {
return append([]goModuleVersionMeta(nil), fallback[:limit]...)
if len(fallback) > fallbackLimit {
return append([]goModuleVersionMeta(nil), fallback[:fallbackLimit]...)
}
return append([]goModuleVersionMeta(nil), fallback...)
}
@@ -1635,15 +1695,13 @@ func triggerDriverVersionMetadataWarmup(definitions []driverDefinition) {
if driverType == "" || !db.IsOptionalGoDriver(driverType) {
continue
}
modulePath := strings.TrimSpace(driverGoModulePathMap[driverType])
if modulePath == "" {
continue
for _, modulePath := range resolveDriverGoModulePaths(driverType) {
if _, ok := seenModule[modulePath]; ok {
continue
}
seenModule[modulePath] = struct{}{}
modulePaths = append(modulePaths, modulePath)
}
if _, ok := seenModule[modulePath]; ok {
continue
}
seenModule[modulePath] = struct{}{}
modulePaths = append(modulePaths, modulePath)
}
if len(modulePaths) == 0 {
@@ -1663,6 +1721,40 @@ func triggerDriverVersionMetadataWarmup(definitions []driverDefinition) {
}(append([]string(nil), modulePaths...))
}
func resolveDriverGoModulePaths(driverType string) []string {
normalized := normalizeDriverType(driverType)
if normalized == "" {
return nil
}
paths := make([]string, 0, 3)
seen := make(map[string]struct{}, 3)
appendPath := func(path string) {
trimmed := strings.TrimSpace(path)
if trimmed == "" {
return
}
if _, ok := seen[trimmed]; ok {
return
}
seen[trimmed] = struct{}{}
paths = append(paths, trimmed)
}
appendPath(driverGoModulePathMap[normalized])
for _, alias := range driverGoModuleAliasPathMap[normalized] {
appendPath(alias)
}
return paths
}
func resolveDriverExtraHistoryLimit(driverType string) int {
limit := driverExtraHistoryLimitMap[normalizeDriverType(driverType)]
if limit < 0 {
return 0
}
return limit
}
func tryStartDriverVersionMetadataWarmup(now time.Time) bool {
driverVersionWarmupMu.Lock()
defer driverVersionWarmupMu.Unlock()
@@ -2356,16 +2448,23 @@ func hashFileSHA256(filePath string) (string, error) {
func installOptionalDriverAgentPackage(a *App, definition driverDefinition, selectedVersion string, resolvedDir string, downloadURL string) (installedDriverPackage, error) {
driverType := normalizeDriverType(definition.Type)
executablePath, err := db.ResolveOptionalDriverAgentExecutablePath(resolvedDir, driverType)
installPath, err := db.ResolveOptionalDriverAgentExecutablePathForVersion(resolvedDir, driverType, selectedVersion)
if err != nil {
return installedDriverPackage{}, err
}
downloadSource, hash, err := ensureOptionalDriverAgentBinary(a, definition, executablePath, downloadURL)
runtimePath, err := db.ResolveOptionalDriverAgentExecutablePath(resolvedDir, driverType)
if err != nil {
return installedDriverPackage{}, err
}
downloadSource, hash, err := ensureOptionalDriverAgentBinary(a, definition, installPath, downloadURL, selectedVersion)
if err != nil {
return installedDriverPackage{}, err
}
if activateErr := activateOptionalDriverAgentBinary(installPath, runtimePath); activateErr != nil {
return installedDriverPackage{}, fmt.Errorf("activate %s driver agent failed: %w", resolveDriverDisplayName(definition), activateErr)
}
if strings.TrimSpace(hash) == "" {
hash, err = hashFileSHA256(executablePath)
hash, err = hashFileSHA256(installPath)
if err != nil {
return installedDriverPackage{}, fmt.Errorf("计算 %s 驱动代理摘要失败:%w", resolveDriverDisplayName(definition), err)
}
@@ -2376,9 +2475,9 @@ func installOptionalDriverAgentPackage(a *App, definition driverDefinition, sele
return installedDriverPackage{
DriverType: driverType,
Version: strings.TrimSpace(selectedVersion),
FilePath: executablePath,
FileName: filepath.Base(executablePath),
ExecutablePath: executablePath,
FilePath: installPath,
FileName: filepath.Base(installPath),
ExecutablePath: runtimePath,
DownloadURL: strings.TrimSpace(downloadSource),
SHA256: hash,
DownloadedAt: time.Now().Format(time.RFC3339),
@@ -2686,9 +2785,11 @@ func installOptionalDriverAgentFromLocalZip(zipPath string, definition driverDef
return filepath.ToSlash(strings.TrimPrefix(strings.TrimSpace(entry.Name), "./")), nil
}
func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, executablePath string, downloadURL string) (string, string, error) {
func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, executablePath string, downloadURL string, selectedVersion string) (string, string, error) {
driverType := normalizeDriverType(definition.Type)
displayName := resolveDriverDisplayName(definition)
forceSourceBuild := shouldForceSourceBuildForVersion(driverType, selectedVersion)
skipReuseCandidate := shouldSkipReusableAgentCandidate(driverType, selectedVersion)
info, err := os.Stat(executablePath)
if err == nil && !info.IsDir() {
@@ -2708,49 +2809,53 @@ func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, execut
if a != nil {
a.emitDriverDownloadProgress(driverType, "downloading", 10, 100, "检查本地驱动代理缓存")
}
if sourcePath, ok := findExistingOptionalDriverAgentCandidate(definition, executablePath); ok {
if copyErr := copyAgentBinary(sourcePath, executablePath); copyErr != nil {
return "", "", fmt.Errorf("复制预置 %s 驱动代理失败:%w", displayName, copyErr)
if !skipReuseCandidate {
if sourcePath, ok := findExistingOptionalDriverAgentCandidate(definition, executablePath); ok {
if copyErr := copyAgentBinary(sourcePath, executablePath); copyErr != nil {
return "", "", fmt.Errorf("复制预置 %s 驱动代理失败:%w", displayName, copyErr)
}
hash, hashErr := hashFileSHA256(executablePath)
if hashErr != nil {
return "", "", fmt.Errorf("计算预置 %s 驱动代理摘要失败:%w", displayName, hashErr)
}
return "file://" + sourcePath, hash, nil
}
hash, hashErr := hashFileSHA256(executablePath)
if hashErr != nil {
return "", "", fmt.Errorf("计算预置 %s 驱动代理摘要失败:%w", displayName, hashErr)
}
return "file://" + sourcePath, hash, nil
}
downloadURLs := resolveOptionalDriverAgentDownloadURLs(definition, downloadURL)
var downloadErrs []string
if len(downloadURLs) > 0 {
for _, candidateURL := range downloadURLs {
if a != nil {
a.emitDriverDownloadProgress(driverType, "downloading", 20, 100, fmt.Sprintf("下载预编译 %s 驱动代理", displayName))
if !forceSourceBuild {
downloadURLs := resolveOptionalDriverAgentDownloadURLs(definition, downloadURL)
if len(downloadURLs) > 0 {
for _, candidateURL := range downloadURLs {
if a != nil {
a.emitDriverDownloadProgress(driverType, "downloading", 20, 100, fmt.Sprintf("下载预编译 %s 驱动代理", displayName))
}
hash, dlErr := downloadOptionalDriverAgentBinary(a, definition, candidateURL, executablePath)
if dlErr == nil {
return candidateURL, hash, nil
}
downloadErrs = append(downloadErrs, fmt.Sprintf("%s: %s", candidateURL, strings.TrimSpace(dlErr.Error())))
}
hash, dlErr := downloadOptionalDriverAgentBinary(a, definition, candidateURL, executablePath)
if dlErr == nil {
return candidateURL, hash, nil
}
downloadErrs = append(downloadErrs, fmt.Sprintf("%s: %s", candidateURL, strings.TrimSpace(dlErr.Error())))
}
}
bundleURLs := resolveOptionalDriverBundleDownloadURLs()
if len(bundleURLs) > 0 {
for _, bundleURL := range bundleURLs {
if a != nil {
a.emitDriverDownloadProgress(driverType, "downloading", 20, 100, fmt.Sprintf("从驱动总包提取 %s 代理", displayName))
bundleURLs := resolveOptionalDriverBundleDownloadURLs()
if len(bundleURLs) > 0 {
for _, bundleURL := range bundleURLs {
if a != nil {
a.emitDriverDownloadProgress(driverType, "downloading", 20, 100, fmt.Sprintf("从驱动总包提取 %s 代理", displayName))
}
source, hash, bundleErr := downloadOptionalDriverAgentFromBundle(a, definition, bundleURL, executablePath)
if bundleErr == nil {
return source, hash, nil
}
downloadErrs = append(downloadErrs, fmt.Sprintf("%s: %s", bundleURL, strings.TrimSpace(bundleErr.Error())))
}
source, hash, bundleErr := downloadOptionalDriverAgentFromBundle(a, definition, bundleURL, executablePath)
if bundleErr == nil {
return source, hash, nil
}
downloadErrs = append(downloadErrs, fmt.Sprintf("%s: %s", bundleURL, strings.TrimSpace(bundleErr.Error())))
}
}
if a != nil {
a.emitDriverDownloadProgress(driverType, "downloading", 92, 100, "未命中预编译包,尝试开发态本地构建")
}
hash, buildErr := buildOptionalDriverAgentFromSource(definition, executablePath)
hash, buildErr := buildOptionalDriverAgentFromSource(definition, executablePath, selectedVersion)
if buildErr == nil {
return fmt.Sprintf("local://go-build/%s-driver-agent", driverType), hash, nil
}
@@ -2912,7 +3017,7 @@ func downloadOptionalDriverAgentFromBundle(a *App, definition driverDefinition,
return source, hash, nil
}
func buildOptionalDriverAgentFromSource(definition driverDefinition, executablePath string) (string, error) {
func buildOptionalDriverAgentFromSource(definition driverDefinition, executablePath string, selectedVersion string) (string, error) {
driverType := normalizeDriverType(definition.Type)
displayName := resolveDriverDisplayName(definition)
goPath, lookErr := exec.LookPath("go")
@@ -2920,7 +3025,7 @@ func buildOptionalDriverAgentFromSource(definition driverDefinition, executableP
return "", fmt.Errorf("当前环境未安装 Go且未找到可用的 %s 预编译代理包", displayName)
}
tagName, tagErr := optionalDriverBuildTag(driverType)
tagName, tagErr := optionalDriverBuildTag(driverType, selectedVersion)
if tagErr != nil {
return "", tagErr
}
@@ -2931,6 +3036,7 @@ func buildOptionalDriverAgentFromSource(definition driverDefinition, executableP
}
cmd := exec.Command(goPath, "build", "-tags", tagName, "-trimpath", "-ldflags", "-s -w", "-o", executablePath, "./cmd/optional-driver-agent")
cmd.Dir = projectRoot
cmd.Env = append(os.Environ(), "GOTOOLCHAIN=auto")
output, buildErr := cmd.CombinedOutput()
if buildErr != nil {
return "", fmt.Errorf("构建 %s 驱动代理失败:%v输出%s", displayName, buildErr, strings.TrimSpace(string(output)))
@@ -2945,7 +3051,31 @@ func buildOptionalDriverAgentFromSource(definition driverDefinition, executableP
return hash, nil
}
func optionalDriverBuildTag(driverType string) (string, error) {
func resolveMongoDriverMajorFromVersion(version string) int {
trimmed := strings.TrimSpace(version)
trimmed = strings.TrimPrefix(trimmed, "v")
if strings.HasPrefix(trimmed, "1.") || trimmed == "1" {
return 1
}
return 2
}
func shouldForceSourceBuildForVersion(driverType string, selectedVersion string) bool {
if normalizeDriverType(driverType) != "mongodb" {
return false
}
return resolveMongoDriverMajorFromVersion(selectedVersion) == 1
}
func shouldSkipReusableAgentCandidate(driverType string, selectedVersion string) bool {
if normalizeDriverType(driverType) != "mongodb" {
return false
}
_ = selectedVersion
return true
}
func optionalDriverBuildTag(driverType string, selectedVersion string) (string, error) {
switch normalizeDriverType(driverType) {
case "mysql":
return "gonavi_mysql_driver", nil
@@ -2970,6 +3100,9 @@ func optionalDriverBuildTag(driverType string) (string, error) {
case "vastbase":
return "gonavi_vastbase_driver", nil
case "mongodb":
if resolveMongoDriverMajorFromVersion(selectedVersion) == 1 {
return "gonavi_mongodb_driver_v1", nil
}
return "gonavi_mongodb_driver", nil
case "tdengine":
return "gonavi_tdengine_driver", nil
@@ -3310,6 +3443,30 @@ func resolveDriverDisplayName(definition driverDefinition) string {
return "未知"
}
func activateOptionalDriverAgentBinary(installPath string, runtimePath string) error {
source := strings.TrimSpace(installPath)
target := strings.TrimSpace(runtimePath)
if source == "" || target == "" {
return fmt.Errorf("agent path is empty")
}
if source == target {
return nil
}
absSource := source
absTarget := target
if value, err := filepath.Abs(source); err == nil && strings.TrimSpace(value) != "" {
absSource = value
}
if value, err := filepath.Abs(target); err == nil && strings.TrimSpace(value) != "" {
absTarget = value
}
if strings.EqualFold(absSource, absTarget) {
return nil
}
return copyAgentBinary(source, target)
}
func copyAgentBinary(sourcePath, targetPath string) error {
src, err := os.Open(sourcePath)
if err != nil {

View File

@@ -6,6 +6,7 @@ import (
"encoding/csv"
"encoding/json"
"fmt"
"html"
"math"
"os"
"path/filepath"
@@ -1595,6 +1596,26 @@ func writeRowsToFile(f *os.File, data []map[string]interface{}, columns []string
return writeRowsToXlsx(f.Name(), data, columns)
}
// html 使用内嵌 CSS 输出可直接浏览器预览的独立页面
if format == "html" {
return writeRowsToHTML(f, data, columns)
}
// 如果列名为空但数据不为空,从所有数据行提取所有键
if len(columns) == 0 && len(data) > 0 {
keySet := make(map[string]bool)
for _, row := range data {
for key := range row {
keySet[key] = true
}
}
// 排序以确保输出一致
for key := range keySet {
columns = append(columns, key)
}
sort.Strings(columns)
}
var csvWriter *csv.Writer
var jsonEncoder *json.Encoder
isJsonFirstRow := true
@@ -1688,6 +1709,188 @@ func writeRowsToFile(f *os.File, data []map[string]interface{}, columns []string
return nil
}
func formatExportHTMLCell(val interface{}) string {
text := formatExportCellText(val)
escaped := html.EscapeString(text)
escaped = strings.ReplaceAll(escaped, "\r\n", "\n")
escaped = strings.ReplaceAll(escaped, "\r", "\n")
return strings.ReplaceAll(escaped, "\n", "<br>")
}
func writeRowsToHTML(f *os.File, data []map[string]interface{}, columns []string) error {
w := bufio.NewWriterSize(f, 1024*256)
if _, err := w.WriteString(`<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>GoNavi Export</title>
<style>
:root {
color-scheme: light;
--bg: #f8f9fa;
--card: #ffffff;
--line: #dee2e6;
--text: #212529;
--muted: #6c757d;
--hover: #f1f3f5;
--zebra: #f8f9fa;
--head: #ffffff;
}
* { box-sizing: border-box; }
body {
margin: 0;
padding: 24px;
background: var(--bg);
color: var(--text);
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, "Noto Sans", "PingFang SC", "Microsoft YaHei", sans-serif;
line-height: 1.6;
}
.export-wrap {
max-width: 100%;
margin: 0 auto;
background: var(--card);
border: 1px solid var(--line);
border-radius: 8px;
overflow: hidden;
}
.export-head {
padding: 16px 20px;
background: var(--head);
border-bottom: 2px solid var(--line);
}
.export-head h1 {
margin: 0;
font-size: 16px;
font-weight: 600;
color: var(--text);
}
.export-meta {
margin-top: 6px;
color: var(--muted);
font-size: 13px;
}
.table-wrap {
width: 100%;
overflow: auto;
padding: 16px;
}
table {
border-collapse: collapse;
width: auto;
font-size: 13px;
}
thead th {
position: sticky;
top: 0;
z-index: 2;
background: var(--head);
text-align: left;
font-weight: 600;
white-space: nowrap;
border-bottom: 2px solid var(--line);
color: var(--text);
padding: 12px 16px;
}
td {
padding: 10px 16px;
border-bottom: 1px solid var(--line);
vertical-align: top;
white-space: pre-wrap;
word-wrap: break-word;
overflow-wrap: anywhere;
max-width: 500px;
color: var(--text);
}
tbody tr:nth-child(even) {
background: var(--zebra);
}
tbody tr:hover {
background: var(--hover);
}
td.empty {
text-align: center;
color: var(--muted);
font-style: italic;
}
@media (max-width: 768px) {
body { padding: 16px; }
.export-head { padding: 12px 16px; }
.table-wrap { padding: 12px; }
th, td { padding: 8px 12px; font-size: 12px; }
}
@media print {
body { background: white; padding: 0; }
.export-wrap { border: none; }
}
</style>
</head>
<body>
<div class="export-wrap">
<div class="export-head">
<h1>GoNavi Data Export</h1>
<div class="export-meta">`); err != nil {
return err
}
if _, err := fmt.Fprintf(w, "Rows: %d · Columns: %d · Generated: %s", len(data), len(columns), time.Now().Format("2006-01-02 15:04:05")); err != nil {
return err
}
if _, err := w.WriteString(`</div>
</div>
<div class="table-wrap">
<table>
<thead><tr>`); err != nil {
return err
}
for _, col := range columns {
if _, err := fmt.Fprintf(w, "<th>%s</th>", html.EscapeString(col)); err != nil {
return err
}
}
if _, err := w.WriteString(`</tr></thead><tbody>`); err != nil {
return err
}
if len(data) == 0 {
colspan := len(columns)
if colspan <= 0 {
colspan = 1
}
if _, err := fmt.Fprintf(w, `<tr><td class="empty" colspan="%d">(0 rows)</td></tr>`, colspan); err != nil {
return err
}
} else {
for _, rowMap := range data {
if _, err := w.WriteString("<tr>"); err != nil {
return err
}
for _, col := range columns {
if _, err := fmt.Fprintf(w, "<td>%s</td>", formatExportHTMLCell(rowMap[col])); err != nil {
return err
}
}
if _, err := w.WriteString("</tr>"); err != nil {
return err
}
}
}
if _, err := w.WriteString(`</tbody></table>
</div>
</div>
</body>
</html>`); err != nil {
return err
}
return w.Flush()
}
func formatExportCellText(val interface{}) string {
if val == nil {
return "NULL"

View File

@@ -203,3 +203,73 @@ func TestGetExportQueryTimeout_CustomClickHouseUsesLongerMinimum(t *testing.T) {
t.Fatalf("custom clickhouse 导出超时下限异常want=%s got=%s", minClickHouseExportQueryTimeout, timeout)
}
}
func TestWriteRowsToFile_HTML_EscapeAndStyle(t *testing.T) {
f, err := os.CreateTemp("", "gonavi-export-*.html")
if err != nil {
t.Fatalf("创建临时文件失败: %v", err)
}
defer os.Remove(f.Name())
defer f.Close()
data := []map[string]interface{}{
{
"name": "<script>alert(1)</script>",
"note": "line1\nline2",
"nullable": nil,
},
}
columns := []string{"name", "note", "nullable"}
if err := writeRowsToFile(f, data, columns, "html"); err != nil {
t.Fatalf("写入 html 失败: %v", err)
}
contentBytes, err := os.ReadFile(f.Name())
if err != nil {
t.Fatalf("读取 html 失败: %v", err)
}
content := string(contentBytes)
if !strings.Contains(content, "<!DOCTYPE html>") {
t.Fatalf("html 导出缺少 doctype: %s", content)
}
if !strings.Contains(content, "position: sticky") {
t.Fatalf("html 导出缺少表头吸顶样式: %s", content)
}
if !strings.Contains(content, "tbody tr:nth-child(even)") {
t.Fatalf("html 导出缺少斑马纹样式: %s", content)
}
if !strings.Contains(content, "&lt;script&gt;alert(1)&lt;/script&gt;") {
t.Fatalf("html 导出未进行 XSS 转义: %s", content)
}
if strings.Contains(content, "<script>alert(1)</script>") {
t.Fatalf("html 导出包含未转义脚本: %s", content)
}
if !strings.Contains(content, "line1<br>line2") {
t.Fatalf("html 导出换行未转为 <br>: %s", content)
}
if !strings.Contains(content, "<td>NULL</td>") {
t.Fatalf("html 导出空值显示异常: %s", content)
}
}
func TestWriteRowsToFile_HTML_EscapeHeader(t *testing.T) {
f, err := os.CreateTemp("", "gonavi-export-*.html")
if err != nil {
t.Fatalf("创建临时文件失败: %v", err)
}
defer os.Remove(f.Name())
defer f.Close()
columnName := "<b>name</b>"
data := []map[string]interface{}{{columnName: "ok"}}
if err := writeRowsToFile(f, data, []string{columnName}, "html"); err != nil {
t.Fatalf("写入 html 失败: %v", err)
}
contentBytes, _ := os.ReadFile(f.Name())
content := string(contentBytes)
if !strings.Contains(content, "<th>&lt;b&gt;name&lt;/b&gt;</th>") || strings.Contains(content, "<th><b>name</b></th>") {
t.Fatalf("html 表头未正确转义: %s", content)
}
}

View File

@@ -51,12 +51,13 @@ type UpdateInfo struct {
}
type AppInfo struct {
Version string `json:"version"`
Author string `json:"author"`
RepoURL string `json:"repoUrl,omitempty"`
IssueURL string `json:"issueUrl,omitempty"`
ReleaseURL string `json:"releaseUrl,omitempty"`
BuildTime string `json:"buildTime,omitempty"`
Version string `json:"version"`
Author string `json:"author"`
RepoURL string `json:"repoUrl,omitempty"`
IssueURL string `json:"issueUrl,omitempty"`
ReleaseURL string `json:"releaseUrl,omitempty"`
CommunityURL string `json:"communityUrl,omitempty"`
BuildTime string `json:"buildTime,omitempty"`
}
type updateDownloadResult struct {
@@ -137,12 +138,13 @@ func (a *App) CheckForUpdates() connection.QueryResult {
func (a *App) GetAppInfo() connection.QueryResult {
info := AppInfo{
Version: getCurrentVersion(),
Author: getCurrentAuthor(),
RepoURL: "https://github.com/" + updateRepo,
IssueURL: "https://github.com/" + updateRepo + "/issues",
ReleaseURL: "https://github.com/" + updateRepo + "/releases",
BuildTime: strings.TrimSpace(AppBuildTime),
Version: getCurrentVersion(),
Author: getCurrentAuthor(),
RepoURL: "https://github.com/" + updateRepo,
IssueURL: "https://github.com/" + updateRepo + "/issues",
ReleaseURL: "https://github.com/" + updateRepo + "/releases",
CommunityURL: "https://aibook.ren",
BuildTime: strings.TrimSpace(AppBuildTime),
}
return connection.QueryResult{Success: true, Message: "OK", Data: info}
}

View File

@@ -27,6 +27,10 @@ type ConnectionConfig struct {
Password string `json:"password"`
SavePassword bool `json:"savePassword,omitempty"` // Persist password in saved connection
Database string `json:"database"`
UseSSL bool `json:"useSSL,omitempty"` // MySQL-like SSL/TLS switch
SSLMode string `json:"sslMode,omitempty"` // preferred | required | skip-verify | disable
SSLCertPath string `json:"sslCertPath,omitempty"` // TLS client certificate path (e.g., Dameng)
SSLKeyPath string `json:"sslKeyPath,omitempty"` // TLS client private key path (e.g., Dameng)
UseSSH bool `json:"useSSH"`
SSH SSHConfig `json:"ssh"`
UseProxy bool `json:"useProxy,omitempty"`
@@ -55,6 +59,7 @@ type QueryResult struct {
Message string `json:"message"`
Data interface{} `json:"data"`
Fields []string `json:"fields,omitempty"`
QueryID string `json:"queryId,omitempty"` // Unique ID for query cancellation
}
// ColumnDefinition represents a table column

View File

@@ -107,7 +107,7 @@ func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig
if readTimeout < minClickHouseReadTimeout {
readTimeout = minClickHouseReadTimeout
}
return &clickhouse.Options{
opts := &clickhouse.Options{
Addr: []string{
net.JoinHostPort(config.Host, strconv.Itoa(config.Port)),
},
@@ -119,6 +119,10 @@ func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig
DialTimeout: connectTimeout,
ReadTimeout: readTimeout,
}
if tlsConfig := resolveGenericTLSConfig(config); tlsConfig != nil {
opts.TLS = tlsConfig
}
return opts
}
func (c *ClickHouseDB) Connect(config connection.ConnectionConfig) error {
@@ -165,13 +169,30 @@ func (c *ClickHouseDB) Connect(config connection.ConnectionConfig) error {
logger.Infof("ClickHouse 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
}
c.conn = clickhouse.OpenDB(c.buildClickHouseOptions(runConfig))
if err := c.Ping(); err != nil {
_ = c.Close()
return fmt.Errorf("连接建立后验证失败:%w", err)
attempts := []connection.ConnectionConfig{runConfig}
if shouldTrySSLPreferredFallback(runConfig) {
attempts = append(attempts, withSSLDisabled(runConfig))
}
return nil
var failures []string
for idx, attempt := range attempts {
c.conn = clickhouse.OpenDB(c.buildClickHouseOptions(attempt))
if err := c.Ping(); err != nil {
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err))
if c.conn != nil {
_ = c.conn.Close()
c.conn = nil
}
continue
}
if idx > 0 {
logger.Warnf("ClickHouse SSL 优先连接失败,已回退至明文连接")
}
return nil
}
_ = c.Close()
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ""))
}
func (c *ClickHouseDB) Close() error {

View File

@@ -36,6 +36,14 @@ func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
if config.Database != "" {
q.Set("schema", config.Database)
}
if config.UseSSL {
if certPath := strings.TrimSpace(config.SSLCertPath); certPath != "" {
q.Set("SSL_CERT_PATH", certPath)
}
if keyPath := strings.TrimSpace(config.SSLKeyPath); keyPath != "" {
q.Set("SSL_KEY_PATH", keyPath)
}
}
if escapedPassword != config.Password {
// 达梦驱动要求密码包含特殊字符时password 需 PathEscape并添加 escapeProcess=true 让驱动解码。
q.Set("escapeProcess", "true")
@@ -50,8 +58,12 @@ func (d *DamengDB) getDSN(config connection.ConnectionConfig) string {
}
func (d *DamengDB) Connect(config connection.ConnectionConfig) error {
var dsn string
var err error
runConfig := config
if runConfig.UseSSL {
if strings.TrimSpace(runConfig.SSLCertPath) == "" || strings.TrimSpace(runConfig.SSLKeyPath) == "" {
return fmt.Errorf("达梦启用 SSL 需要同时配置证书路径(sslCertPath)与私钥路径(sslKeyPath)")
}
}
if config.UseSSH {
// Create SSH tunnel with local port forwarding
@@ -80,22 +92,37 @@ func (d *DamengDB) Connect(config connection.ConnectionConfig) error {
localConfig.Port = port
localConfig.UseSSH = false
dsn = d.getDSN(localConfig)
runConfig = localConfig
logger.Infof("达梦数据库通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
} else {
dsn = d.getDSN(config)
}
db, err := sql.Open("dm", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
attempts := []connection.ConnectionConfig{runConfig}
if shouldTrySSLPreferredFallback(runConfig) {
attempts = append(attempts, withSSLDisabled(runConfig))
}
d.conn = db
d.pingTimeout = getConnectTimeout(config)
if err := d.Ping(); err != nil {
return fmt.Errorf("连接建立后验证失败:%w", err)
var failures []string
for idx, attempt := range attempts {
dsn := d.getDSN(attempt)
db, err := sql.Open("dm", dsn)
if err != nil {
failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err))
continue
}
d.conn = db
d.pingTimeout = getConnectTimeout(attempt)
if err := d.Ping(); err != nil {
_ = db.Close()
d.conn = nil
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err))
continue
}
if idx > 0 {
logger.Warnf("达梦 SSL 优先连接失败,已回退至明文连接")
}
return nil
}
return nil
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ""))
}
func (d *DamengDB) Close() error {

View File

@@ -151,9 +151,10 @@ func (d *DirosDB) getDSN(config connection.ConnectionConfig) string {
}
timeout := getConnectTimeoutSeconds(config)
tlsMode := resolveMySQLTLSMode(config)
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds",
config.User, config.Password, protocol, address, database, timeout)
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s",
config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode))
}
func resolveDirosCredential(config connection.ConnectionConfig, addressIndex int) (string, string) {

View File

@@ -33,6 +33,44 @@ func TestPostgresDSN_EscapesPassword(t *testing.T) {
}
}
func TestPostgresDSN_SSLModeRequireWhenEnabled(t *testing.T) {
p := &PostgresDB{}
cfg := connection.ConnectionConfig{
Type: "postgres",
Host: "127.0.0.1",
Port: 5432,
User: "user",
Password: "pass",
Database: "db",
UseSSL: true,
SSLMode: "required",
}
dsn := p.getDSN(cfg)
if !strings.Contains(dsn, "sslmode=require") {
t.Fatalf("dsn 缺少 sslmode=require 参数:%s", dsn)
}
}
func TestMySQLDSN_UsesTLSParamWhenSSLEnabled(t *testing.T) {
m := &MySQLDB{}
cfg := connection.ConnectionConfig{
Type: "mysql",
Host: "127.0.0.1",
Port: 3306,
User: "root",
Password: "pass",
Database: "db",
UseSSL: true,
SSLMode: "required",
}
dsn := m.getDSN(cfg)
if !strings.Contains(dsn, "tls=true") {
t.Fatalf("dsn 缺少 tls=true 参数:%s", dsn)
}
}
func TestOracleDSN_EscapesUserAndPassword(t *testing.T) {
o := &OracleDB{}
cfg := connection.ConnectionConfig{
@@ -82,6 +120,30 @@ func TestDamengDSN_EscapesPasswordAndEnablesEscapeProcess(t *testing.T) {
}
}
func TestDamengDSN_AppendsSSLCertAndKeyParams(t *testing.T) {
d := &DamengDB{}
cfg := connection.ConnectionConfig{
Type: "dameng",
Host: "127.0.0.1",
Port: 5236,
User: "SYSDBA",
Password: "pass",
Database: "DBName",
UseSSL: true,
SSLMode: "required",
SSLCertPath: "C:\\certs\\client-cert.pem",
SSLKeyPath: "C:\\certs\\client-key.pem",
}
dsn := d.getDSN(cfg)
if !strings.Contains(dsn, "SSL_CERT_PATH=") {
t.Fatalf("dsn 缺少 SSL_CERT_PATH 参数:%s", dsn)
}
if !strings.Contains(dsn, "SSL_KEY_PATH=") {
t.Fatalf("dsn 缺少 SSL_KEY_PATH 参数:%s", dsn)
}
}
func TestKingbaseDSN_QuotesPasswordWithSpaces(t *testing.T) {
k := &KingbaseDB{}
cfg := connection.ConnectionConfig{
@@ -116,6 +178,47 @@ func TestTDengineDSN_UsesWebSocketFormat(t *testing.T) {
}
}
func TestTDengineDSN_UsesSecureWebSocketWhenSSLEnabled(t *testing.T) {
td := &TDengineDB{}
cfg := connection.ConnectionConfig{
Type: "tdengine",
Host: "127.0.0.1",
Port: 6041,
User: "root",
Password: "taosdata",
Database: "power",
UseSSL: true,
SSLMode: "required",
}
dsn := td.getDSN(cfg)
if !strings.HasPrefix(dsn, "root:taosdata@wss(127.0.0.1:6041)/power") {
t.Fatalf("tdengine ssl dsn 格式不正确:%s", dsn)
}
}
func TestSQLServerDSN_EncryptMapping(t *testing.T) {
s := &SqlServerDB{}
cfg := connection.ConnectionConfig{
Type: "sqlserver",
Host: "127.0.0.1",
Port: 1433,
User: "sa",
Password: "pass",
Database: "master",
UseSSL: true,
SSLMode: "required",
}
dsn := s.getDSN(cfg)
if !strings.Contains(strings.ToLower(dsn), "encrypt=true") {
t.Fatalf("sqlserver dsn 缺少 encrypt=true%s", dsn)
}
if !strings.Contains(strings.ToLower(dsn), "trustservercertificate=false") {
t.Fatalf("sqlserver dsn 缺少 TrustServerCertificate=false%s", dsn)
}
}
func TestClickHouseOptions_UsesStructuredTimeoutAndAuth(t *testing.T) {
c := &ClickHouseDB{}
cfg := normalizeClickHouseConfig(connection.ConnectionConfig{

View File

@@ -42,7 +42,7 @@ func (h *HighGoDB) getDSN(config connection.ConnectionConfig) string {
}
u.User = url.UserPassword(config.User, config.Password)
q := url.Values{}
q.Set("sslmode", "disable")
q.Set("sslmode", resolvePostgresSSLMode(config))
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
u.RawQuery = q.Encode()
@@ -50,7 +50,7 @@ func (h *HighGoDB) getDSN(config connection.ConnectionConfig) string {
}
func (h *HighGoDB) Connect(config connection.ConnectionConfig) error {
var dsn string
runConfig := config
if config.UseSSH {
logger.Infof("HighGo 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
@@ -76,23 +76,37 @@ func (h *HighGoDB) Connect(config connection.ConnectionConfig) error {
localConfig.Port = port
localConfig.UseSSH = false
dsn = h.getDSN(localConfig)
runConfig = localConfig
logger.Infof("HighGo 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
} else {
dsn = h.getDSN(config)
}
db, err := sql.Open("highgo", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
attempts := []connection.ConnectionConfig{runConfig}
if shouldTrySSLPreferredFallback(runConfig) {
attempts = append(attempts, withSSLDisabled(runConfig))
}
h.conn = db
h.pingTimeout = getConnectTimeout(config)
if err := h.Ping(); err != nil {
return fmt.Errorf("连接建立后验证失败:%w", err)
var failures []string
for idx, attempt := range attempts {
dsn := h.getDSN(attempt)
db, err := sql.Open("highgo", dsn)
if err != nil {
failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err))
continue
}
h.conn = db
h.pingTimeout = getConnectTimeout(attempt)
if err := h.Ping(); err != nil {
_ = db.Close()
h.conn = nil
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err))
continue
}
if idx > 0 {
logger.Warnf("HighGo SSL 优先连接失败,已回退至明文连接")
}
return nil
}
return nil
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ""))
}
func (h *HighGoDB) Close() error {

View File

@@ -65,12 +65,13 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string {
port := config.Port
// Construct DSN
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable connect_timeout=%d",
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s connect_timeout=%d",
quoteConnValue(address),
port,
quoteConnValue(config.User),
quoteConnValue(config.Password),
quoteConnValue(config.Database),
quoteConnValue(resolvePostgresSSLMode(config)),
getConnectTimeoutSeconds(config),
)
@@ -78,8 +79,7 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string {
}
func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error {
var dsn string
var err error
runConfig := config
if config.UseSSH {
// Create SSH tunnel with local port forwarding
@@ -108,23 +108,37 @@ func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error {
localConfig.Port = port
localConfig.UseSSH = false
dsn = k.getDSN(localConfig)
runConfig = localConfig
logger.Infof("人大金仓通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
} else {
dsn = k.getDSN(config)
}
// Open using "kingbase" driver
db, err := sql.Open("kingbase", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
attempts := []connection.ConnectionConfig{runConfig}
if shouldTrySSLPreferredFallback(runConfig) {
attempts = append(attempts, withSSLDisabled(runConfig))
}
k.conn = db
k.pingTimeout = getConnectTimeout(config)
if err := k.Ping(); err != nil {
return fmt.Errorf("连接建立后验证失败:%w", err)
var failures []string
for idx, attempt := range attempts {
dsn := k.getDSN(attempt)
db, err := sql.Open("kingbase", dsn)
if err != nil {
failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err))
continue
}
k.conn = db
k.pingTimeout = getConnectTimeout(attempt)
if err := k.Ping(); err != nil {
_ = db.Close()
k.conn = nil
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err))
continue
}
if idx > 0 {
logger.Warnf("人大金仓 SSL 优先连接失败,已回退至明文连接")
}
return nil
}
return nil
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ""))
}
func (k *KingbaseDB) Close() error {

View File

@@ -6,6 +6,7 @@ import (
"context"
"database/sql"
"fmt"
"net/url"
"strings"
"time"
@@ -40,9 +41,10 @@ func (m *MariaDB) getDSN(config connection.ConnectionConfig) string {
}
timeout := getConnectTimeoutSeconds(config)
tlsMode := resolveMySQLTLSMode(config)
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds",
config.User, config.Password, protocol, address, database, timeout)
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s",
config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode))
}
func (m *MariaDB) Connect(config connection.ConnectionConfig) error {

View File

@@ -4,6 +4,7 @@ package db
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/url"
@@ -327,35 +328,60 @@ func (m *MongoDB) Connect(config connection.ConnectionConfig) error {
m.database = "admin"
}
attemptConfigs := buildMongoAuthAttempts(connectConfig)
sslAttempts := []connection.ConnectionConfig{connectConfig}
if shouldTrySSLPreferredFallback(connectConfig) {
sslAttempts = append(sslAttempts, withSSLDisabled(connectConfig))
}
var errorDetails []string
for index, attemptConfig := range attemptConfigs {
authLabel := "主库凭据"
if index > 0 {
authLabel = "从库凭据"
for sslIndex, sslConfig := range sslAttempts {
sslLabel := "SSL"
if sslIndex > 0 {
sslLabel = "明文回退"
}
uri := m.getURI(attemptConfig)
clientOpts := options.Client().ApplyURI(uri)
if attemptConfig.UseProxy {
clientOpts.SetDialer(&mongoProxyDialer{proxyConfig: attemptConfig.Proxy})
}
client, err := mongo.Connect(clientOpts)
if err != nil {
errorDetails = append(errorDetails, fmt.Sprintf("%s连接失败: %v", authLabel, err))
continue
}
attemptConfigs := buildMongoAuthAttempts(sslConfig)
for index, attemptConfig := range attemptConfigs {
authLabel := "主库凭据"
if index > 0 {
authLabel = "从库凭据"
}
m.client = client
if err := m.Ping(); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
_ = client.Disconnect(ctx)
cancel()
m.client = nil
errorDetails = append(errorDetails, fmt.Sprintf("%s验证失败: %v", authLabel, err))
continue
if sslIndex > 0 {
attemptConfig.URI = ""
}
uri := m.getURI(attemptConfig)
clientOpts := options.Client().ApplyURI(uri)
tlsEnabled, tlsInsecure := resolveMongoTLSSettings(attemptConfig)
if tlsEnabled {
clientOpts.SetTLSConfig(&tls.Config{
MinVersion: tls.VersionTLS12,
InsecureSkipVerify: tlsInsecure,
})
}
if attemptConfig.UseProxy {
clientOpts.SetDialer(&mongoProxyDialer{proxyConfig: attemptConfig.Proxy})
}
client, err := mongo.Connect(clientOpts)
if err != nil {
errorDetails = append(errorDetails, fmt.Sprintf("%s %s连接失败: %v", sslLabel, authLabel, err))
continue
}
m.client = client
if err := m.Ping(); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
_ = client.Disconnect(ctx)
cancel()
m.client = nil
errorDetails = append(errorDetails, fmt.Sprintf("%s %s验证失败: %v", sslLabel, authLabel, err))
continue
}
if sslIndex > 0 {
logger.Warnf("MongoDB SSL 优先连接失败,已回退至明文连接")
}
return nil
}
return nil
}
if len(errorDetails) > 0 {

File diff suppressed because it is too large Load Diff

View File

@@ -35,6 +35,32 @@ func ResolveOptionalDriverAgentExecutablePath(downloadDir string, driverType str
return filepath.Join(root, normalized, optionalDriverAgentExecutableName(normalized)), nil
}
func ResolveOptionalDriverAgentExecutablePathForVersion(downloadDir string, driverType string, version string) (string, error) {
normalized := normalizeRuntimeDriverType(driverType)
if strings.TrimSpace(normalized) == "" {
return "", fmt.Errorf("驱动类型为空")
}
root, err := resolveExternalDriverRoot(downloadDir)
if err != nil {
return "", err
}
if normalized != "mongodb" {
return filepath.Join(root, normalized, optionalDriverAgentExecutableName(normalized)), nil
}
baseName := optionalDriverAgentExecutableName(normalized)
ext := filepath.Ext(baseName)
stem := strings.TrimSuffix(baseName, ext)
major := 2
trimmed := strings.TrimSpace(version)
trimmed = strings.TrimPrefix(trimmed, "v")
if strings.HasPrefix(trimmed, "1.") || trimmed == "1" {
major = 1
}
versionedName := fmt.Sprintf("%s-v%d%s", stem, major, ext)
return filepath.Join(root, normalized, versionedName), nil
}
func ResolveMySQLAgentExecutablePath(downloadDir string) (string, error) {
return ResolveOptionalDriverAgentExecutablePath(downloadDir, "mysql")
}

View File

@@ -184,9 +184,10 @@ func (m *MySQLDB) getDSN(config connection.ConnectionConfig) string {
}
timeout := getConnectTimeoutSeconds(config)
tlsMode := resolveMySQLTLSMode(config)
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds",
config.User, config.Password, protocol, address, database, timeout)
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s",
config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode))
}
func resolveMySQLCredential(config connection.ConnectionConfig, addressIndex int) (string, string) {

View File

@@ -35,12 +35,23 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string {
}
u.User = url.UserPassword(config.User, config.Password)
u.RawPath = "/" + url.PathEscape(database)
q := url.Values{}
switch normalizedSSLMode(config) {
case sslModeRequired:
q.Set("SSL", "TRUE")
q.Set("SSL VERIFY", "TRUE")
case sslModeSkipVerify, sslModePreferred:
q.Set("SSL", "TRUE")
q.Set("SSL VERIFY", "FALSE")
}
if encoded := q.Encode(); encoded != "" {
u.RawQuery = encoded
}
return u.String()
}
func (o *OracleDB) Connect(config connection.ConnectionConfig) error {
var dsn string
var err error
runConfig := config
serviceName := strings.TrimSpace(config.Database)
if serviceName == "" {
return fmt.Errorf("Oracle 连接缺少服务名Service Name请在连接配置中填写例如 ORCLPDB1")
@@ -73,22 +84,37 @@ func (o *OracleDB) Connect(config connection.ConnectionConfig) error {
localConfig.Port = port
localConfig.UseSSH = false
dsn = o.getDSN(localConfig)
runConfig = localConfig
logger.Infof("Oracle 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
} else {
dsn = o.getDSN(config)
}
db, err := sql.Open("oracle", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
attempts := []connection.ConnectionConfig{runConfig}
if shouldTrySSLPreferredFallback(runConfig) {
attempts = append(attempts, withSSLDisabled(runConfig))
}
o.conn = db
o.pingTimeout = getConnectTimeout(config)
if err := o.Ping(); err != nil {
return fmt.Errorf("连接建立后验证失败:%w", err)
var failures []string
for idx, attempt := range attempts {
dsn := o.getDSN(attempt)
db, err := sql.Open("oracle", dsn)
if err != nil {
failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err))
continue
}
o.conn = db
o.pingTimeout = getConnectTimeout(attempt)
if err := o.Ping(); err != nil {
_ = db.Close()
o.conn = nil
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err))
continue
}
if idx > 0 {
logger.Warnf("Oracle SSL 优先连接失败,已回退至明文连接")
}
return nil
}
return nil
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ""))
}
func (o *OracleDB) Close() error {

View File

@@ -62,7 +62,7 @@ func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string {
}
u.User = url.UserPassword(config.User, config.Password)
q := url.Values{}
q.Set("sslmode", "disable")
q.Set("sslmode", resolvePostgresSSLMode(config))
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
u.RawQuery = q.Encode()
@@ -126,34 +126,49 @@ func (p *PostgresDB) Connect(config connection.ConnectionConfig) error {
logger.Infof("PostgreSQL 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
}
attemptDBs := resolvePostgresConnectDatabases(runConfig)
sslAttempts := []connection.ConnectionConfig{runConfig}
if shouldTrySSLPreferredFallback(runConfig) {
sslAttempts = append(sslAttempts, withSSLDisabled(runConfig))
}
var failures []string
for _, dbName := range attemptDBs {
attemptConfig := runConfig
attemptConfig.Database = dbName
dsn := p.getDSN(attemptConfig)
dbConn, err := sql.Open("postgres", dsn)
if err != nil {
failures = append(failures, fmt.Sprintf("数据库=%s 打开连接失败: %v", dbName, err))
continue
}
p.conn = dbConn
// Force verification
if err := p.Ping(); err != nil {
failures = append(failures, fmt.Sprintf("数据库=%s 验证失败: %v", dbName, err))
_ = dbConn.Close()
p.conn = nil
continue
for sslIndex, sslConfig := range sslAttempts {
sslLabel := "SSL"
if sslIndex > 0 {
sslLabel = "明文回退"
}
if strings.TrimSpace(config.Database) == "" && !strings.EqualFold(dbName, "postgres") {
logger.Infof("PostgreSQL 自动选择连接数据库:%s", dbName)
}
attemptDBs := resolvePostgresConnectDatabases(sslConfig)
for _, dbName := range attemptDBs {
attemptConfig := sslConfig
attemptConfig.Database = dbName
dsn := p.getDSN(attemptConfig)
cleanupOnFailure = false
return nil
dbConn, err := sql.Open("postgres", dsn)
if err != nil {
failures = append(failures, fmt.Sprintf("%s 数据库=%s 打开连接失败: %v", sslLabel, dbName, err))
continue
}
p.conn = dbConn
// Force verification
if err := p.Ping(); err != nil {
failures = append(failures, fmt.Sprintf("%s 数据库=%s 验证失败: %v", sslLabel, dbName, err))
_ = dbConn.Close()
p.conn = nil
continue
}
if sslIndex > 0 {
logger.Warnf("PostgreSQL SSL 优先连接失败,已回退至明文连接")
}
if strings.TrimSpace(config.Database) == "" && !strings.EqualFold(dbName, "postgres") {
logger.Infof("PostgreSQL 自动选择连接数据库:%s", dbName)
}
cleanupOnFailure = false
return nil
}
}
if len(failures) == 0 {

View File

@@ -47,8 +47,9 @@ func (s *SqlServerDB) getDSN(config connection.ConnectionConfig) string {
q := url.Values{}
q.Set("database", dbname)
q.Set("connection timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
q.Set("encrypt", "disable")
q.Set("TrustServerCertificate", "true")
encrypt, trustServerCertificate := resolveSQLServerTLSSettings(config)
q.Set("encrypt", encrypt)
q.Set("TrustServerCertificate", trustServerCertificate)
u.RawQuery = q.Encode()
return u.String()

122
internal/db/ssl_mode.go Normal file
View File

@@ -0,0 +1,122 @@
package db
import (
"crypto/tls"
"strings"
"GoNavi-Wails/internal/connection"
)
const (
sslModeDisable = "disable"
sslModePreferred = "preferred"
sslModeRequired = "required"
sslModeSkipVerify = "skip-verify"
)
func normalizeSSLModeValue(raw string) string {
mode := strings.ToLower(strings.TrimSpace(raw))
switch mode {
case "", sslModePreferred, "prefer":
return sslModePreferred
case sslModeRequired, "require", "on", "true", "mandatory", "strict":
return sslModeRequired
case sslModeSkipVerify, "insecure", "skipverify", "skip_verify", "insecure-skip-verify":
return sslModeSkipVerify
case sslModeDisable, "disabled", "off", "false", "none":
return sslModeDisable
default:
return sslModePreferred
}
}
func normalizedSSLMode(config connection.ConnectionConfig) string {
if !config.UseSSL {
return sslModeDisable
}
return normalizeSSLModeValue(config.SSLMode)
}
func shouldTrySSLPreferredFallback(config connection.ConnectionConfig) bool {
return config.UseSSL && normalizeSSLModeValue(config.SSLMode) == sslModePreferred
}
func withSSLDisabled(config connection.ConnectionConfig) connection.ConnectionConfig {
next := config
next.UseSSL = false
next.SSLMode = sslModeDisable
return next
}
func resolveMySQLTLSMode(config connection.ConnectionConfig) string {
switch normalizedSSLMode(config) {
case sslModeDisable:
return "false"
case sslModeRequired:
return "true"
case sslModeSkipVerify:
return "skip-verify"
default:
return "preferred"
}
}
func resolvePostgresSSLMode(config connection.ConnectionConfig) string {
switch normalizedSSLMode(config) {
case sslModeDisable:
return "disable"
case sslModeRequired:
return "require"
case sslModeSkipVerify:
return "require"
default:
return "require"
}
}
func resolveSQLServerTLSSettings(config connection.ConnectionConfig) (encrypt string, trustServerCertificate string) {
switch normalizedSSLMode(config) {
case sslModeDisable:
return "disable", "true"
case sslModeRequired:
return "true", "false"
case sslModeSkipVerify:
return "true", "true"
default:
return "false", "true"
}
}
func resolveGenericTLSConfig(config connection.ConnectionConfig) *tls.Config {
switch normalizedSSLMode(config) {
case sslModeDisable:
return nil
case sslModeRequired:
return &tls.Config{MinVersion: tls.VersionTLS12}
case sslModeSkipVerify:
return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true}
default:
// Preferred: 先尝试 TLS为提升兼容性默认跳过证书校验失败时由调用方按需回退明文。
return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true}
}
}
func resolveMongoTLSSettings(config connection.ConnectionConfig) (enabled bool, insecure bool) {
switch normalizedSSLMode(config) {
case sslModeDisable:
return false, false
case sslModeRequired:
return true, false
case sslModeSkipVerify:
return true, true
default:
return true, true
}
}
func resolveTDengineNet(config connection.ConnectionConfig) string {
if normalizedSSLMode(config) == sslModeDisable {
return "ws"
}
return "wss"
}

View File

@@ -40,11 +40,12 @@ func (t *TDengineDB) getDSN(config connection.ConnectionConfig) string {
path = "/" + dbName
}
return fmt.Sprintf("%s:%s@ws(%s)%s", user, pass, net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), path)
netType := resolveTDengineNet(config)
return fmt.Sprintf("%s:%s@%s(%s)%s", user, pass, netType, net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), path)
}
func (t *TDengineDB) Connect(config connection.ConnectionConfig) error {
var dsn string
runConfig := config
if config.UseSSH {
logger.Infof("TDengine 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
@@ -68,23 +69,38 @@ func (t *TDengineDB) Connect(config connection.ConnectionConfig) error {
localConfig.Host = host
localConfig.Port = port
localConfig.UseSSH = false
dsn = t.getDSN(localConfig)
runConfig = localConfig
logger.Infof("TDengine 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
} else {
dsn = t.getDSN(config)
}
db, err := sql.Open("taosWS", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
attempts := []connection.ConnectionConfig{runConfig}
if shouldTrySSLPreferredFallback(runConfig) {
attempts = append(attempts, withSSLDisabled(runConfig))
}
t.conn = db
t.pingTimeout = getConnectTimeout(config)
if err := t.Ping(); err != nil {
return fmt.Errorf("连接建立后验证失败:%w", err)
var failures []string
for idx, attempt := range attempts {
dsn := t.getDSN(attempt)
db, err := sql.Open("taosWS", dsn)
if err != nil {
failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err))
continue
}
t.conn = db
t.pingTimeout = getConnectTimeout(attempt)
if err := t.Ping(); err != nil {
_ = db.Close()
t.conn = nil
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err))
continue
}
if idx > 0 {
logger.Warnf("TDengine SSL 优先连接失败,已回退至明文连接")
}
return nil
}
return nil
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ""))
}
func (t *TDengineDB) Close() error {

View File

@@ -41,7 +41,7 @@ func (v *VastbaseDB) getDSN(config connection.ConnectionConfig) string {
}
u.User = url.UserPassword(config.User, config.Password)
q := url.Values{}
q.Set("sslmode", "disable")
q.Set("sslmode", resolvePostgresSSLMode(config))
q.Set("connect_timeout", strconv.Itoa(getConnectTimeoutSeconds(config)))
u.RawQuery = q.Encode()
@@ -49,7 +49,7 @@ func (v *VastbaseDB) getDSN(config connection.ConnectionConfig) string {
}
func (v *VastbaseDB) Connect(config connection.ConnectionConfig) error {
var dsn string
runConfig := config
if config.UseSSH {
logger.Infof("Vastbase 使用 SSH 连接:地址=%s:%d 用户=%s", config.Host, config.Port, config.User)
@@ -75,23 +75,37 @@ func (v *VastbaseDB) Connect(config connection.ConnectionConfig) error {
localConfig.Port = port
localConfig.UseSSH = false
dsn = v.getDSN(localConfig)
runConfig = localConfig
logger.Infof("Vastbase 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
} else {
dsn = v.getDSN(config)
}
db, err := sql.Open("postgres", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
attempts := []connection.ConnectionConfig{runConfig}
if shouldTrySSLPreferredFallback(runConfig) {
attempts = append(attempts, withSSLDisabled(runConfig))
}
v.conn = db
v.pingTimeout = getConnectTimeout(config)
if err := v.Ping(); err != nil {
return fmt.Errorf("连接建立后验证失败:%w", err)
var failures []string
for idx, attempt := range attempts {
dsn := v.getDSN(attempt)
db, err := sql.Open("postgres", dsn)
if err != nil {
failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err))
continue
}
v.conn = db
v.pingTimeout = getConnectTimeout(attempt)
if err := v.Ping(); err != nil {
_ = db.Close()
v.conn = nil
failures = append(failures, fmt.Sprintf("第%d次连接验证失败: %v", idx+1, err))
continue
}
if idx > 0 {
logger.Warnf("Vastbase SSL 优先连接失败,已回退至明文连接")
}
return nil
}
return nil
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ""))
}
func (v *VastbaseDB) Close() error {

View File

@@ -2,6 +2,7 @@ package redis
import (
"context"
"crypto/tls"
"fmt"
"net"
"strconv"
@@ -201,25 +202,48 @@ func (r *RedisClientImpl) Connect(config connection.ConnectionConfig) error {
timeout := normalizeRedisTimeout(config.Timeout)
if r.isCluster {
opts := &redis.ClusterOptions{
Addrs: seedAddrs,
Username: strings.TrimSpace(config.User),
Password: config.Password,
DialTimeout: timeout,
ReadTimeout: timeout,
WriteTimeout: timeout,
attempts := []connection.ConnectionConfig{config}
if shouldTryRedisSSLPreferredFallback(config) {
attempts = append(attempts, withRedisSSLDisabled(config))
}
clusterClient := redis.NewClusterClient(opts)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if err := clusterClient.Ping(ctx).Err(); err != nil {
clusterClient.Close()
return fmt.Errorf("Redis 集群连接失败: %w", err)
var failures []string
for idx, attempt := range attempts {
var tlsConfig *tls.Config
if cfg := resolveRedisTLSConfig(attempt); cfg != nil {
if host, _, err := net.SplitHostPort(seedAddrs[0]); err == nil && host != "" {
cfg.ServerName = host
}
tlsConfig = cfg
}
opts := &redis.ClusterOptions{
Addrs: seedAddrs,
Username: strings.TrimSpace(attempt.User),
Password: attempt.Password,
DialTimeout: timeout,
ReadTimeout: timeout,
WriteTimeout: timeout,
TLSConfig: tlsConfig,
}
clusterClient := redis.NewClusterClient(opts)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
pingErr := clusterClient.Ping(ctx).Err()
cancel()
if pingErr != nil {
clusterClient.Close()
failures = append(failures, fmt.Sprintf("第%d次连接失败: %v", idx+1, pingErr))
continue
}
r.client = clusterClient
r.clusterClient = clusterClient
r.config = attempt
if idx > 0 {
logger.Warnf("Redis 集群 SSL 优先连接失败,已回退至明文连接")
}
logger.Infof("Redis 集群连接成功: seeds=%s 逻辑库=db%d", strings.Join(seedAddrs, ","), r.currentDB)
return nil
}
r.client = clusterClient
r.clusterClient = clusterClient
logger.Infof("Redis 集群连接成功: seeds=%s 逻辑库=db%d", strings.Join(seedAddrs, ","), r.currentDB)
return nil
return fmt.Errorf("Redis 集群连接失败: %s", strings.Join(failures, ""))
}
addr := seedAddrs[0]
@@ -233,29 +257,53 @@ func (r *RedisClientImpl) Connect(config connection.ConnectionConfig) error {
logger.Infof("Redis 通过 SSH 隧道连接: %s -> %s:%d", addr, config.Host, config.Port)
}
opts := &redis.Options{
Addr: addr,
Username: strings.TrimSpace(config.User),
Password: config.Password,
DB: r.currentDB,
DialTimeout: timeout,
ReadTimeout: timeout,
WriteTimeout: timeout,
attempts := []connection.ConnectionConfig{config}
if shouldTryRedisSSLPreferredFallback(config) {
attempts = append(attempts, withRedisSSLDisabled(config))
}
singleClient := redis.NewClient(opts)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
var failures []string
for idx, attempt := range attempts {
var tlsConfig *tls.Config
if cfg := resolveRedisTLSConfig(attempt); cfg != nil {
if host, _, err := net.SplitHostPort(addr); err == nil && host != "" {
cfg.ServerName = host
}
tlsConfig = cfg
}
if err := singleClient.Ping(ctx).Err(); err != nil {
singleClient.Close()
return fmt.Errorf("Redis 连接失败: %w", err)
opts := &redis.Options{
Addr: addr,
Username: strings.TrimSpace(attempt.User),
Password: attempt.Password,
DB: r.currentDB,
DialTimeout: timeout,
ReadTimeout: timeout,
WriteTimeout: timeout,
TLSConfig: tlsConfig,
}
singleClient := redis.NewClient(opts)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
pingErr := singleClient.Ping(ctx).Err()
cancel()
if pingErr != nil {
singleClient.Close()
failures = append(failures, fmt.Sprintf("第%d次连接失败: %v", idx+1, pingErr))
continue
}
r.client = singleClient
r.singleClient = singleClient
r.config = attempt
if idx > 0 {
logger.Warnf("Redis SSL 优先连接失败,已回退至明文连接")
}
logger.Infof("Redis 连接成功: %s DB=%d", addr, r.currentDB)
return nil
}
r.client = singleClient
r.singleClient = singleClient
logger.Infof("Redis 连接成功: %s DB=%d", addr, r.currentDB)
return nil
return fmt.Errorf("Redis 连接失败: %s", strings.Join(failures, ""))
}
// Close closes the Redis connection

View File

@@ -0,0 +1,55 @@
package redis
import (
"crypto/tls"
"strings"
"GoNavi-Wails/internal/connection"
)
func normalizeRedisSSLMode(raw string) string {
mode := strings.ToLower(strings.TrimSpace(raw))
switch mode {
case "", "preferred", "prefer":
return "preferred"
case "required", "require", "on", "true", "mandatory", "strict":
return "required"
case "skip-verify", "insecure", "skipverify", "skip_verify", "insecure-skip-verify":
return "skip-verify"
case "disable", "disabled", "off", "false", "none":
return "disable"
default:
return "preferred"
}
}
func redisSSLMode(config connection.ConnectionConfig) string {
if !config.UseSSL {
return "disable"
}
return normalizeRedisSSLMode(config.SSLMode)
}
func shouldTryRedisSSLPreferredFallback(config connection.ConnectionConfig) bool {
return config.UseSSL && normalizeRedisSSLMode(config.SSLMode) == "preferred"
}
func withRedisSSLDisabled(config connection.ConnectionConfig) connection.ConnectionConfig {
next := config
next.UseSSL = false
next.SSLMode = "disable"
return next
}
func resolveRedisTLSConfig(config connection.ConnectionConfig) *tls.Config {
switch redisSSLMode(config) {
case "disable":
return nil
case "required":
return &tls.Config{MinVersion: tls.VersionTLS12}
case "skip-verify":
return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true}
default:
return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true}
}
}